forked from gmalivenko/onnx2keras
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
32 lines (26 loc) · 1.24 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import io
import torch
import onnx
from onnx2keras import onnx_to_keras, check_torch_keras_error
def torch2keras(model: torch.nn.Module, input_variable, verbose=True, change_ordering=False):
if isinstance(input_variable, (tuple, list)):
input_variable = tuple(torch.FloatTensor(var) for var in input_variable)
input_names = [f'test_in{i}' for i, _ in enumerate(input_variable)]
else:
input_variable = torch.FloatTensor(input_variable)
input_names = ['test_in']
temp_f = io.BytesIO()
torch.onnx.export(model, input_variable, temp_f, verbose=verbose, input_names=input_names,
output_names=['test_out'])
temp_f.seek(0)
onnx_model = onnx.load(temp_f)
k_model = onnx_to_keras(onnx_model, input_names, change_ordering=change_ordering)
return k_model
def convert_and_test(model: torch.nn.Module,
input_variable,
verbose=True,
change_ordering=False,
epsilon=1e-5):
k_model = torch2keras(model, input_variable, verbose=verbose, change_ordering=change_ordering)
error = check_torch_keras_error(model, k_model, input_variable, change_ordering=change_ordering, epsilon=epsilon)
return error