Check for cyclic prepend before making origin

It's important to only make the origin when the prepend goes
through, as the precense of the origin informs whether to do an
origin backfill.

This plus 2d877327e fix [Bug #17590].
This commit is contained in:
Alan Wu 2021-02-12 22:45:08 -05:00
parent 67d2619463
commit 58e8220605
Notes: git 2021-02-23 07:58:18 +09:00
2 changed files with 48 additions and 21 deletions

56
class.c
View File

@ -351,7 +351,7 @@ copy_tables(VALUE clone, VALUE orig)
} }
} }
static void ensure_origin(VALUE klass); static bool ensure_origin(VALUE klass);
/* :nodoc: */ /* :nodoc: */
VALUE VALUE
@ -1014,27 +1014,31 @@ clear_module_cache_i(ID id, VALUE val, void *data)
return ID_TABLE_CONTINUE; return ID_TABLE_CONTINUE;
} }
static bool
module_in_super_chain(const VALUE klass, VALUE module)
{
struct rb_id_table *const klass_m_tbl = RCLASS_M_TBL(RCLASS_ORIGIN(klass));
if (klass_m_tbl) {
while (module) {
if (klass_m_tbl == RCLASS_M_TBL(module))
return true;
module = RCLASS_SUPER(module);
}
}
return false;
}
static int static int
include_modules_at(const VALUE klass, VALUE c, VALUE module, int search_super) do_include_modules_at(const VALUE klass, VALUE c, VALUE module, int search_super, bool check_cyclic)
{ {
VALUE p, iclass, origin_stack = 0; VALUE p, iclass, origin_stack = 0;
int method_changed = 0, constant_changed = 0, add_subclass; int method_changed = 0, constant_changed = 0, add_subclass;
long origin_len; long origin_len;
VALUE klass_origin = RCLASS_ORIGIN(klass); VALUE klass_origin = RCLASS_ORIGIN(klass);
struct rb_id_table *const klass_m_tbl = RCLASS_M_TBL(klass_origin);
VALUE original_klass = klass; VALUE original_klass = klass;
if (klass_m_tbl) { if (check_cyclic && module_in_super_chain(klass, module))
VALUE original_module = module; return -1;
while (module) {
if (klass_m_tbl == RCLASS_M_TBL(module))
return -1;
module = RCLASS_SUPER(module);
}
module = original_module;
}
while (module) { while (module) {
int c_seen = FALSE; int c_seen = FALSE;
@ -1129,6 +1133,12 @@ include_modules_at(const VALUE klass, VALUE c, VALUE module, int search_super)
return method_changed; return method_changed;
} }
static int
include_modules_at(const VALUE klass, VALUE c, VALUE module, int search_super)
{
return do_include_modules_at(klass, c, module, search_super, true);
}
static enum rb_id_table_iterator_result static enum rb_id_table_iterator_result
move_refined_method(ID key, VALUE value, void *data) move_refined_method(ID key, VALUE value, void *data)
{ {
@ -1169,7 +1179,7 @@ cache_clear_refined_method(ID key, VALUE value, void *data)
return ID_TABLE_CONTINUE; return ID_TABLE_CONTINUE;
} }
static void static bool
ensure_origin(VALUE klass) ensure_origin(VALUE klass)
{ {
VALUE origin = RCLASS_ORIGIN(klass); VALUE origin = RCLASS_ORIGIN(klass);
@ -1182,20 +1192,24 @@ ensure_origin(VALUE klass)
RCLASS_M_TBL_INIT(klass); RCLASS_M_TBL_INIT(klass);
rb_id_table_foreach(RCLASS_M_TBL(origin), cache_clear_refined_method, (void *)klass); rb_id_table_foreach(RCLASS_M_TBL(origin), cache_clear_refined_method, (void *)klass);
rb_id_table_foreach(RCLASS_M_TBL(origin), move_refined_method, (void *)klass); rb_id_table_foreach(RCLASS_M_TBL(origin), move_refined_method, (void *)klass);
return true;
} }
return false;
} }
void void
rb_prepend_module(VALUE klass, VALUE module) rb_prepend_module(VALUE klass, VALUE module)
{ {
int changed = 0; int changed;
bool klass_had_no_origin = RCLASS_ORIGIN(klass) == klass; bool klass_had_no_origin;
ensure_includable(klass, module); ensure_includable(klass, module);
ensure_origin(klass); if (module_in_super_chain(klass, module))
changed = include_modules_at(klass, klass, module, FALSE); rb_raise(rb_eArgError, "cyclic prepend detected");
if (changed < 0)
rb_raise(rb_eArgError, "cyclic prepend detected"); klass_had_no_origin = ensure_origin(klass);
changed = do_include_modules_at(klass, klass, module, FALSE, false);
RUBY_ASSERT(changed >= 0); // already checked for cyclic prepend above
if (changed) { if (changed) {
rb_vm_check_redefinition_by_prepend(klass); rb_vm_check_redefinition_by_prepend(klass);
} }

View File

@ -485,6 +485,19 @@ class TestModule < Test::Unit::TestCase
assert_equal([m], m.ancestors) assert_equal([m], m.ancestors)
end end
def test_bug17590
m = Module.new
c = Class.new
c.prepend(m)
c.include(m)
m.prepend(m) rescue nil
m2 = Module.new
m2.prepend(m)
c.include(m2)
assert_equal([m, c, m2] + Object.ancestors, c.ancestors)
end
def test_prepend_works_with_duped_classes def test_prepend_works_with_duped_classes
m = Module.new m = Module.new
a = Class.new do a = Class.new do