質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

87.37%

torch.sizeの出力がおかしい

受付中

回答 0

投稿 編集

  • 評価
  • クリップ 0
  • VIEW 75

score 13

画像データを読み込んで表示させるプログラムです。
機械学習用にデータを読み込みたいです。
画像はうまく出力されるのですが、x.shapeがおかしいです。
プログラムを実行すると、画像の出力とともにtorch.sizeなどが端末に表示されます。
3というのはどこにも定義していないのですが、、
データそのものがおかしいですかね?画像はtiffファイルを8枚重ねています。(3次元)input画像はRGB(3チャンネル)、output画像は白黒(1チャンネル)
バッチサイズは1
Input,Targetの画像はそれぞれ44枚。合計で88枚。
モデルはUNet

現在の表示                               理想
torch.Size([1, 8, 256, 256, 3])   →[1, 8, 256, 256]
tensor(0.) tensor(1.)
torch.Size([8, 256, 256])
tensor([0, 1])
torch.Size([1, 1, 8, 256, 256, 3])   →[1, 1, 88, 256, 256]
tensor(0.) tensor(1.)
torch.Size([1, 8, 256, 256])          →[1,88,256,256]
tensor([0, 1])           →[0, 1, 2]

# Imports
import pathlib

import albumentations
import numpy as np
import torch
from torch.utils.data import DataLoader

from customdatasets import SegmentationDataSet3
from transformations import (
    ComposeDouble,
    normalize_01,
    AlbuSeg3d,
    FunctionWrapperDouble,
    create_dense_target,
)

#add
from sklearn.model_selection import train_test_split
from unet import UNet
from trainer import Trainer
from skimage.transform import resize


# root directory
#root = pathlib.Path.cwd() / "Microtubules3D"
root = pathlib.Path.cwd() / "Cow3D"


def get_filenames_of_path(path: pathlib.Path, ext: str = "*"):
    """Returns a list of files in a directory/path. Uses pathlib."""
    filenames = [file for file in path.glob(ext) if file.is_file()]
    return filenames


# input and target files
inputs = get_filenames_of_path(root / "Input")
targets = get_filenames_of_path(root / "Target")

#add1
# pre-transformations
pre_transforms = ComposeDouble(
    [
        FunctionWrapperDouble(
            resize, input=True, target=False, output_shape=(128, 128, 128, 1)
        ),
        FunctionWrapperDouble(
            resize,
            input=False,
            target=True,
            output_shape=(128, 128, 128, 1),
            order=0,
            anti_aliasing=False,
            preserve_range=True,
        ),
    ]
)


# training transformations and augmentations
# example how to properly resize and use AlbuSeg3d
# please note that the input is grayscale and the channel dimension of size 1 is added
# also note that the AlbuSeg3d currently only works with input that does not have a C dim!
transforms_training = ComposeDouble(
    [
        # FunctionWrapperDouble(resize, input=True, target=False, output_shape=(16, 100, 100)),
        # FunctionWrapperDouble(resize, input=False, target=True, output_shape=(16, 100, 100), order=0, anti_aliasing=False, preserve_range=True),
        # AlbuSeg3d(albumentations.HorizontalFlip(p=0.5)),
        # AlbuSeg3d(albumentations.VerticalFlip(p=0.5)),
        # AlbuSeg3d(albumentations.Rotate(p=0.5)),
        AlbuSeg3d(albumentations.RandomRotate90(p=0.5)),
        FunctionWrapperDouble(create_dense_target, input=False, target=True),
        FunctionWrapperDouble(np.expand_dims, axis=0),
        # RandomFlip(ndim_spatial=3),
        FunctionWrapperDouble(normalize_01),
    ]
)

#add2
# validation transformations
transforms_validation = ComposeDouble(
    [
        FunctionWrapperDouble(
            resize, input=True, target=False, output_shape=(128, 128, 128, 1)
        ),
        FunctionWrapperDouble(
            resize,
            input=False,
            target=True,
            output_shape=(128, 128, 128, 1),
            order=0,
            anti_aliasing=False,
            preserve_range=True,
        ),
        FunctionWrapperDouble(create_dense_target, input=False, target=True),
        FunctionWrapperDouble(
            np.moveaxis, input=True, target=False, source=-1, destination=0
        ),
        FunctionWrapperDouble(normalize_01),
    ]
)

# random seed
random_seed = 42

#add3
# split dataset into training set and validation set
train_size = 0.8  # 80:20 split
inputs_train, inputs_valid = train_test_split(
    inputs, random_state=random_seed, train_size=train_size, shuffle=True
)
targets_train, targets_valid = train_test_split(
    targets, random_state=random_seed, train_size=train_size, shuffle=True
)


