Skip to content

Commit

Permalink
[LayoutOpt]: Add a greedy heuristic to ameliorate solver time in lieu…
Browse files Browse the repository at this point in the history
… of runtime performance.
  • Loading branch information
vidsinghal committed Oct 1, 2023
1 parent 540b2b7 commit 209a90d
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 35 deletions.
5 changes: 3 additions & 2 deletions gibbon-compiler/src/Gibbon/Compiler.hs
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,7 @@ passes config@Config{dynflags} l0 = do
should_fuse = gopt Opt_Fusion dynflags
opt_layout_local = gopt Opt_Layout_Local dynflags
opt_layout_global = gopt Opt_Layout_Global dynflags
use_solver = gopt Opt_Layout_Use_Solver dynflags
tcProg3 = L3.tcProg isPacked
l0 <- go "freshen" freshNames l0
l0 <- goE0 "typecheck" L0.tcProg l0
Expand Down Expand Up @@ -644,12 +645,12 @@ passes config@Config{dynflags} l0 = do
-- Note: L1 -> L2
l1 <- if opt_layout_local
then do
after_layout_out <- goE1 "optimizeADTLayoutLocal" locallyOptimizeDataConLayout l1
after_layout_out <- goE1 "optimizeADTLayoutLocal" (locallyOptimizeDataConLayout (not use_solver)) l1
flatten_after_opt <- goE1 "L1.flatten2" flattenL1 after_layout_out
pure flatten_after_opt
else if opt_layout_global
then do
after_layout_out <- goE1 "optimizeADTLayoutGlobal" globallyOptimizeDataConLayout l1
after_layout_out <- goE1 "optimizeADTLayoutGlobal" (globallyOptimizeDataConLayout (not use_solver)) l1
flatten_after_opt <- goE1 "L1.flatten2" flattenL1 after_layout_out
pure flatten_after_opt
else return l1
Expand Down
4 changes: 3 additions & 1 deletion gibbon-compiler/src/Gibbon/DynFlags.hs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ data GeneralFlag
| Opt_SimpleWriteBarrier -- ^ Disables eliminate-indirection-chains optimization.
| Opt_Layout_Local -- ^ Optimize the layout of Algebraic data types locally
| Opt_Layout_Global -- ^ Optimize the layout of Algebraic data types globally
| Opt_Layout_Use_Solver -- ^ Use the Solver to optimize the layout of the data types.
deriving (Show,Read,Eq,Ord)

