質問編集履歴

1

該当コードを追加しました。

2021/11/27 21:48

投稿

h_proc
h_proc

スコア68

test CHANGED
File without changes
test CHANGED
@@ -66,6 +66,324 @@
66
66
 
67
67
 
68
68
 
69
+ 以下長いですが、コードになります。
70
+
71
+ ```datasets.py
72
+
73
+ def create_dataset(dataset_name,
74
+
75
+ compute_node_feature_stats=True,
76
+
77
+ node_feature_stats_filename=None,
78
+
79
+ **dataset_params):
80
+
81
+
82
+
83
+ if (dataset_name == 'shrec_16'):
84
+
85
+ dataset = Shrec2016DualPrimal(**dataset_params)
86
+
87
+ elif (dataset_name == 'cubes'):
88
+
89
+ dataset = CubesDualPrimal(**dataset_params)
90
+
91
+ elif (dataset_name == 'coseg'):
92
+
93
+ dataset = CosegDualPrimal(**dataset_params)
94
+
95
+ elif (dataset_name == 'human_seg'):
96
+
97
+ dataset = HumanSegDualPrimal(**dataset_params)
98
+
99
+ elif(dataset_name == 'vessel'):
100
+
101
+ dataset = []
102
+
103
+ else:
104
+
105
+ raise KeyError(
106
+
107
+ f"No known dataset can be generated with the name '{dataset_name}'."
108
+
109
+ )
110
+
111
+
112
+
113
+ node_statistics = None
114
+
115
+ print(dataset)
116
+
117
+ if (compute_node_feature_stats):
118
+
119
+ dataset_params = dataset.input_parameters
120
+
121
+ (primal_graph_mean, primal_graph_std, dual_graph_mean,
122
+
123
+ dual_graph_std) = compute_mean_and_std(
124
+
125
+ dataset=dataset,
126
+
127
+ dataset_params=dataset_params,
128
+
129
+ filename=node_feature_stats_filename)
130
+
131
+ node_statistics = (primal_graph_mean, primal_graph_std, dual_graph_mean,
132
+
133
+ dual_graph_std)
134
+
135
+ return dataset, node_statistics
136
+
137
+
138
+
139
+
140
+
141
+ def compute_mean_and_std(dataset=None, dataset_params=None, filename=None):
142
+
143
+
144
+
145
+ if (dataset_params is not None):
146
+
147
+ for param_keyword in ['mean', 'std']:
148
+
149
+ for graph_keyword in ['primal', 'dual']:
150
+
151
+ keyword = f"{graph_keyword}_{param_keyword}"
152
+
153
+ if (keyword in dataset_params):
154
+
155
+ raise KeyError(
156
+
157
+ f"The parameters of the input dataset already contain "
158
+
159
+ f"an entry '{keyword}'. Exiting.")
160
+
161
+ file_exists = False
162
+
163
+ if (filename is not None):
164
+
165
+ # Load the data from disk, if the file exists.
166
+
167
+ if (os.path.exists(filename)):
168
+
169
+ file_exists = True
170
+
171
+ if (file_exists):
172
+
173
+ assert (dataset_params is not None)
174
+
175
+ assert (isinstance(dataset_params, dict))
176
+
177
+ try:
178
+
179
+ with open(filename, "rb") as f:
180
+
181
+ data_from_disk = pkl.load(f)
182
+
183
+ except IOError:
184
+
185
+ raise IOError(f"Error loading cache mean-std file '{filename}'. "
186
+
187
+ "Exiting.")
188
+
189
+ # Check that the file contains the mean and standard deviation.
190
+
191
+ for keyword in ['primal', 'dual']:
192
+
193
+ if (f'{keyword}_mean' not in data_from_disk):
194
+
195
+ raise KeyError(
196
+
197
+ f"Cached file '{filename}' does not contain the mean of "
198
+
199
+ f"the {keyword}-graph node features. Exiting.")
200
+
201
+ if (f'{keyword}_std' not in data_from_disk):
202
+
203
+ raise KeyError(
204
+
205
+ f"Cached file '{filename}' does not contain the standard "
206
+
207
+ f"deviation of the {keyword}-graph node features. Exiting.")
208
+
209
+ # Check that the size of the dataset is compatible.
210
+
211
+ try:
212
+
213
+ size_dataset_of_file = data_from_disk['dataset_size']
214
+
215
+ except KeyError:
216
+
217
+ raise KeyError(
218
+
219
+ f"Cached file '{filename}' does not contain the dataset size. "
220
+
221
+ f"Exiting.")
222
+
223
+ current_dataset_size = len(dataset)
224
+
225
+ if (size_dataset_of_file != current_dataset_size):
226
+
227
+ warnings.warn("Please note that the current dataset has size "
228
+
229
+ f"{current_dataset_size}, whereas the cached file ("
230
+
231
+ f"'{filename}') was generated from a dataset of size "
232
+
233
+ f"{size_dataset_of_file}.")
234
+
235
+
236
+
237
+ # Check that the parameters match.
238
+
239
+ for param_name, param_value in dataset_params.items():
240
+
241
+ if (param_name not in data_from_disk):
242
+
243
+ raise KeyError(
244
+
245
+ f"Could not find dataset parameter {param_name} in the "
246
+
247
+ f"cached file '{filename}'. Please provide a different "
248
+
249
+ "filename.")
250
+
251
+ else:
252
+
253
+ if (data_from_disk[param_name] != param_value):
254
+
255
+ raise ValueError(
256
+
257
+ f"Cached file '{filename}' is incompatible with "
258
+
259
+ f"current dataset. Expected parameter {param_name} to "
260
+
261
+ f"be {param_value}, found "
262
+
263
+ f"{data_from_disk[param_name]}. Please provide a "
264
+
265
+ "different filename.")
266
+
267
+ for cached_param_name in dataset_params.keys():
268
+
269
+ if (cached_param_name in [
270
+
271
+ 'primal_mean', 'primal_std', 'dual_mean', 'dual_std'
272
+
273
+ ]):
274
+
275
+ continue
276
+
277
+ if (cached_param_name not in dataset_params):
278
+
279
+ raise KeyError(
280
+
281
+ f"Cached file '{filename}' is incompatible with "
282
+
283
+ "current dataset, as it contains parameter "
284
+
285
+ f"{cached_param_name}, which is missing in the input "
286
+
287
+ "dataset. Please provide a different filename.")
288
+
289
+ # Return the cached data.
290
+
291
+ primal_graph_mean = data_from_disk['primal_mean']
292
+
293
+ primal_graph_std = data_from_disk['primal_std']
294
+
295
+ dual_graph_mean = data_from_disk['dual_mean']
296
+
297
+ dual_graph_std = data_from_disk['dual_std']
298
+
299
+ else:
300
+
301
+ # Compute the mean and standard deviation of the node features from
302
+
303
+ # scratch.
304
+
305
+ primal_graph_xs = torch.empty([0, dataset[0][0].x.shape[1]])
306
+
307
+ print('len',primal_graph_xs.size())
308
+
309
+ dual_graph_xs = torch.empty([0, dataset[0][1].x.shape[1]])
310
+
311
+ for sample_idx, (primal_graph, dual_graph, _, _) in enumerate(dataset):
312
+
313
+ primal_graph_xs = torch.cat([primal_graph_xs, primal_graph.x])
314
+
315
+ dual_graph_xs = torch.cat([dual_graph_xs, dual_graph.x])
316
+
317
+ assert (len(dataset) == sample_idx + 1)
318
+
319
+ primal_graph_mean = primal_graph_xs.mean(axis=0).numpy()
320
+
321
+ primal_graph_std = primal_graph_xs.std(axis=0).numpy()
322
+
323
+ dual_graph_mean = dual_graph_xs.mean(axis=0).numpy()
324
+
325
+ dual_graph_std = dual_graph_xs.std(axis=0).numpy()
326
+
327
+ assert (np.all(
328
+
329
+ primal_graph_std > 10 * np.finfo(primal_graph_std.dtype).eps))
330
+
331
+ assert (np.all(
332
+
333
+ dual_graph_std > 10 * np.finfo(dual_graph_std.dtype).eps))
334
+
335
+
336
+
337
+ if (filename is not None):
338
+
339
+ # Save the values to file, together with the dataset parameters and
340
+
341
+ # the dataset size, if required.
342
+
343
+ if (dataset_params is None):
344
+
345
+ dataset_params = {}
346
+
347
+ output_values = {
348
+
349
+ **dataset_params, 'primal_mean': primal_graph_mean,
350
+
351
+ 'primal_std': primal_graph_std,
352
+
353
+ 'dual_mean': dual_graph_mean,
354
+
355
+ 'dual_std': dual_graph_std,
356
+
357
+ 'dataset_size': sample_idx + 1
358
+
359
+ }
360
+
361
+ try:
362
+
363
+ with open(filename, 'wb') as f:
364
+
365
+ pkl.dump(output_values, f)
366
+
367
+ except IOError:
368
+
369
+ raise IOError(
370
+
371
+ "Unable to save mean-std data to file at location "
372
+
373
+ f"{filename}.")
374
+
375
+
376
+
377
+ return (primal_graph_mean, primal_graph_std, dual_graph_mean,
378
+
379
+ dual_graph_std)
380
+
381
+
382
+
383
+ ```
384
+
385
+
386
+
69
387
  分かる方がいらっしゃいましたら回答いただけますと幸いです。
70
388
 
71
389
  よろしくお願い致します。