Skip to content

Commit

Permalink
Fix ONNX compatible unfold
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova authored Dec 5, 2024
1 parent e2828ff commit 43c21d0
Showing 1 changed file with 30 additions and 6 deletions.
36 changes: 30 additions & 6 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,36 @@ class PatchingSpec:
# An ONNX-export-compatible version of `tensor.unfold`. Without this, we get:
# torch.onnx.errors.SymbolicValueError: Unsupported: ONNX export of operator Unfold, input size not accessible.
# See https://github.com/pytorch/pytorch/issues/81871 for more information
def onnx_compatible_unfold(self, dimension, size, step):
num_patches = (self.size(dimension) - size) // step + 1
return torch.stack(
[self[:, i : i + size, :] for i in range(0, num_patches * step, step)],
dim=1,
).transpose(3, 2)
def onnx_compatible_unfold(input_tensor, dimension, size, step):
"""
Custom implementation of torch.unfold without using torch.unfold.
Args:
input_tensor (torch.Tensor): The input tensor.
dimension (int): The dimension to unfold.
size (int): The size of each slice.
step (int): The step size between slices.
Returns:
torch.Tensor: The unfolded tensor.
"""
# Compute the shape of the unfolded output
input_size = input_tensor.size(dimension)
num_slices = (input_size - size) // step + 1

# Permute dimension to the end for easier indexing
input_tensor = input_tensor.transpose(dimension, -1)

# Extract slices
slices = []
for i in range(num_slices):
start = i * step
end = start + size
slices.append(input_tensor[..., start:end])

# Stack slices and permute dimensions back
result = torch.stack(slices, dim=-2).transpose(dimension, -2)
return result


UNSUPPORTED_OPS_PATCHING_SPEC = [PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold)]
Expand Down

0 comments on commit 43c21d0

Please sign in to comment.