diff --git a/sql/vector_mhnsw.cc b/sql/vector_mhnsw.cc index 92b5d1eeb4c..ee5e848267b 100644 --- a/sql/vector_mhnsw.cc +++ b/sql/vector_mhnsw.cc @@ -29,10 +29,6 @@ ulonglong mhnsw_cache_size; static constexpr float alpha = 1.1f; static constexpr uint ef_construction= 10; -// SIMD definitions -#define SIMD_word (256/8) -#define SIMD_floats (SIMD_word/sizeof(float)) - enum Graph_table_fields { FIELD_LAYER, FIELD_TREF, FIELD_VEC, FIELD_NEIGHBORS }; @@ -44,19 +40,110 @@ class MHNSW_Context; class FVectorNode; /* - One vector, an array of ctx->vec_len floats - - Aligned on 32-byte (SIMD_word) boundary for SIMD, vector lenght - is zero-padded to multiples of 8, for the same reason. + One vector, an array of coordinates in ctx->vec_len dimensions */ -class FVector +#pragma pack(push, 1) +struct FVector { -public: - FVector(MHNSW_Context *ctx_, MEM_ROOT *root, const void *vec_); - float *vec; -protected: - FVector() : vec(nullptr) {} + static constexpr size_t data_header= sizeof(float); + static constexpr size_t alloc_header= data_header + sizeof(float); + + float abs2, scale; + int16_t dims[4]; + + uchar *data() const { return (uchar*)(&scale); } + + static size_t data_size(size_t n) + { return data_header + n*2; } + + static size_t data_to_value_size(size_t data_size) + { return (data_size - data_header)*2; } + + static const FVector *create(void *mem, const void *src, size_t src_len) + { + float scale=0, *v= (float *)src; + size_t vec_len= src_len / sizeof(float); + for (size_t i= 0; i < vec_len; i++) + if (std::abs(scale) < std::abs(get_float(v + i))) + scale= get_float(v + i); + + FVector *vec= align_ptr(mem); + vec->scale= scale ? scale/32767 : 1; + for (size_t i= 0; i < vec_len; i++) + vec->dims[i] = static_cast(std::round(get_float(v + i) / vec->scale)); + vec->postprocess(vec_len); + return vec; + } + + void postprocess(size_t vec_len) + { + fix_tail(vec_len); + abs2= scale * scale * dot_product(dims, dims, vec_len) / 2; + } + +#ifdef AVX2_IMPLEMENTATION + /************* AVX2 *****************************************************/ + static constexpr size_t AVX2_bytes= 256/8; + static constexpr size_t AVX2_dims= AVX2_bytes/sizeof(int16_t); + + AVX2_IMPLEMENTATION + static float dot_product(const int16_t *v1, const int16_t *v2, size_t len) + { + typedef float v8f __attribute__((vector_size(AVX2_bytes))); + union { v8f v; __m256 i; } tmp; + __m256i *p1= (__m256i*)v1; + __m256i *p2= (__m256i*)v2; + v8f d= {0}; + for (size_t i= 0; i < (len + AVX2_dims-1)/AVX2_dims; p1++, p2++, i++) + { + tmp.i= _mm256_cvtepi32_ps(_mm256_madd_epi16(*p1, *p2)); + d+= tmp.v; + } + return d[0] + d[1] + d[2] + d[3] + d[4] + d[5] + d[6] + d[7]; + } + + AVX2_IMPLEMENTATION + static size_t alloc_size(size_t n) + { return alloc_header + MY_ALIGN(n*2, AVX2_bytes) + AVX2_bytes - 1; } + + AVX2_IMPLEMENTATION + static FVector *align_ptr(void *ptr) + { return (FVector*)(MY_ALIGN(((intptr)ptr) + alloc_header, AVX2_bytes) + - alloc_header); } + + AVX2_IMPLEMENTATION + void fix_tail(size_t vec_len) + { + bzero(dims + vec_len, (MY_ALIGN(vec_len, AVX2_dims) - vec_len)*2); + } +#endif + + /************* no-SIMD default ******************************************/ + DEFAULT_IMPLEMENTATION + static float dot_product(const int16_t *v1, const int16_t *v2, size_t len) + { + int64_t d= 0; + for (size_t i= 0; i < len; i++) + d+= int32_t(v1[i]) * int32_t(v2[i]); + return static_cast(d); + } + + DEFAULT_IMPLEMENTATION + static size_t alloc_size(size_t n) { return alloc_header + n*2; } + + DEFAULT_IMPLEMENTATION + static FVector *align_ptr(void *ptr) { return (FVector*)ptr; } + + DEFAULT_IMPLEMENTATION + void fix_tail(size_t) { } + + float distance_to(const FVector *other, size_t vec_len) const + { + return abs2 + other->abs2 - scale * other->scale * + dot_product(dims, other->dims, vec_len); + } }; +#pragma pack(pop) /* An array of pointers to graph nodes @@ -86,30 +173,6 @@ struct Neighborhood: public Sql_alloc }; -#ifdef AVX2_IMPLEMENTATION -AVX2_IMPLEMENTATION -float vec_distance(float *v1, float *v2, size_t len) -{ - typedef float v8f __attribute__((vector_size(SIMD_word))); - v8f *p1= (v8f*)v1; - v8f *p2= (v8f*)v2; - v8f d= {0}; - for (size_t i= 0; i < len/SIMD_floats; p1++, p2++, i++) - { - v8f dist= *p1 - *p2; - d+= dist * dist; - } - return d[0] + d[1] + d[2] + d[3] + d[4] + d[5] + d[6] + d[7]; -} -#endif - -DEFAULT_IMPLEMENTATION -float vec_distance(float *v1, float *v2, size_t len) -{ - return euclidean_vec_distance(v1, v2, len); -} - - /* One node in a graph = one row in the graph table @@ -132,14 +195,15 @@ float vec_distance(float *v1, float *v2, size_t len) is constrained by mhnsw_cache_size, so every byte matters here */ #pragma pack(push, 1) -class FVectorNode: public FVector +class FVectorNode { private: MHNSW_Context *ctx; - float *make_vec(const void *v); + const FVector *make_vec(const void *v); int alloc_neighborhood(uint8_t layer); public: + const FVector *vec= nullptr; Neighborhood *neighbors= nullptr; uint8_t max_layer; bool stored:1, deleted:1; @@ -147,7 +211,7 @@ public: FVectorNode(MHNSW_Context *ctx_, const void *gref_); FVectorNode(MHNSW_Context *ctx_, const void *tref_, uint8_t layer, const void *vec_); - float distance_to(const FVector &other) const; + float distance_to(const FVector *other) const; int load(TABLE *graph); int load_from_record(TABLE *graph); int save(TABLE *graph); @@ -192,7 +256,7 @@ class MHNSW_Context : public Sql_alloc void *alloc_node_internal() { return alloc_root(&root, sizeof(FVectorNode) + gref_len + tref_len - + vec_len * sizeof(float) + SIMD_word - 1); + + FVector::alloc_size(vec_len)); } protected: @@ -252,7 +316,7 @@ public: void set_lengths(size_t len) { byte_len= len; - vec_len= MY_ALIGN(byte_len/sizeof(float), SIMD_floats); + vec_len= len / sizeof(float); } static int acquire(MHNSW_Context **ctx, TABLE *table, bool for_update); @@ -505,42 +569,26 @@ int MHNSW_Context::acquire(MHNSW_Context **ctx, TABLE *table, bool for_update) return err; graph->file->position(graph->record[0]); - (*ctx)->set_lengths(graph->field[FIELD_VEC]->value_length()); + (*ctx)->set_lengths(FVector::data_to_value_size(graph->field[FIELD_VEC]->value_length())); (*ctx)->start= (*ctx)->get_node(graph->file->ref); return (*ctx)->start->load_from_record(graph); } -/* copy the vector, aligned and padded for SIMD */ -static float *make_vec(void *mem, const void *src, size_t src_len) +/* copy the vector, preprocessed as needed */ +const FVector *FVectorNode::make_vec(const void *v) { - auto dst= (float*)MY_ALIGN((intptr)mem, SIMD_word); - memcpy(dst, src, src_len); - const size_t start= src_len/sizeof(float); - for (size_t i= start; i < MY_ALIGN(start, SIMD_floats); i++) - dst[i]=0.0f; - return dst; -} - -FVector::FVector(MHNSW_Context *ctx, MEM_ROOT *root, const void *vec_) -{ - vec= make_vec(alloc_root(root, ctx->vec_len * sizeof(float) + SIMD_word - 1), - vec_, ctx->byte_len); -} - -float *FVectorNode::make_vec(const void *v) -{ - return ::make_vec(tref() + tref_len(), v, ctx->byte_len); + return FVector::create(tref() + tref_len(), v, ctx->byte_len); } FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *gref_) - : FVector(), ctx(ctx_), stored(true), deleted(false) + : ctx(ctx_), stored(true), deleted(false) { memcpy(gref(), gref_, gref_len()); } FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *tref_, uint8_t layer, const void *vec_) - : FVector(), ctx(ctx_), stored(false), deleted(false) + : ctx(ctx_), stored(false), deleted(false) { DBUG_ASSERT(tref_); memset(gref(), 0xff, gref_len()); // important: larger than any real gref @@ -550,9 +598,9 @@ FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *tref_, uint8_t layer, alloc_neighborhood(layer); } -float FVectorNode::distance_to(const FVector &other) const +float FVectorNode::distance_to(const FVector *other) const { - return vec_distance(vec, other.vec, ctx->vec_len); + return vec->distance_to(other, ctx->vec_len); } int FVectorNode::alloc_neighborhood(uint8_t layer) @@ -603,9 +651,11 @@ int FVectorNode::load_from_record(TABLE *graph) if (unlikely(!v)) return my_errno= HA_ERR_CRASHED; - if (v->length() != ctx->byte_len) + if (v->length() != FVector::data_size(ctx->vec_len)) return my_errno= HA_ERR_CRASHED; - float *vec_ptr= make_vec(v->ptr()); + FVector *vec_ptr= FVector::align_ptr(tref() + tref_len()); + memcpy(vec_ptr->data(), v->ptr(), v->length()); + vec_ptr->postprocess(ctx->vec_len); longlong layer= graph->field[FIELD_LAYER]->val_int(); if (layer > 100) // 10e30 nodes at M=2, more at larger M's @@ -676,13 +726,13 @@ struct Visited : public Sql_alloc class VisitedSet { MEM_ROOT *root; - const FVector ⌖ + const FVector *target; PatternedSimdBloomFilter map; const FVectorNode *nodes[8]= {0,0,0,0,0,0,0,0}; size_t idx= 1; // to record 0 in the filter public: uint count= 0; - VisitedSet(MEM_ROOT *root, const FVector &target, uint size) : + VisitedSet(MEM_ROOT *root, const FVector *target, uint size) : root(root), target(target), map(size, 0.01f) {} Visited *create(FVectorNode *node) { @@ -730,10 +780,10 @@ static int select_neighbors(MHNSW_Context *ctx, TABLE *graph, size_t layer, FVectorNode *node= candidates.links[i]; if (int err= node->load(graph)) return err; - pq.push(new (root) Visited(node, node->distance_to(target))); + pq.push(new (root) Visited(node, node->distance_to(target.vec))); } if (extra_candidate) - pq.push(new (root) Visited(extra_candidate, extra_candidate->distance_to(target))); + pq.push(new (root) Visited(extra_candidate, extra_candidate->distance_to(target.vec))); DBUG_ASSERT(pq.elements()); neighbors.num= 0; @@ -745,7 +795,7 @@ static int select_neighbors(MHNSW_Context *ctx, TABLE *graph, size_t layer, const float target_dista= vec->distance_to_target / alpha; bool discard= false; for (size_t i=0; i < neighbors.num; i++) - if ((discard= node->distance_to(*neighbors.links[i]) < target_dista)) + if ((discard= node->distance_to(neighbors.links[i]->vec) < target_dista)) break; if (!discard) target.push_neighbor(layer, node); @@ -775,7 +825,7 @@ int FVectorNode::save(TABLE *graph) graph->field[FIELD_TREF]->set_notnull(); graph->field[FIELD_TREF]->store_binary(tref(), tref_len()); } - graph->field[FIELD_VEC]->store_binary((uchar*)vec, ctx->byte_len); + graph->field[FIELD_VEC]->store_binary(vec->data(), FVector::data_size(ctx->vec_len)); size_t total_size= 0; for (size_t i=0; i <= max_layer; i++) @@ -835,7 +885,7 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx, TABLE *graph, return 0; } -static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector &target, +static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target, Neighborhood *start_nodes, uint ef, size_t layer, Neighborhood *result, bool skip_deleted) { @@ -1003,8 +1053,8 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) for (cur_layer= max_layer; cur_layer > target_layer; cur_layer--) { - if (int err= search_layer(ctx, graph, *target, &start_nodes, 1, cur_layer, - &candidates, false)) + if (int err= search_layer(ctx, graph, target->vec, &start_nodes, 1, + cur_layer, &candidates, false)) return err; std::swap(start_nodes, candidates); } @@ -1012,7 +1062,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) for (; cur_layer >= 0; cur_layer--) { uint max_neighbors= ctx->max_neighbors(cur_layer); - if (int err= search_layer(ctx, graph, *target, &start_nodes, + if (int err= search_layer(ctx, graph, target->vec, &start_nodes, ef_construction, cur_layer, &candidates, false)) return err; @@ -1069,13 +1119,20 @@ int mhnsw_read_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) /* if the query vector is NULL or invalid, VEC_DISTANCE will return NULL, so the result is basically unsorted, we can return rows - in any order. For simplicity let's sort by the start_node. + in any order. Let's use some hardcoded value here */ if (!res || ctx->byte_len != res->length()) - (res= &buf)->set((char*)start_nodes.links[0]->vec, ctx->byte_len, &my_charset_bin); + { + res= &buf; + buf.alloc(ctx->byte_len); + buf.length(ctx->byte_len); + for (size_t i=0; i < ctx->vec_len; i++) + ((float*)buf.ptr())[i]= i == 0; + } const longlong max_layer= start_nodes.links[0]->max_layer; - FVector target(ctx, thd->mem_root, res->ptr()); + auto target= FVector::create(thd->alloc(FVector::alloc_size(ctx->vec_len)), + res->ptr(), res->length()); if (int err= graph->file->ha_rnd_init(0)) return err;