-- | Exactly like GHC's ddump flags.
Expand Down Expand Up @@ -118,7 +119,8 @@ dynflagsParser = DynFlags <$> (S.fromList <$> many gflagsParser) <*> (S.fromList
flag' Opt_NoEagerPromote (long "no-eager-promote" <> help "Disable eager promotion.") <|>
flag' Opt_SimpleWriteBarrier (long "simple-write-barrier" <> help "Disables eliminate-indirection-chains optimization.") <|>
flag' Opt_Layout_Local (long "opt-layout-local" <> help "Optimizes the Layout of Algebraic data types locally") <|>
flag' Opt_Layout_Global (long "opt-layout-global" <> help "Optimizes the Layout of Algebraic data types globally")
flag' Opt_Layout_Global (long "opt-layout-global" <> help "Optimizes the Layout of Algebraic data types globally") <|>
flag' Opt_Layout_Use_Solver (long "opt-layout-use-solver" <> help "Use the solver instead of a Greedy Heuristic")


dflagsParser :: Parser DebugFlag
Expand Down
77 changes: 70 additions & 7 deletions gibbon-compiler/src/Gibbon/Passes/AccessPatternsAnalysis.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module Gibbon.Passes.AccessPatternsAnalysis
( generateAccessGraphs,
getGreedyOrder,
FieldMap,
DataConAccessMap,
)
Expand Down Expand Up @@ -57,7 +58,7 @@ generateAccessGraphs
topologicallySortedNodes =
P.map nodeFromVertex topologicallySortedVertices
map = backtrackVariablesToDataConFields topologicallySortedNodes dcons
edges =
edges = S.toList $ S.fromList $
( constructFieldGraph
Nothing
nodeFromVertex
Expand All @@ -69,9 +70,71 @@ generateAccessGraphs
dcons
accessMapsList = zipWith (\x y -> (x, y)) [dcons] [edges]
accessMaps = M.fromList accessMapsList
in M.insert funName accessMaps fieldMap --dbgTraceIt (sdoc (edges, map))
in M.insert funName accessMaps fieldMap --dbgTraceIt (sdoc topologicallySortedVertices) dbgTraceIt ("\n") dbgTraceIt (sdoc (topologicallySortedVertices, edges)) dbgTraceIt ("\n")
Nothing -> error "generateAccessGraphs: no CFG for function found!"



getGreedyOrder :: [((Integer, Integer), Integer)] -> Int -> [Integer]
getGreedyOrder edges fieldLength =
if edges == []
then P.map P.toInteger [0 .. (fieldLength - 1)]
else
let partial_order = greedyOrderOfVertices edges
completeOrder = P.foldl (\lst val -> if S.member val (S.fromList lst) then lst
else lst ++ [val]
) partial_order [0 .. (fieldLength - 1)]
in dbgTraceIt (sdoc (edges, completeOrder)) P.map P.toInteger completeOrder

greedyOrderOfVertices :: [((Integer, Integer), Integer)] -> [Int]
greedyOrderOfVertices ee = let edges' = P.map (\((a, b), c) -> ((P.fromIntegral a, P.fromIntegral b), P.fromIntegral c)) ee
bounds = (\e -> let v = P.foldr (\((i, j), _) s -> S.insert j (S.insert i s)) S.empty e
mini = minimum v
maxi = maximum v
in (mini, maxi)
) edges'
edgesWithoutWeight = P.map fst edges'
graph = buildG bounds edgesWithoutWeight
weightMap = P.foldr (\(e, w) mm -> M.insert e w mm) M.empty edges'
v'' = greedyOrderOfVerticesHelper graph (topSort graph) weightMap S.empty
in v'' -- dbgTraceIt (sdoc (v'', (M.elems weightMap)))


greedyOrderOfVerticesHelper :: Graph -> [Int] -> M.Map (Int, Int) Int -> S.Set Int -> [Int]
greedyOrderOfVerticesHelper graph vertices' weightMap visited = case vertices' of
[] -> []
x:xs -> if S.member x visited
then greedyOrderOfVerticesHelper graph xs weightMap visited
else let successors = reachable graph x
removeCurr = S.toList $ S.delete x (S.fromList successors)
orderedSucc = orderedSuccsByWeight removeCurr x weightMap visited
visited' = P.foldr S.insert S.empty orderedSucc
v'' = greedyOrderOfVerticesHelper graph xs weightMap visited'
in if successors == [x]
then orderedSucc ++ v'' --dbgTraceIt (sdoc (v'', orderedSucc))
else [x] ++ orderedSucc ++ v''

orderedSuccsByWeight :: [Int] -> Int -> M.Map (Int, Int) Int -> S.Set Int -> [Int]
orderedSuccsByWeight s i weightMap visited = case s of
[] -> []
_ -> let vertexWithMaxWeight = P.foldr (\v (v', maxx) -> let w = M.lookup (i, v) weightMap
in case w of
Nothing -> (-1, -1)
Just w' -> if w' > maxx
then (v, w')
else (v', maxx)
) (-1, -1) s
in if fst vertexWithMaxWeight == -1
then []
else
let removeVertexWithMaxWeight = S.toList $ S.delete (fst vertexWithMaxWeight) (S.fromList s)
in if S.member (fst vertexWithMaxWeight) visited
then orderedSuccsByWeight removeVertexWithMaxWeight i weightMap visited
else fst vertexWithMaxWeight : orderedSuccsByWeight removeVertexWithMaxWeight i weightMap visited --dbgTraceIt (sdoc (s, removeVertexWithMaxWeight, vertexWithMaxWeight))




backtrackVariablesToDataConFields ::
(FreeVars (e l d), Ord l, Ord d, Ord (e l d), Out d, Out l) =>
[(((PreExp e l d), Integer), Integer, [Integer])] ->
Expand All @@ -81,9 +144,9 @@ backtrackVariablesToDataConFields graph dcon =
case graph of
[] -> M.empty
x : xs ->
let newMap = processVertex graph x M.empty dcon
let newMap = processVertex graph x M.empty dcon
mlist = M.toList (newMap)
m = backtrackVariablesToDataConFields xs dcon
m = backtrackVariablesToDataConFields xs dcon
mlist' = M.toList m
newMap' = M.fromList (mlist ++ mlist')
in newMap'
Expand All @@ -93,21 +156,21 @@ processVertex ::
[(((PreExp e l d), Integer), Integer, [Integer])] ->
(((PreExp e l d), Integer), Integer, [Integer]) ->
VariableMap ->
DataCon ->
DataCon ->
VariableMap
processVertex graph node map dataCon =
case node of
((expression, likelihood), id, succ) ->
case expression of
DataConE loc dcon args ->
if dcon == dataCon
then
then
let freeVariables = L.concat (P.map (\x -> S.toList (gFreeVars x)) args)
maybeIndexes = P.map (getDataConIndexFromVariable graph) freeVariables
mapList = M.toList map
newMapList = P.zipWith (\x y -> (x, y)) freeVariables maybeIndexes
in M.fromList (mapList ++ newMapList)
else map
else map
_ -> map

getDataConIndexFromVariable ::
Expand Down
68 changes: 46 additions & 22 deletions gibbon-compiler/src/Gibbon/Passes/OptimizeADTLayout.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
{-# HLINT ignore "Redundant lambda" #-}
{-# HLINT ignore "Use tuple-section" #-}
module Gibbon.Passes.OptimizeADTLayout
( optimizeADTLayout,
(
globallyOptimizeDataConLayout,
locallyOptimizeDataConLayout
)
Expand Down Expand Up @@ -34,6 +34,7 @@ import Gibbon.Passes.AccessPatternsAnalysis
( DataConAccessMap,
FieldMap,
generateAccessGraphs,
getGreedyOrder
)
import Gibbon.Passes.CallGraph
( ProducersMap (..),
Expand Down Expand Up @@ -66,11 +67,11 @@ import Gibbon.Passes.Flatten (flattenL1)
type FieldOrder = M.Map DataCon [Integer]

-- TODO: Make FieldOrder an argument passed to shuffleDataCon function.
optimizeADTLayout ::
Prog1 ->
PassM Prog1
optimizeADTLayout prg@Prog{ddefs, fundefs, mainExp} =
do
--optimizeADTLayout ::
-- Prog1 ->
-- PassM Prog1
--optimizeADTLayout prg@Prog{ddefs, fundefs, mainExp} =
--do
-- let list_pair_func_dcon =
-- concatMap ( \fn@(FunDef {funName, funMeta = FunMeta {funOptLayout = layout}}) ->
-- case layout of
Expand Down Expand Up @@ -124,25 +125,25 @@ optimizeADTLayout prg@Prog{ddefs, fundefs, mainExp} =
-- p
-- pure prg'
--prg' <- runUntilFixPoint prg
globallyOptimizeDataConLayout prg
--globallyOptimizeDataConLayout prg
--pure prg'
--generateCopyFunctionsForFunctionsThatUseOptimizedVariable (toVar funcName) (dcon ++ "Optimized") fieldorder prg'
--_ -> error "OptimizeFieldOrder: handle user constraints"


locallyOptimizeDataConLayout :: Prog1 -> PassM Prog1
locallyOptimizeDataConLayout prg1 = do
runUntilFixPoint prg1
locallyOptimizeDataConLayout :: Bool -> Prog1 -> PassM Prog1
locallyOptimizeDataConLayout useGreedy prg1 = do
runUntilFixPoint useGreedy prg1



runUntilFixPoint :: Prog1 -> PassM Prog1
runUntilFixPoint prog1 = do
prog1' <- producerConsumerLayoutOptimization prog1
runUntilFixPoint :: Bool -> Prog1 -> PassM Prog1
runUntilFixPoint useGreedy prog1 = do
prog1' <- producerConsumerLayoutOptimization prog1 useGreedy
prog1'' <- flattenL1 prog1'
if prog1 == prog1''
then return prog1
else runUntilFixPoint prog1''
else runUntilFixPoint useGreedy prog1''


dataConsInFunBody :: Exp1 -> S.Set DataCon
Expand Down Expand Up @@ -172,8 +173,8 @@ dataConsInFunBody funBody = case funBody of
MapE {} -> error "getGeneratedVariable: TODO MapE"
FoldE {} -> error "getGeneratedVariable: TODO FoldE"

producerConsumerLayoutOptimization :: Prog1 -> PassM Prog1
producerConsumerLayoutOptimization prg@Prog{ddefs, fundefs, mainExp} = do
producerConsumerLayoutOptimization :: Prog1 -> Bool -> PassM Prog1
producerConsumerLayoutOptimization prg@Prog{ddefs, fundefs, mainExp} useGreedy = do
-- TODO: make a custom function name printer that guarantees that functions starting with _ are auto-generated.
let funsToOptimize = P.concatMap (\FunDef{funName} -> ([funName | not $ isInfixOf "_" (fromVar funName)])
) $ M.elems fundefs
Expand All @@ -193,7 +194,7 @@ producerConsumerLayoutOptimization prg@Prog{ddefs, fundefs, mainExp} = do
Just x -> x
Nothing -> error "producerConsumerLayoutOptimization: expected a function definition!!"
let fieldOrder = getAccessGraph f dcon
let result = optimizeFunctionWRTDataCon dd fd dcon (fromVar newSymDcon) fieldOrder
let result = optimizeFunctionWRTDataCon dd fd dcon (fromVar newSymDcon) fieldOrder useGreedy
case result of
Nothing -> pure pr --dbgTraceIt (sdoc (result, fname, fieldOrder))
Just (ddefs', fundef', fieldorder) -> let fundefs' = M.delete fname fds
Expand All @@ -207,8 +208,8 @@ producerConsumerLayoutOptimization prg@Prog{ddefs, fundefs, mainExp} = do
P.foldrM lambda prg linearizeDcons --dbgTraceIt (sdoc linearizeDcons)


globallyOptimizeDataConLayout :: Prog1 -> PassM Prog1
globallyOptimizeDataConLayout prg@Prog{ddefs, fundefs, mainExp} = do
globallyOptimizeDataConLayout :: Bool -> Prog1 -> PassM Prog1
globallyOptimizeDataConLayout useGreedy prg@Prog{ddefs, fundefs, mainExp} = do
-- TODO: make a custom function name printer that guarantees that functions starting with _ are auto-generated.
let funsToOptimize = P.concatMap (\FunDef{funName} -> ([funName | not $ isInfixOf "_" (fromVar funName)])
) $ M.elems fundefs
Expand Down Expand Up @@ -261,7 +262,7 @@ globallyOptimizeDataConLayout prg@Prog{ddefs, fundefs, mainExp} = do
let fd = case maybeFd of
Just x -> x
Nothing -> error "globallyOptimizeDataConLayout: expected a function definition!!"
let result = optimizeFunctionWRTDataCon dd fd dcon (fromVar newSymDcon) fieldOrder
let result = optimizeFunctionWRTDataCon dd fd dcon (fromVar newSymDcon) fieldOrder useGreedy
case result of
Nothing -> pure pr
Just (ddefs', fundef', fieldorder) -> let fundefs' = M.delete fname fds
Expand Down Expand Up @@ -491,12 +492,16 @@ getAccessGraph




-- getGreedyFieldOrder :: Int -> DataCon -> FieldMap

optimizeFunctionWRTDataCon ::
DDefs1 ->
FunDef1 ->
DataCon ->
DataCon ->
FieldMap ->
Bool ->
Maybe (DDefs1, FunDef1, FieldOrder)
optimizeFunctionWRTDataCon
ddefs
Expand All @@ -508,7 +513,9 @@ optimizeFunctionWRTDataCon
}
datacon
newDcon
fieldMap =
fieldMap
useGreedy = case useGreedy of
False ->
let field_len = P.length $ snd . snd $ lkp' ddefs datacon
fieldorder =
optimizeDataConOrderFunc
Expand All @@ -531,7 +538,24 @@ optimizeFunctionWRTDataCon
fundef' = shuffleDataConFunBody True fieldorder fundef newDcon
in Just (newDDefs, fundef', fieldorder) --dbgTraceIt (sdoc order) -- dbgTraceIt (sdoc fieldorder)
_ -> error "more than one"

True ->
let field_len = P.length $ snd . snd $ lkp' ddefs datacon
edges' = case (M.lookup funName fieldMap) of
Just d -> case (M.lookup datacon d) of
Nothing -> error ""
Just e -> e
Nothing -> error ""
greedy_order = getGreedyOrder edges' field_len
fieldorder = M.insert datacon greedy_order M.empty
in case M.toList fieldorder of
[] -> Nothing --dbgTraceIt (sdoc fieldorder) dbgTraceIt (sdoc greedy_order)
[(dcon, order)] -> let orignal_order = [0..(P.length order - 1)]
in if orignal_order == P.map P.fromInteger order
then Nothing
else let newDDefs = optimizeDataCon (dcon, order) ddefs newDcon
fundef' = shuffleDataConFunBody True fieldorder fundef newDcon
in Just (newDDefs, fundef', fieldorder) --dbgTraceIt (sdoc order) -- dbgTraceIt (sdoc fieldorder) dbgTraceIt (sdoc greedy_order)
_ -> error "more than one"

changeCallNameInRecFunction ::
Var -> FunDef1 -> FunDef1
Expand Down
6 changes: 3 additions & 3 deletions gibbon-compiler/tests/test-gibbon-examples.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -896,21 +896,21 @@ tests:
run-modes: ["gibbon2", "gibbon3", "pointer"]

- name: layout1ContentSearchRunPipeline.hs
test-flags: ["--no-gc", "--opt-layout-local"]
test-flags: ["--no-gc", "--opt-layout-local", "--opt-layout-use-solver"]
dir: examples/layout_bench
answer-file: examples/layout_bench/layout1ContentSearchRunPipeline.ans
failing: [interp1,pointer,gibbon1, gibbon3]
run-modes: ["gibbon2"]

- name: manyFuncs.hs
test-flags: ["--no-gc", "--opt-layout-local"]
test-flags: ["--no-gc", "--opt-layout-local", "--opt-layout-use-solver"]
dir: examples/layout_bench
answer-file: examples/layout_bench/manyFuncsLocal.ans
failing: [interp1,pointer,gibbon1, gibbon3]
run-modes: ["gibbon2"]

- name: manyFuncs.hs
test-flags: ["--no-gc", "--opt-layout-global"]
test-flags: ["--no-gc", "--opt-layout-global", "--opt-layout-use-solver"]
dir: examples/layout_bench
answer-file: examples/layout_bench/manyFuncsGlobal.ans
failing: [interp1,pointer,gibbon1, gibbon3]
Expand Down

0 comments on commit 209a90d

Please sign in to comment.