質問編集履歴
2
    
        title	
    CHANGED
    
    | 
         
            File without changes
         
     | 
    
        body	
    CHANGED
    
    | 
         @@ -1,6 +1,83 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            ``````ここに言語を入力
         
     | 
| 
      
 2 
     | 
    
         
            +
            ここに言語を入力```import os, pickle
         
     | 
| 
      
 3 
     | 
    
         
            +
            import numpy as np
         
     | 
| 
      
 4 
     | 
    
         
            +
            from PIL import Image
         
     | 
| 
      
 5 
     | 
    
         
            +
            from sklearn.model_selection import train_test_split
         
     | 
| 
      
 6 
     | 
    
         
            +
             
         
     | 
| 
      
 7 
     | 
    
         
            +
            class DigitImage:
         
     | 
| 
      
 8 
     | 
    
         
            +
                def __init__(self):
         
     | 
| 
      
 9 
     | 
    
         
            +
                    self.dataset = {}
         
     | 
| 
      
 10 
     | 
    
         
            +
             
         
     | 
| 
      
 11 
     | 
    
         
            +
                def load_data(self, save_file, normalize=True, flatten=False, one_hot_label=True):
         
     | 
| 
      
 12 
     | 
    
         
            +
                    # save_fileがなければデータセットを作成, あればデータを読み込む
         
     | 
| 
      
 13 
     | 
    
         
            +
                    if not os.path.exists(save_file):
         
     | 
| 
      
 14 
     | 
    
         
            +
                        self.create_dataset(save_file)
         
     | 
| 
      
 15 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 16 
     | 
    
         
            +
                        with open(save_file, 'rb') as f:
         
     | 
| 
      
 17 
     | 
    
         
            +
                            self.dataset = pickle.load(f)
         
     | 
| 
      
 18 
     | 
    
         
            +
             
         
     | 
| 
      
 19 
     | 
    
         
            +
                    if normalize:
         
     | 
| 
      
 20 
     | 
    
         
            +
                        # float32に変換して0~1の間になるよう正規化
         
     | 
| 
      
 21 
     | 
    
         
            +
                        self.dataset['img'] = self.dataset['img'].astype(np.float32) / 255
         
     | 
| 
      
 22 
     | 
    
         
            +
             
         
     | 
| 
      
 23 
     | 
    
         
            +
                    if not flatten:
         
     | 
| 
      
 24 
     | 
    
         
            +
                        self.dataset['img'] = self.dataset['img'].reshape(-1, 1, 28, 28)
         
     | 
| 
      
 25 
     | 
    
         
            +
             
         
     | 
| 
      
 26 
     | 
    
         
            +
                    if one_hot_label:
         
     | 
| 
      
 27 
     | 
    
         
            +
                        self.dataset['label'] = self.change_one_hot_label()
         
     | 
| 
      
 28 
     | 
    
         
            +
             
         
     | 
| 
      
 29 
     | 
    
         
            +
                    return self.dataset['img'], self.dataset['label']
         
     | 
| 
      
 30 
     | 
    
         
            +
             
         
     | 
| 
      
 31 
     | 
    
         
            +
                def change_one_hot_label(self):
         
     | 
| 
      
 32 
     | 
    
         
            +
                    T = np.zeros((self.dataset['label'].size, 10))
         
     | 
| 
      
 33 
     | 
    
         
            +
                    for idx, row in enumerate(T):
         
     | 
| 
      
 34 
     | 
    
         
            +
                        row[self.dataset['label'][idx]] = 1
         
     | 
| 
       2 
     | 
    
         
            -
             
     | 
| 
      
 35 
     | 
    
         
            +
                    return T
         
     | 
| 
      
 36 
     | 
    
         
            +
             
         
     | 
| 
      
 37 
     | 
    
         
            +
                def create_dataset(self, save_file):
         
     | 
| 
      
 38 
     | 
    
         
            +
                    image, label = [], []
         
     | 
| 
      
 39 
     | 
    
         
            +
                    for i in range(10):
         
     | 
| 
      
 40 
     | 
    
         
            +
                        dir_path = os.path.join('dataset', str(i))
         
     | 
| 
      
 41 
     | 
    
         
            +
                        filelist = [os.path.join(dir_path, file) for file in os.listdir(dir_path)]
         
     | 
| 
      
 42 
     | 
    
         
            +
             
         
     | 
| 
      
 43 
     | 
    
         
            +
                        # ファイル毎に画像を1次元のnumpy配列として追加していく
         
     | 
| 
      
 44 
     | 
    
         
            +
                        images = []
         
     | 
| 
      
 45 
     | 
    
         
            +
                        for file in filelist:
         
     | 
| 
      
 46 
     | 
    
         
            +
                            _img = Image.open(file).convert('L')
         
     | 
| 
      
 47 
     | 
    
         
            +
                            _img = _img.resize((28, 28))
         
     | 
| 
      
 48 
     | 
    
         
            +
                            _img = np.asarray(_img, dtype=np.uint8).reshape(-1)
         
     | 
| 
      
 49 
     | 
    
         
            +
                            images.append(_img)
         
     | 
| 
      
 50 
     | 
    
         
            +
                        image.extend(images)
         
     | 
| 
      
 51 
     | 
    
         
            +
             
         
     | 
| 
      
 52 
     | 
    
         
            +
                        label.extend([i] * len(filelist))
         
     | 
| 
      
 53 
     | 
    
         
            +
             
         
     | 
| 
      
 54 
     | 
    
         
            +
                    self.dataset = {'img': np.array(image), 'label': np.array(label)}
         
     | 
| 
      
 55 
     | 
    
         
            +
                    with open(save_file, 'wb') as f:
         
     | 
| 
      
 56 
     | 
    
         
            +
                        pickle.dump(self.dataset, f)
         
     | 
| 
      
 57 
     | 
    
         
            +
             
         
     | 
| 
      
 58 
     | 
    
         
            +
             
         
     | 
| 
      
 59 
     | 
    
         
            +
            if __name__ == '__main__':
         
     | 
| 
      
 60 
     | 
    
         
            +
             
         
     | 
| 
      
 61 
     | 
    
         
            +
                digit = DigitImage()
         
     | 
| 
      
 62 
     | 
    
         
            +
                X, y = digit.load_data('digit_dataset.pkl', normalize=True,
         
     | 
| 
      
 63 
     | 
    
         
            +
                                flatten=False, one_hot_label=True)
         
     | 
| 
      
 64 
     | 
    
         
            +
                print(X.shape)
         
     | 
| 
      
 65 
     | 
    
         
            +
                print(y.shape)
         
     | 
| 
      
 66 
     | 
    
         
            +
             
         
     | 
| 
      
 67 
     | 
    
         
            +
                # ここでロードしたデータを学習用とテスト用に分割. シャッフルもされる
         
     | 
| 
      
 68 
     | 
    
         
            +
                X_train, X_test, y_train, y_test = train_test_split(
         
     | 
| 
      
 69 
     | 
    
         
            +
                                                       X.reshape((len(X), 28, 28, 1)),
         
     | 
| 
      
 70 
     | 
    
         
            +
                                                        y, train_size=0.6)
         
     | 
| 
      
 71 
     | 
    
         
            +
                print(y_train[:10])
         
     | 
| 
      
 72 
     | 
    
         
            +
             
     | 
| 
      
 73 
     | 
    
         
            +
            digit = DigitImage()
         
     | 
| 
      
 74 
     | 
    
         
            +
            X, y = digit.load_data('digit_dataset.pkl', normalize=True,
         
     | 
| 
      
 75 
     | 
    
         
            +
                                flatten=False, one_hot_label=True)
         
     | 
| 
      
 76 
     | 
    
         
            +
            print(X.shape)    # (10000, 1, 28, 28)
         
     | 
| 
      
 77 
     | 
    
         
            +
            print(y.shape)    # (10000, 10)
         
     | 
| 
      
 78 
     | 
    
         
            +
            コード
         
     | 
| 
       3 
79 
     | 
    
         
             
            ```
         
     | 
