diff --git a/sql/vector_mhnsw.cc b/sql/vector_mhnsw.cc index d4cda8b0cf6..83ef816f6d5 100644 --- a/sql/vector_mhnsw.cc +++ b/sql/vector_mhnsw.cc @@ -31,6 +31,8 @@ static constexpr float NEAREST = -1.0f; static constexpr float alpha = 1.1f; static constexpr uint ef_construction= 10; static constexpr uint max_ef= 10000; +static constexpr size_t subdist_part= 192; +static constexpr float subdist_margin= 1.05f; /* graph related statistical data. stored in MHNSW_Share. @@ -88,9 +90,9 @@ class FVectorNode; struct FVector { static constexpr size_t data_header= sizeof(float); - static constexpr size_t alloc_header= data_header + sizeof(float); + static constexpr size_t alloc_header= data_header + sizeof(float)*2; - float abs2, scale; + float abs2, subabs2, scale; int16_t dims[4]; uchar *data() const { return (uchar*)(&scale); } @@ -101,18 +103,28 @@ struct FVector static size_t data_to_value_size(size_t data_size) { return (data_size - data_header)*2; } - static const FVector *create(metric_type metric, void *mem, const void *src, size_t src_len); + static const FVector *create(const MHNSW_Share *ctx, void *mem, const void *src); - void postprocess(size_t vec_len) + void postprocess(bool use_subdist, size_t vec_len) { + int16_t *d= dims; fix_tail(vec_len); - abs2= scale * scale * dot_product(dims, dims, vec_len) / 2; + if (use_subdist) + { + subabs2= scale * scale * dot_product(d, d, subdist_part) / 2; + d+= subdist_part; + vec_len-= subdist_part; + } + else + subabs2= 0; + abs2= subabs2 + scale * scale * dot_product(d, d, vec_len) / 2; } #ifdef AVX2_IMPLEMENTATION /************* AVX2 *****************************************************/ static constexpr size_t AVX2_bytes= 256/8; static constexpr size_t AVX2_dims= AVX2_bytes/sizeof(int16_t); + static_assert(subdist_part % AVX2_dims == 0); AVX2_IMPLEMENTATION static float dot_product(const int16_t *v1, const int16_t *v2, size_t len) @@ -150,6 +162,7 @@ struct FVector /************* AVX512 ****************************************************/ static constexpr size_t AVX512_bytes= 512/8; static constexpr size_t AVX512_dims= AVX512_bytes/sizeof(int16_t); + static_assert(subdist_part % AVX512_dims == 0); AVX512_IMPLEMENTATION static float dot_product(const int16_t *v1, const int16_t *v2, size_t len) @@ -191,6 +204,7 @@ struct FVector #ifdef NEON_IMPLEMENTATION static constexpr size_t NEON_bytes= 128 / 8; static constexpr size_t NEON_dims= NEON_bytes / sizeof(int16_t); + static_assert(subdist_part % NEON_dims == 0); static float dot_product(const int16_t *v1, const int16_t *v2, size_t len) { @@ -224,6 +238,7 @@ struct FVector /************* POWERPC *****************************************************/ static constexpr size_t POWER_bytes= 128 / 8; // Assume 128-bit vector width static constexpr size_t POWER_dims= POWER_bytes / sizeof(int16_t); + static_assert(subdist_part % POWER_dims == 0); static float dot_product(const int16_t *v1, const int16_t *v2, size_t len) { @@ -298,6 +313,19 @@ struct FVector return abs2 + other->abs2 - scale * other->scale * dot_product(dims, other->dims, vec_len); } + + float distance_greater_than(const FVector *other, size_t vec_len, float than) const + { + float k = scale * other->scale; + float dp= dot_product(dims, other->dims, subdist_part); + float subdist= (subabs2 + other->subabs2 - k * dp)/subdist_part*vec_len; + if (subdist > than) + return subdist; + dp+= dot_product(dims+subdist_part, other->dims+subdist_part, + vec_len - subdist_part); + double dist= abs2 + other->abs2 - k * dp; + return dist; + } }; #pragma pack(pop) @@ -368,6 +396,7 @@ public: FVectorNode(MHNSW_Share *ctx_, const void *tref_, uint8_t layer, const void *vec_); float distance_to(const FVector *other) const; + float distance_greater_than(const FVector *other, float than) const; int load(TABLE *graph); int load_from_record(TABLE *graph); int save(TABLE *graph); @@ -430,6 +459,7 @@ public: const uint gref_len; const uint M; metric_type metric; + bool use_subdist; MHNSW_Share(TABLE *t) : tref_len(t->file->ref_length), gref_len(t->hlindex->file->ref_length), @@ -475,6 +505,7 @@ public: { byte_len= len; vec_len= len / sizeof(float); + use_subdist= vec_len >= subdist_part * 2; } static int acquire(MHNSW_Share **ctx, TABLE *table, bool for_update); @@ -806,23 +837,25 @@ int MHNSW_Share::acquire(MHNSW_Share **ctx, TABLE *table, bool for_update) return 0; } -const FVector *FVector::create(metric_type metric, void *mem, const void *src, size_t src_len) +const FVector *FVector::create(const MHNSW_Share *ctx, void *mem, const void *src) { float scale=0, *v= (float *)src; - size_t vec_len= src_len / sizeof(float); - for (size_t i= 0; i < vec_len; i++) + for (size_t i= 0; i < ctx->vec_len; i++) if (std::abs(scale) < std::abs(get_float(v + i))) scale= get_float(v + i); FVector *vec= align_ptr(mem); vec->scale= scale ? scale/32767 : 1; - for (size_t i= 0; i < vec_len; i++) + for (size_t i= 0; i < ctx->vec_len; i++) vec->dims[i] = static_cast(std::round(get_float(v + i) / vec->scale)); - vec->postprocess(vec_len); - if (metric == COSINE) + vec->postprocess(ctx->use_subdist, ctx->vec_len); + if (ctx->metric == COSINE) { if (vec->abs2 > 0.0f) + { vec->scale/= std::sqrt(2*vec->abs2); + vec->subabs2/= 2*vec->abs2; + } vec->abs2= 0.5f; } return vec; @@ -831,7 +864,7 @@ const FVector *FVector::create(metric_type metric, void *mem, const void *src, s /* copy the vector, preprocessed as needed */ const FVector *FVectorNode::make_vec(const void *v) { - return FVector::create(ctx->metric, tref() + tref_len(), v, ctx->byte_len); + return FVector::create(ctx, tref() + tref_len(), v); } FVectorNode::FVectorNode(MHNSW_Share *ctx_, const void *gref_) @@ -857,6 +890,13 @@ float FVectorNode::distance_to(const FVector *other) const return vec->distance_to(other, ctx->vec_len); } +float FVectorNode::distance_greater_than(const FVector *other, float than) const +{ + if (!ctx->use_subdist) + return distance_to(other); + return vec->distance_greater_than(other, ctx->vec_len, than*subdist_margin); +} + int FVectorNode::alloc_neighborhood(uint8_t layer) { if (neighbors) @@ -909,7 +949,7 @@ int FVectorNode::load_from_record(TABLE *graph) return my_errno= HA_ERR_CRASHED; FVector *vec_ptr= FVector::align_ptr(tref() + tref_len()); memcpy(vec_ptr->data(), v->ptr(), v->length()); - vec_ptr->postprocess(ctx->vec_len); + vec_ptr->postprocess(ctx->use_subdist, ctx->vec_len); longlong layer= graph->field[FIELD_LAYER]->val_int(); if (layer > 100) // 10e30 nodes at M=2, more at larger M's @@ -995,17 +1035,15 @@ struct Visited : public Sql_alloc class VisitedSet { MEM_ROOT *root; - const FVector *target; PatternedSimdBloomFilter map; const FVectorNode *nodes[8]= {0,0,0,0,0,0,0,0}; size_t idx= 1; // to record 0 in the filter public: uint count= 0; - VisitedSet(MEM_ROOT *root, const FVector *target, uint size) : - root(root), target(target), map(size, 0.01f) {} - Visited *create(FVectorNode *node) + VisitedSet(MEM_ROOT *root, uint size) : root(root), map(size, 0.01f) {} + Visited *create(FVectorNode *node, float dist) { - auto *v= new (root) Visited(node, node->distance_to(target)); + auto *v= new (root) Visited(node, dist); insert(node); count++; return v; @@ -1064,7 +1102,8 @@ static int select_neighbors(MHNSW_param *p, FVectorNode *target, const float target_dista= std::max(32*FLT_EPSILON, vec->distance_to_target / alpha); bool discard= false; for (size_t i=0; i < neighbors.num; i++) - if ((discard= node->distance_to(neighbors.links[i]->vec) <= target_dista)) + if ((discard= node->distance_greater_than(neighbors.links[i]->vec, + target_dista) <= target_dista)) break; if (!discard) target->push_neighbor(p->layer, node); @@ -1194,7 +1233,7 @@ static int search_layer(MHNSW_param *p, const FVector *target, float threshold, const double est_heuristic= 8 * std::sqrt(p->ctx->max_neighbors(p->layer)); double est_size= est_heuristic * std::pow(ef, p->stats.ef_power); set_if_smaller(est_size, p->stats.graph_size/1.3); - VisitedSet visited(root, target, static_cast(est_size)); + VisitedSet visited(root, static_cast(est_size)); candidates.init(max_ef, false, Visited::cmp); best.init(ef, true, Visited::cmp); @@ -1202,7 +1241,8 @@ static int search_layer(MHNSW_param *p, const FVector *target, float threshold, DBUG_ASSERT(inout->num <= result_size); for (size_t i=0; i < inout->num; i++) { - Visited *v= visited.create(inout->links[i]); + auto node= inout->links[i]; + Visited *v= visited.create(node, node->distance_to(target)); p->stats.diameter= std::max(p->stats.diameter, v->distance_to_target); candidates.push(v); if ((skip_deleted && v->node->deleted) || threshold > NEAREST) @@ -1234,11 +1274,11 @@ static int search_layer(MHNSW_param *p, const FVector *target, float threshold, continue; if (int err= links[i]->load(p->graph)) return err; - Visited *v= visited.create(links[i]); - if (v->distance_to_target <= threshold) - continue; if (!best.is_full()) { + Visited *v= visited.create(links[i], links[i]->distance_to(target)); + if (v->distance_to_target <= threshold) + continue; p->stats.diameter= std::max(p->stats.diameter, v->distance_to_target); candidates.safe_push(v); if (skip_deleted && v->node->deleted) @@ -1246,15 +1286,22 @@ static int search_layer(MHNSW_param *p, const FVector *target, float threshold, best.push(v); furthest_best= generous_furthest(best, p->stats.diameter, generosity); } - else if (v->distance_to_target < furthest_best) + else { - candidates.safe_push(v); - if (skip_deleted && v->node->deleted) + Visited *v= visited.create(links[i], + links[i]->distance_greater_than(target, furthest_best)); + if (v->distance_to_target <= threshold) continue; - if (v->distance_to_target < best.top()->distance_to_target) + if (v->distance_to_target < furthest_best) { - best.replace_top(v); - furthest_best= generous_furthest(best, p->stats.diameter, generosity); + candidates.safe_push(v); + if (skip_deleted && v->node->deleted) + continue; + if (v->distance_to_target < best.top()->distance_to_target) + { + best.replace_top(v); + furthest_best= generous_furthest(best, p->stats.diameter, generosity); + } } } } @@ -1432,8 +1479,8 @@ int mhnsw_read_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) } const longlong max_layer= candidates.links[0]->max_layer; - auto target= FVector::create(ctx->metric, thd->alloc(FVector::alloc_size(ctx->vec_len)), - res->ptr(), res->length()); + auto target= FVector::create(ctx, thd->alloc(FVector::alloc_size(ctx->vec_len)), + res->ptr()); if (int err= graph->file->ha_rnd_init(0)) return err;