Skip to content

Commit

Permalink
Implement TCO for try-catch-finally block
Browse files Browse the repository at this point in the history
Signed-off-by: HyukWoo Park <[email protected]>
  • Loading branch information
clover2123 authored and ksh8281 committed Nov 8, 2023
1 parent e706ea6 commit 4581040
Show file tree
Hide file tree
Showing 9 changed files with 191 additions and 39 deletions.
33 changes: 33 additions & 0 deletions src/interpreter/ByteCode.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ struct GlobalVariableAccessCacheItem;
#define FOR_EACH_BYTECODE_TCO_OP(F) \
F(CallReturn) \
F(TailRecursion) \
F(TailRecursionInTry) \
F(CallReturnWithReceiver) \
F(TailRecursionWithReceiver)
#else
Expand Down Expand Up @@ -1685,6 +1686,9 @@ class ControlFlowRecord : public gc {
NeedsReturn,
NeedsJump,
NeedsThrow,
#if defined(ENABLE_TCO)
NeedsRecursion,
#endif
};

ControlFlowRecord(const ControlFlowReason& reason, const Value& value, size_t count = 0, size_t outerLimitCount = SIZE_MAX)
Expand Down Expand Up @@ -1748,6 +1752,7 @@ class ControlFlowRecord : public gc {
union {
Value m_value;
size_t m_wordValue;
size_t m_calleeIndex;
};
// m_count is for saving tryStatementScopeCount of the context which contains
// the occurrence(departure point) of this controlflow (e.g. break;)
Expand Down Expand Up @@ -2056,6 +2061,33 @@ class TailRecursion : public ByteCode {
#endif
};

class TailRecursionInTry : public ByteCode {
public:
TailRecursionInTry(const ByteCodeLOC& loc, const size_t calleeIndex, const size_t argumentsStartIndex, const size_t resultIndex, const size_t argumentCount)
: ByteCode(Opcode::TailRecursionInTryOpcode, loc)
, m_calleeIndex(calleeIndex)
, m_argumentsStartIndex(argumentsStartIndex)
, m_resultIndex(resultIndex)
, m_argumentCount(argumentCount)
{
}
ByteCodeRegisterIndex m_calleeIndex;
ByteCodeRegisterIndex m_argumentsStartIndex;
ByteCodeRegisterIndex m_resultIndex;
uint16_t m_argumentCount;

#ifndef NDEBUG
void dump()
{
if (m_argumentCount) {
printf("tail recursion call in try r%u(r%u-r%u)", m_calleeIndex, m_argumentsStartIndex, m_argumentsStartIndex + m_argumentCount);
} else {
printf("tail recursion call in try r%u()", m_calleeIndex);
}
}
#endif
};

class CallReturnWithReceiver : public ByteCode {
public:
CallReturnWithReceiver(const ByteCodeLOC& loc, const size_t receiverIndex, const size_t calleeIndex, const size_t argumentsStartIndex, const size_t argumentCount)
Expand Down Expand Up @@ -2114,6 +2146,7 @@ class TailRecursionWithReceiver : public ByteCode {

COMPILE_ASSERT(sizeof(CallReturn) == sizeof(TailRecursion), "");
COMPILE_ASSERT(sizeof(CallReturnWithReceiver) == sizeof(TailRecursionWithReceiver), "");
COMPILE_ASSERT(sizeof(Call) == sizeof(TailRecursionInTry), "");
#endif

class CallComplexCase : public ByteCode {
Expand Down
12 changes: 12 additions & 0 deletions src/interpreter/ByteCodeGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ ByteCodeGenerateContext::ByteCodeGenerateContext(InterpretedCodeBlock* codeBlock
, m_isHeadOfMemberExpression(false)
, m_forInOfVarBinding(false)
, m_isLeftBindingAffectedByRightExpression(false)
#if defined(ENABLE_TCO)
, m_tcoDisabled(false)
#endif
, m_registerStack(new std::vector<ByteCodeRegisterIndex>())
, m_lexicallyDeclaredNames(new std::vector<std::pair<size_t, AtomicString>>())
, m_positionToContinue(0)
Expand All @@ -60,7 +63,9 @@ ByteCodeGenerateContext::ByteCodeGenerateContext(InterpretedCodeBlock* codeBlock
, m_lexicalBlockIndex(0)
, m_classInfo()
, m_numeralLiteralData(numeralLiteralData) // should be NumeralLiteralVector
#if defined(ENABLE_TCO)
, m_returnRegister(SIZE_MAX)
#endif
#ifdef ESCARGOT_DEBUGGER
, m_breakpointContext(nullptr)
#endif /* ESCARGOT_DEBUGGER */
Expand Down Expand Up @@ -615,6 +620,13 @@ void ByteCodeGenerator::relocateByteCode(ByteCodeBlock* block)
ASSIGN_STACKINDEX_IF_NEEDED(cd->m_argumentsStartIndex, stackBase, stackBaseWillBe, stackVariableSize);
break;
}
case TailRecursionInTryOpcode: {
TailRecursionInTry* cd = (TailRecursionInTry*)currentCode;
ASSIGN_STACKINDEX_IF_NEEDED(cd->m_calleeIndex, stackBase, stackBaseWillBe, stackVariableSize);
ASSIGN_STACKINDEX_IF_NEEDED(cd->m_argumentsStartIndex, stackBase, stackBaseWillBe, stackVariableSize);
ASSIGN_STACKINDEX_IF_NEEDED(cd->m_resultIndex, stackBase, stackBaseWillBe, stackVariableSize);
break;
}
#endif
case CallComplexCaseOpcode: {
CallComplexCase* cd = (CallComplexCase*)currentCode;
Expand Down
23 changes: 13 additions & 10 deletions src/interpreter/ByteCodeGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ struct ByteCodeGenerateContext {
, m_isHeadOfMemberExpression(false)
, m_forInOfVarBinding(contextBefore.m_forInOfVarBinding)
, m_isLeftBindingAffectedByRightExpression(contextBefore.m_isLeftBindingAffectedByRightExpression)
#if defined(ENABLE_TCO)
, m_tcoDisabled(contextBefore.m_tcoDisabled)
#endif
, m_registerStack(contextBefore.m_registerStack)
, m_lexicallyDeclaredNames(contextBefore.m_lexicallyDeclaredNames)
, m_positionToContinue(contextBefore.m_positionToContinue)
Expand All @@ -97,7 +100,9 @@ struct ByteCodeGenerateContext {
, m_lexicalBlockIndex(contextBefore.m_lexicalBlockIndex)
, m_classInfo(contextBefore.m_classInfo)
, m_numeralLiteralData(contextBefore.m_numeralLiteralData) // should be NumeralLiteralVector
#if defined(ENABLE_TCO)
, m_returnRegister(contextBefore.m_returnRegister)
#endif
#ifdef ESCARGOT_DEBUGGER
, m_breakpointContext(contextBefore.m_breakpointContext)
#endif /* ESCARGOT_DEBUGGER */
Expand Down Expand Up @@ -297,21 +302,13 @@ struct ByteCodeGenerateContext {
return m_recursiveStatementStack.size();
}

bool inTryStatement() const
{
for (size_t i = 0; i < m_recursiveStatementStack.size(); i++) {
if (m_recursiveStatementStack[i].first == Try) {
return true;
}
}
return false;
}

#if defined(ENABLE_TCO)
void setReturnRegister(size_t dstRegister)
{
ASSERT(m_returnRegister != dstRegister);
m_returnRegister = dstRegister;
}
#endif

#ifndef NDEBUG
void checkAllDataUsed()
Expand Down Expand Up @@ -359,6 +356,9 @@ struct ByteCodeGenerateContext {
bool m_isHeadOfMemberExpression : 1;
bool m_forInOfVarBinding : 1;
bool m_isLeftBindingAffectedByRightExpression : 1; // x = delete x; or x = eval("var x"), 1;
#if defined(ENABLE_TCO)
bool m_tcoDisabled : 1; // disable tail call optimizaiton (TCO) for some conditions
#endif

std::shared_ptr<std::vector<ByteCodeRegisterIndex>> m_registerStack;
std::shared_ptr<std::vector<std::pair<size_t, AtomicString>>> m_lexicallyDeclaredNames;
Expand Down Expand Up @@ -389,7 +389,10 @@ struct ByteCodeGenerateContext {
std::map<size_t, size_t> m_complexCaseStatementPositions;
void* m_numeralLiteralData; // should be NumeralLiteralVector

#if defined(ENABLE_TCO)
size_t m_returnRegister; // for tail call optimizaiton (TCO)
#endif

#ifdef ESCARGOT_DEBUGGER
ByteCodeBreakpointContext* m_breakpointContext;
#endif /* ESCARGOT_DEBUGGER */
Expand Down
75 changes: 71 additions & 4 deletions src/interpreter/ByteCodeInterpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1569,7 +1569,7 @@ Value Interpreter::interpret(ExecutionState* state, ByteCodeBlock* byteCodeBlock

// set programCounter
programCounter = reinterpret_cast<size_t>(byteCodeBlock->m_code.data());
state->m_programCounter = &programCounter;
ASSERT(state->m_programCounter == &programCounter);

NEXT_INSTRUCTION();
}
Expand Down Expand Up @@ -1619,10 +1619,42 @@ Value Interpreter::interpret(ExecutionState* state, ByteCodeBlock* byteCodeBlock

// set programCounter
programCounter = reinterpret_cast<size_t>(byteCodeBlock->m_code.data());
state->m_programCounter = &programCounter;
ASSERT(state->m_programCounter == &programCounter);

NEXT_INSTRUCTION();
}

// TCO : tail recursion case in catch or finally block
DEFINE_OPCODE(TailRecursionInTry)
:
{
TailRecursionInTry* code = (TailRecursionInTry*)programCounter;
const Value& callee = registerFile[code->m_calleeIndex];

if (UNLIKELY((callee != Value(state->resolveCallee())) || (state->m_argc != code->m_argumentCount))) {
// goto slow path
code->changeOpcode(Opcode::CallOpcode);
if (UNLIKELY(!callee.isPointerValue())) {
ErrorObject::throwBuiltinError(*state, ErrorCode::TypeError, ErrorObject::Messages::NOT_Callable);
}

// Return F.[[Call]](V, argumentsList).
registerFile[code->m_resultIndex] = callee.asPointerValue()->call(*state, Value(), code->m_argumentCount, &registerFile[code->m_argumentsStartIndex]);

ADD_PROGRAM_COUNTER(Call);
NEXT_INSTRUCTION();
}

ASSERT(callee.isPointerValue() && callee.asPointerValue()->isScriptFunctionObject());
ASSERT(callee.asPointerValue()->asScriptFunctionObject()->codeBlock() == byteCodeBlock->codeBlock());
ASSERT(state->m_argc == code->m_argumentCount);
ASSERT((state->rareData() != nullptr) && state->rareData()->m_controlFlowRecord && state->rareData()->m_controlFlowRecord->size());

// postpone recursion call
// because we need to close the current interpreter routine which is called inside try operation
state->rareData()->m_controlFlowRecord->back() = new ControlFlowRecord(ControlFlowRecord::NeedsRecursion, callee, code->m_argumentCount, code->m_argumentsStartIndex);
return Value(Value::EmptyValue);
}
#endif

#ifdef ESCARGOT_DEBUGGER
Expand Down Expand Up @@ -3181,13 +3213,48 @@ NEVER_INLINE Value InterpreterSlowPath::tryOperation(ExecutionState*& state, siz
ASSERT_NOT_REACHED();
// never get here. but I add return statement for removing compile warning
return Value(Value::EmptyValue);
} else {
ASSERT(record->reason() == ControlFlowRecord::NeedsReturn);
} else if (record->reason() == ControlFlowRecord::NeedsReturn) {
record->m_count--;
if (record->count()) {
state->rareData()->m_controlFlowRecord->back() = record;
}
return record->value();
} else {
#if defined(ENABLE_TCO)
// NeedsRecursion should be allocated in one of catch or finally block
ASSERT(record->reason() == ControlFlowRecord::NeedsRecursion);
ASSERT(!inPauserScope && !inPauserResumeProcess);
ASSERT(code->m_hasCatch || code->m_hasFinalizer);
ASSERT(record->m_value == state->resolveCallee());

Value callee = record->m_value;
size_t argCount = record->m_count;
size_t argStartIndex = record->m_outerLimitCount;
if (argCount) {
// At the start of tail call, we need to allocate a buffer for arguments
// because recursive tail call reuses this buffer
if (UNLIKELY(!state->initTCO())) {
Value* newArgs = (Value*)GC_MALLOC(sizeof(Value) * argCount);
state->setTCOArguments(newArgs);
}

// its safe to overwrite arguments because old arguments are no longer necessary
for (size_t i = 0; i < state->m_argc; i++) {
state->m_argv[i] = registerFile[argStartIndex + i];
}
}

// set this value
registerFile[byteCodeBlock->m_requiredOperandRegisterNumber] = state->inStrictMode() ? Value() : state->context()->globalObjectProxy();

// set programCounter
programCounter = reinterpret_cast<size_t>(byteCodeBlock->m_code.data());
ASSERT(state->m_programCounter == &programCounter);

return Value(Value::EmptyValue);
#else
RELEASE_ASSERT_NOT_REACHED();
#endif
}
} else {
programCounter = jumpTo(codeBuffer, code->m_finallyEndPosition);
Expand Down
26 changes: 20 additions & 6 deletions src/parser/ast/CallExpressionNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -417,10 +417,15 @@ class CallExpressionNode : public ExpressionNode {
if (dstRegister == context->m_returnRegister) {
// Try tail recursion optimization (TCO)
isTailCall = true;
if (context->m_codeBlock->isTailRecursionTarget()) {
codeBlock->pushCode(TailRecursionWithReceiver(ByteCodeLOC(m_loc.index), receiverIndex, calleeIndex, argumentsStartIndex, m_arguments.size()), context, this->m_loc.index);
bool isTailRecursion = context->m_codeBlock->isTailRecursionTarget();
if (UNLIKELY(context->tryCatchWithBlockStatementCount())) {
codeBlock->pushCode(CallWithReceiver(ByteCodeLOC(m_loc.index), receiverIndex, calleeIndex, argumentsStartIndex, dstRegister, m_arguments.size()), context, this->m_loc.index);
} else {
codeBlock->pushCode(CallReturnWithReceiver(ByteCodeLOC(m_loc.index), receiverIndex, calleeIndex, argumentsStartIndex, m_arguments.size()), context, this->m_loc.index);
if (isTailRecursion) {
codeBlock->pushCode(TailRecursionWithReceiver(ByteCodeLOC(m_loc.index), receiverIndex, calleeIndex, argumentsStartIndex, m_arguments.size()), context, this->m_loc.index);
} else {
codeBlock->pushCode(CallReturnWithReceiver(ByteCodeLOC(m_loc.index), receiverIndex, calleeIndex, argumentsStartIndex, m_arguments.size()), context, this->m_loc.index);
}
}
} else {
codeBlock->pushCode(CallWithReceiver(ByteCodeLOC(m_loc.index), receiverIndex, calleeIndex, argumentsStartIndex, dstRegister, m_arguments.size()), context, this->m_loc.index);
Expand All @@ -429,10 +434,19 @@ class CallExpressionNode : public ExpressionNode {
if (dstRegister == context->m_returnRegister) {
// Try tail recursion optimization (TCO)
isTailCall = true;
if (context->m_codeBlock->isTailRecursionTarget()) {
codeBlock->pushCode(TailRecursion(ByteCodeLOC(m_loc.index), calleeIndex, argumentsStartIndex, m_arguments.size()), context, this->m_loc.index);
bool isTailRecursion = context->m_codeBlock->isTailRecursionTarget();
if (UNLIKELY(context->tryCatchWithBlockStatementCount())) {
if (isTailRecursion) {
codeBlock->pushCode(TailRecursionInTry(ByteCodeLOC(m_loc.index), calleeIndex, argumentsStartIndex, dstRegister, m_arguments.size()), context, this->m_loc.index);
} else {
codeBlock->pushCode(Call(ByteCodeLOC(m_loc.index), calleeIndex, argumentsStartIndex, dstRegister, m_arguments.size()), context, this->m_loc.index);
}
} else {
codeBlock->pushCode(CallReturn(ByteCodeLOC(m_loc.index), calleeIndex, argumentsStartIndex, m_arguments.size()), context, this->m_loc.index);
if (isTailRecursion) {
codeBlock->pushCode(TailRecursion(ByteCodeLOC(m_loc.index), calleeIndex, argumentsStartIndex, m_arguments.size()), context, this->m_loc.index);
} else {
codeBlock->pushCode(CallReturn(ByteCodeLOC(m_loc.index), calleeIndex, argumentsStartIndex, m_arguments.size()), context, this->m_loc.index);
}
}
} else {
codeBlock->pushCode(Call(ByteCodeLOC(m_loc.index), calleeIndex, argumentsStartIndex, dstRegister, m_arguments.size()), context, this->m_loc.index);
Expand Down
36 changes: 24 additions & 12 deletions src/parser/ast/ReturnStatementNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,19 @@ class ReturnStatementNode : public StatementNode {
ByteCodeRegisterIndex index;
if (m_argument) {
index = m_argument->getRegister(codeBlock, context);
#if defined(ENABLE_TCO)
if (!context->m_tcoDisabled && (context->tryCatchWithBlockStatementCount() == 1) && ((context->m_recursiveStatementStack.back().first == ByteCodeGenerateContext::Catch) || (context->m_recursiveStatementStack.back().first == ByteCodeGenerateContext::Finally))) {
// consider tail recursion (TCO) for catch, finally block within depth 1
bool isTailCall = false;
context->setReturnRegister(index);
m_argument->generateTCOExpressionByteCode(codeBlock, context, index, isTailCall);
context->setReturnRegister(SIZE_MAX);
} else {
m_argument->generateExpressionByteCode(codeBlock, context, index);
}
#else
m_argument->generateExpressionByteCode(codeBlock, context, index);
#endif
codeBlock->pushCode(ReturnFunctionSlowCase(ByteCodeLOC(m_loc.index), index), context, this->m_loc.index);
} else {
index = context->getRegister();
Expand All @@ -64,27 +76,27 @@ class ReturnStatementNode : public StatementNode {
size_t r;
if (m_argument) {
r = m_argument->getRegister(codeBlock, context);
if (context->tryCatchWithBlockStatementCount() == 0) {
// consider tail recursion (TCO)
context->setReturnRegister(r);

#if defined(ENABLE_TCO)
m_argument->generateTCOExpressionByteCode(codeBlock, context, r, isTailCall);
// consider tail recursion (TCO)
ASSERT(!context->m_tcoDisabled);
context->setReturnRegister(r);
m_argument->generateTCOExpressionByteCode(codeBlock, context, r, isTailCall);
context->setReturnRegister(SIZE_MAX);
#else
m_argument->generateExpressionByteCode(codeBlock, context, r);
m_argument->generateExpressionByteCode(codeBlock, context, r);
#endif
context->setReturnRegister(SIZE_MAX);
} else {
m_argument->generateExpressionByteCode(codeBlock, context, r);
}
} else {
r = context->getRegister();
codeBlock->pushCode(LoadLiteral(ByteCodeLOC(m_loc.index), r, Value()), context, this->m_loc.index);
}

if (!isTailCall || (m_argument->type() != CallExpression)) {
// skip End bytecode only if it directly returns the result of tail call
#if defined(ENABLE_TCO)
if (!isTailCall || (m_argument->type() != CallExpression))
// skip End bytecode only if it directly returns the result of tail call
#endif
codeBlock->pushCode(End(ByteCodeLOC(m_loc.index), r), context, this->m_loc.index);
}

context->giveUpRegister();
}
}
Expand Down
7 changes: 7 additions & 0 deletions src/parser/ast/TaggedTemplateExpressionNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ class TaggedTemplateExpressionNode : public ExpressionNode {
return m_convertedExpression->generateExpressionByteCode(codeBlock, context, dstRegister);
}

#if defined(ENABLE_TCO)
virtual void generateTCOExpressionByteCode(ByteCodeBlock* codeBlock, ByteCodeGenerateContext* context, ByteCodeRegisterIndex dstRegister, bool& isTailCall) override
{
return m_convertedExpression->generateTCOExpressionByteCode(codeBlock, context, dstRegister, isTailCall);
}
#endif

virtual void iterateChildren(const std::function<void(Node* node)>& fn) override
{
fn(this);
Expand Down
Loading

0 comments on commit 4581040

Please sign in to comment.