diff --git a/gibbon-compiler/src/Gibbon/Compiler.hs b/gibbon-compiler/src/Gibbon/Compiler.hs index b122f05d6..3014d1d38 100644 --- a/gibbon-compiler/src/Gibbon/Compiler.hs +++ b/gibbon-compiler/src/Gibbon/Compiler.hs @@ -601,6 +601,7 @@ passes config@Config{dynflags} l0 = do should_fuse = gopt Opt_Fusion dynflags opt_layout_local = gopt Opt_Layout_Local dynflags opt_layout_global = gopt Opt_Layout_Global dynflags + use_solver = gopt Opt_Layout_Use_Solver dynflags tcProg3 = L3.tcProg isPacked l0 <- go "freshen" freshNames l0 l0 <- goE0 "typecheck" L0.tcProg l0 @@ -644,12 +645,12 @@ passes config@Config{dynflags} l0 = do -- Note: L1 -> L2 l1 <- if opt_layout_local then do - after_layout_out <- goE1 "optimizeADTLayoutLocal" locallyOptimizeDataConLayout l1 + after_layout_out <- goE1 "optimizeADTLayoutLocal" (locallyOptimizeDataConLayout (not use_solver)) l1 flatten_after_opt <- goE1 "L1.flatten2" flattenL1 after_layout_out pure flatten_after_opt else if opt_layout_global then do - after_layout_out <- goE1 "optimizeADTLayoutGlobal" globallyOptimizeDataConLayout l1 + after_layout_out <- goE1 "optimizeADTLayoutGlobal" (globallyOptimizeDataConLayout (not use_solver)) l1 flatten_after_opt <- goE1 "L1.flatten2" flattenL1 after_layout_out pure flatten_after_opt else return l1 diff --git a/gibbon-compiler/src/Gibbon/DynFlags.hs b/gibbon-compiler/src/Gibbon/DynFlags.hs index 274bce796..f887db342 100644 --- a/gibbon-compiler/src/Gibbon/DynFlags.hs +++ b/gibbon-compiler/src/Gibbon/DynFlags.hs @@ -41,6 +41,7 @@ data GeneralFlag | Opt_SimpleWriteBarrier -- ^ Disables eliminate-indirection-chains optimization. | Opt_Layout_Local -- ^ Optimize the layout of Algebraic data types locally | Opt_Layout_Global -- ^ Optimize the layout of Algebraic data types globally + | Opt_Layout_Use_Solver -- ^ Use the Solver to optimize the layout of the data types. deriving (Show,Read,Eq,Ord) -- | Exactly like GHC's ddump flags. @@ -118,7 +119,8 @@ dynflagsParser = DynFlags <$> (S.fromList <$> many gflagsParser) <*> (S.fromList flag' Opt_NoEagerPromote (long "no-eager-promote" <> help "Disable eager promotion.") <|> flag' Opt_SimpleWriteBarrier (long "simple-write-barrier" <> help "Disables eliminate-indirection-chains optimization.") <|> flag' Opt_Layout_Local (long "opt-layout-local" <> help "Optimizes the Layout of Algebraic data types locally") <|> - flag' Opt_Layout_Global (long "opt-layout-global" <> help "Optimizes the Layout of Algebraic data types globally") + flag' Opt_Layout_Global (long "opt-layout-global" <> help "Optimizes the Layout of Algebraic data types globally") <|> + flag' Opt_Layout_Use_Solver (long "opt-layout-use-solver" <> help "Use the solver instead of a Greedy Heuristic") dflagsParser :: Parser DebugFlag diff --git a/gibbon-compiler/src/Gibbon/Passes/AccessPatternsAnalysis.hs b/gibbon-compiler/src/Gibbon/Passes/AccessPatternsAnalysis.hs index 4aaf5be22..4f60b9e97 100644 --- a/gibbon-compiler/src/Gibbon/Passes/AccessPatternsAnalysis.hs +++ b/gibbon-compiler/src/Gibbon/Passes/AccessPatternsAnalysis.hs @@ -1,5 +1,6 @@ module Gibbon.Passes.AccessPatternsAnalysis ( generateAccessGraphs, + getGreedyOrder, FieldMap, DataConAccessMap, ) @@ -57,7 +58,7 @@ generateAccessGraphs topologicallySortedNodes = P.map nodeFromVertex topologicallySortedVertices map = backtrackVariablesToDataConFields topologicallySortedNodes dcons - edges = + edges = S.toList $ S.fromList $ ( constructFieldGraph Nothing nodeFromVertex @@ -69,9 +70,71 @@ generateAccessGraphs dcons accessMapsList = zipWith (\x y -> (x, y)) [dcons] [edges] accessMaps = M.fromList accessMapsList - in M.insert funName accessMaps fieldMap --dbgTraceIt (sdoc (edges, map)) + in M.insert funName accessMaps fieldMap --dbgTraceIt (sdoc topologicallySortedVertices) dbgTraceIt ("\n") dbgTraceIt (sdoc (topologicallySortedVertices, edges)) dbgTraceIt ("\n") Nothing -> error "generateAccessGraphs: no CFG for function found!" + + +getGreedyOrder :: [((Integer, Integer), Integer)] -> Int -> [Integer] +getGreedyOrder edges fieldLength = + if edges == [] + then P.map P.toInteger [0 .. (fieldLength - 1)] + else + let partial_order = greedyOrderOfVertices edges + completeOrder = P.foldl (\lst val -> if S.member val (S.fromList lst) then lst + else lst ++ [val] + ) partial_order [0 .. (fieldLength - 1)] + in dbgTraceIt (sdoc (edges, completeOrder)) P.map P.toInteger completeOrder + +greedyOrderOfVertices :: [((Integer, Integer), Integer)] -> [Int] +greedyOrderOfVertices ee = let edges' = P.map (\((a, b), c) -> ((P.fromIntegral a, P.fromIntegral b), P.fromIntegral c)) ee + bounds = (\e -> let v = P.foldr (\((i, j), _) s -> S.insert j (S.insert i s)) S.empty e + mini = minimum v + maxi = maximum v + in (mini, maxi) + ) edges' + edgesWithoutWeight = P.map fst edges' + graph = buildG bounds edgesWithoutWeight + weightMap = P.foldr (\(e, w) mm -> M.insert e w mm) M.empty edges' + v'' = greedyOrderOfVerticesHelper graph (topSort graph) weightMap S.empty + in v'' -- dbgTraceIt (sdoc (v'', (M.elems weightMap))) + + +greedyOrderOfVerticesHelper :: Graph -> [Int] -> M.Map (Int, Int) Int -> S.Set Int -> [Int] +greedyOrderOfVerticesHelper graph vertices' weightMap visited = case vertices' of + [] -> [] + x:xs -> if S.member x visited + then greedyOrderOfVerticesHelper graph xs weightMap visited + else let successors = reachable graph x + removeCurr = S.toList $ S.delete x (S.fromList successors) + orderedSucc = orderedSuccsByWeight removeCurr x weightMap visited + visited' = P.foldr S.insert S.empty orderedSucc + v'' = greedyOrderOfVerticesHelper graph xs weightMap visited' + in if successors == [x] + then orderedSucc ++ v'' --dbgTraceIt (sdoc (v'', orderedSucc)) + else [x] ++ orderedSucc ++ v'' + +orderedSuccsByWeight :: [Int] -> Int -> M.Map (Int, Int) Int -> S.Set Int -> [Int] +orderedSuccsByWeight s i weightMap visited = case s of + [] -> [] + _ -> let vertexWithMaxWeight = P.foldr (\v (v', maxx) -> let w = M.lookup (i, v) weightMap + in case w of + Nothing -> (-1, -1) + Just w' -> if w' > maxx + then (v, w') + else (v', maxx) + ) (-1, -1) s + in if fst vertexWithMaxWeight == -1 + then [] + else + let removeVertexWithMaxWeight = S.toList $ S.delete (fst vertexWithMaxWeight) (S.fromList s) + in if S.member (fst vertexWithMaxWeight) visited + then orderedSuccsByWeight removeVertexWithMaxWeight i weightMap visited + else fst vertexWithMaxWeight : orderedSuccsByWeight removeVertexWithMaxWeight i weightMap visited --dbgTraceIt (sdoc (s, removeVertexWithMaxWeight, vertexWithMaxWeight)) + + + + backtrackVariablesToDataConFields :: (FreeVars (e l d), Ord l, Ord d, Ord (e l d), Out d, Out l) => [(((PreExp e l d), Integer), Integer, [Integer])] -> @@ -81,9 +144,9 @@ backtrackVariablesToDataConFields graph dcon = case graph of [] -> M.empty x : xs -> - let newMap = processVertex graph x M.empty dcon + let newMap = processVertex graph x M.empty dcon mlist = M.toList (newMap) - m = backtrackVariablesToDataConFields xs dcon + m = backtrackVariablesToDataConFields xs dcon mlist' = M.toList m newMap' = M.fromList (mlist ++ mlist') in newMap' @@ -93,7 +156,7 @@ processVertex :: [(((PreExp e l d), Integer), Integer, [Integer])] -> (((PreExp e l d), Integer), Integer, [Integer]) -> VariableMap -> - DataCon -> + DataCon -> VariableMap processVertex graph node map dataCon = case node of @@ -101,13 +164,13 @@ processVertex graph node map dataCon = case expression of DataConE loc dcon args -> if dcon == dataCon - then + then let freeVariables = L.concat (P.map (\x -> S.toList (gFreeVars x)) args) maybeIndexes = P.map (getDataConIndexFromVariable graph) freeVariables mapList = M.toList map newMapList = P.zipWith (\x y -> (x, y)) freeVariables maybeIndexes in M.fromList (mapList ++ newMapList) - else map + else map _ -> map getDataConIndexFromVariable :: diff --git a/gibbon-compiler/src/Gibbon/Passes/OptimizeADTLayout.hs b/gibbon-compiler/src/Gibbon/Passes/OptimizeADTLayout.hs index 92f8fe75c..01997f7d3 100644 --- a/gibbon-compiler/src/Gibbon/Passes/OptimizeADTLayout.hs +++ b/gibbon-compiler/src/Gibbon/Passes/OptimizeADTLayout.hs @@ -4,7 +4,7 @@ {-# HLINT ignore "Redundant lambda" #-} {-# HLINT ignore "Use tuple-section" #-} module Gibbon.Passes.OptimizeADTLayout - ( optimizeADTLayout, + ( globallyOptimizeDataConLayout, locallyOptimizeDataConLayout ) @@ -34,6 +34,7 @@ import Gibbon.Passes.AccessPatternsAnalysis ( DataConAccessMap, FieldMap, generateAccessGraphs, + getGreedyOrder ) import Gibbon.Passes.CallGraph ( ProducersMap (..), @@ -66,11 +67,11 @@ import Gibbon.Passes.Flatten (flattenL1) type FieldOrder = M.Map DataCon [Integer] -- TODO: Make FieldOrder an argument passed to shuffleDataCon function. -optimizeADTLayout :: - Prog1 -> - PassM Prog1 -optimizeADTLayout prg@Prog{ddefs, fundefs, mainExp} = - do +--optimizeADTLayout :: +-- Prog1 -> +-- PassM Prog1 +--optimizeADTLayout prg@Prog{ddefs, fundefs, mainExp} = + --do -- let list_pair_func_dcon = -- concatMap ( \fn@(FunDef {funName, funMeta = FunMeta {funOptLayout = layout}}) -> -- case layout of @@ -124,25 +125,25 @@ optimizeADTLayout prg@Prog{ddefs, fundefs, mainExp} = -- p -- pure prg' --prg' <- runUntilFixPoint prg - globallyOptimizeDataConLayout prg + --globallyOptimizeDataConLayout prg --pure prg' --generateCopyFunctionsForFunctionsThatUseOptimizedVariable (toVar funcName) (dcon ++ "Optimized") fieldorder prg' --_ -> error "OptimizeFieldOrder: handle user constraints" -locallyOptimizeDataConLayout :: Prog1 -> PassM Prog1 -locallyOptimizeDataConLayout prg1 = do - runUntilFixPoint prg1 +locallyOptimizeDataConLayout :: Bool -> Prog1 -> PassM Prog1 +locallyOptimizeDataConLayout useGreedy prg1 = do + runUntilFixPoint useGreedy prg1 -runUntilFixPoint :: Prog1 -> PassM Prog1 -runUntilFixPoint prog1 = do - prog1' <- producerConsumerLayoutOptimization prog1 +runUntilFixPoint :: Bool -> Prog1 -> PassM Prog1 +runUntilFixPoint useGreedy prog1 = do + prog1' <- producerConsumerLayoutOptimization prog1 useGreedy prog1'' <- flattenL1 prog1' if prog1 == prog1'' then return prog1 - else runUntilFixPoint prog1'' + else runUntilFixPoint useGreedy prog1'' dataConsInFunBody :: Exp1 -> S.Set DataCon @@ -172,8 +173,8 @@ dataConsInFunBody funBody = case funBody of MapE {} -> error "getGeneratedVariable: TODO MapE" FoldE {} -> error "getGeneratedVariable: TODO FoldE" -producerConsumerLayoutOptimization :: Prog1 -> PassM Prog1 -producerConsumerLayoutOptimization prg@Prog{ddefs, fundefs, mainExp} = do +producerConsumerLayoutOptimization :: Prog1 -> Bool -> PassM Prog1 +producerConsumerLayoutOptimization prg@Prog{ddefs, fundefs, mainExp} useGreedy = do -- TODO: make a custom function name printer that guarantees that functions starting with _ are auto-generated. let funsToOptimize = P.concatMap (\FunDef{funName} -> ([funName | not $ isInfixOf "_" (fromVar funName)]) ) $ M.elems fundefs @@ -193,7 +194,7 @@ producerConsumerLayoutOptimization prg@Prog{ddefs, fundefs, mainExp} = do Just x -> x Nothing -> error "producerConsumerLayoutOptimization: expected a function definition!!" let fieldOrder = getAccessGraph f dcon - let result = optimizeFunctionWRTDataCon dd fd dcon (fromVar newSymDcon) fieldOrder + let result = optimizeFunctionWRTDataCon dd fd dcon (fromVar newSymDcon) fieldOrder useGreedy case result of Nothing -> pure pr --dbgTraceIt (sdoc (result, fname, fieldOrder)) Just (ddefs', fundef', fieldorder) -> let fundefs' = M.delete fname fds @@ -207,8 +208,8 @@ producerConsumerLayoutOptimization prg@Prog{ddefs, fundefs, mainExp} = do P.foldrM lambda prg linearizeDcons --dbgTraceIt (sdoc linearizeDcons) -globallyOptimizeDataConLayout :: Prog1 -> PassM Prog1 -globallyOptimizeDataConLayout prg@Prog{ddefs, fundefs, mainExp} = do +globallyOptimizeDataConLayout :: Bool -> Prog1 -> PassM Prog1 +globallyOptimizeDataConLayout useGreedy prg@Prog{ddefs, fundefs, mainExp} = do -- TODO: make a custom function name printer that guarantees that functions starting with _ are auto-generated. let funsToOptimize = P.concatMap (\FunDef{funName} -> ([funName | not $ isInfixOf "_" (fromVar funName)]) ) $ M.elems fundefs @@ -261,7 +262,7 @@ globallyOptimizeDataConLayout prg@Prog{ddefs, fundefs, mainExp} = do let fd = case maybeFd of Just x -> x Nothing -> error "globallyOptimizeDataConLayout: expected a function definition!!" - let result = optimizeFunctionWRTDataCon dd fd dcon (fromVar newSymDcon) fieldOrder + let result = optimizeFunctionWRTDataCon dd fd dcon (fromVar newSymDcon) fieldOrder useGreedy case result of Nothing -> pure pr Just (ddefs', fundef', fieldorder) -> let fundefs' = M.delete fname fds @@ -491,12 +492,16 @@ getAccessGraph + +-- getGreedyFieldOrder :: Int -> DataCon -> FieldMap + optimizeFunctionWRTDataCon :: DDefs1 -> FunDef1 -> DataCon -> DataCon -> FieldMap -> + Bool -> Maybe (DDefs1, FunDef1, FieldOrder) optimizeFunctionWRTDataCon ddefs @@ -508,7 +513,9 @@ optimizeFunctionWRTDataCon } datacon newDcon - fieldMap = + fieldMap + useGreedy = case useGreedy of + False -> let field_len = P.length $ snd . snd $ lkp' ddefs datacon fieldorder = optimizeDataConOrderFunc @@ -531,7 +538,24 @@ optimizeFunctionWRTDataCon fundef' = shuffleDataConFunBody True fieldorder fundef newDcon in Just (newDDefs, fundef', fieldorder) --dbgTraceIt (sdoc order) -- dbgTraceIt (sdoc fieldorder) _ -> error "more than one" - + True -> + let field_len = P.length $ snd . snd $ lkp' ddefs datacon + edges' = case (M.lookup funName fieldMap) of + Just d -> case (M.lookup datacon d) of + Nothing -> error "" + Just e -> e + Nothing -> error "" + greedy_order = getGreedyOrder edges' field_len + fieldorder = M.insert datacon greedy_order M.empty + in case M.toList fieldorder of + [] -> Nothing --dbgTraceIt (sdoc fieldorder) dbgTraceIt (sdoc greedy_order) + [(dcon, order)] -> let orignal_order = [0..(P.length order - 1)] + in if orignal_order == P.map P.fromInteger order + then Nothing + else let newDDefs = optimizeDataCon (dcon, order) ddefs newDcon + fundef' = shuffleDataConFunBody True fieldorder fundef newDcon + in Just (newDDefs, fundef', fieldorder) --dbgTraceIt (sdoc order) -- dbgTraceIt (sdoc fieldorder) dbgTraceIt (sdoc greedy_order) + _ -> error "more than one" changeCallNameInRecFunction :: Var -> FunDef1 -> FunDef1 diff --git a/gibbon-compiler/tests/test-gibbon-examples.yaml b/gibbon-compiler/tests/test-gibbon-examples.yaml index 64a754ba1..ff123bb0d 100644 --- a/gibbon-compiler/tests/test-gibbon-examples.yaml +++ b/gibbon-compiler/tests/test-gibbon-examples.yaml @@ -896,21 +896,21 @@ tests: run-modes: ["gibbon2", "gibbon3", "pointer"] - name: layout1ContentSearchRunPipeline.hs - test-flags: ["--no-gc", "--opt-layout-local"] + test-flags: ["--no-gc", "--opt-layout-local", "--opt-layout-use-solver"] dir: examples/layout_bench answer-file: examples/layout_bench/layout1ContentSearchRunPipeline.ans failing: [interp1,pointer,gibbon1, gibbon3] run-modes: ["gibbon2"] - name: manyFuncs.hs - test-flags: ["--no-gc", "--opt-layout-local"] + test-flags: ["--no-gc", "--opt-layout-local", "--opt-layout-use-solver"] dir: examples/layout_bench answer-file: examples/layout_bench/manyFuncsLocal.ans failing: [interp1,pointer,gibbon1, gibbon3] run-modes: ["gibbon2"] - name: manyFuncs.hs - test-flags: ["--no-gc", "--opt-layout-global"] + test-flags: ["--no-gc", "--opt-layout-global", "--opt-layout-use-solver"] dir: examples/layout_bench answer-file: examples/layout_bench/manyFuncsGlobal.ans failing: [interp1,pointer,gibbon1, gibbon3]