From 2fbaff5351e6130929079d73575f0a00fe86a770 Mon Sep 17 00:00:00 2001
From: Alexander Momchilov <alexander.momchilov@shopify.com>
Date: Thu, 22 Aug 2024 18:30:59 -0400
Subject: [PATCH] [ruby/prism] Fix warning when `#!` ends with carriage return

https://github.com/ruby/prism/commit/5753fb6260
---
 prism/prism.c                      | 25 ++++++++++++------
 test/prism/result/warnings_test.rb | 42 ++++++++++++++++++++++++++++--
 2 files changed, 57 insertions(+), 10 deletions(-)

diff --git a/prism/prism.c b/prism/prism.c
index 4d8a97b66e..19883ba9df 100644
--- a/prism/prism.c
+++ b/prism/prism.c
@@ -21767,7 +21767,7 @@ pm_strnstr(const char *big, const char *little, size_t big_length) {
  */
 static void
 pm_parser_warn_shebang_carriage_return(pm_parser_t *parser, const uint8_t *start, size_t length) {
-    if (length > 2 && start[length - 1] == '\n' && start[length - 2] == '\r') {
+    if (length > 2 && start[length - 2] == '\r' && start[length - 1] == '\n') {
         pm_parser_warn(parser, start, start + length, PM_WARN_SHEBANG_CARRIAGE_RETURN);
     }
 }
@@ -21960,11 +21960,17 @@ pm_parser_init(pm_parser_t *parser, const uint8_t *source, size_t size, const pm
 
         const char *engine;
         if ((engine = pm_strnstr((const char *) parser->start, "ruby", length)) != NULL) {
-            pm_parser_warn_shebang_carriage_return(parser, parser->start, length);
-            if (newline != NULL) parser->encoding_comment_start = newline + 1;
+            if (newline != NULL) {
+                size_t length_including_newline = length + 1;
+                pm_parser_warn_shebang_carriage_return(parser, parser->start, length_including_newline);
+
+                parser->encoding_comment_start = newline + 1;
+            }
+
             if (options != NULL && options->shebang_callback != NULL) {
                 pm_parser_init_shebang(parser, options, engine, length - ((size_t) (engine - (const char *) parser->start)));
             }
+
             search_shebang = false;
         } else if (!parser->parsing_eval) {
             search_shebang = true;
@@ -21994,17 +22000,20 @@ pm_parser_init(pm_parser_t *parser, const uint8_t *source, size_t size, const pm
 
             size_t length = (size_t) ((newline != NULL ? newline : parser->end) - cursor);
             if (length > 2 && cursor[0] == '#' && cursor[1] == '!') {
-                if (parser->newline_list.size == 1) {
-                    pm_parser_warn_shebang_carriage_return(parser, cursor, length);
-                }
-
                 const char *engine;
                 if ((engine = pm_strnstr((const char *) cursor, "ruby", length)) != NULL) {
                     found_shebang = true;
-                    if (newline != NULL) parser->encoding_comment_start = newline + 1;
+                    if (newline != NULL) {
+                        size_t length_including_newline = length + 1;
+                        pm_parser_warn_shebang_carriage_return(parser, cursor, length_including_newline);
+
+                        parser->encoding_comment_start = newline + 1;
+                    }
+
                     if (options != NULL && options->shebang_callback != NULL) {
                         pm_parser_init_shebang(parser, options, engine, length - ((size_t) (engine - (const char *) cursor)));
                     }
+
                     break;
                 }
             }
diff --git a/test/prism/result/warnings_test.rb b/test/prism/result/warnings_test.rb
index 8ccaec74ed..aeac7f80e6 100644
--- a/test/prism/result/warnings_test.rb
+++ b/test/prism/result/warnings_test.rb
@@ -321,6 +321,44 @@ module Prism
       assert_warning("tap { redo; foo }", "statement not reached")
     end
 
+    def test_shebang_ending_with_carriage_return
+      msg = "shebang line ending with \\r may cause problems"
+
+      assert_warning(<<~RUBY, msg, compare: false)
+        #!ruby\r
+        p(123)
+      RUBY
+
+      assert_warning(<<~RUBY, msg, compare: false)
+        #!ruby \r
+        p(123)
+      RUBY
+
+      assert_warning(<<~RUBY, msg, compare: false)
+        #!ruby -Eutf-8\r
+        p(123)
+      RUBY
+
+      # Used with the `-x` object, to ignore the script up until the first shebang that mentioned "ruby".
+      assert_warning(<<~SCRIPT, msg, compare: false)
+        #!/usr/bin/env bash
+        # Some initial shell script or other content
+        # that Ruby should ignore
+        echo "This is shell script part"
+        exit 0
+
+        #! /usr/bin/env ruby -Eutf-8\r
+        # Ruby script starts here
+        puts "Hello from Ruby!"
+      SCRIPT
+
+      refute_warning("#ruby not_a_shebang\r\n", compare: false)
+
+      # CRuby doesn't emit the warning if a malformed file only has `\r` and not `\n`.
+      # https://bugs.ruby-lang.org/issues/20700
+      refute_warning("#!ruby\r", compare: false)
+    end
+
     def test_warnings_verbosity
       warning = Prism.parse("def foo; END { }; end").warnings.first
       assert_equal "END in method; use at_exit", warning.message
@@ -333,7 +371,7 @@ module Prism
 
     private
 
-    def assert_warning(source, *messages)
+    def assert_warning(source, *messages, compare: true)
       warnings = Prism.parse(source).warnings
       assert_equal messages.length, warnings.length, "Expected #{messages.length} warning(s) in #{source.inspect}, got #{warnings.map(&:message).inspect}"
 
@@ -341,7 +379,7 @@ module Prism
         assert_include warning.message, message
       end
 
-      if defined?(RubyVM::AbstractSyntaxTree)
+      if compare && defined?(RubyVM::AbstractSyntaxTree)
         stderr = capture_stderr { RubyVM::AbstractSyntaxTree.parse(source) }
         messages.each { |message| assert_include stderr, message }
       end