[ruby/prism] Fix up PR failings

https://github.com/ruby/prism/commit/11255f636e
This commit is contained in:
Kevin Newton 2023-10-12 11:38:57 -04:00
parent d06523bc52
commit 11e946da2f
4 changed files with 157 additions and 46 deletions

View File

@ -370,10 +370,9 @@ module Prism
@embexpr_balance -= 1 @embexpr_balance -= 1
when :on_tstring_content when :on_tstring_content
if embexpr_balance == 0 if embexpr_balance == 0
token.value.split(/(?<=\n)/).each_with_index do |line, index| line = token.value
next if line.strip.empty? && line.end_with?("\n")
next if !(dedent_next || index > 0)
if !(line.strip.empty? && line.end_with?("\n")) && dedent_next
leading = line[/\A(\s*)\n?/, 1] leading = line[/\A(\s*)\n?/, 1]
next_dedent = 0 next_dedent = 0
@ -430,6 +429,45 @@ module Prism
return results return results
end end
# If the minimum common whitespace is 0, then we need to concatenate
# string nodes together that are immediately adjacent.
if dedent == 0
results = []
embexpr_balance = 0
index = 0
max_index = tokens.length
while index < max_index
token = tokens[index]
index += 1
case token.event
when :on_embexpr_beg, :on_heredoc_beg
embexpr_balance += 1
results << token
when :on_embexpr_end, :on_heredoc_end
embexpr_balance -= 1
results << token
when :on_tstring_content
if embexpr_balance == 0
results << token
while index < max_index && tokens[index].event == :on_tstring_content
token.value << tokens[index].value
index += 1
end
else
results << token
end
else
results << token
end
end
return results
end
# Otherwise, we're going to run through each token in the list and # Otherwise, we're going to run through each token in the list and
# insert on_ignored_sp tokens for the amount of dedent that we need to # insert on_ignored_sp tokens for the amount of dedent that we need to
# perform. We also need to remove the dedent from the beginning of # perform. We also need to remove the dedent from the beginning of
@ -787,10 +825,6 @@ module Prism
# We sort by location to compare against Ripper's output # We sort by location to compare against Ripper's output
tokens.sort_by!(&:location) tokens.sort_by!(&:location)
if result_value.size - 1 > tokens.size
raise StandardError, "Lost tokens when performing lex_compat"
end
ParseResult.new(tokens, result.comments, result.errors, result.warnings, []) ParseResult.new(tokens, result.comments, result.errors, result.warnings, [])
end end
end end

View File