# dataset training
dataset_train = SegmentationDataSet3(
    inputs=inputs,
    targets=targets,
    transform=transforms_training,
    use_cache=False,
    pre_transform=None,
)

#データの数によってdataset_train[]を変化させる もとはdataset_train[1]
x, y = dataset_train[1]
print(x.shape)
print(x.min(), x.max())
print(y.shape)
print(torch.unique(y))

#add4
# dataset validation
dataset_valid = SegmentationDataSet3(
    inputs=inputs_valid,
    targets=targets_valid,
    transform=transforms_validation,
    use_cache=True,
    pre_transform=pre_transforms,
)

# dataloader training
dataloader_training = DataLoader(
    dataset=dataset_train,
    batch_size=1,
    # batch_size of 2 won't work because the depth dimension is different between the 2 samples
    shuffle=True,
)


#add5
dataloader_validation = DataLoader(
    dataset=dataset_train,
    batch_size=1,
    # batch_size of 2 won't work because the depth dimension is different between the 2 samples
    shuffle=True,
)


batch = next(iter(dataloader_training))
x, y = batch
print(x.shape)
print(x.min(), x.max())
print(y.shape)
print(torch.unique(y))

# create DatasetViewer instances
from visual import DatasetViewer

dataset_viewer_training = DatasetViewer(dataset_train)
dataset_viewer_training.napari()  # navigate with 'n' for next and 'b' for back
# model
model = UNet(
    in_channels=3,
    #in_channels=1,
    out_channels=1,
    #out_channels=3,
    n_blocks=4,
    start_filters=32,
    activation="relu",
    normalization="batch",
    conv_mode="same",
    dim=3,
).to(device)
class SegmentationDataSet3(data.Dataset):
    """Image segmentation dataset with caching, pretransforms and multiprocessing."""

    def __init__(
        self,
        inputs: list,
        targets: list,
        transform=None,
        use_cache=False,
        pre_transform=None,
    ):
        self.inputs = inputs
        self.targets = targets
        self.transform = transform
        self.inputs_dtype = torch.float32
        self.targets_dtype = torch.long
        self.use_cache = use_cache
        self.pre_transform = pre_transform

        if self.use_cache:
            from itertools import repeat
            from multiprocessing import Pool

            with Pool() as pool:
                self.cached_data = pool.starmap(
                    self.read_images, zip(inputs, targets, repeat(self.pre_transform))
                )

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, index: int):
        if self.use_cache:
            x, y = self.cached_data[index]
        else:
            # Select the sample
            input_ID = self.inputs[index]
            target_ID = self.targets[index]

            # Load input and target
            x, y = imread(str(input_ID)), imread(str(target_ID))

        # Preprocessing
        if self.transform is not None:
            x, y = self.transform(x, y)

        # Typecasting
        x, y = torch.from_numpy(x).type(self.inputs_dtype), torch.from_numpy(y).type(
            self.targets_dtype
        )

        return x, y

    @staticmethod
    def read_images(inp, tar, pre_transform):
        inp, tar = imread(str(inp)), imread(str(tar))
        if pre_transform:
            inp, tar = pre_transform(inp, tar)
        return inp, tar
  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 過去に投稿した質問と同じ内容の質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

質問への追記・修正の依頼

  • tiroha

    2021/11/26 18:04

    セグメンテーションは対象:1(白)、それ以外:0(黒)という感じです。
    自身のデータセットは牛:1、それ以外:0
    にしています。

    キャンセル

  • HRCo4

    2021/11/26 19:15

    0,1 による判別であれば softmax ではなく sigmoid あたりかな。

    Dataset をざっくり見てみましたが、
    x, y = imread(str(input_ID)), imread(str(target_ID))
    ここで x がおそらく [8, 256, 256, 3](=[1tiff画像8枚, 幅, 高さ, (RGB)ch数])となっています。
    その後、
    x, y = self.transform(x, y) で FunctionWrapperDouble(np.expand_dims, axis=0) によって [1, 8, 256, 256, 3] になっています。
    この形が UNet の入力に適切なのかどうかといわれると UNet を触っていないので何ともですが、一般的ではないとは思います。

    キャンセル

  • tiroha

    2021/11/26 21:48

    そうですね。
    恐らく、[1,3,8,256,256]みたいにチャネル数が前に来ると思います。

    キャンセル

まだ回答がついていません

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 87.37%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

関連した質問

同じタグがついた質問を見る