Improvements to IO::Buffer read/write/pread/pwrite. (#7826)

- Fix IO::Buffer `read`/`write` to use a minimum length.
This commit is contained in:
Samuel Williams 2023-05-24 10:17:35 +09:00 committed by GitHub
parent 12dfd9d1c9
commit 135a0d26a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
Notes: git 2023-05-24 01:17:56 +00:00
Merged-By: ioquatix <samuel@codeotaku.com>
2 changed files with 198 additions and 84 deletions

View File

@ -991,17 +991,23 @@ io_buffer_readonly_p(VALUE self)
return RBOOL(rb_io_buffer_readonly_p(self));
}
static void
io_buffer_lock(struct rb_io_buffer *buffer)
{
if (buffer->flags & RB_IO_BUFFER_LOCKED) {
rb_raise(rb_eIOBufferLockedError, "Buffer already locked!");
}
buffer->flags |= RB_IO_BUFFER_LOCKED;
}
VALUE
rb_io_buffer_lock(VALUE self)
{
struct rb_io_buffer *buffer = NULL;
TypedData_Get_Struct(self, struct rb_io_buffer, &rb_io_buffer_type, buffer);
if (buffer->flags & RB_IO_BUFFER_LOCKED) {
rb_raise(rb_eIOBufferLockedError, "Buffer already locked!");
}
buffer->flags |= RB_IO_BUFFER_LOCKED;
io_buffer_lock(buffer);
return self;
}
@ -2433,18 +2439,120 @@ io_buffer_default_size(size_t page_size)
return platform_agnostic_default_size;
}
struct io_buffer_blocking_region_argument {
struct rb_io_buffer *buffer;
rb_blocking_function_t *function;
void *data;
int descriptor;
};
static VALUE
io_buffer_blocking_region_begin(VALUE _argument)
{
struct io_buffer_blocking_region_argument *argument = (void*)_argument;
return rb_thread_io_blocking_region(argument->function, argument->data, argument->descriptor);
}
static VALUE
io_buffer_blocking_region_ensure(VALUE _argument)
{
struct io_buffer_blocking_region_argument *argument = (void*)_argument;
io_buffer_unlock(argument->buffer);
return Qnil;
}
static VALUE
io_buffer_blocking_region(struct rb_io_buffer *buffer, rb_blocking_function_t *function, void *data, int descriptor)
{
struct io_buffer_blocking_region_argument argument = {
.buffer = buffer,
.function = function,
.data = data,
.descriptor = descriptor,
};
// If the buffer is already locked, we can skip the ensure (unlock):
if (buffer->flags & RB_IO_BUFFER_LOCKED) {
return io_buffer_blocking_region_begin((VALUE)&argument);
}
else {
// The buffer should be locked for the duration of the blocking region:
io_buffer_lock(buffer);
return rb_ensure(io_buffer_blocking_region_begin, (VALUE)&argument, io_buffer_blocking_region_ensure, (VALUE)&argument);
}
}
static inline struct rb_io_buffer *
io_buffer_extract_arguments(VALUE self, int argc, VALUE argv[], size_t *length, size_t *offset)
{
struct rb_io_buffer *buffer = NULL;
TypedData_Get_Struct(self, struct rb_io_buffer, &rb_io_buffer_type, buffer);
*offset = 0;
if (argc >= 2) {
if (rb_int_negative_p(argv[1])) {
rb_raise(rb_eArgError, "Offset can't be negative!");
}
*offset = NUM2SIZET(argv[1]);
}
if (argc >= 1 && !NIL_P(argv[0])) {
if (rb_int_negative_p(argv[0])) {
rb_raise(rb_eArgError, "Length can't be negative!");
}
*length = NUM2SIZET(argv[0]);
}
else {
*length = buffer->size - *offset;
}
return buffer;
}
struct io_buffer_read_internal_argument {
int descriptor;
void *base;
// The base pointer to read from:
char *base;
// The size of the buffer:
size_t size;
// The minimum number of bytes to read:
size_t length;
};
static VALUE
io_buffer_read_internal(void *_argument)
{
size_t total = 0;
struct io_buffer_read_internal_argument *argument = _argument;
ssize_t result = read(argument->descriptor, argument->base, argument->size);
return rb_fiber_scheduler_io_result(result, errno);
while (true) {
ssize_t result = read(argument->descriptor, argument->base, argument->size);
if (result < 0) {
return rb_fiber_scheduler_io_result(result, errno);
}
else if (result == 0) {
return rb_fiber_scheduler_io_result(total, 0);
}
else {
total += result;
if (total >= argument->length) {
return rb_fiber_scheduler_io_result(total, 0);
}
argument->base = argument->base + result;
argument->size = argument->size - result;
}
}
}
VALUE
@ -2475,10 +2583,11 @@ rb_io_buffer_read(VALUE self, VALUE io, size_t length, size_t offset)
struct io_buffer_read_internal_argument argument = {
.descriptor = descriptor,
.base = base,
.size = length,
.size = size,
.length = length,
};
return rb_thread_io_blocking_region(io_buffer_read_internal, &argument, descriptor);
return io_buffer_blocking_region(buffer, io_buffer_read_internal, &argument, descriptor);
}
/*
@ -2508,23 +2617,12 @@ rb_io_buffer_read(VALUE self, VALUE io, size_t length, size_t offset)
static VALUE
io_buffer_read(int argc, VALUE *argv, VALUE self)
{
rb_check_arity(argc, 2, 3);
rb_check_arity(argc, 1, 3);
VALUE io = argv[0];
if (rb_int_negative_p(argv[1])) {
rb_raise(rb_eArgError, "Length can't be negative!");
}
size_t length = NUM2SIZET(argv[1]);
size_t offset = 0;
if (argc >= 3) {
if (rb_int_negative_p(argv[2])) {
rb_raise(rb_eArgError, "Offset can't be negative!");
}
offset = NUM2SIZET(argv[2]);
}
size_t length, offset;
io_buffer_extract_arguments(self, argc-1, argv+1, &length, &offset);
return rb_io_buffer_read(self, io, length, offset);
}
@ -2597,7 +2695,7 @@ rb_io_buffer_pread(VALUE self, VALUE io, rb_off_t from, size_t length, size_t of
.offset = from,
};
return rb_thread_io_blocking_region(io_buffer_pread_internal, &argument, descriptor);
return io_buffer_blocking_region(buffer, io_buffer_pread_internal, &argument, descriptor);
}
/*
@ -2629,41 +2727,55 @@ rb_io_buffer_pread(VALUE self, VALUE io, rb_off_t from, size_t length, size_t of
static VALUE
io_buffer_pread(int argc, VALUE *argv, VALUE self)
{
rb_check_arity(argc, 3, 4);
rb_check_arity(argc, 2, 4);
VALUE io = argv[0];
rb_off_t from = NUM2OFFT(argv[1]);
size_t length;
if (rb_int_negative_p(argv[2])) {
rb_raise(rb_eArgError, "Length can't be negative!");
}
length = NUM2SIZET(argv[2]);
size_t offset = 0;
if (argc >= 4) {
if (rb_int_negative_p(argv[3])) {
rb_raise(rb_eArgError, "Offset can't be negative!");
}
offset = NUM2SIZET(argv[3]);
}
size_t length, offset;
io_buffer_extract_arguments(self, argc-2, argv+2, &length, &offset);
return rb_io_buffer_pread(self, io, from, length, offset);
}
struct io_buffer_write_internal_argument {
int descriptor;
const void *base;
// The base pointer to write from:
const char *base;
// The size of the buffer:
size_t size;
// The minimum length to write:
size_t length;
};
static VALUE
io_buffer_write_internal(void *_argument)
{
size_t total = 0;
struct io_buffer_write_internal_argument *argument = _argument;
ssize_t result = write(argument->descriptor, argument->base, argument->size);
return rb_fiber_scheduler_io_result(result, errno);
while (true) {
ssize_t result = write(argument->descriptor, argument->base, argument->size);
if (result < 0) {
return rb_fiber_scheduler_io_result(result, errno);
}
else if (result == 0) {
return rb_fiber_scheduler_io_result(total, 0);
}
else {
total += result;
if (total >= argument->length) {
return rb_fiber_scheduler_io_result(total, 0);
}
argument->base = argument->base + result;
argument->size = argument->size - result;
}
}
}
VALUE
@ -2694,18 +2806,22 @@ rb_io_buffer_write(VALUE self, VALUE io, size_t length, size_t offset)
struct io_buffer_write_internal_argument argument = {
.descriptor = descriptor,
.base = base,
.size = length,
.size = size,
.length = length,
};
return rb_thread_io_blocking_region(io_buffer_write_internal, &argument, descriptor);
return io_buffer_blocking_region(buffer, io_buffer_write_internal, &argument, descriptor);
}
/*
* call-seq: write(io, length, [offset]) -> written length or -errno
* call-seq: write(io, [length, [offset]]) -> written length or -errno
*
* Writes +length+ bytes from buffer into +io+, starting at
* Writes at least +length+ bytes from buffer into +io+, starting at
* +offset+ in the buffer. If an error occurs, return <tt>-errno</tt>.
*
* If +length+ is not given or nil, the whole buffer is written, minus
* the offset. If +length+ is zero, write will be called once.
*
* If +offset+ is not given, the bytes are taken from the beginning
* of the buffer.
*
@ -2717,23 +2833,12 @@ rb_io_buffer_write(VALUE self, VALUE io, size_t length, size_t offset)
static VALUE
io_buffer_write(int argc, VALUE *argv, VALUE self)
{
rb_check_arity(argc, 2, 3);
rb_check_arity(argc, 1, 3);
VALUE io = argv[0];
if (rb_int_negative_p(argv[1])) {
rb_raise(rb_eArgError, "Length can't be negative!");
}
size_t length = NUM2SIZET(argv[1]);
size_t offset = 0;
if (argc >= 3) {
if (rb_int_negative_p(argv[2])) {
rb_raise(rb_eArgError, "Offset can't be negative!");
}
offset = NUM2SIZET(argv[2]);
}
size_t length, offset;
io_buffer_extract_arguments(self, argc-1, argv+1, &length, &offset);
return rb_io_buffer_write(self, io, length, offset);
}
@ -2806,7 +2911,7 @@ rb_io_buffer_pwrite(VALUE self, VALUE io, rb_off_t from, size_t length, size_t o
.offset = from,
};
return rb_thread_io_blocking_region(io_buffer_pwrite_internal, &argument, descriptor);
return io_buffer_blocking_region(buffer, io_buffer_pwrite_internal, &argument, descriptor);
}
/*
@ -2828,25 +2933,13 @@ rb_io_buffer_pwrite(VALUE self, VALUE io, rb_off_t from, size_t length, size_t o
static VALUE
io_buffer_pwrite(int argc, VALUE *argv, VALUE self)
{
rb_check_arity(argc, 3, 4);
rb_check_arity(argc, 2, 4);
VALUE io = argv[0];
rb_off_t from = NUM2OFFT(argv[1]);
size_t length;
if (rb_int_negative_p(argv[2])) {
rb_raise(rb_eArgError, "Length can't be negative!");
}
length = NUM2SIZET(argv[2]);
size_t offset = 0;
if (argc >= 4) {
if (rb_int_negative_p(argv[3])) {
rb_raise(rb_eArgError, "Offset can't be negative!");
}
offset = NUM2SIZET(argv[3]);
}
size_t length, offset;
io_buffer_extract_arguments(self, argc-2, argv+2, &length, &offset);
return rb_io_buffer_pwrite(self, io, from, length, offset);
}

View File

@ -361,17 +361,38 @@ class TestIOBuffer < Test::Unit::TestCase
input.close
end
def test_read
def hello_world_tempfile
io = Tempfile.new
io.write("Hello World")
io.seek(0)
buffer = IO::Buffer.new(128)
buffer.read(io, 5)
assert_equal "Hello", buffer.get_string(0, 5)
yield io
ensure
io.close! if io
io&.close!
end
def test_read
hello_world_tempfile do |io|
buffer = IO::Buffer.new(128)
buffer.read(io)
assert_equal "Hello", buffer.get_string(0, 5)
end
end
def test_read_with_with_length
hello_world_tempfile do |io|
buffer = IO::Buffer.new(128)
buffer.read(io, 5)
assert_equal "Hello", buffer.get_string(0, 5)
end
end
def test_read_with_with_offset
hello_world_tempfile do |io|
buffer = IO::Buffer.new(128)
buffer.read(io, nil, 6)
assert_equal "Hello", buffer.get_string(6, 5)
end
end
def test_write
@ -379,7 +400,7 @@ class TestIOBuffer < Test::Unit::TestCase
buffer = IO::Buffer.new(128)
buffer.set_string("Hello")
buffer.write(io, 5)
buffer.write(io)
io.seek(0)
assert_equal "Hello", io.read(5)