質問編集履歴
2
修正
    
        title	
    CHANGED
    
    | 
         
            File without changes
         
     | 
    
        body	
    CHANGED
    
    | 
         @@ -100,7 +100,8 @@ 
     | 
|
| 
       100 
100 
     | 
    
         
             
                    if not os.path.exists(count_image_path):
         
     | 
| 
       101 
101 
     | 
    
         
             
                        r_running = False
         
     | 
| 
       102 
102 
     | 
    
         | 
| 
      
 103 
     | 
    
         
            +
            ```
         
     | 
| 
       103 
     | 
    
         
            -
             
     | 
| 
      
 104 
     | 
    
         
            +
            ###model
         
     | 
| 
       104 
105 
     | 
    
         | 
| 
       105 
106 
     | 
    
         
             
            ```ここに言語を入力
         
     | 
| 
       106 
107 
     | 
    
         
             
            import torch
         
     | 
1
修正
    
        title	
    CHANGED
    
    | 
         
            File without changes
         
     | 
    
        body	
    CHANGED
    
    | 
         @@ -98,6 +98,65 @@ 
     | 
|
| 
       98 
98 
     | 
    
         
             
                        print(r_running)
         
     | 
| 
       99 
99 
     | 
    
         | 
| 
       100 
100 
     | 
    
         
             
                    if not os.path.exists(count_image_path):
         
     | 
| 
      
 101 
     | 
    
         
            +
                        r_running = False
         
     | 
| 
      
 102 
     | 
    
         
            +
             
         
     | 
| 
      
 103 
     | 
    
         
            +
            ```model
         
     | 
| 
       101 
104 
     | 
    
         | 
| 
       102 
     | 
    
         
            -
             
     | 
| 
      
 105 
     | 
    
         
            +
            ```ここに言語を入力
         
     | 
| 
      
 106 
     | 
    
         
            +
            import torch
         
     | 
| 
      
 107 
     | 
    
         
            +
            import torch.nn as nn
         
     | 
| 
      
 108 
     | 
    
         
            +
            import torchvision.models as models
         
     | 
| 
      
 109 
     | 
    
         
            +
            from torch.nn.utils.rnn import pack_padded_sequence
         
     | 
| 
      
 110 
     | 
    
         
            +
             
     | 
| 
      
 111 
     | 
    
         
            +
             
     | 
| 
      
 112 
     | 
    
         
            +
            class EncoderCNN(nn.Module):
         
     | 
| 
      
 113 
     | 
    
         
            +
                def __init__(self, embed_size):
         
     | 
| 
      
 114 
     | 
    
         
            +
                    """Load the pretrained ResNet-152 and replace top fc layer."""
         
     | 
| 
      
 115 
     | 
    
         
            +
                    super(EncoderCNN, self).__init__()
         
     | 
| 
      
 116 
     | 
    
         
            +
                    resnet = models.resnet152(pretrained=True)
         
     | 
| 
      
 117 
     | 
    
         
            +
                    modules = list(resnet.children())[:-1]      # delete the last fc layer.
         
     | 
| 
      
 118 
     | 
    
         
            +
                    self.resnet = nn.Sequential(*modules)
         
     | 
| 
      
 119 
     | 
    
         
            +
                    self.linear = nn.Linear(resnet.fc.in_features, embed_size)
         
     | 
| 
      
 120 
     | 
    
         
            +
                    self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
         
     | 
| 
      
 121 
     | 
    
         
            +
                    
         
     | 
| 
      
 122 
     | 
    
         
            +
                def forward(self, images):
         
     | 
| 
      
 123 
     | 
    
         
            +
                    """Extract feature vectors from input images."""
         
     | 
| 
      
 124 
     | 
    
         
            +
                    with torch.no_grad():
         
     | 
| 
      
 125 
     | 
    
         
            +
                        features = self.resnet(images)
         
     | 
| 
      
 126 
     | 
    
         
            +
                    features = features.reshape(features.size(0), -1)
         
     | 
