Skip to content

Commit

Permalink
Apply substitution to the rule predicate
Browse files Browse the repository at this point in the history
  • Loading branch information
geo2a committed Jul 7, 2024
1 parent 5e03d41 commit 391b47c
Showing 1 changed file with 55 additions and 39 deletions.
94 changes: 55 additions & 39 deletions booster/library/Booster/Pattern/Rewrite.hs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ import Booster.Pattern.Match (
MatchResult (MatchFailed, MatchIndeterminate, MatchSuccess),
MatchType (Rewrite),
SortError,
Substitution,
matchTerms,
)
import Booster.Pattern.Pretty
Expand Down Expand Up @@ -153,7 +154,7 @@ data RewriteStepResult a = OnlyTrivial | AppliedRules a deriving (Eq, Show, Func
rewriteStep ::
LoggerMIO io =>
Pattern ->
RewriteT io (RewriteStepResult [(RewriteRule "Rewrite", Pattern)])
RewriteT io (RewriteStepResult [(RewriteRule "Rewrite", Pattern, Substitution)])
rewriteStep pat = do
def <- getDefinition
let getIndex =
Expand All @@ -175,18 +176,18 @@ rewriteStep pat = do
-- return `OnlyTrivial` if all elements of a list are `(r, Nothing)`. If the list is empty or contains at least one `(r, Just p)`,
-- return an `AppliedRules` list of `(r, p)` pairs.
filterOutTrivial ::
[(RewriteRule "Rewrite", Maybe Pattern)] ->
RewriteStepResult [(RewriteRule "Rewrite", Pattern)]
[(RewriteRule "Rewrite", Maybe (Pattern, Substitution))] ->
RewriteStepResult [(RewriteRule "Rewrite", Pattern, Substitution)]
filterOutTrivial = \case
[] -> AppliedRules []
[(_, Nothing)] -> OnlyTrivial
(_, Nothing) : xs -> filterOutTrivial xs
(rule, Just p) : xs -> AppliedRules $ (rule, p) : mapMaybe (\(r, mp) -> (r,) <$> mp) xs
(rule, Just (p, subst)) : xs -> AppliedRules $ (rule, p, subst) : mapMaybe (\(r, mp) -> (\(x, y) -> (r, x, y)) <$> mp) xs

processGroups ::
LoggerMIO io =>
[[RewriteRule "Rewrite"]] ->
RewriteT io [(RewriteRule "Rewrite", Maybe Pattern)]
RewriteT io [(RewriteRule "Rewrite", Maybe (Pattern, Substitution))]
processGroups [] = pure []
processGroups (rules : lowerPriorityRules) = do
-- try all rules of the priority group. This will immediately
Expand All @@ -209,8 +210,10 @@ rewriteStep pat = do
results
-- compute remainder condition here from @nonTrivialResults@ and the remainder up to now.
-- If the new remainder is bottom, then no lower priority rules apply
newRemainder = currentRemainder <> Set.fromList (mapMaybe (snd . snd) nonTrivialResultsWithPartialRemainders)
resultsWithoutRemainders = map (fmap (fmap fst)) results
newRemainder =
currentRemainder
<> Set.fromList (mapMaybe ((\(_, r, _) -> r) . snd) nonTrivialResultsWithPartialRemainders)
resultsWithoutRemainders = map (fmap (fmap (\(p, _, s) -> (p, s)))) results
setRemainder newRemainder
ModifiersRep (_ :: FromModifiersT mods => Proxy mods) <- getPrettyModifiers
withContext CtxRemainder $ logPretty' @mods (collapseAndBools . Set.toList $ newRemainder)
Expand Down Expand Up @@ -272,7 +275,7 @@ applyRule ::
LoggerMIO io =>
Pattern ->
RewriteRule "Rewrite" ->
RewriteT io (Maybe (Maybe (Pattern, Maybe Predicate)))
RewriteT io (Maybe (Maybe (Pattern, Maybe Predicate, Substitution)))
applyRule pat@Pattern{ceilConditions} rule =
withRuleContext rule $
runRewriteRuleAppT $
Expand Down Expand Up @@ -436,11 +439,12 @@ applyRule pat@Pattern{ceilConditions} rule =
ceilConditions
withContext CtxSuccess $ do
case unclearRequiresAfterSmt of
[] -> withPatternContext rewritten $ pure (rewritten, Nothing)
[] -> withPatternContext rewritten $ pure (rewritten, Nothing, subst)
_ ->
let rewritten' = rewritten{constraints = rewritten.constraints <> Set.fromList unclearRequiresAfterSmt}
in withPatternContext rewritten' $
pure (rewritten', Just $ Predicate $ NotBool $ coerce $ collapseAndBools unclearRequiresAfterSmt)
pure
(rewritten', Just $ Predicate $ NotBool $ coerce $ collapseAndBools unclearRequiresAfterSmt, subst)
where
failRewrite :: RewriteFailed "Rewrite" -> RewriteRuleAppT (RewriteT io) a
failRewrite = lift . (throw)
Expand Down Expand Up @@ -841,7 +845,7 @@ performRewrite doTracing def mLlvmLibrary mSolver mbMaxDepth cutLabels terminalL
-- We are stuck here not trivial because we didn't apply a single rule
logMessage ("Rewrite stuck after simplification." :: Text) >> pure (RewriteStuck pat')
pat'@Simplified{} -> logMessage ("Retrying with simplified pattern" :: Text) >> doSteps pat'
AppliedRules [(rule, nextPat)] -- applied single rule
AppliedRules [(rule, nextPat, _subst)] -- applied single rule
-- cut-point rule, stop
| labelOf rule `elem` cutLabels -> do
simplify pat >>= \case
Expand Down Expand Up @@ -880,34 +884,46 @@ performRewrite doTracing def mLlvmLibrary mSolver mbMaxDepth cutLabels terminalL
logMessage $ "Previous state found to be bottom after " <> showCounter counter
pure $ RewriteTrivial pat'
Simplified pat' ->
(catSimplified <$> mapM (\(r, nextPat) -> fmap (r,) <$> simplify (Unsimplified nextPat)) nextPats) >>= \case
[] -> withPatternContext pat' $ do
logMessage ("Rewrite trivial after pruning all branches" :: Text)
pure $ RewriteTrivial pat'
[(rule, nextPat')] -> withPatternContext pat' $ do
logMessage ("All but one branch pruned, continuing" :: Text)
emitRewriteTrace $ RewriteSingleStep (labelOf rule) (uniqueId rule) pat' nextPat'
incrementCounter
doSteps (Simplified nextPat')
nextPats' -> do
emitRewriteTrace $
RewriteBranchingStep pat' $
NE.fromList $
map (\(rule, _) -> (ruleLabelOrLocT rule, uniqueId rule)) nextPats'
unless (Set.null remainderPredicates) $ do
ModifiersRep (_ :: FromModifiersT mods => Proxy mods) <- getPrettyModifiers
withContext CtxRemainder . withContext CtxDetail $
logMessage
( ("Uncovered remainder branch after rewriting with rules " :: Text)
<> ( Text.intercalate ", " $ map (\(r, _) -> getUniqueId $ uniqueId r) nextPats'
)
)
pure $
RewriteBranch pat' $
NE.fromList $
map
(\(r, n) -> (ruleLabelOrLocT r, uniqueId r, n, Just (collapseAndBools . Set.toList $ r.requires)))
nextPats'
( catSimplified
<$> mapM (\(r, nextPat, subst) -> fmap (r,,subst) <$> simplify (Unsimplified nextPat)) nextPats
)
>>= \case
[] -> withPatternContext pat' $ do
logMessage ("Rewrite trivial after pruning all branches" :: Text)
pure $ RewriteTrivial pat'
[(rule, nextPat', _subst)] -> withPatternContext pat' $ do
logMessage ("All but one branch pruned, continuing" :: Text)
emitRewriteTrace $ RewriteSingleStep (labelOf rule) (uniqueId rule) pat' nextPat'
incrementCounter
doSteps (Simplified nextPat')
nextPats' -> do
emitRewriteTrace $
RewriteBranchingStep pat' $
NE.fromList $
map (\(rule, _, _subst) -> (ruleLabelOrLocT rule, uniqueId rule)) nextPats'
unless (Set.null remainderPredicates) $ do
ModifiersRep (_ :: FromModifiersT mods => Proxy mods) <- getPrettyModifiers
withContext CtxRemainder . withContext CtxDetail $
logMessage
( ("Uncovered remainder branch after rewriting with rules " :: Text)
<> ( Text.intercalate ", " $ map (\(r, _, _subst) -> getUniqueId $ uniqueId r) nextPats'
)
)
pure $
RewriteBranch pat' $
NE.fromList $
map
( \(r, n, subst) ->
( ruleLabelOrLocT r
, uniqueId r
, n
, Just
( collapseAndBools $
concatMap (splitBoolPredicates . coerce . substituteInTerm subst . coerce) r.requires
)
)
)
nextPats'

data RewriteStepsState = RewriteStepsState
{ counter :: !Natural
Expand Down

0 comments on commit 391b47c

Please sign in to comment.