YJIT: Avoid writing return value to memory in leave

Previously, at the end of `leave` we did
`*caller_cfp->sp = return_value`, like the interpreter.
With future changes that leaves the SP field uninitialized for C frames,
this will become problematic. For cases like returning from
`rb_funcall()`, the return value was written above the stack and
never read anyway (callers use the copy in the return register).

Leave the return value in a register at the end of `leave` and have the
code at `cfp->jit_return` decide what to do with it. This avoids the
unnecessary memory write mentioned above. For JIT-to-JIT returns, it goes
through `asm.stack_push()` and benefits from register allocation for
stack temporaries.

Mostly flat on benchmarks, with maybe some marginal speed improvements.

Co-authored-by: Takashi Kokubun <takashikkbn@gmail.com>
This commit is contained in:
Alan Wu 2023-09-29 22:29:24 -04:00
parent a5cc6341c0
commit 41a6e4bdf9
3 changed files with 111 additions and 41 deletions

View File

@ -22,6 +22,7 @@ pub const SP: Opnd = _SP;
pub const C_ARG_OPNDS: [Opnd; 6] = _C_ARG_OPNDS; pub const C_ARG_OPNDS: [Opnd; 6] = _C_ARG_OPNDS;
pub const C_RET_OPND: Opnd = _C_RET_OPND; pub const C_RET_OPND: Opnd = _C_RET_OPND;
pub use crate::backend::current::{Reg, C_RET_REG};
// Memory operand base // Memory operand base
#[derive(Clone, Copy, PartialEq, Eq, Debug)] #[derive(Clone, Copy, PartialEq, Eq, Debug)]
@ -955,6 +956,7 @@ pub struct SideExitContext {
pub stack_size: u8, pub stack_size: u8,
pub sp_offset: i8, pub sp_offset: i8,
pub reg_temps: RegTemps, pub reg_temps: RegTemps,
pub is_return_landing: bool,
} }
impl SideExitContext { impl SideExitContext {
@ -965,6 +967,7 @@ impl SideExitContext {
stack_size: ctx.get_stack_size(), stack_size: ctx.get_stack_size(),
sp_offset: ctx.get_sp_offset(), sp_offset: ctx.get_sp_offset(),
reg_temps: ctx.get_reg_temps(), reg_temps: ctx.get_reg_temps(),
is_return_landing: ctx.is_return_landing(),
}; };
if cfg!(debug_assertions) { if cfg!(debug_assertions) {
// Assert that we're not losing any mandatory metadata // Assert that we're not losing any mandatory metadata
@ -979,6 +982,9 @@ impl SideExitContext {
ctx.set_stack_size(self.stack_size); ctx.set_stack_size(self.stack_size);
ctx.set_sp_offset(self.sp_offset); ctx.set_sp_offset(self.sp_offset);
ctx.set_reg_temps(self.reg_temps); ctx.set_reg_temps(self.reg_temps);
if self.is_return_landing {
ctx.set_as_return_landing();
}
ctx ctx
} }
} }

View File

