Skip to content

Commit

Permalink
Cache memory data in instance
Browse files Browse the repository at this point in the history
Signed-off-by: Zoltan Herczeg [email protected]
  • Loading branch information
Zoltan Herczeg committed Sep 27, 2024
1 parent 4e655cb commit 8e3ae56
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 96 deletions.
46 changes: 4 additions & 42 deletions src/jit/Backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,7 @@ CompileContext::CompileContext(Module* module, JITCompiler* compiler)
#if (defined SLJIT_CONFIG_X86 && SLJIT_CONFIG_X86)
, shuffleOffset(0)
#endif /* SLJIT_CONFIG_X86 */
, stackTmpStart(0)
, stackMemoryStart(sizeof(sljit_sw))
, stackTmpStart(sizeof(sljit_sw))
, nextTryBlock(0)
, currentTryBlock(InstanceConstData::globalTryBlock)
, trapBlocksStart(0)
Expand All @@ -233,7 +232,9 @@ CompileContext::CompileContext(Module* module, JITCompiler* compiler)
{
// Compiler is not initialized yet.
size_t offset = Instance::alignedSize();
globalsStart = offset + sizeof(void*) * module->numberOfMemoryTypes();
size_t numberOfMemoryTypes = module->numberOfMemoryTypes();
targetBuffersStart = offset + numberOfMemoryTypes * sizeof(void*);
globalsStart = targetBuffersStart + Memory::TargetBuffer::sizeInPointers(numberOfMemoryTypes) * sizeof(void*);
tableStart = globalsStart + module->numberOfGlobalTypes() * sizeof(void*);
functionsStart = tableStart + module->numberOfTableTypes() * sizeof(void*);

Expand Down Expand Up @@ -1020,7 +1021,6 @@ JITCompiler::JITCompiler(Module* module, uint32_t JITFlags)
, m_savedIntegerRegCount(0)
, m_savedFloatRegCount(0)
, m_stackTmpSize(0)
, m_useMemory0(false)
{
if (module->m_jitModule != nullptr) {
ASSERT(module->m_jitModule->m_instanceConstData != nullptr);
Expand All @@ -1038,10 +1038,6 @@ void JITCompiler::compileFunction(JITFunction* jitFunc, bool isExternal)

m_functionList.push_back(FunctionList(jitFunc, isExternal, m_branchTableSize));

sljit_uw stackTmpStart = m_context.stackMemoryStart + (m_useMemory0 ? sizeof(Memory::TargetBuffer) : 0);
// Align data.
m_context.stackTmpStart = static_cast<sljit_sw>((stackTmpStart + sizeof(sljit_sw) - 1) & ~(sizeof(sljit_sw) - 1));

if (m_compiler == nullptr) {
// First compiled function.
m_compiler = sljit_create_compiler(nullptr);
Expand Down Expand Up @@ -1482,7 +1478,6 @@ void JITCompiler::clear()
m_last = nullptr;
m_branchTableSize = 0;
m_stackTmpSize = 0;
m_useMemory0 = false;
#if (defined SLJIT_CONFIG_X86 && SLJIT_CONFIG_X86)
m_context.shuffleOffset = 0;
#endif /* SLJIT_CONFIG_X86 */
Expand Down Expand Up @@ -1530,24 +1525,6 @@ void JITCompiler::emitProlog()
(m_savedIntegerRegCount + 2) | SLJIT_ENTER_FLOAT(m_savedFloatRegCount), m_context.stackTmpStart + m_stackTmpSize);

sljit_emit_op1(m_compiler, SLJIT_MOV, SLJIT_MEM1(SLJIT_SP), kContextOffset, SLJIT_R0, 0);
if (hasMemory0()) {
sljit_sw stackMemoryStart = m_context.stackMemoryStart;
ASSERT(m_context.stackTmpStart >= stackMemoryStart + static_cast<sljit_sw>(sizeof(Memory::TargetBuffer)));

sljit_emit_op1(m_compiler, SLJIT_MOV, SLJIT_R0, 0, SLJIT_MEM1(kInstanceReg), Instance::alignedSize());

sljit_emit_op1(m_compiler, SLJIT_MOV_P, SLJIT_R1, 0, SLJIT_MEM1(SLJIT_R0), offsetof(Memory, m_targetBuffers));
sljit_emit_op1(m_compiler, SLJIT_MOV, SLJIT_R2, 0, SLJIT_MEM1(SLJIT_R0), offsetof(Memory, m_sizeInByte) + WORD_LOW_OFFSET);
sljit_get_local_base(m_compiler, SLJIT_MEM1(SLJIT_R0), offsetof(Memory, m_targetBuffers), stackMemoryStart);
sljit_emit_op1(m_compiler, SLJIT_MOV_P, SLJIT_R0, 0, SLJIT_MEM1(SLJIT_R0), offsetof(Memory, m_buffer));

#if (defined SLJIT_32BIT_ARCHITECTURE && SLJIT_32BIT_ARCHITECTURE)
sljit_emit_op1(m_compiler, SLJIT_MOV, SLJIT_MEM1(SLJIT_SP), stackMemoryStart + offsetof(Memory::TargetBuffer, sizeInByte) + WORD_HIGH_OFFSET, SLJIT_IMM, 0);
#endif /* SLJIT_32BIT_ARCHITECTURE */
sljit_emit_op1(m_compiler, SLJIT_MOV_P, SLJIT_MEM1(SLJIT_SP), stackMemoryStart + offsetof(Memory::TargetBuffer, prev), SLJIT_R1, 0);
sljit_emit_op1(m_compiler, SLJIT_MOV, SLJIT_MEM1(SLJIT_SP), stackMemoryStart + offsetof(Memory::TargetBuffer, sizeInByte) + WORD_LOW_OFFSET, SLJIT_R2, 0);
sljit_emit_op1(m_compiler, SLJIT_MOV_P, SLJIT_MEM1(SLJIT_SP), stackMemoryStart + offsetof(Memory::TargetBuffer, buffer), SLJIT_R0, 0);
}

m_context.branchTableOffset = 0;
size_t size = func.branchTableSize * sizeof(sljit_up);
Expand All @@ -1568,19 +1545,6 @@ void JITCompiler::emitProlog()
}
}

void JITCompiler::emitRestoreMemories()
{
if (!hasMemory0()) {
return;
}

sljit_sw stackMemoryStart = m_context.stackMemoryStart;

sljit_emit_op1(m_compiler, SLJIT_MOV, SLJIT_R1, 0, SLJIT_MEM1(kInstanceReg), Instance::alignedSize());
sljit_emit_op1(m_compiler, SLJIT_MOV_P, SLJIT_R2, 0, SLJIT_MEM1(SLJIT_SP), stackMemoryStart + offsetof(Memory::TargetBuffer, prev));
sljit_emit_op1(m_compiler, SLJIT_MOV_P, SLJIT_MEM1(SLJIT_R1), offsetof(Memory, m_targetBuffers), SLJIT_R2, 0);
}

void JITCompiler::emitEpilog()
{
FunctionList& func = m_functionList.back();
Expand All @@ -1598,7 +1562,6 @@ void JITCompiler::emitEpilog()
m_context.earlyReturns.clear();
}

emitRestoreMemories();
sljit_emit_return(m_compiler, SLJIT_MOV_P, SLJIT_R0, 0);

m_context.emitSlowCases(m_compiler);
Expand Down Expand Up @@ -1661,7 +1624,6 @@ void JITCompiler::emitEpilog()
sljit_emit_op_dst(m_compiler, SLJIT_GET_RETURN_ADDRESS, SLJIT_R1, 0);
sljit_emit_icall(m_compiler, SLJIT_CALL, SLJIT_ARGS2(W, W, W), SLJIT_IMM, GET_FUNC_ADDR(sljit_sw, getTrapHandler));

emitRestoreMemories();
sljit_emit_return_to(m_compiler, SLJIT_R0, 0);

while (trapJumpIndex < trapJumps.size()) {
Expand Down
17 changes: 0 additions & 17 deletions src/jit/ByteCodeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1022,14 +1022,12 @@ static void compileFunction(JITCompiler* compiler)
case ByteCode::Load32Opcode: {
group = Instruction::Load;
paramType = ParamTypes::ParamSrcDst;
compiler->useMemory0();
requiredInit = OTLoadI32;
break;
}
case ByteCode::Load64Opcode: {
group = Instruction::Load;
paramType = ParamTypes::ParamSrcDst;
compiler->useMemory0();
requiredInit = OTLoadI64;
break;
}
Expand All @@ -1049,7 +1047,6 @@ static void compileFunction(JITCompiler* compiler)
case ByteCode::I64Load32UOpcode: {
group = Instruction::Load;
paramType = ParamTypes::ParamSrcDstValue;
compiler->useMemory0();
if (requiredInit == OTNone) {
requiredInit = OTLoadI64;
}
Expand All @@ -1072,7 +1069,6 @@ static void compileFunction(JITCompiler* compiler)
case ByteCode::V128Load64ZeroOpcode: {
group = Instruction::Load;
paramType = ParamTypes::ParamSrcDstValue;
compiler->useMemory0();

if (opcode == ByteCode::F32LoadOpcode)
requiredInit = OTLoadF32;
Expand All @@ -1089,7 +1085,6 @@ static void compileFunction(JITCompiler* compiler)
SIMDMemoryLoad* loadOperation = reinterpret_cast<SIMDMemoryLoad*>(byteCode);
Instruction* instr = compiler->append(byteCode, Instruction::LoadLaneSIMD, opcode, 2, 1);
instr->setRequiredRegsDescriptor(OTLoadLaneV128);
compiler->useMemory0();

Operand* operands = instr->operands();
operands[0] = STACK_OFFSET(loadOperation->src0Offset());
Expand All @@ -1100,14 +1095,12 @@ static void compileFunction(JITCompiler* compiler)
case ByteCode::Store32Opcode: {
group = Instruction::Store;
paramType = ParamTypes::ParamSrc2;
compiler->useMemory0();
requiredInit = OTStoreI32;
break;
}
case ByteCode::Store64Opcode: {
group = Instruction::Store;
paramType = ParamTypes::ParamSrc2;
compiler->useMemory0();
requiredInit = OTStoreI64;
break;
}
Expand All @@ -1127,7 +1120,6 @@ static void compileFunction(JITCompiler* compiler)
case ByteCode::I64StoreOpcode: {
group = Instruction::Store;
paramType = ParamTypes::ParamSrc2Value;
compiler->useMemory0();
if (requiredInit == OTNone) {
requiredInit = OTStoreI64;
}
Expand All @@ -1138,7 +1130,6 @@ static void compileFunction(JITCompiler* compiler)
case ByteCode::V128StoreOpcode: {
group = Instruction::Store;
paramType = ParamTypes::ParamSrc2Value;
compiler->useMemory0();

if (opcode == ByteCode::F32StoreOpcode)
requiredInit = OTStoreF32;
Expand All @@ -1155,7 +1146,6 @@ static void compileFunction(JITCompiler* compiler)
SIMDMemoryStore* storeOperation = reinterpret_cast<SIMDMemoryStore*>(byteCode);
Instruction* instr = compiler->append(byteCode, Instruction::Store, opcode, 2, 0);
instr->setRequiredRegsDescriptor(OTStoreV128);
compiler->useMemory0();

Operand* operands = instr->operands();
operands[0] = STACK_OFFSET(storeOperation->src0Offset());
Expand Down Expand Up @@ -1330,7 +1320,6 @@ static void compileFunction(JITCompiler* compiler)

Instruction* instr = compiler->append(byteCode, Instruction::Memory, opcode, 0, 1);
instr->setRequiredRegsDescriptor(OTPutI32);
compiler->useMemory0();

*instr->operands() = STACK_OFFSET(memorySize->dstOffset());
break;
Expand Down Expand Up @@ -1874,7 +1863,6 @@ static void compileFunction(JITCompiler* compiler)
compiler->increaseStackTmpSize(8);
}
#endif /* SLJIT_32BIT_ARCHITECTURE */
compiler->useMemory0();
if (requiredInit == OTNone) {
requiredInit = OTLoadI64;
}
Expand All @@ -1899,7 +1887,6 @@ static void compileFunction(JITCompiler* compiler)
compiler->increaseStackTmpSize(8);
}
#endif /* SLJIT_32BIT_ARCHITECTURE */
compiler->useMemory0();
if (requiredInit == OTNone) {
requiredInit = OTStoreI64;
}
Expand Down Expand Up @@ -1965,7 +1952,6 @@ static void compileFunction(JITCompiler* compiler)
AtomicRmw* atomicRmw = reinterpret_cast<AtomicRmw*>(byteCode);
Operand* operands = instr->operands();
instr->setRequiredRegsDescriptor(requiredInit != OTNone ? requiredInit : OTAtomicRmwI64);
compiler->useMemory0();

operands[0] = STACK_OFFSET(atomicRmw->src0Offset());
operands[1] = STACK_OFFSET(atomicRmw->src1Offset());
Expand Down Expand Up @@ -1997,7 +1983,6 @@ static void compileFunction(JITCompiler* compiler)
AtomicRmwCmpxchg* atomicRmwCmpxchg = reinterpret_cast<AtomicRmwCmpxchg*>(byteCode);
Operand* operands = instr->operands();
instr->setRequiredRegsDescriptor(requiredInit != OTNone ? requiredInit : OTAtomicRmwCmpxchgI64);
compiler->useMemory0();

operands[0] = STACK_OFFSET(atomicRmwCmpxchg->src0Offset());
operands[1] = STACK_OFFSET(atomicRmwCmpxchg->src1Offset());
Expand All @@ -2017,7 +2002,6 @@ static void compileFunction(JITCompiler* compiler)
Operand* operands = instr->operands();
instr->setRequiredRegsDescriptor(requiredInit != OTNone ? requiredInit : OTAtomicWaitI32);
compiler->increaseStackTmpSize(16);
compiler->useMemory0();

operands[0] = STACK_OFFSET(memoryAtomicWait->src0Offset());
operands[1] = STACK_OFFSET(memoryAtomicWait->src1Offset());
Expand All @@ -2032,7 +2016,6 @@ static void compileFunction(JITCompiler* compiler)
MemoryAtomicNotify* memoryAtomicNotify = reinterpret_cast<MemoryAtomicNotify*>(byteCode);
Operand* operands = instr->operands();
instr->setRequiredRegsDescriptor(OTAtomicNotify);
compiler->useMemory0();

operands[0] = STACK_OFFSET(memoryAtomicNotify->src0Offset());
operands[1] = STACK_OFFSET(memoryAtomicNotify->src1Offset());
Expand Down
14 changes: 1 addition & 13 deletions src/jit/Compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -585,11 +585,11 @@ struct CompileContext {
#if (defined SLJIT_CONFIG_X86 && SLJIT_CONFIG_X86)
uintptr_t shuffleOffset;
#endif /* SLJIT_CONFIG_X86 */
size_t targetBuffersStart;
size_t globalsStart;
size_t tableStart;
size_t functionsStart;
sljit_sw stackTmpStart;
sljit_sw stackMemoryStart;
size_t nextTryBlock;
size_t currentTryBlock;
size_t trapBlocksStart;
Expand Down Expand Up @@ -761,16 +761,6 @@ class JITCompiler {
}
}

void useMemory0()
{
m_useMemory0 = true;
}

bool hasMemory0()
{
return m_useMemory0;
}

void setModuleFunction(ModuleFunction* moduleFunction)
{
m_moduleFunction = moduleFunction;
Expand Down Expand Up @@ -817,7 +807,6 @@ class JITCompiler {
// Backend operations.
void emitProlog();
void emitEpilog();
void emitRestoreMemories();

#if !defined(NDEBUG)
static const char* m_byteCodeNames[];
Expand All @@ -841,7 +830,6 @@ class JITCompiler {
uint8_t m_savedIntegerRegCount;
uint8_t m_savedFloatRegCount;
uint8_t m_stackTmpSize;
bool m_useMemory0;

std::vector<TryBlock> m_tryBlocks;
std::vector<FunctionList> m_functionList;
Expand Down
23 changes: 11 additions & 12 deletions src/jit/MemoryInl.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,8 @@ struct MemAddress {
void MemAddress::check(sljit_compiler* compiler, Operand* offsetOperand, sljit_uw offset, sljit_u32 size)
{
CompileContext* context = CompileContext::get(compiler);
sljit_sw stackMemoryStart = context->stackMemoryStart;
sljit_sw targetBufferOffset = context->targetBuffersStart;

ASSERT(context->compiler->hasMemory0());
ASSERT(!(options & LoadInteger) || baseReg != sourceReg);
ASSERT(!(options & LoadInteger) || offsetReg != sourceReg);
#if defined(ENABLE_EXTENDED_FEATURES)
Expand Down Expand Up @@ -107,8 +106,8 @@ void MemAddress::check(sljit_compiler* compiler, Operand* offsetOperand, sljit_u

if (offset + size <= context->initialMemorySize) {
ASSERT(baseReg != 0);
sljit_emit_op1(compiler, SLJIT_MOV_P, baseReg, 0, SLJIT_MEM1(SLJIT_SP),
stackMemoryStart + offsetof(Memory::TargetBuffer, buffer));
sljit_emit_op1(compiler, SLJIT_MOV_P, baseReg, 0, SLJIT_MEM1(kInstanceReg),
targetBufferOffset + offsetof(Memory::TargetBuffer, buffer));
memArg.arg = SLJIT_MEM1(baseReg);
memArg.argw = offset;
load(compiler);
Expand All @@ -124,12 +123,12 @@ void MemAddress::check(sljit_compiler* compiler, Operand* offsetOperand, sljit_u

ASSERT(baseReg != 0 && offsetReg != 0);
/* The sizeInByte is always a 32 bit number on 32 bit systems. */
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_TMP_DEST_REG, 0, SLJIT_MEM1(SLJIT_SP),
stackMemoryStart + offsetof(Memory::TargetBuffer, sizeInByte) + WORD_LOW_OFFSET);
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_TMP_DEST_REG, 0, SLJIT_MEM1(kInstanceReg),
targetBufferOffset + offsetof(Memory::TargetBuffer, sizeInByte) + WORD_LOW_OFFSET);

sljit_emit_op1(compiler, SLJIT_MOV, offsetReg, 0, SLJIT_IMM, static_cast<sljit_sw>(offset + size));
sljit_emit_op1(compiler, SLJIT_MOV_P, baseReg, 0, SLJIT_MEM1(SLJIT_SP),
stackMemoryStart + offsetof(Memory::TargetBuffer, buffer));
sljit_emit_op1(compiler, SLJIT_MOV_P, baseReg, 0, SLJIT_MEM1(kInstanceReg),
targetBufferOffset + offsetof(Memory::TargetBuffer, buffer));

load(compiler);

Expand Down Expand Up @@ -162,13 +161,13 @@ void MemAddress::check(sljit_compiler* compiler, Operand* offsetOperand, sljit_u

if (context->initialMemorySize != context->maximumMemorySize) {
/* The sizeInByte is always a 32 bit number on 32 bit systems. */
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_TMP_DEST_REG, 0, SLJIT_MEM1(SLJIT_SP),
stackMemoryStart + offsetof(Memory::TargetBuffer, sizeInByte) + WORD_LOW_OFFSET);
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_TMP_DEST_REG, 0, SLJIT_MEM1(kInstanceReg),
targetBufferOffset + offsetof(Memory::TargetBuffer, sizeInByte) + WORD_LOW_OFFSET);
offset += size;
}

sljit_emit_op1(compiler, SLJIT_MOV_P, baseReg, 0, SLJIT_MEM1(SLJIT_SP),
stackMemoryStart + offsetof(Memory::TargetBuffer, buffer));
sljit_emit_op1(compiler, SLJIT_MOV_P, baseReg, 0, SLJIT_MEM1(kInstanceReg),
targetBufferOffset + offsetof(Memory::TargetBuffer, buffer));

load(compiler);

Expand Down
5 changes: 2 additions & 3 deletions src/jit/MemoryUtilInl.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,12 @@ static void emitMemory(sljit_compiler* compiler, Instruction* instr)
switch (opcode) {
case ByteCode::MemorySizeOpcode: {
ASSERT(!(instr->info() & Instruction::kIsCallback));
ASSERT(context->compiler->hasMemory0());

JITArg dstArg(params);

/* The sizeInByte is always a 32 bit number on 32 bit systems. */
sljit_emit_op2(compiler, SLJIT_LSHR, dstArg.arg, dstArg.argw, SLJIT_MEM1(SLJIT_SP),
context->stackMemoryStart + offsetof(Memory::TargetBuffer, sizeInByte) + WORD_LOW_OFFSET, SLJIT_IMM, 16);
sljit_emit_op2(compiler, SLJIT_LSHR, dstArg.arg, dstArg.argw, SLJIT_MEM1(kInstanceReg),
context->targetBuffersStart + offsetof(Memory::TargetBuffer, sizeInByte) + WORD_LOW_OFFSET, SLJIT_IMM, 16);
return;
}
case ByteCode::MemoryInitOpcode:
Expand Down
17 changes: 15 additions & 2 deletions src/runtime/Instance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@ namespace Walrus {
Instance* Instance::newInstance(Module* module)
{
// Must follow the order in Module::instantiate.
size_t numberOfRefs = module->numberOfMemoryTypes() + module->numberOfGlobalTypes()
+ module->numberOfTableTypes() + module->numberOfFunctions() + module->numberOfTagTypes();

size_t numberOfRefs = module->numberOfMemoryTypes()
+ Memory::TargetBuffer::sizeInPointers(module->numberOfMemoryTypes())
+ module->numberOfGlobalTypes() + module->numberOfTableTypes()
+ module->numberOfFunctions() + module->numberOfTagTypes();

void* result = malloc(alignedSize() + numberOfRefs * sizeof(void*));

Expand Down Expand Up @@ -60,6 +63,16 @@ Instance::Instance(Module* module)
module->store()->appendInstance(this);
}

Instance::~Instance()
{
size_t size = m_module->numberOfMemoryTypes();
Memory::TargetBuffer* targetBuffers = reinterpret_cast<Memory::TargetBuffer*>(alignedEnd() + m_module->numberOfMemoryTypes());

for (size_t i = 0; i < size; i++) {
targetBuffers[i].deque(m_memories[i]);
}
}

Optional<ExportType*> Instance::resolveExportType(std::string& name)
{
for (auto me : m_module->exports()) {
Expand Down
Loading

0 comments on commit 8e3ae56

Please sign in to comment.