質問編集履歴
2
test
CHANGED
File without changes
|
test
CHANGED
@@ -1,9 +1,163 @@
|
|
1
1
|
``````ここに言語を入力
|
2
2
|
|
3
|
+
ここに言語を入力```import os, pickle
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from PIL import Image
|
8
|
+
|
9
|
+
from sklearn.model_selection import train_test_split
|
10
|
+
|
11
|
+
|
12
|
+
|
13
|
+
class DigitImage:
|
14
|
+
|
15
|
+
def __init__(self):
|
16
|
+
|
17
|
+
self.dataset = {}
|
18
|
+
|
19
|
+
|
20
|
+
|
21
|
+
def load_data(self, save_file, normalize=True, flatten=False, one_hot_label=True):
|
22
|
+
|
23
|
+
# save_fileがなければデータセットを作成, あればデータを読み込む
|
24
|
+
|
25
|
+
if not os.path.exists(save_file):
|
26
|
+
|
27
|
+
self.create_dataset(save_file)
|
28
|
+
|
29
|
+
else:
|
30
|
+
|
31
|
+
with open(save_file, 'rb') as f:
|
32
|
+
|
33
|
+
self.dataset = pickle.load(f)
|
34
|
+
|
35
|
+
|
36
|
+
|
37
|
+
if normalize:
|
38
|
+
|
39
|
+
# float32に変換して0~1の間になるよう正規化
|
40
|
+
|
41
|
+
self.dataset['img'] = self.dataset['img'].astype(np.float32) / 255
|
42
|
+
|
43
|
+
|
44
|
+
|
45
|
+
if not flatten:
|
46
|
+
|
47
|
+
self.dataset['img'] = self.dataset['img'].reshape(-1, 1, 28, 28)
|
48
|
+
|
49
|
+
|
50
|
+
|
51
|
+
if one_hot_label:
|
52
|
+
|
53
|
+
self.dataset['label'] = self.change_one_hot_label()
|
54
|
+
|
55
|
+
|
56
|
+
|
57
|
+
return self.dataset['img'], self.dataset['label']
|
58
|
+
|
59
|
+
|
60
|
+
|
61
|
+
def change_one_hot_label(self):
|
62
|
+
|
63
|
+
T = np.zeros((self.dataset['label'].size, 10))
|
64
|
+
|
65
|
+
for idx, row in enumerate(T):
|
66
|
+
|
67
|
+
row[self.dataset['label'][idx]] = 1
|
68
|
+
|
3
|
-
|
69
|
+
return T
|
70
|
+
|
71
|
+
|
72
|
+
|
73
|
+
def create_dataset(self, save_file):
|
74
|
+
|
75
|
+
image, label = [], []
|
76
|
+
|
77
|
+
for i in range(10):
|
78
|
+
|
79
|
+
dir_path = os.path.join('dataset', str(i))
|
80
|
+
|
81
|
+
filelist = [os.path.join(dir_path, file) for file in os.listdir(dir_path)]
|
82
|
+
|
83
|
+
|
84
|
+
|
85
|
+
# ファイル毎に画像を1次元のnumpy配列として追加していく
|
86
|
+
|
87
|
+
images = []
|
88
|
+
|
89
|
+
for file in filelist:
|
90
|
+
|
91
|
+
_img = Image.open(file).convert('L')
|
92
|
+
|
93
|
+
_img = _img.resize((28, 28))
|
94
|
+
|
95
|
+
_img = np.asarray(_img, dtype=np.uint8).reshape(-1)
|
96
|
+
|
97
|
+
images.append(_img)
|
98
|
+
|
99
|
+
image.extend(images)
|
100
|
+
|
101
|
+
|
102
|
+
|
103
|
+
label.extend([i] * len(filelist))
|
104
|
+
|
105
|
+
|
106
|
+
|
107
|
+
self.dataset = {'img': np.array(image), 'label': np.array(label)}
|
108
|
+
|
109
|
+
with open(save_file, 'wb') as f:
|
110
|
+
|
111
|
+
pickle.dump(self.dataset, f)
|
112
|
+
|
113
|
+
|
114
|
+
|
115
|
+
|
116
|
+
|
117
|
+
if __name__ == '__main__':
|
118
|
+
|
119
|
+
|
120
|
+
|
121
|
+
digit = DigitImage()
|
122
|
+
|
123
|
+
X, y = digit.load_data('digit_dataset.pkl', normalize=True,
|
124
|
+
|
125
|
+
flatten=False, one_hot_label=True)
|
126
|
+
|
127
|
+
print(X.shape)
|
128
|
+
|
129
|
+
print(y.shape)
|
130
|
+
|
131
|
+
|
132
|
+
|
133
|
+
# ここでロードしたデータを学習用とテスト用に分割. シャッフルもされる
|
134
|
+
|
135
|
+
X_train, X_test, y_train, y_test = train_test_split(
|
136
|
+
|
137
|
+
X.reshape((len(X), 28, 28, 1)),
|
138
|
+
|
139
|
+
y, train_size=0.6)
|
140
|
+
|
141
|
+
print(y_train[:10])
|
142
|
+
|
143
|
+
|
144
|
+
|
145
|
+
digit = DigitImage()
|
146
|
+
|
147
|
+
X, y = digit.load_data('digit_dataset.pkl', normalize=True,
|
148
|
+
|
149
|
+
flatten=False, one_hot_label=True)
|
150
|
+
|
151
|
+
print(X.shape) # (10000, 1, 28, 28)
|
152
|
+
|
153
|
+
print(y.shape) # (10000, 10)
|
154
|
+
|
155
|
+
コード
|
4
156
|
|
5
157
|
```
|
6
158
|
|
159
|
+
```
|
160
|
+
|
7
161
|
コード
|
8
162
|
|
9
163
|
```### 前提・実現したいこと
|
1
test
CHANGED
File without changes
|
test
CHANGED
@@ -1,4 +1,12 @@
|
|
1
|
+
``````ここに言語を入力
|
2
|
+
|
3
|
+
ここに言語を入力
|
4
|
+
|
5
|
+
```
|
6
|
+
|
7
|
+
コード
|
8
|
+
|
1
|
-
### 前提・実現したいこと
|
9
|
+
```### 前提・実現したいこと
|
2
10
|
|
3
11
|
自作データセットをコード一行で読み込めるようにしたいです。
|
4
12
|
|