ZJIT: Infer ArrayExact for the rest parameter

The rest parameter is always a rb_cArray, even when anonymous. (This is
different from kw_rest, which can be nil.)
This commit is contained in:
Alan Wu 2025-05-14 19:18:23 +09:00
parent 767e8e165a
commit 37d6de5331
Notes: git 2025-05-14 11:13:39 +00:00

View File

@ -654,6 +654,8 @@ impl<T: Copy + Into<usize> + PartialEq> UnionFind<T> {
pub struct Function {
// ISEQ this function refers to
iseq: *const rb_iseq_t,
// The types for the parameters of this function
param_types: Vec<Type>,
// TODO: get method name and source location from the ISEQ
@ -673,6 +675,7 @@ impl Function {
union_find: UnionFind::new().into(),
blocks: vec![Block::default()],
entry_block: BlockId(0),
param_types: vec![],
}
}
@ -913,9 +916,18 @@ impl Function {
fn infer_types(&mut self) {
// Reset all types
self.insn_types.fill(types::Empty);
for param in &self.blocks[self.entry_block.0].params {
// Fill parameter types
let entry_params = self.blocks[self.entry_block.0].params.iter();
let param_types = self.param_types.iter();
assert_eq!(
entry_params.len(),
entry_params.len(),
"param types should be initialized before type inference"
);
for (param, param_type) in std::iter::zip(entry_params, param_types) {
// We know that function parameters are BasicObject or some subclass
self.insn_types[param.0] = types::BasicObject;
self.insn_types[param.0] = *param_type;
}
let rpo = self.rpo();
// Walk the graph, computing types until fixpoint
@ -1712,6 +1724,14 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
// Iteratively fill out basic blocks using a queue
// TODO(max): Basic block arguments at edges
let mut queue = std::collections::VecDeque::new();
// Index of the rest parameter for comparison below
let rest_param_idx = if !iseq.is_null() && unsafe { get_iseq_flags_has_rest(iseq) } {
let opt_num = unsafe { get_iseq_body_param_opt_num(iseq) };
let lead_num = unsafe { get_iseq_body_param_lead_num(iseq) };
opt_num + lead_num
} else {
-1
};
// The HIR function will have the same number of parameter as the iseq so
// we properly handle calls from the interpreter. Roughly speaking, each
// item between commas in the source increase the parameter count by one,
@ -1723,6 +1743,13 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
} else {
entry_state.locals.push(fun.push_insn(fun.entry_block, Insn::Const { val: Const::Value(Qnil) }));
}
let mut param_type = types::BasicObject;
// Rest parameters are always ArrayExact
if let Ok(true) = c_int::try_from(idx).map(|idx| idx == rest_param_idx) {
param_type = types::ArrayExact;
}
fun.param_types.push(param_type);
}
queue.push_back((entry_state, fun.entry_block, /*insn_idx=*/0_u32));
@ -3111,6 +3138,18 @@ mod opt_tests {
"#]]);
}
#[test]
fn test_rest_param_get_bb_param() {
eval("
def rest(*array) = array
");
assert_optimized_method_hir("rest", expect![[r#"
fn rest:
bb0(v0:ArrayExact):
Return v0
"#]]);
}
#[test]
fn test_optimize_top_level_call_into_send_direct() {
eval("