diff --git a/class.c b/class.c index 6bf17aaa47..a11285fc97 100644 --- a/class.c +++ b/class.c @@ -124,6 +124,8 @@ rb_class_foreach_subclass(VALUE klass, void (*f)(VALUE, VALUE), VALUE arg) while (cur) { VALUE curklass = cur->klass; cur = cur->next; + // do not trigger GC during f, otherwise the cur will become + // a dangling pointer if the subclass is collected f(curklass, arg); } } @@ -1334,13 +1336,24 @@ rb_mod_ancestors(VALUE mod) return ary; } -static void -class_descendants_recursive(VALUE klass, VALUE ary) +struct subclass_traverse_data { + VALUE *buffer; + long count; +}; + +static void +class_descendants_recursive(VALUE klass, VALUE v) +{ + struct subclass_traverse_data *data = (struct subclass_traverse_data *) v; + if (BUILTIN_TYPE(klass) == T_CLASS && !FL_TEST(klass, FL_SINGLETON)) { - rb_ary_push(ary, klass); + if (data->buffer) { + data->buffer[data->count] = klass; + } + data->count++; } - rb_class_foreach_subclass(klass, class_descendants_recursive, ary); + rb_class_foreach_subclass(klass, class_descendants_recursive, v); } /* @@ -1364,9 +1377,19 @@ class_descendants_recursive(VALUE klass, VALUE ary) VALUE rb_class_descendants(VALUE klass) { - VALUE ary = rb_ary_new(); - rb_class_foreach_subclass(klass, class_descendants_recursive, ary); - return ary; + struct subclass_traverse_data data = { NULL, 0 }; + + // estimate the count of subclasses + rb_class_foreach_subclass(klass, class_descendants_recursive, (VALUE) &data); + + // this allocation may cause GC which may reduce the subclasses + data.buffer = ALLOCA_N(VALUE, data.count); + data.count = 0; + + // enumerate subclasses + rb_class_foreach_subclass(klass, class_descendants_recursive, (VALUE) &data); + + return rb_ary_new_from_values(data.count, data.buffer); } static void diff --git a/test/ruby/test_class.rb b/test/ruby/test_class.rb index 96bca08601..034f4c6d20 100644 --- a/test/ruby/test_class.rb +++ b/test/ruby/test_class.rb @@ -755,4 +755,10 @@ class TestClass < Test::Unit::TestCase assert_include(object_descendants, sc) assert_include(object_descendants, ssc) end + + def test_descendants_gc + c = Class.new + 100000.times { Class.new(c) } + assert(c.descendants.size <= 100000) + end end