From 2d9db72d37e570dcfc5e8b0b74476360cae96056 Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Thu, 11 Jan 2024 17:05:42 -0500 Subject: [PATCH] Write to constant path, call, and index in rescue When compiling a rescue reference, you can write to a constant path, a call, or an index. In these cases we need to compile the prefix expression first, then handle actually doing the write/call. --- prism_compile.c | 296 +++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 265 insertions(+), 31 deletions(-) diff --git a/prism_compile.c b/prism_compile.c index 7bba3be0a1..cab2af42e9 100644 --- a/prism_compile.c +++ b/prism_compile.c @@ -207,21 +207,21 @@ parse_string_encoded(const pm_node_t *node, const pm_string_t *string, const pm_ } static inline ID -parse_symbol(const uint8_t *start, const uint8_t *end, pm_parser_t *parser) +parse_symbol(const uint8_t *start, const uint8_t *end, const pm_parser_t *parser) { rb_encoding *enc = rb_enc_from_index(rb_enc_find_index(parser->encoding->name)); return rb_intern3((const char *) start, end - start, enc); } static inline ID -parse_string_symbol(const pm_string_t *string, pm_parser_t *parser) +parse_string_symbol(const pm_string_t *string, const pm_parser_t *parser) { const uint8_t *start = pm_string_source(string); return parse_symbol(start, start + pm_string_length(string), parser); } static inline ID -parse_location_symbol(pm_location_t *location, pm_parser_t *parser) +parse_location_symbol(const pm_location_t *location, const pm_parser_t *parser) { return parse_symbol(location->start, location->end, parser); } @@ -326,7 +326,7 @@ pm_new_regex(pm_regular_expression_node_t * cast, const pm_parser_t * parser) { * literal values can be compiled into a literal array. */ static inline VALUE -pm_static_literal_value(const pm_node_t *node, pm_scope_node_t *scope_node, pm_parser_t *parser) +pm_static_literal_value(const pm_node_t *node, const pm_scope_node_t *scope_node, const pm_parser_t *parser) { // Every node that comes into this function should already be marked as // static literal. If it's not, then we have a bug somewhere. @@ -1147,6 +1147,118 @@ pm_setup_args(pm_arguments_node_t *arguments_node, int *flags, struct rb_callinf return orig_argc; } +/** + * A callinfo struct basically mirrors the information that is going to be + * passed to a callinfo object that will be used on a send instruction. We use + * it to communicate the information between the function that derives it and + * the function that uses it. + */ +typedef struct { + int argc; + int flags; + struct rb_callinfo_kwarg *kwargs; +} pm_callinfo_t; + +/** + * Derive the callinfo from the given arguments node. It assumes the pointer to + * the callinfo struct is zeroed memory. + */ +static void +pm_arguments_node_callinfo(pm_callinfo_t *callinfo, const pm_arguments_node_t *node, const pm_scope_node_t *scope_node, const pm_parser_t *parser) +{ + if (node == NULL) { + if (callinfo->flags & VM_CALL_FCALL) { + callinfo->flags |= VM_CALL_VCALL; + } + } else { + const pm_node_list_t *arguments = &node->arguments; + bool has_splat = false; + + for (size_t argument_index = 0; argument_index < arguments->size; argument_index++) { + const pm_node_t *argument = arguments->nodes[argument_index]; + + switch (PM_NODE_TYPE(argument)) { + case PM_KEYWORD_HASH_NODE: { + pm_keyword_hash_node_t *keyword_hash = (pm_keyword_hash_node_t *) argument; + size_t elements_size = keyword_hash->elements.size; + + if (PM_NODE_FLAG_P(node, PM_ARGUMENTS_NODE_FLAGS_CONTAINS_KEYWORD_SPLAT)) { + for (size_t element_index = 0; element_index < elements_size; element_index++) { + const pm_node_t *element = keyword_hash->elements.nodes[element_index]; + + switch (PM_NODE_TYPE(element)) { + case PM_ASSOC_NODE: + callinfo->argc++; + break; + case PM_ASSOC_SPLAT_NODE: + if (elements_size > 1) callinfo->flags |= VM_CALL_KW_SPLAT_MUT; + callinfo->flags |= VM_CALL_KW_SPLAT; + break; + default: + rb_bug("Unknown type in keyword argument %s\n", pm_node_type_to_str(PM_NODE_TYPE(element))); + } + } + break; + } else if (PM_NODE_FLAG_P(keyword_hash, PM_KEYWORD_HASH_NODE_FLAGS_SYMBOL_KEYS)) { + // We need to first figure out if all elements of the + // KeywordHashNode are AssocNode nodes with symbol keys. If + // they are all symbol keys then we can pass them as keyword + // arguments. + callinfo->flags |= VM_CALL_KWARG; + + callinfo->kwargs = rb_xmalloc_mul_add(elements_size, sizeof(VALUE), sizeof(struct rb_callinfo_kwarg)); + callinfo->kwargs->references = 0; + callinfo->kwargs->keyword_len = (int) elements_size; + + VALUE *keywords = callinfo->kwargs->keywords; + for (size_t element_index = 0; element_index < elements_size; element_index++) { + pm_assoc_node_t *assoc = (pm_assoc_node_t *) keyword_hash->elements.nodes[element_index]; + keywords[element_index] = pm_static_literal_value(assoc->key, scope_node, parser); + } + } else { + // If they aren't all symbol keys then we need to construct + // a new hash and pass that as an argument. + callinfo->argc++; + callinfo->flags |= VM_CALL_KW_SPLAT; + + if (elements_size > 1) { + // A new hash will be created for the keyword arguments + // in this case, so mark the method as passing mutable + // keyword splat. + callinfo->flags |= VM_CALL_KW_SPLAT_MUT; + } + } + break; + } + case PM_SPLAT_NODE: { + // Splat nodes add a splat flag and can change the way the + // arguments are loaded by combining them into a single array. + callinfo->flags |= VM_CALL_ARGS_SPLAT; + if (((pm_splat_node_t *) argument)->expression != NULL) callinfo->argc++; + has_splat = true; + break; + } + case PM_FORWARDING_ARGUMENTS_NODE: { + // Forwarding arguments indicate that a splat and a block are + // present, and increase the argument count by one. + callinfo->flags |= VM_CALL_ARGS_BLOCKARG | VM_CALL_ARGS_SPLAT; + callinfo->argc++; + break; + } + default: { + // A regular argument increases the argument count by one. If + // there is a splat and this is the last argument, then the + // argument count becomes 1 because it gets grouped into a + // single array. + callinfo->argc++; + if (has_splat && (argument_index == arguments->size - 1)) callinfo->argc = 1; + break; + } + } + } + } +} + static void pm_compile_index_and_or_write_node(bool and_node, pm_node_t *receiver, pm_node_t *value, pm_arguments_node_t *arguments, pm_node_t *block, LINK_ANCHOR *const ret, rb_iseq_t *iseq, int lineno, const uint8_t * src, bool popped, pm_scope_node_t *scope_node, pm_parser_t *parser) { @@ -3234,6 +3346,20 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret, return; } + case PM_CALL_TARGET_NODE: { + // Call targets can be used to indirectly call a method in places like + // rescue references, for loops, and multiple assignment. In those + // circumstances, it's necessary to first compile the receiver, then to + // compile the method call itself. + // + // Therefore in the main switch case here where we're compiling a call + // target, we're only going to compile the receiver. Then wherever + // we've called into pm_compile_node when we're compiling call targets, + // we'll need to make sure we compile the method call as well. + pm_call_target_node_t *cast = (pm_call_target_node_t*) node; + PM_COMPILE_NOT_POPPED(cast->receiver); + return; + } case PM_CASE_NODE: { pm_case_node_t *case_node = (pm_case_node_t *)node; bool has_predicate = case_node->predicate; @@ -3731,13 +3857,23 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret, } ADD_INSN1(ret, &dummy_line_node, setconstant, child_name); - return ; + return; } case PM_CONSTANT_PATH_TARGET_NODE: { - pm_constant_path_target_node_t *cast = (pm_constant_path_target_node_t *)node; + // Constant path targets can be used to indirectly write a constant in + // places like rescue references, for loops, and multiple assignment. In + // those circumstances, it's necessary to first compile the parent, then + // to compile the child. + // + // Therefore in the main switch case here where we're compiling a + // constant path target, we're only going to compile the parent. Then + // wherever we've called into pm_compile_node when we're compiling + // constant path targets, we'll need to make sure we compile the child + // as well. + pm_constant_path_target_node_t *cast = (pm_constant_path_target_node_t *) node; if (cast->parent) { - PM_COMPILE(cast->parent); + PM_COMPILE_NOT_POPPED(cast->parent); } return; @@ -4379,6 +4515,33 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret, return; } + case PM_INDEX_TARGET_NODE: { + // Index targets can be used to indirectly call a method in places like + // rescue references, for loops, and multiple assignment. In those + // circumstances, it's necessary to first compile the receiver and + // arguments, then to compile the method call itself. + // + // Therefore in the main switch case here where we're compiling a index + // target, we're only going to compile the receiver and arguments. Then + // wherever we've called into pm_compile_node when we're compiling index + // targets, we'll need to make sure we compile the method call as well. + // + // Note that index target nodes can have blocks attached to them in the + // form of the & operator. These blocks should almost always be compiled + // _after_ the value that is being written is added to the argument + // list, so we don't compile them here. Therefore at the places where + // these nodes are handled, blocks also need to be handled. + pm_index_target_node_t *cast = (pm_index_target_node_t*) node; + PM_COMPILE_NOT_POPPED(cast->receiver); + + if (cast->arguments != NULL) { + int flags; + struct rb_callinfo_kwarg *keywords = NULL; + pm_setup_args(cast->arguments, &flags, &keywords, iseq, ret, src, false, scope_node, dummy_line_node, parser); + } + + return; + } case PM_INSTANCE_VARIABLE_AND_WRITE_NODE: { pm_instance_variable_and_write_node_t *instance_variable_and_write_node = (pm_instance_variable_and_write_node_t*) node; @@ -5371,49 +5534,120 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret, return; } case PM_RESCUE_NODE: { - LABEL *excep_match = NEW_LABEL(lineno); - LABEL *rescue_end = NEW_LABEL(lineno); - - ISEQ_COMPILE_DATA(iseq)->end_label = rescue_end; - - pm_rescue_node_t *rescue_node = (pm_rescue_node_t *)node; + pm_rescue_node_t *cast = (pm_rescue_node_t *) node; iseq_set_exception_local_table(iseq); - pm_node_list_t exception_list = rescue_node->exceptions; - if (exception_list.size > 0) { - for (size_t i = 0; i < exception_list.size; i++) { + // First, establish the labels that we need to be able to jump to within + // this compilation block. + LABEL *exception_match_label = NEW_LABEL(lineno); + LABEL *rescue_end_label = NEW_LABEL(lineno); + ISEQ_COMPILE_DATA(iseq)->end_label = rescue_end_label; + + // Next, compile each of the exceptions that we're going to be + // handling. For each one, we'll add instructions to check if the + // exception matches the raised one, and if it does then jump to the + // exception_match_label label. Otherwise it will fall through to the + // subsequent check. If there are no exceptions, we'll only check + // StandardError. + pm_node_list_t *exceptions = &cast->exceptions; + + if (exceptions->size > 0) { + for (size_t index = 0; index < exceptions->size; index++) { ADD_GETLOCAL(ret, &dummy_line_node, LVAR_ERRINFO, 0); - PM_COMPILE(exception_list.nodes[i]); + PM_COMPILE(exceptions->nodes[index]); ADD_INSN1(ret, &dummy_line_node, checkmatch, INT2FIX(VM_CHECKMATCH_TYPE_RESCUE)); - ADD_INSN1(ret, &dummy_line_node, branchif, excep_match); + ADD_INSN1(ret, &dummy_line_node, branchif, exception_match_label); } } else { ADD_GETLOCAL(ret, &dummy_line_node, LVAR_ERRINFO, 0); ADD_INSN1(ret, &dummy_line_node, putobject, rb_eStandardError); ADD_INSN1(ret, &dummy_line_node, checkmatch, INT2FIX(VM_CHECKMATCH_TYPE_RESCUE)); - ADD_INSN1(ret, &dummy_line_node, branchif, excep_match); + ADD_INSN1(ret, &dummy_line_node, branchif, exception_match_label); } - ADD_INSN1(ret, &dummy_line_node, jump, rescue_end); - ADD_LABEL(ret, excep_match); + // If none of the exceptions that we are matching against matched, then + // we'll jump straight to the rescue_end_label label. + ADD_INSN1(ret, &dummy_line_node, jump, rescue_end_label); + + // Here we have the exception_match_label, which is where the + // control-flow goes in the case that one of the exceptions matched. + // Here we will compile the instructions to handle the exception. + ADD_LABEL(ret, exception_match_label); ADD_TRACE(ret, RUBY_EVENT_RESCUE); - if (rescue_node->reference) { - ADD_GETLOCAL(ret, &dummy_line_node, LVAR_ERRINFO, 0); - PM_COMPILE((pm_node_t *)rescue_node->reference); + + // If we have a reference to the exception, then we'll compile the write + // into the instruction sequence. This can look quite different + // depending on the kind of write being performed. + if (cast->reference) { + switch (PM_NODE_TYPE(cast->reference)) { + case PM_CALL_TARGET_NODE: { + // begin; rescue => Foo.bar; end + const pm_call_target_node_t *reference = (const pm_call_target_node_t *) cast->reference; + ID method_id = pm_constant_id_lookup(scope_node, reference->name); + + PM_COMPILE((pm_node_t *) reference); + ADD_GETLOCAL(ret, &dummy_line_node, LVAR_ERRINFO, 0); + + ADD_SEND(ret, &dummy_line_node, method_id, INT2NUM(1)); + ADD_INSN(ret, &dummy_line_node, pop); + break; + } + case PM_CONSTANT_PATH_TARGET_NODE: { + // begin; rescue => Foo::Bar; end + const pm_constant_path_target_node_t *reference = (const pm_constant_path_target_node_t *) cast->reference; + const pm_constant_read_node_t *constant = (const pm_constant_read_node_t *) reference->child; + + PM_COMPILE((pm_node_t *) reference); + ADD_GETLOCAL(ret, &dummy_line_node, LVAR_ERRINFO, 0); + + ADD_INSN(ret, &dummy_line_node, swap); + ADD_INSN1(ret, &dummy_line_node, setconstant, ID2SYM(pm_constant_id_lookup(scope_node, constant->name))); + break; + } + case PM_INDEX_TARGET_NODE: { + // begin; rescue => foo[:bar]; end + const pm_index_target_node_t *reference = (const pm_index_target_node_t *) cast->reference; + + pm_callinfo_t callinfo = { 0 }; + pm_arguments_node_callinfo(&callinfo, reference->arguments, scope_node, parser); + + PM_COMPILE((pm_node_t *) reference); + ADD_GETLOCAL(ret, &dummy_line_node, LVAR_ERRINFO, 0); + + if (reference->block != NULL) { + callinfo.flags |= VM_CALL_ARGS_BLOCKARG; + PM_COMPILE_NOT_POPPED((pm_node_t *) reference->block); + } + + ADD_SEND_R(ret, &dummy_line_node, idASET, INT2FIX(callinfo.argc + 1), NULL, INT2FIX(callinfo.flags), callinfo.kwargs); + ADD_INSN(ret, &dummy_line_node, pop); + break; + } + default: + // Indirectly writing to a variable or constant. + ADD_GETLOCAL(ret, &dummy_line_node, LVAR_ERRINFO, 0); + PM_COMPILE((pm_node_t *) cast->reference); + break; + } } - if (rescue_node->statements) { - PM_COMPILE((pm_node_t *)rescue_node->statements); - } - else { + // If we have statements to execute, we'll compile them here. Otherwise + // we'll push nil onto the stack. + if (cast->statements) { + PM_COMPILE((pm_node_t *) cast->statements); + } else { PM_PUTNIL; } ADD_INSN(ret, &dummy_line_node, leave); - ADD_LABEL(ret, rescue_end); - if (rescue_node->consequent) { - PM_COMPILE((pm_node_t *)rescue_node->consequent); + // Here we'll insert the rescue_end_label label, which is jumped to if + // none of the exceptions matched. It will cause the control-flow to + // either jump to the next rescue clause or it will fall through to the + // subsequent instruction returning the raised error. + ADD_LABEL(ret, rescue_end_label); + if (cast->consequent) { + PM_COMPILE((pm_node_t *) cast->consequent); } else { ADD_GETLOCAL(ret, &dummy_line_node, 1, 0); }