質問編集履歴

2

2021/07/23 14:09

投稿

syosinnsya88
syosinnsya88

スコア0

test CHANGED
File without changes
test CHANGED
@@ -1,9 +1,163 @@
1
1
  ``````ここに言語を入力
2
2
 
3
+ ここに言語を入力```import os, pickle
4
+
5
+ import numpy as np
6
+
7
+ from PIL import Image
8
+
9
+ from sklearn.model_selection import train_test_split
10
+
11
+
12
+
13
+ class DigitImage:
14
+
15
+ def __init__(self):
16
+
17
+ self.dataset = {}
18
+
19
+
20
+
21
+ def load_data(self, save_file, normalize=True, flatten=False, one_hot_label=True):
22
+
23
+ # save_fileがなければデータセットを作成, あればデータを読み込む
24
+
25
+ if not os.path.exists(save_file):
26
+
27
+ self.create_dataset(save_file)
28
+
29
+ else:
30
+
31
+ with open(save_file, 'rb') as f:
32
+
33
+ self.dataset = pickle.load(f)
34
+
35
+
36
+
37
+ if normalize:
38
+
39
+ # float32に変換して0~1の間になるよう正規化
40
+
41
+ self.dataset['img'] = self.dataset['img'].astype(np.float32) / 255
42
+
43
+
44
+
45
+ if not flatten:
46
+
47
+ self.dataset['img'] = self.dataset['img'].reshape(-1, 1, 28, 28)
48
+
49
+
50
+
51
+ if one_hot_label:
52
+
53
+ self.dataset['label'] = self.change_one_hot_label()
54
+
55
+
56
+
57
+ return self.dataset['img'], self.dataset['label']
58
+
59
+
60
+
61
+ def change_one_hot_label(self):
62
+
63
+ T = np.zeros((self.dataset['label'].size, 10))
64
+
65
+ for idx, row in enumerate(T):
66
+
67
+ row[self.dataset['label'][idx]] = 1
68
+
3
- ここに言語を入力
69
+ return T
70
+
71
+
72
+
73
+ def create_dataset(self, save_file):
74
+
75
+ image, label = [], []
76
+
77
+ for i in range(10):
78
+
79
+ dir_path = os.path.join('dataset', str(i))
80
+
81
+ filelist = [os.path.join(dir_path, file) for file in os.listdir(dir_path)]
82
+
83
+
84
+
85
+ # ファイル毎に画像を1次元のnumpy配列として追加していく
86
+
87
+ images = []
88
+
89
+ for file in filelist:
90
+
91
+ _img = Image.open(file).convert('L')
92
+
93
+ _img = _img.resize((28, 28))
94
+
95
+ _img = np.asarray(_img, dtype=np.uint8).reshape(-1)
96
+
97
+ images.append(_img)
98
+
99
+ image.extend(images)
100
+
101
+
102
+
103
+ label.extend([i] * len(filelist))
104
+
105
+
106
+
107
+ self.dataset = {'img': np.array(image), 'label': np.array(label)}
108
+
109
+ with open(save_file, 'wb') as f:
110
+
111
+ pickle.dump(self.dataset, f)
112
+
113
+
114
+
115
+
116
+
117
+ if __name__ == '__main__':
118
+
119
+
120
+
121
+ digit = DigitImage()
122
+
123
+ X, y = digit.load_data('digit_dataset.pkl', normalize=True,
124
+
125
+ flatten=False, one_hot_label=True)
126
+
127
+ print(X.shape)
128
+
129
+ print(y.shape)
130
+
131
+
132
+
133
+ # ここでロードしたデータを学習用とテスト用に分割. シャッフルもされる
134
+
135
+ X_train, X_test, y_train, y_test = train_test_split(
136
+
137
+ X.reshape((len(X), 28, 28, 1)),
138
+
139
+ y, train_size=0.6)
140
+
141
+ print(y_train[:10])
142
+
143
+
144
+
145
+ digit = DigitImage()
146
+
147
+ X, y = digit.load_data('digit_dataset.pkl', normalize=True,
148
+
149
+ flatten=False, one_hot_label=True)
150
+
151
+ print(X.shape) # (10000, 1, 28, 28)
152
+
153
+ print(y.shape) # (10000, 10)
154
+
155
+ コード
4
156
 
5
157
  ```
6
158
 
159
+ ```
160
+
7
161
  コード
8
162
 
9
163
  ```### 前提・実現したいこと

1

2021/07/23 14:09

投稿

syosinnsya88
syosinnsya88

スコア0

test CHANGED
File without changes
test CHANGED
@@ -1,4 +1,12 @@
1
+ ``````ここに言語を入力
2
+
3
+ ここに言語を入力
4
+
5
+ ```
6
+
7
+ コード
8
+
1
- ### 前提・実現したいこと
9
+ ```### 前提・実現したいこと
2
10
 
3
11
  自作データセットをコード一行で読み込めるようにしたいです。
4
12