このサイトの実装を行っています。
training_jobs/mesh_classification/shrec_16/training_job_1.pyを実行すると、
pd_mesh_net/utils/datasets.pyの201行目
primal_graph_xs = torch.empty([0, dataset[0][0].x.shape[1]])
で
NotImplementedError
というエラーが表示されます。
試しに、datasetを表示させてみたところ、
Shrec2016DualPrimal(480, categories=['alien', 'ants', 'armadillo', 'bird1', 'bird2', 'camel', 'cat', 'centaur', 'dino_ske', 'dinosaur', 'dog1', 'dog2', 'flamingo', 'glasses', 'gorilla', 'hand', 'horse', 'lamp', 'laptop', 'man', 'myScissor', 'octopus', 'pliers', 'rabbit', 'santa', 'shark', 'snake', 'spiders', 'two_balls', 'woman'])
と返ってきました。
しかし、dataset[0]
や dataset[0][0]
をprintで表示させると、
NotImplementedError
というエラーが表示されます。これはどいうことなのでしょうか?
ここで使用しているshrec16というデータは、
このコードと同様に書かれていることが調べていて分かったのですが、このdatasetは文字列で返されているのでしょうか?
そうだとすると、
dataset[0][0].x.shape[1]]
これはどのように取得すればよいのでしょうか?
以下長いですが、コードになります。
datasets.py
1def create_dataset(dataset_name, 2 compute_node_feature_stats=True, 3 node_feature_stats_filename=None, 4 **dataset_params): 5 6 if (dataset_name == 'shrec_16'): 7 dataset = Shrec2016DualPrimal(**dataset_params) 8 elif (dataset_name == 'cubes'): 9 dataset = CubesDualPrimal(**dataset_params) 10 elif (dataset_name == 'coseg'): 11 dataset = CosegDualPrimal(**dataset_params) 12 elif (dataset_name == 'human_seg'): 13 dataset = HumanSegDualPrimal(**dataset_params) 14 elif(dataset_name == 'vessel'): 15 dataset = [] 16 else: 17 raise KeyError( 18 f"No known dataset can be generated with the name '{dataset_name}'." 19 ) 20 21 node_statistics = None 22 print(dataset) 23 if (compute_node_feature_stats): 24 dataset_params = dataset.input_parameters 25 (primal_graph_mean, primal_graph_std, dual_graph_mean, 26 dual_graph_std) = compute_mean_and_std( 27 dataset=dataset, 28 dataset_params=dataset_params, 29 filename=node_feature_stats_filename) 30 node_statistics = (primal_graph_mean, primal_graph_std, dual_graph_mean, 31 dual_graph_std) 32 return dataset, node_statistics 33 34 35def compute_mean_and_std(dataset=None, dataset_params=None, filename=None): 36 37 if (dataset_params is not None): 38 for param_keyword in ['mean', 'std']: 39 for graph_keyword in ['primal', 'dual']: 40 keyword = f"{graph_keyword}_{param_keyword}" 41 if (keyword in dataset_params): 42 raise KeyError( 43 f"The parameters of the input dataset already contain " 44 f"an entry '{keyword}'. Exiting.") 45 file_exists = False 46 if (filename is not None): 47 # Load the data from disk, if the file exists. 48 if (os.path.exists(filename)): 49 file_exists = True 50 if (file_exists): 51 assert (dataset_params is not None) 52 assert (isinstance(dataset_params, dict)) 53 try: 54 with open(filename, "rb") as f: 55 data_from_disk = pkl.load(f) 56 except IOError: 57 raise IOError(f"Error loading cache mean-std file '{filename}'. " 58 "Exiting.") 59 # Check that the file contains the mean and standard deviation. 60 for keyword in ['primal', 'dual']: 61 if (f'{keyword}_mean' not in data_from_disk): 62 raise KeyError( 63 f"Cached file '{filename}' does not contain the mean of " 64 f"the {keyword}-graph node features. Exiting.") 65 if (f'{keyword}_std' not in data_from_disk): 66 raise KeyError( 67 f"Cached file '{filename}' does not contain the standard " 68 f"deviation of the {keyword}-graph node features. Exiting.") 69 # Check that the size of the dataset is compatible. 70 try: 71 size_dataset_of_file = data_from_disk['dataset_size'] 72 except KeyError: 73 raise KeyError( 74 f"Cached file '{filename}' does not contain the dataset size. " 75 f"Exiting.") 76 current_dataset_size = len(dataset) 77 if (size_dataset_of_file != current_dataset_size): 78 warnings.warn("Please note that the current dataset has size " 79 f"{current_dataset_size}, whereas the cached file (" 80 f"'{filename}') was generated from a dataset of size " 81 f"{size_dataset_of_file}.") 82 83 # Check that the parameters match. 84 for param_name, param_value in dataset_params.items(): 85 if (param_name not in data_from_disk): 86 raise KeyError( 87 f"Could not find dataset parameter {param_name} in the " 88 f"cached file '{filename}'. Please provide a different " 89 "filename.") 90 else: 91 if (data_from_disk[param_name] != param_value): 92 raise ValueError( 93 f"Cached file '{filename}' is incompatible with " 94 f"current dataset. Expected parameter {param_name} to " 95 f"be {param_value}, found " 96 f"{data_from_disk[param_name]}. Please provide a " 97 "different filename.") 98 for cached_param_name in dataset_params.keys(): 99 if (cached_param_name in [ 100 'primal_mean', 'primal_std', 'dual_mean', 'dual_std' 101 ]): 102 continue 103 if (cached_param_name not in dataset_params): 104 raise KeyError( 105 f"Cached file '{filename}' is incompatible with " 106 "current dataset, as it contains parameter " 107 f"{cached_param_name}, which is missing in the input " 108 "dataset. Please provide a different filename.") 109 # Return the cached data. 110 primal_graph_mean = data_from_disk['primal_mean'] 111 primal_graph_std = data_from_disk['primal_std'] 112 dual_graph_mean = data_from_disk['dual_mean'] 113 dual_graph_std = data_from_disk['dual_std'] 114 else: 115 # Compute the mean and standard deviation of the node features from 116 # scratch. 117 primal_graph_xs = torch.empty([0, dataset[0][0].x.shape[1]]) 118 print('len',primal_graph_xs.size()) 119 dual_graph_xs = torch.empty([0, dataset[0][1].x.shape[1]]) 120 for sample_idx, (primal_graph, dual_graph, _, _) in enumerate(dataset): 121 primal_graph_xs = torch.cat([primal_graph_xs, primal_graph.x]) 122 dual_graph_xs = torch.cat([dual_graph_xs, dual_graph.x]) 123 assert (len(dataset) == sample_idx + 1) 124 primal_graph_mean = primal_graph_xs.mean(axis=0).numpy() 125 primal_graph_std = primal_graph_xs.std(axis=0).numpy() 126 dual_graph_mean = dual_graph_xs.mean(axis=0).numpy() 127 dual_graph_std = dual_graph_xs.std(axis=0).numpy() 128 assert (np.all( 129 primal_graph_std > 10 * np.finfo(primal_graph_std.dtype).eps)) 130 assert (np.all( 131 dual_graph_std > 10 * np.finfo(dual_graph_std.dtype).eps)) 132 133 if (filename is not None): 134 # Save the values to file, together with the dataset parameters and 135 # the dataset size, if required. 136 if (dataset_params is None): 137 dataset_params = {} 138 output_values = { 139 **dataset_params, 'primal_mean': primal_graph_mean, 140 'primal_std': primal_graph_std, 141 'dual_mean': dual_graph_mean, 142 'dual_std': dual_graph_std, 143 'dataset_size': sample_idx + 1 144 } 145 try: 146 with open(filename, 'wb') as f: 147 pkl.dump(output_values, f) 148 except IOError: 149 raise IOError( 150 "Unable to save mean-std data to file at location " 151 f"{filename}.") 152 153 return (primal_graph_mean, primal_graph_std, dual_graph_mean, 154 dual_graph_std) 155
分かる方がいらっしゃいましたら回答いただけますと幸いです。
よろしくお願い致します。
あなたの回答
tips
プレビュー