| 
      
 127 
     | 
    
         
            +
                    features = self.bn(self.linear(features))
         
     | 
| 
      
 128 
     | 
    
         
            +
                    return features
         
     | 
| 
      
 129 
     | 
    
         
            +
             
     | 
| 
      
 130 
     | 
    
         
            +
             
     | 
| 
      
 131 
     | 
    
         
            +
            class DecoderRNN(nn.Module):
         
     | 
| 
      
 132 
     | 
    
         
            +
                def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=20):
         
     | 
| 
      
 133 
     | 
    
         
            +
                    """Set the hyper-parameters and build the layers."""
         
     | 
| 
      
 134 
     | 
    
         
            +
                    super(DecoderRNN, self).__init__()
         
     | 
| 
      
 135 
     | 
    
         
            +
                    self.embed = nn.Embedding(vocab_size, embed_size)
         
     | 
| 
      
 136 
     | 
    
         
            +
                    self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
         
     | 
| 
      
 137 
     | 
    
         
            +
                    self.linear = nn.Linear(hidden_size, vocab_size)
         
     | 
| 
      
 138 
     | 
    
         
            +
                    self.max_seg_length = max_seq_length
         
     | 
| 
      
 139 
     | 
    
         
            +
                    
         
     | 
| 
      
 140 
     | 
    
         
            +
                def forward(self, features, captions, lengths):
         
     | 
| 
      
 141 
     | 
    
         
            +
                    """Decode image feature vectors and generates captions."""
         
     | 
| 
      
 142 
     | 
    
         
            +
                    embeddings = self.embed(captions)
         
     | 
| 
      
 143 
     | 
    
         
            +
                    embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
         
     | 
| 
      
 144 
     | 
    
         
            +
                    packed = pack_padded_sequence(embeddings, lengths, batch_first=True) 
         
     | 
| 
      
 145 
     | 
    
         
            +
                    hiddens, _ = self.lstm(packed)
         
     | 
| 
      
 146 
     | 
    
         
            +
                    outputs = self.linear(hiddens[0])
         
     | 
| 
      
 147 
     | 
    
         
            +
                    return outputs
         
     | 
| 
      
 148 
     | 
    
         
            +
                
         
     | 
| 
      
 149 
     | 
    
         
            +
                def sample(self, features, states=None):
         
     | 
| 
      
 150 
     | 
    
         
            +
                    """Generate captions for given image features using greedy search."""
         
     | 
| 
      
 151 
     | 
    
         
            +
                    sampled_ids = []
         
     | 
| 
      
 152 
     | 
    
         
            +
                    inputs = features.unsqueeze(1)
         
     | 
| 
      
 153 
     | 
    
         
            +
                    for i in range(self.max_seg_length):
         
     | 
| 
      
 154 
     | 
    
         
            +
                        hiddens, states = self.lstm(inputs, states)          # hiddens: (batch_size, 1, hidden_size)
         
     | 
| 
      
 155 
     | 
    
         
            +
                        outputs = self.linear(hiddens.squeeze(1))            # outputs:  (batch_size, vocab_size)
         
     | 
| 
      
 156 
     | 
    
         
            +
                        _, predicted = outputs.max(1)                        # predicted: (batch_size)
         
     | 
| 
      
 157 
     | 
    
         
            +
                        sampled_ids.append(predicted)
         
     | 
| 
      
 158 
     | 
    
         
            +
                        inputs = self.embed(predicted)                       # inputs: (batch_size, embed_size)
         
     | 
| 
      
 159 
     | 
    
         
            +
                        inputs = inputs.unsqueeze(1)                         # inputs: (batch_size, 1, embed_size)
         
     | 
| 
      
 160 
     | 
    
         
            +
                    sampled_ids = torch.stack(sampled_ids, 1)                # sampled_ids: (batch_size, max_seq_length)
         
     | 
| 
      
 161 
     | 
    
         
            +
                    return sampled_ids
         
     | 
| 
       103 
162 
     | 
    
         
             
            ```
         
     |