質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

87.37%

NumPy配列を初期化しないとkmeansが収束しない

解決済

回答 1

投稿 編集

  • 評価
  • クリップ 0
  • VIEW 820

score 11

K-meansで画像圧縮のコードを書いていました

128 * 128 pixelsの画像があり、各pixelには(R, G, B)の値が0~255でNumPy Arrayで格納されています。k-meansアルゴリズムを用いて、各pixelを1つのサンプルとみなし(つまり128*128サンプルある)、(R, G, B)の空間上で16色のクラスタに分け、各pixelをクラスタ中心点に置き換えることで、画像圧縮をしようとしました。

発生している問題・エラーメッセージ

以下のコードは、kmeansにおいて各クラスタの中心点をアップデートする部分の関数です。
以下のソースコードにおいて、

#---この行がないと、kmeansが収束しない!---#
centroids = np.zeros((num_clusters, c))  # init centroids


の部分がないと、k-meansが収束せず、逆にこの行を加えるとkmeansが動いたのですが、なぜこの行でcentroids(Numpy Array)の初期化が必要なのか、理解できません。

該当のソースコード

以下のコードは、centroids=init_centroids(k-meansにおいて初めにランダムに中心点を初期化したもの)を引数として受け取ります。クラスタの中心点をアップデートする部分の関数です。

def update_centroids(centroids, image, max_iter=30, print_every=10):
    """
    Carry out k-means centroid update step `max_iter` times

    Parameters
    ----------
    centroids : nparray
        The centroids stored as an nparray
    image : nparray
        (H, W, C) image represented as an nparray
    max_iter : int
        Number of iterations to run
    print_every : int
        Frequency of status update

    Returns
    -------
    new_centroids : nparray
        Updated centroids
    """
    (h, w, c, num_clusters) = (image.shape[0], image.shape[1], image.shape[2], centroids.shape[0])  # 画像のh高さ、w幅、c色の数、num_clustersクラスタの数(16)
    image_flat = image.reshape(h * w, c)  # flatten the image --> each row: pixel, each column: color
    for _iter in range(max_iter):
        dist = np.array([np.linalg.norm(pixel - centroids, ord=2, axis=1) for pixel in image_flat])  # 各pixelと各クラスタの中心点までの距離を計算
        idx = np.argmin(dist, axis=1)
        # Find closest centroid and update `new_centroids`

        #---この行がないと、kmeansが収束しない!---#
        centroids = np.zeros((num_clusters, c))  # centroidsをゼロで初期化
        #----- -----#

        for k in range(num_clusters):
            centroids[k] = np.mean(image_flat[idx == k], axis=0)
        if _iter % print_every == 0:
            print("{} iterations done.".format(_iter))
            print("index of cluster (idx[::100]): \n")
            print(idx[::100])
            print("centroids' RGB values: \n")
            print(centroids)
            print(25*"-")
    new_centroids = np.copy(centroids)
    return new_centroids
  • 全体の関数設計(補足)
centroids_init = init_centroids(num_clusters, image)
centroids = update_centroids(centroids_init, image, max_iter, print_every)


init_centroids関数は、以下のようになっています。

def init_centroids(num_clusters, image):
    (h, w, c) = (image.shape[0], image.shape[1], image.shape[2])
    image_flat = image.reshape(h * w, c)
    init_centroid_index = np.array([random.randint(0, h * w) for _ in range(num_clusters)])
    centroids_init = image_flat[init_centroid_index]
    return centroids_init

仮説

NumPy配列の変数アドレスの問題か何かだと思うのですが、毎イタレーションごとに、各要素を書き換えているので、問題ない気がしています。なのでなぜ該当行がないと動かないのか分かりません。メモリの問題?なお、distはshape=(128*128, 16)の配列です。

補足情報(FW/ツールのバージョンなど)

使用しているライブラリはNumPyのみです。

追記(2019/8/11/09:00)

なお、centroids配列は次のようになっています。

  • 収束しない時(centroids配列をゼロで初期化しないとき)
10 iterations done.
index of cluster (idx[::100]):

[ 9  9  6 15  9  9 15 14 15 15 15  9  9 15 13  1  9 14 15  9  9  9  9  9
 11  9  9 15  9 14 14  9  9  9  9  9  9 15 15  7  9 11  9 13  9  9  9 14
  9 15  9 15 14 14 15  9 14  9 15 14 15 14 14  9  9 14 11 11  9  6  9 14
 15  9  9 15  9  9 11 15 14 15  9  9 15  9 15  9 13  9 15  9  9  9  9 15
  9 15 11 14  9  9 13  9 15  9 14 13  9  9 14  9  9 15  9  9  9  9 15  9
  9  9  9 14  9  9  9 15  9  9  9 14  9  9  9  9 15 15  9  9 14 15  9  9
  9  9 15 12  9 14 15  7 14  9  2 15  9 14  3  2  9  9  9  1]
centroids' RGB values:

[[197 220 187]
 [185 212 158]
 [195 207 116]
 [197 211 153]
 [200 216 159]
 [203 221 196]
 [194 216 184]
 [207  89  67]
 [192 212 139]
 [140  93  52]
 [202 214 108]
 [183 204 109]
 [196 144 134]
 [142 189 120]
 [125  30  24]
 [168 184  91]]
-------------------------
20 iterations done.
index of cluster (idx[::100]):

[ 7  7  2  9  7  7  9  7 15 15  9  9  7  9 11 12  7  7  1  7  1  7  7  9
 12  1  7  6  7  1  1  7  7  7  1  1  7  9 15  1  7 12  7  0  7  9  1  7
  7  9  7  9  7  1  9  7 14  7  9 14  9  1  7  9  7 14  6 12  7  3  7  7
  9  7  7  6  7 14 12  9 14 12 14  7  9  7  9  7  9  9 12  7  7  9  7  9
  7  6  6 14  7  7  9  7  9  7 14  7  7  7  7  9  7  9  7  7  7  7  9  7
  7  7  9 14  7  7  7  9  7  9  7 14  7  7  7  7 12  9 14  7 14  9  7  7
  7 14  7 13  7 14 12  1 14  7 12 12  7 14  3 12  7  7  7 12]
centroids' RGB values:

[[206 173 150]
 [201  74  52]
 [192 213 163]
 [195 216 174]
 [199 216 188]
 [203 219 183]
 [171 198 125]
 [147  87  50]
 [209 203 183]
 [151 182  89]
 [202 214 108]
 [152 177 125]
 [188 206 111]
 [203 221 204]
 [ 81  12  10]
 [197 170  88]]
-------------------------


値が激しく変化しています。クラスタの配属変化も激しいです。

  • 収束するとき(centroids配列をきちんとゼロで初期化する時)
10 iterations done.
index of cluster (idx[::100]):

[11 10  3  2 11  1  6 15 14  7  8 10 15  2  6  3 15 15  0  1 15 15  4  2
  7 15  1  3 10 15 15 15  1  1 15 15  4  6 14 14  1  7 10  5 15 10  1 15
 15  8  4  6 15 15  0 10 12 11  8 12  2 15 15  2  4 12  5  7 11  3 15 15
  2 14  4  6  1  9  8  2  4  7  9  4  2 13  2 15 13 10  7  9  9 13 13  2
  9  5  3 12  4 15  6 13  8  9 12 13  1  0 15 13 10  2  4  9  1  1  2 15
 13  1  0 12  9 10  1  8  4 13 10 12  9  9 10  1  8  8 12  4 12  2  9 10
  1 12  0  3 12 12  7 15 12  4  7  7 11 12  5  7  1 11 11  5]
centroids' RGB values:

[[156.91060291 148.36382536  65.34303534]
 [115.60351413 113.881589    55.83346066]
 [149.43956044 180.28671329  82.11788212]
 [192.38899083 218.05321101 191.79449541]
 [106.05243902  60.4597561   34.56585366]
 [189.56716418 201.32338308 148.12271973]
 [151.72988506 196.17241379 129.73754789]
 [191.81183317 204.19786615  92.38312318]
 [172.34967623 194.25531915  84.92969473]
 [143.43722564  21.54784899  24.15276558]
 [115.6759195  156.87369882  74.57113116]
 [172.82359081  52.58037578  41.93841336]
 [ 64.18624044   6.95969423   5.67268937]
 [126.88624339 180.91269841  96.17107584]
 [203.56716418 106.52487562  76.47761194]
 [200.2748184   49.78006457  41.54560129]]
-------------------------
20 iterations done.
index of cluster (idx[::100]):

[11 10  3  2 15  1  6 15 14  7  8 10 15  2  6  3 15 15  8 10 15 15  4  2
  7 15  1  3 10 15 15 15 10  1 15 15  4  5 14 14  1  7 13  5 11 10  1 15
 15  8  4  5 15 15  2 13 12 11  8 12  2 15 15  2  4 12  5  7 11  3 15 15
  2 14  4  6  1  9  7  2  4  7  9  1  2 13  2 15  6 13  7  9 11 13 13  2
  9  5  3 12  4 15  6 13  8  9 12 13  1  0 15 13 10  8  4  9 10  1  8 11
 13  1  0 12 11 10  1  8  1 13 10 12  9  9 10  1  8  8 12  4 12  2  9 10
  1 12  0  3 12 12  7 15 12  4  7  7 11 12  5  7  1 11 11  5]
