Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in exporting soundstream to onnx #254

Open
kalradivyanshu opened this issue Nov 22, 2023 · 14 comments
Open

Error in exporting soundstream to onnx #254

kalradivyanshu opened this issue Nov 22, 2023 · 14 comments

Comments

@kalradivyanshu
Copy link

Has anyone exported the soundstream model to ONNX? I tried:

torch.onnx.export(soundstream, audio, "soundstream.onnx")

but it fails with

/home/divya/.local/lib/python3.8/site-packages/audiolm_pytorch/utils.py:11: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  seq_slice = slice(None, rounded_seq_len) if not from_left else slice(-rounded_seq_len, None)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<@beartype(torch.onnx.utils.export) at 0x7fb5d0936820>", line 369, in export
  File "/home/divya/.local/lib/python3.8/site-packages/torch/onnx/utils.py", line 516, in export
    _export(
  File "/home/divya/.local/lib/python3.8/site-packages/torch/onnx/utils.py", line 1596, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "<@beartype(torch.onnx.utils._model_to_graph) at 0x7fb5d093e700>", line 11, in _model_to_graph
  File "/home/divya/.local/lib/python3.8/site-packages/torch/onnx/utils.py", line 1135, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/home/divya/.local/lib/python3.8/site-packages/torch/onnx/utils.py", line 1011, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/home/divya/.local/lib/python3.8/site-packages/torch/onnx/utils.py", line 915, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "/home/divya/.local/lib/python3.8/site-packages/torch/jit/_trace.py", line 1285, in _get_trace_graph
    outs = ONNXTracedModule(
  File "/home/divya/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/divya/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/divya/.local/lib/python3.8/site-packages/torch/jit/_trace.py", line 133, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/home/divya/.local/lib/python3.8/site-packages/torch/jit/_trace.py", line 124, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/home/divya/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/divya/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/divya/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1508, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/divya/.local/lib/python3.8/site-packages/audiolm_pytorch/soundstream.py", line 827, in forward
    x = self.encoder_attn(x)
  File "/home/divya/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/divya/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/divya/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1508, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/divya/.local/lib/python3.8/site-packages/audiolm_pytorch/soundstream.py", line 436, in forward
    x = attn(x, attn_bias = attn_bias) + x
  File "/home/divya/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/divya/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/divya/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1508, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/divya/.local/lib/python3.8/site-packages/local_attention/transformer.py", line 106, in forward
    out = self.attn_fn(q, k, v, mask = mask, attn_bias = attn_bias)
  File "/home/divya/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/divya/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/divya/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1508, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/divya/.local/lib/python3.8/site-packages/local_attention/local_attention.py", line 126, in forward
    (needed_pad, q), (_, k), (_, v) = map(lambda t: pad_to_multiple(t, self.window_size, dim = -2), (q, k, v))
  File "/home/divya/.local/lib/python3.8/site-packages/local_attention/local_attention.py", line 126, in <lambda>
    (needed_pad, q), (_, k), (_, v) = map(lambda t: pad_to_multiple(t, self.window_size, dim = -2), (q, k, v))
  File "/home/divya/.local/lib/python3.8/site-packages/local_attention/local_attention.py", line 37, in pad_to_multiple
    if m.is_integer():
AttributeError: 'Tensor' object has no attribute 'is_integer'

Any help would be really appreciated, thanks!

@kalradivyanshu kalradivyanshu changed the title Exporting soundstream to onnx Error in exporting soundstream to onnx Nov 22, 2023
@kalradivyanshu
Copy link
Author

Update: the error goes away when use_local_attn=False, but I am guessing it will have adverse affect on the performance of the network?

@lucidrains
Copy link
Owner

@kalradivyanshu ahh darn... can't wait until Tri's flash attention cuda kernel is standard. it has windowed local attention built-in

@lucidrains
Copy link
Owner

Update: the error goes away when use_local_attn=False, but I am guessing it will have adverse affect on the performance of the network?

some researchers have told me local attention does little, but i wouldn't bet against it

@lucidrains
Copy link
Owner

@kalradivyanshu maybe i can offer a full attention causal layer in there? with flash attention, 8k is nothing these days

@kalradivyanshu
Copy link
Author

kalradivyanshu commented Nov 23, 2023

@lucidrains i was able to make the export work by changing a line in local attention, will attach the details when i get home.

I also wanted to ask, if i were to make the soundstream model smaller, its 100mb right now (checkpoint file size), 30mb after quantizing, i were to make it 3x smaller, which layers should i reduce with least impact to accuracy/loss?

I basically want to make it work in client devices, downloading 30+mb seems a lot. I just am using soundstream, as an encoder/decoder.

@lucidrains
Copy link
Owner

@kalradivyanshu oh good to hear! would welcome a PR upstream, in the spirit of open source 🙏

i don't really know, as i haven't been following the research there

@kalradivyanshu
Copy link
Author

kalradivyanshu commented Nov 23, 2023

Oh okay.
@lucidrains
So i basically changed this if https://github.com/lucidrains/local-attention/blob/a415ac6a6078d07484a48a415dea66e853c994fc/local_attention/local_attention.py#L37

from:

    if m.is_integer():
        return False, tensor

to

    if (torch.is_tensor(m) and m.item().is_integer()) or (not torch.is_tensor(m) and m.is_integer()):
         return False, tensor

Because in some function during onnx export, m was going as a 1D tensor with m inside it, so I hacked this together, no idea if it has any other impact.

If you think it has no impact, let me know if I should open a PR.

@lucidrains
Copy link
Owner

@kalradivyanshu hmm, strange, m should never be a tensor

@lucidrains
Copy link
Owner

@kalradivyanshu is your converted onnx model working ok with this change?

@kalradivyanshu
Copy link
Author

kalradivyanshu commented Nov 23, 2023

@lucidrains

from audiolm_pytorch import SoundStream, SoundStreamTrainer
import torch

soundstream = SoundStream(
    codebook_size = 4096,
    rq_num_quantizers = 8,
    rq_groups = 2,                      # this paper proposes using multi-headed residual vector quantization - https://arxiv.org/abs/2305.02765
    use_lookup_free_quantizer = False,  # whether to use residual lookup free quantization
    use_finite_scalar_quantizer = True, # whether to use residual finite scalar quantization
    attn_window_size = 128,             # local attention receptive field at bottleneck
    attn_depth = 2                      # 2 local attention transformer blocks - the soundstream folks were not experts with attention, so i took the liberty to add some. encodec went with lstms, but attention should be better
).cpu()

soundstream.eval()

audio = torch.randn(10080).cpu()
torch.onnx.export(soundstream, audio, "soundstream.onnx", input_names = ["input"], output_names=["output"])

this is my export code.
I am yet to test out the onnx model, will get back to you

@lucidrains
Copy link
Owner

@kalradivyanshu oh i see, ok yea let me know if the onnx model is ok

wow, you are using an unpublished feature! (residual FSQ)

does it work?

@kalradivyanshu
Copy link
Author

I am saving a random untrained soundstream just to test if it can be exported to onnx, will try with a trained one and get back to you.

@lucidrains
Copy link
Owner

ahh, you haven't actually trained it yet, gotcha

@kalradivyanshu
Copy link
Author

no clue on the residual FSQ, will have to see and get back to you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants