YJIT: Compress BranchGenFn and BranchShape (#7401)

* YJIT: Compress BranchGenFn and BranchShape

* YJIT: Derive Debug for Branch

* YJIT: Capitalize BranchGenFn names

Co-authored-by: Maxime Chevalier-Boisvert <maxime.chevalierboisvert@shopify.com>
Co-authored-by: Alan Wu <alansi.xingwu@shopify.com>
This commit is contained in:
Takashi Kokubun 2023-02-28 10:04:28 -08:00 committed by GitHub
parent 67ad831b5f
commit 966adfb799
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
Notes: git 2023-02-28 18:04:49 +00:00
Merged-By: k0kubun <takashikkbn@gmail.com>
2 changed files with 127 additions and 157 deletions

View File

@ -1802,42 +1802,6 @@ fn gen_checkkeyword(
KeepCompiling
}
fn gen_jnz_to_target0(
asm: &mut Assembler,
target0: CodePtr,
_target1: Option<CodePtr>,
shape: BranchShape,
) {
match shape {
BranchShape::Next0 | BranchShape::Next1 => unreachable!(),
BranchShape::Default => asm.jnz(target0.into()),
}
}
fn gen_jz_to_target0(
asm: &mut Assembler,
target0: CodePtr,
_target1: Option<CodePtr>,
shape: BranchShape,
) {
match shape {
BranchShape::Next0 | BranchShape::Next1 => unreachable!(),
BranchShape::Default => asm.jz(Target::CodePtr(target0)),
}
}
fn gen_jbe_to_target0(
asm: &mut Assembler,
target0: CodePtr,
_target1: Option<CodePtr>,
shape: BranchShape,
) {
match shape {
BranchShape::Next0 | BranchShape::Next1 => unreachable!(),
BranchShape::Default => asm.jbe(Target::CodePtr(target0)),
}
}
// Generate a jump to a stub that recompiles the current YARV instruction on failure.
// When depth_limit is exceeded, generate a jump to a side exit.
fn jit_chain_guard(
@ -1850,9 +1814,9 @@ fn jit_chain_guard(
side_exit: Target,
) {
let target0_gen_fn = match jcc {
JCC_JNE | JCC_JNZ => gen_jnz_to_target0,
JCC_JZ | JCC_JE => gen_jz_to_target0,
JCC_JBE | JCC_JNA => gen_jbe_to_target0,
JCC_JNE | JCC_JNZ => BranchGenFn::JNZToTarget0,
JCC_JZ | JCC_JE => BranchGenFn::JZToTarget0,
JCC_JBE | JCC_JNA => BranchGenFn::JBEToTarget0,
};
if (ctx.get_chain_depth() as i32) < depth_limit {
@ -1865,7 +1829,7 @@ fn jit_chain_guard(
gen_branch(jit, asm, ocb, bid, &deeper, None, None, target0_gen_fn);
} else {
target0_gen_fn(asm, side_exit.unwrap_code_ptr(), None, BranchShape::Default);
target0_gen_fn.call(asm, side_exit.unwrap_code_ptr(), None);
}
}
@ -3498,27 +3462,6 @@ fn gen_opt_case_dispatch(
}
}
fn gen_branchif_branch(
asm: &mut Assembler,
target0: CodePtr,
target1: Option<CodePtr>,
shape: BranchShape,
) {
assert!(target1 != None);
match shape {
BranchShape::Next0 => {
asm.jz(target1.unwrap().into());
}
BranchShape::Next1 => {
asm.jnz(target0.into());
}
BranchShape::Default => {
asm.jnz(target0.into());
asm.jmp(target1.unwrap().into());
}
}
}
fn gen_branchif(
jit: &mut JITState,
ctx: &mut Context,
@ -3565,29 +3508,13 @@ fn gen_branchif(
ctx,
Some(next_block),
Some(ctx),
gen_branchif_branch,
BranchGenFn::BranchIf(BranchShape::Default),
);
}
EndBlock
}
fn gen_branchunless_branch(
asm: &mut Assembler,
target0: CodePtr,
target1: Option<CodePtr>,
shape: BranchShape,
) {
match shape {
BranchShape::Next0 => asm.jnz(target1.unwrap().into()),
BranchShape::Next1 => asm.jz(target0.into()),
BranchShape::Default => {
asm.jz(target0.into());
asm.jmp(target1.unwrap().into());
}
}
}
fn gen_branchunless(
jit: &mut JITState,
ctx: &mut Context,
@ -3635,29 +3562,13 @@ fn gen_branchunless(
ctx,
Some(next_block),
Some(ctx),
gen_branchunless_branch,
BranchGenFn::BranchUnless(BranchShape::Default),
);
}
EndBlock
}
fn gen_branchnil_branch(
asm: &mut Assembler,
target0: CodePtr,
target1: Option<CodePtr>,
shape: BranchShape,
) {
match shape {
BranchShape::Next0 => asm.jne(target1.unwrap().into()),
BranchShape::Next1 => asm.je(target0.into()),
BranchShape::Default => {
asm.je(target0.into());
asm.jmp(target1.unwrap().into());
}
}
}
fn gen_branchnil(
jit: &mut JITState,
ctx: &mut Context,
@ -3702,7 +3613,7 @@ fn gen_branchnil(
ctx,
Some(next_block),
Some(ctx),
gen_branchnil_branch,
BranchGenFn::BranchNil(BranchShape::Default),
);
}
@ -5954,15 +5865,7 @@ fn gen_send_iseq(
&return_ctx,
None,
None,
|asm, target0, _target1, shape| {
match shape {
BranchShape::Default => {
asm.comment("update cfp->jit_return");
asm.mov(Opnd::mem(64, CFP, RUBY_OFFSET_CFP_JIT_RETURN), Opnd::const_ptr(target0.raw_ptr()));
}
_ => unreachable!()
}
},
BranchGenFn::JITReturn,
);
// Directly jump to the entry point of the callee

