[rubygems/rubygems] Aggressively optimize allocations in SafeMarshal

Reduces allocations in a bundle install --full-index by an order of magnitude

Main wins are (a) getting rid of exessive string allocations for exception message stack

(b) Avoiding hash allocations caused by kwargs for #initialize

(c) avoid using unpack to do bit math, its easy enough to do by hand

(d) special case the most common elements so they can be read without an allocation

(e) avoid string allocations every time a symbol->string lookup is done by using symbol#name

https://github.com/rubygems/rubygems/commit/7d2ee51402
This commit is contained in:
Samuel Giddins 2023-09-20 21:32:09 -07:00 committed by git
parent a49d17a080
commit 0a423d4c4e
4 changed files with 252 additions and 88 deletions

View File

@ -7,14 +7,14 @@ module Gem
end
class Symbol < Element
def initialize(name:)
def initialize(name)
@name = name
end
attr_reader :name
end
class UserDefined < Element
def initialize(name:, binary_string:)
def initialize(name, binary_string)
@name = name
@binary_string = binary_string
end
@ -23,7 +23,7 @@ module Gem
end
class UserMarshal < Element
def initialize(name:, data:)
def initialize(name, data)
@name = name
@data = data
end
@ -32,7 +32,7 @@ module Gem
end
class String < Element
def initialize(str:)
def initialize(str)
@str = str
end
@ -40,7 +40,7 @@ module Gem
end
class Hash < Element
def initialize(pairs:)
def initialize(pairs)
@pairs = pairs
end
@ -48,8 +48,8 @@ module Gem
end
class HashWithDefaultValue < Hash
def initialize(default:, **kwargs)
super(**kwargs)
def initialize(pairs, default)
super(pairs)
@default = default
end
@ -57,7 +57,7 @@ module Gem
end
class Array < Element
def initialize(elements:)
def initialize(elements)
@elements = elements
end
@ -65,7 +65,7 @@ module Gem
end
class Integer < Element
def initialize(int:)
def initialize(int)
@int = int
end
@ -86,7 +86,7 @@ module Gem
end
class WithIvars < Element
def initialize(object:,ivars:)
def initialize(object, ivars)
@object = object
@ivars = ivars
end
@ -95,7 +95,7 @@ module Gem
end
class Object < Element
def initialize(name:)
def initialize(name)
@name = name
end
attr_reader :name
@ -106,28 +106,28 @@ module Gem
end
class ObjectLink < Element
def initialize(offset:)
def initialize(offset)
@offset = offset
end
attr_reader :offset
end
class SymbolLink < Element
def initialize(offset:)
def initialize(offset)
@offset = offset
end
attr_reader :offset
end
class Float < Element
def initialize(string:)
def initialize(string)
@string = string
end
attr_reader :string
end
class Bignum < Element # rubocop:disable Lint/UnifiedInteger
def initialize(sign:, data:)
def initialize(sign, data)
@sign = sign
@data = data
end

View File