centroids' RGB values:

[[154.13868613 140.83698297  62.75912409]
 [116.4886562  106.15794066  53.08202443]
 [147.43695015 177.84359726  80.51221896]
 [192.87184116 218.00180505 191.34115523]
 [104.56811989  55.91008174  32.53814714]
 [184.55113636 201.33238636 147.74431818]
 [144.18470418 192.13275613 118.23809524]
 [189.66607302 205.07569012  93.33926981]
 [172.52888087 190.85920578  82.72292419]
 [140.20383451  19.89101917  22.60443996]
 [112.52730375 147.36433447  68.76450512]
 [171.89906542  43.12149533  38.25420561]
 [ 63.49573257   6.6116643    5.42318634]
 [122.00744048 174.20089286  88.81994048]
 [202.94634146 107.63902439  76.84634146]
 [200.11342685  52.06092184  42.48857715]]
-------------------------


中心点の値はそれほど変動せず、クラスタの配属が安定的です。

  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 過去に投稿した質問と同じ内容の質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

質問への追記・修正、ベストアンサー選択の依頼

  • meg_

    2019/08/11 09:19

    ”「centroids = np.zeros((num_clusters, c))」の前後で”centroids”の値がどうなっているのか”というのはプログラム中で”centroids”がどう変化するのかデバッグしたら何か判るのでは?という意味でした。

    ところで、関数の引数に「centroids」を渡しているのに、関数中で初期化して問題ないのですか?
    (初期化するなら渡す意味もないような)

    キャンセル

  • mokemokechicken

    2019/08/11 09:22

    あまり関係ないかもですが、
    この update_centroids() を呼び出すときの 最初の centroids はどういう値を与えているのでしょうか?

    キャンセル

  • crows_007

    2019/08/11 09:30

    printしてみたのですが、centroidsがゼロにきちんと初期化されている、ということ以外には分かりませんでした。

    関数内の初期化については、全体の関数設計について僕の説明不足でした。k-meansにおいて、一番初めにクラスタの中心点を決める際、ランダムに中心点を決めるので、その中心点をinit_centroidsと名付けており、update_centroids関数は、init_centroidsを引数として受け取ります。引数の名前がややこしくて申し訳ありません。
    失礼いたしました。質問内容に補足いたします。

    キャンセル

回答 1

checkベストアンサー

+1

もしかすると、その関数の入力の centroids の型が int系なのではないでしょうか。
なので、 centroids[k] = np.mean(..)  の結果が int に cast されて計算が正しくされないから収束しないのでは、と思いました。

import numpy as np

z_int = np.arange(6).reshape(3, 2)
z_float = z_int * 1.
print(z_int)
"""
[[0 1]
 [2 3]
 [4 5]]
"""

print(z_float)
"""
[[0. 1.]
 [2. 3.]
 [4. 5.]]
"""

#%%

z_int[0] = np.mean(z_float) * 0.567
print(z_int)  # ndarrayの型が int だと castされる
"""
[[1 1]
 [2 3]
 [4 5]]
"""

z_float[0] = np.mean(z_float) * 0.567
print(z_float)  # float系だと大丈夫
"""
[[1.4175 1.4175]
 [2.     3.    ]
 [4.     5.    ]]
"""

■ 追記1

あとひとつ気になったのが、

centroids[k] = np.mean(image_flat[idx == k], axis=0)

で cluster-k に所属するPixelが一つもない場合 
np.mean(image_flat[idx == k], axis=0) が nan になってしまって少し気持ち悪いので、
idx == k が存在しないときは更新しないような処理を入れるほうが良い気がしました。

投稿

編集

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

  • 2019/08/11 09:46

    まさにこれでした。image_flatというnp.arrayのデータ型を見ると、uint8となっておりました。RGBのピクセル値なので、納得です。
    dist = np.array([np.linalg.norm(np.array(pixel, dtype='float64') - centroids, ord=2, axis=1) for pixel in image_flat])のように、dist配列を求めるときに、pixelのデータ型をきちんと指定してあげれば、挙動が正常になりました。

    追記1について、ありがとうございます。更新しない処理を含めておきます。

    キャンセル

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 87.37%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

関連した質問

同じタグがついた質問を見る