YJIT: Add support for rest with option and splat args (#7698)

This commit is contained in:
Jimmy Miller 2023-04-13 19:21:02 -04:00 committed by GitHub
parent f7d41b9d7b
commit 08413f982c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
Notes: git 2023-04-13 23:21:23 +00:00
Merged-By: k0kubun <takashikkbn@gmail.com>
3 changed files with 170 additions and 40 deletions

View File

@ -3873,3 +3873,43 @@ assert_equal '2', %q{
entry # call branch_stub_hit (spill temps) entry # call branch_stub_hit (spill temps)
entry # doesn't call branch_stub_hit (not spill temps) entry # doesn't call branch_stub_hit (not spill temps)
} }
# Test rest and optional_params
assert_equal '[true, true, true, true]', %q{
def my_func(stuff, base=nil, sort=true, *args)
[stuff, base, sort, args]
end
def calling_my_func
results = []
results << (my_func("test") == ["test", nil, true, []])
results << (my_func("test", :base) == ["test", :base, true, []])
results << (my_func("test", :base, false) == ["test", :base, false, []])
results << (my_func("test", :base, false, "other", "other") == ["test", :base, false, ["other", "other"]])
results
end
calling_my_func
}
# Test rest and optional_params and splat
assert_equal '[true, true, true, true, true]', %q{
def my_func(stuff, base=nil, sort=true, *args)
[stuff, base, sort, args]
end
def calling_my_func
results = []
splat = ["test"]
results << (my_func(*splat) == ["test", nil, true, []])
splat = [:base]
results << (my_func("test", *splat) == ["test", :base, true, []])
splat = [:base, false]
results << (my_func("test", *splat) == ["test", :base, false, []])
splat = [:base, false, "other", "other"]
results << (my_func("test", *splat) == ["test", :base, false, ["other", "other"]])
splat = ["test", :base, false, "other", "other"]
results << (my_func(*splat) == ["test", :base, false, ["other", "other"]])
results
end
calling_my_func
}

View File

