teratail header banner
teratail header banner
質問するログイン新規登録

質問編集履歴

2

修正

2021/12/14 07:58

投稿

退会済みユーザー
title CHANGED
File without changes
body CHANGED
@@ -100,7 +100,8 @@
100
100
  if not os.path.exists(count_image_path):
101
101
  r_running = False
102
102
 
103
+ ```
103
- ```model
104
+ ###model
104
105
 
105
106
  ```ここに言語を入力
106
107
  import torch

1

修正

2021/12/14 07:58

投稿

退会済みユーザー
title CHANGED
File without changes
body CHANGED
@@ -98,6 +98,65 @@
98
98
  print(r_running)
99
99
 
100
100
  if not os.path.exists(count_image_path):
101
+ r_running = False
102
+
103
+ ```model
101
104
 
102
-
105
+ ```ここに言語を入力
106
+ import torch
107
+ import torch.nn as nn
108
+ import torchvision.models as models
109
+ from torch.nn.utils.rnn import pack_padded_sequence
110
+
111
+
112
+ class EncoderCNN(nn.Module):
113
+ def __init__(self, embed_size):
114
+ """Load the pretrained ResNet-152 and replace top fc layer."""
115
+ super(EncoderCNN, self).__init__()
116
+ resnet = models.resnet152(pretrained=True)
117
+ modules = list(resnet.children())[:-1] # delete the last fc layer.
118
+ self.resnet = nn.Sequential(*modules)
119
+ self.linear = nn.Linear(resnet.fc.in_features, embed_size)
120
+ self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
121
+
122
+ def forward(self, images):
123
+ """Extract feature vectors from input images."""
124
+ with torch.no_grad():
125
+ features = self.resnet(images)
126
+ features = features.reshape(features.size(0), -1)
127
+ features = self.bn(self.linear(features))
128
+ return features
129
+
130
+
131
+ class DecoderRNN(nn.Module):
132
+ def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=20):
133
+ """Set the hyper-parameters and build the layers."""
134
+ super(DecoderRNN, self).__init__()
135
+ self.embed = nn.Embedding(vocab_size, embed_size)
136
+ self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
137
+ self.linear = nn.Linear(hidden_size, vocab_size)
138
+ self.max_seg_length = max_seq_length
139
+
140
+ def forward(self, features, captions, lengths):
141
+ """Decode image feature vectors and generates captions."""
142
+ embeddings = self.embed(captions)
143
+ embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
144
+ packed = pack_padded_sequence(embeddings, lengths, batch_first=True)
145
+ hiddens, _ = self.lstm(packed)
146
+ outputs = self.linear(hiddens[0])
147
+ return outputs
148
+
149
+ def sample(self, features, states=None):
150
+ """Generate captions for given image features using greedy search."""
151
+ sampled_ids = []
152
+ inputs = features.unsqueeze(1)
153
+ for i in range(self.max_seg_length):
154
+ hiddens, states = self.lstm(inputs, states) # hiddens: (batch_size, 1, hidden_size)
155
+ outputs = self.linear(hiddens.squeeze(1)) # outputs: (batch_size, vocab_size)
156
+ _, predicted = outputs.max(1) # predicted: (batch_size)
157
+ sampled_ids.append(predicted)
158
+ inputs = self.embed(predicted) # inputs: (batch_size, embed_size)
159
+ inputs = inputs.unsqueeze(1) # inputs: (batch_size, 1, embed_size)
160
+ sampled_ids = torch.stack(sampled_ids, 1) # sampled_ids: (batch_size, max_seq_length)
161
+ return sampled_ids
103
162
  ```