From 41a6e4bdf9738e2cf1ea356422a429efeeb5a8f0 Mon Sep 17 00:00:00 2001 From: Alan Wu Date: Fri, 29 Sep 2023 22:29:24 -0400 Subject: [PATCH] YJIT: Avoid writing return value to memory in `leave` Previously, at the end of `leave` we did `*caller_cfp->sp = return_value`, like the interpreter. With future changes that leaves the SP field uninitialized for C frames, this will become problematic. For cases like returning from `rb_funcall()`, the return value was written above the stack and never read anyway (callers use the copy in the return register). Leave the return value in a register at the end of `leave` and have the code at `cfp->jit_return` decide what to do with it. This avoids the unnecessary memory write mentioned above. For JIT-to-JIT returns, it goes through `asm.stack_push()` and benefits from register allocation for stack temporaries. Mostly flat on benchmarks, with maybe some marginal speed improvements. Co-authored-by: Takashi Kokubun --- yjit/src/backend/ir.rs | 6 +++ yjit/src/codegen.rs | 54 +++++++++++++++++-------- yjit/src/core.rs | 92 +++++++++++++++++++++++++++++++----------- 3 files changed, 111 insertions(+), 41 deletions(-) diff --git a/yjit/src/backend/ir.rs b/yjit/src/backend/ir.rs index 10e463885b..67b5547bf9 100644 --- a/yjit/src/backend/ir.rs +++ b/yjit/src/backend/ir.rs @@ -22,6 +22,7 @@ pub const SP: Opnd = _SP; pub const C_ARG_OPNDS: [Opnd; 6] = _C_ARG_OPNDS; pub const C_RET_OPND: Opnd = _C_RET_OPND; +pub use crate::backend::current::{Reg, C_RET_REG}; // Memory operand base #[derive(Clone, Copy, PartialEq, Eq, Debug)] @@ -955,6 +956,7 @@ pub struct SideExitContext { pub stack_size: u8, pub sp_offset: i8, pub reg_temps: RegTemps, + pub is_return_landing: bool, } impl SideExitContext { @@ -965,6 +967,7 @@ impl SideExitContext { stack_size: ctx.get_stack_size(), sp_offset: ctx.get_sp_offset(), reg_temps: ctx.get_reg_temps(), + is_return_landing: ctx.is_return_landing(), }; if cfg!(debug_assertions) { // Assert that we're not losing any mandatory metadata @@ -979,6 +982,9 @@ impl SideExitContext { ctx.set_stack_size(self.stack_size); ctx.set_sp_offset(self.sp_offset); ctx.set_reg_temps(self.reg_temps); + if self.is_return_landing { + ctx.set_as_return_landing(); + } ctx } } diff --git a/yjit/src/codegen.rs b/yjit/src/codegen.rs index 5de12e0420..e4fa1a3665 100644 --- a/yjit/src/codegen.rs +++ b/yjit/src/codegen.rs @@ -444,6 +444,12 @@ fn gen_exit(exit_pc: *mut VALUE, asm: &mut Assembler) { asm_comment!(asm, "exit to interpreter on {}", insn_name(opcode as usize)); } + if asm.ctx.is_return_landing() { + asm.mov(SP, Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SP)); + let top = asm.stack_push(Type::Unknown); + asm.mov(top, C_RET_OPND); + } + // Spill stack temps before returning to the interpreter asm.spill_temps(); @@ -636,13 +642,18 @@ fn gen_leave_exception(ocb: &mut OutlinedCb) -> CodePtr { let code_ptr = ocb.get_write_ptr(); let mut asm = Assembler::new(); + // gen_leave() leaves the return value in C_RET_OPND before coming here. + let ruby_ret_val = asm.live_reg_opnd(C_RET_OPND); + // Every exit to the interpreter should be counted gen_counter_incr(&mut asm, Counter::leave_interp_return); - asm_comment!(asm, "increment SP of the caller"); - let sp = Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SP); + asm_comment!(asm, "push return value through cfp->sp"); + let cfp_sp = Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SP); + let sp = asm.load(cfp_sp); + asm.mov(Opnd::mem(64, sp, 0), ruby_ret_val); let new_sp = asm.add(sp, SIZEOF_VALUE.into()); - asm.mov(sp, new_sp); + asm.mov(cfp_sp, new_sp); asm_comment!(asm, "exit from exception"); asm.cpop_into(SP); @@ -872,6 +883,18 @@ pub fn gen_single_block( asm_comment!(asm, "reg_temps: {:08b}", asm.ctx.get_reg_temps().as_u8()); } + if asm.ctx.is_return_landing() { + // Continuation of the end of gen_leave(). + // Reload REG_SP for the current frame and transfer the return value + // to the stack top. + asm.mov(SP, Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SP)); + + let top = asm.stack_push(Type::Unknown); + asm.mov(top, C_RET_OPND); + + asm.ctx.clear_return_landing(); + } + // For each instruction to compile // NOTE: could rewrite this loop with a std::iter::Iterator while insn_idx < iseq_size { @@ -6535,17 +6558,14 @@ fn gen_send_iseq( // The callee might change locals through Kernel#binding and other means. asm.ctx.clear_local_types(); - // Pop arguments and receiver in return context, push the return value - // After the return, sp_offset will be 1. The codegen for leave writes - // the return value in case of JIT-to-JIT return. + // Pop arguments and receiver in return context and + // mark it as a continuation of gen_leave() let mut return_asm = Assembler::new(); return_asm.ctx = asm.ctx.clone(); return_asm.stack_pop(sp_offset.try_into().unwrap()); - let return_val = return_asm.stack_push(Type::Unknown); - // The callee writes a return value on stack. Update reg_temps accordingly. - return_asm.ctx.dealloc_temp_reg(return_val.stack_idx()); - return_asm.ctx.set_sp_offset(1); + return_asm.ctx.set_sp_offset(0); // We set SP on the caller's frame above return_asm.ctx.reset_chain_depth(); + return_asm.ctx.set_as_return_landing(); // Write the JIT return address on the callee frame gen_branch( @@ -7745,15 +7765,15 @@ fn gen_leave( // Load the return value let retval_opnd = asm.stack_pop(1); - // Move the return value into the C return register for gen_leave_exit() + // Move the return value into the C return register asm.mov(C_RET_OPND, retval_opnd); - // Reload REG_SP for the caller and write the return value. - // Top of the stack is REG_SP[0] since the caller has sp_offset=1. - asm.mov(SP, Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SP)); - asm.mov(Opnd::mem(64, SP, 0), C_RET_OPND); - - // Jump to the JIT return address on the frame that was just popped + // Jump to the JIT return address on the frame that was just popped. + // There are a few possible jump targets: + // - gen_leave_exit() and gen_leave_exception(), for C callers + // - Return context set up by gen_send_iseq() + // We don't write the return value to stack memory like the interpreter here. + // Each jump target do it as necessary. let offset_to_jit_return = -(RUBY_SIZEOF_CONTROL_FRAME as i32) + RUBY_OFFSET_CFP_JIT_RETURN; asm.jmp_opnd(Opnd::mem(64, CFP, offset_to_jit_return)); diff --git a/yjit/src/core.rs b/yjit/src/core.rs index a5943fb393..d2329455d0 100644 --- a/yjit/src/core.rs +++ b/yjit/src/core.rs @@ -450,8 +450,11 @@ pub struct Context { /// Bitmap of which stack temps are in a register reg_temps: RegTemps, - // Depth of this block in the sidechain (eg: inline-cache chain) - chain_depth: u8, + /// Fields packed into u8 + /// - Lower 7 bits: Depth of this block in the sidechain (eg: inline-cache chain) + /// - Top bit: Whether this code is the target of a JIT-to-JIT Ruby return + /// ([Self::is_return_landing]) + chain_depth_return_landing: u8, // Local variable types we keep track of local_types: [Type; MAX_LOCAL_TYPES], @@ -1402,7 +1405,7 @@ fn find_block_version(blockid: BlockId, ctx: &Context) -> Option { /// Produce a generic context when the block version limit is hit for a blockid pub fn limit_block_versions(blockid: BlockId, ctx: &Context) -> Context { // Guard chains implement limits separately, do nothing - if ctx.chain_depth > 0 { + if ctx.get_chain_depth() > 0 { return ctx.clone(); } @@ -1610,6 +1613,9 @@ impl Context { generic_ctx.stack_size = self.stack_size; generic_ctx.sp_offset = self.sp_offset; generic_ctx.reg_temps = self.reg_temps; + if self.is_return_landing() { + generic_ctx.set_as_return_landing(); + } generic_ctx } @@ -1640,15 +1646,30 @@ impl Context { } pub fn get_chain_depth(&self) -> u8 { - self.chain_depth + self.chain_depth_return_landing & 0x7f } pub fn reset_chain_depth(&mut self) { - self.chain_depth = 0; + self.chain_depth_return_landing &= 0x80; } pub fn increment_chain_depth(&mut self) { - self.chain_depth += 1; + if self.get_chain_depth() == 0x7f { + panic!("max block version chain depth reached!"); + } + self.chain_depth_return_landing += 1; + } + + pub fn set_as_return_landing(&mut self) { + self.chain_depth_return_landing |= 0x80; + } + + pub fn clear_return_landing(&mut self) { + self.chain_depth_return_landing &= 0x7f; + } + + pub fn is_return_landing(&self) -> bool { + self.chain_depth_return_landing & 0x80 > 0 } /// Get an operand for the adjusted stack pointer address @@ -1845,13 +1866,17 @@ impl Context { let src = self; // Can only lookup the first version in the chain - if dst.chain_depth != 0 { + if dst.get_chain_depth() != 0 { return TypeDiff::Incompatible; } // Blocks with depth > 0 always produce new versions // Sidechains cannot overlap - if src.chain_depth != 0 { + if src.get_chain_depth() != 0 { + return TypeDiff::Incompatible; + } + + if src.is_return_landing() != dst.is_return_landing() { return TypeDiff::Incompatible; } @@ -2496,6 +2521,9 @@ fn branch_stub_hit_body(branch_ptr: *const c_void, target_idx: u32, ec: EcPtr) - let running_iseq = rb_cfp_get_iseq(cfp); let reconned_pc = rb_iseq_pc_at_idx(running_iseq, target_blockid.idx.into()); let reconned_sp = original_interp_sp.offset(target_ctx.sp_offset.into()); + // Unlike in the interpreter, our `leave` doesn't write to the caller's + // SP -- we do it in the returned-to code. Account for this difference. + let reconned_sp = reconned_sp.add(target_ctx.is_return_landing().into()); assert_eq!(running_iseq, target_blockid.iseq as _, "each stub expects a particular iseq"); @@ -2632,10 +2660,16 @@ fn gen_branch_stub( asm.set_reg_temps(ctx.reg_temps); asm_comment!(asm, "branch stub hit"); + if asm.ctx.is_return_landing() { + asm.mov(SP, Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SP)); + let top = asm.stack_push(Type::Unknown); + asm.mov(top, C_RET_OPND); + } + // Save caller-saved registers before C_ARG_OPNDS get clobbered. // Spill all registers for consistency with the trampoline. - for ® in caller_saved_temp_regs().iter() { - asm.cpush(reg); + for ® in caller_saved_temp_regs() { + asm.cpush(Opnd::Reg(reg)); } // Spill temps to the VM stack as well for jit.peek_at_stack() @@ -2676,7 +2710,7 @@ pub fn gen_branch_stub_hit_trampoline(ocb: &mut OutlinedCb) -> CodePtr { // Since this trampoline is static, it allows code GC inside // branch_stub_hit() to free stubs without problems. asm_comment!(asm, "branch_stub_hit() trampoline"); - let jump_addr = asm.ccall( + let stub_hit_ret = asm.ccall( branch_stub_hit as *mut u8, vec![ C_ARG_OPNDS[0], @@ -2684,28 +2718,41 @@ pub fn gen_branch_stub_hit_trampoline(ocb: &mut OutlinedCb) -> CodePtr { EC, ] ); + let jump_addr = asm.load(stub_hit_ret); // Restore caller-saved registers for stack temps - for ® in caller_saved_temp_regs().iter().rev() { - asm.cpop_into(reg); + for ® in caller_saved_temp_regs().rev() { + asm.cpop_into(Opnd::Reg(reg)); } // Jump to the address returned by the branch_stub_hit() call asm.jmp_opnd(jump_addr); + // HACK: popping into C_RET_REG clobbers the return value of branch_stub_hit() we need to jump + // to, so we need a scratch register to preserve it. This extends the live range of the C + // return register so we get something else for the return value. + let _ = asm.live_reg_opnd(stub_hit_ret); + asm.compile(ocb, None); code_ptr } /// Return registers to be pushed and popped on branch_stub_hit. -/// The return value may include an extra register for x86 alignment. -fn caller_saved_temp_regs() -> Vec { - let mut regs = Assembler::get_temp_regs().to_vec(); - if regs.len() % 2 == 1 { - regs.push(*regs.last().unwrap()); // x86 alignment +fn caller_saved_temp_regs() -> impl Iterator + DoubleEndedIterator { + let temp_regs = Assembler::get_temp_regs().iter(); + let len = temp_regs.len(); + // The return value gen_leave() leaves in C_RET_REG + // needs to survive the branch_stub_hit() call. + let regs = temp_regs.chain(std::iter::once(&C_RET_REG)); + + // On x86_64, maintain 16-byte stack alignment + if cfg!(target_arch = "x86_64") && len % 2 == 0 { + static ONE_MORE: [Reg; 1] = [C_RET_REG]; + regs.chain(ONE_MORE.iter()) + } else { + regs.chain(&[]) } - regs.iter().map(|®| Opnd::Reg(reg)).collect() } impl Assembler @@ -2832,16 +2879,13 @@ pub fn defer_compilation( asm: &mut Assembler, ocb: &mut OutlinedCb, ) { - if asm.ctx.chain_depth != 0 { + if asm.ctx.get_chain_depth() != 0 { panic!("Double defer!"); } let mut next_ctx = asm.ctx.clone(); - if next_ctx.chain_depth == u8::MAX { - panic!("max block version chain depth reached!"); - } - next_ctx.chain_depth += 1; + next_ctx.increment_chain_depth(); let branch = new_pending_branch(jit, BranchGenFn::JumpToTarget0(Cell::new(BranchShape::Default)));