[PRISM] Properly compile MultiTargetNodes within parameters

If there are MultiTargetNodes within parameters, we need to
iterate over them and compile them individually correctly, once
the locals are all in the correct spaces. We need to add one
getlocal for the hidden variable, and then can recurse into the
MultiTargetNodes themselves
This commit is contained in:
Jemma Issroff 2023-12-11 15:17:48 -05:00
parent 5c8e1911ca
commit 69d60cc67b
2 changed files with 105 additions and 12 deletions

View File

@ -3709,6 +3709,13 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret,
PM_POP_IF_POPPED; PM_POP_IF_POPPED;
return; return;
} }
case PM_REQUIRED_PARAMETER_NODE: {
pm_required_parameter_node_t *required_parameter_node = (pm_required_parameter_node_t *)node;
int index = pm_lookup_local_index(iseq, scope_node, required_parameter_node->name);
ADD_SETLOCAL(ret, &dummy_line_node, index, 0);
return;
}
case PM_MULTI_TARGET_NODE: { case PM_MULTI_TARGET_NODE: {
pm_multi_target_node_t *cast = (pm_multi_target_node_t *) node; pm_multi_target_node_t *cast = (pm_multi_target_node_t *) node;
bool has_rest_expression = (cast->rest && bool has_rest_expression = (cast->rest &&
@ -3723,17 +3730,22 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret,
} }
} }
if (has_rest_expression) {
if (cast->rights.size) { if (cast->rights.size) {
ADD_INSN2(ret, &dummy_line_node, expandarray, INT2FIX(cast->rights.size), INT2FIX(3));
}
pm_node_t *expression = ((pm_splat_node_t *)cast->rest)->expression;
PM_COMPILE_NOT_POPPED(expression);
}
if (cast->rights.size) {
if (!has_rest_expression) {
ADD_INSN2(ret, &dummy_line_node, expandarray, INT2FIX(cast->rights.size), INT2FIX(2)); ADD_INSN2(ret, &dummy_line_node, expandarray, INT2FIX(cast->rights.size), INT2FIX(2));
}
for (size_t index = 0; index < cast->rights.size; index++) { for (size_t index = 0; index < cast->rights.size; index++) {
PM_COMPILE_NOT_POPPED(cast->rights.nodes[index]); PM_COMPILE_NOT_POPPED(cast->rights.nodes[index]);
} }
} }
if (has_rest_expression) {
pm_node_t *expression = ((pm_splat_node_t *)cast->rest)->expression;
PM_COMPILE_NOT_POPPED(expression);
}
return; return;
} }
case PM_MULTI_WRITE_NODE: { case PM_MULTI_WRITE_NODE: {
@ -4368,6 +4380,11 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret,
// those anonymous items temporary names (as below) // those anonymous items temporary names (as below)
int local_index = 0; int local_index = 0;
// We will assign these values now, if applicable, and use them for
// the ISEQs on these multis
int required_multis_hidden_index = 0;
int post_multis_hidden_index = 0;
// Here we figure out local table indices and insert them in to the // Here we figure out local table indices and insert them in to the
// index lookup table and local tables. // index lookup table and local tables.
// //
@ -4384,6 +4401,7 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret,
// def foo(a, (b, *c, d), e = 1, *f, g, (h, *i, j), k:, l: 1, **m, &n) // def foo(a, (b, *c, d), e = 1, *f, g, (h, *i, j), k:, l: 1, **m, &n)
// ^^^^^^^^^^ // ^^^^^^^^^^
case PM_MULTI_TARGET_NODE: { case PM_MULTI_TARGET_NODE: {
required_multis_hidden_index = local_index;
local = rb_make_temporary_id(local_index); local = rb_make_temporary_id(local_index);
local_table_for_iseq->ids[local_index] = local; local_table_for_iseq->ids[local_index] = local;
break; break;
@ -4453,6 +4471,7 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret,
// def foo(a, (b, *c, d), e = 1, *f, g, (h, *i, j), k:, l: 1, **m, &n) // def foo(a, (b, *c, d), e = 1, *f, g, (h, *i, j), k:, l: 1, **m, &n)
// ^^^^^^^^^^ // ^^^^^^^^^^
case PM_MULTI_TARGET_NODE: { case PM_MULTI_TARGET_NODE: {
post_multis_hidden_index = local_index;
local = rb_make_temporary_id(local_index); local = rb_make_temporary_id(local_index);
local_table_for_iseq->ids[local_index] = local; local_table_for_iseq->ids[local_index] = local;
break; break;
@ -4638,6 +4657,7 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret,
// Goal: fill in the names of the parameters in MultiTargetNodes // Goal: fill in the names of the parameters in MultiTargetNodes
// //
// Go through requireds again to set the multis // Go through requireds again to set the multis
if (requireds_list && requireds_list->size) { if (requireds_list && requireds_list->size) {
for (size_t i = 0; i < requireds_list->size; i++) { for (size_t i = 0; i < requireds_list->size; i++) {
// For each MultiTargetNode, we're going to have one // For each MultiTargetNode, we're going to have one
@ -4757,6 +4777,32 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret,
body->param.opt_table = (const VALUE *)opt_table; body->param.opt_table = (const VALUE *)opt_table;
} }
if (requireds_list && requireds_list->size) {
for (size_t i = 0; i < requireds_list->size; i++) {
// For each MultiTargetNode, we're going to have one
// additional anonymous local not represented in the locals table
// We want to account for this in our table size
pm_node_t *required = requireds_list->nodes[i];
if (PM_NODE_TYPE_P(required, PM_MULTI_TARGET_NODE)) {
ADD_GETLOCAL(ret, &dummy_line_node, table_size - required_multis_hidden_index, 0);
PM_COMPILE(required);
}
}
}
if (posts_list && posts_list->size) {
for (size_t i = 0; i < posts_list->size; i++) {
// For each MultiTargetNode, we're going to have one
// additional anonymous local not represented in the locals table
// We want to account for this in our table size
pm_node_t *post = posts_list->nodes[i];
if (PM_NODE_TYPE_P(post, PM_MULTI_TARGET_NODE)) {
ADD_GETLOCAL(ret, &dummy_line_node, table_size - post_multis_hidden_index, 0);
PM_COMPILE(post);
}
}
}
switch (body->type) { switch (body->type) {
case ISEQ_TYPE_BLOCK: { case ISEQ_TYPE_BLOCK: {
LABEL *start = ISEQ_COMPILE_DATA(iseq)->start_label = NEW_LABEL(0); LABEL *start = ISEQ_COMPILE_DATA(iseq)->start_label = NEW_LABEL(0);

View File

@ -505,6 +505,7 @@ module Prism
assert_prism_eval("a, (b, *c) = [1, [2, 3]]; c") assert_prism_eval("a, (b, *c) = [1, [2, 3]]; c")
assert_prism_eval("a, (b, *c) = 1, [2, 3]; c") assert_prism_eval("a, (b, *c) = 1, [2, 3]; c")
assert_prism_eval("a, (b, *) = 1, [2, 3]; b") assert_prism_eval("a, (b, *) = 1, [2, 3]; b")
assert_prism_eval("a, (b, *c, d) = 1, [2, 3, 4]; [a, b, c, d]")
assert_prism_eval("(a, (b, c, d, e), f, g), h = [1, [2, 3]], 4, 5, [6, 7]; c") assert_prism_eval("(a, (b, c, d, e), f, g), h = [1, [2, 3]], 4, 5, [6, 7]; c")
end end
@ -1175,12 +1176,58 @@ module Prism
assert_prism_eval("class PrismTestDefNode; def prism_test_def_node(*a) a end end; PrismTestDefNode.new.prism_test_def_node(1).inspect") assert_prism_eval("class PrismTestDefNode; def prism_test_def_node(*a) a end end; PrismTestDefNode.new.prism_test_def_node(1).inspect")
# block argument # block argument
assert_prism_eval(<<-CODE assert_prism_eval(<<-CODE)
def self.prism_test_def_node(&block) prism_test_def_node2(&block) end def self.prism_test_def_node(&block) prism_test_def_node2(&block) end
def self.prism_test_def_node2() yield 1 end def self.prism_test_def_node2() yield 1 end
prism_test_def_node2 {|a| a } prism_test_def_node2 {|a| a }
CODE CODE
# multi argument
assert_prism_eval(<<-CODE)
def self.prism_test_def_node(a, (b, *c, d))
[a, b, c, d]
end
prism_test_def_node("a", ["b", "c", "d"])
CODE
assert_prism_eval(<<-CODE)
def self.prism_test_def_node(a, (b, c, *))
[a, b, c]
end
prism_test_def_node("a", ["b", "c"])
CODE
assert_prism_eval(<<-CODE)
def self.prism_test_def_node(a, (*, b, c))
[a, b, c]
end
prism_test_def_node("a", ["b", "c"])
CODE
# recursive multis
assert_prism_eval(<<-CODE)
def self.prism_test_def_node(a, (b, *c, (d, *e, f)))
[a, b, c, d, d, e, f]
end
prism_test_def_node("a", ["b", "c", ["d", "e", "f"]])
CODE
# Many arguments
assert_prism_eval(<<-CODE)
def self.prism_test_def_node(a, (b, *c, d), e = 1, *f, g, (h, *i, j), k:, l: 1, **m)
[a, b, c, d, e, f, g, h, i, j, k, l, m]
end
prism_test_def_node(
"a",
["b", "c1", "c2", "d"],
"e",
"f1", "f2",
"g",
["h", "i1", "i2", "j"],
k: "k",
l: "l",
m1: "m1",
m2: "m2"
) )
CODE
end end
def test_method_parameters def test_method_parameters