[ruby/stringio] Add missing row separator encoding conversion

(https://github.com/ruby/stringio/pull/69)

The conversion logic is borrowed from ruby/ruby's io.c:
40391faeab/io.c (L4059-L4079)

Fix ruby/stringio#68

Reported by IWAMOTO Kouichi. Thanks!!!

https://github.com/ruby/stringio/commit/4b170c1a68
This commit is contained in:
Sutou Kouhei 2023-11-08 09:46:10 +09:00 committed by git
parent 2f07963609
commit 7ed37388fb
2 changed files with 49 additions and 22 deletions

View File

@ -1143,38 +1143,57 @@ struct getline_arg {
}; };
static struct getline_arg * static struct getline_arg *
prepare_getline_args(struct getline_arg *arg, int argc, VALUE *argv) prepare_getline_args(struct StringIO *ptr, struct getline_arg *arg, int argc, VALUE *argv)
{ {
VALUE str, lim, opts; VALUE rs, lim, opts;
long limit = -1; long limit = -1;
int respect_chomp; int respect_chomp;
argc = rb_scan_args(argc, argv, "02:", &str, &lim, &opts); argc = rb_scan_args(argc, argv, "02:", &rs, &lim, &opts);
respect_chomp = argc == 0 || !NIL_P(str); respect_chomp = argc == 0 || !NIL_P(rs);
switch (argc) { switch (argc) {
case 0: case 0:
str = rb_rs; rs = rb_rs;
break; break;
case 1: case 1:
if (!NIL_P(str) && !RB_TYPE_P(str, T_STRING)) { if (!NIL_P(rs) && !RB_TYPE_P(rs, T_STRING)) {
VALUE tmp = rb_check_string_type(str); VALUE tmp = rb_check_string_type(rs);
if (NIL_P(tmp)) { if (NIL_P(tmp)) {
limit = NUM2LONG(str); limit = NUM2LONG(rs);
str = rb_rs; rs = rb_rs;
} }
else { else {
str = tmp; rs = tmp;
} }
} }
break; break;
case 2: case 2:
if (!NIL_P(str)) StringValue(str); if (!NIL_P(rs)) StringValue(rs);
if (!NIL_P(lim)) limit = NUM2LONG(lim); if (!NIL_P(lim)) limit = NUM2LONG(lim);
break; break;
} }
arg->rs = str; if (!NIL_P(rs)) {
rb_encoding *enc_rs, *enc_io;
enc_rs = rb_enc_get(rs);
enc_io = get_enc(ptr);
if (enc_rs != enc_io &&
(rb_enc_str_coderange(rs) != ENC_CODERANGE_7BIT ||
(RSTRING_LEN(rs) > 0 && !rb_enc_asciicompat(enc_io)))) {
if (rs == rb_rs) {
rs = rb_enc_str_new(0, 0, enc_io);
rb_str_buf_cat_ascii(rs, "\n");
rs = rs;
}
else {
rb_raise(rb_eArgError, "encoding mismatch: %s IO with %s RS",
rb_enc_name(enc_io),
rb_enc_name(enc_rs));
}
}
}
arg->rs = rs;
arg->limit = limit; arg->limit = limit;
arg->chomp = 0; arg->chomp = 0;
if (!NIL_P(opts)) { if (!NIL_P(opts)) {
@ -1302,15 +1321,15 @@ strio_getline(struct getline_arg *arg, struct StringIO *ptr)
static VALUE static VALUE
strio_gets(int argc, VALUE *argv, VALUE self) strio_gets(int argc, VALUE *argv, VALUE self)
{ {
struct StringIO *ptr = readable(self);
struct getline_arg arg; struct getline_arg arg;
VALUE str; VALUE str;
if (prepare_getline_args(&arg, argc, argv)->limit == 0) { if (prepare_getline_args(ptr, &arg, argc, argv)->limit == 0) {
struct StringIO *ptr = readable(self);
return rb_enc_str_new(0, 0, get_enc(ptr)); return rb_enc_str_new(0, 0, get_enc(ptr));
} }
str = strio_getline(&arg, readable(self)); str = strio_getline(&arg, ptr);
rb_lastline_set(str); rb_lastline_set(str);
return str; return str;
} }
@ -1347,16 +1366,16 @@ static VALUE
strio_each(int argc, VALUE *argv, VALUE self) strio_each(int argc, VALUE *argv, VALUE self)
{ {
VALUE line; VALUE line;
struct StringIO *ptr = readable(self);
struct getline_arg arg; struct getline_arg arg;
StringIO(self);
RETURN_ENUMERATOR(self, argc, argv); RETURN_ENUMERATOR(self, argc, argv);
if (prepare_getline_args(&arg, argc, argv)->limit == 0) { if (prepare_getline_args(ptr, &arg, argc, argv)->limit == 0) {
rb_raise(rb_eArgError, "invalid limit: 0 for each_line"); rb_raise(rb_eArgError, "invalid limit: 0 for each_line");
} }
while (!NIL_P(line = strio_getline(&arg, readable(self)))) { while (!NIL_P(line = strio_getline(&arg, ptr))) {
rb_yield(line); rb_yield(line);
} }
return self; return self;
@ -1374,15 +1393,15 @@ static VALUE
strio_readlines(int argc, VALUE *argv, VALUE self) strio_readlines(int argc, VALUE *argv, VALUE self)
{ {
VALUE ary, line; VALUE ary, line;
struct StringIO *ptr = readable(self);
struct getline_arg arg; struct getline_arg arg;
StringIO(self); if (prepare_getline_args(ptr, &arg, argc, argv)->limit == 0) {
ary = rb_ary_new();
if (prepare_getline_args(&arg, argc, argv)->limit == 0) {
rb_raise(rb_eArgError, "invalid limit: 0 for readlines"); rb_raise(rb_eArgError, "invalid limit: 0 for readlines");
} }
while (!NIL_P(line = strio_getline(&arg, readable(self)))) { ary = rb_ary_new();
while (!NIL_P(line = strio_getline(&arg, ptr))) {
rb_ary_push(ary, line); rb_ary_push(ary, line);
} }
return ary; return ary;

View File

@ -88,6 +88,14 @@ class TestStringIO < Test::Unit::TestCase
assert_string("", Encoding::UTF_8, StringIO.new("foo").gets(0)) assert_string("", Encoding::UTF_8, StringIO.new("foo").gets(0))
end end
def test_gets_utf_16
stringio = StringIO.new("line1\nline2\nline3\n".encode("utf-16le"))
assert_equal("line1\n".encode("utf-16le"), stringio.gets)
assert_equal("line2\n".encode("utf-16le"), stringio.gets)
assert_equal("line3\n".encode("utf-16le"), stringio.gets)
assert_nil(stringio.gets)
end
def test_gets_chomp def test_gets_chomp
assert_equal(nil, StringIO.new("").gets(chomp: true)) assert_equal(nil, StringIO.new("").gets(chomp: true))
assert_equal("", StringIO.new("\n").gets(chomp: true)) assert_equal("", StringIO.new("\n").gets(chomp: true))