teratail header banner
teratail header banner
質問するログイン新規登録

質問編集履歴

2

2021/07/23 14:09

投稿

syosinnsya88
syosinnsya88

スコア0

title CHANGED
File without changes
body CHANGED
@@ -1,6 +1,83 @@
1
1
  ``````ここに言語を入力
2
+ ここに言語を入力```import os, pickle
3
+ import numpy as np
4
+ from PIL import Image
5
+ from sklearn.model_selection import train_test_split
6
+
7
+ class DigitImage:
8
+ def __init__(self):
9
+ self.dataset = {}
10
+
11
+ def load_data(self, save_file, normalize=True, flatten=False, one_hot_label=True):
12
+ # save_fileがなければデータセットを作成, あればデータを読み込む
13
+ if not os.path.exists(save_file):
14
+ self.create_dataset(save_file)
15
+ else:
16
+ with open(save_file, 'rb') as f:
17
+ self.dataset = pickle.load(f)
18
+
19
+ if normalize:
20
+ # float32に変換して0~1の間になるよう正規化
21
+ self.dataset['img'] = self.dataset['img'].astype(np.float32) / 255
22
+
23
+ if not flatten:
24
+ self.dataset['img'] = self.dataset['img'].reshape(-1, 1, 28, 28)
25
+
26
+ if one_hot_label:
27
+ self.dataset['label'] = self.change_one_hot_label()
28
+
29
+ return self.dataset['img'], self.dataset['label']
30
+
31
+ def change_one_hot_label(self):
32
+ T = np.zeros((self.dataset['label'].size, 10))
33
+ for idx, row in enumerate(T):
34
+ row[self.dataset['label'][idx]] = 1
2
- ここに言語を入力
35
+ return T
36
+
37
+ def create_dataset(self, save_file):
38
+ image, label = [], []
39
+ for i in range(10):
40
+ dir_path = os.path.join('dataset', str(i))
41
+ filelist = [os.path.join(dir_path, file) for file in os.listdir(dir_path)]
42
+
43
+ # ファイル毎に画像を1次元のnumpy配列として追加していく
44
+ images = []
45
+ for file in filelist:
46
+ _img = Image.open(file).convert('L')
47
+ _img = _img.resize((28, 28))
48
+ _img = np.asarray(_img, dtype=np.uint8).reshape(-1)
49
+ images.append(_img)
50
+ image.extend(images)
51
+
52
+ label.extend([i] * len(filelist))
53
+
54
+ self.dataset = {'img': np.array(image), 'label': np.array(label)}
55
+ with open(save_file, 'wb') as f:
56
+ pickle.dump(self.dataset, f)
57
+
58
+
59
+ if __name__ == '__main__':
60
+
61
+ digit = DigitImage()
62
+ X, y = digit.load_data('digit_dataset.pkl', normalize=True,
63
+ flatten=False, one_hot_label=True)
64
+ print(X.shape)
65
+ print(y.shape)
66
+
67
+ # ここでロードしたデータを学習用とテスト用に分割. シャッフルもされる
68
+ X_train, X_test, y_train, y_test = train_test_split(
69
+ X.reshape((len(X), 28, 28, 1)),
70
+ y, train_size=0.6)
71
+ print(y_train[:10])
72
+
73
+ digit = DigitImage()
74
+ X, y = digit.load_data('digit_dataset.pkl', normalize=True,
75
+ flatten=False, one_hot_label=True)
76
+ print(X.shape) # (10000, 1, 28, 28)
77
+ print(y.shape) # (10000, 10)
78
+ コード
3
79
  ```
80
+ ```
4
81
  コード
5
82
  ```### 前提・実現したいこと
6
83
  自作データセットをコード一行で読み込めるようにしたいです。

1

2021/07/23 14:09

投稿

syosinnsya88
syosinnsya88

スコア0

title CHANGED
File without changes
body CHANGED
@@ -1,4 +1,8 @@
1
+ ``````ここに言語を入力
2
+ ここに言語を入力
3
+ ```
4
+ コード
1
- ### 前提・実現したいこと
5
+ ```### 前提・実現したいこと
2
6
  自作データセットをコード一行で読み込めるようにしたいです。
3
7
  プログラム初心者で全くわからない状況なのでおかしいことたくさん言っていると思います。すみません
4
8
  ここに質問の内容を詳しく書いてください。