ZJIT: Create more ergonomic type profiling API (#13339)
This commit is contained in:
parent
eead83160b
commit
d9248856d2
Notes:
git
2025-05-16 17:51:02 +00:00
Merged-By: k0kubun <takashikkbn@gmail.com>
117
zjit/src/hir.rs
117
zjit/src/hir.rs
@ -6,7 +6,7 @@
|
||||
use crate::{
|
||||
cruby::*,
|
||||
options::{get_option, DumpHIR},
|
||||
profile::{self, get_or_create_iseq_payload},
|
||||
profile::{get_or_create_iseq_payload, IseqPayload},
|
||||
state::ZJITState,
|
||||
cast::IntoUsize,
|
||||
};
|
||||
@ -673,6 +673,7 @@ pub struct Function {
|
||||
insn_types: Vec<Type>,
|
||||
blocks: Vec<Block>,
|
||||
entry_block: BlockId,
|
||||
profiles: Option<ProfileOracle>,
|
||||
}
|
||||
|
||||
impl Function {
|
||||
@ -685,6 +686,7 @@ impl Function {
|
||||
blocks: vec![Block::default()],
|
||||
entry_block: BlockId(0),
|
||||
param_types: vec![],
|
||||
profiles: None,
|
||||
}
|
||||
}
|
||||
|
||||
@ -994,6 +996,20 @@ impl Function {
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the interpreter-profiled type of the HIR instruction at the given ISEQ instruction
|
||||
/// index, if it is known. This historical type record is not a guarantee and must be checked
|
||||
/// with a GuardType or similar.
|
||||
fn profiled_type_of_at(&self, insn: InsnId, iseq_insn_idx: usize) -> Option<Type> {
|
||||
let Some(ref profiles) = self.profiles else { return None };
|
||||
let Some(entries) = profiles.types.get(&iseq_insn_idx) else { return None };
|
||||
for &(entry_insn, entry_type) in entries {
|
||||
if self.union_find.borrow().find_const(entry_insn) == self.union_find.borrow().find_const(insn) {
|
||||
return Some(entry_type);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn likely_is_fixnum(&self, val: InsnId, profiled_type: Type) -> bool {
|
||||
return self.is_a(val, types::Fixnum) || profiled_type.is_subtype(types::Fixnum);
|
||||
}
|
||||
@ -1003,20 +1019,16 @@ impl Function {
|
||||
return self.push_insn(block, Insn::GuardType { val, guard_type: types::Fixnum, state });
|
||||
}
|
||||
|
||||
fn arguments_likely_fixnums(&mut self, payload: &profile:: IseqPayload, left: InsnId, right: InsnId, state: InsnId) -> bool {
|
||||
let mut left_profiled_type = types::BasicObject;
|
||||
let mut right_profiled_type = types::BasicObject;
|
||||
fn arguments_likely_fixnums(&mut self, left: InsnId, right: InsnId, state: InsnId) -> bool {
|
||||
let frame_state = self.frame_state(state);
|
||||
let insn_idx = frame_state.insn_idx;
|
||||
if let Some([left_type, right_type]) = payload.get_operand_types(insn_idx as usize) {
|
||||
left_profiled_type = *left_type;
|
||||
right_profiled_type = *right_type;
|
||||
}
|
||||
let iseq_insn_idx = frame_state.insn_idx as usize;
|
||||
let left_profiled_type = self.profiled_type_of_at(left, iseq_insn_idx).unwrap_or(types::BasicObject);
|
||||
let right_profiled_type = self.profiled_type_of_at(right, iseq_insn_idx).unwrap_or(types::BasicObject);
|
||||
self.likely_is_fixnum(left, left_profiled_type) && self.likely_is_fixnum(right, right_profiled_type)
|
||||
}
|
||||
|
||||
fn try_rewrite_fixnum_op(&mut self, block: BlockId, orig_insn_id: InsnId, f: &dyn Fn(InsnId, InsnId) -> Insn, bop: u32, left: InsnId, right: InsnId, payload: &profile::IseqPayload, state: InsnId) {
|
||||
if self.arguments_likely_fixnums(payload, left, right, state) {
|
||||
fn try_rewrite_fixnum_op(&mut self, block: BlockId, orig_insn_id: InsnId, f: &dyn Fn(InsnId, InsnId) -> Insn, bop: u32, left: InsnId, right: InsnId, state: InsnId) {
|
||||
if self.arguments_likely_fixnums(left, right, state) {
|
||||
if bop == BOP_NEQ {
|
||||
// For opt_neq, the interpreter checks that both neq and eq are unchanged.
|
||||
self.push_insn(block, Insn::PatchPoint(Invariant::BOPRedefined { klass: INTEGER_REDEFINED_OP_FLAG, bop: BOP_EQ }));
|
||||
@ -1026,6 +1038,7 @@ impl Function {
|
||||
let right = self.coerce_to_fixnum(block, right, state);
|
||||
let result = self.push_insn(block, f(left, right));
|
||||
self.make_equal_to(orig_insn_id, result);
|
||||
self.insn_types[result.0] = self.infer_type(result);
|
||||
} else {
|
||||
self.push_insn_id(block, orig_insn_id);
|
||||
}
|
||||
@ -1034,34 +1047,33 @@ impl Function {
|
||||
/// Rewrite SendWithoutBlock opcodes into SendWithoutBlockDirect opcodes if we know the target
|
||||
/// ISEQ statically. This removes run-time method lookups and opens the door for inlining.
|
||||
fn optimize_direct_sends(&mut self) {
|
||||
let payload = get_or_create_iseq_payload(self.iseq);
|
||||
for block in self.rpo() {
|
||||
let old_insns = std::mem::take(&mut self.blocks[block.0].insns);
|
||||
assert!(self.blocks[block.0].insns.is_empty());
|
||||
for insn_id in old_insns {
|
||||
match self.find(insn_id) {
|
||||
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "+" && args.len() == 1 =>
|
||||
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumAdd { left, right, state }, BOP_PLUS, self_val, args[0], payload, state),
|
||||
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumAdd { left, right, state }, BOP_PLUS, self_val, args[0], state),
|
||||
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "-" && args.len() == 1 =>
|
||||
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumSub { left, right, state }, BOP_MINUS, self_val, args[0], payload, state),
|
||||
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumSub { left, right, state }, BOP_MINUS, self_val, args[0], state),
|
||||
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "*" && args.len() == 1 =>
|
||||
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumMult { left, right, state }, BOP_MULT, self_val, args[0], payload, state),
|
||||
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumMult { left, right, state }, BOP_MULT, self_val, args[0], state),
|
||||
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "/" && args.len() == 1 =>
|
||||
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumDiv { left, right, state }, BOP_DIV, self_val, args[0], payload, state),
|
||||
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumDiv { left, right, state }, BOP_DIV, self_val, args[0], state),
|
||||
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "%" && args.len() == 1 =>
|
||||
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumMod { left, right, state }, BOP_MOD, self_val, args[0], payload, state),
|
||||
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumMod { left, right, state }, BOP_MOD, self_val, args[0], state),
|
||||
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "==" && args.len() == 1 =>
|
||||
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumEq { left, right }, BOP_EQ, self_val, args[0], payload, state),
|
||||
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumEq { left, right }, BOP_EQ, self_val, args[0], state),
|
||||
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "!=" && args.len() == 1 =>
|
||||
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumNeq { left, right }, BOP_NEQ, self_val, args[0], payload, state),
|
||||
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumNeq { left, right }, BOP_NEQ, self_val, args[0], state),
|
||||
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "<" && args.len() == 1 =>
|
||||
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumLt { left, right }, BOP_LT, self_val, args[0], payload, state),
|
||||
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumLt { left, right }, BOP_LT, self_val, args[0], state),
|
||||
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "<=" && args.len() == 1 =>
|
||||
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumLe { left, right }, BOP_LE, self_val, args[0], payload, state),
|
||||
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumLe { left, right }, BOP_LE, self_val, args[0], state),
|
||||
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == ">" && args.len() == 1 =>
|
||||
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumGt { left, right }, BOP_GT, self_val, args[0], payload, state),
|
||||
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumGt { left, right }, BOP_GT, self_val, args[0], state),
|
||||
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == ">=" && args.len() == 1 =>
|
||||
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumGe { left, right }, BOP_GE, self_val, args[0], payload, state),
|
||||
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumGe { left, right }, BOP_GE, self_val, args[0], state),
|
||||
Insn::SendWithoutBlock { mut self_val, call_info, cd, args, state } => {
|
||||
let frame_state = self.frame_state(state);
|
||||
let (klass, guard_equal_to) = if let Some(klass) = self.type_of(self_val).runtime_exact_ruby_class() {
|
||||
@ -1069,8 +1081,8 @@ impl Function {
|
||||
(klass, None)
|
||||
} else {
|
||||
// If we know that self is top-self from profile information, guard and use it to fold the lookup at compile-time.
|
||||
match payload.get_operand_types(frame_state.insn_idx) {
|
||||
Some([self_type, ..]) if self_type.is_top_self() => (self_type.exact_ruby_class().unwrap(), self_type.ruby_object()),
|
||||
match self.profiled_type_of_at(self_val, frame_state.insn_idx) {
|
||||
Some(self_type) if self_type.is_top_self() => (self_type.exact_ruby_class().unwrap(), self_type.ruby_object()),
|
||||
_ => { self.push_insn_id(block, insn_id); continue; }
|
||||
}
|
||||
};
|
||||
@ -1130,7 +1142,6 @@ impl Function {
|
||||
fn reduce_to_ccall(
|
||||
fun: &mut Function,
|
||||
block: BlockId,
|
||||
payload: &profile::IseqPayload,
|
||||
self_type: Type,
|
||||
send: Insn,
|
||||
send_insn_id: InsnId,
|
||||
@ -1142,7 +1153,6 @@ impl Function {
|
||||
let call_info = unsafe { (*cd).ci };
|
||||
let argc = unsafe { vm_ci_argc(call_info) };
|
||||
let method_id = unsafe { rb_vm_ci_mid(call_info) };
|
||||
let iseq_insn_idx = fun.frame_state(state).insn_idx;
|
||||
|
||||
// If we have info about the class of the receiver
|
||||
//
|
||||
@ -1152,10 +1162,10 @@ impl Function {
|
||||
let (recv_class, guard_type) = if let Some(klass) = self_type.runtime_exact_ruby_class() {
|
||||
(klass, None)
|
||||
} else {
|
||||
payload.get_operand_types(iseq_insn_idx)
|
||||
.and_then(|types| types.get(argc as usize))
|
||||
.and_then(|recv_type| recv_type.exact_ruby_class().and_then(|class| Some((class, Some(recv_type.unspecialized())))))
|
||||
.ok_or(())?
|
||||
let iseq_insn_idx = fun.frame_state(state).insn_idx;
|
||||
let Some(recv_type) = fun.profiled_type_of_at(self_val, iseq_insn_idx) else { return Err(()) };
|
||||
let Some(recv_class) = recv_type.exact_ruby_class() else { return Err(()) };
|
||||
(recv_class, Some(recv_type.unspecialized()))
|
||||
};
|
||||
|
||||
// Do method lookup
|
||||
@ -1221,14 +1231,13 @@ impl Function {
|
||||
Err(())
|
||||
}
|
||||
|
||||
let payload = get_or_create_iseq_payload(self.iseq);
|
||||
for block in self.rpo() {
|
||||
let old_insns = std::mem::take(&mut self.blocks[block.0].insns);
|
||||
assert!(self.blocks[block.0].insns.is_empty());
|
||||
for insn_id in old_insns {
|
||||
if let send @ Insn::SendWithoutBlock { self_val, .. } = self.find(insn_id) {
|
||||
let self_type = self.type_of(self_val);
|
||||
if reduce_to_ccall(self, block, payload, self_type, send, insn_id).is_ok() {
|
||||
if reduce_to_ccall(self, block, self_type, send, insn_id).is_ok() {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@ -1598,7 +1607,7 @@ impl FrameState {
|
||||
}
|
||||
|
||||
/// Get a stack operand at idx
|
||||
fn stack_topn(&mut self, idx: usize) -> Result<InsnId, ParseError> {
|
||||
fn stack_topn(&self, idx: usize) -> Result<InsnId, ParseError> {
|
||||
let idx = self.stack.len() - idx - 1;
|
||||
self.stack.get(idx).ok_or_else(|| ParseError::StackUnderflow(self.clone())).copied()
|
||||
}
|
||||
@ -1717,8 +1726,42 @@ fn filter_translatable_calls(flag: u32) -> Result<(), ParseError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// We have IseqPayload, which keeps track of HIR Types in the interpreter, but this is not useful
|
||||
/// or correct to query from inside the optimizer. Instead, ProfileOracle provides an API to look
|
||||
/// up profiled type information by HIR InsnId at a given ISEQ instruction.
|
||||
#[derive(Debug)]
|
||||
struct ProfileOracle {
|
||||
payload: &'static IseqPayload,
|
||||
/// types is a map from ISEQ instruction indices -> profiled type information at that ISEQ
|
||||
/// instruction index. At a given ISEQ instruction, the interpreter has profiled the stack
|
||||
/// operands to a given ISEQ instruction, and this list of pairs of (InsnId, Type) map that
|
||||
/// profiling information into HIR instructions.
|
||||
types: HashMap<usize, Vec<(InsnId, Type)>>,
|
||||
}
|
||||
|
||||
impl ProfileOracle {
|
||||
fn new(payload: &'static IseqPayload) -> Self {
|
||||
Self { payload, types: Default::default() }
|
||||
}
|
||||
|
||||
/// Map the interpreter-recorded types of the stack onto the HIR operands on our compile-time virtual stack
|
||||
fn profile_stack(&mut self, state: &FrameState) {
|
||||
let iseq_insn_idx = state.insn_idx;
|
||||
let Some(operand_types) = self.payload.get_operand_types(iseq_insn_idx) else { return };
|
||||
let entry = self.types.entry(iseq_insn_idx).or_insert_with(|| vec![]);
|
||||
// operand_types is always going to be <= stack size (otherwise it would have an underflow
|
||||
// at run-time) so use that to drive iteration.
|
||||
for (idx, &insn_type) in operand_types.iter().rev().enumerate() {
|
||||
let insn = state.stack_topn(idx).expect("Unexpected stack underflow in profiling");
|
||||
entry.push((insn, insn_type))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Compile ISEQ into High-level IR
|
||||
pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
|
||||
let payload = get_or_create_iseq_payload(iseq);
|
||||
let mut profiles = ProfileOracle::new(payload);
|
||||
let mut fun = Function::new(iseq);
|
||||
// Compute a map of PC->Block by finding jump targets
|
||||
let jump_targets = compute_jump_targets(iseq);
|
||||
@ -1791,6 +1834,7 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
|
||||
let pc = unsafe { rb_iseq_pc_at_idx(iseq, insn_idx) };
|
||||
state.pc = pc;
|
||||
let exit_state = state.clone();
|
||||
profiles.profile_stack(&exit_state);
|
||||
|
||||
// try_into() call below is unfortunate. Maybe pick i32 instead of usize for opcodes.
|
||||
let opcode: u32 = unsafe { rb_iseq_opcode_at_pc(iseq, pc) }
|
||||
@ -2061,6 +2105,7 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
|
||||
None => {},
|
||||
}
|
||||
|
||||
fun.profiles = Some(profiles);
|
||||
Ok(fun)
|
||||
}
|
||||
|
||||
@ -3058,8 +3103,8 @@ mod opt_tests {
|
||||
bb0():
|
||||
PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_PLUS)
|
||||
PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_PLUS)
|
||||
v15:Fixnum[6] = Const Value(6)
|
||||
Return v15
|
||||
v14:Fixnum[6] = Const Value(6)
|
||||
Return v14
|
||||
"#]]);
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user