質問編集履歴
2
修正
test
CHANGED
File without changes
|
test
CHANGED
@@ -202,7 +202,9 @@
|
|
202
202
|
|
203
203
|
|
204
204
|
|
205
|
+
```
|
206
|
+
|
205
|
-
|
207
|
+
###model
|
206
208
|
|
207
209
|
|
208
210
|
|
1
修正
test
CHANGED
File without changes
|
test
CHANGED
@@ -198,8 +198,126 @@
|
|
198
198
|
|
199
199
|
if not os.path.exists(count_image_path):
|
200
200
|
|
201
|
-
|
201
|
+
r_running = False
|
202
|
+
|
203
|
+
|
204
|
+
|
202
|
-
|
205
|
+
```model
|
206
|
+
|
207
|
+
|
208
|
+
|
203
|
-
|
209
|
+
```ここに言語を入力
|
210
|
+
|
211
|
+
import torch
|
212
|
+
|
213
|
+
import torch.nn as nn
|
214
|
+
|
215
|
+
import torchvision.models as models
|
216
|
+
|
217
|
+
from torch.nn.utils.rnn import pack_padded_sequence
|
218
|
+
|
219
|
+
|
220
|
+
|
221
|
+
|
222
|
+
|
223
|
+
class EncoderCNN(nn.Module):
|
224
|
+
|
225
|
+
def __init__(self, embed_size):
|
226
|
+
|
227
|
+
"""Load the pretrained ResNet-152 and replace top fc layer."""
|
228
|
+
|
229
|
+
super(EncoderCNN, self).__init__()
|
230
|
+
|
231
|
+
resnet = models.resnet152(pretrained=True)
|
232
|
+
|
233
|
+
modules = list(resnet.children())[:-1] # delete the last fc layer.
|
234
|
+
|
235
|
+
self.resnet = nn.Sequential(*modules)
|
236
|
+
|
237
|
+
self.linear = nn.Linear(resnet.fc.in_features, embed_size)
|
238
|
+
|
239
|
+
self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
|
240
|
+
|
241
|
+
|
242
|
+
|
243
|
+
def forward(self, images):
|
244
|
+
|
245
|
+
"""Extract feature vectors from input images."""
|
246
|
+
|
247
|
+
with torch.no_grad():
|
248
|
+
|
249
|
+
features = self.resnet(images)
|
250
|
+
|
251
|
+
features = features.reshape(features.size(0), -1)
|
252
|
+
|
253
|
+
features = self.bn(self.linear(features))
|
254
|
+
|
255
|
+
return features
|
256
|
+
|
257
|
+
|
258
|
+
|
259
|
+
|
260
|
+
|
261
|
+
class DecoderRNN(nn.Module):
|
262
|
+
|
263
|
+
def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=20):
|
264
|
+
|
265
|
+
"""Set the hyper-parameters and build the layers."""
|
266
|
+
|
267
|
+
super(DecoderRNN, self).__init__()
|
268
|
+
|
269
|
+
self.embed = nn.Embedding(vocab_size, embed_size)
|
270
|
+
|
271
|
+
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
|
272
|
+
|
273
|
+
self.linear = nn.Linear(hidden_size, vocab_size)
|
274
|
+
|
275
|
+
self.max_seg_length = max_seq_length
|
276
|
+
|
277
|
+
|
278
|
+
|
279
|
+
def forward(self, features, captions, lengths):
|
280
|
+
|
281
|
+
"""Decode image feature vectors and generates captions."""
|
282
|
+
|
283
|
+
embeddings = self.embed(captions)
|
284
|
+
|
285
|
+
embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
|
286
|
+
|
287
|
+
packed = pack_padded_sequence(embeddings, lengths, batch_first=True)
|
288
|
+
|
289
|
+
hiddens, _ = self.lstm(packed)
|
290
|
+
|
291
|
+
outputs = self.linear(hiddens[0])
|
292
|
+
|
293
|
+
return outputs
|
294
|
+
|
295
|
+
|
296
|
+
|
297
|
+
def sample(self, features, states=None):
|
298
|
+
|
299
|
+
"""Generate captions for given image features using greedy search."""
|
300
|
+
|
301
|
+
sampled_ids = []
|
302
|
+
|
303
|
+
inputs = features.unsqueeze(1)
|
304
|
+
|
305
|
+
for i in range(self.max_seg_length):
|
306
|
+
|
307
|
+
hiddens, states = self.lstm(inputs, states) # hiddens: (batch_size, 1, hidden_size)
|
308
|
+
|
309
|
+
outputs = self.linear(hiddens.squeeze(1)) # outputs: (batch_size, vocab_size)
|
310
|
+
|
311
|
+
_, predicted = outputs.max(1) # predicted: (batch_size)
|
312
|
+
|
313
|
+
sampled_ids.append(predicted)
|
314
|
+
|
315
|
+
inputs = self.embed(predicted) # inputs: (batch_size, embed_size)
|
316
|
+
|
317
|
+
inputs = inputs.unsqueeze(1) # inputs: (batch_size, 1, embed_size)
|
318
|
+
|
319
|
+
sampled_ids = torch.stack(sampled_ids, 1) # sampled_ids: (batch_size, max_seq_length)
|
320
|
+
|
321
|
+
return sampled_ids
|
204
322
|
|
205
323
|
```
|