MDEV-36205 autodetection of subdist applicability

* skip subdist optimization for vectors shorter than 2*192, as before

* for longer vectors do not enable the optimization at first, but
  still calculate the extrapolated distance and accumulate statistics
  of extrapolated_distance/true_distance ratios.

  * calculated for every distance_greater_than() call
  * accumulated using parallel Welford's online algorithm,
    only M1 and M2, experiments show that higher-order central
    moment are not needed here
  * stored in MHNSW_Share

* after accumulating 10000 (best by test) data points, enable
  subdist optimization if the standard deviation of ratios (as above)
  is below 0.05, this distinguishes applicable datasets from
  not applicable datasets with a >99.9% probabilty

* to be on the safe side, keep calculating statistics (it's cheap)
  and disable subdist optimization again if standard deviation
  grows above 0.05
This commit is contained in:
Sergei Golubchik 2025-06-07 21:09:11 +02:00
parent 1315806024
commit 455f07b746

View File

@ -33,6 +33,48 @@ static constexpr uint ef_construction= 10;
static constexpr uint max_ef= 10000; static constexpr uint max_ef= 10000;
static constexpr size_t subdist_part= 192; static constexpr size_t subdist_part= 192;
static constexpr float subdist_margin= 1.05f; static constexpr float subdist_margin= 1.05f;
static constexpr double subdist_stddev_threshold= 0.05; // 3σ, p>99.9%
static constexpr ulonglong subdist_stddev_valid= 10000; // sufficient
/*
The class below can assume normal distribution and only collect
M1 and M2, or go beyond that and collect M3 and M4 to account
for non-gaussian. The second mode is useful for research and tuning,
but M2 is all we need in production for now.
*/
#define STATS_NON_GAUSSIAN 0 /* 0 for no, 3 for yes */
struct stats_collector
{
ulonglong n= 0;
double M1= 0, M2= 0, M3= 0, M4= 0;
void add(ulonglong nB, double M1B, double M2B, double M3B, double M4B)
{ // parallel Welford's online algorithm
ulonglong nA= n;
n+= nB;
double d= M1B-M1, dn= d/n, t1= d*dn*nA*nB;
M1+= dn*nB;
#if STATS_NON_GAUSSIAN
M4+= M4B + t1*dn*dn*(nA*nA-nA*nB+nB*nB)
+ 6*dn*dn*(nA*nA*M2B+nB*nB*M2)
+ 4*dn*(nA*M3B-nB*M3);
M3+= M3B + dn*t1*(nA-nB) + 3*dn*(nA*M2B-nB*M2)
#endif
M2+= t1;
}
void add(double x) { if (std::isfinite(x)) add(1, x, 0, 0, 0); }
void add(const stats_collector &b) { if (b.n) add(b.n, b.M1, b.M2, b.M3, b.M4); }
double mean() { return M1; }
double stddev() { return std::sqrt(M2/n); }
double skewness() { return M3/M2/stddev(); }
double kurtosis() { return n*M4/M2/M2 - STATS_NON_GAUSSIAN; }
double quantile(double z) // CornishFisher expansion
{
return mean() + stddev() * (z +
skewness() / 6 * (z*z-1) +
kurtosis() / 24 * (z*z*z-3*z) -
skewness() * skewness() / 36 * (2*z*z*z-5*z));
}
};
/* /*
graph related statistical data. stored in MHNSW_Share. graph related statistical data. stored in MHNSW_Share.
@ -43,8 +85,10 @@ struct Stats
double ef_power= 0.6; // for the bloom filter size heuristic double ef_power= 0.6; // for the bloom filter size heuristic
float diameter= 0; float diameter= 0;
size_t graph_size= 0; size_t graph_size= 0;
stats_collector subdist;
}; };
static ulonglong mhnsw_max_cache_size; static ulonglong mhnsw_max_cache_size;
static MYSQL_SYSVAR_ULONGLONG(max_cache_size, mhnsw_max_cache_size, static MYSQL_SYSVAR_ULONGLONG(max_cache_size, mhnsw_max_cache_size,
PLUGIN_VAR_RQCMDARG, "Upper limit for one MHNSW vector index cache", PLUGIN_VAR_RQCMDARG, "Upper limit for one MHNSW vector index cache",
@ -314,7 +358,8 @@ struct FVector
dot_product(dims, other->dims, vec_len); dot_product(dims, other->dims, vec_len);
} }
float distance_greater_than(const FVector *other, size_t vec_len, float than) const float distance_greater_than(const FVector *other, size_t vec_len, float than,
Stats *stats) const
{ {
float k = scale * other->scale; float k = scale * other->scale;
float dp= dot_product(dims, other->dims, subdist_part); float dp= dot_product(dims, other->dims, subdist_part);
@ -324,6 +369,7 @@ struct FVector
dp+= dot_product(dims+subdist_part, other->dims+subdist_part, dp+= dot_product(dims+subdist_part, other->dims+subdist_part,
vec_len - subdist_part); vec_len - subdist_part);
double dist= abs2 + other->abs2 - k * dp; double dist= abs2 + other->abs2 - k * dp;
stats->subdist.add(subdist/dist);
return dist; return dist;
} }
}; };
@ -356,6 +402,8 @@ struct Neighborhood: public Sql_alloc
} }
}; };
/* how to execute distance_greater_than() */
enum dgt_mode { NOSTAT_NOSUBDIST, STAT_NOSUBDIST, STAT_SUBDIST };
/* /*
One node in a graph = one row in the graph table One node in a graph = one row in the graph table
@ -396,7 +444,8 @@ public:
FVectorNode(MHNSW_Share *ctx_, const void *tref_, uint8_t layer, FVectorNode(MHNSW_Share *ctx_, const void *tref_, uint8_t layer,
const void *vec_); const void *vec_);
float distance_to(const FVector *other) const; float distance_to(const FVector *other) const;
float distance_greater_than(const FVector *other, float than) const; float distance_greater_than(const FVector *other, float than, dgt_mode mode,
Stats *stats) const;
int load(TABLE *graph); int load(TABLE *graph);
int load_from_record(TABLE *graph); int load_from_record(TABLE *graph);
int save(TABLE *graph); int save(TABLE *graph);
@ -613,6 +662,7 @@ public:
stats.graph_size+= addend.graph_size; stats.graph_size+= addend.graph_size;
stats.diameter= std::max(stats.diameter, addend.diameter); stats.diameter= std::max(stats.diameter, addend.diameter);
stats.ef_power= std::max(stats.ef_power, addend.ef_power); stats.ef_power= std::max(stats.ef_power, addend.ef_power);
stats.subdist.add(addend.subdist);
mysql_mutex_unlock(&cache_lock); mysql_mutex_unlock(&cache_lock);
} }
}; };
@ -890,11 +940,14 @@ float FVectorNode::distance_to(const FVector *other) const
return vec->distance_to(other, ctx->vec_len); return vec->distance_to(other, ctx->vec_len);
} }
float FVectorNode::distance_greater_than(const FVector *other, float than) const float FVectorNode::distance_greater_than(const FVector *other, float than,
dgt_mode mode, Stats *stats) const
{ {
if (!ctx->use_subdist) static constexpr float mul[3]= {0, 10, subdist_margin };
if (mode == NOSTAT_NOSUBDIST)
return distance_to(other); return distance_to(other);
return vec->distance_greater_than(other, ctx->vec_len, than*subdist_margin); return vec->distance_greater_than(other, ctx->vec_len,
than*mul[mode], stats);
} }
int FVectorNode::alloc_neighborhood(uint8_t layer) int FVectorNode::alloc_neighborhood(uint8_t layer)
@ -1003,10 +1056,28 @@ struct MHNSW_param
MHNSW_Share *ctx; MHNSW_Share *ctx;
TABLE *graph; TABLE *graph;
int layer; int layer;
Stats stats; Stats acc;
dgt_mode mode;
double max_est_size;
MHNSW_param(MHNSW_Share *ctx, TABLE *graph, int layer) MHNSW_param(MHNSW_Share *ctx, TABLE *graph, int layer)
: ctx(ctx), graph(graph), layer(layer) : ctx(ctx), graph(graph), layer(layer)
{ ctx->read_stats(&stats); } {
Stats stats;
ctx->read_stats(&stats);
max_est_size= stats.graph_size/1.3;
acc.diameter= stats.diameter;
acc.ef_power= stats.ef_power;
if (ctx->use_subdist)
{
if (stats.subdist.n > subdist_stddev_valid)
mode= stats.subdist.stddev() < subdist_stddev_threshold
? STAT_SUBDIST : NOSTAT_NOSUBDIST;
else
mode= STAT_NOSUBDIST;
}
else
mode= NOSTAT_NOSUBDIST;
}
}; };
/* one visited node during the search. caches the distance to target */ /* one visited node during the search. caches the distance to target */
@ -1103,7 +1174,7 @@ static int select_neighbors(MHNSW_param *p, FVectorNode *target,
bool discard= false; bool discard= false;
for (size_t i=0; i < neighbors.num; i++) for (size_t i=0; i < neighbors.num; i++)
if ((discard= node->distance_greater_than(neighbors.links[i]->vec, if ((discard= node->distance_greater_than(neighbors.links[i]->vec,
target_dista) <= target_dista)) target_dista, p->mode, &p->acc) <= target_dista))
break; break;
if (!discard) if (!discard)
target->push_neighbor(p->layer, node); target->push_neighbor(p->layer, node);
@ -1231,8 +1302,8 @@ static int search_layer(MHNSW_param *p, const FVector *target, float threshold,
// WARNING! heuristic here // WARNING! heuristic here
const double est_heuristic= 8 * std::sqrt(p->ctx->max_neighbors(p->layer)); const double est_heuristic= 8 * std::sqrt(p->ctx->max_neighbors(p->layer));
double est_size= est_heuristic * std::pow(ef, p->stats.ef_power); double est_size= est_heuristic * std::pow(ef, p->acc.ef_power);
set_if_smaller(est_size, p->stats.graph_size/1.3); est_size= std::min(est_size, p->max_est_size);
VisitedSet visited(root, static_cast<uint>(est_size)); VisitedSet visited(root, static_cast<uint>(est_size));
candidates.init(max_ef, false, Visited::cmp); candidates.init(max_ef, false, Visited::cmp);
@ -1243,7 +1314,7 @@ static int search_layer(MHNSW_param *p, const FVector *target, float threshold,
{ {
auto node= inout->links[i]; auto node= inout->links[i];
Visited *v= visited.create(node, node->distance_to(target)); Visited *v= visited.create(node, node->distance_to(target));
p->stats.diameter= std::max(p->stats.diameter, v->distance_to_target); p->acc.diameter= std::max(p->acc.diameter, v->distance_to_target);
candidates.push(v); candidates.push(v);
if ((skip_deleted && v->node->deleted) || threshold > NEAREST) if ((skip_deleted && v->node->deleted) || threshold > NEAREST)
continue; continue;
@ -1251,7 +1322,7 @@ static int search_layer(MHNSW_param *p, const FVector *target, float threshold,
} }
float furthest_best= best.is_empty() ? FLT_MAX float furthest_best= best.is_empty() ? FLT_MAX
: generous_furthest(best, p->stats.diameter, generosity); : generous_furthest(best, p->acc.diameter, generosity);
while (candidates.elements()) while (candidates.elements())
{ {
const Visited &cur= *candidates.pop(); const Visited &cur= *candidates.pop();
@ -1279,17 +1350,18 @@ static int search_layer(MHNSW_param *p, const FVector *target, float threshold,
Visited *v= visited.create(links[i], links[i]->distance_to(target)); Visited *v= visited.create(links[i], links[i]->distance_to(target));
if (v->distance_to_target <= threshold) if (v->distance_to_target <= threshold)
continue; continue;
p->stats.diameter= std::max(p->stats.diameter, v->distance_to_target); p->acc.diameter= std::max(p->acc.diameter, v->distance_to_target);
candidates.safe_push(v); candidates.safe_push(v);
if (skip_deleted && v->node->deleted) if (skip_deleted && v->node->deleted)
continue; continue;
best.push(v); best.push(v);
furthest_best= generous_furthest(best, p->stats.diameter, generosity); furthest_best= generous_furthest(best, p->acc.diameter, generosity);
} }
else else
{ {
Visited *v= visited.create(links[i], Visited *v= visited.create(links[i],
links[i]->distance_greater_than(target, furthest_best)); links[i]->distance_greater_than(target, furthest_best,
p->mode, &p->acc));
if (v->distance_to_target <= threshold) if (v->distance_to_target <= threshold)
continue; continue;
if (v->distance_to_target < furthest_best) if (v->distance_to_target < furthest_best)
@ -1300,7 +1372,7 @@ static int search_layer(MHNSW_param *p, const FVector *target, float threshold,
if (v->distance_to_target < best.top()->distance_to_target) if (v->distance_to_target < best.top()->distance_to_target)
{ {
best.replace_top(v); best.replace_top(v);
furthest_best= generous_furthest(best, p->stats.diameter, generosity); furthest_best= generous_furthest(best, p->acc.diameter, generosity);
} }
} }
} }
@ -1310,7 +1382,7 @@ static int search_layer(MHNSW_param *p, const FVector *target, float threshold,
if (ef > 1 && visited.count > est_size) if (ef > 1 && visited.count > est_size)
{ {
double ef_power= std::log(visited.count/est_heuristic) / std::log(ef); double ef_power= std::log(visited.count/est_heuristic) / std::log(ef);
set_if_bigger(p->stats.ef_power, ef_power); p->acc.ef_power= std::max(p->acc.ef_power, ef_power);
} }
while (best.elements() > result_size) while (best.elements() > result_size)
@ -1386,6 +1458,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
SCOPE_EXIT([graph](){ graph->file->ha_rnd_end(); }); SCOPE_EXIT([graph](){ graph->file->ha_rnd_end(); });
MHNSW_param p(ctx, graph, max_layer); MHNSW_param p(ctx, graph, max_layer);
p.acc.graph_size= 1; // we're adding one node to the graph
for (; p.layer > target_layer; p.layer--) for (; p.layer > target_layer; p.layer--)
{ {
@ -1406,8 +1479,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
if (int err= target->save(graph)) if (int err= target->save(graph))
return err; return err;
p.stats.graph_size= 1; ctx->add_to_stats(p.acc);
ctx->add_to_stats(p.stats);
if (target_layer > max_layer) if (target_layer > max_layer)
ctx->start= target; ctx->start= target;
@ -1502,6 +1574,7 @@ int mhnsw_read_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
graph->file->ha_rnd_end(); graph->file->ha_rnd_end();
return err; return err;
} }
ctx->add_to_stats(p.acc);
auto result= new (thd->mem_root) Search_context(&candidates, ctx, target); auto result= new (thd->mem_root) Search_context(&candidates, ctx, target);
graph->context= result; graph->context= result;