Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How do I export a darts' TCNModel to ONNX? #2617

Open
lmsasu opened this issue Dec 12, 2024 · 3 comments · May be fixed by #2620
Open

How do I export a darts' TCNModel to ONNX? #2617

lmsasu opened this issue Dec 12, 2024 · 3 comments · May be fixed by #2620
Labels
question Further information is requested

Comments

@lmsasu
Copy link

lmsasu commented Dec 12, 2024

I have a TCNModel that is trained on a time series:

model_air = TCNModel(
        input_chunk_length=13,
        output_chunk_length=12,
        n_epochs=3,
        dropout=0.1,
        dilation_base=2,
        weight_norm=True,
        kernel_size=5,
        num_filters=3,
        random_state=0,
        save_checkpoints=True,
        model_name=model_name,
        force_reset=True,
    )

According to darts documentation, the inner model_air.model of type darts.models.forecasting.tcn_model._TCNModule (derived from PLPastCovariatesModule) can be serialized to ONNX via to_onnx call. However, no matter what I try, I get errors at export:
model_air.model.to_onnx('model_air.onnx')
produces ValueError: Could not export to ONNX since neither input_sample nor model.example_input_array attribute is set.
Then

dummy_input = torch.randn(1, 13, 1)
model_air.model.to_onnx('model_air.onnx', input_sample=dummy_input)

outputs ValueError: not enough values to unpack (expected 2, got 1)
and

dummy_input = (torch.randn(1, 13, 1), None)
model_air.model.to_onnx('model_air.onnx', input_sample=dummy_input)

produces TypeError: _TCNModule.forward() takes 2 positional arguments but 3 were given.

What input value should I give to model_air.model.to_onnx()?

@lmsasu lmsasu added question Further information is requested triage Issue waiting for triaging labels Dec 12, 2024
@madtoinou
Copy link
Collaborator

Hi @lmsasu,

This feature is not officially supported yet (I hope to be able to back at it sometimes soon) but you can find a workaround in #1521.

Note that once the model is exported to this format, you will be responsible for generating the inputs so make sure to have a look at the Dataset class associated with the forecasting model class in order to have an idea of the order of the features for example.

@madtoinou madtoinou removed the triage Issue waiting for triaging label Dec 12, 2024
@lmsasu
Copy link
Author

lmsasu commented Dec 12, 2024

Hi @madtoinou, thanks for reply, the workaround is helpful.
After recovering the _TCNModule model from ONNX, and preparing the Dataset accordingly, how do I use it? the _TCNModule object is hosted by a TCNModel instance, should I create it with same init params and overwrite its model attribute with the recovered one from onnx?

@madtoinou
Copy link
Collaborator

Hi @lmsasu,

Once you have exported the weights from the model to the ONNX format, you are technically outside of Darts boundaries; you need to then load them either in an inference runtime or another framework (see https://onnx.ai/get-started.html). If you want to stay in the Darts/Python environment, we would recommend using checkpoints instead of the ONNX format to export the model.

@madtoinou madtoinou linked a pull request Dec 17, 2024 that will close this issue
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants