質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

88.92%

scipy minimize 一度の実行で複数の結果を取得したい

解決済

回答 3

投稿

  • 評価
  • クリップ 0
  • VIEW 268

taku_ant

score 2

前提・実現したいこと

制約付き非線形関数の最小化問題で制約条件の係数の値を色々変えながら
scipy.optimize.minimizeでそれぞれの値に対応する最小値を全て取得して結果を配列に格納する

該当のソースコード

import numpy as np
from scipy.optimize import minimize

def objective_function(x):
    return - x[0]**0.5 * x[1]**0.5

def constraint(x,a,b):
    x,a = np.array(x),np.array(a)
    return b - x@a

reps = 100 #繰り返し回数

a_init = [50,200] #制約条件係数の基準値
a_shock = np.array([np.random.rand(reps)*50+50,np.zeros(reps)]).T #係数をランダムに変更
a = a_init + a_shock
b = 10000

cons = lambda x,i: constraint(x,a[i],b)
cons_list = [lambda x: cons(x,i) for i in range(len(a))]
bnds = ((0,None),(0,None))
opts = lambda i: minimize(objective_function, #目的関数
                    (15,5), #初期値
                    method='SLSQP',
                    bounds=bnds,
                    constraints={'type':'eq', 'fun':cons_list[i]})
#以下を実行
print(opts(0).x)
# array([33.62294395, 25.00001409])

#異なるインデックスに対して再度実行すると別の値が返ってくる
print(opts(1).x)
# array([35.37724898, 25.0000091 ])

#しかし以下のように配列化すると全て同じ結果が返ってくる
results = [opts(i).x for i in range(len(a))]
print(results)
# [array([37.31991237, 24.99999784]), array([37.31991237, 24.99999784]), ...

#一度の実行で複数のインデックスについて計算すると同じ値が返ってくる
print(f"インデックス1: {opts(1).x}\nインデックス2: {opts(2).x}")
# インデックス1: [44.29645281 25.00000902]
# インデックス2: [44.29645281 25.00000902]

試したこと

〇普通にfor loopしたり、リスト内記法で結果を取得すると全て同じ値になってしまう
〇lambda関数化して1つずつ結果を取り出すと問題なく毎回異なる結果が返ってくる
〇↑をforで回して結果を配列に格納してもすべて同じ値になってしまう

補足情報(FW/ツールのバージョンなど)

実行環境はGoogle colaboratoryです

  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 過去に投稿した質問と同じ内容の質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

回答 3

check解決した方法

+1

問題はどうやらcons_listだったようです。

【該当部分】

cons_list = [lambda x: cons(x,i) for i in range(len(a))]


ここでリスト内記法を用いると中身の関数が全て同じになってしまうようです。
なので、cons_listは作らずにoptsを以下のようにすると異なる結果を一度に取得できました。

import numpy as np
from scipy.optimize import minimize

def objective_function(x):
    return - x[0]**0.5 * x[1]**0.5

def constraint(x,a,b):
    x,a = np.array(x),np.array(a)
    return b - x@a

reps = 100 #繰り返し回数

a_init = [50,200] #制約条件係数の基準値
a_shock = np.array([np.random.rand(reps)*50+50,np.zeros(reps)]).T #係数をランダムに変更
a = a_init + a_shock
b = 10000

cons = lambda x,i: constraint(x,a[i],b)
bnds = ((0,None),(0,None))

opts = lambda i: minimize(objective_function, #目的関数
                    (15,5), #初期値
                    method='SLSQP',
                    bounds=bnds,
                    constraints={'type':'eq', 'fun':lambda x:cons(x,i)})

res = [opts(i).x for i in range(reps)]
res

実行結果

#毎回結果が異なっている
[array([37.59090689, 25.00025669]), array([40.62532566, 24.99999821]), array([45.94143763, 24.99998405]), array([38.30967904, 24.99999733]), ...

投稿

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

+1

ご質問のコードを何度実行しても再現できませんでした(全て同じ値が出力されました)。
(ある特定のa_shockが原因か、あるいは、モジュールのバージョンが原因かわかりません)
少なくともprint(opts(0).x)とprint(opts(1).x)が異なることはありませんでした。
本当にご質問のような結果が得られたのですか?

[41.14691214 24.99999917]
[41.14691214 24.99999917]
[array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917]), array([41.14691214, 24.99999917])]
インデックス1: [41.14691214 24.99999917]
インデックス2: [41.14691214 24.99999917]

この結果を再現したa_shockの値

array([[86.94631736,  0.        ],
       [88.24059994,  0.        ],
       [63.88071978,  0.        ],
       [97.57583615,  0.        ],
       [59.1515068 ,  0.        ],
       [76.57500374,  0.        ],
       [62.57463254,  0.        ],
       [71.99813083,  0.        ],
       [79.11855668,  0.        ],
       [84.37097279,  0.        ],
       [56.33228482,  0.        ],
       [76.90794869,  0.        ],
       [88.78803925,  0.        ],
       [72.86546644,  0.        ],
       [64.55730109,  0.        ],
       [87.66662926,  0.        ],
       [68.26486145,  0.        ],
       [89.0884346 ,  0.        ],
       [51.9560971 ,  0.        ],
       [91.22909041,  0.        ],
       [75.41843306,  0.        ],
       [91.10976738,  0.        ],
       [73.51168527,  0.        ],
       [89.40328009,  0.        ],
       [53.81101371,  0.        ],
       [61.98393568,  0.        ],
       [54.71236022,  0.        ],
       [67.17937382,  0.        ],
       [98.55481394,  0.        ],
       [77.88477456,  0.        ],
       [73.64123646,  0.        ],
       [94.04620222,  0.        ],
       [58.52606914,  0.        ],
       [81.7636329 ,  0.        ],
       [69.39601077,  0.        ],
       [72.49074165,  0.        ],
       [77.33081577,  0.        ],
       [77.02725975,  0.        ],
       [59.50171833,  0.        ],
       [95.2699818 ,  0.        ],
       [97.16985503,  0.        ],
       [83.2670286 ,  0.        ],
       [52.5908398 ,  0.        ],
       [56.48330684,  0.        ],
       [87.88075307,  0.        ],
       [63.01896622,  0.        ],
       [89.89313833,  0.        ],
       [51.57691706,  0.        ],
       [81.922707  ,  0.        ],
       [94.3653957 ,  0.        ],
       [85.20211311,  0.        ],
       [78.60473673,  0.        ],
       [60.04666014,  0.        ],
       [54.09208572,  0.        ],
       [72.95415372,  0.        ],
       [75.09743132,  0.        ],
       [60.1501105 ,  0.        ],
       [68.80185706,  0.        ],
       [83.84797023,  0.        ],
       [98.53367147,  0.        ],
       [65.52482817,  0.        ],
       [50.90285463,  0.        ],
       [82.66954479,  0.        ],
       [88.0250928 ,  0.        ],
       [60.12744854,  0.        ],
       [95.11834392,  0.        ],
       [98.78607007,  0.        ],
       [98.39624051,  0.        ],
       [94.87834786,  0.        ],
       [76.22672687,  0.        ],
       [60.67453039,  0.        ],
       [99.30529285,  0.        ],
       [67.59547283,  0.        ],
       [70.21665278,  0.        ],
       [58.52796051,  0.        ],
       [82.91705426,  0.        ],
       [99.79190275,  0.        ],
       [94.23850625,  0.        ],
       [96.80085235,  0.        ],
       [88.94066684,  0.        ],
       [88.09091175,  0.        ],
       [52.63167066,  0.        ],
       [82.86003197,  0.        ],
       [91.62778404,  0.        ],
       [98.19051657,  0.        ],
       [69.39058311,  0.        ],
       [97.77421829,  0.        ],
       [83.57513457,  0.        ],
       [54.43381656,  0.        ],
       [54.03169961,  0.        ],
       [54.59933567,  0.        ],
       [78.88361243,  0.        ],
       [82.9944889 ,  0.        ],
       [55.94026688,  0.        ],
       [56.15304813,  0.        ],
       [63.84645291,  0.        ],
       [56.79068478,  0.        ],
       [97.46527693,  0.        ],
       [86.68915797,  0.        ],
       [64.58626962,  0.        ]])

投稿

編集

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

+1

解決しているので蛇足かも知れませんが、ラムダ式の内包表現に問題ありです。

def func(x, a):
    print(f"x={x},a={a},id(a)={id(a)}")


a = list(range(10))
print("a=", a)

lfunc = lambda x, i: func(x, a[i])
lfuncs = [lambda x: lfunc(x, i) for i in range(len(a))]

lfuncs[0](123)
lfuncs[1](456)
lfuncs[2](789)


結果

a= [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
x=123,a=9,id(a)=4452857216
x=456,a=9,id(a)=4452857216
x=789,a=9,id(a)=4452857216

投稿

編集

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 88.92%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

関連した質問

同じタグがついた質問を見る