| 
      
 80 
     | 
    
         
            +
            ```
         
     | 
| 
       4 
81 
     | 
    
         
             
            コード
         
     | 
| 
       5 
82 
     | 
    
         
             
            ```### 前提・実現したいこと
         
     | 
| 
       6 
83 
     | 
    
         
             
            自作データセットをコード一行で読み込めるようにしたいです。
         
     | 
1
    
        title	
    CHANGED
    
    | 
         
            File without changes
         
     | 
    
        body	
    CHANGED
    
    | 
         @@ -1,4 +1,8 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            ``````ここに言語を入力
         
     | 
| 
      
 2 
     | 
    
         
            +
            ここに言語を入力
         
     | 
| 
      
 3 
     | 
    
         
            +
            ```
         
     | 
| 
      
 4 
     | 
    
         
            +
            コード
         
     | 
| 
       1 
     | 
    
         
            -
            ### 前提・実現したいこと
         
     | 
| 
      
 5 
     | 
    
         
            +
            ```### 前提・実現したいこと
         
     | 
| 
       2 
6 
     | 
    
         
             
            自作データセットをコード一行で読み込めるようにしたいです。
         
     | 
| 
       3 
7 
     | 
    
         
             
            プログラム初心者で全くわからない状況なのでおかしいことたくさん言っていると思います。すみません
         
     | 
| 
       4 
8 
     | 
    
         
             
            ここに質問の内容を詳しく書いてください。
         
     |