diff --git a/sql/vector_mhnsw.cc b/sql/vector_mhnsw.cc index ee5e848267b..f2fd74e0ccb 100644 --- a/sql/vector_mhnsw.cc +++ b/sql/vector_mhnsw.cc @@ -886,16 +886,29 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx, TABLE *graph, } static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target, - Neighborhood *start_nodes, uint ef, size_t layer, - Neighborhood *result, bool skip_deleted) + Neighborhood *start_nodes, uint result_size, + size_t layer, Neighborhood *result, bool construction) { DBUG_ASSERT(start_nodes->num > 0); result->num= 0; MEM_ROOT * const root= graph->in_use->mem_root; + Queue candidates, best; + bool skip_deleted; + uint ef= result_size; - Queue candidates; - Queue best; + if (construction) + { + skip_deleted= false; + if (ef > 1) + ef= std::max(ef_construction, ef); + } + else + { + skip_deleted= layer == 0; + if (ef > 1 || layer == 0) + ef= std::max(graph->in_use->variables.mhnsw_min_limit, ef); + } // WARNING! heuristic here const double est_heuristic= 8 * std::sqrt(ctx->max_neighbors(layer)); @@ -905,23 +918,21 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target, candidates.init(10000, false, Visited::cmp); best.init(ef, true, Visited::cmp); + DBUG_ASSERT(start_nodes->num <= result_size); for (size_t i=0; i < start_nodes->num; i++) { Visited *v= visited.create(start_nodes->links[i]); candidates.push(v); if (skip_deleted && v->node->deleted) continue; - if (best.elements() < ef) - best.push(v); - else if (v->distance_to_target < best.top()->distance_to_target) - best.replace_top(v); + best.push(v); } float furthest_best= FLT_MAX; while (candidates.elements()) { const Visited &cur= *candidates.pop(); - if (cur.distance_to_target > furthest_best && best.elements() == ef) + if (cur.distance_to_target > furthest_best && best.is_full()) break; // All possible candidates are worse than what we have visited.flush(); @@ -941,7 +952,7 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target, if (int err= links[i]->load(graph)) return err; Visited *v= visited.create(links[i]); - if (best.elements() < ef) + if (!best.is_full()) { candidates.push(v); if (skip_deleted && v->node->deleted) @@ -966,6 +977,9 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target, set_if_bigger(ctx->ef_power, ef_power); // not atomic, but it's ok } + while (best.elements() > result_size) + best.pop(); + result->num= best.elements(); for (FVectorNode **links= result->links + result->num; best.elements();) *--links= best.pop()->node; @@ -1033,9 +1047,10 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) root_make_savepoint(thd->mem_root, &memroot_sv); SCOPE_EXIT([memroot_sv](){ root_free_to_savepoint(&memroot_sv); }); + const size_t max_found= ctx->max_neighbors(0); Neighborhood candidates, start_nodes; - candidates.init(thd->alloc(ef_construction + 7), ef_construction); - start_nodes.init(thd->alloc(ef_construction + 7), ef_construction); + candidates.init(thd->alloc(max_found + 7), max_found); + start_nodes.init(thd->alloc(max_found + 7), max_found); start_nodes.links[start_nodes.num++]= ctx->start; const double NORMALIZATION_FACTOR= 1 / std::log(ctx->M); @@ -1063,7 +1078,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) { uint max_neighbors= ctx->max_neighbors(cur_layer); if (int err= search_layer(ctx, graph, target->vec, &start_nodes, - ef_construction, cur_layer, &candidates, false)) + max_neighbors, cur_layer, &candidates, true)) return err; if (int err= select_neighbors(ctx, graph, cur_layer, *target, candidates, @@ -1106,11 +1121,9 @@ int mhnsw_read_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) if (err) return err; - size_t ef= thd->variables.mhnsw_min_limit; - Neighborhood candidates, start_nodes; - candidates.init(thd->alloc(ef + 7), ef); - start_nodes.init(thd->alloc(ef + 7), ef); + candidates.init(thd->alloc(limit + 7), limit); + start_nodes.init(thd->alloc(limit + 7), limit); // one could put all max_layer nodes in start_nodes // but it has no effect on the recall or speed @@ -1146,8 +1159,8 @@ int mhnsw_read_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) std::swap(start_nodes, candidates); } - if (int err= search_layer(ctx, graph, target, &start_nodes, ef, 0, - &candidates, true)) + if (int err= search_layer(ctx, graph, target, &start_nodes, + static_cast(limit), 0, &candidates, false)) return err; if (limit > candidates.num)