Skip to content

Fix: Read Lua weights with Pytorch > 1.0 #26

@ariel415el

Description

@ariel415el

from torch.utils.serialization import load_lua doesnt work in current pytorch versions

here is a possible fix with torchfile

class pytorch_lua_wrapper:
    def __init__(self, lua_path):
        self.lua_model = torchfile.load(lua_path)

    def get(self, idx):
        return self.lua_model._obj.modules[idx]._obj

Now you can relace this line:

vgg1 = load_lua(args.vgg1)
with
vgg1 = pytorch_lua_wrapper(args.vgg1)

and this line
self.conv1.weight = torch.nn.Parameter(vgg1.get(0).weight.float())
with
self.conv1.weight =torch.nn.Parameter(torch.from_numpy(vgg1.get(0).weight).float())

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions