Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove batch from shape spec #18

Closed
wants to merge 6 commits into from

Conversation

sheiksadique
Copy link
Collaborator

Fixed input and output shapes generated in nirtorch to ignore batch shape.

@sheiksadique sheiksadique linked an issue Oct 18, 2023 that may be closed by this pull request
@sheiksadique sheiksadique changed the title remove batch froms shape spec remove batch from shape spec Oct 18, 2023
@@ -49,7 +49,9 @@ def extract_nir_graph(

# Convert the nodes and get indices
nir_edges = []
nir_nodes = {"input": nir.Input(np.array(sample_data.shape))}
nir_nodes = {
"input": nir.Input(np.array(sample_data.shape[1:]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it make sense to add a flag is_batched and only remove the first dimension if this flag is true? I think we would always have batched input, so leaving it like this would also be fine with me

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In torch, the usual convention is to always have the batch dimension. So I would think it is safer to do this than to expect all other modules to add this flag of having a batch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good to me!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to think about this for a bit. I don't think I understand the premise. Why do you have to modify the sample data? Can't the user just not include the batch dimension?

I'm asking because none of the PyTorch modules (linear, conv, ...) requires a batch dimension to evaluate them. Can't we just specify that whatever the user puts in, the user gets (with or without a batch dim)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasn't even aware this was possible! Alright I have an alternative solution, we can only look at the last necessary dimensions and ignore the other dims perhaps?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I see, we can't do that because we don't know what the dimensionality of the output or input is going to be in the first place.

@stevenabreu7 where do you suggest is_batched flag to go?

Copy link
Collaborator

@Jegp Jegp Oct 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what problem we're addressing at the moment. Can I ask why the shape of the input isn't sufficient? Wouldn't this be solved by something like extract_nir_graph(..., data.squeeze())?

tests/test_from_nir.py Show resolved Hide resolved
Jegp added a commit that referenced this pull request Oct 20, 2023
@Jegp
Copy link
Collaborator

Jegp commented Oct 20, 2023

Merged into #13

Nice work!

@Jegp Jegp closed this Oct 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Input node retains batch dimension
3 participants