Skip to content

Commit

Permalink
SpecDB: Add OutTensor specs for add.Tensor & add.Scalar (#5)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #5

Added the out tensor spec for:
- add.Tensor: promoted type of inputs must be castable to out dtype
- add.Scalar: promoted type of inputs must be equal to out dtype

Reviewed By: zonglinpengmeta

Differential Revision: D59402158
  • Loading branch information
manuelcandales authored and facebook-github-bot committed Jul 5, 2024
1 parent 221e7e7 commit 279ad5c
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions specdb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,16 @@
),
],
outspec=[
OutArg(ArgType.Tensor),
OutArg(
ArgType.Tensor,
constraints=[
cp.Dtype.In(
lambda deps: dt.can_cast_from(
torch.promote_types(deps[0].dtype, deps[1].dtype)
)
),
],
),
],
),
Spec(
Expand Down Expand Up @@ -373,7 +382,16 @@
),
],
outspec=[
OutArg(ArgType.Tensor),
OutArg(
ArgType.Tensor,
constraints=[
cp.Dtype.Eq(
lambda deps: (
fn.promote_type_with_scalar(deps[0].dtype, deps[1])
)
),
],
)
],
),
Spec(
Expand Down

0 comments on commit 279ad5c

Please sign in to comment.