diff --git a/sql/vector_mhnsw.cc b/sql/vector_mhnsw.cc index 83ef816f6d5..248ea58e263 100644 --- a/sql/vector_mhnsw.cc +++ b/sql/vector_mhnsw.cc @@ -33,6 +33,48 @@ static constexpr uint ef_construction= 10; static constexpr uint max_ef= 10000; static constexpr size_t subdist_part= 192; static constexpr float subdist_margin= 1.05f; +static constexpr double subdist_stddev_threshold= 0.05; // 3σ, p>99.9% +static constexpr ulonglong subdist_stddev_valid= 10000; // sufficient + +/* + The class below can assume normal distribution and only collect + M1 and M2, or go beyond that and collect M3 and M4 to account + for non-gaussian. The second mode is useful for research and tuning, + but M2 is all we need in production for now. +*/ +#define STATS_NON_GAUSSIAN 0 /* 0 for no, 3 for yes */ +struct stats_collector +{ + ulonglong n= 0; + double M1= 0, M2= 0, M3= 0, M4= 0; + void add(ulonglong nB, double M1B, double M2B, double M3B, double M4B) + { // parallel Welford's online algorithm + ulonglong nA= n; + n+= nB; + double d= M1B-M1, dn= d/n, t1= d*dn*nA*nB; + M1+= dn*nB; +#if STATS_NON_GAUSSIAN + M4+= M4B + t1*dn*dn*(nA*nA-nA*nB+nB*nB) + + 6*dn*dn*(nA*nA*M2B+nB*nB*M2) + + 4*dn*(nA*M3B-nB*M3); + M3+= M3B + dn*t1*(nA-nB) + 3*dn*(nA*M2B-nB*M2) +#endif + M2+= t1; + } + void add(double x) { if (std::isfinite(x)) add(1, x, 0, 0, 0); } + void add(const stats_collector &b) { if (b.n) add(b.n, b.M1, b.M2, b.M3, b.M4); } + double mean() { return M1; } + double stddev() { return std::sqrt(M2/n); } + double skewness() { return M3/M2/stddev(); } + double kurtosis() { return n*M4/M2/M2 - STATS_NON_GAUSSIAN; } + double quantile(double z) // Cornish–Fisher expansion + { + return mean() + stddev() * (z + + skewness() / 6 * (z*z-1) + + kurtosis() / 24 * (z*z*z-3*z) - + skewness() * skewness() / 36 * (2*z*z*z-5*z)); + } +}; /* graph related statistical data. stored in MHNSW_Share. @@ -43,8 +85,10 @@ struct Stats double ef_power= 0.6; // for the bloom filter size heuristic float diameter= 0; size_t graph_size= 0; + stats_collector subdist; }; + static ulonglong mhnsw_max_cache_size; static MYSQL_SYSVAR_ULONGLONG(max_cache_size, mhnsw_max_cache_size, PLUGIN_VAR_RQCMDARG, "Upper limit for one MHNSW vector index cache", @@ -314,7 +358,8 @@ struct FVector dot_product(dims, other->dims, vec_len); } - float distance_greater_than(const FVector *other, size_t vec_len, float than) const + float distance_greater_than(const FVector *other, size_t vec_len, float than, + Stats *stats) const { float k = scale * other->scale; float dp= dot_product(dims, other->dims, subdist_part); @@ -324,6 +369,7 @@ struct FVector dp+= dot_product(dims+subdist_part, other->dims+subdist_part, vec_len - subdist_part); double dist= abs2 + other->abs2 - k * dp; + stats->subdist.add(subdist/dist); return dist; } }; @@ -356,6 +402,8 @@ struct Neighborhood: public Sql_alloc } }; +/* how to execute distance_greater_than() */ +enum dgt_mode { NOSTAT_NOSUBDIST, STAT_NOSUBDIST, STAT_SUBDIST }; /* One node in a graph = one row in the graph table @@ -396,7 +444,8 @@ public: FVectorNode(MHNSW_Share *ctx_, const void *tref_, uint8_t layer, const void *vec_); float distance_to(const FVector *other) const; - float distance_greater_than(const FVector *other, float than) const; + float distance_greater_than(const FVector *other, float than, dgt_mode mode, + Stats *stats) const; int load(TABLE *graph); int load_from_record(TABLE *graph); int save(TABLE *graph); @@ -613,6 +662,7 @@ public: stats.graph_size+= addend.graph_size; stats.diameter= std::max(stats.diameter, addend.diameter); stats.ef_power= std::max(stats.ef_power, addend.ef_power); + stats.subdist.add(addend.subdist); mysql_mutex_unlock(&cache_lock); } }; @@ -890,11 +940,14 @@ float FVectorNode::distance_to(const FVector *other) const return vec->distance_to(other, ctx->vec_len); } -float FVectorNode::distance_greater_than(const FVector *other, float than) const +float FVectorNode::distance_greater_than(const FVector *other, float than, + dgt_mode mode, Stats *stats) const { - if (!ctx->use_subdist) + static constexpr float mul[3]= {0, 10, subdist_margin }; + if (mode == NOSTAT_NOSUBDIST) return distance_to(other); - return vec->distance_greater_than(other, ctx->vec_len, than*subdist_margin); + return vec->distance_greater_than(other, ctx->vec_len, + than*mul[mode], stats); } int FVectorNode::alloc_neighborhood(uint8_t layer) @@ -1003,10 +1056,28 @@ struct MHNSW_param MHNSW_Share *ctx; TABLE *graph; int layer; - Stats stats; + Stats acc; + dgt_mode mode; + double max_est_size; MHNSW_param(MHNSW_Share *ctx, TABLE *graph, int layer) : ctx(ctx), graph(graph), layer(layer) - { ctx->read_stats(&stats); } + { + Stats stats; + ctx->read_stats(&stats); + max_est_size= stats.graph_size/1.3; + acc.diameter= stats.diameter; + acc.ef_power= stats.ef_power; + if (ctx->use_subdist) + { + if (stats.subdist.n > subdist_stddev_valid) + mode= stats.subdist.stddev() < subdist_stddev_threshold + ? STAT_SUBDIST : NOSTAT_NOSUBDIST; + else + mode= STAT_NOSUBDIST; + } + else + mode= NOSTAT_NOSUBDIST; + } }; /* one visited node during the search. caches the distance to target */ @@ -1103,7 +1174,7 @@ static int select_neighbors(MHNSW_param *p, FVectorNode *target, bool discard= false; for (size_t i=0; i < neighbors.num; i++) if ((discard= node->distance_greater_than(neighbors.links[i]->vec, - target_dista) <= target_dista)) + target_dista, p->mode, &p->acc) <= target_dista)) break; if (!discard) target->push_neighbor(p->layer, node); @@ -1231,8 +1302,8 @@ static int search_layer(MHNSW_param *p, const FVector *target, float threshold, // WARNING! heuristic here const double est_heuristic= 8 * std::sqrt(p->ctx->max_neighbors(p->layer)); - double est_size= est_heuristic * std::pow(ef, p->stats.ef_power); - set_if_smaller(est_size, p->stats.graph_size/1.3); + double est_size= est_heuristic * std::pow(ef, p->acc.ef_power); + est_size= std::min(est_size, p->max_est_size); VisitedSet visited(root, static_cast(est_size)); candidates.init(max_ef, false, Visited::cmp); @@ -1243,7 +1314,7 @@ static int search_layer(MHNSW_param *p, const FVector *target, float threshold, { auto node= inout->links[i]; Visited *v= visited.create(node, node->distance_to(target)); - p->stats.diameter= std::max(p->stats.diameter, v->distance_to_target); + p->acc.diameter= std::max(p->acc.diameter, v->distance_to_target); candidates.push(v); if ((skip_deleted && v->node->deleted) || threshold > NEAREST) continue; @@ -1251,7 +1322,7 @@ static int search_layer(MHNSW_param *p, const FVector *target, float threshold, } float furthest_best= best.is_empty() ? FLT_MAX - : generous_furthest(best, p->stats.diameter, generosity); + : generous_furthest(best, p->acc.diameter, generosity); while (candidates.elements()) { const Visited &cur= *candidates.pop(); @@ -1279,17 +1350,18 @@ static int search_layer(MHNSW_param *p, const FVector *target, float threshold, Visited *v= visited.create(links[i], links[i]->distance_to(target)); if (v->distance_to_target <= threshold) continue; - p->stats.diameter= std::max(p->stats.diameter, v->distance_to_target); + p->acc.diameter= std::max(p->acc.diameter, v->distance_to_target); candidates.safe_push(v); if (skip_deleted && v->node->deleted) continue; best.push(v); - furthest_best= generous_furthest(best, p->stats.diameter, generosity); + furthest_best= generous_furthest(best, p->acc.diameter, generosity); } else { Visited *v= visited.create(links[i], - links[i]->distance_greater_than(target, furthest_best)); + links[i]->distance_greater_than(target, furthest_best, + p->mode, &p->acc)); if (v->distance_to_target <= threshold) continue; if (v->distance_to_target < furthest_best) @@ -1300,7 +1372,7 @@ static int search_layer(MHNSW_param *p, const FVector *target, float threshold, if (v->distance_to_target < best.top()->distance_to_target) { best.replace_top(v); - furthest_best= generous_furthest(best, p->stats.diameter, generosity); + furthest_best= generous_furthest(best, p->acc.diameter, generosity); } } } @@ -1310,7 +1382,7 @@ static int search_layer(MHNSW_param *p, const FVector *target, float threshold, if (ef > 1 && visited.count > est_size) { double ef_power= std::log(visited.count/est_heuristic) / std::log(ef); - set_if_bigger(p->stats.ef_power, ef_power); + p->acc.ef_power= std::max(p->acc.ef_power, ef_power); } while (best.elements() > result_size) @@ -1386,6 +1458,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) SCOPE_EXIT([graph](){ graph->file->ha_rnd_end(); }); MHNSW_param p(ctx, graph, max_layer); + p.acc.graph_size= 1; // we're adding one node to the graph for (; p.layer > target_layer; p.layer--) { @@ -1406,8 +1479,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) if (int err= target->save(graph)) return err; - p.stats.graph_size= 1; - ctx->add_to_stats(p.stats); + ctx->add_to_stats(p.acc); if (target_layer > max_layer) ctx->start= target; @@ -1502,6 +1574,7 @@ int mhnsw_read_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) graph->file->ha_rnd_end(); return err; } + ctx->add_to_stats(p.acc); auto result= new (thd->mem_root) Search_context(&candidates, ctx, target); graph->context= result;