@ -49,19 +49,19 @@ module Gem
when 0x00
0
when 0x01
@io.read(1).unpack1("C")
read_byte
when 0x02
@io.read(2).unpack1("S<")
read_byte | (read_byte << 8)
when 0x03
(@io.read(3) + "\0").unpack1("L<")
read_byte | (read_byte << 8) | (read_byte << 16)
when 0x04
@io.read(4).unpack1("L<")
read_byte | (read_byte << 8) | (read_byte << 16) | (read_byte << 24)
when 0xFC
@io.read(4).unpack1("L<") | -0x100000000
read_byte | (read_byte << 8) | (read_byte << 16) | (read_byte << 24) | -0x100000000
when 0xFD
(@io.read(3) + "\0").unpack1("L<") | -0x1000000
read_byte | (read_byte << 8) | (read_byte << 16) | -0x1000000
when 0xFE
@io.read(2).unpack1("s<") | -0x10000
read_byte | (read_byte << 8) | -0x10000
when 0xFF
read_byte | -0x100
else
@ -88,31 +88,51 @@ module Gem
when 85 then read_user_marshal # ?U
when 91 then read_array # ?[
when 102 then read_float # ?f
when 105 then Elements::Integer.new int: read_integer # ?i
when 105 then Elements::Integer.new(read_integer) # ?i
when 108 then read_bignum # ?l
when 111 then read_object # ?o
when 117 then read_user_defined # ?u
when 123 then read_hash # ?{
when 125 then read_hash_with_default_value # ?}
when "e".ord then read_extended_object
when "c".ord then read_class
when "m".ord then read_module
when "M".ord then read_class_or_module
when "d".ord then read_data
when "/".ord then read_regexp
when "S".ord then read_struct
when "C".ord then read_user_class
when 101 then read_extended_object # ?e
when 99 then read_class # ?c
when 109 then read_module # ?m
when 77 then read_class_or_module # ?M
when 100 then read_data # ?d
when 47 then read_regexp # ?/
when 83 then read_struct # ?S
when 67 then read_user_class # ?C
else
raise Error, "Unknown marshal type discriminator #{type.chr.inspect} (#{type})"
end
end
STRING_E_SYMBOL = Elements::Symbol.new("E").freeze
private_constant :STRING_E_SYMBOL
def read_symbol
Elements::Symbol.new name: @io.read(read_integer)
len = read_integer
if len == 1
byte = read_byte
if byte == 69 # ?E
STRING_E_SYMBOL
else
Elements::Symbol.new(byte.chr)
end
else
name = -@io.read(len)
Elements::Symbol.new(name)
end
end
EMPTY_STRING = Elements::String.new("".b.freeze).freeze
private_constant :EMPTY_STRING
def read_string
Elements::String.new(str: @io.read(read_integer))
length = read_integer
return EMPTY_STRING if length == 0
str = @io.read(length)
Elements::String.new(str)
end
def read_true
@ -124,55 +144,108 @@ module Gem
end
def read_user_defined
Elements::UserDefined.new(name: read_element, binary_string: @io.read(read_integer))
name = read_element
binary_string = @io.read(read_integer)
Elements::UserDefined.new(name, binary_string)
end
EMPTY_ARRAY = Elements::Array.new([].freeze).freeze
private_constant :EMPTY_ARRAY
def read_array
Elements::Array.new(elements: Array.new(read_integer) do |_i|
read_element
end)
length = read_integer
return EMPTY_ARRAY if length == 0
elements = Array.new(length) do
read_element
end
Elements::Array.new(elements)
end
def read_object_with_ivars
Elements::WithIvars.new(object: read_element, ivars:
Array.new(read_integer) do
[read_element, read_element]
end)
object = read_element
ivars = Array.new(read_integer) do
[read_element, read_element]
end
Elements::WithIvars.new(object, ivars)
end
def read_symbol_link
Elements::SymbolLink.new offset: read_integer
offset = read_integer
Elements::SymbolLink.new(offset)
end
def read_user_marshal
Elements::UserMarshal.new(name: read_element, data: read_element)
name = read_element
data = read_element
Elements::UserMarshal.new(name, data)
end
# profiling bundle install --full-index shows that
# offset 6 is by far the most common object link,
# so we special case it to avoid allocating a new
# object a third of the time.
# the following are all the object links that
# appear more than 10000 times in my profiling
OBJECT_LINKS = {
6 => Elements::ObjectLink.new(6).freeze,
30 => Elements::ObjectLink.new(30).freeze,
81 => Elements::ObjectLink.new(81).freeze,
34 => Elements::ObjectLink.new(34).freeze,
38 => Elements::ObjectLink.new(38).freeze,
50 => Elements::ObjectLink.new(50).freeze,
91 => Elements::ObjectLink.new(91).freeze,
42 => Elements::ObjectLink.new(42).freeze,
46 => Elements::ObjectLink.new(46).freeze,
150 => Elements::ObjectLink.new(150).freeze,
100 => Elements::ObjectLink.new(100).freeze,
104 => Elements::ObjectLink.new(104).freeze,
108 => Elements::ObjectLink.new(108).freeze,
242 => Elements::ObjectLink.new(242).freeze,
246 => Elements::ObjectLink.new(246).freeze,
139 => Elements::ObjectLink.new(139).freeze,
143 => Elements::ObjectLink.new(143).freeze,
114 => Elements::ObjectLink.new(114).freeze,
308 => Elements::ObjectLink.new(308).freeze,
200 => Elements::ObjectLink.new(200).freeze,
54 => Elements::ObjectLink.new(54).freeze,
62 => Elements::ObjectLink.new(62).freeze,
1_286_245 => Elements::ObjectLink.new(1_286_245).freeze,
}.freeze
private_constant :OBJECT_LINKS
def read_object_link
Elements::ObjectLink.new(offset: read_integer)
offset = read_integer
OBJECT_LINKS[offset] || Elements::ObjectLink.new(offset)
end
EMPTY_HASH = Elements::Hash.new([].freeze).freeze
private_constant :EMPTY_HASH
def read_hash
pairs = Array.new(read_integer) do
length = read_integer
return EMPTY_HASH if length == 0
pairs = Array.new(length) do
[read_element, read_element]
end
Elements::Hash.new(pairs: pairs)
Elements::Hash.new(pairs)
end
def read_hash_with_default_value
pairs = Array.new(read_integer) do
[read_element, read_element]
end
Elements::HashWithDefaultValue.new(pairs: pairs, default: read_element)
default = read_element
Elements::HashWithDefaultValue.new(pairs, default)
end
def read_object
Elements::WithIvars.new(
object: Elements::Object.new(name: read_element),
ivars: Array.new(read_integer) do
[read_element, read_element]
end
)
name = read_element
object = Elements::Object.new(name)
ivars = Array.new(read_integer) do
[read_element, read_element]
end
Elements::WithIvars.new(object, ivars)
end
def read_nil
@ -180,11 +253,14 @@ module Gem
end
def read_float
Elements::Float.new string: @io.read(read_integer)
string = @io.read(read_integer)
Elements::Float.new(string)
end
def read_bignum
Elements::Bignum.new(sign: read_byte, data: @io.read(read_integer * 2))
sign = read_byte
data = @io.read(read_integer * 2)
Elements::Bignum.new(sign, data)
end
def read_extended_object

