From 360551d830468585338130293291a9e052de2fec Mon Sep 17 00:00:00 2001 From: Olle Fredriksson Date: Fri, 31 May 2024 23:19:13 +0200 Subject: [PATCH] Implement simple reference counting scheme --- src/LowToLLVM.hs | 151 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 118 insertions(+), 33 deletions(-) diff --git a/src/LowToLLVM.hs b/src/LowToLLVM.hs index 548df78..13d660f 100644 --- a/src/LowToLLVM.hs +++ b/src/LowToLLVM.hs @@ -198,6 +198,12 @@ declareLLVMGlobal name decl = ------------------------------------------------------------------------------- +data StackAllocation = StackAllocation + { saved :: !Var + , reference :: !Var + , size :: !Operand + } + saveStack :: Assembler Var saveStack = do declareLLVMGlobal "llvm.stacksave" "declare ptr @llvm.stacksave()" @@ -205,10 +211,88 @@ saveStack = do emitInstruction $ varName var <> " = call ptr @llvm.stacksave()" pure var -restoreStack :: Var -> Assembler () -restoreStack var = do +restoreStack :: StackAllocation -> Assembler () +restoreStack stackAllocation = do declareLLVMGlobal "llvm.stackrestore" "declare void @llvm.stackrestore(ptr)" - emitInstruction $ "call void @llvm.stackrestore" <> parens ["ptr " <> varName var] + decreaseReferenceCounts stackAllocation.size stackAllocation.reference + emitInstruction $ "call void @llvm.stackrestore" <> parens ["ptr " <> varName stackAllocation.saved] + +increaseReferenceCount :: Representation -> Operand -> Assembler () +increaseReferenceCount repr o = + case (pointerType repr.pointers, nonPointerType repr.nonPointerBytes) of + (Nothing, _) -> pure () + (Just _, npType) -> do + declareLLVMGlobal "sixten_increase_reference_count" "declare void @sixten_increase_reference_count(i64)" + o' <- case npType of + Nothing -> pure o + Just _ -> do + pointers <- freshVar "pointers" + emitInstruction $ + varName pointers + <> " = extractvalue " + <> typedOperand (PassBy.Value repr, o) + <> ", 0" + pure $ Local pointers + case repr.pointers of + 1 -> + emitInstruction $ + "call void @sixten_increase_reference_count" + <> parens ["i64 " <> operand o'] + n -> do + forM_ [0 .. n - 1] \i -> do + extractedPointer <- freshVar "extracted_pointer" + emitInstruction $ + varName extractedPointer + <> " = extractvalue " + <> typedOperand (PassBy.Value repr, o') + <> ", " + <> Builder.word32Dec i + emitInstruction $ + "call void @sixten_increase_reference_count" + <> parens ["i64 " <> varName extractedPointer] + +decreaseReferenceCounts :: Operand -> Var -> Assembler () +decreaseReferenceCounts size reference = do + declareLLVMGlobal "sixten_decrease_reference_counts" "declare void @sixten_decrease_reference_counts(ptr, i32)" + (pointers, _) <- extractSizeParts (PassBy.Value Representation.type_, size) + (pointersPointer, _) <- extractParts (PassBy.Reference, Local reference) + emitInstruction $ + "call void @sixten_decrease_reference_counts" + <> parens ["ptr " <> operand pointersPointer, "i32 " <> varName pointers] + +decreaseReferenceCount :: Representation -> Operand -> Assembler () +decreaseReferenceCount repr o = + case (pointerType repr.pointers, nonPointerType repr.nonPointerBytes) of + (Nothing, _) -> pure () + (Just _, npType) -> do + declareLLVMGlobal "sixten_decrease_reference_count" "declare void @sixten_decrease_reference_count(i64)" + o' <- case npType of + Nothing -> pure o + Just _ -> do + pointers <- freshVar "pointers" + emitInstruction $ + varName pointers + <> " = extractvalue " + <> typedOperand (PassBy.Value repr, o) + <> ", 0" + pure $ Local pointers + case repr.pointers of + 1 -> + emitInstruction $ + "call void @sixten_decrease_reference_count" + <> parens ["i64 " <> operand o'] + n -> do + forM_ [0 .. n - 1] \i -> do + extractedPointer <- freshVar "extracted_pointer" + emitInstruction $ + varName extractedPointer + <> " = extractvalue " + <> typedOperand (PassBy.Value repr, o') + <> ", " + <> Builder.word32Dec i + emitInstruction $ + "call void @sixten_decrease_reference_count" + <> parens ["i64 " <> varName extractedPointer] ------------------------------------------------------------------------------- @@ -322,14 +406,20 @@ assembleTerm -> Maybe Name -> PassBy -> Syntax.Term v - -> Assembler (Operand, Maybe Var) + -> Assembler (Operand, Maybe StackAllocation) assembleTerm env nameSuggestion passBy = \case Syntax.Operand o -> do - (_, o') <- assembleOperand env o + (passOperandBy, o') <- assembleOperand env o + case passOperandBy of + PassBy.Value repr -> increaseReferenceCount repr o' + PassBy.Reference -> pure () pure (o', Nothing) Syntax.Let passLetBy name term body -> do (termResult, termStack) <- assembleTerm env (Just name) passLetBy term (bodyResult, bodyStack) <- assembleTerm (env Index.Seq.:> (passLetBy, termResult)) nameSuggestion passBy body + case passLetBy of + PassBy.Value repr -> decreaseReferenceCount repr termResult + PassBy.Reference -> pure () mapM_ restoreStack termStack mapM_ restoreStack bodyStack pure (bodyResult, Nothing) @@ -448,7 +538,15 @@ assembleTerm env nameSuggestion passBy = \case <> ", i32 " <> varName pointers result <- constructTuple (fromMaybe "alloca_result" nameSuggestion) "ptr" allocaBytes "ptr" nonPointerPointer - pure (Local result, Just stack) + pure + ( Local result + , Just + StackAllocation + { saved = stack + , reference = result + , size = snd $ size' + } + ) Syntax.HeapAllocate constr size -> do declareLLVMGlobal "sixten_heap_allocate" "declare i64 @sixten_heap_allocate(i64, i32, i32)" var <- freshVar $ fromMaybe "heap_allocation" nameSuggestion @@ -514,36 +612,20 @@ assembleTerm env nameSuggestion passBy = \case src' <- assembleOperand env src size' <- assembleOperand env size (pointers, nonPointerBytes) <- extractSizeParts size' - (dstPointerPointer, dstNonPointerPointer) <- extractParts dst' - (srcPointerPointer, srcNonPointerPointer) <- extractParts src' - declareLLVMGlobal "llvm.memcpy.p0.p0.i32" "declare void @llvm.memcpy.p0.p0.i32(ptr, ptr, i32, i1)" - pointerBytes <- freshVar "pointer_bytes" - emitInstruction $ - varName pointerBytes - <> " = mul i32 " - <> varName pointers - <> ", " - <> Builder.intDec Representation.wordBytes + declareLLVMGlobal "sixten_copy" "declare void @sixten_copy({ptr, ptr}, {ptr, ptr}, i32, i32)" emitInstruction $ - "call void @llvm.memcpy.p0.p0.i32" + "call void @sixten_copy" <> parens - [ "ptr " <> operand dstPointerPointer - , "ptr " <> operand srcPointerPointer - , "i32 " <> varName pointerBytes - , "i1 0" -- isvolatile - ] - emitInstruction $ - "call void @llvm.memcpy.p0.p0.i32" - <> parens - [ "ptr " <> operand dstNonPointerPointer - , "ptr " <> operand srcNonPointerPointer + [ typedOperand dst' + , typedOperand src' + , "i32 " <> varName pointers , "i32 " <> varName nonPointerBytes - , "i1 0" -- isvolatile ] pure (Constant "undef", Nothing) Syntax.Store dst src repr -> do dst' <- assembleOperand env dst src' <- assembleOperand env src + increaseReferenceCount repr $ snd src' (dstPointerPointer, dstNonPointerPointer) <- extractParts dst' case (pointerType repr.pointers, nonPointerType repr.nonPointerBytes) of (Nothing, Nothing) -> pure () @@ -559,23 +641,26 @@ assembleTerm env nameSuggestion passBy = \case Syntax.Load src repr -> do src' <- assembleOperand env src (srcPointerPointer, srcNonPointerPointer) <- extractParts src' - case (pointerType repr.pointers, nonPointerType repr.nonPointerBytes) of - (Nothing, Nothing) -> pure (Constant "undef", Nothing) + result <- case (pointerType repr.pointers, nonPointerType repr.nonPointerBytes) of + (Nothing, Nothing) -> pure $ Constant "undef" (Just p, Nothing) -> do result <- freshVar $ fromMaybe "load_result" nameSuggestion emitInstruction $ varName result <> " = load " <> p <> ", ptr " <> operand srcPointerPointer - pure (Local result, Nothing) + pure $ Local result (Nothing, Just np) -> do result <- freshVar $ fromMaybe "load_result" nameSuggestion emitInstruction $ varName result <> " = load " <> np <> ", ptr " <> operand srcNonPointerPointer - pure (Local result, Nothing) + pure $ Local result (Just p, Just np) -> do pointers <- freshVar "load_pointers" nonPointers <- freshVar "load_non_pointers" emitInstruction $ varName pointers <> " = load " <> p <> ", ptr " <> operand srcPointerPointer emitInstruction $ varName nonPointers <> " = load " <> np <> ", ptr " <> operand srcNonPointerPointer result <- constructTuple (fromMaybe "load_result" nameSuggestion) p pointers np nonPointers - pure (Local result, Nothing) + pure $ Local result + + increaseReferenceCount repr result + pure (result, Nothing) assembleOperand :: Environment v -> Syntax.Operand v -> Assembler (PassBy, Operand) assembleOperand env = \case