Skip to content

Commit

Permalink
Don't run rewrites when graph is already measurable
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Nov 20, 2024
1 parent f7f8bc5 commit 009959b
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions pymc/logprob/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
)
from pytensor.tensor.variable import TensorVariable

from pymc.logprob.abstract import PromisedValuedRV, ValuedRV, valued_rv
from pymc.logprob.abstract import MeasurableOp, PromisedValuedRV, ValuedRV, valued_rv
from pymc.logprob.utils import DiracDelta
from pymc.pytensorf import toposort_replace

Expand Down Expand Up @@ -175,7 +175,7 @@ def remove_DiracDelta(fgraph, node):
specialization_ir_rewrites_db = EquilibriumDB()
specialization_ir_rewrites_db.name = "specialization_ir_rewrites_db"
logprob_rewrites_db.register(
"specialization_ir_rewrites_db", specialization_ir_rewrites_db, "basic"
"specialization_ir_rewrites_db", specialization_ir_rewrites_db, "basic", "specialize"
)


Expand Down Expand Up @@ -250,7 +250,12 @@ def construct_ir_fgraph(
toposort_replace(fgraph, replacements, reverse=True)

if ir_rewriter is None:
ir_rewriter = logprob_rewrites_db.query(RewriteDatabaseQuery(include=["basic"]))
if all(isinstance(rv.owner.op, MeasurableOp) for rv in ir_rv_values):
# All Ops are already measurable, only run specialize
include = ["specialize"]
else:
include = ["basic"]
ir_rewriter = logprob_rewrites_db.query(RewriteDatabaseQuery(include=include))
ir_rewriter.rewrite(fgraph)

# Reintroduce original value variables
Expand Down

0 comments on commit 009959b

Please sign in to comment.