Skip to content

Commit

Permalink
Add condition for dim consideration
Browse files Browse the repository at this point in the history
  • Loading branch information
shivam096 committed Jan 24, 2025
1 parent a254c04 commit ca6853c
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
6 changes: 5 additions & 1 deletion tests/fixtures/misc/checker/logsumexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
# logsumexp
y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True))
y = torch.log(torch.sum(torch.exp(2.5 + x), 1))
y = torch.log(torch.sum(torch.exp(1),1,1))
y = torch.log(torch.sum(torch.exp(1),1,dim=1))

# not logsumexp
y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True) + 2.5)
y = torch.log(torch.sum(torch.exp(x) + 2.5, 1))
y = torch.log(2 + x)
y = torch.sum(torch.log(torch.exp(x)), 1)
y = torch.exp(torch.sum(torch.log(x), 1, keepdim=True))
y = torch.log(torch.sum(torch.exp(2.5)))
y = torch.log(torch.sum(torch.exp(2.5))) # this should not be flagged as the second argument is missing for sum function call
y = torch.sum(torch.log(torch.exp(x)), dim=1)
y = torch.sum(torch.log(torch.exp(x)), dim=None)
2 changes: 2 additions & 0 deletions tests/fixtures/misc/checker/logsumexp.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
6:5 TOR108 Use numerically stabilized `torch.logsumexp`.
7:5 TOR108 Use numerically stabilized `torch.logsumexp`.
8:5 TOR108 Use numerically stabilized `torch.logsumexp`.
9:5 TOR108 Use numerically stabilized `torch.logsumexp`.
18 changes: 11 additions & 7 deletions torchfix/visitors/misc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,14 @@ def visit_Call(self, node):
node.args[0].value.args[0].value
)
== "torch.exp"
) and len(node.args[0].value.args) > 1 and node.args[0].value.args[1].value is not None:
self.add_violation(
node,
error_code=self.ERRORS[0].error_code,
message=self.ERRORS[0].message(),
replacement=None,
)
):
if len(node.args[0].value.args) > 1 and (
node.args[0].value.args[1].value is not None
or self.has_specific_arg(node.args[0].value, "dim", -1)
):
self.add_violation(
node,
error_code=self.ERRORS[0].error_code,
message=self.ERRORS[0].message(),
replacement=None,
)

0 comments on commit ca6853c

Please sign in to comment.