質問編集履歴
2
title
CHANGED
File without changes
|
body
CHANGED
@@ -1,6 +1,83 @@
|
|
1
1
|
``````ここに言語を入力
|
2
|
+
ここに言語を入力```import os, pickle
|
3
|
+
import numpy as np
|
4
|
+
from PIL import Image
|
5
|
+
from sklearn.model_selection import train_test_split
|
6
|
+
|
7
|
+
class DigitImage:
|
8
|
+
def __init__(self):
|
9
|
+
self.dataset = {}
|
10
|
+
|
11
|
+
def load_data(self, save_file, normalize=True, flatten=False, one_hot_label=True):
|
12
|
+
# save_fileがなければデータセットを作成, あればデータを読み込む
|
13
|
+
if not os.path.exists(save_file):
|
14
|
+
self.create_dataset(save_file)
|
15
|
+
else:
|
16
|
+
with open(save_file, 'rb') as f:
|
17
|
+
self.dataset = pickle.load(f)
|
18
|
+
|
19
|
+
if normalize:
|
20
|
+
# float32に変換して0~1の間になるよう正規化
|
21
|
+
self.dataset['img'] = self.dataset['img'].astype(np.float32) / 255
|
22
|
+
|
23
|
+
if not flatten:
|
24
|
+
self.dataset['img'] = self.dataset['img'].reshape(-1, 1, 28, 28)
|
25
|
+
|
26
|
+
if one_hot_label:
|
27
|
+
self.dataset['label'] = self.change_one_hot_label()
|
28
|
+
|
29
|
+
return self.dataset['img'], self.dataset['label']
|
30
|
+
|
31
|
+
def change_one_hot_label(self):
|
32
|
+
T = np.zeros((self.dataset['label'].size, 10))
|
33
|
+
for idx, row in enumerate(T):
|
34
|
+
row[self.dataset['label'][idx]] = 1
|
2
|
-
|
35
|
+
return T
|
36
|
+
|
37
|
+
def create_dataset(self, save_file):
|
38
|
+
image, label = [], []
|
39
|
+
for i in range(10):
|
40
|
+
dir_path = os.path.join('dataset', str(i))
|
41
|
+
filelist = [os.path.join(dir_path, file) for file in os.listdir(dir_path)]
|
42
|
+
|
43
|
+
# ファイル毎に画像を1次元のnumpy配列として追加していく
|
44
|
+
images = []
|
45
|
+
for file in filelist:
|
46
|
+
_img = Image.open(file).convert('L')
|
47
|
+
_img = _img.resize((28, 28))
|
48
|
+
_img = np.asarray(_img, dtype=np.uint8).reshape(-1)
|
49
|
+
images.append(_img)
|
50
|
+
image.extend(images)
|
51
|
+
|
52
|
+
label.extend([i] * len(filelist))
|
53
|
+
|
54
|
+
self.dataset = {'img': np.array(image), 'label': np.array(label)}
|
55
|
+
with open(save_file, 'wb') as f:
|
56
|
+
pickle.dump(self.dataset, f)
|
57
|
+
|
58
|
+
|
59
|
+
if __name__ == '__main__':
|
60
|
+
|
61
|
+
digit = DigitImage()
|
62
|
+
X, y = digit.load_data('digit_dataset.pkl', normalize=True,
|
63
|
+
flatten=False, one_hot_label=True)
|
64
|
+
print(X.shape)
|
65
|
+
print(y.shape)
|
66
|
+
|
67
|
+
# ここでロードしたデータを学習用とテスト用に分割. シャッフルもされる
|
68
|
+
X_train, X_test, y_train, y_test = train_test_split(
|
69
|
+
X.reshape((len(X), 28, 28, 1)),
|
70
|
+
y, train_size=0.6)
|
71
|
+
print(y_train[:10])
|
72
|
+
|
73
|
+
digit = DigitImage()
|
74
|
+
X, y = digit.load_data('digit_dataset.pkl', normalize=True,
|
75
|
+
flatten=False, one_hot_label=True)
|
76
|
+
print(X.shape) # (10000, 1, 28, 28)
|
77
|
+
print(y.shape) # (10000, 10)
|
78
|
+
コード
|
3
79
|
```
|
80
|
+
```
|
4
81
|
コード
|
5
82
|
```### 前提・実現したいこと
|
6
83
|
自作データセットをコード一行で読み込めるようにしたいです。
|
1
title
CHANGED
File without changes
|
body
CHANGED
@@ -1,4 +1,8 @@
|
|
1
|
+
``````ここに言語を入力
|
2
|
+
ここに言語を入力
|
3
|
+
```
|
4
|
+
コード
|
1
|
-
### 前提・実現したいこと
|
5
|
+
```### 前提・実現したいこと
|
2
6
|
自作データセットをコード一行で読み込めるようにしたいです。
|
3
7
|
プログラム初心者で全くわからない状況なのでおかしいことたくさん言っていると思います。すみません
|
4
8
|
ここに質問の内容を詳しく書いてください。
|