Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.jit.script does not work with Tensorized Models #33

Open
hello-fri-end opened this issue Dec 1, 2023 · 1 comment
Open

torch.jit.script does not work with Tensorized Models #33

hello-fri-end opened this issue Dec 1, 2023 · 1 comment

Comments

@hello-fri-end
Copy link

Minimal Code:

import torch
from torch.nn import Module
from tltorch import FactorizedConv

class Test(Module):
    def __init__(self):
        super(Test, self).__init__()
        self.layer = FactorizedConv(3, 4, 3, factorization='tucker', order=3)

def main():
# Instantiate the model
    model = Test()
    scripted_module = torch.jit.script(model)

if __name__ == "__main__":
    main()

Error:

Traceback (most recent call last):
  File "/workspaces/RepNet-Rex-Solutions/test.py", line 27, in <module>
    main()
  File "/workspaces/RepNet-Rex-Solutions/test.py", line 24, in main
    save_model(model)
  File "/workspaces/RepNet-Rex-Solutions/test.py", line 8, in save_model
    scripted_module = torch.jit.script(model)
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_script.py", line 1324, in script
    return torch.jit._recursive.create_script_module(
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 559, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 632, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_script.py", line 639, in _construct
    init_fn(script_module)
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 608, in init_fn
    scripted = create_script_module_impl(
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 632, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_script.py", line 639, in _construct
    init_fn(script_module)
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 608, in init_fn
    scripted = create_script_module_impl(
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 572, in create_script_module_impl
    method_stubs = stubs_fn(nn_module)
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 899, in infer_methods_to_compile
    stubs.append(make_stub_from_method(nn_module, method))
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 87, in make_stub_from_method
    return make_stub(func, method_name)
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 71, in make_stub
    ast = get_jit_def(func, name, self_name="RecursiveScriptModule")
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/frontend.py", line 372, in get_jit_def
    return build_def(
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/frontend.py", line 422, in build_def
    param_list = build_param_list(ctx, py_def.args, self_name, pdt_arg_types)
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/frontend.py", line 448, in build_param_list
    raise NotSupportedError(ctx_range, _vararg_kwarg_err)
torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/tltorch/factorized_tensors/core.py", line 259
    def forward(self, indices=None, **kwargs):
                                     ~~~~~~~ <--- HERE
        """To use a tensor factorization within a network, use ``tensor.forward``, or, equivalently, ``tensor()`

The main issue here is torch.jit.script doesn't support variable number of arguments and keyword-only arguments with defaults which are present in the forward function of the factorized/tensorized layers.

@JeanKossaifi
Copy link
Member

So does removing the **kwargs fix the issue? Would you be able to open a small PR if so?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants