タイトルにもあるようにTVMによるONNXモデルからcソースコードへの
変換を目指しています。
targetをcにしてrelayに通すと、module.get_source()でソースコードが出てくると言われたので、以下のようにpythonコードを書きました。
#sample1.py import tvm import onnx import tvm.relay as relay from tvm.contrib.download import download_testdata import numpy as np from PIL import Image model_path = 'resnet50-v2-7.onnx' onnx_model = onnx.load(model_path) img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true" img_path = download_testdata(img_url, "cat.png", module="data") img = Image.open(img_path).resize((224, 224)) img_ycbcr = img.convert("YCbCr") # convert to YCbCr img_y, img_cb, img_cr = img_ycbcr.split() x = np.array(img_y)[np.newaxis, np.newaxis, :, :] # target = 'llvm' target = "c" input_name = "1" shape_dict = {input_name: x.shape} mod, params = relay.frontend.from_onnx(onnx_model, shape_dict) with tvm.transform.PassContext(opt_level=1): intrp = relay.build_module.create_executor("graph", mod, tvm.cpu(0), target) dtype = "float32" tvm_output = intrp.evaluate()(tvm.nd.array(x.astype(dtype)), **params).asnumpy() from matplotlib import pyplot as plt out_y = Image.fromarray(np.uint8((tvm_output[0, 0]).clip(0, 255)), mode="L") out_cb = img_cb.resize(out_y.size, Image.BICUBIC) out_cr = img_cr.resize(out_y.size, Image.BICUBIC) result = Image.merge("YCbCr", [out_y, out_cb, out_cr]).convert("RGB") canvas = np.full((672, 672 * 2, 3), 255) canvas[0:224, 0:224, :] = np.asarray(img) canvas[:, 672:, :] = np.asarray(result) plt.imshow(canvas.astype(np.uint8)) plt.show() module.get_source()
python3 sample1.py
実行したところ、以下のエラー?が出力されました。
Traceback (most recent call last): File "sample1.py", line 23, in <module> mod, params = relay.frontend.from_onnx(onnx_model, shape_dict) File "/mnt/c/programing/tvm/python/tvm/relay/frontend/onnx.py", line 5699, in from_onnx mod, params = g.from_onnx(graph, opset) File "/mnt/c/programing/tvm/python/tvm/relay/frontend/onnx.py", line 5367, in from_onnx self._check_user_inputs_in_outermost_graph_scope() File "/mnt/c/programing/tvm/python/tvm/relay/frontend/onnx.py", line 5444, in _check_user_inputs_in_outermost_graph_scope self._shape AssertionError: User specified the shape for inputs that weren't found in the graph: {'1': (1, 1, 224, 224)}
何でもいいのでここら辺に関する知識が欲しいです。
参考にすべきサイト等があればそれも提示していただけると嬉しいです。
よろしくお願いします。
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。