Optimize Range#bsearch by reducing the number of Integer#+ calls

This commit is contained in:
Kouhei Yanagita 2023-09-19 20:18:00 +09:00 committed by Nobuyoshi Nakada
parent 91042ec0ae
commit 4199e49cad

51
range.c
View File

@ -649,27 +649,30 @@ bsearch_integer_range(VALUE beg, VALUE end, int excl)
VALUE low = rb_to_int(beg); VALUE low = rb_to_int(beg);
VALUE high = rb_to_int(end); VALUE high = rb_to_int(end);
VALUE mid, org_high; VALUE mid;
ID id_div; ID id_div;
CONST_ID(id_div, "div"); CONST_ID(id_div, "div");
if (excl) high = rb_funcall(high, '-', 1, INT2FIX(1)); if (!excl) high = rb_funcall(high, '+', 1, INT2FIX(1));
org_high = high; low = rb_funcall(low, '-', 1, INT2FIX(1));
while (rb_cmpint(rb_funcall(low, id_cmp, 1, high), low, high) < 0) { /*
mid = rb_funcall(rb_funcall(high, '+', 1, low), id_div, 1, INT2FIX(2)); * This loop must continue while low + 1 < high.
* Instead of checking low + 1 < high, check low < mid, where mid = (low + high) / 2.
* This is to avoid the cost of calculating low + 1 on each iteration.
* Note that this condition replacement is valid because Integer#div always rounds
* towards negative infinity.
*/
while (mid = rb_funcall(rb_funcall(high, '+', 1, low), id_div, 1, INT2FIX(2)),
rb_cmpint(rb_funcall(low, id_cmp, 1, mid), low, mid) < 0) {
BSEARCH_CHECK(mid); BSEARCH_CHECK(mid);
if (smaller) { if (smaller) {
high = mid; high = mid;
} }
else { else {
low = rb_funcall(mid, '+', 1, INT2FIX(1)); low = mid;
} }
} }
if (rb_equal(low, org_high)) {
BSEARCH_CHECK(low);
if (!smaller) return Qnil;
}
return satisfied; return satisfied;
} }
@ -696,8 +699,14 @@ range_bsearch(VALUE range)
* by the mantissa. This is true with or without implicit bit. * by the mantissa. This is true with or without implicit bit.
* *
* Finding the average of two ints needs to be careful about * Finding the average of two ints needs to be careful about
* potential overflow (since float to long can use 64 bits) * potential overflow (since float to long can use 64 bits).
* as well as the fact that -1/2 can be 0 or -1 in C89. *
* The half-open interval (low, high] indicates where the target is located.
* The loop continues until low and high are adjacent.
*
* -1/2 can be either 0 or -1 in C89. However, when low and high are not adjacent,
* the rounding direction of mid = (low + high) / 2 does not affect the result of
* the binary search.
* *
* Note that -0.0 is mapped to the same int as 0.0 as we don't want * Note that -0.0 is mapped to the same int as 0.0 as we don't want
* (-1...0.0).bsearch to yield -0.0. * (-1...0.0).bsearch to yield -0.0.
@ -706,23 +715,19 @@ range_bsearch(VALUE range)
#define BSEARCH(conv, excl) \ #define BSEARCH(conv, excl) \
do { \ do { \
RETURN_ENUMERATOR(range, 0, 0); \ RETURN_ENUMERATOR(range, 0, 0); \
if (excl) high--; \ if (!(excl)) high++; \
org_high = high; \ low--; \
while (low < high) { \ while (low + 1 < high) { \
mid = ((high < 0) == (low < 0)) ? low + ((high - low) / 2) \ mid = ((high < 0) == (low < 0)) ? low + ((high - low) / 2) \
: (low < -high) ? -((-1 - low - high)/2 + 1) : (low + high) / 2; \ : (low + high) / 2; \
BSEARCH_CHECK(conv(mid)); \ BSEARCH_CHECK(conv(mid)); \
if (smaller) { \ if (smaller) { \
high = mid; \ high = mid; \
} \ } \
else { \ else { \
low = mid + 1; \ low = mid; \
} \ } \
} \ } \
if (low == org_high) { \
BSEARCH_CHECK(conv(low)); \
if (!smaller) return Qnil; \
} \
return satisfied; \ return satisfied; \
} while (0) } while (0)
@ -730,7 +735,7 @@ range_bsearch(VALUE range)
do { \ do { \
long low = FIX2LONG(beg); \ long low = FIX2LONG(beg); \
long high = FIX2LONG(end); \ long high = FIX2LONG(end); \
long mid, org_high; \ long mid; \
BSEARCH(INT2FIX, (excl)); \ BSEARCH(INT2FIX, (excl)); \
} while (0) } while (0)
@ -744,7 +749,7 @@ range_bsearch(VALUE range)
else if (RB_FLOAT_TYPE_P(beg) || RB_FLOAT_TYPE_P(end)) { else if (RB_FLOAT_TYPE_P(beg) || RB_FLOAT_TYPE_P(end)) {
int64_t low = double_as_int64(NIL_P(beg) ? -HUGE_VAL : RFLOAT_VALUE(rb_Float(beg))); int64_t low = double_as_int64(NIL_P(beg) ? -HUGE_VAL : RFLOAT_VALUE(rb_Float(beg)));
int64_t high = double_as_int64(NIL_P(end) ? HUGE_VAL : RFLOAT_VALUE(rb_Float(end))); int64_t high = double_as_int64(NIL_P(end) ? HUGE_VAL : RFLOAT_VALUE(rb_Float(end)));
int64_t mid, org_high; int64_t mid;
BSEARCH(int64_as_double_to_num, EXCL(range)); BSEARCH(int64_as_double_to_num, EXCL(range));
} }
#endif #endif