cleanup search_layer()
to return only as many elements as needed, the caller no longer needs to overallocate result arrays for throwaway nodes
This commit is contained in:
parent
fa2078ddff
commit
885eb19823
@ -886,16 +886,29 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx, TABLE *graph,
|
|||||||
}
|
}
|
||||||
|
|
||||||
static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
|
static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
|
||||||
Neighborhood *start_nodes, uint ef, size_t layer,
|
Neighborhood *start_nodes, uint result_size,
|
||||||
Neighborhood *result, bool skip_deleted)
|
size_t layer, Neighborhood *result, bool construction)
|
||||||
{
|
{
|
||||||
DBUG_ASSERT(start_nodes->num > 0);
|
DBUG_ASSERT(start_nodes->num > 0);
|
||||||
result->num= 0;
|
result->num= 0;
|
||||||
|
|
||||||
MEM_ROOT * const root= graph->in_use->mem_root;
|
MEM_ROOT * const root= graph->in_use->mem_root;
|
||||||
|
Queue<Visited> candidates, best;
|
||||||
|
bool skip_deleted;
|
||||||
|
uint ef= result_size;
|
||||||
|
|
||||||
Queue<Visited> candidates;
|
if (construction)
|
||||||
Queue<Visited> best;
|
{
|
||||||
|
skip_deleted= false;
|
||||||
|
if (ef > 1)
|
||||||
|
ef= std::max(ef_construction, ef);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
skip_deleted= layer == 0;
|
||||||
|
if (ef > 1 || layer == 0)
|
||||||
|
ef= std::max(graph->in_use->variables.mhnsw_min_limit, ef);
|
||||||
|
}
|
||||||
|
|
||||||
// WARNING! heuristic here
|
// WARNING! heuristic here
|
||||||
const double est_heuristic= 8 * std::sqrt(ctx->max_neighbors(layer));
|
const double est_heuristic= 8 * std::sqrt(ctx->max_neighbors(layer));
|
||||||
@ -905,23 +918,21 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
|
|||||||
candidates.init(10000, false, Visited::cmp);
|
candidates.init(10000, false, Visited::cmp);
|
||||||
best.init(ef, true, Visited::cmp);
|
best.init(ef, true, Visited::cmp);
|
||||||
|
|
||||||
|
DBUG_ASSERT(start_nodes->num <= result_size);
|
||||||
for (size_t i=0; i < start_nodes->num; i++)
|
for (size_t i=0; i < start_nodes->num; i++)
|
||||||
{
|
{
|
||||||
Visited *v= visited.create(start_nodes->links[i]);
|
Visited *v= visited.create(start_nodes->links[i]);
|
||||||
candidates.push(v);
|
candidates.push(v);
|
||||||
if (skip_deleted && v->node->deleted)
|
if (skip_deleted && v->node->deleted)
|
||||||
continue;
|
continue;
|
||||||
if (best.elements() < ef)
|
best.push(v);
|
||||||
best.push(v);
|
|
||||||
else if (v->distance_to_target < best.top()->distance_to_target)
|
|
||||||
best.replace_top(v);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
float furthest_best= FLT_MAX;
|
float furthest_best= FLT_MAX;
|
||||||
while (candidates.elements())
|
while (candidates.elements())
|
||||||
{
|
{
|
||||||
const Visited &cur= *candidates.pop();
|
const Visited &cur= *candidates.pop();
|
||||||
if (cur.distance_to_target > furthest_best && best.elements() == ef)
|
if (cur.distance_to_target > furthest_best && best.is_full())
|
||||||
break; // All possible candidates are worse than what we have
|
break; // All possible candidates are worse than what we have
|
||||||
|
|
||||||
visited.flush();
|
visited.flush();
|
||||||
@ -941,7 +952,7 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
|
|||||||
if (int err= links[i]->load(graph))
|
if (int err= links[i]->load(graph))
|
||||||
return err;
|
return err;
|
||||||
Visited *v= visited.create(links[i]);
|
Visited *v= visited.create(links[i]);
|
||||||
if (best.elements() < ef)
|
if (!best.is_full())
|
||||||
{
|
{
|
||||||
candidates.push(v);
|
candidates.push(v);
|
||||||
if (skip_deleted && v->node->deleted)
|
if (skip_deleted && v->node->deleted)
|
||||||
@ -966,6 +977,9 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
|
|||||||
set_if_bigger(ctx->ef_power, ef_power); // not atomic, but it's ok
|
set_if_bigger(ctx->ef_power, ef_power); // not atomic, but it's ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
while (best.elements() > result_size)
|
||||||
|
best.pop();
|
||||||
|
|
||||||
result->num= best.elements();
|
result->num= best.elements();
|
||||||
for (FVectorNode **links= result->links + result->num; best.elements();)
|
for (FVectorNode **links= result->links + result->num; best.elements();)
|
||||||
*--links= best.pop()->node;
|
*--links= best.pop()->node;
|
||||||
@ -1033,9 +1047,10 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
|
|||||||
root_make_savepoint(thd->mem_root, &memroot_sv);
|
root_make_savepoint(thd->mem_root, &memroot_sv);
|
||||||
SCOPE_EXIT([memroot_sv](){ root_free_to_savepoint(&memroot_sv); });
|
SCOPE_EXIT([memroot_sv](){ root_free_to_savepoint(&memroot_sv); });
|
||||||
|
|
||||||
|
const size_t max_found= ctx->max_neighbors(0);
|
||||||
Neighborhood candidates, start_nodes;
|
Neighborhood candidates, start_nodes;
|
||||||
candidates.init(thd->alloc<FVectorNode*>(ef_construction + 7), ef_construction);
|
candidates.init(thd->alloc<FVectorNode*>(max_found + 7), max_found);
|
||||||
start_nodes.init(thd->alloc<FVectorNode*>(ef_construction + 7), ef_construction);
|
start_nodes.init(thd->alloc<FVectorNode*>(max_found + 7), max_found);
|
||||||
start_nodes.links[start_nodes.num++]= ctx->start;
|
start_nodes.links[start_nodes.num++]= ctx->start;
|
||||||
|
|
||||||
const double NORMALIZATION_FACTOR= 1 / std::log(ctx->M);
|
const double NORMALIZATION_FACTOR= 1 / std::log(ctx->M);
|
||||||
@ -1063,7 +1078,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
|
|||||||
{
|
{
|
||||||
uint max_neighbors= ctx->max_neighbors(cur_layer);
|
uint max_neighbors= ctx->max_neighbors(cur_layer);
|
||||||
if (int err= search_layer(ctx, graph, target->vec, &start_nodes,
|
if (int err= search_layer(ctx, graph, target->vec, &start_nodes,
|
||||||
ef_construction, cur_layer, &candidates, false))
|
max_neighbors, cur_layer, &candidates, true))
|
||||||
return err;
|
return err;
|
||||||
|
|
||||||
if (int err= select_neighbors(ctx, graph, cur_layer, *target, candidates,
|
if (int err= select_neighbors(ctx, graph, cur_layer, *target, candidates,
|
||||||
@ -1106,11 +1121,9 @@ int mhnsw_read_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
|
|||||||
if (err)
|
if (err)
|
||||||
return err;
|
return err;
|
||||||
|
|
||||||
size_t ef= thd->variables.mhnsw_min_limit;
|
|
||||||
|
|
||||||
Neighborhood candidates, start_nodes;
|
Neighborhood candidates, start_nodes;
|
||||||
candidates.init(thd->alloc<FVectorNode*>(ef + 7), ef);
|
candidates.init(thd->alloc<FVectorNode*>(limit + 7), limit);
|
||||||
start_nodes.init(thd->alloc<FVectorNode*>(ef + 7), ef);
|
start_nodes.init(thd->alloc<FVectorNode*>(limit + 7), limit);
|
||||||
|
|
||||||
// one could put all max_layer nodes in start_nodes
|
// one could put all max_layer nodes in start_nodes
|
||||||
// but it has no effect on the recall or speed
|
// but it has no effect on the recall or speed
|
||||||
@ -1146,8 +1159,8 @@ int mhnsw_read_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
|
|||||||
std::swap(start_nodes, candidates);
|
std::swap(start_nodes, candidates);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (int err= search_layer(ctx, graph, target, &start_nodes, ef, 0,
|
if (int err= search_layer(ctx, graph, target, &start_nodes,
|
||||||
&candidates, true))
|
static_cast<uint>(limit), 0, &candidates, false))
|
||||||
return err;
|
return err;
|
||||||
|
|
||||||
if (limit > candidates.num)
|
if (limit > candidates.num)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user