質問編集履歴
4
補足
title
CHANGED
File without changes
|
body
CHANGED
@@ -209,4 +209,65 @@
|
|
209
209
|
dim=3,
|
210
210
|
).to(device)
|
211
211
|
|
212
|
+
```
|
213
|
+
```python
|
214
|
+
class SegmentationDataSet3(data.Dataset):
|
215
|
+
"""Image segmentation dataset with caching, pretransforms and multiprocessing."""
|
216
|
+
|
217
|
+
def __init__(
|
218
|
+
self,
|
219
|
+
inputs: list,
|
220
|
+
targets: list,
|
221
|
+
transform=None,
|
222
|
+
use_cache=False,
|
223
|
+
pre_transform=None,
|
224
|
+
):
|
225
|
+
self.inputs = inputs
|
226
|
+
self.targets = targets
|
227
|
+
self.transform = transform
|
228
|
+
self.inputs_dtype = torch.float32
|
229
|
+
self.targets_dtype = torch.long
|
230
|
+
self.use_cache = use_cache
|
231
|
+
self.pre_transform = pre_transform
|
232
|
+
|
233
|
+
if self.use_cache:
|
234
|
+
from itertools import repeat
|
235
|
+
from multiprocessing import Pool
|
236
|
+
|
237
|
+
with Pool() as pool:
|
238
|
+
self.cached_data = pool.starmap(
|
239
|
+
self.read_images, zip(inputs, targets, repeat(self.pre_transform))
|
240
|
+
)
|
241
|
+
|
242
|
+
def __len__(self):
|
243
|
+
return len(self.inputs)
|
244
|
+
|
245
|
+
def __getitem__(self, index: int):
|
246
|
+
if self.use_cache:
|
247
|
+
x, y = self.cached_data[index]
|
248
|
+
else:
|
249
|
+
# Select the sample
|
250
|
+
input_ID = self.inputs[index]
|
251
|
+
target_ID = self.targets[index]
|
252
|
+
|
253
|
+
# Load input and target
|
254
|
+
x, y = imread(str(input_ID)), imread(str(target_ID))
|
255
|
+
|
256
|
+
# Preprocessing
|
257
|
+
if self.transform is not None:
|
258
|
+
x, y = self.transform(x, y)
|
259
|
+
|
260
|
+
# Typecasting
|
261
|
+
x, y = torch.from_numpy(x).type(self.inputs_dtype), torch.from_numpy(y).type(
|
262
|
+
self.targets_dtype
|
263
|
+
)
|
264
|
+
|
265
|
+
return x, y
|
266
|
+
|
267
|
+
@staticmethod
|
268
|
+
def read_images(inp, tar, pre_transform):
|
269
|
+
inp, tar = imread(str(inp)), imread(str(tar))
|
270
|
+
if pre_transform:
|
271
|
+
inp, tar = pre_transform(inp, tar)
|
272
|
+
return inp, tar
|
212
273
|
```
|
3
理想の配列を変えた
title
CHANGED
File without changes
|
body
CHANGED
@@ -5,15 +5,17 @@
|
|
5
5
|
3というのはどこにも定義していないのですが、、
|
6
6
|
データそのものがおかしいですかね?画像はtiffファイルを8枚重ねています。(3次元)input画像はRGB(3チャンネル)、output画像は白黒(1チャンネル)
|
7
7
|
バッチサイズは1
|
8
|
+
Input,Targetの画像はそれぞれ44枚。合計で88枚。
|
9
|
+
モデルはUNet
|
8
10
|
|
9
11
|
現在の表示 理想
|
10
12
|
torch.Size([1, 8, 256, 256, 3]) →[1, 8, 256, 256]
|
11
13
|
tensor(0.) tensor(1.)
|
12
14
|
torch.Size([8, 256, 256])
|
13
15
|
tensor([0, 1])
|
14
|
-
torch.Size([1, 1, 8, 256, 256, 3]) →[1, 1,
|
16
|
+
torch.Size([1, 1, 8, 256, 256, 3]) →[1, 1, 88, 256, 256]
|
15
17
|
tensor(0.) tensor(1.)
|
16
|
-
torch.Size([1, 8, 256, 256])
|
18
|
+
torch.Size([1, 8, 256, 256]) →[1,88,256,256]
|
17
19
|
tensor([0, 1]) →[0, 1, 2]
|
18
20
|
|
19
21
|
|
@@ -191,5 +193,20 @@
|
|
191
193
|
dataset_viewer_training = DatasetViewer(dataset_train)
|
192
194
|
dataset_viewer_training.napari() # navigate with 'n' for next and 'b' for back
|
193
195
|
|
196
|
+
```
|
197
|
+
```UNet
|
198
|
+
# model
|
199
|
+
model = UNet(
|
200
|
+
in_channels=3,
|
201
|
+
#in_channels=1,
|
202
|
+
out_channels=1,
|
203
|
+
#out_channels=3,
|
204
|
+
n_blocks=4,
|
205
|
+
start_filters=32,
|
206
|
+
activation="relu",
|
207
|
+
normalization="batch",
|
208
|
+
conv_mode="same",
|
209
|
+
dim=3,
|
210
|
+
).to(device)
|
194
211
|
|
195
212
|
```
|
2
補足
title
CHANGED
File without changes
|
body
CHANGED
@@ -3,7 +3,8 @@
|
|
3
3
|
画像はうまく出力されるのですが、x.shapeがおかしいです。
|
4
4
|
プログラムを実行すると、画像の出力とともにtorch.sizeなどが端末に表示されます。
|
5
5
|
3というのはどこにも定義していないのですが、、
|
6
|
-
データそのものがおかしいですかね?画像はtiffファイルを8枚重ねています。(3次元)
|
6
|
+
データそのものがおかしいですかね?画像はtiffファイルを8枚重ねています。(3次元)input画像はRGB(3チャンネル)、output画像は白黒(1チャンネル)
|
7
|
+
バッチサイズは1
|
7
8
|
|
8
9
|
現在の表示 理想
|
9
10
|
torch.Size([1, 8, 256, 256, 3]) →[1, 8, 256, 256]
|
1
補足
title
CHANGED
File without changes
|
body
CHANGED
@@ -3,17 +3,17 @@
|
|
3
3
|
画像はうまく出力されるのですが、x.shapeがおかしいです。
|
4
4
|
プログラムを実行すると、画像の出力とともにtorch.sizeなどが端末に表示されます。
|
5
5
|
3というのはどこにも定義していないのですが、、
|
6
|
-
データそのものがおかしいですかね?画像はtiffファイルを8枚重ねています。
|
6
|
+
データそのものがおかしいですかね?画像はtiffファイルを8枚重ねています。(3次元)
|
7
7
|
|
8
|
-
現在の表示
|
8
|
+
現在の表示 理想
|
9
9
|
torch.Size([1, 8, 256, 256, 3]) →[1, 8, 256, 256]
|
10
10
|
tensor(0.) tensor(1.)
|
11
11
|
torch.Size([8, 256, 256])
|
12
12
|
tensor([0, 1])
|
13
|
-
torch.Size([1, 1, 8, 256, 256, 3])
|
13
|
+
torch.Size([1, 1, 8, 256, 256, 3]) →[1, 1, 8, 256, 256]
|
14
14
|
tensor(0.) tensor(1.)
|
15
15
|
torch.Size([1, 8, 256, 256])
|
16
|
-
tensor([0, 1])
|
16
|
+
tensor([0, 1]) →[0, 1, 2]
|
17
17
|
|
18
18
|
|
19
19
|
```python
|