質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
WSL(Windows Subsystem for Linux)

WSL (Windows Subsystem for Linux) は、Windows10のOS上でLinux向けのバイナリプログラムを実行可能にする機能です。また、WindowsOSのAPIを用いた仕組みを提供しており、Linux側からWindowsOSへのファイルアクセスもできます。

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Q&A

解決済

2回答

1194閲覧

自作データセットを活用する際のモデルを作成時におけるconv2d引数の入れ方がわからない

oinari03

総合スコア59

WSL(Windows Subsystem for Linux)

WSL (Windows Subsystem for Linux) は、Windows10のOS上でLinux向けのバイナリプログラムを実行可能にする機能です。また、WindowsOSのAPIを用いた仕組みを提供しており、Linux側からWindowsOSへのファイルアクセスもできます。

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

0グッド

0クリップ

投稿2020/08/13 08:12

初めに

以前質問させてもらってあるディレクトリから画像をとってきてlabelを付与するという
いわゆる自作データセットみたいなのができました。

次にしようと思ったことがモデルの作成です。条件としてはinitだったりfowardだったりで必ず作りたいです。

前提条件

以下にモデルを作成する前段階を状態を記します。

ディレクトリ構成

├─animal_dataset ├─train │ ├─cat(70枚くらい) │ └─dog(70枚くらい) └─val ├─cat(30枚くらい) └─dog(30枚くらい)

個の画像から0/1でのlabelを付与したdataset.py
以下のような形で画像(data)とlabelを紐づけ、出力しています。

python

1 transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()]) 2 train_dataset = MyDatasets("./animal_dataset", "train", transform) 3 train_dataloader = data.DataLoader(train_dataset, batch_size=32, shuffle= True) 4 5 for data, labels in train_dataloader: 6 print(data.shape, labels.shape) 7 datas, labels = iter(train_dataloader).next() 8 9 s=10 10 pic = transforms.ToPILImage(mode='RGB')(datas[s]) 11 pic.save('./result.jpg') 12 if labels[s].numpy() == 0: 13 print("label: cat") 14 else: 15 print("label: dog")

出力結果
画像dataとラベルlabelのサイズを出力しています。

torch.Size([32, 3, 256, 256]) torch.Size([32]) torch.Size([32, 3, 256, 256]) torch.Size([32]) torch.Size([32, 3, 256, 256]) torch.Size([32]) torch.Size([32, 3, 256, 256]) torch.Size([32]) torch.Size([1, 3, 256, 256]) torch.Size([1]) label: dog

書いてみたコード

これは以下を参照したモデルです。
モデルの書き方参照

class Net(nn.Module): def __init__(self): #親クラスのnn.Moduleのコンストラクタを呼ぶ super(Net,self).__init__() #畳み込み層を定義する #引数は順番に、サンプル数、チャネル数、フィルタのサイズ self.conv1=nn.Conv2d(1,6,(5,5)) #フィルタのサイズは正方形であればタプルではなく整数でも可(8行目と10行目は同じ意味) self.conv2=nn.Conv2d(6,16,5) #全結合層を定義する #fc1の第一引数は、チャネル数*最後のプーリング層の出力のマップのサイズ=特徴量の数 self.fc1=nn.Linear(16*5*5,120) self.fc2=nn.Linear(120,84) self.fc3=nn.Linear(84,10)

わかっていること

・元画像が256×256の正方形であること
・最終的に犬or猫(1 or 0)であることから最後のLinear(x,2)の2になるだろうと予想しています。
・プーリング層の意味や全結合などの意味は把握しているつもりです。

わからないこと(よろしくお願いします。)

modelの作成に際して

・torch.nn.Conv1dやtorch.nn.Conv2dの違いがわからないです。なのでconv2dを使う意図もわかっていません

・このコードを見た時にサンプル数とチャンネル数ってのがなんだかわかりません。何を入れたらいいのでしょうか。
・なのでサンプル数が1でチャンネルが6でフィルタのサイズが5なのでしょうか。どのようなきじゅんで決めるのか?という疑問があります。

・またどうして次のチャンネル数が16なのでしょうか

・conv2dだけでなくlinearの引数の数字にどのような値を入れたらいいのかわかりません。

自分の出力結果からどのような数字を入れるべきかアドバイスをくださいますでしょうか。

また、そもそもcnnでやろうとしているのが間違っていたり、足りない情報があればご指摘くださいませ。
追記いたします

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答2

0

Conv2dを使う理由は、おそらく入力される値が2次元だからだと思います。
また、それ以降のサンプル数や、畳込みのチャンネル数などはほとんど勘です。
Linearにはなるべく大きい値が好ましいですが、値が大きければ大きいほど学習に時間がかかります。

投稿2020/08/13 08:20

Luke02561

総合スコア404

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

0

ベストアンサー

conv2dとconv1dの違いは畳み込むフィルタの形状です。
画像の場合は基本的にconv2dです。

また引数として最低限必要なのは、in_channels: int, out_channels: int, kernel_size: int or tuple,strideです。(サンプル数というのはおそらくin_channelsのことかと)
公式ドキュメントに書いてありますが、
in_channelsは入力のチャンネル数です。(例えば、画像の場合は3です(カラー画像はRGBで3チャンネル、グレースケールの場合は1チャンネル))
out_channelsは出力のチャンネル数です。(これは自由に決めることができます)
kernel_sizeはCNNで畳み込むフィルタのサイズです(例えば5x5のフィルタだと画像における5x5 pixelをたたみ込みます)
strideはフィルタ(kernel)をどのくらい移動させて畳み込むかを表しています。

おすすめはPytorchの公式ドキュメントをみることです。
https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html

また、引数の意味するところは、CNNの仕組みがわからないと厳しいと思います。「ゼロから作るDeepLearning」などの入門書を一度読んでみるのはいかがでしょうか。

投稿2020/08/13 09:53

msoniku

総合スコア36

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問