[ruby/openssl] Reduce OpenSSL::Buffering#do_write overhead

[Bug #20972]

The `rb_str_new_freeze` was added in https://github.com/ruby/openssl/issues/452
to better handle concurrent use of a Socket, but SSL sockets can't be used
concurrently AFAIK, so we might as well just error cleanly.

By using `rb_str_locktmp` we can ensure attempts at concurrent write
will raise an error, be we avoid causing a copy of the bytes.

We also use the newer `String#append_as_bytes` method when available
to save on some more copies.

https://github.com/ruby/openssl/commit/0d8c17aa85

Co-Authored-By: luke.gru@gmail.com
This commit is contained in:
Jean Boussier 2024-12-21 11:27:09 +01:00 committed by git
parent ccb4ba45ed
commit 2f5d31d38a
2 changed files with 58 additions and 25 deletions

View File

@ -24,25 +24,21 @@ module OpenSSL::Buffering
# A buffer which will retain binary encoding.
class Buffer < String
BINARY = Encoding::BINARY
unless String.method_defined?(:append_as_bytes)
alias_method :_append, :<<
def append_as_bytes(string)
if string.encoding == Encoding::BINARY
_append(string)
else
_append(string.b)
end
def initialize
super
force_encoding(BINARY)
end
def << string
if string.encoding == BINARY
super(string)
else
super(string.b)
self
end
return self
end
alias concat <<
alias_method :concat, :append_as_bytes
alias_method :<<, :append_as_bytes
end
##
@ -352,22 +348,32 @@ module OpenSSL::Buffering
def do_write(s)
@wbuffer = Buffer.new unless defined? @wbuffer
@wbuffer << s
@wbuffer.force_encoding(Encoding::BINARY)
@wbuffer.append_as_bytes(s)
@sync ||= false
buffer_size = @wbuffer.size
buffer_size = @wbuffer.bytesize
if @sync or buffer_size > BLOCK_SIZE
nwrote = 0
begin
while nwrote < buffer_size do
begin
nwrote += syswrite(@wbuffer[nwrote, buffer_size - nwrote])
chunk = if nwrote > 0
@wbuffer.byteslice(nwrote, @wbuffer.bytesize)
else
@wbuffer
end
nwrote += syswrite(chunk)
rescue Errno::EAGAIN
retry
end
end
ensure
@wbuffer[0, nwrote] = ""
if nwrote < @wbuffer.bytesize
@wbuffer[0, nwrote] = ""
else
@wbuffer.clear
end
end
end
end

View File

@ -2054,28 +2054,32 @@ ossl_ssl_read_nonblock(int argc, VALUE *argv, VALUE self)
}
static VALUE
ossl_ssl_write_internal(VALUE self, VALUE str, VALUE opts)
ossl_ssl_write_internal_safe(VALUE _args)
{
VALUE *args = (VALUE*)_args;
VALUE self = args[0];
VALUE str = args[1];
VALUE opts = args[2];
SSL *ssl;
rb_io_t *fptr;
int num, nonblock = opts != Qfalse;
VALUE tmp, cb_state;
VALUE cb_state;
GetSSL(self, ssl);
if (!ssl_started(ssl))
rb_raise(eSSLError, "SSL session is not started yet");
tmp = rb_str_new_frozen(StringValue(str));
VALUE io = rb_attr_get(self, id_i_io);
GetOpenFile(io, fptr);
/* SSL_write(3ssl) manpage states num == 0 is undefined */
num = RSTRING_LENINT(tmp);
num = RSTRING_LENINT(str);
if (num == 0)
return INT2FIX(0);
for (;;) {
int nwritten = SSL_write(ssl, RSTRING_PTR(tmp), num);
int nwritten = SSL_write(ssl, RSTRING_PTR(str), num);
cb_state = rb_attr_get(self, ID_callback_state);
if (!NIL_P(cb_state)) {
@ -2116,6 +2120,29 @@ ossl_ssl_write_internal(VALUE self, VALUE str, VALUE opts)
}
}
static VALUE
ossl_ssl_write_internal(VALUE self, VALUE str, VALUE opts)
{
VALUE args[3] = {self, str, opts};
int state;
str = StringValue(str);
int frozen = RB_OBJ_FROZEN(str);
if (!frozen) {
str = rb_str_locktmp(str);
}
VALUE result = rb_protect(ossl_ssl_write_internal_safe, (VALUE)args, &state);
if (!frozen) {
rb_str_unlocktmp(str);
}
if (state) {
rb_jump_tag(state);
}
return result;
}
/*
* call-seq:
* ssl.syswrite(string) => Integer