[ruby/prism] Fix global variable read off end

https://github.com/ruby/prism/commit/3f2c34b53d
This commit is contained in:
Kevin Newton 2025-01-05 12:29:02 -05:00 committed by git
parent 22e9fa81ca
commit 179e2cfa91

View File

@ -1649,22 +1649,25 @@ pm_arguments_validate_block(pm_parser_t *parser, pm_arguments_t *arguments, pm_b
* the function pointer or can just directly use the UTF-8 functions.
*/
static inline size_t
char_is_identifier_start(const pm_parser_t *parser, const uint8_t *b) {
char_is_identifier_start(const pm_parser_t *parser, const uint8_t *b, ptrdiff_t n) {
if (n <= 0) return 0;
if (parser->encoding_changed) {
size_t width;
if ((width = parser->encoding->alpha_char(b, parser->end - b)) != 0) {
if ((width = parser->encoding->alpha_char(b, n)) != 0) {
return width;
} else if (*b == '_') {
return 1;
} else if (*b >= 0x80) {
return parser->encoding->char_width(b, parser->end - b);
return parser->encoding->char_width(b, n);
} else {
return 0;
}
} else if (*b < 0x80) {
return (pm_encoding_unicode_table[*b] & PRISM_ENCODING_ALPHABETIC_BIT ? 1 : 0) || (*b == '_');
} else {
return pm_encoding_utf_8_char_width(b, parser->end - b);
return pm_encoding_utf_8_char_width(b, n);
}
}
@ -1673,11 +1676,13 @@ char_is_identifier_start(const pm_parser_t *parser, const uint8_t *b) {
* has not been changed.
*/
static inline size_t
char_is_identifier_utf8(const uint8_t *b, const uint8_t *end) {
if (*b < 0x80) {
char_is_identifier_utf8(const uint8_t *b, ptrdiff_t n) {
if (n <= 0) {
return 0;
} else if (*b < 0x80) {
return (*b == '_') || (pm_encoding_unicode_table[*b] & PRISM_ENCODING_ALPHANUMERIC_BIT ? 1 : 0);
} else {
return pm_encoding_utf_8_char_width(b, end - b);
return pm_encoding_utf_8_char_width(b, n);
}
}
@ -1687,20 +1692,24 @@ char_is_identifier_utf8(const uint8_t *b, const uint8_t *end) {
* it's important that it be as fast as possible.
*/
static inline size_t
char_is_identifier(const pm_parser_t *parser, const uint8_t *b) {
if (parser->encoding_changed) {
char_is_identifier(const pm_parser_t *parser, const uint8_t *b, ptrdiff_t n) {
if (n <= 0) {
return 0;
} else if (parser->encoding_changed) {
size_t width;
if ((width = parser->encoding->alnum_char(b, parser->end - b)) != 0) {
if ((width = parser->encoding->alnum_char(b, n)) != 0) {
return width;
} else if (*b == '_') {
return 1;
} else if (*b >= 0x80) {
return parser->encoding->char_width(b, parser->end - b);
return parser->encoding->char_width(b, n);
} else {
return 0;
}
} else {
return char_is_identifier_utf8(b, n);
}
return char_is_identifier_utf8(b, parser->end);
}
// Here we're defining a perfect hash for the characters that are allowed in
@ -2895,7 +2904,7 @@ pm_call_node_writable_p(const pm_parser_t *parser, const pm_call_node_t *node) {
(node->message_loc.start != NULL) &&
(node->message_loc.end[-1] != '!') &&
(node->message_loc.end[-1] != '?') &&
char_is_identifier_start(parser, node->message_loc.start) &&
char_is_identifier_start(parser, node->message_loc.start, parser->end - node->message_loc.start) &&
(node->opening_loc.start == NULL) &&
(node->arguments == NULL) &&
(node->block == NULL)
@ -9082,10 +9091,10 @@ lex_global_variable(pm_parser_t *parser) {
parser->current.end++;
size_t width;
if (parser->current.end < parser->end && (width = char_is_identifier(parser, parser->current.end)) > 0) {
if ((width = char_is_identifier(parser, parser->current.end, parser->end - parser->current.end)) > 0) {
do {
parser->current.end += width;
} while (parser->current.end < parser->end && (width = char_is_identifier(parser, parser->current.end)) > 0);
} while ((width = char_is_identifier(parser, parser->current.end, parser->end - parser->current.end)) > 0);
// $0 isn't allowed to be followed by anything.
pm_diagnostic_id_t diag_id = parser->version == PM_OPTIONS_VERSION_CRUBY_3_3 ? PM_ERR_INVALID_VARIABLE_GLOBAL_3_3 : PM_ERR_INVALID_VARIABLE_GLOBAL;
@ -9114,10 +9123,10 @@ lex_global_variable(pm_parser_t *parser) {
default: {
size_t width;
if ((width = char_is_identifier(parser, parser->current.end)) > 0) {
if ((width = char_is_identifier(parser, parser->current.end, parser->end - parser->current.end)) > 0) {
do {
parser->current.end += width;
} while (allow_multiple && parser->current.end < parser->end && (width = char_is_identifier(parser, parser->current.end)) > 0);
} while (allow_multiple && (width = char_is_identifier(parser, parser->current.end, parser->end - parser->current.end)) > 0);
} else if (pm_char_is_whitespace(peek(parser))) {
// If we get here, then we have a $ followed by whitespace,
// which is not allowed.
@ -9182,11 +9191,11 @@ lex_identifier(pm_parser_t *parser, bool previous_command_start) {
bool encoding_changed = parser->encoding_changed;
if (encoding_changed) {
while (current_end < end && (width = char_is_identifier(parser, current_end)) > 0) {
while ((width = char_is_identifier(parser, current_end, end - current_end)) > 0) {
current_end += width;
}
} else {
while (current_end < end && (width = char_is_identifier_utf8(current_end, end)) > 0) {
while ((width = char_is_identifier_utf8(current_end, end - current_end)) > 0) {
current_end += width;
}
}
@ -9360,7 +9369,7 @@ lex_interpolation(pm_parser_t *parser, const uint8_t *pound) {
const uint8_t *variable = pound + 2;
if (*variable == '@' && pound + 3 < parser->end) variable++;
if (char_is_identifier_start(parser, variable)) {
if (char_is_identifier_start(parser, variable, parser->end - variable)) {
// At this point we're sure that we've either hit an embedded instance
// or class variable. In this case we'll first need to check if we've
// already consumed content.
@ -9409,7 +9418,7 @@ lex_interpolation(pm_parser_t *parser, const uint8_t *pound) {
// or a global name punctuation character, then we've hit an embedded
// global variable.
if (
char_is_identifier_start(parser, check) ||
char_is_identifier_start(parser, check, parser->end - check) ||
(pound[2] != '-' && (pm_char_is_decimal_digit(pound[2]) || char_is_global_name_punctuation(pound[2])))
) {
// In this case we've hit an embedded global variable. First check to
@ -10135,7 +10144,7 @@ lex_question_mark(pm_parser_t *parser) {
!(parser->encoding->alnum_char(parser->current.end, parser->end - parser->current.end) || peek(parser) == '_') ||
(
(parser->current.end + encoding_width >= parser->end) ||
!char_is_identifier(parser, parser->current.end + encoding_width)
!char_is_identifier(parser, parser->current.end + encoding_width, parser->end - (parser->current.end + encoding_width))
)
) {
lex_state_set(parser, PM_LEX_STATE_END);
@ -10155,21 +10164,22 @@ lex_question_mark(pm_parser_t *parser) {
static pm_token_type_t
lex_at_variable(pm_parser_t *parser) {
pm_token_type_t type = match(parser, '@') ? PM_TOKEN_CLASS_VARIABLE : PM_TOKEN_INSTANCE_VARIABLE;
size_t width;
const uint8_t *end = parser->end;
if (parser->current.end < parser->end && (width = char_is_identifier_start(parser, parser->current.end)) > 0) {
size_t width;
if ((width = char_is_identifier_start(parser, parser->current.end, end - parser->current.end)) > 0) {
parser->current.end += width;
while (parser->current.end < parser->end && (width = char_is_identifier(parser, parser->current.end)) > 0) {
while ((width = char_is_identifier(parser, parser->current.end, end - parser->current.end)) > 0) {
parser->current.end += width;
}
} else if (parser->current.end < parser->end && pm_char_is_decimal_digit(*parser->current.end)) {
} else if (parser->current.end < end && pm_char_is_decimal_digit(*parser->current.end)) {
pm_diagnostic_id_t diag_id = (type == PM_TOKEN_CLASS_VARIABLE) ? PM_ERR_INCOMPLETE_VARIABLE_CLASS : PM_ERR_INCOMPLETE_VARIABLE_INSTANCE;
if (parser->version == PM_OPTIONS_VERSION_CRUBY_3_3) {
diag_id = (type == PM_TOKEN_CLASS_VARIABLE) ? PM_ERR_INCOMPLETE_VARIABLE_CLASS_3_3 : PM_ERR_INCOMPLETE_VARIABLE_INSTANCE_3_3;
}
size_t width = parser->encoding->char_width(parser->current.end, parser->end - parser->current.end);
size_t width = parser->encoding->char_width(parser->current.end, end - parser->current.end);
PM_PARSER_ERR_TOKEN_FORMAT(parser, parser->current, diag_id, (int) ((parser->current.end + width) - parser->current.start), (const char *) parser->current.start);
} else {
pm_diagnostic_id_t diag_id = (type == PM_TOKEN_CLASS_VARIABLE) ? PM_ERR_CLASS_VARIABLE_BARE : PM_ERR_INSTANCE_VARIABLE_BARE;
@ -11157,13 +11167,13 @@ parser_lex(pm_parser_t *parser) {
if (parser->current.end >= parser->end) {
parser->current.end = end;
} else if (quote == PM_HEREDOC_QUOTE_NONE && (width = char_is_identifier(parser, parser->current.end)) == 0) {
} else if (quote == PM_HEREDOC_QUOTE_NONE && (width = char_is_identifier(parser, parser->current.end, parser->end - parser->current.end)) == 0) {
parser->current.end = end;
} else {
if (quote == PM_HEREDOC_QUOTE_NONE) {
parser->current.end += width;
while ((parser->current.end < parser->end) && (width = char_is_identifier(parser, parser->current.end))) {
while ((width = char_is_identifier(parser, parser->current.end, parser->end - parser->current.end))) {
parser->current.end += width;
}
} else {
@ -11348,7 +11358,7 @@ parser_lex(pm_parser_t *parser) {
} else {
const uint8_t delim = peek_offset(parser, 1);
if ((delim != '\'') && (delim != '"') && !char_is_identifier(parser, parser->current.end + 1)) {
if ((delim != '\'') && (delim != '"') && !char_is_identifier(parser, parser->current.end + 1, parser->end - (parser->current.end + 1))) {
pm_parser_warn_token(parser, &parser->current, PM_WARN_AMBIGUOUS_PREFIX_AMPERSAND);
}
}
@ -11786,7 +11796,7 @@ parser_lex(pm_parser_t *parser) {
default: {
if (*parser->current.start != '_') {
size_t width = char_is_identifier_start(parser, parser->current.start);
size_t width = char_is_identifier_start(parser, parser->current.start, parser->end - parser->current.start);
// If this isn't the beginning of an identifier, then
// it's an invalid token as we've exhausted all of the
@ -13720,7 +13730,7 @@ parse_write(pm_parser_t *parser, pm_node_t *target, pm_token_t *operator, pm_nod
return target;
}
if (char_is_identifier_start(parser, call->message_loc.start)) {
if (char_is_identifier_start(parser, call->message_loc.start, parser->end - call->message_loc.start)) {
// When we get here, we have a method call, because it was
// previously marked as a method call but now we have an =. This
// looks like:
@ -17052,7 +17062,7 @@ pm_slice_is_valid_local(const pm_parser_t *parser, const uint8_t *start, const u
if (length == 0) return false;
// First ensure that it starts with a valid identifier starting character.
size_t width = char_is_identifier_start(parser, start);
size_t width = char_is_identifier_start(parser, start, end - start);
if (width == 0) return false;
// Next, ensure that it's not an uppercase character.
@ -17065,7 +17075,7 @@ pm_slice_is_valid_local(const pm_parser_t *parser, const uint8_t *start, const u
// Next, iterate through all of the bytes of the string to ensure that they
// are all valid identifier characters.
const uint8_t *cursor = start + width;
while ((cursor < end) && (width = char_is_identifier(parser, cursor))) cursor += width;
while ((width = char_is_identifier(parser, cursor, end - cursor))) cursor += width;
return cursor == end;
}