diff --git a/src/interpreter/ByteCode.h b/src/interpreter/ByteCode.h index 2179e686f..f4abd90f7 100644 --- a/src/interpreter/ByteCode.h +++ b/src/interpreter/ByteCode.h @@ -140,6 +140,7 @@ struct GlobalVariableAccessCacheItem; #define FOR_EACH_BYTECODE_TCO_OP(F) \ F(CallReturn) \ F(TailRecursion) \ + F(TailRecursionInTry) \ F(CallReturnWithReceiver) \ F(TailRecursionWithReceiver) #else @@ -1685,6 +1686,9 @@ class ControlFlowRecord : public gc { NeedsReturn, NeedsJump, NeedsThrow, +#if defined(ENABLE_TCO) + NeedsRecursion, +#endif }; ControlFlowRecord(const ControlFlowReason& reason, const Value& value, size_t count = 0, size_t outerLimitCount = SIZE_MAX) @@ -1748,6 +1752,7 @@ class ControlFlowRecord : public gc { union { Value m_value; size_t m_wordValue; + size_t m_calleeIndex; }; // m_count is for saving tryStatementScopeCount of the context which contains // the occurrence(departure point) of this controlflow (e.g. break;) @@ -2056,6 +2061,33 @@ class TailRecursion : public ByteCode { #endif }; +class TailRecursionInTry : public ByteCode { +public: + TailRecursionInTry(const ByteCodeLOC& loc, const size_t calleeIndex, const size_t argumentsStartIndex, const size_t resultIndex, const size_t argumentCount) + : ByteCode(Opcode::TailRecursionInTryOpcode, loc) + , m_calleeIndex(calleeIndex) + , m_argumentsStartIndex(argumentsStartIndex) + , m_resultIndex(resultIndex) + , m_argumentCount(argumentCount) + { + } + ByteCodeRegisterIndex m_calleeIndex; + ByteCodeRegisterIndex m_argumentsStartIndex; + ByteCodeRegisterIndex m_resultIndex; + uint16_t m_argumentCount; + +#ifndef NDEBUG + void dump() + { + if (m_argumentCount) { + printf("tail recursion call in try r%u(r%u-r%u)", m_calleeIndex, m_argumentsStartIndex, m_argumentsStartIndex + m_argumentCount); + } else { + printf("tail recursion call in try r%u()", m_calleeIndex); + } + } +#endif +}; + class CallReturnWithReceiver : public ByteCode { public: CallReturnWithReceiver(const ByteCodeLOC& loc, const size_t receiverIndex, const size_t calleeIndex, const size_t argumentsStartIndex, const size_t argumentCount) @@ -2114,6 +2146,7 @@ class TailRecursionWithReceiver : public ByteCode { COMPILE_ASSERT(sizeof(CallReturn) == sizeof(TailRecursion), ""); COMPILE_ASSERT(sizeof(CallReturnWithReceiver) == sizeof(TailRecursionWithReceiver), ""); +COMPILE_ASSERT(sizeof(Call) == sizeof(TailRecursionInTry), ""); #endif class CallComplexCase : public ByteCode { diff --git a/src/interpreter/ByteCodeGenerator.cpp b/src/interpreter/ByteCodeGenerator.cpp index 1a8cafcc1..d5eb33886 100644 --- a/src/interpreter/ByteCodeGenerator.cpp +++ b/src/interpreter/ByteCodeGenerator.cpp @@ -52,6 +52,9 @@ ByteCodeGenerateContext::ByteCodeGenerateContext(InterpretedCodeBlock* codeBlock , m_isHeadOfMemberExpression(false) , m_forInOfVarBinding(false) , m_isLeftBindingAffectedByRightExpression(false) +#if defined(ENABLE_TCO) + , m_tcoDisabled(false) +#endif , m_registerStack(new std::vector()) , m_lexicallyDeclaredNames(new std::vector>()) , m_positionToContinue(0) @@ -60,7 +63,9 @@ ByteCodeGenerateContext::ByteCodeGenerateContext(InterpretedCodeBlock* codeBlock , m_lexicalBlockIndex(0) , m_classInfo() , m_numeralLiteralData(numeralLiteralData) // should be NumeralLiteralVector +#if defined(ENABLE_TCO) , m_returnRegister(SIZE_MAX) +#endif #ifdef ESCARGOT_DEBUGGER , m_breakpointContext(nullptr) #endif /* ESCARGOT_DEBUGGER */ @@ -615,6 +620,13 @@ void ByteCodeGenerator::relocateByteCode(ByteCodeBlock* block) ASSIGN_STACKINDEX_IF_NEEDED(cd->m_argumentsStartIndex, stackBase, stackBaseWillBe, stackVariableSize); break; } + case TailRecursionInTryOpcode: { + TailRecursionInTry* cd = (TailRecursionInTry*)currentCode; + ASSIGN_STACKINDEX_IF_NEEDED(cd->m_calleeIndex, stackBase, stackBaseWillBe, stackVariableSize); + ASSIGN_STACKINDEX_IF_NEEDED(cd->m_argumentsStartIndex, stackBase, stackBaseWillBe, stackVariableSize); + ASSIGN_STACKINDEX_IF_NEEDED(cd->m_resultIndex, stackBase, stackBaseWillBe, stackVariableSize); + break; + } #endif case CallComplexCaseOpcode: { CallComplexCase* cd = (CallComplexCase*)currentCode; diff --git a/src/interpreter/ByteCodeGenerator.h b/src/interpreter/ByteCodeGenerator.h index ff45aa5ef..c087c5417 100644 --- a/src/interpreter/ByteCodeGenerator.h +++ b/src/interpreter/ByteCodeGenerator.h @@ -88,6 +88,9 @@ struct ByteCodeGenerateContext { , m_isHeadOfMemberExpression(false) , m_forInOfVarBinding(contextBefore.m_forInOfVarBinding) , m_isLeftBindingAffectedByRightExpression(contextBefore.m_isLeftBindingAffectedByRightExpression) +#if defined(ENABLE_TCO) + , m_tcoDisabled(contextBefore.m_tcoDisabled) +#endif , m_registerStack(contextBefore.m_registerStack) , m_lexicallyDeclaredNames(contextBefore.m_lexicallyDeclaredNames) , m_positionToContinue(contextBefore.m_positionToContinue) @@ -97,7 +100,9 @@ struct ByteCodeGenerateContext { , m_lexicalBlockIndex(contextBefore.m_lexicalBlockIndex) , m_classInfo(contextBefore.m_classInfo) , m_numeralLiteralData(contextBefore.m_numeralLiteralData) // should be NumeralLiteralVector +#if defined(ENABLE_TCO) , m_returnRegister(contextBefore.m_returnRegister) +#endif #ifdef ESCARGOT_DEBUGGER , m_breakpointContext(contextBefore.m_breakpointContext) #endif /* ESCARGOT_DEBUGGER */ @@ -297,21 +302,13 @@ struct ByteCodeGenerateContext { return m_recursiveStatementStack.size(); } - bool inTryStatement() const - { - for (size_t i = 0; i < m_recursiveStatementStack.size(); i++) { - if (m_recursiveStatementStack[i].first == Try) { - return true; - } - } - return false; - } - +#if defined(ENABLE_TCO) void setReturnRegister(size_t dstRegister) { ASSERT(m_returnRegister != dstRegister); m_returnRegister = dstRegister; } +#endif #ifndef NDEBUG void checkAllDataUsed() @@ -359,6 +356,9 @@ struct ByteCodeGenerateContext { bool m_isHeadOfMemberExpression : 1; bool m_forInOfVarBinding : 1; bool m_isLeftBindingAffectedByRightExpression : 1; // x = delete x; or x = eval("var x"), 1; +#if defined(ENABLE_TCO) + bool m_tcoDisabled : 1; // disable tail call optimizaiton (TCO) for some conditions +#endif std::shared_ptr> m_registerStack; std::shared_ptr>> m_lexicallyDeclaredNames; @@ -389,7 +389,10 @@ struct ByteCodeGenerateContext { std::map m_complexCaseStatementPositions; void* m_numeralLiteralData; // should be NumeralLiteralVector +#if defined(ENABLE_TCO) size_t m_returnRegister; // for tail call optimizaiton (TCO) +#endif + #ifdef ESCARGOT_DEBUGGER ByteCodeBreakpointContext* m_breakpointContext; #endif /* ESCARGOT_DEBUGGER */ diff --git a/src/interpreter/ByteCodeInterpreter.cpp b/src/interpreter/ByteCodeInterpreter.cpp index 567423ef5..2bb4eda18 100644 --- a/src/interpreter/ByteCodeInterpreter.cpp +++ b/src/interpreter/ByteCodeInterpreter.cpp @@ -1569,7 +1569,7 @@ Value Interpreter::interpret(ExecutionState* state, ByteCodeBlock* byteCodeBlock // set programCounter programCounter = reinterpret_cast(byteCodeBlock->m_code.data()); - state->m_programCounter = &programCounter; + ASSERT(state->m_programCounter == &programCounter); NEXT_INSTRUCTION(); } @@ -1619,10 +1619,42 @@ Value Interpreter::interpret(ExecutionState* state, ByteCodeBlock* byteCodeBlock // set programCounter programCounter = reinterpret_cast(byteCodeBlock->m_code.data()); - state->m_programCounter = &programCounter; + ASSERT(state->m_programCounter == &programCounter); NEXT_INSTRUCTION(); } + + // TCO : tail recursion case in catch or finally block + DEFINE_OPCODE(TailRecursionInTry) + : + { + TailRecursionInTry* code = (TailRecursionInTry*)programCounter; + const Value& callee = registerFile[code->m_calleeIndex]; + + if (UNLIKELY((callee != Value(state->resolveCallee())) || (state->m_argc != code->m_argumentCount))) { + // goto slow path + code->changeOpcode(Opcode::CallOpcode); + if (UNLIKELY(!callee.isPointerValue())) { + ErrorObject::throwBuiltinError(*state, ErrorCode::TypeError, ErrorObject::Messages::NOT_Callable); + } + + // Return F.[[Call]](V, argumentsList). + registerFile[code->m_resultIndex] = callee.asPointerValue()->call(*state, Value(), code->m_argumentCount, ®isterFile[code->m_argumentsStartIndex]); + + ADD_PROGRAM_COUNTER(Call); + NEXT_INSTRUCTION(); + } + + ASSERT(callee.isPointerValue() && callee.asPointerValue()->isScriptFunctionObject()); + ASSERT(callee.asPointerValue()->asScriptFunctionObject()->codeBlock() == byteCodeBlock->codeBlock()); + ASSERT(state->m_argc == code->m_argumentCount); + ASSERT((state->rareData() != nullptr) && state->rareData()->m_controlFlowRecord && state->rareData()->m_controlFlowRecord->size()); + + // postpone recursion call + // because we need to close the current interpreter routine which is called inside try operation + state->rareData()->m_controlFlowRecord->back() = new ControlFlowRecord(ControlFlowRecord::NeedsRecursion, callee, code->m_argumentCount, code->m_argumentsStartIndex); + return Value(Value::EmptyValue); + } #endif #ifdef ESCARGOT_DEBUGGER @@ -3181,13 +3213,48 @@ NEVER_INLINE Value InterpreterSlowPath::tryOperation(ExecutionState*& state, siz ASSERT_NOT_REACHED(); // never get here. but I add return statement for removing compile warning return Value(Value::EmptyValue); - } else { - ASSERT(record->reason() == ControlFlowRecord::NeedsReturn); + } else if (record->reason() == ControlFlowRecord::NeedsReturn) { record->m_count--; if (record->count()) { state->rareData()->m_controlFlowRecord->back() = record; } return record->value(); + } else { +#if defined(ENABLE_TCO) + // NeedsRecursion should be allocated in one of catch or finally block + ASSERT(record->reason() == ControlFlowRecord::NeedsRecursion); + ASSERT(!inPauserScope && !inPauserResumeProcess); + ASSERT(code->m_hasCatch || code->m_hasFinalizer); + ASSERT(record->m_value == state->resolveCallee()); + + Value callee = record->m_value; + size_t argCount = record->m_count; + size_t argStartIndex = record->m_outerLimitCount; + if (argCount) { + // At the start of tail call, we need to allocate a buffer for arguments + // because recursive tail call reuses this buffer + if (UNLIKELY(!state->initTCO())) { + Value* newArgs = (Value*)GC_MALLOC(sizeof(Value) * argCount); + state->setTCOArguments(newArgs); + } + + // its safe to overwrite arguments because old arguments are no longer necessary + for (size_t i = 0; i < state->m_argc; i++) { + state->m_argv[i] = registerFile[argStartIndex + i]; + } + } + + // set this value + registerFile[byteCodeBlock->m_requiredOperandRegisterNumber] = state->inStrictMode() ? Value() : state->context()->globalObjectProxy(); + + // set programCounter + programCounter = reinterpret_cast(byteCodeBlock->m_code.data()); + ASSERT(state->m_programCounter == &programCounter); + + return Value(Value::EmptyValue); +#else + RELEASE_ASSERT_NOT_REACHED(); +#endif } } else { programCounter = jumpTo(codeBuffer, code->m_finallyEndPosition); diff --git a/src/parser/ast/CallExpressionNode.h b/src/parser/ast/CallExpressionNode.h index fac7dafd6..629d0b14f 100644 --- a/src/parser/ast/CallExpressionNode.h +++ b/src/parser/ast/CallExpressionNode.h @@ -417,10 +417,15 @@ class CallExpressionNode : public ExpressionNode { if (dstRegister == context->m_returnRegister) { // Try tail recursion optimization (TCO) isTailCall = true; - if (context->m_codeBlock->isTailRecursionTarget()) { - codeBlock->pushCode(TailRecursionWithReceiver(ByteCodeLOC(m_loc.index), receiverIndex, calleeIndex, argumentsStartIndex, m_arguments.size()), context, this->m_loc.index); + bool isTailRecursion = context->m_codeBlock->isTailRecursionTarget(); + if (UNLIKELY(context->tryCatchWithBlockStatementCount())) { + codeBlock->pushCode(CallWithReceiver(ByteCodeLOC(m_loc.index), receiverIndex, calleeIndex, argumentsStartIndex, dstRegister, m_arguments.size()), context, this->m_loc.index); } else { - codeBlock->pushCode(CallReturnWithReceiver(ByteCodeLOC(m_loc.index), receiverIndex, calleeIndex, argumentsStartIndex, m_arguments.size()), context, this->m_loc.index); + if (isTailRecursion) { + codeBlock->pushCode(TailRecursionWithReceiver(ByteCodeLOC(m_loc.index), receiverIndex, calleeIndex, argumentsStartIndex, m_arguments.size()), context, this->m_loc.index); + } else { + codeBlock->pushCode(CallReturnWithReceiver(ByteCodeLOC(m_loc.index), receiverIndex, calleeIndex, argumentsStartIndex, m_arguments.size()), context, this->m_loc.index); + } } } else { codeBlock->pushCode(CallWithReceiver(ByteCodeLOC(m_loc.index), receiverIndex, calleeIndex, argumentsStartIndex, dstRegister, m_arguments.size()), context, this->m_loc.index); @@ -429,10 +434,19 @@ class CallExpressionNode : public ExpressionNode { if (dstRegister == context->m_returnRegister) { // Try tail recursion optimization (TCO) isTailCall = true; - if (context->m_codeBlock->isTailRecursionTarget()) { - codeBlock->pushCode(TailRecursion(ByteCodeLOC(m_loc.index), calleeIndex, argumentsStartIndex, m_arguments.size()), context, this->m_loc.index); + bool isTailRecursion = context->m_codeBlock->isTailRecursionTarget(); + if (UNLIKELY(context->tryCatchWithBlockStatementCount())) { + if (isTailRecursion) { + codeBlock->pushCode(TailRecursionInTry(ByteCodeLOC(m_loc.index), calleeIndex, argumentsStartIndex, dstRegister, m_arguments.size()), context, this->m_loc.index); + } else { + codeBlock->pushCode(Call(ByteCodeLOC(m_loc.index), calleeIndex, argumentsStartIndex, dstRegister, m_arguments.size()), context, this->m_loc.index); + } } else { - codeBlock->pushCode(CallReturn(ByteCodeLOC(m_loc.index), calleeIndex, argumentsStartIndex, m_arguments.size()), context, this->m_loc.index); + if (isTailRecursion) { + codeBlock->pushCode(TailRecursion(ByteCodeLOC(m_loc.index), calleeIndex, argumentsStartIndex, m_arguments.size()), context, this->m_loc.index); + } else { + codeBlock->pushCode(CallReturn(ByteCodeLOC(m_loc.index), calleeIndex, argumentsStartIndex, m_arguments.size()), context, this->m_loc.index); + } } } else { codeBlock->pushCode(Call(ByteCodeLOC(m_loc.index), calleeIndex, argumentsStartIndex, dstRegister, m_arguments.size()), context, this->m_loc.index); diff --git a/src/parser/ast/ReturnStatementNode.h b/src/parser/ast/ReturnStatementNode.h index b978fd3e9..0dab3bfe9 100644 --- a/src/parser/ast/ReturnStatementNode.h +++ b/src/parser/ast/ReturnStatementNode.h @@ -51,7 +51,19 @@ class ReturnStatementNode : public StatementNode { ByteCodeRegisterIndex index; if (m_argument) { index = m_argument->getRegister(codeBlock, context); +#if defined(ENABLE_TCO) + if (!context->m_tcoDisabled && (context->tryCatchWithBlockStatementCount() == 1) && ((context->m_recursiveStatementStack.back().first == ByteCodeGenerateContext::Catch) || (context->m_recursiveStatementStack.back().first == ByteCodeGenerateContext::Finally))) { + // consider tail recursion (TCO) for catch, finally block within depth 1 + bool isTailCall = false; + context->setReturnRegister(index); + m_argument->generateTCOExpressionByteCode(codeBlock, context, index, isTailCall); + context->setReturnRegister(SIZE_MAX); + } else { + m_argument->generateExpressionByteCode(codeBlock, context, index); + } +#else m_argument->generateExpressionByteCode(codeBlock, context, index); +#endif codeBlock->pushCode(ReturnFunctionSlowCase(ByteCodeLOC(m_loc.index), index), context, this->m_loc.index); } else { index = context->getRegister(); @@ -64,27 +76,27 @@ class ReturnStatementNode : public StatementNode { size_t r; if (m_argument) { r = m_argument->getRegister(codeBlock, context); - if (context->tryCatchWithBlockStatementCount() == 0) { - // consider tail recursion (TCO) - context->setReturnRegister(r); + #if defined(ENABLE_TCO) - m_argument->generateTCOExpressionByteCode(codeBlock, context, r, isTailCall); + // consider tail recursion (TCO) + ASSERT(!context->m_tcoDisabled); + context->setReturnRegister(r); + m_argument->generateTCOExpressionByteCode(codeBlock, context, r, isTailCall); + context->setReturnRegister(SIZE_MAX); #else - m_argument->generateExpressionByteCode(codeBlock, context, r); + m_argument->generateExpressionByteCode(codeBlock, context, r); #endif - context->setReturnRegister(SIZE_MAX); - } else { - m_argument->generateExpressionByteCode(codeBlock, context, r); - } } else { r = context->getRegister(); codeBlock->pushCode(LoadLiteral(ByteCodeLOC(m_loc.index), r, Value()), context, this->m_loc.index); } - if (!isTailCall || (m_argument->type() != CallExpression)) { - // skip End bytecode only if it directly returns the result of tail call +#if defined(ENABLE_TCO) + if (!isTailCall || (m_argument->type() != CallExpression)) + // skip End bytecode only if it directly returns the result of tail call +#endif codeBlock->pushCode(End(ByteCodeLOC(m_loc.index), r), context, this->m_loc.index); - } + context->giveUpRegister(); } } diff --git a/src/parser/ast/TaggedTemplateExpressionNode.h b/src/parser/ast/TaggedTemplateExpressionNode.h index 0457ea368..8ed2cc2b6 100644 --- a/src/parser/ast/TaggedTemplateExpressionNode.h +++ b/src/parser/ast/TaggedTemplateExpressionNode.h @@ -52,6 +52,13 @@ class TaggedTemplateExpressionNode : public ExpressionNode { return m_convertedExpression->generateExpressionByteCode(codeBlock, context, dstRegister); } +#if defined(ENABLE_TCO) + virtual void generateTCOExpressionByteCode(ByteCodeBlock* codeBlock, ByteCodeGenerateContext* context, ByteCodeRegisterIndex dstRegister, bool& isTailCall) override + { + return m_convertedExpression->generateTCOExpressionByteCode(codeBlock, context, dstRegister, isTailCall); + } +#endif + virtual void iterateChildren(const std::function& fn) override { fn(this); diff --git a/src/parser/ast/TryStatementNode.h b/src/parser/ast/TryStatementNode.h index 21e3e7893..fadfc65dc 100644 --- a/src/parser/ast/TryStatementNode.h +++ b/src/parser/ast/TryStatementNode.h @@ -58,13 +58,19 @@ class TryStatementNode : public StatementNode { ctx.tryCatchBodyPos = codeBlock->lastCodePosition(); } - static void generateTryHandlerStatementStartByteCode(ByteCodeBlock *codeBlock, ByteCodeGenerateContext *context, Node *self, TryStatementByteCodeContext &ctx, CatchClauseNode *handler) + static void generateTryHandlerStatementStartByteCode(ByteCodeBlock *codeBlock, ByteCodeGenerateContext *context, Node *self, TryStatementByteCodeContext &ctx, CatchClauseNode *handler, bool hasFinalizer) { context->m_recursiveStatementStack.pop_back(); context->m_recursiveStatementStack.push_back(std::make_pair(ByteCodeGenerateContext::Catch, ctx.tryStartPosition)); codeBlock->peekCode(ctx.tryStartPosition)->m_hasCatch = true; codeBlock->peekCode(ctx.tryStartPosition)->m_catchPosition = codeBlock->currentCodeSize(); +#if defined(ENABLE_TCO) + // if try statement has a finally block, TCO should be disabled for this catch block + bool beforeTCODisabled = context->m_tcoDisabled; + context->m_tcoDisabled |= hasFinalizer; +#endif + // catch paramter block size_t lexicalBlockIndexBefore = context->m_lexicalBlockIndex; ByteCodeBlock::ByteCodeLexicalBlockContext blockContext; @@ -116,6 +122,9 @@ class TryStatementNode : public StatementNode { codeBlock->pushCode(CloseLexicalEnvironment(ByteCodeLOC(self->loc().index)), context, self->m_loc.index); context->m_recursiveStatementStack.pop_back(); context->m_recursiveStatementStack.push_back(std::make_pair(ByteCodeGenerateContext::Try, ctx.tryStartPosition)); +#if defined(ENABLE_TCO) + context->m_tcoDisabled = beforeTCODisabled; +#endif } static void generateTryFinalizerStatementStartByteCode(ByteCodeBlock *codeBlock, ByteCodeGenerateContext *context, Node *self, TryStatementByteCodeContext &ctx, bool hasFinalizer) @@ -159,7 +168,7 @@ class TryStatementNode : public StatementNode { generateTryStatementBodyEndByteCode(codeBlock, context, this, ctx); if (m_handler) { - generateTryHandlerStatementStartByteCode(codeBlock, context, this, ctx, m_handler); + generateTryHandlerStatementStartByteCode(codeBlock, context, this, ctx, m_handler, m_finalizer != nullptr); } generateTryFinalizerStatementStartByteCode(codeBlock, context, this, ctx, m_finalizer != nullptr); diff --git a/tools/test/test262/excludelist.orig.xml b/tools/test/test262/excludelist.orig.xml index 272273c7e..d40eda4d4 100644 --- a/tools/test/test262/excludelist.orig.xml +++ b/tools/test/test262/excludelist.orig.xml @@ -5166,8 +5166,6 @@ TODO TODO TODO - TODO - TODO TODO TODO TODO @@ -5359,9 +5357,6 @@ TODO TODO TODO - TODO - TODO - TODO TODO TODO TODO