diff --git a/booster/library/Booster/Pattern/Rewrite.hs b/booster/library/Booster/Pattern/Rewrite.hs index 56bdc0c778..45c55db166 100644 --- a/booster/library/Booster/Pattern/Rewrite.hs +++ b/booster/library/Booster/Pattern/Rewrite.hs @@ -65,6 +65,7 @@ import Booster.Pattern.Match ( MatchResult (MatchFailed, MatchIndeterminate, MatchSuccess), MatchType (Rewrite), SortError, + Substitution, matchTerms, ) import Booster.Pattern.Pretty @@ -153,7 +154,7 @@ data RewriteStepResult a = OnlyTrivial | AppliedRules a deriving (Eq, Show, Func rewriteStep :: LoggerMIO io => Pattern -> - RewriteT io (RewriteStepResult [(RewriteRule "Rewrite", Pattern)]) + RewriteT io (RewriteStepResult [(RewriteRule "Rewrite", Pattern, Substitution)]) rewriteStep pat = do def <- getDefinition let getIndex = @@ -175,18 +176,18 @@ rewriteStep pat = do -- return `OnlyTrivial` if all elements of a list are `(r, Nothing)`. If the list is empty or contains at least one `(r, Just p)`, -- return an `AppliedRules` list of `(r, p)` pairs. filterOutTrivial :: - [(RewriteRule "Rewrite", Maybe Pattern)] -> - RewriteStepResult [(RewriteRule "Rewrite", Pattern)] + [(RewriteRule "Rewrite", Maybe (Pattern, Substitution))] -> + RewriteStepResult [(RewriteRule "Rewrite", Pattern, Substitution)] filterOutTrivial = \case [] -> AppliedRules [] [(_, Nothing)] -> OnlyTrivial (_, Nothing) : xs -> filterOutTrivial xs - (rule, Just p) : xs -> AppliedRules $ (rule, p) : mapMaybe (\(r, mp) -> (r,) <$> mp) xs + (rule, Just (p, subst)) : xs -> AppliedRules $ (rule, p, subst) : mapMaybe (\(r, mp) -> (\(x, y) -> (r, x, y)) <$> mp) xs processGroups :: LoggerMIO io => [[RewriteRule "Rewrite"]] -> - RewriteT io [(RewriteRule "Rewrite", Maybe Pattern)] + RewriteT io [(RewriteRule "Rewrite", Maybe (Pattern, Substitution))] processGroups [] = pure [] processGroups (rules : lowerPriorityRules) = do -- try all rules of the priority group. This will immediately @@ -209,8 +210,10 @@ rewriteStep pat = do results -- compute remainder condition here from @nonTrivialResults@ and the remainder up to now. -- If the new remainder is bottom, then no lower priority rules apply - newRemainder = currentRemainder <> Set.fromList (mapMaybe (snd . snd) nonTrivialResultsWithPartialRemainders) - resultsWithoutRemainders = map (fmap (fmap fst)) results + newRemainder = + currentRemainder + <> Set.fromList (mapMaybe ((\(_, r, _) -> r) . snd) nonTrivialResultsWithPartialRemainders) + resultsWithoutRemainders = map (fmap (fmap (\(p, _, s) -> (p, s)))) results setRemainder newRemainder ModifiersRep (_ :: FromModifiersT mods => Proxy mods) <- getPrettyModifiers withContext CtxRemainder $ logPretty' @mods (collapseAndBools . Set.toList $ newRemainder) @@ -272,7 +275,7 @@ applyRule :: LoggerMIO io => Pattern -> RewriteRule "Rewrite" -> - RewriteT io (Maybe (Maybe (Pattern, Maybe Predicate))) + RewriteT io (Maybe (Maybe (Pattern, Maybe Predicate, Substitution))) applyRule pat@Pattern{ceilConditions} rule = withRuleContext rule $ runRewriteRuleAppT $ @@ -436,11 +439,12 @@ applyRule pat@Pattern{ceilConditions} rule = ceilConditions withContext CtxSuccess $ do case unclearRequiresAfterSmt of - [] -> withPatternContext rewritten $ pure (rewritten, Nothing) + [] -> withPatternContext rewritten $ pure (rewritten, Nothing, subst) _ -> let rewritten' = rewritten{constraints = rewritten.constraints <> Set.fromList unclearRequiresAfterSmt} in withPatternContext rewritten' $ - pure (rewritten', Just $ Predicate $ NotBool $ coerce $ collapseAndBools unclearRequiresAfterSmt) + pure + (rewritten', Just $ Predicate $ NotBool $ coerce $ collapseAndBools unclearRequiresAfterSmt, subst) where failRewrite :: RewriteFailed "Rewrite" -> RewriteRuleAppT (RewriteT io) a failRewrite = lift . (throw) @@ -841,7 +845,7 @@ performRewrite doTracing def mLlvmLibrary mSolver mbMaxDepth cutLabels terminalL -- We are stuck here not trivial because we didn't apply a single rule logMessage ("Rewrite stuck after simplification." :: Text) >> pure (RewriteStuck pat') pat'@Simplified{} -> logMessage ("Retrying with simplified pattern" :: Text) >> doSteps pat' - AppliedRules [(rule, nextPat)] -- applied single rule + AppliedRules [(rule, nextPat, _subst)] -- applied single rule -- cut-point rule, stop | labelOf rule `elem` cutLabels -> do simplify pat >>= \case @@ -880,34 +884,46 @@ performRewrite doTracing def mLlvmLibrary mSolver mbMaxDepth cutLabels terminalL logMessage $ "Previous state found to be bottom after " <> showCounter counter pure $ RewriteTrivial pat' Simplified pat' -> - (catSimplified <$> mapM (\(r, nextPat) -> fmap (r,) <$> simplify (Unsimplified nextPat)) nextPats) >>= \case - [] -> withPatternContext pat' $ do - logMessage ("Rewrite trivial after pruning all branches" :: Text) - pure $ RewriteTrivial pat' - [(rule, nextPat')] -> withPatternContext pat' $ do - logMessage ("All but one branch pruned, continuing" :: Text) - emitRewriteTrace $ RewriteSingleStep (labelOf rule) (uniqueId rule) pat' nextPat' - incrementCounter - doSteps (Simplified nextPat') - nextPats' -> do - emitRewriteTrace $ - RewriteBranchingStep pat' $ - NE.fromList $ - map (\(rule, _) -> (ruleLabelOrLocT rule, uniqueId rule)) nextPats' - unless (Set.null remainderPredicates) $ do - ModifiersRep (_ :: FromModifiersT mods => Proxy mods) <- getPrettyModifiers - withContext CtxRemainder . withContext CtxDetail $ - logMessage - ( ("Uncovered remainder branch after rewriting with rules " :: Text) - <> ( Text.intercalate ", " $ map (\(r, _) -> getUniqueId $ uniqueId r) nextPats' - ) - ) - pure $ - RewriteBranch pat' $ - NE.fromList $ - map - (\(r, n) -> (ruleLabelOrLocT r, uniqueId r, n, Just (collapseAndBools . Set.toList $ r.requires))) - nextPats' + ( catSimplified + <$> mapM (\(r, nextPat, subst) -> fmap (r,,subst) <$> simplify (Unsimplified nextPat)) nextPats + ) + >>= \case + [] -> withPatternContext pat' $ do + logMessage ("Rewrite trivial after pruning all branches" :: Text) + pure $ RewriteTrivial pat' + [(rule, nextPat', _subst)] -> withPatternContext pat' $ do + logMessage ("All but one branch pruned, continuing" :: Text) + emitRewriteTrace $ RewriteSingleStep (labelOf rule) (uniqueId rule) pat' nextPat' + incrementCounter + doSteps (Simplified nextPat') + nextPats' -> do + emitRewriteTrace $ + RewriteBranchingStep pat' $ + NE.fromList $ + map (\(rule, _, _subst) -> (ruleLabelOrLocT rule, uniqueId rule)) nextPats' + unless (Set.null remainderPredicates) $ do + ModifiersRep (_ :: FromModifiersT mods => Proxy mods) <- getPrettyModifiers + withContext CtxRemainder . withContext CtxDetail $ + logMessage + ( ("Uncovered remainder branch after rewriting with rules " :: Text) + <> ( Text.intercalate ", " $ map (\(r, _, _subst) -> getUniqueId $ uniqueId r) nextPats' + ) + ) + pure $ + RewriteBranch pat' $ + NE.fromList $ + map + ( \(r, n, subst) -> + ( ruleLabelOrLocT r + , uniqueId r + , n + , Just + ( collapseAndBools $ + concatMap (splitBoolPredicates . coerce . substituteInTerm subst . coerce) r.requires + ) + ) + ) + nextPats' data RewriteStepsState = RewriteStepsState { counter :: !Natural