View File

@ -400,9 +400,110 @@ pub enum BranchShape {
Default, // Neither target is next
}
// Branch code generation function signature
type BranchGenFn =
fn(cb: &mut Assembler, target0: CodePtr, target1: Option<CodePtr>, shape: BranchShape) -> ();
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum BranchGenFn {
BranchIf(BranchShape),
BranchNil(BranchShape),
BranchUnless(BranchShape),
JumpToTarget0(BranchShape),
JNZToTarget0,
JZToTarget0,
JBEToTarget0,
JITReturn,
}
impl BranchGenFn {
pub fn call(self, asm: &mut Assembler, target0: CodePtr, target1: Option<CodePtr>) {
match self {
BranchGenFn::BranchIf(shape) => {
match shape {
BranchShape::Next0 => asm.jz(target1.unwrap().into()),
BranchShape::Next1 => asm.jnz(target0.into()),
BranchShape::Default => {
asm.jnz(target0.into());
asm.jmp(target1.unwrap().into());
}
}
}
BranchGenFn::BranchNil(shape) => {
match shape {
BranchShape::Next0 => asm.jne(target1.unwrap().into()),
BranchShape::Next1 => asm.je(target0.into()),
BranchShape::Default => {
asm.je(target0.into());
asm.jmp(target1.unwrap().into());
}
}
}
BranchGenFn::BranchUnless(shape) => {
match shape {
BranchShape::Next0 => asm.jnz(target1.unwrap().into()),
BranchShape::Next1 => asm.jz(target0.into()),
BranchShape::Default => {
asm.jz(target0.into());
asm.jmp(target1.unwrap().into());
}
}
}
BranchGenFn::JumpToTarget0(shape) => {
if shape == BranchShape::Next1 {
panic!("Branch shape Next1 not allowed in JumpToTarget0!");
}
if shape == BranchShape::Default {
asm.jmp(target0.into());
}
}
BranchGenFn::JNZToTarget0 => {
asm.jnz(target0.into())
}
BranchGenFn::JZToTarget0 => {
asm.jz(Target::CodePtr(target0))
}
BranchGenFn::JBEToTarget0 => {
asm.jbe(Target::CodePtr(target0))
}
BranchGenFn::JITReturn => {
asm.comment("update cfp->jit_return");
asm.mov(Opnd::mem(64, CFP, RUBY_OFFSET_CFP_JIT_RETURN), Opnd::const_ptr(target0.raw_ptr()));
}
}
}
pub fn get_shape(self) -> BranchShape {
match self {
BranchGenFn::BranchIf(shape) |
BranchGenFn::BranchNil(shape) |
BranchGenFn::BranchUnless(shape) |
BranchGenFn::JumpToTarget0(shape) => shape,
BranchGenFn::JNZToTarget0 |
BranchGenFn::JZToTarget0 |
BranchGenFn::JBEToTarget0 |
BranchGenFn::JITReturn => BranchShape::Default,
}
}
pub fn set_shape(&mut self, new_shape: BranchShape) {
match self {
BranchGenFn::BranchIf(shape) |
BranchGenFn::BranchNil(shape) |
BranchGenFn::BranchUnless(shape) => {
*shape = new_shape;
}
BranchGenFn::JumpToTarget0(shape) => {
if new_shape == BranchShape::Next1 {
panic!("Branch shape Next1 not allowed in JumpToTarget0!");
}
*shape = new_shape;
}
BranchGenFn::JNZToTarget0 |
BranchGenFn::JZToTarget0 |
BranchGenFn::JBEToTarget0 |
BranchGenFn::JITReturn => {
assert_eq!(new_shape, BranchShape::Default);
}
}
}
}
/// A place that a branch could jump to
#[derive(Debug)]
@ -457,6 +558,7 @@ struct BranchStub {
/// Store info about an outgoing branch in a code segment
/// Note: care must be taken to minimize the size of branch objects
#[derive(Debug)]
struct Branch {
// Block this is attached to
block: BlockRef,
@ -470,22 +572,6 @@ struct Branch {
// Branch code generation function
gen_fn: BranchGenFn,
// Shape of the branch
shape: BranchShape,
}
impl std::fmt::Debug for Branch {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// TODO: expand this if needed. #[derive(Debug)] on Branch gave a
// strange error related to BranchGenFn
formatter
.debug_struct("Branch")
.field("start", &self.start_addr)
.field("end", &self.end_addr)
.field("targets", &self.targets)
.finish()
}
}
impl Branch {
@ -1754,11 +1840,10 @@ fn regenerate_branch(cb: &mut CodeBlock, branch: &mut Branch) {
// Generate the branch
let mut asm = Assembler::new();
asm.comment("regenerate_branch");
(branch.gen_fn)(
branch.gen_fn.call(
&mut asm,
branch.get_target_address(0).unwrap(),
branch.get_target_address(1),
branch.shape,
);
// Rewrite the branch
@ -1808,10 +1893,7 @@ fn make_branch_entry(block: &BlockRef, gen_fn: BranchGenFn) -> BranchRef {
targets: [None, None],
// Branch code generation function
gen_fn: gen_fn,
// Shape of the branch
shape: BranchShape::Default,
gen_fn,
};
// Add to the list of outgoing branches for the block
@ -1905,7 +1987,7 @@ fn branch_stub_hit_body(branch_ptr: *const c_void, target_idx: u32, ec: EcPtr) -
// If this block hasn't yet been compiled
if block.is_none() {
let branch_old_shape = branch.shape;
let branch_old_shape = branch.gen_fn.get_shape();
let mut branch_modified = false;
// If the new block can be generated right after the branch (at cb->write_pos)
@ -1914,7 +1996,7 @@ fn branch_stub_hit_body(branch_ptr: *const c_void, target_idx: u32, ec: EcPtr) -
assert!(branch.end_addr == branch.block.borrow().end_addr);
// Change the branch shape to indicate the target block will be placed next
branch.shape = target_branch_shape;
branch.gen_fn.set_shape(target_branch_shape);
// Rewrite the branch with the new, potentially more compact shape
regenerate_branch(cb, &mut branch);
@ -1934,7 +2016,7 @@ fn branch_stub_hit_body(branch_ptr: *const c_void, target_idx: u32, ec: EcPtr) -
if block.is_none() && branch_modified {
// We couldn't generate a new block for the branch, but we modified the branch.
// Restore the branch by regenerating it.
branch.shape = branch_old_shape;
branch.gen_fn.set_shape(branch_old_shape);
regenerate_branch(cb, &mut branch);
}
}
@ -1945,7 +2027,7 @@ fn branch_stub_hit_body(branch_ptr: *const c_void, target_idx: u32, ec: EcPtr) -
let mut block: RefMut<_> = block_rc.borrow_mut();
// Branch shape should reflect layout
assert!(!(branch.shape == target_branch_shape && Some(block.start_addr) != branch.end_addr));
assert!(!(branch.gen_fn.get_shape() == target_branch_shape && Some(block.start_addr) != branch.end_addr));
// Add this branch to the list of incoming branches for the target
block.push_incoming(branch_rc.clone());
@ -2145,28 +2227,13 @@ pub fn gen_branch(
// Call the branch generation function
asm.mark_branch_start(&branchref);
if let Some(dst_addr) = branch.get_target_address(0) {
gen_fn(asm, dst_addr, branch.get_target_address(1), BranchShape::Default);
gen_fn.call(asm, dst_addr, branch.get_target_address(1));
}
asm.mark_branch_end(&branchref);
}
fn gen_jump_branch(
asm: &mut Assembler,
target0: CodePtr,
_target1: Option<CodePtr>,
shape: BranchShape,
) {
if shape == BranchShape::Next1 {
panic!("Branch shape Next1 not allowed in gen_jump_branch!");
}
if shape == BranchShape::Default {
asm.jmp(target0.into());
}
}
pub fn gen_direct_jump(jit: &JITState, ctx: &Context, target0: BlockId, asm: &mut Assembler) {
let branchref = make_branch_entry(&jit.get_block(), gen_jump_branch);
let branchref = make_branch_entry(&jit.get_block(), BranchGenFn::JumpToTarget0(BranchShape::Default));
let mut branch = branchref.borrow_mut();
let mut new_target = BranchTarget::Stub(Box::new(BranchStub {
@ -2186,17 +2253,17 @@ pub fn gen_direct_jump(jit: &JITState, ctx: &Context, target0: BlockId, asm: &mu
new_target = BranchTarget::Block(blockref.clone());
branch.shape = BranchShape::Default;
branch.gen_fn.set_shape(BranchShape::Default);
// Call the branch generation function
asm.comment("gen_direct_jmp: existing block");
asm.mark_branch_start(&branchref);
gen_jump_branch(asm, block_addr, None, BranchShape::Default);
branch.gen_fn.call(asm, block_addr, None);
asm.mark_branch_end(&branchref);
} else {
// `None` in new_target.address signals gen_block_series() to compile the
// target block right after this one (fallthrough).
branch.shape = BranchShape::Next0;
branch.gen_fn.set_shape(BranchShape::Next0);
// The branch is effectively empty (a noop)
asm.comment("gen_direct_jmp: fallthrough");
@ -2226,7 +2293,7 @@ pub fn defer_compilation(
next_ctx.chain_depth += 1;
let block_rc = jit.get_block();
let branch_rc = make_branch_entry(&jit.get_block(), gen_jump_branch);
let branch_rc = make_branch_entry(&jit.get_block(), BranchGenFn::JumpToTarget0(BranchShape::Default));
let mut branch = branch_rc.borrow_mut();
let block = block_rc.borrow();
@ -2240,7 +2307,7 @@ pub fn defer_compilation(
asm.comment("defer_compilation");
asm.mark_branch_start(&branch_rc);
if let Some(dst_addr) = branch.get_target_address(0) {
gen_jump_branch(asm, dst_addr, None, BranchShape::Default);
branch.gen_fn.call(asm, dst_addr, None);
}
asm.mark_branch_end(&branch_rc);
@ -2432,7 +2499,7 @@ pub fn invalidate_block_version(blockref: &BlockRef) {
// The new block will no longer be adjacent.
// Note that we could be enlarging the branch and writing into the
// start of the block being invalidated.
branch.shape = BranchShape::Default;
branch.gen_fn.set_shape(BranchShape::Default);
}
// Rewrite the branch with the new jump target address