View File

@ -7,7 +7,7 @@ module Gem::SafeMarshal
class ToRuby < Visitor
def initialize(permitted_classes:, permitted_symbols:, permitted_ivars:)
@permitted_classes = permitted_classes
@permitted_symbols = permitted_symbols | permitted_classes | ["E"]
@permitted_symbols = ["E"].concat(permitted_symbols).concat(permitted_classes)
@permitted_ivars = permitted_ivars
@objects = []
@ -15,6 +15,7 @@ module Gem::SafeMarshal
@class_cache = {}
@stack = ["root"]
@stack_idx = 1
end
def inspect # :nodoc:
@ -23,39 +24,61 @@ module Gem::SafeMarshal
end
def visit(target)
depth = @stack.size
stack_idx = @stack_idx
super
ensure
@stack.slice!(depth.pred..)
@stack_idx = stack_idx - 1
end
private
def push_stack(element)
@stack[@stack_idx] = element
@stack_idx += 1
end
def visit_Gem_SafeMarshal_Elements_Array(a)
register_object([]).replace(a.elements.each_with_index.map do |e, i|
@stack << "[#{i}]"
visit(e)
end)
array = register_object([])
elements = a.elements
size = elements.size
idx = 0
# not idiomatic, but there's a huge number of IMEMOs allocated here, so we avoid the block
# because this is such a hot path when doing a bundle install with the full index
until idx == size
push_stack idx
array << visit(elements[idx])
idx += 1
end
array
end
def visit_Gem_SafeMarshal_Elements_Symbol(s)
name = s.name
raise UnpermittedSymbolError.new(symbol: name, stack: @stack.dup) unless @permitted_symbols.include?(name)
raise UnpermittedSymbolError.new(symbol: name, stack: formatted_stack) unless @permitted_symbols.include?(name)
visit_symbol_type(s)
end
def map_ivars(klass, ivars)
stack_idx = @stack_idx
ivars.map.with_index do |(k, v), i|
@stack << "ivar_#{i}"
@stack_idx = stack_idx
push_stack "ivar_"
push_stack i
k = resolve_ivar(klass, k)
@stack[-1] = k
@stack_idx = stack_idx
push_stack k
next k, visit(v)
end
end
def visit_Gem_SafeMarshal_Elements_WithIvars(e)
object_offset = @objects.size
@stack << "object"
push_stack "object"
object = visit(e.object)
ivars = map_ivars(object.class, e.ivars)
@ -76,12 +99,18 @@ module Gem::SafeMarshal
s = e.object.binary_string
marshal_string = "\x04\bIu:\tTime#{(s.size + 5).chr}#{s.b}".b
marshal_string << (internal.size + 5).chr
marshal_string = "\x04\bIu:\tTime".b
marshal_string.concat(s.size + 5)
marshal_string << s
marshal_string.concat(internal.size + 5)
internal.each do |k, v|
marshal_string << ":#{(k.size + 5).chr}#{k}#{Marshal.dump(v)[2..-1]}"
marshal_string.concat(":")
marshal_string.concat(k.size + 5)
marshal_string.concat(k.to_s)
dumped = Marshal.dump(v)
dumped[0, 2] = ""
marshal_string.concat(dumped)
end
object = @objects[object_offset] = Marshal.load(marshal_string)
@ -108,7 +137,7 @@ module Gem::SafeMarshal
true
end
object.replace ::String.new(object, encoding: enc)
object.force_encoding(enc) if enc
end
ivars.each do |k, v|
@ -121,9 +150,9 @@ module Gem::SafeMarshal
hash = register_object({})
o.pairs.each_with_index do |(k, v), i|
@stack << i
push_stack i
k = visit(k)
@stack << k
push_stack k
hash[k] = visit(v)
end
@ -132,7 +161,7 @@ module Gem::SafeMarshal
def visit_Gem_SafeMarshal_Elements_HashWithDefaultValue(o)
hash = visit_Gem_SafeMarshal_Elements_Hash(o)
@stack << :default
push_stack :default
hash.default = visit(o.default)
hash
end
@ -159,7 +188,7 @@ module Gem::SafeMarshal
idx = @objects.size
object = register_object(call_method(compat || klass, :allocate))
@stack << :data
push_stack :data
ret = call_method(object, :marshal_load, visit(o.data))
if compat
@ -186,7 +215,7 @@ module Gem::SafeMarshal
end
def visit_Gem_SafeMarshal_Elements_String(s)
register_object(s.str)
register_object(+s.str)
end
def visit_Gem_SafeMarshal_Elements_Float(f)
@ -221,7 +250,7 @@ module Gem::SafeMarshal
def resolve_class(n)
@class_cache[n] ||= begin
to_s = resolve_symbol_name(n)
raise UnpermittedClassError.new(name: to_s, stack: @stack.dup) unless @permitted_classes.include?(to_s)
raise UnpermittedClassError.new(name: to_s, stack: formatted_stack) unless @permitted_classes.include?(to_s)
visit_symbol_type(n)
begin
::Object.const_get(to_s)
@ -238,16 +267,17 @@ module Gem::SafeMarshal
Rational(num, den)
end
end
private_constant :RationalCompat
COMPAT_CLASSES = {}.tap do |h|
h[Rational] = RationalCompat
end.freeze
end.compare_by_identity.freeze
private_constant :COMPAT_CLASSES
def resolve_ivar(klass, name)
to_s = resolve_symbol_name(name)
raise UnpermittedIvarError.new(symbol: to_s, klass: klass, stack: @stack.dup) unless @permitted_ivars.fetch(klass.name, [].freeze).include?(to_s)
raise UnpermittedIvarError.new(symbol: to_s, klass: klass, stack: formatted_stack) unless @permitted_ivars.fetch(klass.name, [].freeze).include?(to_s)
visit_symbol_type(name)
end
@ -263,14 +293,28 @@ module Gem::SafeMarshal
end
end
def resolve_symbol_name(element)
case element
when Elements::Symbol
element.name
when Elements::SymbolLink
visit_Gem_SafeMarshal_Elements_SymbolLink(element).to_s
else
raise FormatError, "Expected symbol or symbol link, got #{element.inspect} @ #{@stack.join(".")}"
# This is a hot method, so avoid respond_to? checks on every invocation
if :read.respond_to?(:name)
def resolve_symbol_name(element)
case element
when Elements::Symbol
element.name
when Elements::SymbolLink
visit_Gem_SafeMarshal_Elements_SymbolLink(element).name
else
raise FormatError, "Expected symbol or symbol link, got #{element.inspect} @ #{formatted_stack.join(".")}"
end
end
else
def resolve_symbol_name(element)
case element
when Elements::Symbol
element.name
when Elements::SymbolLink
visit_Gem_SafeMarshal_Elements_SymbolLink(element).to_s
else
raise FormatError, "Expected symbol or symbol link, got #{element.inspect} @ #{formatted_stack.join(".")}"
end
end
end
@ -287,6 +331,22 @@ module Gem::SafeMarshal
raise MethodCallError, "Unable to call #{method.inspect} on #{receiver.inspect}, perhaps it is a class using marshal compat, which is not visible in ruby? #{e}"
end
def formatted_stack
formatted = []
@stack[0, @stack_idx].each do |e|
if e.is_a?(Integer)
if formatted.last == "ivar_"
formatted[-1] = "ivar_#{e}"
else
formatted << "[#{e}]"
end
else
formatted << e
end
end
formatted
end
class Error < StandardError
end

View File

@ -17,7 +17,9 @@ class TestGemSafeMarshal < Gem::TestCase
define_method("test_safe_load_marshal Float 30000000.0") { assert_safe_load_marshal "\x04\bf\b3e7" }
define_method("test_safe_load_marshal Float -30000000.0") { assert_safe_load_marshal "\x04\bf\t-3e7" }
define_method("test_safe_load_marshal Gem::Version #<Gem::Version \"1.abc\">") { assert_safe_load_marshal "\x04\bU:\x11Gem::Version[\x06I\"\n1.abc\x06:\x06ET" }
define_method("test_safe_load_marshal Hash {}") { assert_safe_load_marshal "\x04\b}\x00[\x00" }
define_method("test_safe_load_marshal Hash {} default value") { assert_safe_load_marshal "\x04\b}\x00[\x00", additional_methods: [:default] }
define_method("test_safe_load_marshal Hash {}") { assert_safe_load_marshal "\x04\b{\x00" }
define_method("test_safe_load_marshal Array {}") { assert_safe_load_marshal "\x04\b[\x00" }
define_method("test_safe_load_marshal Hash {:runtime=>:development}") { assert_safe_load_marshal "\x04\bI{\x06:\fruntime:\x10development\x06:\n@type[\x00", permitted_ivars: { "Hash" => %w[@type] } }
define_method("test_safe_load_marshal Integer -1") { assert_safe_load_marshal "\x04\bi\xFA" }
define_method("test_safe_load_marshal Integer -1048575") { assert_safe_load_marshal "\x04\bi\xFD\x01\x00\xF0" }
@ -124,6 +126,12 @@ class TestGemSafeMarshal < Gem::TestCase
assert_safe_load_as [:development, :development]
end
def test_length_one_symbols
with_const(Gem::SafeMarshal, :PERMITTED_SYMBOLS, %w[E A b 0] << "") do
assert_safe_load_as [:A, :E, :E, :A, "".to_sym, "".to_sym], additional_methods: [:instance_variables]
end
end
def test_repeated_string
s = "hello"
a = [s]
@ -156,6 +164,12 @@ class TestGemSafeMarshal < Gem::TestCase
String.new("abc", encoding: "Windows-1256"),
String.new("abc", encoding: Encoding::BINARY),
String.new("abc", encoding: "UTF-32"),
String.new("", encoding: "US-ASCII"),
String.new("", encoding: "UTF-8"),
String.new("", encoding: "Windows-1256"),
String.new("", encoding: Encoding::BINARY),
String.new("", encoding: "UTF-32"),
].each do |s|
assert_safe_load_as s, additional_methods: [:encoding]
assert_safe_load_as [s, s], additional_methods: [->(a) { a.map(&:encoding) }]
@ -282,6 +296,20 @@ class TestGemSafeMarshal < Gem::TestCase
assert_equal e.message, "Attempting to load unpermitted symbol \"rspec\" @ root.[9].[0].@name"
end
def test_gem_spec_disallowed_ivar
e = assert_raise(Gem::SafeMarshal::Visitors::ToRuby::UnpermittedIvarError) do
spec = Gem::Specification.new do |s|
s.name = "hi"
s.version = "1.2.3"
s.dependencies << Gem::Dependency.new("rspec", Gem::Requirement.new([">= 1.2.3"]), :runtime).tap {|d| d.instance_variable_set(:@foobar, "rspec") }
end
Gem::SafeMarshal.safe_load(Marshal.dump(spec))
end
assert_equal e.message, "Attempting to set unpermitted ivar \"@foobar\" on object of class Gem::Dependency @ root.[9].[0].ivar_5"
end
def assert_safe_load_marshal(dumped, additional_methods: [], permitted_ivars: nil, equality: true, marshal_dump_equality: true)
loaded = Marshal.load(dumped)
safe_loaded =