質問編集履歴
2
修正
title
CHANGED
File without changes
|
body
CHANGED
@@ -100,7 +100,8 @@
|
|
100
100
|
if not os.path.exists(count_image_path):
|
101
101
|
r_running = False
|
102
102
|
|
103
|
+
```
|
103
|
-
|
104
|
+
###model
|
104
105
|
|
105
106
|
```ここに言語を入力
|
106
107
|
import torch
|
1
修正
title
CHANGED
File without changes
|
body
CHANGED
@@ -98,6 +98,65 @@
|
|
98
98
|
print(r_running)
|
99
99
|
|
100
100
|
if not os.path.exists(count_image_path):
|
101
|
+
r_running = False
|
102
|
+
|
103
|
+
```model
|
101
104
|
|
102
|
-
|
105
|
+
```ここに言語を入力
|
106
|
+
import torch
|
107
|
+
import torch.nn as nn
|
108
|
+
import torchvision.models as models
|
109
|
+
from torch.nn.utils.rnn import pack_padded_sequence
|
110
|
+
|
111
|
+
|
112
|
+
class EncoderCNN(nn.Module):
|
113
|
+
def __init__(self, embed_size):
|
114
|
+
"""Load the pretrained ResNet-152 and replace top fc layer."""
|
115
|
+
super(EncoderCNN, self).__init__()
|
116
|
+
resnet = models.resnet152(pretrained=True)
|
117
|
+
modules = list(resnet.children())[:-1] # delete the last fc layer.
|
118
|
+
self.resnet = nn.Sequential(*modules)
|
119
|
+
self.linear = nn.Linear(resnet.fc.in_features, embed_size)
|
120
|
+
self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
|
121
|
+
|
122
|
+
def forward(self, images):
|
123
|
+
"""Extract feature vectors from input images."""
|
124
|
+
with torch.no_grad():
|
125
|
+
features = self.resnet(images)
|
126
|
+
features = features.reshape(features.size(0), -1)
|
127
|
+
features = self.bn(self.linear(features))
|
128
|
+
return features
|
129
|
+
|
130
|
+
|
131
|
+
class DecoderRNN(nn.Module):
|
132
|
+
def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=20):
|
133
|
+
"""Set the hyper-parameters and build the layers."""
|
134
|
+
super(DecoderRNN, self).__init__()
|
135
|
+
self.embed = nn.Embedding(vocab_size, embed_size)
|
136
|
+
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
|
137
|
+
self.linear = nn.Linear(hidden_size, vocab_size)
|
138
|
+
self.max_seg_length = max_seq_length
|
139
|
+
|
140
|
+
def forward(self, features, captions, lengths):
|
141
|
+
"""Decode image feature vectors and generates captions."""
|
142
|
+
embeddings = self.embed(captions)
|
143
|
+
embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
|
144
|
+
packed = pack_padded_sequence(embeddings, lengths, batch_first=True)
|
145
|
+
hiddens, _ = self.lstm(packed)
|
146
|
+
outputs = self.linear(hiddens[0])
|
147
|
+
return outputs
|
148
|
+
|
149
|
+
def sample(self, features, states=None):
|
150
|
+
"""Generate captions for given image features using greedy search."""
|
151
|
+
sampled_ids = []
|
152
|
+
inputs = features.unsqueeze(1)
|
153
|
+
for i in range(self.max_seg_length):
|
154
|
+
hiddens, states = self.lstm(inputs, states) # hiddens: (batch_size, 1, hidden_size)
|
155
|
+
outputs = self.linear(hiddens.squeeze(1)) # outputs: (batch_size, vocab_size)
|
156
|
+
_, predicted = outputs.max(1) # predicted: (batch_size)
|
157
|
+
sampled_ids.append(predicted)
|
158
|
+
inputs = self.embed(predicted) # inputs: (batch_size, embed_size)
|
159
|
+
inputs = inputs.unsqueeze(1) # inputs: (batch_size, 1, embed_size)
|
160
|
+
sampled_ids = torch.stack(sampled_ids, 1) # sampled_ids: (batch_size, max_seq_length)
|
161
|
+
return sampled_ids
|
103
162
|
```
|