pythonで下図のようなファイルを基にSOMを作成したいと考えています.
参考:https://qiita.com/T_Shinaji/items/609fe9aabd99c287b389
python
1import numpy as np 2from matplotlib import pyplot as plt 3 4class SOM(): 5 6 def __init__(self, teachers, N, seed=None): 7 self.teachers = np.array(teachers) 8 self.n_teacher = self.teachers.shape[0] 9 self.N = N 10 if not seed is None: 11 np.random.seed(seed) 12 13 x, y = np.meshgrid(range(self.N), range(self.N)) 14 self.c = np.hstack((y.flatten()[:, np.newaxis], 15 x.flatten()[:, np.newaxis])) 16 self.nodes = np.random.rand(self.N*self.N, 17 self.teachers.shape[1]) 18 19 def train(self): 20 for i, teacher in enumerate(self.teachers): 21 bmu = self._best_matching_unit(teacher) 22 d = np.linalg.norm(self.c - bmu, axis=1) 23 L = self._learning_ratio(i) 24 S = self._learning_radius(i, d) 25 self.nodes += L * S[:, np.newaxis] * (teacher - self.nodes) 26 return self.nodes 27 28 def _best_matching_unit(self, teacher): 29 #compute all norms (square) 30 norms = np.linalg.norm(self.nodes - teacher, axis=1) 31 bmu = np.argmin(norms) #argment with minimum element 32 return np.unravel_index(bmu,(self.N, self.N)) 33 34 def _neighbourhood(self, t):#neighbourhood radious 35 halflife = float(self.n_teacher/4) #for testing 36 initial = float(self.N/2) 37 return initial*np.exp(-t/halflife) 38 39 def _learning_ratio(self, t): 40 halflife = float(self.n_teacher/4) #for testing 41 initial = 0.1 42 return initial*np.exp(-t/halflife) 43 44 def _learning_radius(self, t, d): 45 # d is distance from BMU 46 s = self._neighbourhood(t) 47 return np.exp(-d**2/(2*s**2)) 48 49 50N = 20 51//このteachersにコレクションとして図のようなふぁい 52teachers = np.random.rand(10000, 3) 53som = SOM(teachers, N=N, seed=10) 54 55# Initial map 56plt.imshow(som.nodes.reshape((N, N, 3)), 57 interpolation='none') 58plt.show() 59 60# Train 61som.train() 62 63# Trained MAP 64plt.imshow(som.nodes.reshape((N, N, 3)), 65 interpolation='none') 66plt.show()
//追記
申し訳ございません途中で送信してしまいました.
続きを追記させていただきます.
上記のプログラムのteachers部分に図のようなコレクションを挿入したく下記のようなプログラムに書き直しました.
python
1import numpy as np 2import csv 3data = [] 4 5with open("data.csv","rb") as f: 6 reader = csv.reader(f) 7 header = next(reader) 8 9 for row in reader: 10 data.append(row) 11 12from matplotlib import pyplot as plt 13 14class SOM(): 15 16 def __init__(self, teachers, N, seed=None): 17 self.teachers = np.array(teachers) 18 self.n_teacher = self.teachers.shape[0] 19 self.N = N 20 if not seed is None: 21 np.random.seed(seed) 22 23 x, y = np.meshgrid(range(self.N), range(self.N)) 24 self.c = np.hstack((y.flatten()[:, np.newaxis], 25 x.flatten()[:, np.newaxis])) 26 self.nodes = np.random.rand(self.N*self.N, 27 self.teachers.shape[1]) 28 29 def train(self): 30 for i, teacher in enumerate(self.teachers): 31 bmu = self._best_matching_unit(teacher) 32 d = np.linalg.norm(self.c - bmu, axis=1) 33 L = self._learning_ratio(i) 34 S = self._learning_radius(i, d) 35 self.nodes += L * S[:, np.newaxis] * (teacher - self.nodes) 36 return self.nodes 37 38 def _best_matching_unit(self, teacher): 39 #compute all norms (square) 40 norms = np.linalg.norm(self.nodes - teacher, axis=1) 41 bmu = np.argmin(norms) #argment with minimum element 42 return np.unravel_index(bmu,(self.N, self.N)) 43 44 def _neighbourhood(self, t):#neighbourhood radious 45 halflife = float(self.n_teacher/4) #for testing 46 initial = float(self.N/2) 47 return initial*np.exp(-t/halflife) 48 49 def _learning_ratio(self, t): 50 halflife = float(self.n_teacher/4) #for testing 51 initial = 0.1 52 return initial*np.exp(-t/halflife) 53 54 def _learning_radius(self, t, d): 55 # d is distance from BMU 56 s = self._neighbourhood(t) 57 return np.exp(-d**2/(2*s**2)) 58 59 60N = 20 61teachers = data 62som = SOM(teachers, N=N, seed=10) 63 64# Initial map 65plt.imshow(som.nodes.reshape((N, N, 3)),interpolation='none') 66plt.show() 67 68# Train 69som.train() 70 71# Trained MAP 72plt.imshow(som.nodes.reshape((N, N, 3)), 73 interpolation='none') 74plt.show()
しかし下記のようなエラーがでてしまいます.
実際のコレクションは33×500の行列なので,(N, N, 3)部分を(N, N, 33)にしたりしてみましたがうまくいきませんでした.
Traceback (most recent call last): File "soms.py", line 65, in <module> plt.imshow(som.nodes.reshape((N, N, 3)),interpolation='none') ValueError: cannot reshape array of size 13200 into shape (20,20,3)
python初学者のため,根本的な話なのかもしれませんがどうかご教授いただけると幸いです.