質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

90.51%

  • Python 3.x

    9833questions

    Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

関数classifyがどのように動いているのかわかりません

受付中

回答 1

投稿 編集

  • 評価
  • クリップ 0
  • VIEW 667

mi56

score 6

前提・実現したいこと

クラスタリングについて勉強しています。
その中のkmeans法を試そうとしています。
コードは、https://github.com/joelgrus/data-science-from-scratch/blob/master/code/clustering.pyにあります

発生している問題・エラーメッセージ

下のコードにinputのデータを用意して関数classifyが何を返すか見た所、0,1,2が表示されました。
入力データと3つのself.meansとの距離を計算しminでその中の最小値を出しているので私は最小の距離が出力されると思っていました。
なぜreturnでo,1,2が返るのか教えていただきたいです。

該当のソースコード

from __future__ import division
import math, random
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from functools import reduce
import re, math, random
from collections import defaultdict, Counter

def vector_add(v, w):
    """adds two vectors componentwise"""
    return [v_i + w_i for v_i, w_i in zip(v,w)]

def vector_subtract(v, w):
    """subtracts two vectors componentwise"""
    return [v_i - w_i for v_i, w_i in zip(v,w)]

def vector_sum(vectors):
    return reduce(vector_add, vectors)

def scalar_multiply(c, v):
    return [c * v_i for v_i in v]

def vector_mean(vectors):
    """compute the vector whose i-th element is the mean of the
    i-th elements of the input vectors"""
    n = len(vectors)
    return scalar_multiply(1/n, vector_sum(vectors))

def dot(v, w):
    """v_1 * w_1 + ... + v_n * w_n"""
    return sum(v_i * w_i for v_i, w_i in zip(v, w))

def sum_of_squares(v):
    """v_1 * v_1 + ... + v_n * v_n"""
    return dot(v, v)

def magnitude(v):
    return math.sqrt(sum_of_squares(v))

def squared_distance(v, w):
    return sum_of_squares(vector_subtract(v, w))

def distance(v, w):
    return math.sqrt(squared_distance(v, w))

class KMeans:
    """performs k-means clustering"""

    def __init__(self, k):
        self.k = k          # number of clusters
        self.means = None   # means of clusters


    def classify(self, input):
        """return the index of the cluster closest to the input"""

        return min(range(self.k),
                   key=lambda i: squared_distance(input, self.means[i]))

    def train(self, inputs):

        self.means = random.sample(inputs, self.k)

        assignments = None

        while True:

            new_assignments = map(self.classify, inputs)


            if assignments == new_assignments:                
                return


            assignments = new_assignments    

            for i in range(self.k):
                i_points = [p for p, a in zip(inputs, assignments) if a == i]
                # avoid divide-by-zero if i_points is empty
                if i_points:                                
                    self.means[i] = vector_mean(i_points)    


if __name__ == "__main__":

    inputs = [[-14,-5],[13,13],[20,23],[-19,-11],[-9,-16],[21,27],[-49,15],[26,13],[-46,5],[-34,-1],[11,15],[-49,0],[-22,-16],[19,28],[-12,-8],[-13,-19],[-41,8],[-11,-6],[-25,-9],[-18,-3]]


    random.seed(0) # so you get the same results as me
    clusterer = KMeans(3)
    try:
        clusterer.train(inputs)
    except:
        val =+ 1
    print ("3-means:")
    print (clusterer.means)

first,second = zip(*inputs)
plt.scatter(first,second)

補足情報(言語/FW/ツール等のバージョンなど)

Python3.5

  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

質問への追記・修正、ベストアンサー選択の依頼

  • tell_k

    2017/05/24 17:57

    もう少し具体的にどのような方法で 0,1,2を確認したのか書くと回答がつきやすいと思います。 > 関数classifyが何を返すか見た所、0,1,2が表示されました。

    キャンセル

回答 1

0

    def classify(self, input):
        """return the index of the cluster closest to the input"""

        return min(range(self.k),
                   key=lambda i: squared_distance(input, self.means[i]))

この関数のことだと思うんですが、例えば self.k=3 だった場合 min(range(3)) を実行しているのと同じです。 これを実行するとわかりますが。 range(3) => [0,1,2] から 最小の値を取得するということです。 また第二引数に key を渡すと keyの順序にしたがって値を取り出してくれます。

https://docs.python.jp/3/library/functions.html#min

min(range(3)) # => [0,1,2] の最小の値は 0なので常に0が返ってくる

# key に各要素に「-」をつけて逆順になるような lambda関数を指定する
min(range(3), key=lambda i: -i) # => 常に 2 が返ってくる

ここまでくれば下記コードが何をやってるのか読み取れると思います。

  return min(range(self.k), key=lambda i: squared_distance(input, self.means[i]))

これは k回 squared_distance で距離を計算して、その中で 何番目の距離が最小だったかを返す関数となります。 つまり 0 が返って来たら、1回目の距離が最小だったということを意味しますし、2だったら、3回目の距離が最小だったということを意味します。

投稿

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

同じタグがついた質問を見る

  • Python 3.x

    9833questions

    Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。