質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

89.13%

TensorFlowでk-meansを実装

受付中

回答 0

投稿

  • 評価
  • クリップ 0
  • VIEW 1,499

Rikuto.F

score 8

現在 Learning TensorFlow :: Clustering and k-meansで学習しているのですが,このアルゴリズムで10回ないし特定の閾値を超えるまでk-meansを繰り返したいとき,どのようにすればいいのでしょうか?

上述のサイトと同様のコードですが,以下にコードを掲出します。

functions.py

import tensorflow as tf
import numpy as np

def create_samples(n_clusters, n_samples_per_cluster, n_features, embiggen_factor, seed):
    np.random.seed(seed)
    slices = []
    centroids = []
    for i in range(n_clusters):
        samples = tf.random_normal((n_samples_per_cluster, n_features),
                                   mean=0.0, stddev=5.0, dtype=tf.float32, seed=seed, name="cluster_{}".format(i))
        current_centroid = (np.random.random((1, n_features)) * embiggen_factor) - (embiggen_factor/2)
        centroids.append(current_centroid)
        samples += current_centroid
        slices.append(samples)
    samples = tf.concat(slices, 0, name='samples')
    centroids = tf.concat(centroids, 0, name='centroids')
    return centroids, samples

def plot_clusters(all_samples, centroids, n_samples_per_cluster):
    import matplotlib.pyplot as plt
    colour = plt.cm.rainbow(np.linspace(0,1,len(centroids)))
    for i, centroid in enumerate(centroids):
         samples = all_samples[i*n_samples_per_cluster:(i+1)*n_samples_per_cluster]
         plt.scatter(samples[:,0], samples[:,1], c=colour[i])
         plt.plot(centroid[0], centroid[1], markersize=35, marker="x", color='k', mew=10)
         plt.plot(centroid[0], centroid[1], markersize=30, marker="x", color='m', mew=5)
    plt.show()

def choose_random_centroids(samples, n_clusters, seed=None):
    n_samples = tf.shape(samples)[0]
    random_indices = tf.random_shuffle(tf.range(0, n_samples), seed=seed)
    begin = [0,]
    size = [n_clusters,]
    size[0] = n_clusters
    centroid_indices = tf.slice(random_indices, begin, size)
    initial_centroids = tf.gather(samples, centroid_indices)
    return initial_centroids

def assign_to_nearest(samples, centroids):
    expanded_vectors = tf.expand_dims(samples, 0)
    expanded_centroids = tf.expand_dims(centroids, 1)
    distances = tf.reduce_sum( tf.square(
        tf.subtract(expanded_vectors, expanded_centroids)), 2)
    mins = tf.argmin(distances, 0)

    nearest_indices = mins
    return nearest_indices

def update_centroids(samples, nearest_indices, n_clusters):
    nearest_indices = tf.to_int32(nearest_indices)
    partitions = tf.dynamic_partition(samples, nearest_indices, n_clusters)
    new_centroids = tf.concat([tf.expand_dims(tf.reduce_mean(partition, 0), 0) for partition in partitions], 0)
    return new_centroids

generate_samples.py

import tensorflow as tf
import numpy as np

from functions import *

n_features = 2
n_clusters = 5
n_samples_per_cluster = 500
seed = 700
embiggen_factor = 70

data_centroids, samples = create_samples(n_clusters, n_samples_per_cluster, n_features, embiggen_factor, seed)
initial_centroids = choose_random_centroids(samples, n_clusters)
nearest_indices = assign_to_nearest(samples, initial_centroids)
updated_centroids = update_centroids(samples, nearest_indices, n_clusters)


model = tf.global_variables_initializer()
with tf.Session() as session:
    sample_values = session.run(samples)
    updated_centroid_value = session.run(updated_centroids)
    print(updated_centroid_value)

    plot_clusters(sample_values, updated_centroid_value, n_samples_per_cluster)

nearest_indices および updated_centroids を繰り返し実行するようにすればいいのだと思うのですが,どのようにして更新されたセントロイドの情報を保持し,それをもとに nearest_indices を実行すればよいのか悩んでいます。

よろしくお願いします。

  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 過去に投稿した質問と同じ内容の質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

まだ回答がついていません

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 89.13%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる