From 7c5c7c725f7e0e802430b9eae88f78b89382ef63 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 5 Jul 2024 12:15:26 -0700 Subject: [PATCH] SpecDB: Add OutTensor specs for add.Tensor & add.Scalar (#5) Summary: Pull Request resolved: https://github.com/pytorch-labs/FACTO/pull/5 Added the out tensor spec for: - add.Tensor: promoted type of inputs must be castable to out dtype - add.Scalar: promoted type of inputs must be equal to out dtype Prompted by Jarvis's adoption of FACTO-based testing on D59247125 Reviewed By: zonglinpengmeta Differential Revision: D59402158 fbshipit-source-id: 784471e2b5494dbf82b5003c259ce75e9e63a323 --- specdb/db.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/specdb/db.py b/specdb/db.py index aa1dce6..5cd03d3 100644 --- a/specdb/db.py +++ b/specdb/db.py @@ -341,7 +341,16 @@ ), ], outspec=[ - OutArg(ArgType.Tensor), + OutArg( + ArgType.Tensor, + constraints=[ + cp.Dtype.In( + lambda deps: dt.can_cast_from( + torch.promote_types(deps[0].dtype, deps[1].dtype) + ) + ), + ], + ), ], ), Spec( @@ -373,7 +382,16 @@ ), ], outspec=[ - OutArg(ArgType.Tensor), + OutArg( + ArgType.Tensor, + constraints=[ + cp.Dtype.Eq( + lambda deps: ( + fn.promote_type_with_scalar(deps[0].dtype, deps[1]) + ) + ), + ], + ) ], ), Spec(