Skip to content

Commit

Permalink
Promote in convolution (pytorch#5727)
Browse files Browse the repository at this point in the history
Currently this model:

```
m = timm.create_model("efficientformerv2_s0",
        pretrained=True, scriptable=True).eval()`
```

will fail to convert to stablehlo because we use mixed shape in Conv.
this is allowed in HLO but not in mhlo.

Workaroudn by manually promoting
  • Loading branch information
qihqi authored and chunnienc committed Dec 14, 2023
1 parent c4db786 commit 7a0513c
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion torch_xla/csrc/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,8 @@ xla::XlaOp BuildConvolutionOverrideableBias(
xla::XlaOp bias_broadcast =
xla::Transpose(xla::Broadcast(bias, broadcast_sizes),
BiasTransposePermutation(broadcast_sizes.size() + 1));
return conv + bias_broadcast;
auto promoted = XlaHelpers::Promote(conv, bias_broadcast);
return promoted.first + promoted.second;
}

ConvGrads BuildConvolutionBackwardOverrideable(
Expand Down

0 comments on commit 7a0513c

Please sign in to comment.