MDEV-36205 subdist optimization

* in addition to X.distance_to(Y) operation, introduce
  X.distance_greater_than(Y,Z) to be used like

  - if (X.distance_to(Y) > Z)
  + if (X.distance_greater_than(Y,Z) > Z)

  here we can afford to calculate the distance approximately,
  as long as it's greater than Z

* calculate the distance approximately, by looking only at the
  first 192 dimensions and extrapolate

* if it's larger than Z*1.05 (safety margin), assume |X-Y| > Z too
  and just return the extrapolated distance

* 192 and 1.05 are best by test

* optimization is obviously not applicable for vectors shorter than
  192 dimensions, let's not apply it for vectors shorter than 2*192
  dimensions because it needs to stop SIMD calculations and
  calculate the intermediate distance, this is not cheap

* this works reasonably well for dbpedia-openai and for randomized
  (all vectors multiplied by a random orthogonal matrix) variants
  of gist and *mnist, but completely destroys recall for
  original (non-randomized) gist and *mnist.

* up to 50% speedup for dbpedia-openai-500k, 40% faster run for gist-960
This commit is contained in:
Sergei Golubchik 2025-06-03 22:56:41 +02:00
parent 3d360db8b1
commit da4ae022d4

View File

@ -31,6 +31,8 @@ static constexpr float NEAREST = -1.0f;
static constexpr float alpha = 1.1f;
static constexpr uint ef_construction= 10;
static constexpr uint max_ef= 10000;
static constexpr size_t subdist_part= 192;
static constexpr float subdist_margin= 1.05f;
/*
graph related statistical data. stored in MHNSW_Share.
@ -88,9 +90,9 @@ class FVectorNode;
struct FVector
{
static constexpr size_t data_header= sizeof(float);
static constexpr size_t alloc_header= data_header + sizeof(float);
static constexpr size_t alloc_header= data_header + sizeof(float)*2;
float abs2, scale;
float abs2, subabs2, scale;
int16_t dims[4];
uchar *data() const { return (uchar*)(&scale); }
@ -101,18 +103,28 @@ struct FVector
static size_t data_to_value_size(size_t data_size)
{ return (data_size - data_header)*2; }
static const FVector *create(metric_type metric, void *mem, const void *src, size_t src_len);
static const FVector *create(const MHNSW_Share *ctx, void *mem, const void *src);
void postprocess(size_t vec_len)
void postprocess(bool use_subdist, size_t vec_len)
{
int16_t *d= dims;
fix_tail(vec_len);
abs2= scale * scale * dot_product(dims, dims, vec_len) / 2;
if (use_subdist)
{
subabs2= scale * scale * dot_product(d, d, subdist_part) / 2;
d+= subdist_part;
vec_len-= subdist_part;
}
else
subabs2= 0;
abs2= subabs2 + scale * scale * dot_product(d, d, vec_len) / 2;
}
#ifdef AVX2_IMPLEMENTATION
/************* AVX2 *****************************************************/
static constexpr size_t AVX2_bytes= 256/8;
static constexpr size_t AVX2_dims= AVX2_bytes/sizeof(int16_t);
static_assert(subdist_part % AVX2_dims == 0);
AVX2_IMPLEMENTATION
static float dot_product(const int16_t *v1, const int16_t *v2, size_t len)
@ -150,6 +162,7 @@ struct FVector
/************* AVX512 ****************************************************/
static constexpr size_t AVX512_bytes= 512/8;
static constexpr size_t AVX512_dims= AVX512_bytes/sizeof(int16_t);
static_assert(subdist_part % AVX512_dims == 0);
AVX512_IMPLEMENTATION
static float dot_product(const int16_t *v1, const int16_t *v2, size_t len)
@ -191,6 +204,7 @@ struct FVector
#ifdef NEON_IMPLEMENTATION
static constexpr size_t NEON_bytes= 128 / 8;
static constexpr size_t NEON_dims= NEON_bytes / sizeof(int16_t);
static_assert(subdist_part % NEON_dims == 0);
static float dot_product(const int16_t *v1, const int16_t *v2, size_t len)
{
@ -224,6 +238,7 @@ struct FVector
/************* POWERPC *****************************************************/
static constexpr size_t POWER_bytes= 128 / 8; // Assume 128-bit vector width
static constexpr size_t POWER_dims= POWER_bytes / sizeof(int16_t);
static_assert(subdist_part % POWER_dims == 0);
static float dot_product(const int16_t *v1, const int16_t *v2, size_t len)
{
@ -298,6 +313,19 @@ struct FVector
return abs2 + other->abs2 - scale * other->scale *
dot_product(dims, other->dims, vec_len);
}
float distance_greater_than(const FVector *other, size_t vec_len, float than) const
{
float k = scale * other->scale;
float dp= dot_product(dims, other->dims, subdist_part);
float subdist= (subabs2 + other->subabs2 - k * dp)/subdist_part*vec_len;
if (subdist > than)
return subdist;
dp+= dot_product(dims+subdist_part, other->dims+subdist_part,
vec_len - subdist_part);
double dist= abs2 + other->abs2 - k * dp;
return dist;
}
};
#pragma pack(pop)
@ -368,6 +396,7 @@ public:
FVectorNode(MHNSW_Share *ctx_, const void *tref_, uint8_t layer,
const void *vec_);
float distance_to(const FVector *other) const;
float distance_greater_than(const FVector *other, float than) const;
int load(TABLE *graph);
int load_from_record(TABLE *graph);
int save(TABLE *graph);
@ -430,6 +459,7 @@ public:
const uint gref_len;
const uint M;
metric_type metric;
bool use_subdist;
MHNSW_Share(TABLE *t)
: tref_len(t->file->ref_length), gref_len(t->hlindex->file->ref_length),
@ -475,6 +505,7 @@ public:
{
byte_len= len;
vec_len= len / sizeof(float);
use_subdist= vec_len >= subdist_part * 2;
}
static int acquire(MHNSW_Share **ctx, TABLE *table, bool for_update);
@ -806,23 +837,25 @@ int MHNSW_Share::acquire(MHNSW_Share **ctx, TABLE *table, bool for_update)
return 0;
}
const FVector *FVector::create(metric_type metric, void *mem, const void *src, size_t src_len)
const FVector *FVector::create(const MHNSW_Share *ctx, void *mem, const void *src)
{
float scale=0, *v= (float *)src;
size_t vec_len= src_len / sizeof(float);
for (size_t i= 0; i < vec_len; i++)
for (size_t i= 0; i < ctx->vec_len; i++)
if (std::abs(scale) < std::abs(get_float(v + i)))
scale= get_float(v + i);
FVector *vec= align_ptr(mem);
vec->scale= scale ? scale/32767 : 1;
for (size_t i= 0; i < vec_len; i++)
for (size_t i= 0; i < ctx->vec_len; i++)
vec->dims[i] = static_cast<int16_t>(std::round(get_float(v + i) / vec->scale));
vec->postprocess(vec_len);
if (metric == COSINE)
vec->postprocess(ctx->use_subdist, ctx->vec_len);
if (ctx->metric == COSINE)
{
if (vec->abs2 > 0.0f)
{
vec->scale/= std::sqrt(2*vec->abs2);
vec->subabs2/= 2*vec->abs2;
}
vec->abs2= 0.5f;
}
return vec;
@ -831,7 +864,7 @@ const FVector *FVector::create(metric_type metric, void *mem, const void *src, s
/* copy the vector, preprocessed as needed */
const FVector *FVectorNode::make_vec(const void *v)
{
return FVector::create(ctx->metric, tref() + tref_len(), v, ctx->byte_len);
return FVector::create(ctx, tref() + tref_len(), v);
}
FVectorNode::FVectorNode(MHNSW_Share *ctx_, const void *gref_)
@ -857,6 +890,13 @@ float FVectorNode::distance_to(const FVector *other) const
return vec->distance_to(other, ctx->vec_len);
}
float FVectorNode::distance_greater_than(const FVector *other, float than) const
{
if (!ctx->use_subdist)
return distance_to(other);
return vec->distance_greater_than(other, ctx->vec_len, than*subdist_margin);
}
int FVectorNode::alloc_neighborhood(uint8_t layer)
{
if (neighbors)
@ -909,7 +949,7 @@ int FVectorNode::load_from_record(TABLE *graph)
return my_errno= HA_ERR_CRASHED;
FVector *vec_ptr= FVector::align_ptr(tref() + tref_len());
memcpy(vec_ptr->data(), v->ptr(), v->length());
vec_ptr->postprocess(ctx->vec_len);
vec_ptr->postprocess(ctx->use_subdist, ctx->vec_len);
longlong layer= graph->field[FIELD_LAYER]->val_int();
if (layer > 100) // 10e30 nodes at M=2, more at larger M's
@ -995,17 +1035,15 @@ struct Visited : public Sql_alloc
class VisitedSet
{
MEM_ROOT *root;
const FVector *target;
PatternedSimdBloomFilter<FVectorNode> map;
const FVectorNode *nodes[8]= {0,0,0,0,0,0,0,0};
size_t idx= 1; // to record 0 in the filter
public:
uint count= 0;
VisitedSet(MEM_ROOT *root, const FVector *target, uint size) :
root(root), target(target), map(size, 0.01f) {}
Visited *create(FVectorNode *node)
VisitedSet(MEM_ROOT *root, uint size) : root(root), map(size, 0.01f) {}
Visited *create(FVectorNode *node, float dist)
{
auto *v= new (root) Visited(node, node->distance_to(target));
auto *v= new (root) Visited(node, dist);
insert(node);
count++;
return v;
@ -1064,7 +1102,8 @@ static int select_neighbors(MHNSW_param *p, FVectorNode *target,
const float target_dista= std::max(32*FLT_EPSILON, vec->distance_to_target / alpha);
bool discard= false;
for (size_t i=0; i < neighbors.num; i++)
if ((discard= node->distance_to(neighbors.links[i]->vec) <= target_dista))
if ((discard= node->distance_greater_than(neighbors.links[i]->vec,
target_dista) <= target_dista))
break;
if (!discard)
target->push_neighbor(p->layer, node);
@ -1194,7 +1233,7 @@ static int search_layer(MHNSW_param *p, const FVector *target, float threshold,
const double est_heuristic= 8 * std::sqrt(p->ctx->max_neighbors(p->layer));
double est_size= est_heuristic * std::pow(ef, p->stats.ef_power);
set_if_smaller(est_size, p->stats.graph_size/1.3);
VisitedSet visited(root, target, static_cast<uint>(est_size));
VisitedSet visited(root, static_cast<uint>(est_size));
candidates.init(max_ef, false, Visited::cmp);
best.init(ef, true, Visited::cmp);
@ -1202,7 +1241,8 @@ static int search_layer(MHNSW_param *p, const FVector *target, float threshold,
DBUG_ASSERT(inout->num <= result_size);
for (size_t i=0; i < inout->num; i++)
{
Visited *v= visited.create(inout->links[i]);
auto node= inout->links[i];
Visited *v= visited.create(node, node->distance_to(target));
p->stats.diameter= std::max(p->stats.diameter, v->distance_to_target);
candidates.push(v);
if ((skip_deleted && v->node->deleted) || threshold > NEAREST)
@ -1234,11 +1274,11 @@ static int search_layer(MHNSW_param *p, const FVector *target, float threshold,
continue;
if (int err= links[i]->load(p->graph))
return err;
Visited *v= visited.create(links[i]);
if (v->distance_to_target <= threshold)
continue;
if (!best.is_full())
{
Visited *v= visited.create(links[i], links[i]->distance_to(target));
if (v->distance_to_target <= threshold)
continue;
p->stats.diameter= std::max(p->stats.diameter, v->distance_to_target);
candidates.safe_push(v);
if (skip_deleted && v->node->deleted)
@ -1246,15 +1286,22 @@ static int search_layer(MHNSW_param *p, const FVector *target, float threshold,
best.push(v);
furthest_best= generous_furthest(best, p->stats.diameter, generosity);
}
else if (v->distance_to_target < furthest_best)
else
{
candidates.safe_push(v);
if (skip_deleted && v->node->deleted)
Visited *v= visited.create(links[i],
links[i]->distance_greater_than(target, furthest_best));
if (v->distance_to_target <= threshold)
continue;
if (v->distance_to_target < best.top()->distance_to_target)
if (v->distance_to_target < furthest_best)
{
best.replace_top(v);
furthest_best= generous_furthest(best, p->stats.diameter, generosity);
candidates.safe_push(v);
if (skip_deleted && v->node->deleted)
continue;
if (v->distance_to_target < best.top()->distance_to_target)
{
best.replace_top(v);
furthest_best= generous_furthest(best, p->stats.diameter, generosity);
}
}
}
}
@ -1432,8 +1479,8 @@ int mhnsw_read_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
}
const longlong max_layer= candidates.links[0]->max_layer;
auto target= FVector::create(ctx->metric, thd->alloc(FVector::alloc_size(ctx->vec_len)),
res->ptr(), res->length());
auto target= FVector::create(ctx, thd->alloc(FVector::alloc_size(ctx->vec_len)),
res->ptr());
if (int err= graph->file->ha_rnd_init(0))
return err;