Skip to content

Commit

Permalink
SpecDB: Add spec: topk
Browse files Browse the repository at this point in the history
Reviewed By: digantdesai

Differential Revision: D59936969

fbshipit-source-id: 751275065e896be7f8fc9ce35f2ba2851813377b
  • Loading branch information
manuelcandales authored and facebook-github-bot committed Jul 23, 2024
1 parent bac333f commit a7036d6
Showing 1 changed file with 39 additions and 0 deletions.
39 changes: 39 additions & 0 deletions specdb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3953,6 +3953,45 @@
],
outspec=[OutArg(ArgType.Tensor)],
),
Spec(
op="topk.default", # (Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)
inspec=[
InPosArg(
ArgType.Tensor,
name="self",
constraints=[cp.Dtype.Ne(lambda deps: torch.bool)],
),
InPosArg(
ArgType.Length,
name="k",
deps=[0, 2],
constraints=[
cp.Value.Ge(lambda deps: 0),
cp.Value.Le(lambda deps: fn.safe_size(deps[0], deps[1])),
],
),
InPosArg(
ArgType.Dim,
name="dim",
deps=[0],
constraints=DimDefault,
),
InPosArg(ArgType.Bool, name="largest"),
InPosArg(ArgType.Bool, name="sorted"),
],
outspec=[
OutArg(
ArgType.Tensor,
name="values",
constraints=[cp.Dtype.Eq(lambda deps: deps[0].dtype)],
),
OutArg(
ArgType.Tensor,
name="indices",
constraints=[cp.Dtype.Eq(lambda deps: torch.long)],
),
],
),
Spec(
op="transpose_copy.int", # (Tensor self, int dim0, int dim1) -> Tensor
inspec=[
Expand Down

0 comments on commit a7036d6

Please sign in to comment.