@ -14,6 +14,7 @@ use YARVOpnd::*;
use std::cell::Cell; use std::cell::Cell;
use std::cmp; use std::cmp;
use std::cmp::min;
use std::collections::HashMap; use std::collections::HashMap;
use std::ffi::CStr; use std::ffi::CStr;
use std::mem; use std::mem;
@ -4780,6 +4781,7 @@ fn gen_push_frame(
asm: &mut Assembler, asm: &mut Assembler,
set_sp_cfp: bool, // if true CFP and SP will be switched to the callee set_sp_cfp: bool, // if true CFP and SP will be switched to the callee
frame: ControlFrame, frame: ControlFrame,
rest_arg: Option<(i32, Opnd)>,
) { ) {
assert!(frame.local_size >= 0); assert!(frame.local_size >= 0);
@ -4880,6 +4882,13 @@ fn gen_push_frame(
} }
} }
if let Some((opts_missing, rest_arg)) = rest_arg {
// We want to set the rest_param just after the optional arguments
let index = opts_missing - num_locals - 3;
let offset = SIZEOF_VALUE_I32 * index;
asm.store(Opnd::mem(64, sp, offset), rest_arg);
}
if set_sp_cfp { if set_sp_cfp {
// Saving SP before calculating ep avoids a dependency on a register // Saving SP before calculating ep avoids a dependency on a register
// However this must be done after referencing frame.recv, which may be SP-relative // However this must be done after referencing frame.recv, which may be SP-relative
@ -5109,7 +5118,7 @@ fn gen_send_cfunc(
}, },
iseq: None, iseq: None,
local_size: 0, local_size: 0,
}); }, None);
if !kw_arg.is_null() { if !kw_arg.is_null() {
// Build a hash from all kwargs passed // Build a hash from all kwargs passed
@ -5475,11 +5484,17 @@ fn gen_send_iseq(
} }
let iseq_has_rest = unsafe { get_iseq_flags_has_rest(iseq) }; let iseq_has_rest = unsafe { get_iseq_flags_has_rest(iseq) };
if iseq_has_rest && captured_opnd.is_some() { if iseq_has_rest && captured_opnd.is_some() {
gen_counter_incr!(asm, send_iseq_has_rest_and_captured); gen_counter_incr!(asm, send_iseq_has_rest_and_captured);
return CantCompile; return CantCompile;
} }
if iseq_has_rest && flags & VM_CALL_OPT_SEND != 0 {
gen_counter_incr!(asm, send_iseq_has_rest_and_send);
return CantCompile;
}
if iseq_has_rest && unsafe { get_iseq_flags_has_kw(iseq) } && supplying_kws { if iseq_has_rest && unsafe { get_iseq_flags_has_kw(iseq) } && supplying_kws {
gen_counter_incr!(asm, send_iseq_has_rest_and_kw_supplied); gen_counter_incr!(asm, send_iseq_has_rest_and_kw_supplied);
return CantCompile; return CantCompile;
@ -5538,20 +5553,21 @@ fn gen_send_iseq(
}; };
// Arity handling and optional parameter setup // Arity handling and optional parameter setup
let opts_filled = argc - required_num - kw_arg_num; let mut opts_filled = argc - required_num - kw_arg_num;
let opt_num = unsafe { get_iseq_body_param_opt_num(iseq) }; let opt_num = unsafe { get_iseq_body_param_opt_num(iseq) };
let opts_missing: i32 = opt_num - opts_filled; // We have a rest argument so there could be more args
// than are required + optional. Those will go in rest.
// So we cap ops_filled at opt_num.
if iseq_has_rest {
opts_filled = min(opts_filled, opt_num);
}
let mut opts_missing: i32 = opt_num - opts_filled;
if doing_kw_call && flags & VM_CALL_ARGS_SPLAT != 0 { if doing_kw_call && flags & VM_CALL_ARGS_SPLAT != 0 {
gen_counter_incr!(asm, send_iseq_splat_with_kw); gen_counter_incr!(asm, send_iseq_splat_with_kw);
return CantCompile; return CantCompile;
} }
if iseq_has_rest && opt_num != 0 {
gen_counter_incr!(asm, send_iseq_has_rest_and_optional);
return CantCompile;
}
if opts_filled < 0 && flags & VM_CALL_ARGS_SPLAT == 0 { if opts_filled < 0 && flags & VM_CALL_ARGS_SPLAT == 0 {
// Too few arguments and no splat to make up for it // Too few arguments and no splat to make up for it
gen_counter_incr!(asm, send_iseq_arity_error); gen_counter_incr!(asm, send_iseq_arity_error);
@ -5571,6 +5587,11 @@ fn gen_send_iseq(
None None
}; };
if iseq_has_rest && opt_num != 0 && (unsafe { get_iseq_flags_has_block(iseq) } || block_arg) {
gen_counter_incr!(asm, send_iseq_has_rest_opt_and_block);
return CantCompile;
}
match block_arg_type { match block_arg_type {
Some(Type::Nil | Type::BlockParamProxy) => { Some(Type::Nil | Type::BlockParamProxy) => {
// We'll handle this later // We'll handle this later
@ -5692,7 +5713,8 @@ fn gen_send_iseq(
return CantCompile; return CantCompile;
} }
} }
let splat_array_length = if flags & VM_CALL_ARGS_SPLAT != 0 && !iseq_has_rest {
let splat_array_length = if flags & VM_CALL_ARGS_SPLAT != 0 {
let array = jit.peek_at_stack(&asm.ctx, if block_arg { 1 } else { 0 }) ; let array = jit.peek_at_stack(&asm.ctx, if block_arg { 1 } else { 0 }) ;
let array_length = if array == Qnil { let array_length = if array == Qnil {
0 0
@ -5700,15 +5722,67 @@ fn gen_send_iseq(
unsafe { rb_yjit_array_len(array) as u32} unsafe { rb_yjit_array_len(array) as u32}
}; };
if opt_num == 0 && required_num != array_length as i32 + argc - 1 { if opt_num == 0 && required_num != array_length as i32 + argc - 1 && !iseq_has_rest {
gen_counter_incr!(asm, send_iseq_splat_arity_error); gen_counter_incr!(asm, send_iseq_splat_arity_error);
return CantCompile; return CantCompile;
} }
if iseq_has_rest && opt_num > 0 {
// If we have a rest and option arugments
// we are going to set the pc_offset for where
// to jump in the called method.
// If the number of args change, that would need to
// change and we don't change that dynmically so we side exit.
// On a normal splat without rest and option args this is handled
// elsewhere depending on the case
asm.comment("Side exit if length doesn't not equal compile time length");
let array_len_opnd = get_array_len(asm, asm.stack_opnd(if block_arg { 1 } else { 0 }));
asm.cmp(array_len_opnd, array_length.into());
let exit = counted_exit!(jit, &mut asm.ctx, ocb, send_splatarray_length_not_equal);
asm.jne(exit);
}
Some(array_length) Some(array_length)
} else { } else {
None None
}; };
// If we have a rest, optional arguments, and a splat
// some of those splatted args will end up filling the
// optional arguments and some will potentially end up
// in the rest. This calculates how many are filled
// by the splat.
let opts_filled_with_splat: Option<i32> = {
if iseq_has_rest && opt_num > 0 {
splat_array_length.map(|len| {
let num_args = (argc - 1) + len as i32;
if num_args >= required_num {
min(num_args - required_num, opt_num)
} else {
0
}
})
} else {
None
}
};
// If we have optional arguments filled by the splat (see comment above)
// we need to set a few variables concerning optional arguments
// to their correct values, as well as set the pc_offset.
if let Some(filled) = opts_filled_with_splat {
opts_missing = opt_num - filled;
opts_filled = filled;
num_params -= opts_missing as u32;
// We are going to jump to the correct offset based on how many optional
// params are remaining.
unsafe {
let opt_table = get_iseq_body_param_opt_table(iseq);
start_pc_offset = (*opt_table.offset(filled as isize)).try_into().unwrap();
};
}
// We will not have CantCompile from here. You can use stack_pop / stack_pop. // We will not have CantCompile from here. You can use stack_pop / stack_pop.
match block_arg_type { match block_arg_type {
@ -5768,7 +5842,7 @@ fn gen_send_iseq(
} }
// Number of locals that are not parameters // Number of locals that are not parameters
let num_locals = unsafe { get_iseq_body_local_table_size(iseq) as i32 } - (num_params as i32); let num_locals = unsafe { get_iseq_body_local_table_size(iseq) as i32 } - (num_params as i32) + if iseq_has_rest && opt_num != 0 { 1 } else { 0 };
// Stack overflow check // Stack overflow check
// Note that vm_push_frame checks it against a decremented cfp, hence the multiply by 2. // Note that vm_push_frame checks it against a decremented cfp, hence the multiply by 2.
@ -5785,25 +5859,28 @@ fn gen_send_iseq(
if let Some(array_length) = splat_array_length { if let Some(array_length) = splat_array_length {
let remaining_opt = (opt_num as u32 + required_num as u32).saturating_sub(array_length + (argc as u32 - 1)); let remaining_opt = (opt_num as u32 + required_num as u32).saturating_sub(array_length + (argc as u32 - 1));
if opt_num > 0 { if !iseq_has_rest {
// We are going to jump to the correct offset based on how many optional if opt_num > 0 {
// params are remaining. // We are going to jump to the correct offset based on how many optional
unsafe { // params are remaining.
let opt_table = get_iseq_body_param_opt_table(iseq); unsafe {
let offset = (opt_num - remaining_opt as i32) as isize; let opt_table = get_iseq_body_param_opt_table(iseq);
start_pc_offset = (*opt_table.offset(offset)).try_into().unwrap(); let offset = (opt_num - remaining_opt as i32) as isize;
}; start_pc_offset = (*opt_table.offset(offset)).try_into().unwrap();
} };
// We are going to assume that the splat fills }
// all the remaining arguments. In the generated code
// we test if this is true and if not side exit.
argc = argc - 1 + array_length as i32 + remaining_opt as i32;
push_splat_args(array_length, jit, asm, ocb);
for _ in 0..remaining_opt { // We are going to assume that the splat fills
// We need to push nil for the optional arguments // all the remaining arguments. In the generated code
let stack_ret = asm.stack_push(Type::Unknown); // we test if this is true and if not side exit.
asm.mov(stack_ret, Qnil.into()); argc = argc - 1 + array_length as i32 + remaining_opt as i32;
push_splat_args(array_length, jit, asm, ocb);
for _ in 0..remaining_opt {
// We need to push nil for the optional arguments
let stack_ret = asm.stack_push(Type::Unknown);
asm.mov(stack_ret, Qnil.into());
}
} }
} }
@ -5826,10 +5903,11 @@ fn gen_send_iseq(
rb_ary_dup as *const u8, rb_ary_dup as *const u8,
vec![array], vec![array],
); );
if non_rest_arg_count > required_num { if non_rest_arg_count > required_num + opt_num {
// If we have more arguments than required, we need to prepend // If we have more arguments than required, we need to prepend
// the items from the stack onto the array. // the items from the stack onto the array.
let diff = (non_rest_arg_count - required_num) as u32; let diff = (non_rest_arg_count - required_num + opts_filled_with_splat.unwrap_or(0)) as u32;
// diff is >0 so no need to worry about null pointer // diff is >0 so no need to worry about null pointer
asm.comment("load pointer to array elements"); asm.comment("load pointer to array elements");
@ -5848,11 +5926,12 @@ fn gen_send_iseq(
asm.mov(stack_ret, array); asm.mov(stack_ret, array);
// We now should have the required arguments // We now should have the required arguments
// and an array of all the rest arguments // and an array of all the rest arguments
argc = required_num + 1; argc = required_num + opts_filled_with_splat.unwrap_or(0) + 1;
} else if non_rest_arg_count < required_num { } else if non_rest_arg_count < required_num + opt_num {
// If we have fewer arguments than required, we need to take some // If we have fewer arguments than required, we need to take some
// from the array and move them to the stack. // from the array and move them to the stack.
let diff = (required_num - non_rest_arg_count) as u32;
let diff = (required_num - non_rest_arg_count + opts_filled_with_splat.unwrap_or(0)) as u32;
// This moves the arguments onto the stack. But it doesn't modify the array. // This moves the arguments onto the stack. But it doesn't modify the array.
move_rest_args_to_stack(array, diff, jit, asm, ocb); move_rest_args_to_stack(array, diff, jit, asm, ocb);
@ -5864,17 +5943,17 @@ fn gen_send_iseq(
// We now should have the required arguments // We now should have the required arguments
// and an array of all the rest arguments // and an array of all the rest arguments
argc = required_num + 1; argc = required_num + opts_filled_with_splat.unwrap_or(0) + 1;
} else { } else {
// The arguments are equal so we can just push to the stack // The arguments are equal so we can just push to the stack
assert!(non_rest_arg_count == required_num); assert!(non_rest_arg_count == required_num + opt_num);
let stack_ret = asm.stack_push(Type::TArray); let stack_ret = asm.stack_push(Type::TArray);
asm.mov(stack_ret, array); asm.mov(stack_ret, array);
} }
} else { } else {
assert!(argc >= required_num); assert!(argc >= required_num);
let n = (argc - required_num) as u32; let n = (argc - required_num - opts_filled) as u32;
argc = required_num + 1; argc = required_num + opts_filled + 1;
// If n is 0, then elts is never going to be read, so we can just pass null // If n is 0, then elts is never going to be read, so we can just pass null
let values_ptr = if n == 0 { let values_ptr = if n == 0 {
Opnd::UImm(0) Opnd::UImm(0)
@ -6069,6 +6148,17 @@ fn gen_send_iseq(
// Spill stack temps to let the callee use them // Spill stack temps to let the callee use them
asm.spill_temps(); asm.spill_temps();
// If we have a rest param and optional parameters,
// we don't actually pass the rest parameter as an argument,
// instead we set its value in the callee's locals
let rest_param = if iseq_has_rest && opt_num != 0 {
argc -= 1;
let top = asm.stack_pop(1);
Some((opts_missing as i32, asm.load(top)))
} else {
None
};
// Points to the receiver operand on the stack unless a captured environment is used // Points to the receiver operand on the stack unless a captured environment is used
let recv = match captured_opnd { let recv = match captured_opnd {
Some(captured_opnd) => asm.load(Opnd::mem(64, captured_opnd, 0)), // captured->self Some(captured_opnd) => asm.load(Opnd::mem(64, captured_opnd, 0)), // captured->self
@ -6115,7 +6205,7 @@ fn gen_send_iseq(
iseq: Some(iseq), iseq: Some(iseq),
pc: None, // We are calling into jitted code, which will set the PC as necessary pc: None, // We are calling into jitted code, which will set the PC as necessary
local_size: num_locals local_size: num_locals
}); }, rest_param);
// No need to set cfp->pc since the callee sets it whenever calling into routines // No need to set cfp->pc since the callee sets it whenever calling into routines
// that could look at it through jit_save_pc(). // that could look at it through jit_save_pc().

View File

@ -277,7 +277,7 @@ make_counters! {
send_iseq_has_rest_and_captured, send_iseq_has_rest_and_captured,
send_iseq_has_rest_and_send, send_iseq_has_rest_and_send,
send_iseq_has_rest_and_kw_supplied, send_iseq_has_rest_and_kw_supplied,
send_iseq_has_rest_and_optional, send_iseq_has_rest_opt_and_block,
send_iseq_has_rest_and_splat_not_equal, send_iseq_has_rest_and_splat_not_equal,
send_is_a_class_mismatch, send_is_a_class_mismatch,
send_instance_of_class_mismatch, send_instance_of_class_mismatch,