質問編集履歴

2

修正

2021/12/14 07:58

投稿

退会済みユーザー
test CHANGED
File without changes
test CHANGED
@@ -202,7 +202,9 @@
202
202
 
203
203
 
204
204
 
205
+ ```
206
+
205
- ```model
207
+ ###model
206
208
 
207
209
 
208
210
 

1

修正

2021/12/14 07:58

投稿

退会済みユーザー
test CHANGED
File without changes
test CHANGED
@@ -198,8 +198,126 @@
198
198
 
199
199
  if not os.path.exists(count_image_path):
200
200
 
201
-
201
+ r_running = False
202
+
203
+
204
+
202
-
205
+ ```model
206
+
207
+
208
+
203
-
209
+ ```ここに言語を入力
210
+
211
+ import torch
212
+
213
+ import torch.nn as nn
214
+
215
+ import torchvision.models as models
216
+
217
+ from torch.nn.utils.rnn import pack_padded_sequence
218
+
219
+
220
+
221
+
222
+
223
+ class EncoderCNN(nn.Module):
224
+
225
+ def __init__(self, embed_size):
226
+
227
+ """Load the pretrained ResNet-152 and replace top fc layer."""
228
+
229
+ super(EncoderCNN, self).__init__()
230
+
231
+ resnet = models.resnet152(pretrained=True)
232
+
233
+ modules = list(resnet.children())[:-1] # delete the last fc layer.
234
+
235
+ self.resnet = nn.Sequential(*modules)
236
+
237
+ self.linear = nn.Linear(resnet.fc.in_features, embed_size)
238
+
239
+ self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
240
+
241
+
242
+
243
+ def forward(self, images):
244
+
245
+ """Extract feature vectors from input images."""
246
+
247
+ with torch.no_grad():
248
+
249
+ features = self.resnet(images)
250
+
251
+ features = features.reshape(features.size(0), -1)
252
+
253
+ features = self.bn(self.linear(features))
254
+
255
+ return features
256
+
257
+
258
+
259
+
260
+
261
+ class DecoderRNN(nn.Module):
262
+
263
+ def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=20):
264
+
265
+ """Set the hyper-parameters and build the layers."""
266
+
267
+ super(DecoderRNN, self).__init__()
268
+
269
+ self.embed = nn.Embedding(vocab_size, embed_size)
270
+
271
+ self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
272
+
273
+ self.linear = nn.Linear(hidden_size, vocab_size)
274
+
275
+ self.max_seg_length = max_seq_length
276
+
277
+
278
+
279
+ def forward(self, features, captions, lengths):
280
+
281
+ """Decode image feature vectors and generates captions."""
282
+
283
+ embeddings = self.embed(captions)
284
+
285
+ embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
286
+
287
+ packed = pack_padded_sequence(embeddings, lengths, batch_first=True)
288
+
289
+ hiddens, _ = self.lstm(packed)
290
+
291
+ outputs = self.linear(hiddens[0])
292
+
293
+ return outputs
294
+
295
+
296
+
297
+ def sample(self, features, states=None):
298
+
299
+ """Generate captions for given image features using greedy search."""
300
+
301
+ sampled_ids = []
302
+
303
+ inputs = features.unsqueeze(1)
304
+
305
+ for i in range(self.max_seg_length):
306
+
307
+ hiddens, states = self.lstm(inputs, states) # hiddens: (batch_size, 1, hidden_size)
308
+
309
+ outputs = self.linear(hiddens.squeeze(1)) # outputs: (batch_size, vocab_size)
310
+
311
+ _, predicted = outputs.max(1) # predicted: (batch_size)
312
+
313
+ sampled_ids.append(predicted)
314
+
315
+ inputs = self.embed(predicted) # inputs: (batch_size, embed_size)
316
+
317
+ inputs = inputs.unsqueeze(1) # inputs: (batch_size, 1, embed_size)
318
+
319
+ sampled_ids = torch.stack(sampled_ids, 1) # sampled_ids: (batch_size, max_seq_length)
320
+
321
+ return sampled_ids
204
322
 
205
323
  ```