[ruby/prism] Fix global variable read off end

https://github.com/ruby/prism/commit/3f2c34b53d
This commit is contained in:
Kevin Newton 2025-01-05 12:29:02 -05:00 committed by git
parent 22e9fa81ca
commit 179e2cfa91

View File

@ -1649,22 +1649,25 @@ pm_arguments_validate_block(pm_parser_t *parser, pm_arguments_t *arguments, pm_b
* the function pointer or can just directly use the UTF-8 functions. * the function pointer or can just directly use the UTF-8 functions.
*/ */
static inline size_t static inline size_t
char_is_identifier_start(const pm_parser_t *parser, const uint8_t *b) { char_is_identifier_start(const pm_parser_t *parser, const uint8_t *b, ptrdiff_t n) {
if (n <= 0) return 0;
if (parser->encoding_changed) { if (parser->encoding_changed) {
size_t width; size_t width;
if ((width = parser->encoding->alpha_char(b, parser->end - b)) != 0) {
if ((width = parser->encoding->alpha_char(b, n)) != 0) {
return width; return width;
} else if (*b == '_') { } else if (*b == '_') {
return 1; return 1;
} else if (*b >= 0x80) { } else if (*b >= 0x80) {
return parser->encoding->char_width(b, parser->end - b); return parser->encoding->char_width(b, n);
} else { } else {
return 0; return 0;
} }
} else if (*b < 0x80) { } else if (*b < 0x80) {
return (pm_encoding_unicode_table[*b] & PRISM_ENCODING_ALPHABETIC_BIT ? 1 : 0) || (*b == '_'); return (pm_encoding_unicode_table[*b] & PRISM_ENCODING_ALPHABETIC_BIT ? 1 : 0) || (*b == '_');
} else { } else {
return pm_encoding_utf_8_char_width(b, parser->end - b); return pm_encoding_utf_8_char_width(b, n);
} }
} }
@ -1673,11 +1676,13 @@ char_is_identifier_start(const pm_parser_t *parser, const uint8_t *b) {
* has not been changed. * has not been changed.
*/ */
static inline size_t static inline size_t
char_is_identifier_utf8(const uint8_t *b, const uint8_t *end) { char_is_identifier_utf8(const uint8_t *b, ptrdiff_t n) {
if (*b < 0x80) { if (n <= 0) {
return 0;
} else if (*b < 0x80) {
return (*b == '_') || (pm_encoding_unicode_table[*b] & PRISM_ENCODING_ALPHANUMERIC_BIT ? 1 : 0); return (*b == '_') || (pm_encoding_unicode_table[*b] & PRISM_ENCODING_ALPHANUMERIC_BIT ? 1 : 0);
} else { } else {
return pm_encoding_utf_8_char_width(b, end - b); return pm_encoding_utf_8_char_width(b, n);
} }
} }
@ -1687,20 +1692,24 @@ char_is_identifier_utf8(const uint8_t *b, const uint8_t *end) {
* it's important that it be as fast as possible. * it's important that it be as fast as possible.
*/ */
static inline size_t static inline size_t
char_is_identifier(const pm_parser_t *parser, const uint8_t *b) { char_is_identifier(const pm_parser_t *parser, const uint8_t *b, ptrdiff_t n) {
if (parser->encoding_changed) { if (n <= 0) {
return 0;
} else if (parser->encoding_changed) {
size_t width; size_t width;
if ((width = parser->encoding->alnum_char(b, parser->end - b)) != 0) {
if ((width = parser->encoding->alnum_char(b, n)) != 0) {
return width; return width;
} else if (*b == '_') { } else if (*b == '_') {
return 1; return 1;
} else if (*b >= 0x80) { } else if (*b >= 0x80) {
return parser->encoding->char_width(b, parser->end - b); return parser->encoding->char_width(b, n);
} else { } else {
return 0; return 0;
} }
} else {
return char_is_identifier_utf8(b, n);
} }
return char_is_identifier_utf8(b, parser->end);
} }
// Here we're defining a perfect hash for the characters that are allowed in // Here we're defining a perfect hash for the characters that are allowed in
@ -2895,7 +2904,7 @@ pm_call_node_writable_p(const pm_parser_t *parser, const pm_call_node_t *node) {
(node->message_loc.start != NULL) && (node->message_loc.start != NULL) &&
(node->message_loc.end[-1] != '!') && (node->message_loc.end[-1] != '!') &&
(node->message_loc.end[-1] != '?') && (node->message_loc.end[-1] != '?') &&
char_is_identifier_start(parser, node->message_loc.start) && char_is_identifier_start(parser, node->message_loc.start, parser->end - node->message_loc.start) &&
(node->opening_loc.start == NULL) && (node->opening_loc.start == NULL) &&
(node->arguments == NULL) && (node->arguments == NULL) &&
(node->block == NULL) (node->block == NULL)
@ -9082,10 +9091,10 @@ lex_global_variable(pm_parser_t *parser) {
parser->current.end++; parser->current.end++;
size_t width; size_t width;
if (parser->current.end < parser->end && (width = char_is_identifier(parser, parser->current.end)) > 0) { if ((width = char_is_identifier(parser, parser->current.end, parser->end - parser->current.end)) > 0) {
do { do {
parser->current.end += width; parser->current.end += width;
} while (parser->current.end < parser->end && (width = char_is_identifier(parser, parser->current.end)) > 0); } while ((width = char_is_identifier(parser, parser->current.end, parser->end - parser->current.end)) > 0);
// $0 isn't allowed to be followed by anything. // $0 isn't allowed to be followed by anything.
pm_diagnostic_id_t diag_id = parser->version == PM_OPTIONS_VERSION_CRUBY_3_3 ? PM_ERR_INVALID_VARIABLE_GLOBAL_3_3 : PM_ERR_INVALID_VARIABLE_GLOBAL; pm_diagnostic_id_t diag_id = parser->version == PM_OPTIONS_VERSION_CRUBY_3_3 ? PM_ERR_INVALID_VARIABLE_GLOBAL_3_3 : PM_ERR_INVALID_VARIABLE_GLOBAL;
@ -9114,10 +9123,10 @@ lex_global_variable(pm_parser_t *parser) {
default: { default: {
size_t width; size_t width;
if ((width = char_is_identifier(parser, parser->current.end)) > 0) { if ((width = char_is_identifier(parser, parser->current.end, parser->end - parser->current.end)) > 0) {
do { do {
parser->current.end += width; parser->current.end += width;
} while (allow_multiple && parser->current.end < parser->end && (width = char_is_identifier(parser, parser->current.end)) > 0); } while (allow_multiple && (width = char_is_identifier(parser, parser->current.end, parser->end - parser->current.end)) > 0);
} else if (pm_char_is_whitespace(peek(parser))) { } else if (pm_char_is_whitespace(peek(parser))) {
// If we get here, then we have a $ followed by whitespace, // If we get here, then we have a $ followed by whitespace,
// which is not allowed. // which is not allowed.
@ -9182,11 +9191,11 @@ lex_identifier(pm_parser_t *parser, bool previous_command_start) {
bool encoding_changed = parser->encoding_changed; bool encoding_changed = parser->encoding_changed;
if (encoding_changed) { if (encoding_changed) {
while (current_end < end && (width = char_is_identifier(parser, current_end)) > 0) { while ((width = char_is_identifier(parser, current_end, end - current_end)) > 0) {
current_end += width; current_end += width;
} }
} else { } else {
while (current_end < end && (width = char_is_identifier_utf8(current_end, end)) > 0) { while ((width = char_is_identifier_utf8(current_end, end - current_end)) > 0) {
current_end += width; current_end += width;
} }
} }
@ -9360,7 +9369,7 @@ lex_interpolation(pm_parser_t *parser, const uint8_t *pound) {
const uint8_t *variable = pound + 2; const uint8_t *variable = pound + 2;
if (*variable == '@' && pound + 3 < parser->end) variable++; if (*variable == '@' && pound + 3 < parser->end) variable++;
if (char_is_identifier_start(parser, variable)) { if (char_is_identifier_start(parser, variable, parser->end - variable)) {
// At this point we're sure that we've either hit an embedded instance // At this point we're sure that we've either hit an embedded instance
// or class variable. In this case we'll first need to check if we've // or class variable. In this case we'll first need to check if we've
// already consumed content. // already consumed content.
@ -9409,7 +9418,7 @@ lex_interpolation(pm_parser_t *parser, const uint8_t *pound) {
// or a global name punctuation character, then we've hit an embedded // or a global name punctuation character, then we've hit an embedded
// global variable. // global variable.
if ( if (
char_is_identifier_start(parser, check) || char_is_identifier_start(parser, check, parser->end - check) ||
(pound[2] != '-' && (pm_char_is_decimal_digit(pound[2]) || char_is_global_name_punctuation(pound[2]))) (pound[2] != '-' && (pm_char_is_decimal_digit(pound[2]) || char_is_global_name_punctuation(pound[2])))
) { ) {
// In this case we've hit an embedded global variable. First check to // In this case we've hit an embedded global variable. First check to
@ -10135,7 +10144,7 @@ lex_question_mark(pm_parser_t *parser) {
!(parser->encoding->alnum_char(parser->current.end, parser->end - parser->current.end) || peek(parser) == '_') || !(parser->encoding->alnum_char(parser->current.end, parser->end - parser->current.end) || peek(parser) == '_') ||
( (
(parser->current.end + encoding_width >= parser->end) || (parser->current.end + encoding_width >= parser->end) ||
!char_is_identifier(parser, parser->current.end + encoding_width) !char_is_identifier(parser, parser->current.end + encoding_width, parser->end - (parser->current.end + encoding_width))
) )
) { ) {
lex_state_set(parser, PM_LEX_STATE_END); lex_state_set(parser, PM_LEX_STATE_END);
@ -10155,21 +10164,22 @@ lex_question_mark(pm_parser_t *parser) {
static pm_token_type_t static pm_token_type_t
lex_at_variable(pm_parser_t *parser) { lex_at_variable(pm_parser_t *parser) {
pm_token_type_t type = match(parser, '@') ? PM_TOKEN_CLASS_VARIABLE : PM_TOKEN_INSTANCE_VARIABLE; pm_token_type_t type = match(parser, '@') ? PM_TOKEN_CLASS_VARIABLE : PM_TOKEN_INSTANCE_VARIABLE;
size_t width; const uint8_t *end = parser->end;
if (parser->current.end < parser->end && (width = char_is_identifier_start(parser, parser->current.end)) > 0) { size_t width;
if ((width = char_is_identifier_start(parser, parser->current.end, end - parser->current.end)) > 0) {
parser->current.end += width; parser->current.end += width;
while (parser->current.end < parser->end && (width = char_is_identifier(parser, parser->current.end)) > 0) { while ((width = char_is_identifier(parser, parser->current.end, end - parser->current.end)) > 0) {
parser->current.end += width; parser->current.end += width;
} }
} else if (parser->current.end < parser->end && pm_char_is_decimal_digit(*parser->current.end)) { } else if (parser->current.end < end && pm_char_is_decimal_digit(*parser->current.end)) {
pm_diagnostic_id_t diag_id = (type == PM_TOKEN_CLASS_VARIABLE) ? PM_ERR_INCOMPLETE_VARIABLE_CLASS : PM_ERR_INCOMPLETE_VARIABLE_INSTANCE; pm_diagnostic_id_t diag_id = (type == PM_TOKEN_CLASS_VARIABLE) ? PM_ERR_INCOMPLETE_VARIABLE_CLASS : PM_ERR_INCOMPLETE_VARIABLE_INSTANCE;
if (parser->version == PM_OPTIONS_VERSION_CRUBY_3_3) { if (parser->version == PM_OPTIONS_VERSION_CRUBY_3_3) {
diag_id = (type == PM_TOKEN_CLASS_VARIABLE) ? PM_ERR_INCOMPLETE_VARIABLE_CLASS_3_3 : PM_ERR_INCOMPLETE_VARIABLE_INSTANCE_3_3; diag_id = (type == PM_TOKEN_CLASS_VARIABLE) ? PM_ERR_INCOMPLETE_VARIABLE_CLASS_3_3 : PM_ERR_INCOMPLETE_VARIABLE_INSTANCE_3_3;
} }
size_t width = parser->encoding->char_width(parser->current.end, parser->end - parser->current.end); size_t width = parser->encoding->char_width(parser->current.end, end - parser->current.end);
PM_PARSER_ERR_TOKEN_FORMAT(parser, parser->current, diag_id, (int) ((parser->current.end + width) - parser->current.start), (const char *) parser->current.start); PM_PARSER_ERR_TOKEN_FORMAT(parser, parser->current, diag_id, (int) ((parser->current.end + width) - parser->current.start), (const char *) parser->current.start);
} else { } else {
pm_diagnostic_id_t diag_id = (type == PM_TOKEN_CLASS_VARIABLE) ? PM_ERR_CLASS_VARIABLE_BARE : PM_ERR_INSTANCE_VARIABLE_BARE; pm_diagnostic_id_t diag_id = (type == PM_TOKEN_CLASS_VARIABLE) ? PM_ERR_CLASS_VARIABLE_BARE : PM_ERR_INSTANCE_VARIABLE_BARE;
@ -11157,13 +11167,13 @@ parser_lex(pm_parser_t *parser) {
if (parser->current.end >= parser->end) { if (parser->current.end >= parser->end) {
parser->current.end = end; parser->current.end = end;
} else if (quote == PM_HEREDOC_QUOTE_NONE && (width = char_is_identifier(parser, parser->current.end)) == 0) { } else if (quote == PM_HEREDOC_QUOTE_NONE && (width = char_is_identifier(parser, parser->current.end, parser->end - parser->current.end)) == 0) {
parser->current.end = end; parser->current.end = end;
} else { } else {
if (quote == PM_HEREDOC_QUOTE_NONE) { if (quote == PM_HEREDOC_QUOTE_NONE) {
parser->current.end += width; parser->current.end += width;
while ((parser->current.end < parser->end) && (width = char_is_identifier(parser, parser->current.end))) { while ((width = char_is_identifier(parser, parser->current.end, parser->end - parser->current.end))) {
parser->current.end += width; parser->current.end += width;
} }
} else { } else {
@ -11348,7 +11358,7 @@ parser_lex(pm_parser_t *parser) {
} else { } else {
const uint8_t delim = peek_offset(parser, 1); const uint8_t delim = peek_offset(parser, 1);
if ((delim != '\'') && (delim != '"') && !char_is_identifier(parser, parser->current.end + 1)) { if ((delim != '\'') && (delim != '"') && !char_is_identifier(parser, parser->current.end + 1, parser->end - (parser->current.end + 1))) {
pm_parser_warn_token(parser, &parser->current, PM_WARN_AMBIGUOUS_PREFIX_AMPERSAND); pm_parser_warn_token(parser, &parser->current, PM_WARN_AMBIGUOUS_PREFIX_AMPERSAND);
} }
} }
@ -11786,7 +11796,7 @@ parser_lex(pm_parser_t *parser) {
default: { default: {
if (*parser->current.start != '_') { if (*parser->current.start != '_') {
size_t width = char_is_identifier_start(parser, parser->current.start); size_t width = char_is_identifier_start(parser, parser->current.start, parser->end - parser->current.start);
// If this isn't the beginning of an identifier, then // If this isn't the beginning of an identifier, then
// it's an invalid token as we've exhausted all of the // it's an invalid token as we've exhausted all of the
@ -13720,7 +13730,7 @@ parse_write(pm_parser_t *parser, pm_node_t *target, pm_token_t *operator, pm_nod
return target; return target;
} }
if (char_is_identifier_start(parser, call->message_loc.start)) { if (char_is_identifier_start(parser, call->message_loc.start, parser->end - call->message_loc.start)) {
// When we get here, we have a method call, because it was // When we get here, we have a method call, because it was
// previously marked as a method call but now we have an =. This // previously marked as a method call but now we have an =. This
// looks like: // looks like:
@ -17052,7 +17062,7 @@ pm_slice_is_valid_local(const pm_parser_t *parser, const uint8_t *start, const u
if (length == 0) return false; if (length == 0) return false;
// First ensure that it starts with a valid identifier starting character. // First ensure that it starts with a valid identifier starting character.
size_t width = char_is_identifier_start(parser, start); size_t width = char_is_identifier_start(parser, start, end - start);
if (width == 0) return false; if (width == 0) return false;
// Next, ensure that it's not an uppercase character. // Next, ensure that it's not an uppercase character.
@ -17065,7 +17075,7 @@ pm_slice_is_valid_local(const pm_parser_t *parser, const uint8_t *start, const u
// Next, iterate through all of the bytes of the string to ensure that they // Next, iterate through all of the bytes of the string to ensure that they
// are all valid identifier characters. // are all valid identifier characters.
const uint8_t *cursor = start + width; const uint8_t *cursor = start + width;
while ((cursor < end) && (width = char_is_identifier(parser, cursor))) cursor += width; while ((width = char_is_identifier(parser, cursor, end - cursor))) cursor += width;
return cursor == end; return cursor == end;
} }