From a7036d6f1f179283da5dfb1de8b38baa0372ed63 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Mon, 22 Jul 2024 17:55:05 -0700 Subject: [PATCH] SpecDB: Add spec: topk Reviewed By: digantdesai Differential Revision: D59936969 fbshipit-source-id: 751275065e896be7f8fc9ce35f2ba2851813377b --- specdb/db.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/specdb/db.py b/specdb/db.py index 99f044d..726feb5 100644 --- a/specdb/db.py +++ b/specdb/db.py @@ -3953,6 +3953,45 @@ ], outspec=[OutArg(ArgType.Tensor)], ), + Spec( + op="topk.default", # (Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) + inspec=[ + InPosArg( + ArgType.Tensor, + name="self", + constraints=[cp.Dtype.Ne(lambda deps: torch.bool)], + ), + InPosArg( + ArgType.Length, + name="k", + deps=[0, 2], + constraints=[ + cp.Value.Ge(lambda deps: 0), + cp.Value.Le(lambda deps: fn.safe_size(deps[0], deps[1])), + ], + ), + InPosArg( + ArgType.Dim, + name="dim", + deps=[0], + constraints=DimDefault, + ), + InPosArg(ArgType.Bool, name="largest"), + InPosArg(ArgType.Bool, name="sorted"), + ], + outspec=[ + OutArg( + ArgType.Tensor, + name="values", + constraints=[cp.Dtype.Eq(lambda deps: deps[0].dtype)], + ), + OutArg( + ArgType.Tensor, + name="indices", + constraints=[cp.Dtype.Eq(lambda deps: torch.long)], + ), + ], + ), Spec( op="transpose_copy.int", # (Tensor self, int dim0, int dim1) -> Tensor inspec=[