Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

export to onnx #51

Open
ghost opened this issue Aug 12, 2020 · 3 comments
Open

export to onnx #51

ghost opened this issue Aug 12, 2020 · 3 comments

Comments

@ghost
Copy link

ghost commented Aug 12, 2020

can the weights to be exported to onnx?

@rtolps
Copy link

rtolps commented Dec 24, 2020

I'm wondering this too. I'm trying this in my own fork but I'm not very experienced and it's not quite working.

@rtolps
Copy link

rtolps commented Dec 27, 2020

Update:
I made a script to convert the model to ONNX:

import torch.onnx
from torch import nn

from utils import *
from dataset import ImageFolder
from networks import *

class Model(object) :
    def __init__(self):
        super().__init__()
        self.genA2B = ResnetGenerator(input_nc=3, output_nc=3, 
                                      ngf=64, n_blocks=4, img_size=256, light=True).to('cpu')

    def forward(self, x):
        out = self.genA2B(x)
        out = nn.functional.interpolate(out, scale_factor=2, 
                                        mode='nearest', align_corners=False)
        out = torch.nn.functional.softmax(out, dim=1)
        return out
model = Model()
params = torch.load('/content/Cats2dogs_ONNX/results/cat2dog/model/cat2dog_params_0002000.pt') #guessing what step is equal too
model.genA2B.load_state_dict(params['genA2B'])
model.genA2B.eval()
random_input = torch.randn(3, 3, 256, 256, dtype=torch.float32)
# you can add however many inputs your model or task requires
 
input_names = ["real_A"]
output_names = ["fake_A2B"]
 
torch.onnx.export(model.genA2B, random_input, 'model.onnx', verbose=False, 
                  input_names=input_names, output_names=output_names, 
                  opset_version=11)

However, there are some issues. The torch.var() operators in the model are NOT supported by any ONNX version yet. Does anyone know a work around on how to get rid of the torch.var operators and replace them with something else and still have the model work?

@rtolps
Copy link

rtolps commented Dec 27, 2020

Update I fixed it! just change all torch.var to torch.std (...) ** 2 and it should export an ONNX model!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant