[rubygems/rubygems] Add a Marshal.load replacement that walks an AST to safely load permitted classes/symbols

https://github.com/rubygems/rubygems/commit/7e4478fe73
This commit is contained in:
Samuel Giddins 2023-08-18 13:35:23 -07:00 committed by git
parent c47608494f
commit d182d83ce9
10 changed files with 895 additions and 5 deletions

View File

@ -604,6 +604,16 @@ An Array (#{env.inspect}) was passed in from #{caller[3]}
@yaml_loaded = true
end
@safe_marshal_loaded = false
def self.load_safe_marshal
return if @safe_marshal_loaded
require_relative "rubygems/safe_marshal"
@safe_marshal_loaded = true
end
##
# The file name and line number of the caller of the caller of this method.
#

View File

@ -411,7 +411,8 @@ class Gem::Indexer
# +dest+. For a latest index, does not ensure the new file is minimal.
def update_specs_index(index, source, dest)
specs_index = Marshal.load Gem.read_binary(source)
Gem.load_safe_marshal
specs_index = Gem::SafeMarshal.safe_load Gem.read_binary(source)
index.each do |spec|
platform = spec.original_platform

View File

@ -0,0 +1,71 @@
# frozen_string_literal: true
require_relative "safe_marshal/reader"
require_relative "safe_marshal/visitors/to_ruby"
module Gem
###
# This module is used for safely loading Marshal specs from a gem. The
# `safe_load` method defined on this module is specifically designed for
# loading Gem specifications.
module SafeMarshal
PERMITTED_CLASSES = %w[
Time
Date
Gem::Dependency
Gem::NameTuple
Gem::Platform
Gem::Requirement
Gem::Specification
Gem::Version
Gem::Version::Requirement
YAML::Syck::DefaultKey
YAML::PrivateType
].freeze
private_constant :PERMITTED_CLASSES
PERMITTED_SYMBOLS = %w[
E
offset
zone
nano_num
nano_den
submicro
@_zone
@cpu
@force_ruby_platform
@marshal_with_utc_coercion
@name
@os
@platform
@prerelease
@requirement
@taguri
@type
@type_id
@value
@version
@version_requirement
@version_requirements
development
runtime
].freeze
private_constant :PERMITTED_SYMBOLS
def self.safe_load(input)
load(input, permitted_classes: PERMITTED_CLASSES, permitted_symbols: PERMITTED_SYMBOLS)
end
def self.load(input, permitted_classes: [::Symbol], permitted_symbols: [])
root = Reader.new(StringIO.new(input, "r")).read!
Visitors::ToRuby.new(permitted_classes: permitted_classes, permitted_symbols: permitted_symbols).visit(root)
end
end
end

View File

@ -0,0 +1,138 @@
# frozen_string_literal: true
module Gem
module SafeMarshal
module Elements
class Element
end
class Symbol < Element
def initialize(name:)
@name = name
end
attr_reader :name
end
class UserDefined < Element
def initialize(name:, binary_string:)
@name = name
@binary_string = binary_string
end
attr_reader :name, :binary_string
end
class UserMarshal < Element
def initialize(name:, data:)
@name = name
@data = data
end
attr_reader :name, :data
end
class String < Element
def initialize(str:)
@str = str
end
attr_reader :str
end
class Hash < Element
def initialize(pairs:)
@pairs = pairs
end
attr_reader :pairs
end
class HashWithDefaultValue < Hash
def initialize(default:, **kwargs)
super(**kwargs)
@default = default
end
attr_reader :default
end
class Array < Element
def initialize(elements:)
@elements = elements
end
attr_reader :elements
end
class Integer < Element
def initialize(int:)
@int = int
end
attr_reader :int
end
class True < Element
def initialize
end
TRUE = new.freeze
end
class False < Element
def initialize
end
FALSE = new.freeze
end
class WithIvars < Element
def initialize(object:,ivars:)
@object = object
@ivars = ivars
end
attr_reader :object, :ivars
end
class Object < Element
def initialize(name:)
@name = name
end
attr_reader :name
end
class Nil < Element
NIL = new.freeze
end
class ObjectLink < Element
def initialize(offset:)
@offset = offset
end
attr_reader :offset
end
class SymbolLink < Element
def initialize(offset:)
@offset = offset
end
attr_reader :offset
end
class Float < Element
def initialize(string:)
@string = string
end
attr_reader :string
end
class Bignum < Element # rubocop:disable Lint/UnifiedInteger
def initialize(sign:, data:)
@sign = sign
@data = data
end
attr_reader :sign, :data
end
end
end
end

View File

@ -0,0 +1,182 @@
# frozen_string_literal: true
require_relative "elements"
module Gem
module SafeMarshal
class Reader
class UnconsumedBytesError < StandardError
end
def initialize(io)
@io = io
end
def read!
read_header
root = read_element
raise UnconsumedBytesError unless @io.eof?
root
end
private
MARSHAL_VERSION = [Marshal::MAJOR_VERSION, Marshal::MINOR_VERSION].map(&:chr).join.freeze
private_constant :MARSHAL_VERSION
def read_header
v = @io.read(2)
raise "Unsupported marshal version #{v.inspect}, expected #{MARSHAL_VERSION.inspect}" unless v == MARSHAL_VERSION
end
def read_byte
@io.getbyte
end
def read_integer
b = read_byte
case b
when 0x00
0
when 0x01
@io.read(1).unpack1("C")
when 0x02
@io.read(2).unpack1("S<")
when 0x03
(@io.read(3) + "\0").unpack1("L<")
when 0x04
@io.read(4).unpack1("L<")
when 0xFC
@io.read(4).unpack1("L<") | -0x100000000
when 0xFD
(@io.read(3) + "\0").unpack1("L<") | -0x1000000
when 0xFE
@io.read(2).unpack1("s<") | -0x10000
when 0xFF
read_byte | -0x100
else
signed = (b ^ 128) - 128
if b >= 128
signed + 5
else
signed - 5
end
end
end
def read_element
type = read_byte
case type
when 34 then read_string # ?"
when 48 then read_nil # ?0
when 58 then read_symbol # ?:
when 59 then read_symbol_link # ?;
when 64 then read_object_link # ?@
when 70 then read_false # ?F
when 73 then read_object_with_ivars # ?I
when 84 then read_true # ?T
when 85 then read_user_marshal # ?U
when 91 then read_array # ?[
when 102 then read_float # ?f
when 105 then Elements::Integer.new int: read_integer # ?i
when 108 then read_bignum
when 111 then read_object # ?o
when 117 then read_user_defined # ?u
when 123 then read_hash # ?{
when 125 then read_hash_with_default_value # ?}
when "e".ord then read_extended_object
when "c".ord then read_class
when "m".ord then read_module
when "M".ord then read_class_or_module
when "d".ord then read_data
when "/".ord then read_regexp
when "S".ord then read_struct
when "C".ord then read_user_class
else
raise "Unsupported marshal type discriminator #{type.chr.inspect} (#{type})"
end
end
def read_symbol
Elements::Symbol.new name: @io.read(read_integer)
end
def read_string
Elements::String.new(str: @io.read(read_integer))
end
def read_true
Elements::True::TRUE
end
def read_false
Elements::False::FALSE
end
def read_user_defined
Elements::UserDefined.new(name: read_element, binary_string: @io.read(read_integer))
end
def read_array
Elements::Array.new(elements: Array.new(read_integer) do |_i|
read_element
end)
end
def read_object_with_ivars
Elements::WithIvars.new(object: read_element, ivars:
Array.new(read_integer) do
[read_element, read_element]
end)
end
def read_symbol_link
Elements::SymbolLink.new offset: read_integer
end
def read_user_marshal
Elements::UserMarshal.new(name: read_element, data: read_element)
end
def read_object_link
Elements::ObjectLink.new(offset: read_integer)
end
def read_hash
pairs = Array.new(read_integer) do
[read_element, read_element]
end
Elements::Hash.new(pairs: pairs)
end
def read_hash_with_default_value
pairs = Array.new(read_integer) do
[read_element, read_element]
end
Elements::HashWithDefaultValue.new(pairs: pairs, default: read_element)
end
def read_object
Elements::WithIvars.new(
object: Elements::Object.new(name: read_element),
ivars: Array.new(read_integer) do
[read_element, read_element]
end
)
end
def read_nil
Elements::Nil::NIL
end
def read_float
Elements::Float.new string: @io.read(read_integer)
end
def read_bignum
Elements::Bignum.new(sign: read_byte, data: @io.read(read_integer * 2))
end
end
end
end

View File

@ -0,0 +1,266 @@
# frozen_string_literal: true
require_relative "visitor"
module Gem::SafeMarshal
module Visitors
class ToRuby < Visitor
def initialize(permitted_classes:, permitted_symbols:)
@permitted_classes = permitted_classes
@permitted_symbols = permitted_symbols | permitted_classes | ["E"]
@objects = []
@symbols = []
@class_cache = {}
@stack = ["root"]
end
def inspect # :nodoc:
format("#<%s permitted_classes: %p permitted_symbols: %p>", self.class, @permitted_classes, @permitted_symbols)
end
def visit(target)
depth = @stack.size
super
ensure
@stack.slice!(depth.pred..)
end
private
def visit_Gem_SafeMarshal_Elements_Array(a)
register_object([]).replace(a.elements.each_with_index.map do |e, i|
@stack << "[#{i}]"
visit(e)
end)
end
def visit_Gem_SafeMarshal_Elements_Symbol(s)
resolve_symbol(s.name)
end
def map_ivars(ivars)
ivars.map.with_index do |(k, v), i|
@stack << "ivar #{i}"
k = visit(k)
@stack << k
next k, visit(v)
end
end
def visit_Gem_SafeMarshal_Elements_WithIvars(e)
idx = 0
object_offset = @objects.size
@stack << "object"
object = visit(e.object)
ivars = map_ivars(e.ivars)
case e.object
when Elements::UserDefined
if object.class == ::Time
offset = zone = nano_num = nano_den = nil
ivars.reject! do |k, v|
case k
when :offset
offset = v
when :zone
zone = v
when :nano_num
nano_num = v
when :nano_den
nano_den = v
when :submicro
else
next false
end
true
end
object = object.localtime offset if offset
if (nano_den || nano_num) && !(nano_den && nano_num)
raise FormatError, "Must have all of nano_den, nano_num for Time #{e.pretty_inspect}"
elsif nano_den && nano_num
nano = Rational(nano_num, nano_den)
nsec, subnano = nano.divmod(1)
nano = nsec + subnano
object = Time.at(object.to_r, nano, :nanosecond)
end
if zone
require "time"
Time.send(:force_zone!, object, zone, offset)
end
@objects[object_offset] = object
end
when Elements::String
enc = nil
ivars.each do |k, v|
case k
when :E
case v
when TrueClass
enc = "UTF-8"
when FalseClass
enc = "US-ASCII"
end
else
break
end
idx += 1
end
object.replace ::String.new(object, encoding: enc)
end
ivars[idx..].each do |k, v|
object.instance_variable_set k, v
end
object
end
def visit_Gem_SafeMarshal_Elements_Hash(o)
hash = register_object({})
o.pairs.each_with_index do |(k, v), i|
@stack << i
k = visit(k)
@stack << k
hash[k] = visit(v)
end
hash
end
def visit_Gem_SafeMarshal_Elements_HashWithDefaultValue(o)
hash = visit_Gem_SafeMarshal_Elements_Hash(o)
@stack << :default
hash.default = visit(o.default)
hash
end
def visit_Gem_SafeMarshal_Elements_Object(o)
register_object(resolve_class(o.name).allocate)
end
def visit_Gem_SafeMarshal_Elements_ObjectLink(o)
@objects[o.offset]
end
def visit_Gem_SafeMarshal_Elements_SymbolLink(o)
@symbols[o.offset]
end
def visit_Gem_SafeMarshal_Elements_UserDefined(o)
register_object(resolve_class(o.name).send(:_load, o.binary_string))
end
def visit_Gem_SafeMarshal_Elements_UserMarshal(o)
register_object(resolve_class(o.name).allocate).tap do |object|
@stack << :data
object.marshal_load visit(o.data)
end
end
def visit_Gem_SafeMarshal_Elements_Integer(i)
i.int
end
def visit_Gem_SafeMarshal_Elements_Nil(_)
nil
end
def visit_Gem_SafeMarshal_Elements_True(_)
true
end
def visit_Gem_SafeMarshal_Elements_False(_)
false
end
def visit_Gem_SafeMarshal_Elements_String(s)
register_object(s.str)
end
def visit_Gem_SafeMarshal_Elements_Float(f)
case f.string
when "inf"
::Float::INFINITY
when "-inf"
-::Float::INFINITY
when "nan"
::Float::NAN
else
f.string.to_f
end
end
def visit_Gem_SafeMarshal_Elements_Bignum(b)
result = 0
b.data.each_byte.with_index do |byte, exp|
result += (byte * 2**(exp * 8))
end
case b.sign
when 43 # ?+
result
when 45 # ?-
-result
else
raise FormatError, "Unexpected sign for Bignum #{b.sign.chr.inspect} (#{b.sign})"
end
end
def resolve_class(n)
@class_cache[n] ||= begin
name = nil
case n
when Elements::Symbol, Elements::SymbolLink
@stack << "class name"
name = visit(n)
else
raise FormatError, "Class names must be Symbol or SymbolLink"
end
to_s = name.to_s
raise UnpermittedClassError.new(name: name, stack: @stack.dup) unless @permitted_classes.include?(to_s)
begin
::Object.const_get(to_s)
rescue NameError
raise ArgumentError, "Undefined class #{to_s.inspect}"
end
end
end
def resolve_symbol(name)
raise UnpermittedSymbolError.new(symbol: name, stack: @stack.dup) unless @permitted_symbols.include?(name)
sym = name.to_sym
@symbols << sym
sym
end
def register_object(o)
@objects << o
o
end
class UnpermittedSymbolError < StandardError
def initialize(symbol:, stack:)
@symbol = symbol
@stack = stack
super "Attempting to load unpermitted symbol #{symbol.inspect} @ #{stack.join "."}"
end
end
class UnpermittedClassError < StandardError
def initialize(name:, stack:)
@name = name
@stack = stack
super "Attempting to load unpermitted class #{name.inspect} @ #{stack.join "."}"
end
end
class FormatError < StandardError
end
end
end
end

View File

@ -0,0 +1,74 @@
# frozen_string_literal: true
module Gem::SafeMarshal::Visitors
class Visitor
def visit(target)
send DISPATCH.fetch(target.class), target
end
private
DISPATCH = Gem::SafeMarshal::Elements.constants.each_with_object({}) do |c, h|
next if c == :Element
klass = Gem::SafeMarshal::Elements.const_get(c)
h[klass] = :"visit_#{klass.name.gsub("::", "_")}"
h.default = :visit_unknown_element
end.compare_by_identity.freeze
private_constant :DISPATCH
def visit_unknown_element(e)
raise ArgumentError, "Attempting to visit unknown element #{e.inspect}"
end
def visit_Gem_SafeMarshal_Elements_Array(target)
target.elements.each {|e| visit(e) }
end
def visit_Gem_SafeMarshal_Elements_Bignum(target); end
def visit_Gem_SafeMarshal_Elements_False(target); end
def visit_Gem_SafeMarshal_Elements_Float(target); end
def visit_Gem_SafeMarshal_Elements_Hash(target)
target.pairs.each do |k, v|
visit(k)
visit(v)
end
end
def visit_Gem_SafeMarshal_Elements_HashWithDefaultValue(target)
visit_Gem_SafeMarshal_Elements_Hash(target)
visit(target.default)
end
def visit_Gem_SafeMarshal_Elements_Integer(target); end
def visit_Gem_SafeMarshal_Elements_Nil(target); end
def visit_Gem_SafeMarshal_Elements_Object(target)
visit(target.name)
end
def visit_Gem_SafeMarshal_Elements_ObjectLink(target); end
def visit_Gem_SafeMarshal_Elements_String(target); end
def visit_Gem_SafeMarshal_Elements_Symbol(target); end
def visit_Gem_SafeMarshal_Elements_SymbolLink(target); end
def visit_Gem_SafeMarshal_Elements_True(target); end
def visit_Gem_SafeMarshal_Elements_UserDefined(target)
visit(target.name)
end
def visit_Gem_SafeMarshal_Elements_UserMarshal(target)
visit(target.name)
visit(target.data)
end
def visit_Gem_SafeMarshal_Elements_WithIvars(target)
visit(target.object)
target.ivars.each do |k, v|
visit(k)
visit(v)
end
end
end
end

View File

@ -135,8 +135,9 @@ class Gem::Source
if File.exist? local_spec
spec = Gem.read_binary local_spec
Gem.load_safe_marshal
spec = begin
Marshal.load(spec)
Gem::SafeMarshal.safe_load(spec)
rescue StandardError
nil
end
@ -157,8 +158,9 @@ class Gem::Source
end
end
Gem.load_safe_marshal
# TODO: Investigate setting Gem::Specification#loaded_from to a URI
Marshal.load spec
Gem::SafeMarshal.safe_load spec
end
##
@ -188,8 +190,9 @@ class Gem::Source
spec_dump = fetcher.cache_update_path spec_path, local_file, update_cache?
Gem.load_safe_marshal
begin
Gem::NameTuple.from_list Marshal.load(spec_dump)
Gem::NameTuple.from_list Gem::SafeMarshal.safe_load(spec_dump)
rescue ArgumentError
if update_cache? && !retried
FileUtils.rm local_file

View File

@ -1300,12 +1300,13 @@ class Gem::Specification < Gem::BasicSpecification
def self._load(str)
Gem.load_yaml
Gem.load_safe_marshal
yaml_set = false
retry_count = 0
array = begin
Marshal.load str
Gem::SafeMarshal.safe_load str
rescue ArgumentError => e
# Avoid an infinite retry loop when the argument error has nothing to do
# with the classes not being defined.

View File

@ -0,0 +1,144 @@
# frozen_string_literal: true
require_relative "helper"
require "date"
require "rubygems/safe_marshal"
class TestGemSafeMarshal < Gem::TestCase
def test_repeated_symbol
assert_safe_load_as [:development, :development]
end
def test_repeated_string
s = "hello"
a = [s]
assert_safe_load_as [s, a, s, a]
assert_safe_load_as [s, s]
end
def test_recursive_string
s = String.new("hello")
s.instance_variable_set(:@type, s)
assert_safe_load_as s, additional_methods: [:instance_variables]
end
def test_recursive_array
a = []
a << a
assert_safe_load_as a
end
def test_time_loads
assert_safe_load_as Time.new
end
def test_time_with_zone_loads
assert_safe_load_as Time.now(in: "+04:00")
end
def test_string_with_encoding
assert_safe_load_as String.new("abc", encoding: "US-ASCII")
assert_safe_load_as String.new("abc", encoding: "UTF-8")
end
def test_string_with_ivar
assert_safe_load_as String.new("abc").tap { _1.instance_variable_set :@type, "type" }
end
def test_time_with_ivar
assert_safe_load_as Time.new.tap { _1.instance_variable_set :@type, "type" }
end
secs = Time.new(2000, 12, 31, 23, 59, 59).to_i
[
Time.at(secs, 1, :millisecond),
Time.at(secs, 1.1, :millisecond),
Time.at(secs, 1.01, :millisecond),
Time.at(secs, 1, :microsecond),
Time.at(secs, 1.1, :microsecond),
Time.at(secs, 1.01, :microsecond),
Time.at(secs, 1, :nanosecond),
Time.at(secs, 1.1, :nanosecond),
Time.at(secs, 1.01, :nanosecond),
Time.at(secs, 1.001, :nanosecond),
Time.at(secs, 1.00001, :nanosecond),
Time.at(secs, 1.00001, :nanosecond).tap {|t| t.instance_variable_set :@type, "type" },
].each_with_index do |t, i|
define_method("test_time_#{i} #{t.inspect}") do
assert_safe_load_as t, additional_methods: [:ctime, :to_f, :to_r, :to_i, :zone, :subsec, :instance_variables, :to_a]
end
end
def test_floats
[0.0, Float::INFINITY, Float::NAN, 1.1, 3e7].each do |f|
assert_safe_load_as f
assert_safe_load_as(-f)
end
end
def test_hash_with_ivar
assert_safe_load_as({ runtime: :development }.tap { _1.instance_variable_set :@type, "null" })
end
def test_hash_with_default_value
assert_safe_load_as Hash.new([])
end
def test_frozen_object
assert_safe_load_as Gem::Version.new("1.abc").freeze
end
def test_date
assert_safe_load_as Date.new
end
[
0, 1, 2, 3, 4, 5, 6, 122, 123, 124, 127, 128, 255, 256, 257,
2**16, 2**16 - 1, 2**20 - 1,
2**28, 2**28 - 1,
2**32, 2**32 - 1,
2**63, 2**63 - 1
].
each do |i|
define_method("test_int_ #{i}") do
assert_safe_load_as i
assert_safe_load_as(-i)
assert_safe_load_as(i + 1)
assert_safe_load_as(i - 1)
end
end
def test_gem_spec_disallowed_symbol
e = assert_raise(Gem::SafeMarshal::Visitors::ToRuby::UnpermittedSymbolError) do
spec = Gem::Specification.new do |s|
s.name = "hi"
s.version = "1.2.3"
s.dependencies << Gem::Dependency.new("rspec", Gem::Requirement.new([">= 1.2.3"]), :runtime).tap { _1.instance_variable_set(:@name, :rspec) }
end
Gem::SafeMarshal.safe_load(Marshal.dump(spec))
end
assert_equal e.message, "Attempting to load unpermitted symbol \"rspec\" @ root.[9].[0].@name"
end
def assert_safe_load_as(x, additional_methods: [])
dumped = Marshal.dump(x)
loaded = Marshal.load(dumped)
safe_loaded = Gem::SafeMarshal.safe_load(dumped)
# NaN != NaN, for example
if x == x # rubocop:disable Lint/BinaryOperatorWithIdenticalOperands
# assert_equal x, safe_loaded, "should load #{dumped.inspect}"
assert_equal loaded, safe_loaded, "should equal what Marshal.load returns"
end
assert_equal x.to_s, safe_loaded.to_s, "should have equal to_s"
assert_equal x.inspect, safe_loaded.inspect, "should have equal inspect"
additional_methods.each do |m|
assert_equal loaded.send(m), safe_loaded.send(m), "should have equal #{m}"
end
assert_equal Marshal.dump(loaded), Marshal.dump(safe_loaded), "should Marshal.dump the same"
end
end