ZJIT: Create more ergonomic type profiling API (#13339)

This commit is contained in:
Max Bernstein 2025-05-16 13:50:48 -04:00 committed by GitHub
parent eead83160b
commit d9248856d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
Notes: git 2025-05-16 17:51:02 +00:00
Merged-By: k0kubun <takashikkbn@gmail.com>

View File

@ -6,7 +6,7 @@
use crate::{
cruby::*,
options::{get_option, DumpHIR},
profile::{self, get_or_create_iseq_payload},
profile::{get_or_create_iseq_payload, IseqPayload},
state::ZJITState,
cast::IntoUsize,
};
@ -673,6 +673,7 @@ pub struct Function {
insn_types: Vec<Type>,
blocks: Vec<Block>,
entry_block: BlockId,
profiles: Option<ProfileOracle>,
}
impl Function {
@ -685,6 +686,7 @@ impl Function {
blocks: vec![Block::default()],
entry_block: BlockId(0),
param_types: vec![],
profiles: None,
}
}
@ -994,6 +996,20 @@ impl Function {
}
}
/// Return the interpreter-profiled type of the HIR instruction at the given ISEQ instruction
/// index, if it is known. This historical type record is not a guarantee and must be checked
/// with a GuardType or similar.
fn profiled_type_of_at(&self, insn: InsnId, iseq_insn_idx: usize) -> Option<Type> {
let Some(ref profiles) = self.profiles else { return None };
let Some(entries) = profiles.types.get(&iseq_insn_idx) else { return None };
for &(entry_insn, entry_type) in entries {
if self.union_find.borrow().find_const(entry_insn) == self.union_find.borrow().find_const(insn) {
return Some(entry_type);
}
}
None
}
fn likely_is_fixnum(&self, val: InsnId, profiled_type: Type) -> bool {
return self.is_a(val, types::Fixnum) || profiled_type.is_subtype(types::Fixnum);
}
@ -1003,20 +1019,16 @@ impl Function {
return self.push_insn(block, Insn::GuardType { val, guard_type: types::Fixnum, state });
}
fn arguments_likely_fixnums(&mut self, payload: &profile:: IseqPayload, left: InsnId, right: InsnId, state: InsnId) -> bool {
let mut left_profiled_type = types::BasicObject;
let mut right_profiled_type = types::BasicObject;
fn arguments_likely_fixnums(&mut self, left: InsnId, right: InsnId, state: InsnId) -> bool {
let frame_state = self.frame_state(state);
let insn_idx = frame_state.insn_idx;
if let Some([left_type, right_type]) = payload.get_operand_types(insn_idx as usize) {
left_profiled_type = *left_type;
right_profiled_type = *right_type;
}
let iseq_insn_idx = frame_state.insn_idx as usize;
let left_profiled_type = self.profiled_type_of_at(left, iseq_insn_idx).unwrap_or(types::BasicObject);
let right_profiled_type = self.profiled_type_of_at(right, iseq_insn_idx).unwrap_or(types::BasicObject);
self.likely_is_fixnum(left, left_profiled_type) && self.likely_is_fixnum(right, right_profiled_type)
}
fn try_rewrite_fixnum_op(&mut self, block: BlockId, orig_insn_id: InsnId, f: &dyn Fn(InsnId, InsnId) -> Insn, bop: u32, left: InsnId, right: InsnId, payload: &profile::IseqPayload, state: InsnId) {
if self.arguments_likely_fixnums(payload, left, right, state) {
fn try_rewrite_fixnum_op(&mut self, block: BlockId, orig_insn_id: InsnId, f: &dyn Fn(InsnId, InsnId) -> Insn, bop: u32, left: InsnId, right: InsnId, state: InsnId) {
if self.arguments_likely_fixnums(left, right, state) {
if bop == BOP_NEQ {
// For opt_neq, the interpreter checks that both neq and eq are unchanged.
self.push_insn(block, Insn::PatchPoint(Invariant::BOPRedefined { klass: INTEGER_REDEFINED_OP_FLAG, bop: BOP_EQ }));
@ -1026,6 +1038,7 @@ impl Function {
let right = self.coerce_to_fixnum(block, right, state);
let result = self.push_insn(block, f(left, right));
self.make_equal_to(orig_insn_id, result);
self.insn_types[result.0] = self.infer_type(result);
} else {
self.push_insn_id(block, orig_insn_id);
}
@ -1034,34 +1047,33 @@ impl Function {
/// Rewrite SendWithoutBlock opcodes into SendWithoutBlockDirect opcodes if we know the target
/// ISEQ statically. This removes run-time method lookups and opens the door for inlining.
fn optimize_direct_sends(&mut self) {
let payload = get_or_create_iseq_payload(self.iseq);
for block in self.rpo() {
let old_insns = std::mem::take(&mut self.blocks[block.0].insns);
assert!(self.blocks[block.0].insns.is_empty());
for insn_id in old_insns {
match self.find(insn_id) {
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "+" && args.len() == 1 =>
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumAdd { left, right, state }, BOP_PLUS, self_val, args[0], payload, state),
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumAdd { left, right, state }, BOP_PLUS, self_val, args[0], state),
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "-" && args.len() == 1 =>
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumSub { left, right, state }, BOP_MINUS, self_val, args[0], payload, state),
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumSub { left, right, state }, BOP_MINUS, self_val, args[0], state),
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "*" && args.len() == 1 =>
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumMult { left, right, state }, BOP_MULT, self_val, args[0], payload, state),
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumMult { left, right, state }, BOP_MULT, self_val, args[0], state),
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "/" && args.len() == 1 =>
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumDiv { left, right, state }, BOP_DIV, self_val, args[0], payload, state),
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumDiv { left, right, state }, BOP_DIV, self_val, args[0], state),
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "%" && args.len() == 1 =>
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumMod { left, right, state }, BOP_MOD, self_val, args[0], payload, state),
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumMod { left, right, state }, BOP_MOD, self_val, args[0], state),
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "==" && args.len() == 1 =>
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumEq { left, right }, BOP_EQ, self_val, args[0], payload, state),
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumEq { left, right }, BOP_EQ, self_val, args[0], state),
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "!=" && args.len() == 1 =>
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumNeq { left, right }, BOP_NEQ, self_val, args[0], payload, state),
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumNeq { left, right }, BOP_NEQ, self_val, args[0], state),
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "<" && args.len() == 1 =>
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumLt { left, right }, BOP_LT, self_val, args[0], payload, state),
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumLt { left, right }, BOP_LT, self_val, args[0], state),
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "<=" && args.len() == 1 =>
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumLe { left, right }, BOP_LE, self_val, args[0], payload, state),
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumLe { left, right }, BOP_LE, self_val, args[0], state),
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == ">" && args.len() == 1 =>
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumGt { left, right }, BOP_GT, self_val, args[0], payload, state),
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumGt { left, right }, BOP_GT, self_val, args[0], state),
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == ">=" && args.len() == 1 =>
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumGe { left, right }, BOP_GE, self_val, args[0], payload, state),
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumGe { left, right }, BOP_GE, self_val, args[0], state),
Insn::SendWithoutBlock { mut self_val, call_info, cd, args, state } => {
let frame_state = self.frame_state(state);
let (klass, guard_equal_to) = if let Some(klass) = self.type_of(self_val).runtime_exact_ruby_class() {
@ -1069,8 +1081,8 @@ impl Function {
(klass, None)
} else {
// If we know that self is top-self from profile information, guard and use it to fold the lookup at compile-time.
match payload.get_operand_types(frame_state.insn_idx) {
Some([self_type, ..]) if self_type.is_top_self() => (self_type.exact_ruby_class().unwrap(), self_type.ruby_object()),
match self.profiled_type_of_at(self_val, frame_state.insn_idx) {
Some(self_type) if self_type.is_top_self() => (self_type.exact_ruby_class().unwrap(), self_type.ruby_object()),
_ => { self.push_insn_id(block, insn_id); continue; }
}
};
@ -1130,7 +1142,6 @@ impl Function {
fn reduce_to_ccall(
fun: &mut Function,
block: BlockId,
payload: &profile::IseqPayload,
self_type: Type,
send: Insn,
send_insn_id: InsnId,
@ -1142,7 +1153,6 @@ impl Function {
let call_info = unsafe { (*cd).ci };
let argc = unsafe { vm_ci_argc(call_info) };
let method_id = unsafe { rb_vm_ci_mid(call_info) };
let iseq_insn_idx = fun.frame_state(state).insn_idx;
// If we have info about the class of the receiver
//
@ -1152,10 +1162,10 @@ impl Function {
let (recv_class, guard_type) = if let Some(klass) = self_type.runtime_exact_ruby_class() {
(klass, None)
} else {
payload.get_operand_types(iseq_insn_idx)
.and_then(|types| types.get(argc as usize))
.and_then(|recv_type| recv_type.exact_ruby_class().and_then(|class| Some((class, Some(recv_type.unspecialized())))))
.ok_or(())?
let iseq_insn_idx = fun.frame_state(state).insn_idx;
let Some(recv_type) = fun.profiled_type_of_at(self_val, iseq_insn_idx) else { return Err(()) };
let Some(recv_class) = recv_type.exact_ruby_class() else { return Err(()) };
(recv_class, Some(recv_type.unspecialized()))
};
// Do method lookup
@ -1221,14 +1231,13 @@ impl Function {
Err(())
}
let payload = get_or_create_iseq_payload(self.iseq);
for block in self.rpo() {
let old_insns = std::mem::take(&mut self.blocks[block.0].insns);
assert!(self.blocks[block.0].insns.is_empty());
for insn_id in old_insns {
if let send @ Insn::SendWithoutBlock { self_val, .. } = self.find(insn_id) {
let self_type = self.type_of(self_val);
if reduce_to_ccall(self, block, payload, self_type, send, insn_id).is_ok() {
if reduce_to_ccall(self, block, self_type, send, insn_id).is_ok() {
continue;
}
}
@ -1598,7 +1607,7 @@ impl FrameState {
}
/// Get a stack operand at idx
fn stack_topn(&mut self, idx: usize) -> Result<InsnId, ParseError> {
fn stack_topn(&self, idx: usize) -> Result<InsnId, ParseError> {
let idx = self.stack.len() - idx - 1;
self.stack.get(idx).ok_or_else(|| ParseError::StackUnderflow(self.clone())).copied()
}
@ -1717,8 +1726,42 @@ fn filter_translatable_calls(flag: u32) -> Result<(), ParseError> {
Ok(())
}
/// We have IseqPayload, which keeps track of HIR Types in the interpreter, but this is not useful
/// or correct to query from inside the optimizer. Instead, ProfileOracle provides an API to look
/// up profiled type information by HIR InsnId at a given ISEQ instruction.
#[derive(Debug)]
struct ProfileOracle {
payload: &'static IseqPayload,
/// types is a map from ISEQ instruction indices -> profiled type information at that ISEQ
/// instruction index. At a given ISEQ instruction, the interpreter has profiled the stack
/// operands to a given ISEQ instruction, and this list of pairs of (InsnId, Type) map that
/// profiling information into HIR instructions.
types: HashMap<usize, Vec<(InsnId, Type)>>,
}
impl ProfileOracle {
fn new(payload: &'static IseqPayload) -> Self {
Self { payload, types: Default::default() }
}
/// Map the interpreter-recorded types of the stack onto the HIR operands on our compile-time virtual stack
fn profile_stack(&mut self, state: &FrameState) {
let iseq_insn_idx = state.insn_idx;
let Some(operand_types) = self.payload.get_operand_types(iseq_insn_idx) else { return };
let entry = self.types.entry(iseq_insn_idx).or_insert_with(|| vec![]);
// operand_types is always going to be <= stack size (otherwise it would have an underflow
// at run-time) so use that to drive iteration.
for (idx, &insn_type) in operand_types.iter().rev().enumerate() {
let insn = state.stack_topn(idx).expect("Unexpected stack underflow in profiling");
entry.push((insn, insn_type))
}
}
}
/// Compile ISEQ into High-level IR
pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
let payload = get_or_create_iseq_payload(iseq);
let mut profiles = ProfileOracle::new(payload);
let mut fun = Function::new(iseq);
// Compute a map of PC->Block by finding jump targets
let jump_targets = compute_jump_targets(iseq);
@ -1791,6 +1834,7 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
let pc = unsafe { rb_iseq_pc_at_idx(iseq, insn_idx) };
state.pc = pc;
let exit_state = state.clone();
profiles.profile_stack(&exit_state);
// try_into() call below is unfortunate. Maybe pick i32 instead of usize for opcodes.
let opcode: u32 = unsafe { rb_iseq_opcode_at_pc(iseq, pc) }
@ -2061,6 +2105,7 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
None => {},
}
fun.profiles = Some(profiles);
Ok(fun)
}
@ -3058,8 +3103,8 @@ mod opt_tests {
bb0():
PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_PLUS)
PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_PLUS)
v15:Fixnum[6] = Const Value(6)
Return v15
v14:Fixnum[6] = Const Value(6)
Return v14
"#]]);
}