diff --git a/bootstraptest/test_yjit.rb b/bootstraptest/test_yjit.rb index aaf70e7b93..0e8c4a0a06 100644 --- a/bootstraptest/test_yjit.rb +++ b/bootstraptest/test_yjit.rb @@ -2494,6 +2494,23 @@ assert_equal '[:not_array, nil, nil]', %q{ expandarray_not_array(obj) } +assert_equal '[1, 2]', %q{ + class NilClass + private + def to_ary + [1, 2] + end + end + + def expandarray_redefined_nilclass + a, b = nil + [a, b] + end + + expandarray_redefined_nilclass + expandarray_redefined_nilclass +} unless rjit_enabled? + assert_equal '[1, 2, nil]', %q{ def expandarray_rhs_too_small a, b, c = [1, 2] diff --git a/yjit/src/codegen.rs b/yjit/src/codegen.rs index 7e44c3add0..4e23cf10b1 100644 --- a/yjit/src/codegen.rs +++ b/yjit/src/codegen.rs @@ -1637,18 +1637,6 @@ fn gen_expandarray( let array_opnd = asm.stack_opnd(0); - // If the array operand is nil, just push on nils - if asm.ctx.get_opnd_type(array_opnd.into()) == Type::Nil { - asm.stack_pop(1); // pop after using the type info - // special case for a, b = nil pattern - // push N nils onto the stack - for _ in 0..num { - let push_opnd = asm.stack_push(Type::Nil); - asm.mov(push_opnd, Qnil.into()); - } - return Some(KeepCompiling); - } - // Defer compilation so we can specialize on a runtime `self` if !jit.at_current_insn() { defer_compilation(jit, asm, ocb); @@ -1657,10 +1645,52 @@ fn gen_expandarray( let comptime_recv = jit.peek_at_stack(&asm.ctx, 0); - // If the comptime receiver is not an array, bail - if comptime_recv.class_of() != unsafe { rb_cArray } { - gen_counter_incr(asm, Counter::expandarray_comptime_not_array); - return None; + // If the comptime receiver is not an array + if !unsafe { RB_TYPE_P(comptime_recv, RUBY_T_ARRAY) } { + // at compile time, ensure to_ary is not defined + let target_cme = unsafe { rb_callable_method_entry_or_negative(comptime_recv.class_of(), ID!(to_ary)) }; + let cme_def_type = unsafe { get_cme_def_type(target_cme) }; + + // if to_ary is defined, return can't compile so to_ary can be called + if cme_def_type != VM_METHOD_TYPE_UNDEF { + gen_counter_incr(asm, Counter::expandarray_to_ary); + return None; + } + + // invalidate compile block if to_ary is later defined + jit.assume_method_lookup_stable(asm, ocb, target_cme); + + jit_guard_known_klass( + jit, + asm, + ocb, + comptime_recv.class_of(), + array_opnd, + array_opnd.into(), + comptime_recv, + SEND_MAX_DEPTH, + Counter::expandarray_not_array, + ); + + let opnd = asm.stack_pop(1); // pop after using the type info + + // If we don't actually want any values, then just keep going + if num == 0 { + return Some(KeepCompiling); + } + + // load opnd to avoid a race because we are also pushing onto the stack + let opnd = asm.load(opnd); + + for _ in 1..num { + let push_opnd = asm.stack_push(Type::Nil); + asm.mov(push_opnd, Qnil.into()); + } + + let push_opnd = asm.stack_push(Type::Unknown); + asm.mov(push_opnd, opnd); + + return Some(KeepCompiling); } // Get the compile-time array length diff --git a/yjit/src/cruby.rs b/yjit/src/cruby.rs index a43b053d3e..ac0bdf6885 100644 --- a/yjit/src/cruby.rs +++ b/yjit/src/cruby.rs @@ -781,6 +781,7 @@ pub(crate) mod ids { name: max content: b"max" name: hash content: b"hash" name: respond_to_missing content: b"respond_to_missing?" + name: to_ary content: b"to_ary" } } diff --git a/yjit/src/stats.rs b/yjit/src/stats.rs index ba1479e152..bf918aeb75 100644 --- a/yjit/src/stats.rs +++ b/yjit/src/stats.rs @@ -469,7 +469,7 @@ make_counters! { expandarray_splat, expandarray_postarg, expandarray_not_array, - expandarray_comptime_not_array, + expandarray_to_ary, expandarray_chain_max_depth, // getblockparam