@ -6279,15 +6279,15 @@ escape_read(pm_parser_t *parser, pm_buffer_t *buffer, uint8_t flags) {
return; return;
} }
case '0': case '1': case '2': case '3': case '4': case '5': case '6': case '7': { case '0': case '1': case '2': case '3': case '4': case '5': case '6': case '7': {
uint8_t value = *parser->current.end - '0'; uint8_t value = (uint8_t) (*parser->current.end - '0');
parser->current.end++; parser->current.end++;
if (pm_char_is_octal_digit(peek(parser))) { if (pm_char_is_octal_digit(peek(parser))) {
value = (uint8_t) ((value << 3) | (*parser->current.end - '0')); value = ((uint8_t) (value << 3)) | ((uint8_t) (*parser->current.end - '0'));
parser->current.end++; parser->current.end++;
if (pm_char_is_octal_digit(peek(parser))) { if (pm_char_is_octal_digit(peek(parser))) {
value = (uint8_t) ((value << 3) | (*parser->current.end - '0')); value = ((uint8_t) (value << 3)) | ((uint8_t) (*parser->current.end - '0'));
parser->current.end++; parser->current.end++;
} }
} }
@ -6400,8 +6400,12 @@ escape_read(pm_parser_t *parser, pm_buffer_t *buffer, uint8_t flags) {
} }
case 'c': { case 'c': {
parser->current.end++; parser->current.end++;
uint8_t peeked = peek(parser); if (parser->current.end == parser->end) {
pm_parser_err_current(parser, PM_ERR_ESCAPE_INVALID_CONTROL);
return;
}
uint8_t peeked = peek(parser);
switch (peeked) { switch (peeked) {
case '?': { case '?': {
parser->current.end++; parser->current.end++;
@ -6436,8 +6440,12 @@ escape_read(pm_parser_t *parser, pm_buffer_t *buffer, uint8_t flags) {
} }
parser->current.end++; parser->current.end++;
uint8_t peeked = peek(parser); if (parser->current.end == parser->end) {
pm_parser_err_current(parser, PM_ERR_ESCAPE_INVALID_CONTROL);
return;
}
uint8_t peeked = peek(parser);
switch (peeked) { switch (peeked) {
case '?': { case '?': {
parser->current.end++; parser->current.end++;
@ -6472,8 +6480,12 @@ escape_read(pm_parser_t *parser, pm_buffer_t *buffer, uint8_t flags) {
} }
parser->current.end++; parser->current.end++;
uint8_t peeked = peek(parser); if (parser->current.end == parser->end) {
pm_parser_err_current(parser, PM_ERR_ESCAPE_INVALID_META);
return;
}
uint8_t peeked = peek(parser);
if (peeked == '\\') { if (peeked == '\\') {
if (flags & PM_ESCAPE_FLAG_META) { if (flags & PM_ESCAPE_FLAG_META) {
pm_parser_err_current(parser, PM_ERR_ESCAPE_INVALID_META_REPEAT); pm_parser_err_current(parser, PM_ERR_ESCAPE_INVALID_META_REPEAT);
@ -8073,7 +8085,8 @@ parser_lex(pm_parser_t *parser) {
// If this terminator doesn't actually close the list, then // If this terminator doesn't actually close the list, then
// we need to continue on past it. // we need to continue on past it.
if (lex_mode->as.list.nesting > 0) { if (lex_mode->as.list.nesting > 0) {
breakpoint = pm_strpbrk(parser, breakpoint + 1, breakpoints, parser->end - (breakpoint + 1)); parser->current.end = breakpoint + 1;
breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end);
lex_mode->as.list.nesting--; lex_mode->as.list.nesting--;
continue; continue;
} }
@ -8099,7 +8112,6 @@ parser_lex(pm_parser_t *parser) {
// and find the next breakpoint. // and find the next breakpoint.
if (*breakpoint == '\\') { if (*breakpoint == '\\') {
parser->current.end = breakpoint + 1; parser->current.end = breakpoint + 1;
pm_token_buffer_escape(parser, &token_buffer);
// If we've hit the end of the file, then break out of the // If we've hit the end of the file, then break out of the
// loop by setting the breakpoint to NULL. // loop by setting the breakpoint to NULL.
@ -8108,7 +8120,9 @@ parser_lex(pm_parser_t *parser) {
continue; continue;
} }
pm_token_buffer_escape(parser, &token_buffer);
uint8_t peeked = peek(parser); uint8_t peeked = peek(parser);
switch (peeked) { switch (peeked) {
case ' ': case ' ':
case '\f': case '\f':
@ -8185,13 +8199,19 @@ parser_lex(pm_parser_t *parser) {
// If we've hit the incrementor, then we need to skip past it // If we've hit the incrementor, then we need to skip past it
// and find the next breakpoint. // and find the next breakpoint.
assert(*breakpoint == lex_mode->as.list.incrementor); assert(*breakpoint == lex_mode->as.list.incrementor);
breakpoint = pm_strpbrk(parser, breakpoint + 1, breakpoints, parser->end - (breakpoint + 1)); parser->current.end = breakpoint + 1;
breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end);
lex_mode->as.list.nesting++; lex_mode->as.list.nesting++;
continue; continue;
} }
// If we were unable to find a breakpoint, then this token hits the end of if (parser->current.end > parser->current.start) {
// the file. pm_token_buffer_flush(parser, &token_buffer);
LEX(PM_TOKEN_STRING_CONTENT);
}
// If we were unable to find a breakpoint, then this token hits the
// end of the file.
LEX(PM_TOKEN_EOF); LEX(PM_TOKEN_EOF);
} }
case PM_LEX_REGEXP: { case PM_LEX_REGEXP: {
@ -8223,7 +8243,8 @@ parser_lex(pm_parser_t *parser) {
while (breakpoint != NULL) { while (breakpoint != NULL) {
// If we hit a null byte, skip directly past it. // If we hit a null byte, skip directly past it.
if (*breakpoint == '\0') { if (*breakpoint == '\0') {
breakpoint = pm_strpbrk(parser, breakpoint + 1, breakpoints, parser->end - (breakpoint + 1)); parser->current.end = breakpoint + 1;
breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end);
continue; continue;
} }
@ -8244,7 +8265,8 @@ parser_lex(pm_parser_t *parser) {
if (lex_mode->as.regexp.terminator != '\n') { if (lex_mode->as.regexp.terminator != '\n') {
// If the terminator is not a newline, then we can set // If the terminator is not a newline, then we can set
// the next breakpoint and continue. // the next breakpoint and continue.
breakpoint = pm_strpbrk(parser, breakpoint + 1, breakpoints, parser->end - (breakpoint + 1)); parser->current.end = breakpoint + 1;
breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end);
continue; continue;
} }
} }
@ -8253,7 +8275,8 @@ parser_lex(pm_parser_t *parser) {
// token to return. // token to return.
if (*breakpoint == lex_mode->as.regexp.terminator) { if (*breakpoint == lex_mode->as.regexp.terminator) {
if (lex_mode->as.regexp.nesting > 0) { if (lex_mode->as.regexp.nesting > 0) {
breakpoint = pm_strpbrk(parser, breakpoint + 1, breakpoints, parser->end - (breakpoint + 1)); parser->current.end = breakpoint + 1;
breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end);
lex_mode->as.regexp.nesting--; lex_mode->as.regexp.nesting--;
continue; continue;
} }
@ -8282,9 +8305,17 @@ parser_lex(pm_parser_t *parser) {
// and find the next breakpoint. // and find the next breakpoint.
if (*breakpoint == '\\') { if (*breakpoint == '\\') {
parser->current.end = breakpoint + 1; parser->current.end = breakpoint + 1;
pm_token_buffer_escape(parser, &token_buffer);
// If we've hit the end of the file, then break out of the
// loop by setting the breakpoint to NULL.
if (parser->current.end == parser->end) {
breakpoint = NULL;
continue;
}
pm_token_buffer_escape(parser, &token_buffer);
uint8_t peeked = peek(parser); uint8_t peeked = peek(parser);
switch (peeked) { switch (peeked) {
case '\r': case '\r':
parser->current.end++; parser->current.end++;
@ -8357,13 +8388,19 @@ parser_lex(pm_parser_t *parser) {
// If we've hit the incrementor, then we need to skip past it // If we've hit the incrementor, then we need to skip past it
// and find the next breakpoint. // and find the next breakpoint.
assert(*breakpoint == lex_mode->as.regexp.incrementor); assert(*breakpoint == lex_mode->as.regexp.incrementor);
breakpoint = pm_strpbrk(parser, breakpoint + 1, breakpoints, parser->end - (breakpoint + 1)); parser->current.end = breakpoint + 1;
breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end);
lex_mode->as.regexp.nesting++; lex_mode->as.regexp.nesting++;
continue; continue;
} }
// At this point, the breakpoint is NULL which means we were unable to if (parser->current.end > parser->current.start) {
// find anything before the end of the file. pm_token_buffer_flush(parser, &token_buffer);
LEX(PM_TOKEN_STRING_CONTENT);
}
// If we were unable to find a breakpoint, then this token hits the
// end of the file.
LEX(PM_TOKEN_EOF); LEX(PM_TOKEN_EOF);
} }
case PM_LEX_STRING: { case PM_LEX_STRING: {
@ -8397,7 +8434,8 @@ parser_lex(pm_parser_t *parser) {
// continue lexing. // continue lexing.
if (lex_mode->as.string.incrementor != '\0' && *breakpoint == lex_mode->as.string.incrementor) { if (lex_mode->as.string.incrementor != '\0' && *breakpoint == lex_mode->as.string.incrementor) {
lex_mode->as.string.nesting++; lex_mode->as.string.nesting++;
breakpoint = pm_strpbrk(parser, breakpoint + 1, breakpoints, parser->end - (breakpoint + 1)); parser->current.end = breakpoint + 1;
breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end);
continue; continue;
} }
@ -8408,7 +8446,8 @@ parser_lex(pm_parser_t *parser) {
// If this terminator doesn't actually close the string, then we need // If this terminator doesn't actually close the string, then we need
// to continue on past it. // to continue on past it.
if (lex_mode->as.string.nesting > 0) { if (lex_mode->as.string.nesting > 0) {
breakpoint = pm_strpbrk(parser, breakpoint + 1, breakpoints, parser->end - (breakpoint + 1)); parser->current.end = breakpoint + 1;
breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end);
lex_mode->as.string.nesting--; lex_mode->as.string.nesting--;
continue; continue;
} }
@ -8449,7 +8488,8 @@ parser_lex(pm_parser_t *parser) {
if (*breakpoint == '\n') { if (*breakpoint == '\n') {
if (parser->heredoc_end == NULL) { if (parser->heredoc_end == NULL) {
pm_newline_list_append(&parser->newline_list, breakpoint); pm_newline_list_append(&parser->newline_list, breakpoint);
breakpoint = pm_strpbrk(parser, breakpoint + 1, breakpoints, parser->end - (breakpoint + 1)); parser->current.end = breakpoint + 1;
breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end);
continue; continue;
} else { } else {
parser->current.end = breakpoint + 1; parser->current.end = breakpoint + 1;
@ -8462,12 +8502,12 @@ parser_lex(pm_parser_t *parser) {
switch (*breakpoint) { switch (*breakpoint) {
case '\0': case '\0':
// Skip directly past the null character. // Skip directly past the null character.
breakpoint = pm_strpbrk(parser, breakpoint + 1, breakpoints, parser->end - (breakpoint + 1)); parser->current.end = breakpoint + 1;
breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end);
break; break;
case '\\': { case '\\': {
// Here we hit escapes. // Here we hit escapes.
parser->current.end = breakpoint + 1; parser->current.end = breakpoint + 1;
pm_token_buffer_escape(parser, &token_buffer);
// If we've hit the end of the file, then break out of // If we've hit the end of the file, then break out of
// the loop by setting the breakpoint to NULL. // the loop by setting the breakpoint to NULL.
@ -8476,7 +8516,9 @@ parser_lex(pm_parser_t *parser) {
continue; continue;
} }
pm_token_buffer_escape(parser, &token_buffer);
uint8_t peeked = peek(parser); uint8_t peeked = peek(parser);
switch (peeked) { switch (peeked) {
case '\\': case '\\':
pm_token_buffer_push(&token_buffer, '\\'); pm_token_buffer_push(&token_buffer, '\\');
@ -8557,9 +8599,13 @@ parser_lex(pm_parser_t *parser) {
} }
} }
if (parser->current.end > parser->current.start) {
pm_token_buffer_flush(parser, &token_buffer);
LEX(PM_TOKEN_STRING_CONTENT);
}
// If we've hit the end of the string, then this is an unterminated // If we've hit the end of the string, then this is an unterminated
// string. In that case we'll return the EOF token. // string. In that case we'll return the EOF token.
parser->current.end = parser->end;
LEX(PM_TOKEN_EOF); LEX(PM_TOKEN_EOF);
} }
case PM_LEX_HEREDOC: { case PM_LEX_HEREDOC: {
@ -8649,7 +8695,8 @@ parser_lex(pm_parser_t *parser) {
switch (*breakpoint) { switch (*breakpoint) {
case '\0': case '\0':
// Skip directly past the null character. // Skip directly past the null character.
breakpoint = pm_strpbrk(parser, breakpoint + 1, breakpoints, parser->end - (breakpoint + 1)); parser->current.end = breakpoint + 1;
breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end);
break; break;
case '\n': { case '\n': {
if (parser->heredoc_end != NULL && (parser->heredoc_end > breakpoint)) { if (parser->heredoc_end != NULL && (parser->heredoc_end > breakpoint)) {
@ -8703,7 +8750,8 @@ parser_lex(pm_parser_t *parser) {
// Otherwise we hit a newline and it wasn't followed by // Otherwise we hit a newline and it wasn't followed by
// a terminator, so we can continue parsing. // a terminator, so we can continue parsing.
breakpoint = pm_strpbrk(parser, breakpoint + 1, breakpoints, parser->end - (breakpoint + 1)); parser->current.end = breakpoint + 1;
breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end);
break; break;
} }
case '\\': { case '\\': {
@ -8714,7 +8762,6 @@ parser_lex(pm_parser_t *parser) {
// newline so that we can still potentially find the // newline so that we can still potentially find the
// terminator of the heredoc. // terminator of the heredoc.
parser->current.end = breakpoint + 1; parser->current.end = breakpoint + 1;
pm_token_buffer_escape(parser, &token_buffer);
// If we've hit the end of the file, then break out of // If we've hit the end of the file, then break out of
// the loop by setting the breakpoint to NULL. // the loop by setting the breakpoint to NULL.
@ -8723,7 +8770,9 @@ parser_lex(pm_parser_t *parser) {
continue; continue;
} }
pm_token_buffer_escape(parser, &token_buffer);
uint8_t peeked = peek(parser); uint8_t peeked = peek(parser);
if (quote == PM_HEREDOC_QUOTE_SINGLE) { if (quote == PM_HEREDOC_QUOTE_SINGLE) {
switch (peeked) { switch (peeked) {
case '\r': case '\r':
@ -8796,9 +8845,14 @@ parser_lex(pm_parser_t *parser) {
was_escaped_newline = false; was_escaped_newline = false;
} }
if (parser->current.end > parser->current.start) {
parser->current.end = parser->end;
pm_token_buffer_flush(parser, &token_buffer);
LEX(PM_TOKEN_STRING_CONTENT);
}
// If we've hit the end of the string, then this is an unterminated // If we've hit the end of the string, then this is an unterminated
// heredoc. In that case we'll return the EOF token. // heredoc. In that case we'll return the EOF token.
parser->current.end = parser->end;
LEX(PM_TOKEN_EOF); LEX(PM_TOKEN_EOF);
} }
} }

View File

@ -136,7 +136,7 @@ module Prism
source = "<<-END + /b\nEND\n" source = "<<-END + /b\nEND\n"
assert_errors expression(source), source, [ assert_errors expression(source), source, [
["Expected a closing delimiter for the regular expression", 10..10] ["Expected a closing delimiter for the regular expression", 16..16]
] ]
end end

View File

@ -2,7 +2,7 @@
require_relative "test_helper" require_relative "test_helper"
return if Prism::BACKEND == :FFI return if RUBY_VERSION < "3.1.0" || Prism::BACKEND == :FFI
module Prism module Prism
class UnescapeTest < TestCase class UnescapeTest < TestCase
@ -53,22 +53,40 @@ module Prism
end end
class List < Base class List < Base
def ruby_result(escape) = ruby(escape) { |value| value.first.to_s } def ruby_result(escape)
def prism_result(escape) = prism(escape) { |node| node.elements.first.unescaped } ruby(escape) { |value| value.first.to_s }
end
def prism_result(escape)
prism(escape) { |node| node.elements.first.unescaped }
end
end end
class Symbol < Base class Symbol < Base
def ruby_result(escape) = ruby(escape, &:to_s) def ruby_result(escape)
def prism_result(escape) = prism(escape, &:unescaped) ruby(escape, &:to_s)
end
def prism_result(escape)
prism(escape, &:unescaped)
end
end end
class String < Base class String < Base
def ruby_result(escape) = ruby(escape, &:itself) def ruby_result(escape)
def prism_result(escape) = prism(escape, &:unescaped) ruby(escape, &:itself)
end
def prism_result(escape)
prism(escape, &:unescaped)
end
end end
class Heredoc < Base class Heredoc < Base
def ruby_result(escape) = ruby(escape, &:itself) def ruby_result(escape)
ruby(escape, &:itself)
end
def prism_result(escape) def prism_result(escape)
prism(escape) do |node| prism(escape) do |node|
case node.type case node.type
@ -82,8 +100,13 @@ module Prism
end end
class RegExp < Base class RegExp < Base
def ruby_result(escape) = ruby(escape, &:source) def ruby_result(escape)
def prism_result(escape) = prism(escape, &:unescaped) ruby(escape, &:source)
end
def prism_result(escape)
prism(escape, &:unescaped)
end
end end
end end