Skip to content

Commit

Permalink
[simd/jit]: Implement more V128 loading instructions (#111 from haoyu…
Browse files Browse the repository at this point in the history
…-zc/jit-load-more)
  • Loading branch information
titzer authored Aug 15, 2023
2 parents fc1e38a + 49bb085 commit 7ff3e34
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 18 deletions.
10 changes: 5 additions & 5 deletions src/engine/x86-64/X86_64Interpreter.v3
Original file line number Diff line number Diff line change
Expand Up @@ -2248,7 +2248,7 @@ class X86_64InterpreterGen(ic: X86_64InterpreterCode, w: DataWriter) {
genTagUpdate(BpTypeCode.V128.code);
endHandler();
}
def genSplatLoad(opcode: Opcode,
def genLoadSplat(opcode: Opcode,
load_memarg: (X86_64Gpr, X86_64Addr, X86_64Gpr) -> void,
masm_emit: (X86_64Xmmr, X86_64Gpr) -> void) {
bindHandler(opcode);
Expand Down Expand Up @@ -2747,10 +2747,10 @@ class X86_64InterpreterGen(ic: X86_64InterpreterCode, w: DataWriter) {
genSplat(Opcode.F32X4_SPLAT, asm.q.movd_r_m, masm.emit_f32x4_splat(_, _, r_xmm1));
genSplat(Opcode.F64X2_SPLAT, asm.q.movq_r_m, masm.emit_f64x2_splat(_, _, r_xmm1));
// V128 load_splat
genSplatLoad(Opcode.V128_LOAD_8_SPLAT, load_memarg8, masm.emit_i8x16_splat(_, _, r_xmm1));
genSplatLoad(Opcode.V128_LOAD_16_SPLAT, load_memarg16, masm.emit_i16x8_splat);
genSplatLoad(Opcode.V128_LOAD_32_SPLAT, load_memarg32, masm.emit_i32x4_splat);
genSplatLoad(Opcode.V128_LOAD_64_SPLAT, load_memarg64, masm.emit_i64x2_splat);
genLoadSplat(Opcode.V128_LOAD_8_SPLAT, load_memarg8, masm.emit_i8x16_splat(_, _, r_xmm1));
genLoadSplat(Opcode.V128_LOAD_16_SPLAT, load_memarg16, masm.emit_i16x8_splat);
genLoadSplat(Opcode.V128_LOAD_32_SPLAT, load_memarg32, masm.emit_i32x4_splat);
genLoadSplat(Opcode.V128_LOAD_64_SPLAT, load_memarg64, masm.emit_i64x2_splat);
}
}
def genSimdBinop<T>(opcode: Opcode, f: (X86_64Xmmr, X86_64Xmmr) -> T) {
Expand Down
9 changes: 6 additions & 3 deletions src/engine/x86-64/X86_64MacroAssembler.v3
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,13 @@ class X86_64MacroAssembler extends MacroAssembler {
ABS, V128 => asm.movdqu_s_m(X(dst), X86_64Addr.new(b, t.0, 1, t.1));
}
}
def emit_v128_load_lane_r_m<T>(dst: Reg, base: Reg, index: Reg, offset: u32, asm_mov_r_m: (X86_64Gpr, X86_64Addr) -> T) {
var b = G(base), t = handle_large_offset(index, offset);
def emit_v128_load_lane_r_m<T>(dst: Reg, addr: X86_64Addr, asm_mov_r_m: (X86_64Gpr, X86_64Addr) -> T) {
recordCurSourceLoc();
asm_mov_r_m(G(dst), X86_64Addr.new(b, t.0, 1, t.1));
asm_mov_r_m(G(dst), addr);
}
def decode_memarg_addr(base: Reg, index: Reg, offset: u32) -> X86_64Addr {
var t = handle_large_offset(index, offset);
return X86_64Addr.new(G(base), t.0, 1, t.1);
}
def emit_storeb_r_r_r_i(kind: ValueKind, val: Reg, base: Reg, index: Reg, offset: u32) {
var t = handle_large_offset(index, offset);
Expand Down
75 changes: 65 additions & 10 deletions src/engine/x86-64/X86_64SinglePassCompiler.v3
Original file line number Diff line number Diff line change
Expand Up @@ -639,10 +639,22 @@ class X86_64SinglePassCompiler extends SinglePassCompiler {
def visit_F64X2_CONVERT_LOW_I32X4_U() { do_op1_x_gtmp_xtmp(ValueKind.V128, mmasm.emit_f64x2_convert_low_i32x4_u); }
def visit_F64X2_PROMOTE_LOW_F32X4() { do_op1_x_x(ValueKind.V128, asm.cvtps2pd_s_s); }

def visit_V128_LOAD_8_LANE(imm: MemArg, lane: byte) { visit_V128_LOAD_LANE(imm, lane, asm.q.movb_r_m, asm.pinsrb_s_r_i); }
def visit_V128_LOAD_16_LANE(imm: MemArg, lane: byte) { visit_V128_LOAD_LANE(imm, lane, asm.q.movw_r_m, asm.pinsrw_s_r_i); }
def visit_V128_LOAD_32_LANE(imm: MemArg, lane: byte) { visit_V128_LOAD_LANE(imm, lane, asm.q.movd_r_m, asm.pinsrd_s_r_i); }
def visit_V128_LOAD_64_LANE(imm: MemArg, lane: byte) { visit_V128_LOAD_LANE(imm, lane, asm.q.movq_r_m, asm.pinsrq_s_r_i); }
def visit_V128_LOAD_8_LANE(imm: MemArg, lane: byte) { visit_V128_LOAD_LANE(imm, lane, loadMemarg_b, asm.pinsrb_s_r_i); }
def visit_V128_LOAD_16_LANE(imm: MemArg, lane: byte) { visit_V128_LOAD_LANE(imm, lane, loadMemarg_w, asm.pinsrw_s_r_i); }
def visit_V128_LOAD_32_LANE(imm: MemArg, lane: byte) { visit_V128_LOAD_LANE(imm, lane, loadMemarg_d, asm.pinsrd_s_r_i); }
def visit_V128_LOAD_64_LANE(imm: MemArg, lane: byte) { visit_V128_LOAD_LANE(imm, lane, loadMemarg_q, asm.pinsrq_s_r_i); }
def visit_V128_LOAD_8X8_S(imm: MemArg) { visit_V128_LOAD_EXTEND(imm, asm.pmovsxbw_s_m); }
def visit_V128_LOAD_8X8_U(imm: MemArg) { visit_V128_LOAD_EXTEND(imm, asm.pmovzxbw_s_m); }
def visit_V128_LOAD_16X4_S(imm: MemArg) { visit_V128_LOAD_EXTEND(imm, asm.pmovsxwd_s_m); }
def visit_V128_LOAD_16X4_U(imm: MemArg) { visit_V128_LOAD_EXTEND(imm, asm.pmovzxwd_s_m); }
def visit_V128_LOAD_32X2_S(imm: MemArg) { visit_V128_LOAD_EXTEND(imm, asm.pmovsxdq_s_m); }
def visit_V128_LOAD_32X2_U(imm: MemArg) { visit_V128_LOAD_EXTEND(imm, asm.pmovzxdq_s_m); }
def visit_V128_LOAD_32_ZERO(imm: MemArg) { visit_V128_LOAD_ZERO(imm, loadMemarg_d, asm.pinsrd_s_r_i); }
def visit_V128_LOAD_64_ZERO(imm: MemArg) { visit_V128_LOAD_ZERO(imm, loadMemarg_q, asm.pinsrq_s_r_i); }
def visit_V128_LOAD_8_SPLAT(imm: MemArg) { visit_V128_LOAD_SPLAT(imm, loadMemarg_b, mmasm.emit_i8x16_splat(_, _, X(allocTmp(ValueKind.V128)))); }
def visit_V128_LOAD_16_SPLAT(imm: MemArg) { visit_V128_LOAD_SPLAT(imm, loadMemarg_w, mmasm.emit_i16x8_splat); }
def visit_V128_LOAD_32_SPLAT(imm: MemArg) { visit_V128_LOAD_SPLAT(imm, loadMemarg_d, mmasm.emit_i32x4_splat); }
def visit_V128_LOAD_64_SPLAT(imm: MemArg) { visit_V128_LOAD_SPLAT(imm, loadMemarg_q, mmasm.emit_i64x2_splat); }

def visit_I8X16_REPLACELANE(lane: byte) { visit_V128_REPLACE_LANE(lane, asm.pinsrb_s_r_i); }
def visit_I16X8_REPLACELANE(lane: byte) { visit_V128_REPLACE_LANE(lane, asm.pinsrw_s_r_i); }
Expand Down Expand Up @@ -726,8 +738,8 @@ class X86_64SinglePassCompiler extends SinglePassCompiler {
state.push(a.kindFlagsMatching(ValueKind.V128, IN_REG), a.reg, 0);
}

private def visit_V128_LOAD_LANE<T>(imm: MemArg, lane: byte, asm_mov_r_m: (X86_64Gpr, X86_64Addr) -> T, asm_pins_s_r_i: (X86_64Xmmr, X86_64Gpr, byte) -> T) {
var sv = popRegToOverwrite(), r = X(sv.reg);
// Decode memarg and return the mem address and trap reason if any
private def decodeMemarg(imm: MemArg) -> (X86_64Addr, TrapReason) {
var base_reg = regs.mem0_base;
if (imm.memory_index != 0) {
// XXX: cache the base register for memories > 0
Expand All @@ -743,14 +755,30 @@ class X86_64SinglePassCompiler extends SinglePassCompiler {
var offset = imm.offset;
if (iv.isConst()) {
var sum = u64.view(offset) + u32.view(iv.const); // fold offset calculation
if (sum > u32.max) return emitTrap(TrapReason.MEM_OUT_OF_BOUNDS); // statically OOB
if (sum > u32.max) return (null, TrapReason.MEM_OUT_OF_BOUNDS);
offset = u32.view(sum);
} else {
index_reg = ensureReg(iv, state.sp);
}
var lane_val = allocTmp(ValueKind.I64); // XXX: can reuse index reg if frequency == 1 and ValueKind.I32
mmasm.emit_v128_load_lane_r_m(lane_val, base_reg, index_reg, u32.!(offset), asm_mov_r_m);
asm_pins_s_r_i(r, G(lane_val), lane);
return (mmasm.decode_memarg_addr(base_reg, index_reg, u32.!(offset)), TrapReason.NONE);
}
// Utilities to load a memarg into a register
private def loadMemarg<T>(dst: Reg, imm: MemArg, asm_mov_r_m: (X86_64Gpr, X86_64Addr) -> T) {
def t = decodeMemarg(imm);
if (t.1 != TrapReason.NONE) return emitTrap(t.1);
def addr = t.0;
mmasm.emit_v128_load_lane_r_m(dst, addr, asm_mov_r_m);
}
private def loadMemarg_b(dst: Reg, imm: MemArg) { loadMemarg(dst, imm, asm.q.movb_r_m); }
private def loadMemarg_w(dst: Reg, imm: MemArg) { loadMemarg(dst, imm, asm.q.movw_r_m); }
private def loadMemarg_d(dst: Reg, imm: MemArg) { loadMemarg(dst, imm, asm.q.movd_r_m); }
private def loadMemarg_q(dst: Reg, imm: MemArg) { loadMemarg(dst, imm, asm.q.movq_r_m); }

private def visit_V128_LOAD_LANE<T>(imm: MemArg, lane: byte, loadMem: (Reg, MemArg) -> void, asm_pins_s_r_i: (X86_64Xmmr, X86_64Gpr, byte) -> T) {
var sv = popRegToOverwrite(), r = X(sv.reg);
var val = allocTmp(ValueKind.I64);
loadMem(val, imm);
asm_pins_s_r_i(r, G(val), lane);
state.push(sv.kindFlagsMatching(ValueKind.V128, IN_REG), sv.reg, 0);
}

Expand All @@ -776,6 +804,33 @@ class X86_64SinglePassCompiler extends SinglePassCompiler {
state.push(SpcConsts.kindToFlags(kind) | IN_REG, d, 0);
}

private def visit_V128_LOAD_EXTEND<T>(imm: MemArg, asm_pmov_s_m: (X86_64Xmmr, X86_64Addr) -> T) {
var d = allocRegTos(ValueKind.V128);
var val = allocTmp(ValueKind.I64);
def t = decodeMemarg(imm);
if (t.1 != TrapReason.NONE) return emitTrap(t.1);
def addr = t.0;
asm_pmov_s_m(X(d), addr);
state.push(KIND_V128 | IN_REG, d, 0);
}

private def visit_V128_LOAD_ZERO<T>(imm: MemArg, loadMem: (Reg, MemArg) -> void, asm_pins_s_r_i: (X86_64Xmmr, X86_64Gpr, byte) -> T) {
var val = allocTmp(ValueKind.I64);
var d = allocRegTos(ValueKind.V128);
loadMem(val, imm);
mmasm.emit_v128_zero(X(d));
asm_pins_s_r_i(X(d), G(val), 0);
state.push(KIND_V128 | IN_REG, d, 0);
}

private def visit_V128_LOAD_SPLAT<T>(imm: MemArg, loadMem: (Reg, MemArg) -> void, masm_splat: (X86_64Xmmr, X86_64Gpr) -> void) {
var val = allocTmp(ValueKind.I64);
var d = allocRegTos(ValueKind.V128);
loadMem(val, imm);
masm_splat(X(d), G(val));
state.push(KIND_V128 | IN_REG, d, 0);
}

// r1 = op(r1)
private def do_op1_r<T>(kind: ValueKind, emit: (X86_64Gpr -> T)) -> bool {
var sv = popRegToOverwrite(), r = G(sv.reg);
Expand Down

0 comments on commit 7ff3e34

Please sign in to comment.