質問編集履歴

1

追記

2018/05/24 04:37

投稿

kokawa2003
kokawa2003

スコア217

test CHANGED
File without changes
test CHANGED
@@ -2,12 +2,14 @@
2
2
 
3
3
  画像の読み込みでエラーが出て困っています。
4
4
 
5
- 以下の処理を実行すると最後の行でエラーにります。
5
+ 以下の処理を実行するとload_imageの最後でエラーにります。
6
6
 
7
7
 
8
8
 
9
9
  ```lang-python
10
10
 
11
+ from __future__ import print_function, division
12
+
11
13
  import numpy as np
12
14
 
13
15
  import os, re
@@ -50,6 +52,270 @@
50
52
 
51
53
  return xp.asarray(image, dtype=np.float32).transpose(2, 0, 1)
52
54
 
55
+
56
+
57
+ def gram_matrix(y):
58
+
59
+ b, ch, h, w = y.data.shape
60
+
61
+ features = F.reshape(y, (b, ch, w*h))
62
+
63
+ gram = F.batch_matmul(features, features, transb=True)/np.float32(ch*w*h)
64
+
65
+ return gram
66
+
67
+
68
+
69
+ def total_variation(x):
70
+
71
+ xp = cuda.get_array_module(x.data)
72
+
73
+ b, ch, h, w = x.data.shape
74
+
75
+ wh = Variable(xp.asarray([[[[1], [-1]], [[0], [0]], [[0], [0]]], [[[0], [0]], [[1], [-1]], [[0], [0]]], [[[0], [0]], [[0], [0]], [[1], [-1]]]], dtype=np.float32), volatile=x.volatile)
76
+
77
+ ww = Variable(xp.asarray([[[[1, -1]], [[0, 0]], [[0, 0]]], [[[0, 0]], [[1, -1]], [[0, 0]]], [[[0, 0]], [[0, 0]], [[1, -1]]]], dtype=np.float32), volatile=x.volatile)
78
+
79
+ return F.sum(F.convolution_2d(x, W=wh) ** 2) + F.sum(F.convolution_2d(x, W=ww) ** 2)
80
+
81
+
82
+
83
+ parser = argparse.ArgumentParser(description='Real-time style transfer')
84
+
85
+ parser.add_argument('--gpu', '-g', default=-1, type=int,
86
+
87
+ help='GPU ID (negative value indicates CPU)')
88
+
89
+ parser.add_argument('--dataset', '-d', default='dataset', type=str,
90
+
91
+ help='dataset directory path (according to the paper, use MSCOCO 80k images)')
92
+
93
+ parser.add_argument('--style_image', '-s', type=str, required=True,
94
+
95
+ help='style image path')
96
+
97
+ parser.add_argument('--batchsize', '-b', type=int, default=1,
98
+
99
+ help='batch size (default value is 1)')
100
+
101
+ parser.add_argument('--initmodel', '-i', default=None, type=str,
102
+
103
+ help='initialize the model from given file')
104
+
105
+ parser.add_argument('--resume', '-r', default=None, type=str,
106
+
107
+ help='resume the optimization from snapshot')
108
+
109
+ parser.add_argument('--output', '-o', default=None, type=str,
110
+
111
+ help='output model file path without extension')
112
+
113
+ parser.add_argument('--lambda_tv', default=1e-6, type=float,
114
+
115
+ help='weight of total variation regularization according to the paper to be set between 10e-4 and 10e-6.')
116
+
117
+ parser.add_argument('--lambda_feat', default=1.0, type=float)
118
+
119
+ parser.add_argument('--lambda_style', default=5.0, type=float)
120
+
121
+ parser.add_argument('--epoch', '-e', default=2, type=int)
122
+
123
+ parser.add_argument('--lr', '-l', default=1e-3, type=float)
124
+
125
+ parser.add_argument('--checkpoint', '-c', default=0, type=int)
126
+
127
+ parser.add_argument('--image_size', default=256, type=int)
128
+
129
+ args = parser.parse_args()
130
+
131
+
132
+
133
+ batchsize = args.batchsize
134
+
135
+
136
+
137
+ image_size = args.image_size
138
+
139
+ n_epoch = args.epoch
140
+
141
+ lambda_tv = args.lambda_tv
142
+
143
+ lambda_f = args.lambda_feat
144
+
145
+ lambda_s = args.lambda_style
146
+
147
+ style_prefix, _ = os.path.splitext(os.path.basename(args.style_image))
148
+
149
+ output = style_prefix if args.output == None else args.output
150
+
151
+ fs = os.listdir(args.dataset)
152
+
153
+ imagepaths = []
154
+
155
+ for fn in fs:
156
+
157
+ base, ext = os.path.splitext(fn)
158
+
159
+ if ext == '.jpg' or ext == '.png':
160
+
161
+ imagepath = os.path.join(args.dataset,fn)
162
+
163
+ imagepaths.append(imagepath)
164
+
165
+ n_data = len(imagepaths)
166
+
167
+ print('num traning images:', n_data)
168
+
169
+ n_iter = n_data // batchsize
170
+
171
+ print(n_iter, 'iterations,', n_epoch, 'epochs')
172
+
173
+
174
+
175
+ model = FastStyleNet()
176
+
177
+ vgg = VGG()
178
+
179
+ serializers.load_npz('vgg16.model', vgg)
180
+
181
+ if args.initmodel:
182
+
183
+ print('load model from', args.initmodel)
184
+
185
+ serializers.load_npz(args.initmodel, model)
186
+
187
+ if args.gpu >= 0:
188
+
189
+ cuda.get_device(args.gpu).use()
190
+
191
+ model.to_gpu()
192
+
193
+ vgg.to_gpu()
194
+
195
+ xp = np if args.gpu < 0 else cuda.cupy
196
+
197
+
198
+
199
+ O = optimizers.Adam(alpha=args.lr)
200
+
201
+ O.setup(model)
202
+
203
+ if args.resume:
204
+
205
+ print('load optimizer state from', args.resume)
206
+
207
+ serializers.load_npz(args.resume, O)
208
+
209
+
210
+
211
+ style = vgg.preprocess(np.asarray(Image.open(args.style_image).convert('RGB').resize((image_size,image_size)), dtype=np.float32))
212
+
213
+ style = xp.asarray(style, dtype=xp.float32)
214
+
215
+ style_b = xp.zeros((batchsize,) + style.shape, dtype=xp.float32)
216
+
217
+ for i in range(batchsize):
218
+
219
+ style_b[i] = style
220
+
221
+ feature_s = vgg(Variable(style_b))
222
+
223
+ gram_s = [gram_matrix(y) for y in feature_s]
224
+
225
+
226
+
227
+ for epoch in range(n_epoch):
228
+
229
+ print('epoch', epoch)
230
+
231
+ for i in range(n_iter):
232
+
233
+ model.zerograds()
234
+
235
+ vgg.zerograds()
236
+
237
+
238
+
239
+ indices = range(i * batchsize, (i+1) * batchsize)
240
+
241
+ x = xp.zeros((batchsize, 3, image_size, image_size), dtype=xp.float32)
242
+
243
+ for j in range(batchsize):
244
+
245
+ x[j] = load_image(imagepaths[i*batchsize + j], image_size)
246
+
247
+
248
+
249
+ xc = Variable(x.copy(), volatile=True)
250
+
251
+ x = Variable(x)
252
+
253
+
254
+
255
+ y = model(x)
256
+
257
+
258
+
259
+ xc -= 120
260
+
261
+ y -= 120
262
+
263
+
264
+
265
+ feature = vgg(xc)
266
+
267
+ feature_hat = vgg(y)
268
+
269
+
270
+
271
+ L_feat = lambda_f * F.mean_squared_error(Variable(feature[2].data), feature_hat[2]) # compute for only the output of layer conv3_3
272
+
273
+
274
+
275
+ L_style = Variable(xp.zeros((), dtype=np.float32))
276
+
277
+ for f, f_hat, g_s in zip(feature, feature_hat, gram_s):
278
+
279
+ L_style += lambda_s * F.mean_squared_error(gram_matrix(f_hat), Variable(g_s.data))
280
+
281
+
282
+
283
+ L_tv = lambda_tv * total_variation(y)
284
+
285
+ L = L_feat + L_style + L_tv
286
+
287
+
288
+
289
+ print('(epoch {}) batch {}/{}... training loss is...{}'.format(epoch, i, n_iter, L.data))
290
+
291
+
292
+
293
+ L.backward()
294
+
295
+ O.update()
296
+
297
+
298
+
299
+ if args.checkpoint > 0 and i % args.checkpoint == 0:
300
+
301
+ serializers.save_npz('models/{}_{}_{}.model'.format(output, epoch, i), model)
302
+
303
+ serializers.save_npz('models/{}_{}_{}.state'.format(output, epoch, i), O)
304
+
305
+
306
+
307
+ print('save "style.model"')
308
+
309
+ serializers.save_npz('models/{}_{}.model'.format(output, epoch), model)
310
+
311
+ serializers.save_npz('models/{}_{}.state'.format(output, epoch), O)
312
+
313
+
314
+
315
+ serializers.save_npz('models/{}.model'.format(output), model)
316
+
317
+ serializers.save_npz('models/{}.state'.format(output), O)
318
+
53
319
  ```
54
320
 
55
321
  Traceback (most recent call last):