Implement Enumerator::Product and Enumerator.product [Feature #18685]

This commit is contained in:
Akinori MUSHA 2022-07-29 13:56:54 +09:00
parent 2338845421
commit 1a73a6cdd2
No known key found for this signature in database
GPG Key ID: C5D7717121727E34
Notes: git 2022-07-30 20:05:57 +09:00
2 changed files with 391 additions and 1 deletions

View File

@ -125,7 +125,7 @@
*/
VALUE rb_cEnumerator;
static VALUE rb_cLazy;
static ID id_rewind, id_new, id_to_enum;
static ID id_rewind, id_new, id_to_enum, id_each_entry;
static ID id_next, id_result, id_receiver, id_arguments, id_memo, id_method, id_force;
static ID id_begin, id_end, id_step, id_exclude_end;
static VALUE sym_each, sym_cycle, sym_yield;
@ -194,6 +194,12 @@ struct enum_chain {
long pos;
};
static VALUE rb_cEnumProduct;
struct enum_product {
VALUE enums;
};
VALUE rb_cArithSeq;
/*
@ -3347,6 +3353,335 @@ enumerator_plus(VALUE obj, VALUE eobj)
return new_enum_chain(rb_ary_new_from_args(2, obj, eobj));
}
/*
* Document-class: Enumerator::Product
*
* Enumerator::Product generates a Cartesian product of any number of
* enumerable objects. Iterating over the product of enumerable
* objects is roughly equivalent to nested each_entry loops where the
* loop for the rightmost object is put innermost.
*
* innings = Enumerator::Product.new(1..9, ['top', 'bottom'])
*
* innings.each do |i, h|
* p [i, h]
* end
* # [1, "top"]
* # [1, "bottom"]
* # [2, "top"]
* # [2, "bottom"]
* # [3, "top"]
* # [3, "bottom"]
* # ...
* # [9, "top"]
* # [9, "bottom"]
*
* The method used against each enumerable object is `each_entry`
* instead of `each` so that the product of N enumerable objects
* yields exactly N arguments in each iteration.
*
* When no enumerator is given, it calls a given block once yielding
* an empty argument list.
*
* This type of objects can be created by Enumerator.product.
*/
static void
enum_product_mark(void *p)
{
struct enum_product *ptr = p;
rb_gc_mark_movable(ptr->enums);
}
static void
enum_product_compact(void *p)
{
struct enum_product *ptr = p;
ptr->enums = rb_gc_location(ptr->enums);
}
#define enum_product_free RUBY_TYPED_DEFAULT_FREE
static size_t
enum_product_memsize(const void *p)
{
return sizeof(struct enum_product);
}
static const rb_data_type_t enum_product_data_type = {
"product",
{
enum_product_mark,
enum_product_free,
enum_product_memsize,
enum_product_compact,
},
0, 0, RUBY_TYPED_FREE_IMMEDIATELY
};
static struct enum_product *
enum_product_ptr(VALUE obj)
{
struct enum_product *ptr;
TypedData_Get_Struct(obj, struct enum_product, &enum_product_data_type, ptr);
if (!ptr || ptr->enums == Qundef) {
rb_raise(rb_eArgError, "uninitialized product");
}
return ptr;
}
/* :nodoc: */
static VALUE
enum_product_allocate(VALUE klass)
{
struct enum_product *ptr;
VALUE obj;
obj = TypedData_Make_Struct(klass, struct enum_product, &enum_product_data_type, ptr);
ptr->enums = Qundef;
return obj;
}
/*
* call-seq:
* Enumerator::Product.new(*enums) -> enum
*
* Generates a new enumerator object that generates a Cartesian
* product of given enumerable objects.
*
* e = Enumerator::Product.new(1..3, [4, 5])
* e.to_a #=> [[1, 4], [1, 5], [2, 4], [2, 5], [3, 4], [3, 5]]
* e.size #=> 6
*/
static VALUE
enum_product_initialize(VALUE obj, VALUE enums)
{
struct enum_product *ptr;
rb_check_frozen(obj);
TypedData_Get_Struct(obj, struct enum_product, &enum_product_data_type, ptr);
if (!ptr) rb_raise(rb_eArgError, "unallocated product");
ptr->enums = rb_obj_freeze(enums);
return obj;
}
/* :nodoc: */
static VALUE
enum_product_init_copy(VALUE obj, VALUE orig)
{
struct enum_product *ptr0, *ptr1;
if (!OBJ_INIT_COPY(obj, orig)) return obj;
ptr0 = enum_product_ptr(orig);
TypedData_Get_Struct(obj, struct enum_product, &enum_product_data_type, ptr1);
if (!ptr1) rb_raise(rb_eArgError, "unallocated product");
ptr1->enums = ptr0->enums;
return obj;
}
static VALUE
enum_product_total_size(VALUE enums)
{
VALUE total = INT2FIX(1);
long i;
for (i = 0; i < RARRAY_LEN(enums); i++) {
VALUE size = enum_size(RARRAY_AREF(enums, i));
if (NIL_P(size) || (RB_TYPE_P(size, T_FLOAT) && isinf(NUM2DBL(size)))) {
return size;
}
if (!RB_INTEGER_TYPE_P(size)) {
return Qnil;
}
total = rb_funcall(total, '*', 1, size);
}
return total;
}
/*
* call-seq:
* obj.size -> int, Float::INFINITY or nil
*
* Returns the total size of the enumerator product calculated by
* multiplying the sizes of enumerables in the product. If any of the
* enumerables reports its size as nil or Float::INFINITY, that value
* is returned as the size.
*/
static VALUE
enum_product_size(VALUE obj)
{
return enum_product_total_size(enum_product_ptr(obj)->enums);
}
static VALUE
enum_product_enum_size(VALUE obj, VALUE args, VALUE eobj)
{
return enum_product_size(obj);
}
struct product_state {
VALUE obj;
VALUE block;
int argc;
VALUE *argv;
int index;
};
static VALUE product_each(VALUE, struct product_state *);
static VALUE
product_each_i(RB_BLOCK_CALL_FUNC_ARGLIST(value, state))
{
struct product_state *pstate = (struct product_state *)state;
pstate->argv[pstate->index++] = value;
VALUE val = product_each(pstate->obj, pstate);
pstate->index--;
return val;
}
static VALUE
product_each(VALUE obj, struct product_state *pstate)
{
struct enum_product *ptr = enum_product_ptr(obj);
VALUE enums = ptr->enums;
if (pstate->index < pstate->argc) {
VALUE eobj = RARRAY_AREF(enums, pstate->index);
rb_block_call(eobj, id_each_entry, 0, NULL, product_each_i, (VALUE)pstate);
} else {
rb_funcallv(pstate->block, id_call, pstate->argc, pstate->argv);
}
return obj;
}
static VALUE
enum_product_run(VALUE obj, VALUE block)
{
struct enum_product *ptr = enum_product_ptr(obj);
int argc = RARRAY_LENINT(ptr->enums);
struct product_state state = {
.obj = obj,
.block = block,
.index = 0,
.argc = argc,
.argv = ALLOCA_N(VALUE, argc),
};
return product_each(obj, &state);
}
/*
* call-seq:
* obj.each { |...| ... } -> obj
* obj.each -> enumerator
*
* Iterates over the elements of the first enumerable by calling the
* "each_entry" method on it with the given arguments, then proceeds
* to the following enumerables in sequence until all of the
* enumerables are exhausted.
*
* If no block is given, returns an enumerator. Otherwise, returns self.
*/
static VALUE
enum_product_each(VALUE obj)
{
RETURN_SIZED_ENUMERATOR(obj, 0, 0, enum_product_enum_size);
return enum_product_run(obj, rb_block_proc());
}
/*
* call-seq:
* obj.rewind -> obj
*
* Rewinds the product enumerator by calling the "rewind" method on
* each enumerable in reverse order. Each call is performed only if
* the enumerable responds to the method.
*/
static VALUE
enum_product_rewind(VALUE obj)
{
struct enum_product *ptr = enum_product_ptr(obj);
VALUE enums = ptr->enums;
long i;
for (i = 0; i < RARRAY_LEN(enums); i++) {
rb_check_funcall(RARRAY_AREF(enums, i), id_rewind, 0, 0);
}
return obj;
}
static VALUE
inspect_enum_product(VALUE obj, VALUE dummy, int recur)
{
VALUE klass = rb_obj_class(obj);
struct enum_product *ptr;
TypedData_Get_Struct(obj, struct enum_product, &enum_product_data_type, ptr);
if (!ptr || ptr->enums == Qundef) {
return rb_sprintf("#<%"PRIsVALUE": uninitialized>", rb_class_path(klass));
}
if (recur) {
return rb_sprintf("#<%"PRIsVALUE": ...>", rb_class_path(klass));
}
return rb_sprintf("#<%"PRIsVALUE": %+"PRIsVALUE">", rb_class_path(klass), ptr->enums);
}
/*
* call-seq:
* obj.inspect -> string
*
* Returns a printable version of the product enumerator.
*/
static VALUE
enum_product_inspect(VALUE obj)
{
return rb_exec_recursive(inspect_enum_product, obj, 0);
}
/*
* call-seq:
* Enumerator.product(*enums) -> enumerator
*
* Generates a new enumerator object that generates a Cartesian
* product of given enumerable objects. This is equivalent to
* Enumerator::Product.new.
*
* e = Enumerator.product(1..3, [4, 5])
* e.to_a #=> [[1, 4], [1, 5], [2, 4], [2, 5], [3, 4], [3, 5]]
* e.size #=> 6
*/
static VALUE
enumerator_s_product(VALUE klass, VALUE enums)
{
VALUE obj = enum_product_initialize(enum_product_allocate(rb_cEnumProduct), enums);
if (rb_block_given_p()) {
return enum_product_run(obj, rb_block_proc());
} else {
return obj;
}
}
/*
* Document-class: Enumerator::ArithmeticSequence
*
@ -4214,6 +4549,22 @@ InitVM_Enumerator(void)
rb_undef_method(rb_cEnumChain, "peek");
rb_undef_method(rb_cEnumChain, "peek_values");
/* Product */
rb_cEnumProduct = rb_define_class_under(rb_cEnumerator, "Product", rb_cEnumerator);
rb_define_alloc_func(rb_cEnumProduct, enum_product_allocate);
rb_define_method(rb_cEnumProduct, "initialize", enum_product_initialize, -2);
rb_define_method(rb_cEnumProduct, "initialize_copy", enum_product_init_copy, 1);
rb_define_method(rb_cEnumProduct, "each", enum_product_each, 0);
rb_define_method(rb_cEnumProduct, "size", enum_product_size, 0);
rb_define_method(rb_cEnumProduct, "rewind", enum_product_rewind, 0);
rb_define_method(rb_cEnumProduct, "inspect", enum_product_inspect, 0);
rb_undef_method(rb_cEnumProduct, "feed");
rb_undef_method(rb_cEnumProduct, "next");
rb_undef_method(rb_cEnumProduct, "next_values");
rb_undef_method(rb_cEnumProduct, "peek");
rb_undef_method(rb_cEnumProduct, "peek_values");
rb_define_singleton_method(rb_cEnumerator, "product", enumerator_s_product, -2);
/* ArithmeticSequence */
rb_cArithSeq = rb_define_class_under(rb_cEnumerator, "ArithmeticSequence", rb_cEnumerator);
rb_undef_alloc_func(rb_cArithSeq);
@ -4249,6 +4600,7 @@ Init_Enumerator(void)
id_method = rb_intern_const("method");
id_force = rb_intern_const("force");
id_to_enum = rb_intern_const("to_enum");
id_each_entry = rb_intern_const("each_entry");
id_begin = rb_intern_const("begin");
id_end = rb_intern_const("end");
id_step = rb_intern_const("step");

View File

@ -906,4 +906,42 @@ class TestEnumerator < Test::Unit::TestCase
e.chain.each(&->{})
assert_equal(true, e.is_lambda)
end
def test_product
e = Enumerator::Product.new
assert_instance_of(Enumerator::Product, e)
assert_kind_of(Enumerator, e)
assert_equal(1, e.size)
elts = []
e.each { |*x| elts << x }
assert_equal [[]], elts
e = Enumerator::Product.new(1..3, %w[a b])
assert_instance_of(Enumerator::Product, e)
assert_kind_of(Enumerator, e)
assert_equal(3 * 2, e.size)
elts = []
e.each { |*x| elts << x }
assert_equal [[1, "a"], [1, "b"], [2, "a"], [2, "b"], [3, "a"], [3, "b"]], elts
e = Enumerator.product(1..3, %w[a b])
assert_instance_of(Enumerator::Product, e)
elts = []
ret = Enumerator.product(1..3, %w[a b]) { |*x| elts << x }
assert_instance_of(Enumerator::Product, ret)
assert_equal [[1, "a"], [1, "b"], [2, "a"], [2, "b"], [3, "a"], [3, "b"]], elts
e = Enumerator.product(1.., 'a'..'c')
assert_equal(Float::INFINITY, e.size)
assert_equal [[1, "a"], [1, "b"], [1, "c"], [2, "a"]], e.take(4)
e = Enumerator.product(1.., Enumerator.new { |y| y << 'a' << 'b' })
assert_equal(Float::INFINITY, e.size)
assert_equal [[1, "a"], [1, "b"], [2, "a"], [2, "b"]], e.take(4)
e = Enumerator.product(1..3, Enumerator.new { |y| y << 'a' << 'b' })
assert_equal(nil, e.size)
assert_equal [[1, "a"], [1, "b"], [2, "a"], [2, "b"]], e.take(4)
end
end