@ -444,6 +444,12 @@ fn gen_exit(exit_pc: *mut VALUE, asm: &mut Assembler) {
asm_comment!(asm, "exit to interpreter on {}", insn_name(opcode as usize)); asm_comment!(asm, "exit to interpreter on {}", insn_name(opcode as usize));
} }
if asm.ctx.is_return_landing() {
asm.mov(SP, Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SP));
let top = asm.stack_push(Type::Unknown);
asm.mov(top, C_RET_OPND);
}
// Spill stack temps before returning to the interpreter // Spill stack temps before returning to the interpreter
asm.spill_temps(); asm.spill_temps();
@ -636,13 +642,18 @@ fn gen_leave_exception(ocb: &mut OutlinedCb) -> CodePtr {
let code_ptr = ocb.get_write_ptr(); let code_ptr = ocb.get_write_ptr();
let mut asm = Assembler::new(); let mut asm = Assembler::new();
// gen_leave() leaves the return value in C_RET_OPND before coming here.
let ruby_ret_val = asm.live_reg_opnd(C_RET_OPND);
// Every exit to the interpreter should be counted // Every exit to the interpreter should be counted
gen_counter_incr(&mut asm, Counter::leave_interp_return); gen_counter_incr(&mut asm, Counter::leave_interp_return);
asm_comment!(asm, "increment SP of the caller"); asm_comment!(asm, "push return value through cfp->sp");
let sp = Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SP); let cfp_sp = Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SP);
let sp = asm.load(cfp_sp);
asm.mov(Opnd::mem(64, sp, 0), ruby_ret_val);
let new_sp = asm.add(sp, SIZEOF_VALUE.into()); let new_sp = asm.add(sp, SIZEOF_VALUE.into());
asm.mov(sp, new_sp); asm.mov(cfp_sp, new_sp);
asm_comment!(asm, "exit from exception"); asm_comment!(asm, "exit from exception");
asm.cpop_into(SP); asm.cpop_into(SP);
@ -872,6 +883,18 @@ pub fn gen_single_block(
asm_comment!(asm, "reg_temps: {:08b}", asm.ctx.get_reg_temps().as_u8()); asm_comment!(asm, "reg_temps: {:08b}", asm.ctx.get_reg_temps().as_u8());
} }
if asm.ctx.is_return_landing() {
// Continuation of the end of gen_leave().
// Reload REG_SP for the current frame and transfer the return value
// to the stack top.
asm.mov(SP, Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SP));
let top = asm.stack_push(Type::Unknown);
asm.mov(top, C_RET_OPND);
asm.ctx.clear_return_landing();
}
// For each instruction to compile // For each instruction to compile
// NOTE: could rewrite this loop with a std::iter::Iterator // NOTE: could rewrite this loop with a std::iter::Iterator
while insn_idx < iseq_size { while insn_idx < iseq_size {
@ -6535,17 +6558,14 @@ fn gen_send_iseq(
// The callee might change locals through Kernel#binding and other means. // The callee might change locals through Kernel#binding and other means.
asm.ctx.clear_local_types(); asm.ctx.clear_local_types();
// Pop arguments and receiver in return context, push the return value // Pop arguments and receiver in return context and
// After the return, sp_offset will be 1. The codegen for leave writes // mark it as a continuation of gen_leave()
// the return value in case of JIT-to-JIT return.
let mut return_asm = Assembler::new(); let mut return_asm = Assembler::new();
return_asm.ctx = asm.ctx.clone(); return_asm.ctx = asm.ctx.clone();
return_asm.stack_pop(sp_offset.try_into().unwrap()); return_asm.stack_pop(sp_offset.try_into().unwrap());
let return_val = return_asm.stack_push(Type::Unknown); return_asm.ctx.set_sp_offset(0); // We set SP on the caller's frame above
// The callee writes a return value on stack. Update reg_temps accordingly.
return_asm.ctx.dealloc_temp_reg(return_val.stack_idx());
return_asm.ctx.set_sp_offset(1);
return_asm.ctx.reset_chain_depth(); return_asm.ctx.reset_chain_depth();
return_asm.ctx.set_as_return_landing();
// Write the JIT return address on the callee frame // Write the JIT return address on the callee frame
gen_branch( gen_branch(
@ -7745,15 +7765,15 @@ fn gen_leave(
// Load the return value // Load the return value
let retval_opnd = asm.stack_pop(1); let retval_opnd = asm.stack_pop(1);
// Move the return value into the C return register for gen_leave_exit() // Move the return value into the C return register
asm.mov(C_RET_OPND, retval_opnd); asm.mov(C_RET_OPND, retval_opnd);
// Reload REG_SP for the caller and write the return value. // Jump to the JIT return address on the frame that was just popped.
// Top of the stack is REG_SP[0] since the caller has sp_offset=1. // There are a few possible jump targets:
asm.mov(SP, Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SP)); // - gen_leave_exit() and gen_leave_exception(), for C callers
asm.mov(Opnd::mem(64, SP, 0), C_RET_OPND); // - Return context set up by gen_send_iseq()
// We don't write the return value to stack memory like the interpreter here.
// Jump to the JIT return address on the frame that was just popped // Each jump target do it as necessary.
let offset_to_jit_return = let offset_to_jit_return =
-(RUBY_SIZEOF_CONTROL_FRAME as i32) + RUBY_OFFSET_CFP_JIT_RETURN; -(RUBY_SIZEOF_CONTROL_FRAME as i32) + RUBY_OFFSET_CFP_JIT_RETURN;
asm.jmp_opnd(Opnd::mem(64, CFP, offset_to_jit_return)); asm.jmp_opnd(Opnd::mem(64, CFP, offset_to_jit_return));

View File

@ -450,8 +450,11 @@ pub struct Context {
/// Bitmap of which stack temps are in a register /// Bitmap of which stack temps are in a register
reg_temps: RegTemps, reg_temps: RegTemps,
// Depth of this block in the sidechain (eg: inline-cache chain) /// Fields packed into u8
chain_depth: u8, /// - Lower 7 bits: Depth of this block in the sidechain (eg: inline-cache chain)
/// - Top bit: Whether this code is the target of a JIT-to-JIT Ruby return
/// ([Self::is_return_landing])
chain_depth_return_landing: u8,
// Local variable types we keep track of // Local variable types we keep track of
local_types: [Type; MAX_LOCAL_TYPES], local_types: [Type; MAX_LOCAL_TYPES],
@ -1402,7 +1405,7 @@ fn find_block_version(blockid: BlockId, ctx: &Context) -> Option<BlockRef> {
/// Produce a generic context when the block version limit is hit for a blockid /// Produce a generic context when the block version limit is hit for a blockid
pub fn limit_block_versions(blockid: BlockId, ctx: &Context) -> Context { pub fn limit_block_versions(blockid: BlockId, ctx: &Context) -> Context {
// Guard chains implement limits separately, do nothing // Guard chains implement limits separately, do nothing
if ctx.chain_depth > 0 { if ctx.get_chain_depth() > 0 {
return ctx.clone(); return ctx.clone();
} }
@ -1610,6 +1613,9 @@ impl Context {
generic_ctx.stack_size = self.stack_size; generic_ctx.stack_size = self.stack_size;
generic_ctx.sp_offset = self.sp_offset; generic_ctx.sp_offset = self.sp_offset;
generic_ctx.reg_temps = self.reg_temps; generic_ctx.reg_temps = self.reg_temps;
if self.is_return_landing() {
generic_ctx.set_as_return_landing();
}
generic_ctx generic_ctx
} }
@ -1640,15 +1646,30 @@ impl Context {
} }
pub fn get_chain_depth(&self) -> u8 { pub fn get_chain_depth(&self) -> u8 {
self.chain_depth self.chain_depth_return_landing & 0x7f
} }
pub fn reset_chain_depth(&mut self) { pub fn reset_chain_depth(&mut self) {
self.chain_depth = 0; self.chain_depth_return_landing &= 0x80;
} }
pub fn increment_chain_depth(&mut self) { pub fn increment_chain_depth(&mut self) {
self.chain_depth += 1; if self.get_chain_depth() == 0x7f {
panic!("max block version chain depth reached!");
}
self.chain_depth_return_landing += 1;
}
pub fn set_as_return_landing(&mut self) {
self.chain_depth_return_landing |= 0x80;
}
pub fn clear_return_landing(&mut self) {
self.chain_depth_return_landing &= 0x7f;
}
pub fn is_return_landing(&self) -> bool {
self.chain_depth_return_landing & 0x80 > 0
} }
/// Get an operand for the adjusted stack pointer address /// Get an operand for the adjusted stack pointer address
@ -1845,13 +1866,17 @@ impl Context {
let src = self; let src = self;
// Can only lookup the first version in the chain // Can only lookup the first version in the chain
if dst.chain_depth != 0 { if dst.get_chain_depth() != 0 {
return TypeDiff::Incompatible; return TypeDiff::Incompatible;
} }
// Blocks with depth > 0 always produce new versions // Blocks with depth > 0 always produce new versions
// Sidechains cannot overlap // Sidechains cannot overlap
if src.chain_depth != 0 { if src.get_chain_depth() != 0 {
return TypeDiff::Incompatible;
}
if src.is_return_landing() != dst.is_return_landing() {
return TypeDiff::Incompatible; return TypeDiff::Incompatible;
} }
@ -2496,6 +2521,9 @@ fn branch_stub_hit_body(branch_ptr: *const c_void, target_idx: u32, ec: EcPtr) -
let running_iseq = rb_cfp_get_iseq(cfp); let running_iseq = rb_cfp_get_iseq(cfp);
let reconned_pc = rb_iseq_pc_at_idx(running_iseq, target_blockid.idx.into()); let reconned_pc = rb_iseq_pc_at_idx(running_iseq, target_blockid.idx.into());
let reconned_sp = original_interp_sp.offset(target_ctx.sp_offset.into()); let reconned_sp = original_interp_sp.offset(target_ctx.sp_offset.into());
// Unlike in the interpreter, our `leave` doesn't write to the caller's
// SP -- we do it in the returned-to code. Account for this difference.
let reconned_sp = reconned_sp.add(target_ctx.is_return_landing().into());
assert_eq!(running_iseq, target_blockid.iseq as _, "each stub expects a particular iseq"); assert_eq!(running_iseq, target_blockid.iseq as _, "each stub expects a particular iseq");
@ -2632,10 +2660,16 @@ fn gen_branch_stub(
asm.set_reg_temps(ctx.reg_temps); asm.set_reg_temps(ctx.reg_temps);
asm_comment!(asm, "branch stub hit"); asm_comment!(asm, "branch stub hit");
if asm.ctx.is_return_landing() {
asm.mov(SP, Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SP));
let top = asm.stack_push(Type::Unknown);
asm.mov(top, C_RET_OPND);
}
// Save caller-saved registers before C_ARG_OPNDS get clobbered. // Save caller-saved registers before C_ARG_OPNDS get clobbered.
// Spill all registers for consistency with the trampoline. // Spill all registers for consistency with the trampoline.
for &reg in caller_saved_temp_regs().iter() { for &reg in caller_saved_temp_regs() {
asm.cpush(reg); asm.cpush(Opnd::Reg(reg));
} }
// Spill temps to the VM stack as well for jit.peek_at_stack() // Spill temps to the VM stack as well for jit.peek_at_stack()
@ -2676,7 +2710,7 @@ pub fn gen_branch_stub_hit_trampoline(ocb: &mut OutlinedCb) -> CodePtr {
// Since this trampoline is static, it allows code GC inside // Since this trampoline is static, it allows code GC inside
// branch_stub_hit() to free stubs without problems. // branch_stub_hit() to free stubs without problems.
asm_comment!(asm, "branch_stub_hit() trampoline"); asm_comment!(asm, "branch_stub_hit() trampoline");
let jump_addr = asm.ccall( let stub_hit_ret = asm.ccall(
branch_stub_hit as *mut u8, branch_stub_hit as *mut u8,
vec![ vec![
C_ARG_OPNDS[0], C_ARG_OPNDS[0],
@ -2684,28 +2718,41 @@ pub fn gen_branch_stub_hit_trampoline(ocb: &mut OutlinedCb) -> CodePtr {
EC, EC,
] ]
); );
let jump_addr = asm.load(stub_hit_ret);
// Restore caller-saved registers for stack temps // Restore caller-saved registers for stack temps
for &reg in caller_saved_temp_regs().iter().rev() { for &reg in caller_saved_temp_regs().rev() {
asm.cpop_into(reg); asm.cpop_into(Opnd::Reg(reg));
} }
// Jump to the address returned by the branch_stub_hit() call // Jump to the address returned by the branch_stub_hit() call
asm.jmp_opnd(jump_addr); asm.jmp_opnd(jump_addr);
// HACK: popping into C_RET_REG clobbers the return value of branch_stub_hit() we need to jump
// to, so we need a scratch register to preserve it. This extends the live range of the C
// return register so we get something else for the return value.
let _ = asm.live_reg_opnd(stub_hit_ret);
asm.compile(ocb, None); asm.compile(ocb, None);
code_ptr code_ptr
} }
/// Return registers to be pushed and popped on branch_stub_hit. /// Return registers to be pushed and popped on branch_stub_hit.
/// The return value may include an extra register for x86 alignment. fn caller_saved_temp_regs() -> impl Iterator<Item = &'static Reg> + DoubleEndedIterator {
fn caller_saved_temp_regs() -> Vec<Opnd> { let temp_regs = Assembler::get_temp_regs().iter();
let mut regs = Assembler::get_temp_regs().to_vec(); let len = temp_regs.len();
if regs.len() % 2 == 1 { // The return value gen_leave() leaves in C_RET_REG
regs.push(*regs.last().unwrap()); // x86 alignment // needs to survive the branch_stub_hit() call.
let regs = temp_regs.chain(std::iter::once(&C_RET_REG));
// On x86_64, maintain 16-byte stack alignment
if cfg!(target_arch = "x86_64") && len % 2 == 0 {
static ONE_MORE: [Reg; 1] = [C_RET_REG];
regs.chain(ONE_MORE.iter())
} else {
regs.chain(&[])
} }
regs.iter().map(|&reg| Opnd::Reg(reg)).collect()
} }
impl Assembler impl Assembler
@ -2832,16 +2879,13 @@ pub fn defer_compilation(
asm: &mut Assembler, asm: &mut Assembler,
ocb: &mut OutlinedCb, ocb: &mut OutlinedCb,
) { ) {
if asm.ctx.chain_depth != 0 { if asm.ctx.get_chain_depth() != 0 {
panic!("Double defer!"); panic!("Double defer!");
} }
let mut next_ctx = asm.ctx.clone(); let mut next_ctx = asm.ctx.clone();
if next_ctx.chain_depth == u8::MAX { next_ctx.increment_chain_depth();
panic!("max block version chain depth reached!");
}
next_ctx.chain_depth += 1;
let branch = new_pending_branch(jit, BranchGenFn::JumpToTarget0(Cell::new(BranchShape::Default))); let branch = new_pending_branch(jit, BranchGenFn::JumpToTarget0(Cell::new(BranchShape::Default)));