質問編集履歴
4
補足
test
CHANGED
File without changes
|
test
CHANGED
@@ -421,3 +421,125 @@
|
|
421
421
|
|
422
422
|
|
423
423
|
```
|
424
|
+
|
425
|
+
```python
|
426
|
+
|
427
|
+
class SegmentationDataSet3(data.Dataset):
|
428
|
+
|
429
|
+
"""Image segmentation dataset with caching, pretransforms and multiprocessing."""
|
430
|
+
|
431
|
+
|
432
|
+
|
433
|
+
def __init__(
|
434
|
+
|
435
|
+
self,
|
436
|
+
|
437
|
+
inputs: list,
|
438
|
+
|
439
|
+
targets: list,
|
440
|
+
|
441
|
+
transform=None,
|
442
|
+
|
443
|
+
use_cache=False,
|
444
|
+
|
445
|
+
pre_transform=None,
|
446
|
+
|
447
|
+
):
|
448
|
+
|
449
|
+
self.inputs = inputs
|
450
|
+
|
451
|
+
self.targets = targets
|
452
|
+
|
453
|
+
self.transform = transform
|
454
|
+
|
455
|
+
self.inputs_dtype = torch.float32
|
456
|
+
|
457
|
+
self.targets_dtype = torch.long
|
458
|
+
|
459
|
+
self.use_cache = use_cache
|
460
|
+
|
461
|
+
self.pre_transform = pre_transform
|
462
|
+
|
463
|
+
|
464
|
+
|
465
|
+
if self.use_cache:
|
466
|
+
|
467
|
+
from itertools import repeat
|
468
|
+
|
469
|
+
from multiprocessing import Pool
|
470
|
+
|
471
|
+
|
472
|
+
|
473
|
+
with Pool() as pool:
|
474
|
+
|
475
|
+
self.cached_data = pool.starmap(
|
476
|
+
|
477
|
+
self.read_images, zip(inputs, targets, repeat(self.pre_transform))
|
478
|
+
|
479
|
+
)
|
480
|
+
|
481
|
+
|
482
|
+
|
483
|
+
def __len__(self):
|
484
|
+
|
485
|
+
return len(self.inputs)
|
486
|
+
|
487
|
+
|
488
|
+
|
489
|
+
def __getitem__(self, index: int):
|
490
|
+
|
491
|
+
if self.use_cache:
|
492
|
+
|
493
|
+
x, y = self.cached_data[index]
|
494
|
+
|
495
|
+
else:
|
496
|
+
|
497
|
+
# Select the sample
|
498
|
+
|
499
|
+
input_ID = self.inputs[index]
|
500
|
+
|
501
|
+
target_ID = self.targets[index]
|
502
|
+
|
503
|
+
|
504
|
+
|
505
|
+
# Load input and target
|
506
|
+
|
507
|
+
x, y = imread(str(input_ID)), imread(str(target_ID))
|
508
|
+
|
509
|
+
|
510
|
+
|
511
|
+
# Preprocessing
|
512
|
+
|
513
|
+
if self.transform is not None:
|
514
|
+
|
515
|
+
x, y = self.transform(x, y)
|
516
|
+
|
517
|
+
|
518
|
+
|
519
|
+
# Typecasting
|
520
|
+
|
521
|
+
x, y = torch.from_numpy(x).type(self.inputs_dtype), torch.from_numpy(y).type(
|
522
|
+
|
523
|
+
self.targets_dtype
|
524
|
+
|
525
|
+
)
|
526
|
+
|
527
|
+
|
528
|
+
|
529
|
+
return x, y
|
530
|
+
|
531
|
+
|
532
|
+
|
533
|
+
@staticmethod
|
534
|
+
|
535
|
+
def read_images(inp, tar, pre_transform):
|
536
|
+
|
537
|
+
inp, tar = imread(str(inp)), imread(str(tar))
|
538
|
+
|
539
|
+
if pre_transform:
|
540
|
+
|
541
|
+
inp, tar = pre_transform(inp, tar)
|
542
|
+
|
543
|
+
return inp, tar
|
544
|
+
|
545
|
+
```
|
3
理想の配列を変えた
test
CHANGED
File without changes
|
test
CHANGED
@@ -12,6 +12,10 @@
|
|
12
12
|
|
13
13
|
バッチサイズは1
|
14
14
|
|
15
|
+
Input,Targetの画像はそれぞれ44枚。合計で88枚。
|
16
|
+
|
17
|
+
モデルはUNet
|
18
|
+
|
15
19
|
|
16
20
|
|
17
21
|
現在の表示 理想
|
@@ -24,11 +28,11 @@
|
|
24
28
|
|
25
29
|
tensor([0, 1])
|
26
30
|
|
27
|
-
torch.Size([1, 1, 8, 256, 256, 3]) →[1, 1, 8, 256, 256]
|
31
|
+
torch.Size([1, 1, 8, 256, 256, 3]) →[1, 1, 88, 256, 256]
|
28
32
|
|
29
33
|
tensor(0.) tensor(1.)
|
30
34
|
|
31
|
-
torch.Size([1, 8, 256, 256])
|
35
|
+
torch.Size([1, 8, 256, 256]) →[1,88,256,256]
|
32
36
|
|
33
37
|
tensor([0, 1]) →[0, 1, 2]
|
34
38
|
|
@@ -384,6 +388,36 @@
|
|
384
388
|
|
385
389
|
|
386
390
|
|
387
|
-
|
388
|
-
|
389
391
|
```
|
392
|
+
|
393
|
+
```UNet
|
394
|
+
|
395
|
+
# model
|
396
|
+
|
397
|
+
model = UNet(
|
398
|
+
|
399
|
+
in_channels=3,
|
400
|
+
|
401
|
+
#in_channels=1,
|
402
|
+
|
403
|
+
out_channels=1,
|
404
|
+
|
405
|
+
#out_channels=3,
|
406
|
+
|
407
|
+
n_blocks=4,
|
408
|
+
|
409
|
+
start_filters=32,
|
410
|
+
|
411
|
+
activation="relu",
|
412
|
+
|
413
|
+
normalization="batch",
|
414
|
+
|
415
|
+
conv_mode="same",
|
416
|
+
|
417
|
+
dim=3,
|
418
|
+
|
419
|
+
).to(device)
|
420
|
+
|
421
|
+
|
422
|
+
|
423
|
+
```
|
2
補足
test
CHANGED
File without changes
|
test
CHANGED
@@ -8,7 +8,9 @@
|
|
8
8
|
|
9
9
|
3というのはどこにも定義していないのですが、、
|
10
10
|
|
11
|
-
データそのものがおかしいですかね?画像はtiffファイルを8枚重ねています。(3次元)
|
11
|
+
データそのものがおかしいですかね?画像はtiffファイルを8枚重ねています。(3次元)input画像はRGB(3チャンネル)、output画像は白黒(1チャンネル)
|
12
|
+
|
13
|
+
バッチサイズは1
|
12
14
|
|
13
15
|
|
14
16
|
|
1
補足
test
CHANGED
File without changes
|
test
CHANGED
@@ -8,11 +8,11 @@
|
|
8
8
|
|
9
9
|
3というのはどこにも定義していないのですが、、
|
10
10
|
|
11
|
-
データそのものがおかしいですかね?画像はtiffファイルを8枚重ねています。
|
11
|
+
データそのものがおかしいですかね?画像はtiffファイルを8枚重ねています。(3次元)
|
12
|
-
|
13
|
-
|
14
|
-
|
12
|
+
|
13
|
+
|
14
|
+
|
15
|
-
現在の表示 理想
|
15
|
+
現在の表示 理想
|
16
16
|
|
17
17
|
torch.Size([1, 8, 256, 256, 3]) →[1, 8, 256, 256]
|
18
18
|
|
@@ -22,13 +22,13 @@
|
|
22
22
|
|
23
23
|
tensor([0, 1])
|
24
24
|
|
25
|
-
torch.Size([1, 1, 8, 256, 256, 3]) →[1, 1, 8, 256, 256]
|
25
|
+
torch.Size([1, 1, 8, 256, 256, 3]) →[1, 1, 8, 256, 256]
|
26
26
|
|
27
27
|
tensor(0.) tensor(1.)
|
28
28
|
|
29
29
|
torch.Size([1, 8, 256, 256])
|
30
30
|
|
31
|
-
tensor([0, 1])
|
31
|
+
tensor([0, 1]) →[0, 1, 2]
|
32
32
|
|
33
33
|
|
34
34
|
|