Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement TCO for try-catch-finally block #1274

Merged
merged 1 commit into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/interpreter/ByteCode.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ struct GlobalVariableAccessCacheItem;
#define FOR_EACH_BYTECODE_TCO_OP(F) \
F(CallReturn) \
F(TailRecursion) \
F(TailRecursionInTry) \
F(CallReturnWithReceiver) \
F(TailRecursionWithReceiver)
#else
Expand Down Expand Up @@ -1685,6 +1686,9 @@ class ControlFlowRecord : public gc {
NeedsReturn,
NeedsJump,
NeedsThrow,
#if defined(ENABLE_TCO)
NeedsRecursion,
#endif
};

ControlFlowRecord(const ControlFlowReason& reason, const Value& value, size_t count = 0, size_t outerLimitCount = SIZE_MAX)
Expand Down Expand Up @@ -1748,6 +1752,7 @@ class ControlFlowRecord : public gc {
union {
Value m_value;
size_t m_wordValue;
size_t m_calleeIndex;
};
// m_count is for saving tryStatementScopeCount of the context which contains
// the occurrence(departure point) of this controlflow (e.g. break;)
Expand Down Expand Up @@ -2056,6 +2061,33 @@ class TailRecursion : public ByteCode {
#endif
};

class TailRecursionInTry : public ByteCode {
public:
TailRecursionInTry(const ByteCodeLOC& loc, const size_t calleeIndex, const size_t argumentsStartIndex, const size_t resultIndex, const size_t argumentCount)
: ByteCode(Opcode::TailRecursionInTryOpcode, loc)
, m_calleeIndex(calleeIndex)
, m_argumentsStartIndex(argumentsStartIndex)
, m_resultIndex(resultIndex)
, m_argumentCount(argumentCount)
{
}
ByteCodeRegisterIndex m_calleeIndex;
ByteCodeRegisterIndex m_argumentsStartIndex;
ByteCodeRegisterIndex m_resultIndex;
uint16_t m_argumentCount;

#ifndef NDEBUG
void dump()
{
if (m_argumentCount) {
printf("tail recursion call in try r%u(r%u-r%u)", m_calleeIndex, m_argumentsStartIndex, m_argumentsStartIndex + m_argumentCount);
} else {
printf("tail recursion call in try r%u()", m_calleeIndex);
}
}
#endif
};

class CallReturnWithReceiver : public ByteCode {
public:
CallReturnWithReceiver(const ByteCodeLOC& loc, const size_t receiverIndex, const size_t calleeIndex, const size_t argumentsStartIndex, const size_t argumentCount)
Expand Down Expand Up @@ -2114,6 +2146,7 @@ class TailRecursionWithReceiver : public ByteCode {

COMPILE_ASSERT(sizeof(CallReturn) == sizeof(TailRecursion), "");
COMPILE_ASSERT(sizeof(CallReturnWithReceiver) == sizeof(TailRecursionWithReceiver), "");
COMPILE_ASSERT(sizeof(Call) == sizeof(TailRecursionInTry), "");
#endif

class CallComplexCase : public ByteCode {
Expand Down
12 changes: 12 additions & 0 deletions src/interpreter/ByteCodeGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ ByteCodeGenerateContext::ByteCodeGenerateContext(InterpretedCodeBlock* codeBlock
, m_isHeadOfMemberExpression(false)
, m_forInOfVarBinding(false)
, m_isLeftBindingAffectedByRightExpression(false)
#if defined(ENABLE_TCO)
, m_tcoDisabled(false)
#endif
, m_registerStack(new std::vector<ByteCodeRegisterIndex>())
, m_lexicallyDeclaredNames(new std::vector<std::pair<size_t, AtomicString>>())
, m_positionToContinue(0)
Expand All @@ -60,7 +63,9 @@ ByteCodeGenerateContext::ByteCodeGenerateContext(InterpretedCodeBlock* codeBlock
, m_lexicalBlockIndex(0)
, m_classInfo()
, m_numeralLiteralData(numeralLiteralData) // should be NumeralLiteralVector
#if defined(ENABLE_TCO)
, m_returnRegister(SIZE_MAX)
#endif
#ifdef ESCARGOT_DEBUGGER
, m_breakpointContext(nullptr)
#endif /* ESCARGOT_DEBUGGER */
Expand Down Expand Up @@ -615,6 +620,13 @@ void ByteCodeGenerator::relocateByteCode(ByteCodeBlock* block)
ASSIGN_STACKINDEX_IF_NEEDED(cd->m_argumentsStartIndex, stackBase, stackBaseWillBe, stackVariableSize);
break;
}
case TailRecursionInTryOpcode: {
TailRecursionInTry* cd = (TailRecursionInTry*)currentCode;
ASSIGN_STACKINDEX_IF_NEEDED(cd->m_calleeIndex, stackBase, stackBaseWillBe, stackVariableSize);
ASSIGN_STACKINDEX_IF_NEEDED(cd->m_argumentsStartIndex, stackBase, stackBaseWillBe, stackVariableSize);
ASSIGN_STACKINDEX_IF_NEEDED(cd->m_resultIndex, stackBase, stackBaseWillBe, stackVariableSize);
break;
}
#endif
case CallComplexCaseOpcode: {
CallComplexCase* cd = (CallComplexCase*)currentCode;
Expand Down
23 changes: 13 additions & 10 deletions src/interpreter/ByteCodeGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ struct ByteCodeGenerateContext {
, m_isHeadOfMemberExpression(false)
, m_forInOfVarBinding(contextBefore.m_forInOfVarBinding)
, m_isLeftBindingAffectedByRightExpression(contextBefore.m_isLeftBindingAffectedByRightExpression)
#if defined(ENABLE_TCO)
, m_tcoDisabled(contextBefore.m_tcoDisabled)
#endif
, m_registerStack(contextBefore.m_registerStack)
, m_lexicallyDeclaredNames(contextBefore.m_lexicallyDeclaredNames)
, m_positionToContinue(contextBefore.m_positionToContinue)
Expand All @@ -97,7 +100,9 @@ struct ByteCodeGenerateContext {
, m_lexicalBlockIndex(contextBefore.m_lexicalBlockIndex)
, m_classInfo(contextBefore.m_classInfo)
, m_numeralLiteralData(contextBefore.m_numeralLiteralData) // should be NumeralLiteralVector
#if defined(ENABLE_TCO)
, m_returnRegister(contextBefore.m_returnRegister)
#endif
#ifdef ESCARGOT_DEBUGGER
, m_breakpointContext(contextBefore.m_breakpointContext)
#endif /* ESCARGOT_DEBUGGER */
Expand Down Expand Up @@ -297,21 +302,13 @@ struct ByteCodeGenerateContext {
return m_recursiveStatementStack.size();
}

bool inTryStatement() const
{
for (size_t i = 0; i < m_recursiveStatementStack.size(); i++) {
if (m_recursiveStatementStack[i].first == Try) {
return true;
}
}
return false;
}

#if defined(ENABLE_TCO)
void setReturnRegister(size_t dstRegister)
{
ASSERT(m_returnRegister != dstRegister);
m_returnRegister = dstRegister;
}
#endif

#ifndef NDEBUG
void checkAllDataUsed()
Expand Down Expand Up @@ -359,6 +356,9 @@ struct ByteCodeGenerateContext {
bool m_isHeadOfMemberExpression : 1;
bool m_forInOfVarBinding : 1;
bool m_isLeftBindingAffectedByRightExpression : 1; // x = delete x; or x = eval("var x"), 1;
#if defined(ENABLE_TCO)
bool m_tcoDisabled : 1; // disable tail call optimizaiton (TCO) for some conditions
#endif

std::shared_ptr<std::vector<ByteCodeRegisterIndex>> m_registerStack;
std::shared_ptr<std::vector<std::pair<size_t, AtomicString>>> m_lexicallyDeclaredNames;
Expand Down Expand Up @@ -389,7 +389,10 @@ struct ByteCodeGenerateContext {
std::map<size_t, size_t> m_complexCaseStatementPositions;
void* m_numeralLiteralData; // should be NumeralLiteralVector

#if defined(ENABLE_TCO)
size_t m_returnRegister; // for tail call optimizaiton (TCO)
#endif

#ifdef ESCARGOT_DEBUGGER
ByteCodeBreakpointContext* m_breakpointContext;
#endif /* ESCARGOT_DEBUGGER */
Expand Down
75 changes: 71 additions & 4 deletions src/interpreter/ByteCodeInterpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1569,7 +1569,7 @@ Value Interpreter::interpret(ExecutionState* state, ByteCodeBlock* byteCodeBlock

// set programCounter
programCounter = reinterpret_cast<size_t>(byteCodeBlock->m_code.data());
state->m_programCounter = &programCounter;
ASSERT(state->m_programCounter == &programCounter);

NEXT_INSTRUCTION();
}
Expand Down Expand Up @@ -1619,10 +1619,42 @@ Value Interpreter::interpret(ExecutionState* state, ByteCodeBlock* byteCodeBlock

// set programCounter
programCounter = reinterpret_cast<size_t>(byteCodeBlock->m_code.data());
state->m_programCounter = &programCounter;
ASSERT(state->m_programCounter == &programCounter);

NEXT_INSTRUCTION();
}

// TCO : tail recursion case in catch or finally block
DEFINE_OPCODE(TailRecursionInTry)
:
{
TailRecursionInTry* code = (TailRecursionInTry*)programCounter;
const Value& callee = registerFile[code->m_calleeIndex];

if (UNLIKELY((callee != Value(state->resolveCallee())) || (state->m_argc != code->m_argumentCount))) {
// goto slow path
code->changeOpcode(Opcode::CallOpcode);
if (UNLIKELY(!callee.isPointerValue())) {
ErrorObject::throwBuiltinError(*state, ErrorCode::TypeError, ErrorObject::Messages::NOT_Callable);
}

// Return F.[[Call]](V, argumentsList).
registerFile[code->m_resultIndex] = callee.asPointerValue()->call(*state, Value(), code->m_argumentCount, &registerFile[code->m_argumentsStartIndex]);

ADD_PROGRAM_COUNTER(Call);
NEXT_INSTRUCTION();
}

ASSERT(callee.isPointerValue() && callee.asPointerValue()->isScriptFunctionObject());
ASSERT(callee.asPointerValue()->asScriptFunctionObject()->codeBlock() == byteCodeBlock->codeBlock());
ASSERT(state->m_argc == code->m_argumentCount);
ASSERT((state->rareData() != nullptr) && state->rareData()->m_controlFlowRecord && state->rareData()->m_controlFlowRecord->size());

// postpone recursion call
// because we need to close the current interpreter routine which is called inside try operation
state->rareData()->m_controlFlowRecord->back() = new ControlFlowRecord(ControlFlowRecord::NeedsRecursion, callee, code->m_argumentCount, code->m_argumentsStartIndex);
return Value(Value::EmptyValue);
}
#endif

#ifdef ESCARGOT_DEBUGGER
Expand Down Expand Up @@ -3181,13 +3213,48 @@ NEVER_INLINE Value InterpreterSlowPath::tryOperation(ExecutionState*& state, siz
ASSERT_NOT_REACHED();
// never get here. but I add return statement for removing compile warning
return Value(Value::EmptyValue);
} else {
ASSERT(record->reason() == ControlFlowRecord::NeedsReturn);
} else if (record->reason() == ControlFlowRecord::NeedsReturn) {
record->m_count--;
if (record->count()) {
state->rareData()->m_controlFlowRecord->back() = record;
}
return record->value();
} else {
#if defined(ENABLE_TCO)
// NeedsRecursion should be allocated in one of catch or finally block
ASSERT(record->reason() == ControlFlowRecord::NeedsRecursion);
ASSERT(!inPauserScope && !inPauserResumeProcess);
ASSERT(code->m_hasCatch || code->m_hasFinalizer);
ASSERT(record->m_value == state->resolveCallee());

Value callee = record->m_value;
size_t argCount = record->m_count;
size_t argStartIndex = record->m_outerLimitCount;
if (argCount) {
// At the start of tail call, we need to allocate a buffer for arguments
// because recursive tail call reuses this buffer
if (UNLIKELY(!state->initTCO())) {
Value* newArgs = (Value*)GC_MALLOC(sizeof(Value) * argCount);
state->setTCOArguments(newArgs);
}

// its safe to overwrite arguments because old arguments are no longer necessary
for (size_t i = 0; i < state->m_argc; i++) {
state->m_argv[i] = registerFile[argStartIndex + i];
}
}

// set this value
registerFile[byteCodeBlock->m_requiredOperandRegisterNumber] = state->inStrictMode() ? Value() : state->context()->globalObjectProxy();

// set programCounter
programCounter = reinterpret_cast<size_t>(byteCodeBlock->m_code.data());
ASSERT(state->m_programCounter == &programCounter);

return Value(Value::EmptyValue);
#else
RELEASE_ASSERT_NOT_REACHED();
#endif
}
} else {
programCounter = jumpTo(codeBuffer, code->m_finallyEndPosition);
Expand Down
26 changes: 20 additions & 6 deletions src/parser/ast/CallExpressionNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -417,10 +417,15 @@ class CallExpressionNode : public ExpressionNode {
if (dstRegister == context->m_returnRegister) {
// Try tail recursion optimization (TCO)
isTailCall = true;
if (context->m_codeBlock->isTailRecursionTarget()) {
codeBlock->pushCode(TailRecursionWithReceiver(ByteCodeLOC(m_loc.index), receiverIndex, calleeIndex, argumentsStartIndex, m_arguments.size()), context, this->m_loc.index);
bool isTailRecursion = context->m_codeBlock->isTailRecursionTarget();
if (UNLIKELY(context->tryCatchWithBlockStatementCount())) {
codeBlock->pushCode(CallWithReceiver(ByteCodeLOC(m_loc.index), receiverIndex, calleeIndex, argumentsStartIndex, dstRegister, m_arguments.size()), context, this->m_loc.index);
} else {
codeBlock->pushCode(CallReturnWithReceiver(ByteCodeLOC(m_loc.index), receiverIndex, calleeIndex, argumentsStartIndex, m_arguments.size()), context, this->m_loc.index);
if (isTailRecursion) {
codeBlock->pushCode(TailRecursionWithReceiver(ByteCodeLOC(m_loc.index), receiverIndex, calleeIndex, argumentsStartIndex, m_arguments.size()), context, this->m_loc.index);
} else {
codeBlock->pushCode(CallReturnWithReceiver(ByteCodeLOC(m_loc.index), receiverIndex, calleeIndex, argumentsStartIndex, m_arguments.size()), context, this->m_loc.index);
}
}
} else {
codeBlock->pushCode(CallWithReceiver(ByteCodeLOC(m_loc.index), receiverIndex, calleeIndex, argumentsStartIndex, dstRegister, m_arguments.size()), context, this->m_loc.index);
Expand All @@ -429,10 +434,19 @@ class CallExpressionNode : public ExpressionNode {
if (dstRegister == context->m_returnRegister) {
// Try tail recursion optimization (TCO)
isTailCall = true;
if (context->m_codeBlock->isTailRecursionTarget()) {
codeBlock->pushCode(TailRecursion(ByteCodeLOC(m_loc.index), calleeIndex, argumentsStartIndex, m_arguments.size()), context, this->m_loc.index);
bool isTailRecursion = context->m_codeBlock->isTailRecursionTarget();
if (UNLIKELY(context->tryCatchWithBlockStatementCount())) {
if (isTailRecursion) {
codeBlock->pushCode(TailRecursionInTry(ByteCodeLOC(m_loc.index), calleeIndex, argumentsStartIndex, dstRegister, m_arguments.size()), context, this->m_loc.index);
} else {
codeBlock->pushCode(Call(ByteCodeLOC(m_loc.index), calleeIndex, argumentsStartIndex, dstRegister, m_arguments.size()), context, this->m_loc.index);
}
} else {
codeBlock->pushCode(CallReturn(ByteCodeLOC(m_loc.index), calleeIndex, argumentsStartIndex, m_arguments.size()), context, this->m_loc.index);
if (isTailRecursion) {
codeBlock->pushCode(TailRecursion(ByteCodeLOC(m_loc.index), calleeIndex, argumentsStartIndex, m_arguments.size()), context, this->m_loc.index);
} else {
codeBlock->pushCode(CallReturn(ByteCodeLOC(m_loc.index), calleeIndex, argumentsStartIndex, m_arguments.size()), context, this->m_loc.index);
}
}
} else {
codeBlock->pushCode(Call(ByteCodeLOC(m_loc.index), calleeIndex, argumentsStartIndex, dstRegister, m_arguments.size()), context, this->m_loc.index);
Expand Down
36 changes: 24 additions & 12 deletions src/parser/ast/ReturnStatementNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,19 @@ class ReturnStatementNode : public StatementNode {
ByteCodeRegisterIndex index;
if (m_argument) {
index = m_argument->getRegister(codeBlock, context);
#if defined(ENABLE_TCO)
if (!context->m_tcoDisabled && (context->tryCatchWithBlockStatementCount() == 1) && ((context->m_recursiveStatementStack.back().first == ByteCodeGenerateContext::Catch) || (context->m_recursiveStatementStack.back().first == ByteCodeGenerateContext::Finally))) {
// consider tail recursion (TCO) for catch, finally block within depth 1
bool isTailCall = false;
context->setReturnRegister(index);
m_argument->generateTCOExpressionByteCode(codeBlock, context, index, isTailCall);
context->setReturnRegister(SIZE_MAX);
} else {
m_argument->generateExpressionByteCode(codeBlock, context, index);
}
#else
m_argument->generateExpressionByteCode(codeBlock, context, index);
#endif
codeBlock->pushCode(ReturnFunctionSlowCase(ByteCodeLOC(m_loc.index), index), context, this->m_loc.index);
} else {
index = context->getRegister();
Expand All @@ -64,27 +76,27 @@ class ReturnStatementNode : public StatementNode {
size_t r;
if (m_argument) {
r = m_argument->getRegister(codeBlock, context);
if (context->tryCatchWithBlockStatementCount() == 0) {
// consider tail recursion (TCO)
context->setReturnRegister(r);

#if defined(ENABLE_TCO)
m_argument->generateTCOExpressionByteCode(codeBlock, context, r, isTailCall);
// consider tail recursion (TCO)
ASSERT(!context->m_tcoDisabled);
context->setReturnRegister(r);
m_argument->generateTCOExpressionByteCode(codeBlock, context, r, isTailCall);
context->setReturnRegister(SIZE_MAX);
#else
m_argument->generateExpressionByteCode(codeBlock, context, r);
m_argument->generateExpressionByteCode(codeBlock, context, r);
#endif
context->setReturnRegister(SIZE_MAX);
} else {
m_argument->generateExpressionByteCode(codeBlock, context, r);
}
} else {
r = context->getRegister();
codeBlock->pushCode(LoadLiteral(ByteCodeLOC(m_loc.index), r, Value()), context, this->m_loc.index);
}

if (!isTailCall || (m_argument->type() != CallExpression)) {
// skip End bytecode only if it directly returns the result of tail call
#if defined(ENABLE_TCO)
if (!isTailCall || (m_argument->type() != CallExpression))
// skip End bytecode only if it directly returns the result of tail call
#endif
codeBlock->pushCode(End(ByteCodeLOC(m_loc.index), r), context, this->m_loc.index);
}

context->giveUpRegister();
}
}
Expand Down
7 changes: 7 additions & 0 deletions src/parser/ast/TaggedTemplateExpressionNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ class TaggedTemplateExpressionNode : public ExpressionNode {
return m_convertedExpression->generateExpressionByteCode(codeBlock, context, dstRegister);
}

#if defined(ENABLE_TCO)
virtual void generateTCOExpressionByteCode(ByteCodeBlock* codeBlock, ByteCodeGenerateContext* context, ByteCodeRegisterIndex dstRegister, bool& isTailCall) override
{
return m_convertedExpression->generateTCOExpressionByteCode(codeBlock, context, dstRegister, isTailCall);
}
#endif

virtual void iterateChildren(const std::function<void(Node* node)>& fn) override
{
fn(this);
Expand Down
Loading