diff --git a/bootstraptest/test_method.rb b/bootstraptest/test_method.rb index d1d1f57d55..3929d69da4 100644 --- a/bootstraptest/test_method.rb +++ b/bootstraptest/test_method.rb @@ -1176,3 +1176,153 @@ assert_equal 'ok', %q{ foo foo }, '[Bug #20178]' + +assert_equal 'ok', %q{ + def bar(x); x; end + def foo(...); bar(...); end + foo('ok') +} + +assert_equal 'ok', %q{ + def bar(x); x; end + def foo(z, ...); bar(...); end + foo(1, 'ok') +} + +assert_equal 'ok', %q{ + def bar(x, y); x; end + def foo(...); bar("ok", ...); end + foo(1) +} + +assert_equal 'ok', %q{ + def bar(x); x; end + def foo(...); 1.times { return bar(...) }; end + foo("ok") +} + +assert_equal 'ok', %q{ + def bar(x); x; end + def foo(...); x = nil; 1.times { x = bar(...) }; x; end + foo("ok") +} + +assert_equal 'ok', %q{ + def bar(x); yield; end + def foo(...); bar(...); end + foo(1) { "ok" } +} + +assert_equal 'ok', %q{ + def baz(x); x; end + def bar(...); baz(...); end + def foo(...); bar(...); end + foo("ok") +} + +assert_equal '[1, 2, 3, 4]', %q{ + def baz(a, b, c, d); [a, b, c, d]; end + def bar(...); baz(1, ...); end + def foo(...); bar(2, ...); end + foo(3, 4) +} + +assert_equal 'ok', %q{ + class Foo; def self.foo(x); x; end; end + class Bar < Foo; def self.foo(...); super; end; end + Bar.foo('ok') +} + +assert_equal 'ok', %q{ + class Foo; def self.foo(x); x; end; end + class Bar < Foo; def self.foo(...); super(...); end; end + Bar.foo('ok') +} + +assert_equal 'ok', %q{ + class Foo; def self.foo(x, y); x + y; end; end + class Bar < Foo; def self.foo(...); super("o", ...); end; end + Bar.foo('k') +} + +assert_equal 'ok', %q{ + def bar(a); a; end + def foo(...); lambda { bar(...) }; end + foo("ok").call +} + +assert_equal 'ok', %q{ + class Foo; def self.foo(x, y); x + y; end; end + class Bar < Foo; def self.y(&b); b; end; def self.foo(...); y { super("o", ...) }; end; end + Bar.foo('k').call +} + +assert_equal 'ok', %q{ + def baz(n); n; end + def foo(...); bar = baz(...); lambda { lambda { bar } }; end + foo("ok").call.call +} + +assert_equal 'ok', %q{ + class A; def self.foo(...); new(...); end; attr_reader :b; def initialize(a, b:"ng"); @a = a; @b = b; end end + A.foo(1).b + A.foo(1, b: "ok").b +} + +assert_equal 'ok', %q{ + class A; def initialize; @a = ["ok"]; end; def first(...); @a.first(...); end; end + def call x; x.first; end + def call1 x; x.first(1); end + call(A.new) + call1(A.new).first +} + +assert_equal 'ok', %q{ + class A; def foo; yield("o"); end; end + class B < A; def foo(...); super { |x| yield(x + "k") }; end; end + B.new.foo { |x| x } +} + +assert_equal "[1, 2, 3, 4]", %q{ + def foo(*b) = b + + def forward(...) + splat = [1,2,3] + foo(*splat, ...) + end + + forward(4) +} + +assert_equal "[1, 2, 3, 4]", %q{ +class A + def foo(*b) = b +end + +class B < A + def foo(...) + splat = [1,2,3] + super(*splat, ...) + end +end + +B.new.foo(4) +} + +assert_equal 'ok', %q{ + class A; attr_reader :iv; def initialize(...) = @iv = "ok"; end + A.new("foo", bar: []).iv +} + +assert_equal 'ok', %q{ +def foo(a, b) = a + b +def bar(...) = foo(...) +bar(1, 2) +bar(1, 2) +begin + bar(1, 2, 3) + "ng" +rescue ArgumentError + "ok" +end +} diff --git a/bootstraptest/test_yjit.rb b/bootstraptest/test_yjit.rb index 015d04b8b0..30c23d625f 100644 --- a/bootstraptest/test_yjit.rb +++ b/bootstraptest/test_yjit.rb @@ -4772,6 +4772,22 @@ assert_equal 'foo', %q{ entry(true) } +assert_equal 'ok', %q{ + def ok + :ok + end + + def delegator(...) + ok(...) + end + + def caller + send(:delegator) + end + + caller +} + assert_equal '[:ok, :ok, :ok]', %q{ def identity(x) = x def foo(x, _) = x diff --git a/compile.c b/compile.c index 5e7c4289d5..117a1f6e41 100644 --- a/compile.c +++ b/compile.c @@ -487,7 +487,7 @@ static int iseq_setup_insn(rb_iseq_t *iseq, LINK_ANCHOR *const anchor); static int iseq_optimize(rb_iseq_t *iseq, LINK_ANCHOR *const anchor); static int iseq_insns_unification(rb_iseq_t *iseq, LINK_ANCHOR *const anchor); -static int iseq_set_local_table(rb_iseq_t *iseq, const rb_ast_id_table_t *tbl); +static int iseq_set_local_table(rb_iseq_t *iseq, const rb_ast_id_table_t *tbl, const NODE *const node_args); static int iseq_set_exception_local_table(rb_iseq_t *iseq); static int iseq_set_arguments(rb_iseq_t *iseq, LINK_ANCHOR *const anchor, const NODE *const node); @@ -869,12 +869,12 @@ rb_iseq_compile_node(rb_iseq_t *iseq, const NODE *node) if (node == 0) { NO_CHECK(COMPILE(ret, "nil", node)); - iseq_set_local_table(iseq, 0); + iseq_set_local_table(iseq, 0, 0); } /* assume node is T_NODE */ else if (nd_type_p(node, NODE_SCOPE)) { /* iseq type of top, method, class, block */ - iseq_set_local_table(iseq, RNODE_SCOPE(node)->nd_tbl); + iseq_set_local_table(iseq, RNODE_SCOPE(node)->nd_tbl, (NODE *)RNODE_SCOPE(node)->nd_args); iseq_set_arguments(iseq, ret, (NODE *)RNODE_SCOPE(node)->nd_args); switch (ISEQ_BODY(iseq)->type) { @@ -1439,7 +1439,7 @@ new_callinfo(rb_iseq_t *iseq, ID mid, int argc, unsigned int flag, struct rb_cal { VM_ASSERT(argc >= 0); - if (!(flag & (VM_CALL_ARGS_SPLAT | VM_CALL_ARGS_BLOCKARG | VM_CALL_KW_SPLAT)) && + if (!(flag & (VM_CALL_ARGS_SPLAT | VM_CALL_ARGS_BLOCKARG | VM_CALL_KW_SPLAT | VM_CALL_FORWARDING)) && kw_arg == NULL && !has_blockiseq) { flag |= VM_CALL_ARGS_SIMPLE; } @@ -2034,6 +2034,13 @@ iseq_set_arguments(rb_iseq_t *iseq, LINK_ANCHOR *const optargs, const NODE *cons } block_id = args->block_arg; + bool optimized_forward = (args->forwarding && args->pre_args_num == 0 && !args->opt_args); + + if (optimized_forward) { + rest_id = 0; + block_id = 0; + } + if (args->opt_args) { const rb_node_opt_arg_t *node = args->opt_args; LABEL *label; @@ -2090,7 +2097,7 @@ iseq_set_arguments(rb_iseq_t *iseq, LINK_ANCHOR *const optargs, const NODE *cons if (args->kw_args) { arg_size = iseq_set_arguments_keywords(iseq, optargs, args, arg_size); } - else if (args->kw_rest_arg) { + else if (args->kw_rest_arg && !optimized_forward) { ID kw_id = iseq->body->local_table[arg_size]; struct rb_iseq_param_keyword *keyword = ZALLOC_N(struct rb_iseq_param_keyword, 1); keyword->rest_start = arg_size++; @@ -2111,6 +2118,13 @@ iseq_set_arguments(rb_iseq_t *iseq, LINK_ANCHOR *const optargs, const NODE *cons iseq_set_use_block(iseq); } + // Only optimize specifically methods like this: `foo(...)` + if (optimized_forward) { + body->param.flags.use_block = 1; + body->param.flags.forwardable = TRUE; + arg_size = 1; + } + iseq_calc_param_size(iseq); body->param.size = arg_size; @@ -2140,13 +2154,26 @@ iseq_set_arguments(rb_iseq_t *iseq, LINK_ANCHOR *const optargs, const NODE *cons } static int -iseq_set_local_table(rb_iseq_t *iseq, const rb_ast_id_table_t *tbl) +iseq_set_local_table(rb_iseq_t *iseq, const rb_ast_id_table_t *tbl, const NODE *const node_args) { unsigned int size = tbl ? tbl->size : 0; + unsigned int offset = 0; + + if (node_args) { + struct rb_args_info *args = &RNODE_ARGS(node_args)->nd_ainfo; + + // If we have a function that only has `...` as the parameter, + // then its local table should only be `...` + // FIXME: I think this should be fixed in the AST rather than special case here. + if (args->forwarding && args->pre_args_num == 0 && !args->opt_args) { + size -= 3; + offset += 3; + } + } if (size > 0) { ID *ids = (ID *)ALLOC_N(ID, size); - MEMCPY(ids, tbl->ids, ID, size); + MEMCPY(ids, tbl->ids + offset, ID, size); ISEQ_BODY(iseq)->local_table = ids; } ISEQ_BODY(iseq)->local_table_size = size; @@ -4146,7 +4173,7 @@ iseq_specialized_instruction(rb_iseq_t *iseq, INSN *iobj) } } - if ((vm_ci_flag(ci) & VM_CALL_ARGS_BLOCKARG) == 0 && blockiseq == NULL) { + if ((vm_ci_flag(ci) & (VM_CALL_ARGS_BLOCKARG | VM_CALL_FORWARDING)) == 0 && blockiseq == NULL) { iobj->insn_id = BIN(opt_send_without_block); iobj->operand_size = insn_len(iobj->insn_id) - 1; } @@ -6319,9 +6346,33 @@ setup_args(rb_iseq_t *iseq, LINK_ANCHOR *const args, const NODE *argn, unsigned int dup_rest = 1; DECL_ANCHOR(arg_block); INIT_ANCHOR(arg_block); - NO_CHECK(COMPILE(arg_block, "block", RNODE_BLOCK_PASS(argn)->nd_body)); - *flag |= VM_CALL_ARGS_BLOCKARG; + if (RNODE_BLOCK_PASS(argn)->forwarding && ISEQ_BODY(ISEQ_BODY(iseq)->local_iseq)->param.flags.forwardable) { + int idx = ISEQ_BODY(ISEQ_BODY(iseq)->local_iseq)->local_table_size;// - get_local_var_idx(iseq, idDot3); + + RUBY_ASSERT(nd_type_p(RNODE_BLOCK_PASS(argn)->nd_head, NODE_ARGSPUSH)); + const NODE * arg_node = + RNODE_ARGSPUSH(RNODE_BLOCK_PASS(argn)->nd_head)->nd_head; + + int argc = 0; + + // Only compile leading args: + // foo(x, y, ...) + // ^^^^ + if (nd_type_p(arg_node, NODE_ARGSCAT)) { + argc += setup_args_core(iseq, args, RNODE_ARGSCAT(arg_node)->nd_head, dup_rest, flag, keywords); + } + + *flag |= VM_CALL_FORWARDING; + + ADD_GETLOCAL(args, argn, idx, get_lvar_level(iseq)); + return INT2FIX(argc); + } + else { + *flag |= VM_CALL_ARGS_BLOCKARG; + + NO_CHECK(COMPILE(arg_block, "block", RNODE_BLOCK_PASS(argn)->nd_body)); + } if (LIST_INSN_SIZE_ONE(arg_block)) { LINK_ELEMENT *elem = FIRST_ELEMENT(arg_block); @@ -6353,7 +6404,7 @@ build_postexe_iseq(rb_iseq_t *iseq, LINK_ANCHOR *ret, const void *ptr) ADD_INSN1(ret, body, putspecialobject, INT2FIX(VM_SPECIAL_OBJECT_VMCORE)); ADD_CALL_WITH_BLOCK(ret, body, id_core_set_postexe, argc, block); RB_OBJ_WRITTEN(iseq, Qundef, (VALUE)block); - iseq_set_local_table(iseq, 0); + iseq_set_local_table(iseq, 0, 0); } static void @@ -9465,6 +9516,13 @@ compile_super(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const NODE *const node, i ADD_GETLOCAL(args, node, idx, lvar_level); } + /* forward ... */ + if (local_body->param.flags.forwardable) { + flag |= VM_CALL_FORWARDING; + int idx = local_body->local_table_size - get_local_var_idx(liseq, idDot3); + ADD_GETLOCAL(args, node, idx, lvar_level); + } + if (local_body->param.flags.has_opt) { /* optional arguments */ int j; @@ -13030,7 +13088,8 @@ ibf_dump_iseq_each(struct ibf_dump *dump, const rb_iseq_t *iseq) (body->param.flags.ruby2_keywords << 9) | (body->param.flags.anon_rest << 10) | (body->param.flags.anon_kwrest << 11) | - (body->param.flags.use_block << 12); + (body->param.flags.use_block << 12) | + (body->param.flags.forwardable << 13) ; #if IBF_ISEQ_ENABLE_LOCAL_BUFFER # define IBF_BODY_OFFSET(x) (x) @@ -13247,6 +13306,7 @@ ibf_load_iseq_each(struct ibf_load *load, rb_iseq_t *iseq, ibf_offset_t offset) load_body->param.flags.anon_rest = (param_flags >> 10) & 1; load_body->param.flags.anon_kwrest = (param_flags >> 11) & 1; load_body->param.flags.use_block = (param_flags >> 12) & 1; + load_body->param.flags.forwardable = (param_flags >> 13) & 1; load_body->param.size = param_size; load_body->param.lead_num = param_lead_num; load_body->param.opt_num = param_opt_num; diff --git a/insns.def b/insns.def index f7df92cf06..6c96f252c2 100644 --- a/insns.def +++ b/insns.def @@ -867,10 +867,22 @@ send // attr rb_snum_t sp_inc = sp_inc_of_sendish(cd->ci); // attr rb_snum_t comptime_sp_inc = sp_inc_of_sendish(ci); { - VALUE bh = vm_caller_setup_arg_block(ec, GET_CFP(), cd->ci, blockiseq, false); + struct rb_forwarding_call_data adjusted_cd; + struct rb_callinfo adjusted_ci; + + CALL_DATA _cd = cd; + + VALUE bh = vm_caller_setup_args(GET_EC(), GET_CFP(), &cd, blockiseq, 0, &adjusted_cd, &adjusted_ci); + val = vm_sendish(ec, GET_CFP(), cd, bh, mexp_search_method); JIT_EXEC(ec, val); + if (vm_ci_flag(_cd->ci) & VM_CALL_FORWARDING) { + if (_cd->cc != cd->cc && vm_cc_markable(cd->cc)) { + RB_OBJ_WRITE(GET_ISEQ(), &_cd->cc, cd->cc); + } + } + if (UNDEF_P(val)) { RESTORE_REGS(); NEXT_INSN(); @@ -994,10 +1006,21 @@ invokesuper // attr rb_snum_t sp_inc = sp_inc_of_sendish(cd->ci); // attr rb_snum_t comptime_sp_inc = sp_inc_of_sendish(ci); { - VALUE bh = vm_caller_setup_arg_block(ec, GET_CFP(), cd->ci, blockiseq, true); + CALL_DATA _cd = cd; + struct rb_forwarding_call_data adjusted_cd; + struct rb_callinfo adjusted_ci; + + VALUE bh = vm_caller_setup_args(GET_EC(), GET_CFP(), &cd, blockiseq, 1, &adjusted_cd, &adjusted_ci); + val = vm_sendish(ec, GET_CFP(), cd, bh, mexp_search_super); JIT_EXEC(ec, val); + if (vm_ci_flag(_cd->ci) & VM_CALL_FORWARDING) { + if (_cd->cc != cd->cc && vm_cc_markable(cd->cc)) { + RB_OBJ_WRITE(GET_ISEQ(), &_cd->cc, cd->cc); + } + } + if (UNDEF_P(val)) { RESTORE_REGS(); NEXT_INSN(); diff --git a/iseq.c b/iseq.c index 05d52b61b4..9c94ce355d 100644 --- a/iseq.c +++ b/iseq.c @@ -2449,6 +2449,7 @@ rb_insn_operand_intern(const rb_iseq_t *iseq, CALL_FLAG(KWARG); CALL_FLAG(KW_SPLAT); CALL_FLAG(KW_SPLAT_MUT); + CALL_FLAG(FORWARDING); CALL_FLAG(OPT_SEND); /* maybe not reachable */ rb_ary_push(ary, rb_ary_join(flags, rb_str_new2("|"))); } @@ -3495,6 +3496,17 @@ rb_iseq_parameters(const rb_iseq_t *iseq, int is_proc) CONST_ID(req, "req"); CONST_ID(opt, "opt"); + + if (body->param.flags.forwardable) { + // [[:rest, :*], [:keyrest, :**], [:block, :&]] + CONST_ID(rest, "rest"); + CONST_ID(keyrest, "keyrest"); + CONST_ID(block, "block"); + rb_ary_push(args, rb_ary_new_from_args(2, ID2SYM(rest), ID2SYM(idMULT))); + rb_ary_push(args, rb_ary_new_from_args(2, ID2SYM(keyrest), ID2SYM(idPow))); + rb_ary_push(args, rb_ary_new_from_args(2, ID2SYM(block), ID2SYM(idAnd))); + } + if (is_proc) { for (i = 0; i < body->param.lead_num; i++) { PARAM_TYPE(opt); diff --git a/lib/ruby_vm/rjit/insn_compiler.rb b/lib/ruby_vm/rjit/insn_compiler.rb index 9e4b28f87a..4151ff6db3 100644 --- a/lib/ruby_vm/rjit/insn_compiler.rb +++ b/lib/ruby_vm/rjit/insn_compiler.rb @@ -1435,6 +1435,10 @@ module RubyVM::RJIT mid = C.vm_ci_mid(cd.ci) calling = build_calling(ci: cd.ci, block_handler: blockiseq) + if calling.flags & C::VM_CALL_FORWARDING != 0 + return CantCompile + end + # vm_sendish cme, comptime_recv_klass = jit_search_method(jit, ctx, asm, mid, calling) if cme == CantCompile @@ -4622,6 +4626,11 @@ module RubyVM::RJIT end end + # Don't compile forwardable iseqs + if iseq.body.param.flags.forwardable + return CantCompile + end + # We will not have CantCompile from here. if block_arg diff --git a/parse.y b/parse.y index 0a339c5959..3a21e2f881 100644 --- a/parse.y +++ b/parse.y @@ -12253,6 +12253,7 @@ static rb_node_block_pass_t * rb_node_block_pass_new(struct parser_params *p, NODE *nd_body, const YYLTYPE *loc) { rb_node_block_pass_t *n = NODE_NEWNODE(NODE_BLOCK_PASS, rb_node_block_pass_t, loc); + n->forwarding = 0; n->nd_head = 0; n->nd_body = nd_body; @@ -15084,6 +15085,7 @@ new_args_forward_call(struct parser_params *p, NODE *leading, const YYLTYPE *loc #endif rb_node_block_pass_t *block = NEW_BLOCK_PASS(NEW_LVAR(idFWD_BLOCK, loc), loc); NODE *args = leading ? rest_arg_append(p, leading, rest, argsloc) : NEW_SPLAT(rest, loc); + block->forwarding = TRUE; #ifndef FORWARD_ARGS_WITH_RUBY2_KEYWORDS args = arg_append(p, args, new_hash(p, kwrest, loc), loc); #endif diff --git a/prism_compile.c b/prism_compile.c index 7fefe7c0e4..9307b2449c 100644 --- a/prism_compile.c +++ b/prism_compile.c @@ -1561,7 +1561,17 @@ pm_setup_args_core(const pm_arguments_node_t *arguments_node, const pm_node_t *b break; } case PM_FORWARDING_ARGUMENTS_NODE: { + if (ISEQ_BODY(ISEQ_BODY(iseq)->local_iseq)->param.flags.forwardable) { + *flags |= VM_CALL_FORWARDING; + + pm_local_index_t mult_local = pm_lookup_local_index(iseq, scope_node, PM_CONSTANT_DOT3, 0); + PUSH_GETLOCAL(ret, location, mult_local.index, mult_local.level); + + break; + } + orig_argc += 2; + *flags |= VM_CALL_ARGS_SPLAT | VM_CALL_ARGS_SPLAT_MUT | VM_CALL_ARGS_BLOCKARG | VM_CALL_KW_SPLAT; // Forwarding arguments nodes are treated as foo(*, **, &) @@ -6693,6 +6703,15 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret, int argc = 0; int depth = get_lvar_level(iseq); + if (ISEQ_BODY(ISEQ_BODY(iseq)->local_iseq)->param.flags.forwardable) { + flag |= VM_CALL_FORWARDING; + pm_local_index_t mult_local = pm_lookup_local_index(iseq, scope_node, PM_CONSTANT_DOT3, 0); + PUSH_GETLOCAL(ret, location, mult_local.index, mult_local.level); + PUSH_INSN2(ret, location, invokesuper, new_callinfo(iseq, 0, 0, flag, NULL, block != NULL), block); + if (popped) PUSH_INSN(ret, location, pop); + return; + } + if (local_body->param.flags.has_lead) { /* required arguments */ for (int i = 0; i < local_body->param.lead_num; i++) { @@ -8300,7 +8319,14 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret, // When we have a `...` as the keyword_rest, it's a forwarding_parameter_node and // we need to leave space for 4 locals: *, **, &, ... if (PM_NODE_TYPE_P(parameters_node->keyword_rest, PM_FORWARDING_PARAMETER_NODE)) { - table_size += 4; + // Only optimize specifically methods like this: `foo(...)` + if (requireds_list->size == 0 && optionals_list->size == 0 && keywords_list->size == 0) { + ISEQ_BODY(iseq)->param.flags.forwardable = TRUE; + table_size += 1; + } + else { + table_size += 4; + } } else { const pm_keyword_rest_parameter_node_t *kw_rest = (const pm_keyword_rest_parameter_node_t *) parameters_node->keyword_rest; @@ -8654,29 +8680,31 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret, // def foo(...) // ^^^ case PM_FORWARDING_PARAMETER_NODE: { - body->param.rest_start = local_index; - body->param.flags.has_rest = true; + if (!ISEQ_BODY(iseq)->param.flags.forwardable) { + body->param.rest_start = local_index; + body->param.flags.has_rest = true; - // Add the leading * - pm_insert_local_special(idMULT, local_index++, index_lookup_table, local_table_for_iseq); + // Add the leading * + pm_insert_local_special(idMULT, local_index++, index_lookup_table, local_table_for_iseq); - // Add the kwrest ** - RUBY_ASSERT(!body->param.flags.has_kw); + // Add the kwrest ** + RUBY_ASSERT(!body->param.flags.has_kw); - // There are no keywords declared (in the text of the program) - // but the forwarding node implies we support kwrest (**) - body->param.flags.has_kw = false; - body->param.flags.has_kwrest = true; - body->param.keyword = keyword = ZALLOC_N(struct rb_iseq_param_keyword, 1); + // There are no keywords declared (in the text of the program) + // but the forwarding node implies we support kwrest (**) + body->param.flags.has_kw = false; + body->param.flags.has_kwrest = true; + body->param.keyword = keyword = ZALLOC_N(struct rb_iseq_param_keyword, 1); - keyword->rest_start = local_index; + keyword->rest_start = local_index; - pm_insert_local_special(idPow, local_index++, index_lookup_table, local_table_for_iseq); + pm_insert_local_special(idPow, local_index++, index_lookup_table, local_table_for_iseq); - body->param.block_start = local_index; - body->param.flags.has_block = true; + body->param.block_start = local_index; + body->param.flags.has_block = true; - pm_insert_local_special(idAnd, local_index++, index_lookup_table, local_table_for_iseq); + pm_insert_local_special(idAnd, local_index++, index_lookup_table, local_table_for_iseq); + } pm_insert_local_special(idDot3, local_index++, index_lookup_table, local_table_for_iseq); break; } @@ -8820,7 +8848,15 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret, } scope_node->index_lookup_table = index_lookup_table; iseq_calc_param_size(iseq); - iseq_set_local_table(iseq, local_table_for_iseq); + + if (ISEQ_BODY(iseq)->param.flags.forwardable) { + // We're treating `...` as a parameter so that frame + // pushing won't clobber it. + ISEQ_BODY(iseq)->param.size += 1; + } + + // FIXME: args? + iseq_set_local_table(iseq, local_table_for_iseq, 0); scope_node->local_table_for_iseq_size = local_table_for_iseq->size; //********STEP 5************ diff --git a/proc.c b/proc.c index 1a67a63663..09b288fecd 100644 --- a/proc.c +++ b/proc.c @@ -1050,7 +1050,7 @@ rb_iseq_min_max_arity(const rb_iseq_t *iseq, int *max) { *max = ISEQ_BODY(iseq)->param.flags.has_rest == FALSE ? ISEQ_BODY(iseq)->param.lead_num + ISEQ_BODY(iseq)->param.opt_num + ISEQ_BODY(iseq)->param.post_num + - (ISEQ_BODY(iseq)->param.flags.has_kw == TRUE || ISEQ_BODY(iseq)->param.flags.has_kwrest == TRUE) + (ISEQ_BODY(iseq)->param.flags.has_kw == TRUE || ISEQ_BODY(iseq)->param.flags.has_kwrest == TRUE || ISEQ_BODY(iseq)->param.flags.forwardable == TRUE) : UNLIMITED_ARGUMENTS; return ISEQ_BODY(iseq)->param.lead_num + ISEQ_BODY(iseq)->param.post_num + (ISEQ_BODY(iseq)->param.flags.has_kw && ISEQ_BODY(iseq)->param.keyword->required_num > 0); } diff --git a/rjit_c.rb b/rjit_c.rb index 1444cd6bbc..8c9615c6bb 100644 --- a/rjit_c.rb +++ b/rjit_c.rb @@ -420,6 +420,7 @@ module RubyVM::RJIT # :nodoc: all C::VM_CALL_ARGS_BLOCKARG = Primitive.cexpr! %q{ SIZET2NUM(VM_CALL_ARGS_BLOCKARG) } C::VM_CALL_ARGS_SPLAT = Primitive.cexpr! %q{ SIZET2NUM(VM_CALL_ARGS_SPLAT) } C::VM_CALL_FCALL = Primitive.cexpr! %q{ SIZET2NUM(VM_CALL_FCALL) } + C::VM_CALL_FORWARDING = Primitive.cexpr! %q{ SIZET2NUM(VM_CALL_FORWARDING) } C::VM_CALL_KWARG = Primitive.cexpr! %q{ SIZET2NUM(VM_CALL_KWARG) } C::VM_CALL_KW_SPLAT = Primitive.cexpr! %q{ SIZET2NUM(VM_CALL_KW_SPLAT) } C::VM_CALL_KW_SPLAT_MUT = Primitive.cexpr! %q{ SIZET2NUM(VM_CALL_KW_SPLAT_MUT) } @@ -1098,6 +1099,7 @@ module RubyVM::RJIT # :nodoc: all anon_rest: [CType::BitField.new(1, 2), 10], anon_kwrest: [CType::BitField.new(1, 3), 11], use_block: [CType::BitField.new(1, 4), 12], + forwardable: [CType::BitField.new(1, 5), 13], ), Primitive.cexpr!("OFFSETOF(((struct rb_iseq_constant_body *)NULL)->param, flags)")], size: [CType::Immediate.parse("unsigned int"), Primitive.cexpr!("OFFSETOF(((struct rb_iseq_constant_body *)NULL)->param, size)")], lead_num: [CType::Immediate.parse("int"), Primitive.cexpr!("OFFSETOF(((struct rb_iseq_constant_body *)NULL)->param, lead_num)")], diff --git a/rubyparser.h b/rubyparser.h index c98f4a91a5..62f742eca1 100644 --- a/rubyparser.h +++ b/rubyparser.h @@ -850,6 +850,7 @@ typedef struct RNode_BLOCK_PASS { struct RNode *nd_head; struct RNode *nd_body; + unsigned int forwarding: 1; } rb_node_block_pass_t; typedef struct RNode_DEFN { diff --git a/test/prism/locals_test.rb b/test/prism/locals_test.rb index 27fdfc90ef..3c45d8b08b 100644 --- a/test/prism/locals_test.rb +++ b/test/prism/locals_test.rb @@ -169,7 +169,11 @@ module Prism sorted << AnonymousLocal if params.keywords.any? if params.keyword_rest.is_a?(ForwardingParameterNode) - sorted.push(:*, :**, :&, :"...") + if sorted.length == 0 + sorted.push(:"...") + else + sorted.push(:*, :**, :&, :"...") + end elsif params.keyword_rest.is_a?(KeywordRestParameterNode) sorted << (params.keyword_rest.name || :**) end diff --git a/test/ruby/test_allocation.rb b/test/ruby/test_allocation.rb index fbe0548899..65302e275d 100644 --- a/test/ruby/test_allocation.rb +++ b/test/ruby/test_allocation.rb @@ -630,20 +630,20 @@ class TestAllocation < Test::Unit::TestCase check_allocations(<<~RUBY) def self.argument_forwarding(...); end - check_allocations(1, 1, "argument_forwarding(1, a: 2#{block})") - check_allocations(1, 0, "argument_forwarding(1, *empty_array, a: 2#{block})") - check_allocations(1, 1, "argument_forwarding(1, a:2, **empty_hash#{block})") - check_allocations(1, 1, "argument_forwarding(1, **empty_hash, a: 2#{block})") + check_allocations(0, 0, "argument_forwarding(1, a: 2#{block})") + check_allocations(0, 0, "argument_forwarding(1, *empty_array, a: 2#{block})") + check_allocations(0, 1, "argument_forwarding(1, a:2, **empty_hash#{block})") + check_allocations(0, 1, "argument_forwarding(1, **empty_hash, a: 2#{block})") - check_allocations(1, 0, "argument_forwarding(1, **nil#{block})") - check_allocations(1, 0, "argument_forwarding(1, **empty_hash#{block})") - check_allocations(1, 0, "argument_forwarding(1, **hash1#{block})") - check_allocations(1, 0, "argument_forwarding(1, *empty_array, **hash1#{block})") - check_allocations(1, 1, "argument_forwarding(1, **hash1, **empty_hash#{block})") - check_allocations(1, 1, "argument_forwarding(1, **empty_hash, **hash1#{block})") + check_allocations(0, 0, "argument_forwarding(1, **nil#{block})") + check_allocations(0, 0, "argument_forwarding(1, **empty_hash#{block})") + check_allocations(0, 0, "argument_forwarding(1, **hash1#{block})") + check_allocations(0, 0, "argument_forwarding(1, *empty_array, **hash1#{block})") + check_allocations(0, 1, "argument_forwarding(1, **hash1, **empty_hash#{block})") + check_allocations(0, 1, "argument_forwarding(1, **empty_hash, **hash1#{block})") - check_allocations(1, 0, "argument_forwarding(1, *empty_array#{block})") - check_allocations(1, 1, "argument_forwarding(1, **hash1, **empty_hash#{block})") + check_allocations(0, 0, "argument_forwarding(1, *empty_array#{block})") + check_allocations(0, 1, "argument_forwarding(1, **hash1, **empty_hash#{block})") check_allocations(1, 0, "argument_forwarding(1, *empty_array, *empty_array, **empty_hash#{block})") check_allocations(0, 0, "argument_forwarding(*array1, a: 2#{block})") @@ -676,20 +676,20 @@ class TestAllocation < Test::Unit::TestCase check_allocations(<<~RUBY) def self.argument_forwarding(...); t(...) end; def self.t(...) end - check_allocations(1, 1, "argument_forwarding(1, a: 2#{block})") - check_allocations(1, 0, "argument_forwarding(1, *empty_array, a: 2#{block})") - check_allocations(1, 1, "argument_forwarding(1, a:2, **empty_hash#{block})") - check_allocations(1, 1, "argument_forwarding(1, **empty_hash, a: 2#{block})") + check_allocations(0, 0, "argument_forwarding(1, a: 2#{block})") + check_allocations(0, 0, "argument_forwarding(1, *empty_array, a: 2#{block})") + check_allocations(0, 1, "argument_forwarding(1, a:2, **empty_hash#{block})") + check_allocations(0, 1, "argument_forwarding(1, **empty_hash, a: 2#{block})") - check_allocations(1, 0, "argument_forwarding(1, **nil#{block})") - check_allocations(1, 0, "argument_forwarding(1, **empty_hash#{block})") - check_allocations(1, 0, "argument_forwarding(1, **hash1#{block})") - check_allocations(1, 0, "argument_forwarding(1, *empty_array, **hash1#{block})") - check_allocations(1, 1, "argument_forwarding(1, **hash1, **empty_hash#{block})") - check_allocations(1, 1, "argument_forwarding(1, **empty_hash, **hash1#{block})") + check_allocations(0, 0, "argument_forwarding(1, **nil#{block})") + check_allocations(0, 0, "argument_forwarding(1, **empty_hash#{block})") + check_allocations(0, 0, "argument_forwarding(1, **hash1#{block})") + check_allocations(0, 0, "argument_forwarding(1, *empty_array, **hash1#{block})") + check_allocations(0, 1, "argument_forwarding(1, **hash1, **empty_hash#{block})") + check_allocations(0, 1, "argument_forwarding(1, **empty_hash, **hash1#{block})") - check_allocations(1, 0, "argument_forwarding(1, *empty_array#{block})") - check_allocations(1, 1, "argument_forwarding(1, **hash1, **empty_hash#{block})") + check_allocations(0, 0, "argument_forwarding(1, *empty_array#{block})") + check_allocations(0, 1, "argument_forwarding(1, **hash1, **empty_hash#{block})") check_allocations(1, 0, "argument_forwarding(1, *empty_array, *empty_array, **empty_hash#{block})") check_allocations(0, 0, "argument_forwarding(*array1, a: 2#{block})") diff --git a/test/ruby/test_iseq.rb b/test/ruby/test_iseq.rb index 267c0c83dc..c946d588c1 100644 --- a/test/ruby/test_iseq.rb +++ b/test/ruby/test_iseq.rb @@ -95,6 +95,16 @@ class TestISeq < Test::Unit::TestCase assert_equal(42, ISeq.load_from_binary(iseq.to_binary).eval) end + def test_forwardable + iseq = compile(<<~EOF, __LINE__+1) + Class.new { + def bar(a, b); a + b; end + def foo(...); bar(...); end + } + EOF + assert_equal(42, ISeq.load_from_binary(iseq.to_binary).eval.new.foo(40, 2)) + end + def test_super_with_block iseq = compile(<<~EOF, __LINE__+1) def (Object.new).touch(*) # :nodoc: diff --git a/test/ruby/test_method.rb b/test/ruby/test_method.rb index 5301b51650..a355f86a17 100644 --- a/test/ruby/test_method.rb +++ b/test/ruby/test_method.rb @@ -1624,6 +1624,20 @@ class TestMethod < Test::Unit::TestCase RUBY end + def test_super_with_splat + c = Class.new { + attr_reader :x + + def initialize(*args) + @x, _ = args + end + } + b = Class.new(c) { def initialize(...) = super } + a = Class.new(b) { def initialize(*args) = super } + obj = a.new(1, 2, 3) + assert_equal 1, obj.x + end + def test_warn_unused_block assert_in_out_err '-w', <<-'RUBY' do |_out, err, _status| def foo = nil diff --git a/test/ruby/test_yjit.rb b/test/ruby/test_yjit.rb index 44b935758d..90f9f8ad48 100644 --- a/test/ruby/test_yjit.rb +++ b/test/ruby/test_yjit.rb @@ -1586,21 +1586,19 @@ class TestYJIT < Test::Unit::TestCase end def test_kw_splat_nil - assert_compiles(<<~'RUBY', result: %i[ok ok ok], no_send_fallbacks: true) + assert_compiles(<<~'RUBY', result: %i[ok ok], no_send_fallbacks: true) def id(x) = x def kw_fw(arg, **) = id(arg, **) - def fw(...) = id(...) - def use = [fw(:ok), kw_fw(:ok), :ok.itself(**nil)] + def use = [kw_fw(:ok), :ok.itself(**nil)] use RUBY end def test_empty_splat - assert_compiles(<<~'RUBY', result: %i[ok ok], no_send_fallbacks: true) + assert_compiles(<<~'RUBY', result: :ok, no_send_fallbacks: true) def foo = :ok - def fw(...) = foo(...) - def use(empty) = [foo(*empty), fw] + def use(empty) = foo(*empty) use([]) RUBY diff --git a/tool/rjit/bindgen.rb b/tool/rjit/bindgen.rb index 2fc4576216..4105052dd2 100755 --- a/tool/rjit/bindgen.rb +++ b/tool/rjit/bindgen.rb @@ -447,6 +447,7 @@ generator = BindingGenerator.new( VM_CALL_ARGS_BLOCKARG VM_CALL_ARGS_SPLAT VM_CALL_FCALL + VM_CALL_FORWARDING VM_CALL_KWARG VM_CALL_KW_SPLAT VM_CALL_KW_SPLAT_MUT diff --git a/tool/ruby_vm/views/_sp_inc_helpers.erb b/tool/ruby_vm/views/_sp_inc_helpers.erb index d0b0bd79ef..740fe10142 100644 --- a/tool/ruby_vm/views/_sp_inc_helpers.erb +++ b/tool/ruby_vm/views/_sp_inc_helpers.erb @@ -18,7 +18,7 @@ sp_inc_of_sendish(const struct rb_callinfo *ci) * 3. Pop receiver. * 4. Push return value. */ - const int argb = (vm_ci_flag(ci) & VM_CALL_ARGS_BLOCKARG) ? 1 : 0; + const int argb = (vm_ci_flag(ci) & (VM_CALL_ARGS_BLOCKARG | VM_CALL_FORWARDING)) ? 1 : 0; const int argc = vm_ci_argc(ci); const int recv = 1; const int retn = 1; diff --git a/vm.c b/vm.c index 550ded3209..c7ccbc9550 100644 --- a/vm.c +++ b/vm.c @@ -961,7 +961,14 @@ vm_make_env_each(const rb_execution_context_t * const ec, rb_control_frame_t *co local_size = VM_ENV_DATA_SIZE; } else { - local_size = ISEQ_BODY(cfp->iseq)->local_table_size + VM_ENV_DATA_SIZE; + local_size = ISEQ_BODY(cfp->iseq)->local_table_size; + if (ISEQ_BODY(cfp->iseq)->param.flags.forwardable && VM_ENV_LOCAL_P(cfp->ep)) { + int ci_offset = local_size - ISEQ_BODY(cfp->iseq)->param.size + VM_ENV_DATA_SIZE; + + CALL_INFO ci = (CALL_INFO)VM_CF_LEP(cfp)[-ci_offset]; + local_size += vm_ci_argc(ci); + } + local_size += VM_ENV_DATA_SIZE; } /* diff --git a/vm_args.c b/vm_args.c index 5593f95695..6d31187ea4 100644 --- a/vm_args.c +++ b/vm_args.c @@ -1057,3 +1057,56 @@ vm_caller_setup_arg_block(const rb_execution_context_t *ec, rb_control_frame_t * } } } + +static void vm_adjust_stack_forwarding(const struct rb_execution_context_struct *ec, struct rb_control_frame_struct *cfp, CALL_INFO callers_info, VALUE splat); + +static VALUE +vm_caller_setup_args(const rb_execution_context_t *ec, rb_control_frame_t *reg_cfp, + CALL_DATA *cd, const rb_iseq_t *blockiseq, const int is_super, + struct rb_forwarding_call_data *adjusted_cd, struct rb_callinfo *adjusted_ci) +{ + CALL_INFO site_ci = (*cd)->ci; + VALUE bh = Qundef; + + if (vm_ci_flag(site_ci) & VM_CALL_FORWARDING) { + RUBY_ASSERT(ISEQ_BODY(ISEQ_BODY(GET_ISEQ())->local_iseq)->param.flags.forwardable); + CALL_INFO caller_ci = (CALL_INFO)TOPN(0); + + unsigned int forwarding_argc = vm_ci_argc(site_ci); + VALUE splat = Qfalse; + + if (vm_ci_flag(site_ci) & VM_CALL_ARGS_SPLAT) { + // If we're called with args_splat, the top 1 should be an array + splat = TOPN(1); + forwarding_argc += (RARRAY_LEN(splat) - 1); + } + + // Need to setup the block in case of e.g. `super { :block }` + if (is_super && blockiseq) { + bh = vm_caller_setup_arg_block(ec, GET_CFP(), site_ci, blockiseq, is_super); + } + else { + bh = VM_ENV_BLOCK_HANDLER(GET_LEP()); + } + + vm_adjust_stack_forwarding(ec, GET_CFP(), caller_ci, splat); + + *adjusted_ci = VM_CI_ON_STACK( + vm_ci_mid(site_ci), + (vm_ci_flag(caller_ci) | (vm_ci_flag(site_ci) & (VM_CALL_FCALL | VM_CALL_FORWARDING))), + forwarding_argc + vm_ci_argc(caller_ci), + vm_ci_kwarg(caller_ci) + ); + + adjusted_cd->cd.ci = adjusted_ci; + adjusted_cd->cd.cc = (*cd)->cc; + adjusted_cd->caller_ci = caller_ci; + + *cd = &adjusted_cd->cd; + } + else { + bh = vm_caller_setup_arg_block(ec, GET_CFP(), site_ci, blockiseq, is_super); + } + + return bh; +} diff --git a/vm_callinfo.h b/vm_callinfo.h index 71ab9fe3fa..ccc28eb527 100644 --- a/vm_callinfo.h +++ b/vm_callinfo.h @@ -26,6 +26,7 @@ enum vm_call_flag_bits { VM_CALL_OPT_SEND_bit, // internal flag VM_CALL_KW_SPLAT_MUT_bit, // kw splat hash can be modified (to avoid allocating a new one) VM_CALL_ARGS_SPLAT_MUT_bit, // args splat can be modified (to avoid allocating a new one) + VM_CALL_FORWARDING_bit, // m(...) VM_CALL__END }; @@ -42,6 +43,7 @@ enum vm_call_flag_bits { #define VM_CALL_OPT_SEND (0x01 << VM_CALL_OPT_SEND_bit) #define VM_CALL_KW_SPLAT_MUT (0x01 << VM_CALL_KW_SPLAT_MUT_bit) #define VM_CALL_ARGS_SPLAT_MUT (0x01 << VM_CALL_ARGS_SPLAT_MUT_bit) +#define VM_CALL_FORWARDING (0x01 << VM_CALL_FORWARDING_bit) struct rb_callinfo_kwarg { int keyword_len; @@ -240,6 +242,21 @@ vm_ci_new_runtime_(ID mid, unsigned int flag, unsigned int argc, const struct rb #define VM_CALLINFO_NOT_UNDER_GC IMEMO_FL_USER0 +static inline bool +vm_ci_markable(const struct rb_callinfo *ci) +{ + if (! ci) { + return false; /* or true? This is Qfalse... */ + } + else if (vm_ci_packed_p(ci)) { + return true; + } + else { + VM_ASSERT(IMEMO_TYPE_P(ci, imemo_callinfo)); + return ! FL_ANY_RAW((VALUE)ci, VM_CALLINFO_NOT_UNDER_GC); + } +} + #define VM_CI_ON_STACK(mid_, flags_, argc_, kwarg_) \ (struct rb_callinfo) { \ .flags = T_IMEMO | \ @@ -318,6 +335,8 @@ vm_cc_new(VALUE klass, *((struct rb_callable_method_entry_struct **)&cc->cme_) = (struct rb_callable_method_entry_struct *)cme; *((vm_call_handler *)&cc->call_) = call; + VM_ASSERT(RB_TYPE_P(klass, T_CLASS) || RB_TYPE_P(klass, T_ICLASS)); + switch (type) { case cc_type_normal: break; diff --git a/vm_core.h b/vm_core.h index 30e06c48ed..15ca4571af 100644 --- a/vm_core.h +++ b/vm_core.h @@ -427,6 +427,7 @@ struct rb_iseq_constant_body { unsigned int anon_rest: 1; unsigned int anon_kwrest: 1; unsigned int use_block: 1; + unsigned int forwardable: 1; } flags; unsigned int size; diff --git a/vm_insnhelper.c b/vm_insnhelper.c index 46737145ca..a74f5096f6 100644 --- a/vm_insnhelper.c +++ b/vm_insnhelper.c @@ -2526,6 +2526,15 @@ vm_base_ptr(const rb_control_frame_t *cfp) if (cfp->iseq && VM_FRAME_RUBYFRAME_P(cfp)) { VALUE *bp = prev_cfp->sp + ISEQ_BODY(cfp->iseq)->local_table_size + VM_ENV_DATA_SIZE; + + if (ISEQ_BODY(cfp->iseq)->param.flags.forwardable && VM_ENV_LOCAL_P(cfp->ep)) { + int lts = ISEQ_BODY(cfp->iseq)->local_table_size; + int params = ISEQ_BODY(cfp->iseq)->param.size; + + CALL_INFO ci = (CALL_INFO)cfp->ep[-(VM_ENV_DATA_SIZE + (lts - params))]; // skip EP stuff, CI should be last local + bp += vm_ci_argc(ci); + } + if (ISEQ_BODY(cfp->iseq)->type == ISEQ_TYPE_METHOD || VM_FRAME_BMETHOD_P(cfp)) { /* adjust `self' */ bp += 1; @@ -2594,6 +2603,7 @@ rb_simple_iseq_p(const rb_iseq_t *iseq) ISEQ_BODY(iseq)->param.flags.has_kw == FALSE && ISEQ_BODY(iseq)->param.flags.has_kwrest == FALSE && ISEQ_BODY(iseq)->param.flags.accepts_no_kwarg == FALSE && + ISEQ_BODY(iseq)->param.flags.forwardable == FALSE && ISEQ_BODY(iseq)->param.flags.has_block == FALSE; } @@ -2606,6 +2616,7 @@ rb_iseq_only_optparam_p(const rb_iseq_t *iseq) ISEQ_BODY(iseq)->param.flags.has_kw == FALSE && ISEQ_BODY(iseq)->param.flags.has_kwrest == FALSE && ISEQ_BODY(iseq)->param.flags.accepts_no_kwarg == FALSE && + ISEQ_BODY(iseq)->param.flags.forwardable == FALSE && ISEQ_BODY(iseq)->param.flags.has_block == FALSE; } @@ -2617,6 +2628,7 @@ rb_iseq_only_kwparam_p(const rb_iseq_t *iseq) ISEQ_BODY(iseq)->param.flags.has_post == FALSE && ISEQ_BODY(iseq)->param.flags.has_kw == TRUE && ISEQ_BODY(iseq)->param.flags.has_kwrest == FALSE && + ISEQ_BODY(iseq)->param.flags.forwardable == FALSE && ISEQ_BODY(iseq)->param.flags.has_block == FALSE; } @@ -3053,7 +3065,7 @@ vm_callee_setup_arg(rb_execution_context_t *ec, struct rb_calling_info *calling, argument_arity_error(ec, iseq, calling->argc, lead_num, lead_num); } - VM_ASSERT(ci == calling->cd->ci); + //VM_ASSERT(ci == calling->cd->ci); VM_ASSERT(cc == calling->cc); if (vm_call_iseq_optimizable_p(ci, cc)) { @@ -3140,9 +3152,125 @@ vm_callee_setup_arg(rb_execution_context_t *ec, struct rb_calling_info *calling, } } + // Called iseq is using ... param + // def foo(...) # <- iseq for foo will have "forwardable" + // + // We want to set the `...` local to the caller's CI + // foo(1, 2) # <- the ci for this should end up as `...` + // + // So hopefully the stack looks like: + // + // => 1 + // => 2 + // => * + // => ** + // => & + // => ... # <- points at `foo`s CI + // => cref_or_me + // => specval + // => type + // + if (ISEQ_BODY(iseq)->param.flags.forwardable) { + if ((vm_ci_flag(ci) & VM_CALL_FORWARDING)) { + struct rb_forwarding_call_data * forward_cd = (struct rb_forwarding_call_data *)calling->cd; + if (vm_ci_argc(ci) != vm_ci_argc(forward_cd->caller_ci)) { + ci = vm_ci_new_runtime( + vm_ci_mid(ci), + vm_ci_flag(ci), + vm_ci_argc(ci), + vm_ci_kwarg(ci)); + } else { + ci = forward_cd->caller_ci; + } + } + // C functions calling iseqs will stack allocate a CI, + // so we need to convert it to heap allocated + if (!vm_ci_markable(ci)) { + ci = vm_ci_new_runtime( + vm_ci_mid(ci), + vm_ci_flag(ci), + vm_ci_argc(ci), + vm_ci_kwarg(ci)); + } + argv[param_size - 1] = (VALUE)ci; + return 0; + } + return setup_parameters_complex(ec, iseq, calling, ci, argv, arg_setup_method); } +static void +vm_adjust_stack_forwarding(const struct rb_execution_context_struct *ec, struct rb_control_frame_struct *cfp, CALL_INFO callers_info, VALUE splat) +{ + // This case is when the caller is using a ... parameter. + // For example `bar(...)`. The call info will have VM_CALL_FORWARDING + // In this case the caller's caller's CI will be on the stack. + // + // For example: + // + // def bar(a, b); a + b; end + // def foo(...); bar(...); end + // foo(1, 2) # <- this CI will be on the stack when we call `bar(...)` + // + // Stack layout will be: + // + // > 1 + // > 2 + // > CI for foo(1, 2) + // > cref_or_me + // > specval + // > type + // > receiver + // > CI for foo(1, 2), via `getlocal ...` + // > ( SP points here ) + const VALUE * lep = VM_CF_LEP(cfp); + + // We'll need to copy argc args to this SP + int argc = vm_ci_argc(callers_info); + + const rb_iseq_t *iseq; + + // If we're in an escaped environment (lambda for example), get the iseq + // from the captured env. + if (VM_ENV_FLAGS(lep, VM_ENV_FLAG_ESCAPED)) { + rb_env_t * env = (rb_env_t *)lep[VM_ENV_DATA_INDEX_ENV]; + iseq = env->iseq; + } + else { // Otherwise use the lep to find the caller + iseq = rb_vm_search_cf_from_ep(ec, cfp, lep)->iseq; + } + + // Our local storage is below the args we need to copy + int local_size = ISEQ_BODY(iseq)->local_table_size + argc; + + const VALUE * from = lep - (local_size + VM_ENV_DATA_SIZE - 1); // 2 for EP values + VALUE * to = cfp->sp - 1; // clobber the CI + + if (RTEST(splat)) { + to -= 1; // clobber the splat array + CHECK_VM_STACK_OVERFLOW0(cfp, to, RARRAY_LEN(splat)); + MEMCPY(to, RARRAY_CONST_PTR(splat), VALUE, RARRAY_LEN(splat)); + to += RARRAY_LEN(splat); + } + + CHECK_VM_STACK_OVERFLOW0(cfp, to, argc); + MEMCPY(to, from, VALUE, argc); + cfp->sp = to + argc; + + // Stack layout should now be: + // + // > 1 + // > 2 + // > CI for foo(1, 2) + // > cref_or_me + // > specval + // > type + // > receiver + // > 1 + // > 2 + // > ( SP points here ) +} + static VALUE vm_call_iseq_setup(rb_execution_context_t *ec, rb_control_frame_t *cfp, struct rb_calling_info *calling) { @@ -3150,8 +3278,15 @@ vm_call_iseq_setup(rb_execution_context_t *ec, rb_control_frame_t *cfp, struct r const struct rb_callcache *cc = calling->cc; const rb_iseq_t *iseq = def_iseq_ptr(vm_cc_cme(cc)->def); - const int param_size = ISEQ_BODY(iseq)->param.size; - const int local_size = ISEQ_BODY(iseq)->local_table_size; + int param_size = ISEQ_BODY(iseq)->param.size; + int local_size = ISEQ_BODY(iseq)->local_table_size; + + // Setting up local size and param size + if (ISEQ_BODY(iseq)->param.flags.forwardable) { + local_size = local_size + vm_ci_argc(calling->cd->ci); + param_size = param_size + vm_ci_argc(calling->cd->ci); + } + const int opt_pc = vm_callee_setup_arg(ec, calling, iseq, cfp->sp - calling->argc, param_size, local_size); return vm_call_iseq_setup_2(ec, cfp, calling, opt_pc, param_size, local_size); } @@ -3661,7 +3796,7 @@ vm_call_cfunc_other(rb_execution_context_t *ec, rb_control_frame_t *reg_cfp, str return vm_call_cfunc_with_frame_(ec, reg_cfp, calling, argc, argv, stack_bottom); } else { - CC_SET_FASTPATH(calling->cc, vm_call_cfunc_with_frame, !rb_splat_or_kwargs_p(ci) && !calling->kw_splat); + CC_SET_FASTPATH(calling->cc, vm_call_cfunc_with_frame, !rb_splat_or_kwargs_p(ci) && !calling->kw_splat && !(vm_ci_flag(ci) & VM_CALL_FORWARDING)); return vm_call_cfunc_with_frame(ec, reg_cfp, calling); } @@ -4543,7 +4678,7 @@ vm_call_method_each_type(rb_execution_context_t *ec, rb_control_frame_t *cfp, st rb_check_arity(calling->argc, 1, 1); - const unsigned int aset_mask = (VM_CALL_ARGS_SPLAT | VM_CALL_KW_SPLAT | VM_CALL_KWARG); + const unsigned int aset_mask = (VM_CALL_ARGS_SPLAT | VM_CALL_KW_SPLAT | VM_CALL_KWARG | VM_CALL_FORWARDING); if (vm_cc_markable(cc)) { vm_cc_attr_index_initialize(cc, INVALID_SHAPE_ID); @@ -4577,7 +4712,7 @@ vm_call_method_each_type(rb_execution_context_t *ec, rb_control_frame_t *cfp, st CALLER_SETUP_ARG(cfp, calling, ci, 0); rb_check_arity(calling->argc, 0, 0); vm_cc_attr_index_initialize(cc, INVALID_SHAPE_ID); - const unsigned int ivar_mask = (VM_CALL_ARGS_SPLAT | VM_CALL_KW_SPLAT); + const unsigned int ivar_mask = (VM_CALL_ARGS_SPLAT | VM_CALL_KW_SPLAT | VM_CALL_FORWARDING); VM_CALL_METHOD_ATTR(v, vm_call_ivar(ec, cfp, calling), CC_SET_FASTPATH(cc, vm_call_ivar, !(vm_ci_flag(ci) & ivar_mask))); @@ -4796,13 +4931,18 @@ vm_search_super_method(const rb_control_frame_t *reg_cfp, struct rb_call_data *c ID mid = me->def->original_id; - // update iseq. really? (TODO) - cd->ci = vm_ci_new_runtime(mid, - vm_ci_flag(cd->ci), - vm_ci_argc(cd->ci), - vm_ci_kwarg(cd->ci)); + if (!vm_ci_markable(cd->ci)) { + VM_FORCE_WRITE((const VALUE *)&cd->ci->mid, (VALUE)mid); + } + else { + // update iseq. really? (TODO) + cd->ci = vm_ci_new_runtime(mid, + vm_ci_flag(cd->ci), + vm_ci_argc(cd->ci), + vm_ci_kwarg(cd->ci)); - RB_OBJ_WRITTEN(reg_cfp->iseq, Qundef, cd->ci); + RB_OBJ_WRITTEN(reg_cfp->iseq, Qundef, cd->ci); + } const struct rb_callcache *cc; @@ -5737,8 +5877,20 @@ VALUE rb_vm_send(rb_execution_context_t *ec, rb_control_frame_t *reg_cfp, CALL_DATA cd, ISEQ blockiseq) { stack_check(ec); - VALUE bh = vm_caller_setup_arg_block(ec, GET_CFP(), cd->ci, blockiseq, false); + + struct rb_forwarding_call_data adjusted_cd; + struct rb_callinfo adjusted_ci; + CALL_DATA _cd = cd; + + VALUE bh = vm_caller_setup_args(ec, GET_CFP(), &cd, blockiseq, false, &adjusted_cd, &adjusted_ci); VALUE val = vm_sendish(ec, GET_CFP(), cd, bh, mexp_search_method); + + if (vm_ci_flag(_cd->ci) & VM_CALL_FORWARDING) { + if (_cd->cc != cd->cc && vm_cc_markable(cd->cc)) { + RB_OBJ_WRITE(GET_ISEQ(), &_cd->cc, cd->cc); + } + } + VM_EXEC(ec, val); return val; } @@ -5757,8 +5909,19 @@ VALUE rb_vm_invokesuper(rb_execution_context_t *ec, rb_control_frame_t *reg_cfp, CALL_DATA cd, ISEQ blockiseq) { stack_check(ec); - VALUE bh = vm_caller_setup_arg_block(ec, GET_CFP(), cd->ci, blockiseq, true); + struct rb_forwarding_call_data adjusted_cd; + struct rb_callinfo adjusted_ci; + CALL_DATA _cd = cd; + + VALUE bh = vm_caller_setup_args(ec, GET_CFP(), &cd, blockiseq, true, &adjusted_cd, &adjusted_ci); VALUE val = vm_sendish(ec, GET_CFP(), cd, bh, mexp_search_super); + + if (vm_ci_flag(_cd->ci) & VM_CALL_FORWARDING) { + if (_cd->cc != cd->cc && vm_cc_markable(cd->cc)) { + RB_OBJ_WRITE(GET_ISEQ(), &_cd->cc, cd->cc); + } + } + VM_EXEC(ec, val); return val; } diff --git a/vm_insnhelper.h b/vm_insnhelper.h index 926700b90a..3b0958d2e1 100644 --- a/vm_insnhelper.h +++ b/vm_insnhelper.h @@ -66,6 +66,11 @@ typedef enum call_type { CALL_FCALL_KW } call_type; +struct rb_forwarding_call_data { + struct rb_call_data cd; + CALL_INFO caller_ci; +}; + #if VM_COLLECT_USAGE_DETAILS enum vm_regan_regtype { VM_REGAN_PC = 0, @@ -258,8 +263,8 @@ THROW_DATA_CONSUMED_SET(struct vm_throw_data *obj) static inline bool vm_call_cacheable(const struct rb_callinfo *ci, const struct rb_callcache *cc) { - return (vm_ci_flag(ci) & VM_CALL_FCALL) || - METHOD_ENTRY_VISI(vm_cc_cme(cc)) != METHOD_VISI_PROTECTED; + return !(vm_ci_flag(ci) & VM_CALL_FORWARDING) && ((vm_ci_flag(ci) & VM_CALL_FCALL) || + METHOD_ENTRY_VISI(vm_cc_cme(cc)) != METHOD_VISI_PROTECTED); } /* If this returns true, an optimized function returned by `vm_call_iseq_setup_func` can be used as a fastpath. */ diff --git a/vm_method.c b/vm_method.c index 4cf03fafde..4f82efbf00 100644 --- a/vm_method.c +++ b/vm_method.c @@ -1187,6 +1187,7 @@ rb_check_overloaded_cme(const rb_callable_method_entry_t *cme, const struct rb_c { if (UNLIKELY(cme->def->iseq_overload) && (vm_ci_flag(ci) & (VM_CALL_ARGS_SIMPLE)) && + (!(vm_ci_flag(ci) & VM_CALL_FORWARDING)) && (int)vm_ci_argc(ci) == ISEQ_BODY(method_entry_iseqptr(cme))->param.lead_num) { VM_ASSERT(cme->def->type == VM_METHOD_TYPE_ISEQ); // iseq_overload is marked only on ISEQ methods diff --git a/yjit.c b/yjit.c index 9f68e363ef..56cc892e5b 100644 --- a/yjit.c +++ b/yjit.c @@ -701,6 +701,12 @@ rb_get_iseq_flags_accepts_no_kwarg(const rb_iseq_t *iseq) return iseq->body->param.flags.accepts_no_kwarg; } +bool +rb_get_iseq_flags_forwardable(const rb_iseq_t *iseq) +{ + return iseq->body->param.flags.forwardable; +} + const rb_seq_param_keyword_struct * rb_get_iseq_body_param_keyword(const rb_iseq_t *iseq) { diff --git a/yjit/bindgen/src/main.rs b/yjit/bindgen/src/main.rs index e76d9b3063..62c7ff2c79 100644 --- a/yjit/bindgen/src/main.rs +++ b/yjit/bindgen/src/main.rs @@ -431,6 +431,7 @@ fn main() { .allowlist_function("rb_get_iseq_flags_ambiguous_param0") .allowlist_function("rb_get_iseq_flags_accepts_no_kwarg") .allowlist_function("rb_get_iseq_flags_ruby2_keywords") + .allowlist_function("rb_get_iseq_flags_forwardable") .allowlist_function("rb_get_iseq_body_local_table_size") .allowlist_function("rb_get_iseq_body_param_keyword") .allowlist_function("rb_get_iseq_body_param_size") diff --git a/yjit/src/codegen.rs b/yjit/src/codegen.rs index 755e64c244..26234aec41 100644 --- a/yjit/src/codegen.rs +++ b/yjit/src/codegen.rs @@ -7118,6 +7118,8 @@ fn gen_send_iseq( let kw_splat = flags & VM_CALL_KW_SPLAT != 0; let splat_call = flags & VM_CALL_ARGS_SPLAT != 0; + let forwarding_call = unsafe { rb_get_iseq_flags_forwardable(iseq) }; + // For computing offsets to callee locals let num_params = unsafe { get_iseq_body_param_size(iseq) as i32 }; let num_locals = unsafe { get_iseq_body_local_table_size(iseq) as i32 }; @@ -7160,9 +7162,15 @@ fn gen_send_iseq( exit_if_supplying_kw_and_has_no_kw(asm, supplying_kws, doing_kw_call)?; exit_if_supplying_kws_and_accept_no_kwargs(asm, supplying_kws, iseq)?; exit_if_doing_kw_and_splat(asm, doing_kw_call, flags)?; - exit_if_wrong_number_arguments(asm, arg_setup_block, opts_filled, flags, opt_num, iseq_has_rest)?; + if !forwarding_call { + exit_if_wrong_number_arguments(asm, arg_setup_block, opts_filled, flags, opt_num, iseq_has_rest)?; + } exit_if_doing_kw_and_opts_missing(asm, doing_kw_call, opts_missing)?; exit_if_has_rest_and_optional_and_block(asm, iseq_has_rest, opt_num, iseq, block_arg)?; + if forwarding_call && flags & VM_CALL_OPT_SEND != 0 { + gen_counter_incr(asm, Counter::send_iseq_send_forwarding); + return None; + } let block_arg_type = exit_if_unsupported_block_arg_type(jit, asm, block_arg)?; // Bail if we can't drop extra arguments for a yield by just popping them @@ -7671,25 +7679,26 @@ fn gen_send_iseq( } } - // Nil-initialize missing optional parameters - nil_fill( - "nil-initialize missing optionals", - { - let begin = -argc + required_num + opts_filled; - let end = -argc + required_num + opt_num; + if !forwarding_call { + // Nil-initialize missing optional parameters + nil_fill( + "nil-initialize missing optionals", + { + let begin = -argc + required_num + opts_filled; + let end = -argc + required_num + opt_num; - begin..end - }, - asm - ); - // Nil-initialize the block parameter. It's the last parameter local - if iseq_has_block_param { - let block_param = asm.ctx.sp_opnd(-argc + num_params - 1); - asm.store(block_param, Qnil.into()); - } - // Nil-initialize non-parameter locals - nil_fill( - "nil-initialize locals", + begin..end + }, + asm + ); + // Nil-initialize the block parameter. It's the last parameter local + if iseq_has_block_param { + let block_param = asm.ctx.sp_opnd(-argc + num_params - 1); + asm.store(block_param, Qnil.into()); + } + // Nil-initialize non-parameter locals + nil_fill( + "nil-initialize locals", { let begin = -argc + num_params; let end = -argc + num_locals; @@ -7697,7 +7706,13 @@ fn gen_send_iseq( begin..end }, asm - ); + ); + } + + if forwarding_call { + assert_eq!(1, num_params); + asm.mov(asm.stack_opnd(-1), VALUE(ci as usize).into()); + } // Points to the receiver operand on the stack unless a captured environment is used let recv = match captured_opnd { @@ -7716,7 +7731,13 @@ fn gen_send_iseq( jit_save_pc(jit, asm); // Adjust the callee's stack pointer - let callee_sp = asm.lea(asm.ctx.sp_opnd(-argc + num_locals + VM_ENV_DATA_SIZE as i32)); + let callee_sp = if forwarding_call { + let offs = num_locals + VM_ENV_DATA_SIZE as i32; + asm.lea(asm.ctx.sp_opnd(offs)) + } else { + let offs = -argc + num_locals + VM_ENV_DATA_SIZE as i32; + asm.lea(asm.ctx.sp_opnd(offs)) + }; let specval = if let Some(prev_ep) = prev_ep { // We've already side-exited if the callee expects a block, so we @@ -8519,6 +8540,14 @@ fn gen_send_general( return Some(EndBlock); } + let ci_flags = unsafe { vm_ci_flag(ci) }; + + // Dynamic stack layout. No good way to support without inlining. + if ci_flags & VM_CALL_FORWARDING != 0 { + gen_counter_incr(asm, Counter::send_iseq_forwarding); + return None; + } + let recv_idx = argc + if flags & VM_CALL_ARGS_BLOCKARG != 0 { 1 } else { 0 }; let comptime_recv = jit.peek_at_stack(&asm.ctx, recv_idx as isize); let comptime_recv_klass = comptime_recv.class_of(); @@ -9286,6 +9315,10 @@ fn gen_invokesuper_specialized( gen_counter_incr(asm, Counter::invokesuper_kw_splat); return None; } + if ci_flags & VM_CALL_FORWARDING != 0 { + gen_counter_incr(asm, Counter::invokesuper_forwarding); + return None; + } // Ensure we haven't rebound this method onto an incompatible class. // In the interpreter we try to avoid making this check by performing some diff --git a/yjit/src/cruby.rs b/yjit/src/cruby.rs index 53586cb4f4..b069137664 100644 --- a/yjit/src/cruby.rs +++ b/yjit/src/cruby.rs @@ -714,6 +714,7 @@ mod manual_defs { pub const VM_CALL_ARGS_SIMPLE: u32 = 1 << VM_CALL_ARGS_SIMPLE_bit; pub const VM_CALL_ARGS_SPLAT: u32 = 1 << VM_CALL_ARGS_SPLAT_bit; pub const VM_CALL_ARGS_BLOCKARG: u32 = 1 << VM_CALL_ARGS_BLOCKARG_bit; + pub const VM_CALL_FORWARDING: u32 = 1 << VM_CALL_FORWARDING_bit; pub const VM_CALL_FCALL: u32 = 1 << VM_CALL_FCALL_bit; pub const VM_CALL_KWARG: u32 = 1 << VM_CALL_KWARG_bit; pub const VM_CALL_KW_SPLAT: u32 = 1 << VM_CALL_KW_SPLAT_bit; diff --git a/yjit/src/cruby_bindings.inc.rs b/yjit/src/cruby_bindings.inc.rs index 1f128f5e7e..994fef28b7 100644 --- a/yjit/src/cruby_bindings.inc.rs +++ b/yjit/src/cruby_bindings.inc.rs @@ -675,7 +675,8 @@ pub const VM_CALL_ZSUPER_bit: vm_call_flag_bits = 9; pub const VM_CALL_OPT_SEND_bit: vm_call_flag_bits = 10; pub const VM_CALL_KW_SPLAT_MUT_bit: vm_call_flag_bits = 11; pub const VM_CALL_ARGS_SPLAT_MUT_bit: vm_call_flag_bits = 12; -pub const VM_CALL__END: vm_call_flag_bits = 13; +pub const VM_CALL_FORWARDING_bit: vm_call_flag_bits = 13; +pub const VM_CALL__END: vm_call_flag_bits = 14; pub type vm_call_flag_bits = u32; #[repr(C)] pub struct rb_callinfo_kwarg { @@ -1184,6 +1185,7 @@ extern "C" { pub fn rb_get_iseq_flags_has_block(iseq: *const rb_iseq_t) -> bool; pub fn rb_get_iseq_flags_ambiguous_param0(iseq: *const rb_iseq_t) -> bool; pub fn rb_get_iseq_flags_accepts_no_kwarg(iseq: *const rb_iseq_t) -> bool; + pub fn rb_get_iseq_flags_forwardable(iseq: *const rb_iseq_t) -> bool; pub fn rb_get_iseq_body_param_keyword( iseq: *const rb_iseq_t, ) -> *const rb_seq_param_keyword_struct; diff --git a/yjit/src/stats.rs b/yjit/src/stats.rs index 244ccfd55f..8df5bd7ee3 100644 --- a/yjit/src/stats.rs +++ b/yjit/src/stats.rs @@ -378,6 +378,7 @@ make_counters! { send_iseq_block_arg_type, send_iseq_clobbering_block_arg, send_iseq_complex_discard_extras, + send_iseq_forwarding, send_iseq_leaf_builtin_block_arg_block_param, send_iseq_kw_splat_non_nil, send_iseq_kwargs_mismatch, @@ -385,6 +386,7 @@ make_counters! { send_iseq_has_no_kw, send_iseq_accepts_no_kwarg, send_iseq_materialized_block, + send_iseq_send_forwarding, send_iseq_splat_not_array, send_iseq_splat_with_kw, send_iseq_missing_optional_kw, @@ -414,6 +416,7 @@ make_counters! { send_optimized_block_arg, invokesuper_defined_class_mismatch, + invokesuper_forwarding, invokesuper_kw_splat, invokesuper_kwarg, invokesuper_megamorphic,