画像データを読み込んで表示させるプログラムです。
機械学習用にデータを読み込みたいです。
画像はうまく出力されるのですが、x.shapeがおかしいです。
プログラムを実行すると、画像の出力とともにtorch.sizeなどが端末に表示されます。
3というのはどこにも定義していないのですが、、
データそのものがおかしいですかね?画像はtiffファイルを8枚重ねています。(3次元)input画像はRGB(3チャンネル)、output画像は白黒(1チャンネル)
バッチサイズは1
Input,Targetの画像はそれぞれ44枚。合計で88枚。
モデルはUNet
現在の表示 理想
torch.Size([1, 8, 256, 256, 3]) →[1, 8, 256, 256]
tensor(0.) tensor(1.)
torch.Size([8, 256, 256])
tensor([0, 1])
torch.Size([1, 1, 8, 256, 256, 3]) →[1, 1, 88, 256, 256]
tensor(0.) tensor(1.)
torch.Size([1, 8, 256, 256]) →[1,88,256,256]
tensor([0, 1]) →[0, 1, 2]
python
1# Imports 2import pathlib 3 4import albumentations 5import numpy as np 6import torch 7from torch.utils.data import DataLoader 8 9from customdatasets import SegmentationDataSet3 10from transformations import ( 11 ComposeDouble, 12 normalize_01, 13 AlbuSeg3d, 14 FunctionWrapperDouble, 15 create_dense_target, 16) 17 18#add 19from sklearn.model_selection import train_test_split 20from unet import UNet 21from trainer import Trainer 22from skimage.transform import resize 23 24 25# root directory 26#root = pathlib.Path.cwd() / "Microtubules3D" 27root = pathlib.Path.cwd() / "Cow3D" 28 29 30def get_filenames_of_path(path: pathlib.Path, ext: str = "*"): 31 """Returns a list of files in a directory/path. Uses pathlib.""" 32 filenames = [file for file in path.glob(ext) if file.is_file()] 33 return filenames 34 35 36# input and target files 37inputs = get_filenames_of_path(root / "Input") 38targets = get_filenames_of_path(root / "Target") 39 40#add1 41# pre-transformations 42pre_transforms = ComposeDouble( 43 [ 44 FunctionWrapperDouble( 45 resize, input=True, target=False, output_shape=(128, 128, 128, 1) 46 ), 47 FunctionWrapperDouble( 48 resize, 49 input=False, 50 target=True, 51 output_shape=(128, 128, 128, 1), 52 order=0, 53 anti_aliasing=False, 54 preserve_range=True, 55 ), 56 ] 57) 58 59 60# training transformations and augmentations 61# example how to properly resize and use AlbuSeg3d 62# please note that the input is grayscale and the channel dimension of size 1 is added 63# also note that the AlbuSeg3d currently only works with input that does not have a C dim! 64transforms_training = ComposeDouble( 65 [ 66 # FunctionWrapperDouble(resize, input=True, target=False, output_shape=(16, 100, 100)), 67 # FunctionWrapperDouble(resize, input=False, target=True, output_shape=(16, 100, 100), order=0, anti_aliasing=False, preserve_range=True), 68 # AlbuSeg3d(albumentations.HorizontalFlip(p=0.5)), 69 # AlbuSeg3d(albumentations.VerticalFlip(p=0.5)), 70 # AlbuSeg3d(albumentations.Rotate(p=0.5)), 71 AlbuSeg3d(albumentations.RandomRotate90(p=0.5)), 72 FunctionWrapperDouble(create_dense_target, input=False, target=True), 73 FunctionWrapperDouble(np.expand_dims, axis=0), 74 # RandomFlip(ndim_spatial=3), 75 FunctionWrapperDouble(normalize_01), 76 ] 77) 78 79#add2 80# validation transformations 81transforms_validation = ComposeDouble( 82 [ 83 FunctionWrapperDouble( 84 resize, input=True, target=False, output_shape=(128, 128, 128, 1) 85 ), 86 FunctionWrapperDouble( 87 resize, 88 input=False, 89 target=True, 90 output_shape=(128, 128, 128, 1), 91 order=0, 92 anti_aliasing=False, 93 preserve_range=True, 94 ), 95 FunctionWrapperDouble(create_dense_target, input=False, target=True), 96 FunctionWrapperDouble( 97 np.moveaxis, input=True, target=False, source=-1, destination=0 98 ), 99 FunctionWrapperDouble(normalize_01), 100 ] 101) 102 103# random seed 104random_seed = 42 105 106#add3 107# split dataset into training set and validation set 108train_size = 0.8 # 80:20 split 109inputs_train, inputs_valid = train_test_split( 110 inputs, random_state=random_seed, train_size=train_size, shuffle=True 111) 112targets_train, targets_valid = train_test_split( 113 targets, random_state=random_seed, train_size=train_size, shuffle=True 114) 115 116 117# dataset training 118dataset_train = SegmentationDataSet3( 119 inputs=inputs, 120 targets=targets, 121 transform=transforms_training, 122 use_cache=False, 123 pre_transform=None, 124) 125 126#データの数によってdataset_train[]を変化させる もとはdataset_train[1] 127x, y = dataset_train[1] 128print(x.shape) 129print(x.min(), x.max()) 130print(y.shape) 131print(torch.unique(y)) 132 133#add4 134# dataset validation 135dataset_valid = SegmentationDataSet3( 136 inputs=inputs_valid, 137 targets=targets_valid, 138 transform=transforms_validation, 139 use_cache=True, 140 pre_transform=pre_transforms, 141) 142 143# dataloader training 144dataloader_training = DataLoader( 145 dataset=dataset_train, 146 batch_size=1, 147 # batch_size of 2 won't work because the depth dimension is different between the 2 samples 148 shuffle=True, 149) 150 151 152#add5 153dataloader_validation = DataLoader( 154 dataset=dataset_train, 155 batch_size=1, 156 # batch_size of 2 won't work because the depth dimension is different between the 2 samples 157 shuffle=True, 158) 159 160 161batch = next(iter(dataloader_training)) 162x, y = batch 163print(x.shape) 164print(x.min(), x.max()) 165print(y.shape) 166print(torch.unique(y)) 167 168# create DatasetViewer instances 169from visual import DatasetViewer 170 171dataset_viewer_training = DatasetViewer(dataset_train) 172dataset_viewer_training.napari() # navigate with 'n' for next and 'b' for back 173
UNet
1# model 2model = UNet( 3 in_channels=3, 4 #in_channels=1, 5 out_channels=1, 6 #out_channels=3, 7 n_blocks=4, 8 start_filters=32, 9 activation="relu", 10 normalization="batch", 11 conv_mode="same", 12 dim=3, 13).to(device) 14
python
1class SegmentationDataSet3(data.Dataset): 2 """Image segmentation dataset with caching, pretransforms and multiprocessing.""" 3 4 def __init__( 5 self, 6 inputs: list, 7 targets: list, 8 transform=None, 9 use_cache=False, 10 pre_transform=None, 11 ): 12 self.inputs = inputs 13 self.targets = targets 14 self.transform = transform 15 self.inputs_dtype = torch.float32 16 self.targets_dtype = torch.long 17 self.use_cache = use_cache 18 self.pre_transform = pre_transform 19 20 if self.use_cache: 21 from itertools import repeat 22 from multiprocessing import Pool 23 24 with Pool() as pool: 25 self.cached_data = pool.starmap( 26 self.read_images, zip(inputs, targets, repeat(self.pre_transform)) 27 ) 28 29 def __len__(self): 30 return len(self.inputs) 31 32 def __getitem__(self, index: int): 33 if self.use_cache: 34 x, y = self.cached_data[index] 35 else: 36 # Select the sample 37 input_ID = self.inputs[index] 38 target_ID = self.targets[index] 39 40 # Load input and target 41 x, y = imread(str(input_ID)), imread(str(target_ID)) 42 43 # Preprocessing 44 if self.transform is not None: 45 x, y = self.transform(x, y) 46 47 # Typecasting 48 x, y = torch.from_numpy(x).type(self.inputs_dtype), torch.from_numpy(y).type( 49 self.targets_dtype 50 ) 51 52 return x, y 53 54 @staticmethod 55 def read_images(inp, tar, pre_transform): 56 inp, tar = imread(str(inp)), imread(str(tar)) 57 if pre_transform: 58 inp, tar = pre_transform(inp, tar) 59 return inp, tar
回答1件
あなたの回答
tips
プレビュー