diff --git a/ChangeLog b/ChangeLog index c7091e002a..089fbcb52b 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,21 @@ +Wed Sep 16 06:30:07 2009 Marc-Andre Lafortune + + * thread.c (rb_exec_recursive_outer, rb_exec_recursive): Added method + to short-circuit to the outermost level in case of recursion + + * test/ruby/test_thread.rb (test_recursive_outer): Test for above + + * hash.c (rb_hash_hash): Return a sensible hash for in case of + recursion [ruby-core:24648] + + * range.c (rb_range_hash): ditto + + * struct.c (rb_struct_hash): ditto + + * array.c (rb_array_hash): ditto + + * test/ruby/test_array.rb (test_hash2): test for above + Wed Sep 16 06:17:33 2009 Marc-Andre Lafortune * vm_eval.c (rb_catch_obj, rb_catch, rb_f_catch): No longer use the diff --git a/array.c b/array.c index d16cccb751..55904ddcdd 100644 --- a/array.c +++ b/array.c @@ -2884,13 +2884,15 @@ recursive_hash(VALUE ary, VALUE dummy, int recur) st_index_t h; VALUE n; - if (recur) { - rb_raise(rb_eArgError, "recursive key for hash"); - } h = rb_hash_start(RARRAY_LEN(ary)); - for (i=0; intbl) return LONG2FIX(0); hval = RHASH(hash)->ntbl->num_entries; - rb_hash_foreach(hash, hash_i, (VALUE)&hval); + if (recur) + hval = rb_hash_end(rb_hash_uint(rb_hash_start(rb_hash(rb_cHash)), hval)); + else + rb_hash_foreach(hash, hash_i, (VALUE)&hval); return INT2FIX(hval); } @@ -1577,7 +1577,7 @@ recursive_hash(VALUE hash, VALUE dummy, int recur) static VALUE rb_hash_hash(VALUE hash) { - return rb_exec_recursive(recursive_hash, hash, 0); + return rb_exec_recursive_outer(recursive_hash, hash, 0); } static int diff --git a/include/ruby/intern.h b/include/ruby/intern.h index 3adec625da..d0e321303e 100644 --- a/include/ruby/intern.h +++ b/include/ruby/intern.h @@ -343,6 +343,7 @@ void rb_thread_atfork(void); void rb_thread_atfork_before_exec(void); VALUE rb_exec_recursive(VALUE(*)(VALUE, VALUE, int),VALUE,VALUE); VALUE rb_exec_recursive_paired(VALUE(*)(VALUE, VALUE, int),VALUE,VALUE,VALUE); +VALUE rb_exec_recursive_outer(VALUE(*)(VALUE, VALUE, int),VALUE,VALUE); /* file.c */ VALUE rb_file_s_expand_path(int, VALUE *); VALUE rb_file_expand_path(VALUE, VALUE); diff --git a/range.c b/range.c index 739831536e..a394f00b6e 100644 --- a/range.c +++ b/range.c @@ -207,14 +207,13 @@ recursive_hash(VALUE range, VALUE dummy, int recur) st_index_t hash = EXCL(range); VALUE v; - if (recur) { - rb_raise(rb_eArgError, "recursive key for hash"); - } hash = rb_hash_start(hash); - v = rb_hash(RANGE_BEG(range)); - hash = rb_hash_uint(hash, NUM2LONG(v)); - v = rb_hash(RANGE_END(range)); - hash = rb_hash_uint(hash, NUM2LONG(v)); + if (!recur) { + v = rb_hash(RANGE_BEG(range)); + hash = rb_hash_uint(hash, NUM2LONG(v)); + v = rb_hash(RANGE_END(range)); + hash = rb_hash_uint(hash, NUM2LONG(v)); + } hash = rb_hash_uint(hash, EXCL(range) << 24); hash = rb_hash_end(hash); @@ -233,7 +232,7 @@ recursive_hash(VALUE range, VALUE dummy, int recur) static VALUE range_hash(VALUE range) { - return rb_exec_recursive(recursive_hash, range, 0); + return rb_exec_recursive_outer(recursive_hash, range, 0); } static void diff --git a/struct.c b/struct.c index 463fa392b5..f9adfb868b 100644 --- a/struct.c +++ b/struct.c @@ -810,13 +810,12 @@ recursive_hash(VALUE s, VALUE dummy, int recur) st_index_t h; VALUE n; - if (recur) { - rb_raise(rb_eArgError, "recursive key for hash"); - } h = rb_hash_start(rb_hash(rb_obj_class(s))); - for (i = 0; i < RSTRUCT_LEN(s); i++) { - n = rb_hash(RSTRUCT_PTR(s)[i]); - h = rb_hash_uint(h, NUM2LONG(n)); + if (!recur) { + for (i = 0; i < RSTRUCT_LEN(s); i++) { + n = rb_hash(RSTRUCT_PTR(s)[i]); + h = rb_hash_uint(h, NUM2LONG(n)); + } } h = rb_hash_end(h); return INT2FIX(h); @@ -832,7 +831,7 @@ recursive_hash(VALUE s, VALUE dummy, int recur) static VALUE rb_struct_hash(VALUE s) { - return rb_exec_recursive(recursive_hash, s, 0); + return rb_exec_recursive_outer(recursive_hash, s, 0); } /* diff --git a/test/ruby/test_array.rb b/test/ruby/test_array.rb index 2723b37c6a..fd3b98432d 100644 --- a/test/ruby/test_array.rb +++ b/test/ruby/test_array.rb @@ -1572,7 +1572,8 @@ class TestArray < Test::Unit::TestCase def test_hash2 a = [] a << a - assert_raise(ArgumentError) { a.hash } + assert_equal([[a]].hash, a.hash) + assert_not_equal([a, a].hash, a.hash) # Implementation dependent end def test_flatten2 diff --git a/test/ruby/test_thread.rb b/test/ruby/test_thread.rb index d0c35d8273..397da2c3f7 100644 --- a/test/ruby/test_thread.rb +++ b/test/ruby/test_thread.rb @@ -458,6 +458,19 @@ class TestThread < Test::Unit::TestCase end assert_raise(TypeError) { [o].inspect } end + + def test_recursive_outer + arr = [] + obj = Struct.new(:foo, :visited).new(arr, false) + arr << obj + def obj.hash + self[:visited] = true + super + raise "recursive_outer should short circuit intermediate calls" + end + assert_nothing_raised {arr.hash} + assert(obj[:visited]) + end end class TestThreadGroup < Test::Unit::TestCase diff --git a/thread.c b/thread.c index 5943cecdc0..cab4822e28 100644 --- a/thread.c +++ b/thread.c @@ -3494,33 +3494,77 @@ recursive_pop(VALUE list, VALUE obj, VALUE paired_obj) rb_hash_delete(list, obj); } +struct exec_recursive_params { + VALUE (*func) (VALUE, VALUE, int); + VALUE list; + VALUE obj; + VALUE objid; + VALUE pairid; + VALUE arg; +}; + +static VALUE +exec_recursive_i(VALUE tag, struct exec_recursive_params *p) +{ + VALUE result = Qundef; + int state; + + recursive_push(p->list, p->objid, p->pairid); + PUSH_TAG(); + if ((state = EXEC_TAG()) == 0) { + result = (*p->func) (p->obj, p->arg, Qfalse); + } + POP_TAG(); + recursive_pop(p->list, p->objid, p->pairid); + if (state) + JUMP_TAG(state); + return result; +} + /* * Calls func(obj, arg, recursive), where recursive is non-zero if the * current method is called recursively on obj, or on the pair + * If outer is 0, then the innermost func will be called with recursive set + * to Qtrue, otherwise the outermost func will be called. In the latter case, + * all inner func are short-circuited by throw. + * Implementation details: the value thrown is the recursive list which is + * proper to the current method and unlikely to be catched anywhere else. + * list[recursive_key] is used as a flag for the outermost call. */ static VALUE -exec_recursive(VALUE (*func) (VALUE, VALUE, int), VALUE obj, VALUE pairid, VALUE arg) +exec_recursive(VALUE (*func) (VALUE, VALUE, int), VALUE obj, VALUE pairid, VALUE arg, int outer) { - VALUE list = recursive_list_access(); - VALUE objid = rb_obj_id(obj); + struct exec_recursive_params p; + int outermost; + p.list = recursive_list_access(); + p.objid = rb_obj_id(obj); + outermost = outer && !recursive_check(p.list, ID2SYM(recursive_key), 0); - if (recursive_check(list, objid, pairid)) { + if (recursive_check(p.list, p.objid, pairid)) { + if (outer && !outermost) { + rb_throw_obj(p.list, p.list); + } return (*func) (obj, arg, Qtrue); } else { VALUE result = Qundef; - int state; + p.func = func; + p.obj = obj; + p.pairid = pairid; + p.arg = arg; - recursive_push(list, objid, pairid); - PUSH_TAG(); - if ((state = EXEC_TAG()) == 0) { - result = (*func) (obj, arg, Qfalse); + if (outermost) { + recursive_push(p.list, ID2SYM(recursive_key), 0); + result = rb_catch_obj(p.list, exec_recursive_i, (VALUE)&p); + recursive_pop(p.list, ID2SYM(recursive_key), 0); + if (result == p.list) { + result = (*func) (obj, arg, Qtrue); + } + } + else { + result = exec_recursive_i(0, &p); } - POP_TAG(); - recursive_pop(list, objid, pairid); - if (state) - JUMP_TAG(state); return result; } } @@ -3533,19 +3577,30 @@ exec_recursive(VALUE (*func) (VALUE, VALUE, int), VALUE obj, VALUE pairid, VALUE VALUE rb_exec_recursive(VALUE (*func) (VALUE, VALUE, int), VALUE obj, VALUE arg) { - return exec_recursive(func, obj, 0, arg); + return exec_recursive(func, obj, 0, arg, 0); } /* * Calls func(obj, arg, recursive), where recursive is non-zero if the - * current method is called recursively on the pair - * (in that order) + * current method is called recursively on the ordered pair */ VALUE rb_exec_recursive_paired(VALUE (*func) (VALUE, VALUE, int), VALUE obj, VALUE paired_obj, VALUE arg) { - return exec_recursive(func, obj, rb_obj_id(paired_obj), arg); + return exec_recursive(func, obj, rb_obj_id(paired_obj), arg, 0); +} + +/* + * If recursion is detected on the current method and obj, the outermost + * func will be called with (obj, arg, Qtrue). All inner func will be + * short-circuited using throw. + */ + +VALUE +rb_exec_recursive_outer(VALUE (*func) (VALUE, VALUE, int), VALUE obj, VALUE arg) +{ + return exec_recursive(func, obj, 0, arg, 1); } /* tracer */