cleanup: one Item_func_vec_distance class, not three
prepare for MDEV-35450 VEC_DISTANCE auto-detection
This commit is contained in:
parent
d2ec5ec9c2
commit
528249a20a
@ -6258,7 +6258,8 @@ class Create_func_vec_distance_euclidean: public Create_func_arg2
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
Item *create_2_arg(THD *thd, Item *arg1, Item *arg2) override
|
Item *create_2_arg(THD *thd, Item *arg1, Item *arg2) override
|
||||||
{ return new (thd->mem_root) Item_func_vec_distance_euclidean(thd, arg1, arg2); }
|
{ return new (thd->mem_root)
|
||||||
|
Item_func_vec_distance(thd, arg1, arg2, Item_func_vec_distance::EUCLIDEAN); }
|
||||||
|
|
||||||
static Create_func_vec_distance_euclidean s_singleton;
|
static Create_func_vec_distance_euclidean s_singleton;
|
||||||
|
|
||||||
@ -6274,7 +6275,8 @@ class Create_func_vec_distance_cosine: public Create_func_arg2
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
Item *create_2_arg(THD *thd, Item *arg1, Item *arg2) override
|
Item *create_2_arg(THD *thd, Item *arg1, Item *arg2) override
|
||||||
{ return new (thd->mem_root) Item_func_vec_distance_cosine(thd, arg1, arg2); }
|
{ return new (thd->mem_root)
|
||||||
|
Item_func_vec_distance(thd, arg1, arg2, Item_func_vec_distance::COSINE); }
|
||||||
|
|
||||||
static Create_func_vec_distance_cosine s_singleton;
|
static Create_func_vec_distance_cosine s_singleton;
|
||||||
|
|
||||||
|
@ -24,7 +24,47 @@
|
|||||||
#include "vector_mhnsw.h"
|
#include "vector_mhnsw.h"
|
||||||
#include "sql_type_vector.h"
|
#include "sql_type_vector.h"
|
||||||
|
|
||||||
key_map Item_func_vec_distance_common::part_of_sortkey() const
|
static double calc_distance_euclidean(float *v1, float *v2, size_t v_len)
|
||||||
|
{
|
||||||
|
double d= 0;
|
||||||
|
for (size_t i= 0; i < v_len; i++, v1++, v2++)
|
||||||
|
{
|
||||||
|
float dist= get_float(v1) - get_float(v2);
|
||||||
|
d+= dist * dist;
|
||||||
|
}
|
||||||
|
return sqrt(d);
|
||||||
|
}
|
||||||
|
|
||||||
|
static double calc_distance_cosine(float *v1, float *v2, size_t v_len)
|
||||||
|
{
|
||||||
|
double dotp=0, abs1=0, abs2=0;
|
||||||
|
for (size_t i= 0; i < v_len; i++, v1++, v2++)
|
||||||
|
{
|
||||||
|
float f1= get_float(v1), f2= get_float(v2);
|
||||||
|
abs1+= f1 * f1;
|
||||||
|
abs2+= f2 * f2;
|
||||||
|
dotp+= f1 * f2;
|
||||||
|
}
|
||||||
|
return 1 - dotp/sqrt(abs1*abs2);
|
||||||
|
}
|
||||||
|
|
||||||
|
Item_func_vec_distance::Item_func_vec_distance(THD *thd, Item *a, Item *b,
|
||||||
|
distance_kind kind)
|
||||||
|
:Item_real_func(thd, a, b), kind(kind)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Item_func_vec_distance::fix_length_and_dec(THD *thd)
|
||||||
|
{
|
||||||
|
switch (kind) {
|
||||||
|
case EUCLIDEAN: calc_distance= calc_distance_euclidean; break;
|
||||||
|
case COSINE: calc_distance= calc_distance_cosine; break;
|
||||||
|
}
|
||||||
|
set_maybe_null(); // if wrong dimensions
|
||||||
|
return Item_real_func::fix_length_and_dec(thd);
|
||||||
|
}
|
||||||
|
|
||||||
|
key_map Item_func_vec_distance::part_of_sortkey() const
|
||||||
{
|
{
|
||||||
key_map map(0);
|
key_map map(0);
|
||||||
if (Item_field *item= get_field_arg())
|
if (Item_field *item= get_field_arg())
|
||||||
@ -33,13 +73,13 @@ key_map Item_func_vec_distance_common::part_of_sortkey() const
|
|||||||
KEY *keyinfo= f->table->s->key_info;
|
KEY *keyinfo= f->table->s->key_info;
|
||||||
for (uint i= f->table->s->keys; i < f->table->s->total_keys; i++)
|
for (uint i= f->table->s->keys; i < f->table->s->total_keys; i++)
|
||||||
if (keyinfo[i].algorithm == HA_KEY_ALG_VECTOR && f->key_start.is_set(i)
|
if (keyinfo[i].algorithm == HA_KEY_ALG_VECTOR && f->key_start.is_set(i)
|
||||||
&& mhnsw_uses_distance(f->table, keyinfo + i, this))
|
&& mhnsw_uses_distance(f->table, keyinfo + i) == kind)
|
||||||
map.set_bit(i);
|
map.set_bit(i);
|
||||||
}
|
}
|
||||||
return map;
|
return map;
|
||||||
}
|
}
|
||||||
|
|
||||||
double Item_func_vec_distance_common::val_real()
|
double Item_func_vec_distance::val_real()
|
||||||
{
|
{
|
||||||
String *r1= args[0]->val_str();
|
String *r1= args[0]->val_str();
|
||||||
String *r2= args[1]->val_str();
|
String *r2= args[1]->val_str();
|
||||||
|
@ -22,7 +22,7 @@
|
|||||||
#include "lex_string.h"
|
#include "lex_string.h"
|
||||||
#include "item_func.h"
|
#include "item_func.h"
|
||||||
|
|
||||||
class Item_func_vec_distance_common: public Item_real_func
|
class Item_func_vec_distance: public Item_real_func
|
||||||
{
|
{
|
||||||
Item_field *get_field_arg() const
|
Item_field *get_field_arg() const
|
||||||
{
|
{
|
||||||
@ -36,16 +36,20 @@ class Item_func_vec_distance_common: public Item_real_func
|
|||||||
{
|
{
|
||||||
return check_argument_types_or_binary(NULL, 0, arg_count);
|
return check_argument_types_or_binary(NULL, 0, arg_count);
|
||||||
}
|
}
|
||||||
virtual double calc_distance(float *v1, float *v2, size_t v_len) = 0;
|
double (*calc_distance)(float *v1, float *v2, size_t v_len);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Item_func_vec_distance_common(THD *thd, Item *a, Item *b)
|
enum distance_kind { EUCLIDEAN, COSINE } kind;
|
||||||
:Item_real_func(thd, a, b) {}
|
Item_func_vec_distance(THD *thd, Item *a, Item *b, distance_kind kind);
|
||||||
bool fix_length_and_dec(THD *thd) override
|
LEX_CSTRING func_name_cstring() const override
|
||||||
{
|
{
|
||||||
set_maybe_null(); // if wrong dimensions
|
static LEX_CSTRING name[3]= {
|
||||||
return Item_real_func::fix_length_and_dec(thd);
|
{ STRING_WITH_LEN("VEC_DISTANCE_EUCLIDEAN") },
|
||||||
|
{ STRING_WITH_LEN("VEC_DISTANCE_COSINE") }
|
||||||
|
};
|
||||||
|
return name[kind];
|
||||||
}
|
}
|
||||||
|
bool fix_length_and_dec(THD *thd) override;
|
||||||
double val_real() override;
|
double val_real() override;
|
||||||
Item *get_const_arg() const
|
Item *get_const_arg() const
|
||||||
{
|
{
|
||||||
@ -56,60 +60,8 @@ public:
|
|||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
key_map part_of_sortkey() const override;
|
key_map part_of_sortkey() const override;
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
class Item_func_vec_distance_euclidean: public Item_func_vec_distance_common
|
|
||||||
{
|
|
||||||
double calc_distance(float *v1, float *v2, size_t v_len) override
|
|
||||||
{
|
|
||||||
double d= 0;
|
|
||||||
for (size_t i= 0; i < v_len; i++, v1++, v2++)
|
|
||||||
{
|
|
||||||
float dist= get_float(v1) - get_float(v2);
|
|
||||||
d+= dist * dist;
|
|
||||||
}
|
|
||||||
return sqrt(d);
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
|
||||||
Item_func_vec_distance_euclidean(THD *thd, Item *a, Item *b)
|
|
||||||
:Item_func_vec_distance_common(thd, a, b) {}
|
|
||||||
LEX_CSTRING func_name_cstring() const override
|
|
||||||
{
|
|
||||||
static LEX_CSTRING name= { STRING_WITH_LEN("VEC_DISTANCE_EUCLIDEAN") };
|
|
||||||
return name;
|
|
||||||
}
|
|
||||||
Item *do_get_copy(THD *thd) const override
|
Item *do_get_copy(THD *thd) const override
|
||||||
{ return get_item_copy<Item_func_vec_distance_euclidean>(thd, this); }
|
{ return get_item_copy<Item_func_vec_distance>(thd, this); }
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
class Item_func_vec_distance_cosine: public Item_func_vec_distance_common
|
|
||||||
{
|
|
||||||
double calc_distance(float *v1, float *v2, size_t v_len) override
|
|
||||||
{
|
|
||||||
double dotp=0, abs1=0, abs2=0;
|
|
||||||
for (size_t i= 0; i < v_len; i++, v1++, v2++)
|
|
||||||
{
|
|
||||||
float f1= get_float(v1), f2= get_float(v2);
|
|
||||||
abs1+= f1 * f1;
|
|
||||||
abs2+= f2 * f2;
|
|
||||||
dotp+= f1 * f2;
|
|
||||||
}
|
|
||||||
return 1 - dotp/sqrt(abs1*abs2);
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
|
||||||
Item_func_vec_distance_cosine(THD *thd, Item *a, Item *b)
|
|
||||||
:Item_func_vec_distance_common(thd, a, b) {}
|
|
||||||
LEX_CSTRING func_name_cstring() const override
|
|
||||||
{
|
|
||||||
static LEX_CSTRING name= { STRING_WITH_LEN("VEC_DISTANCE_COSINE") };
|
|
||||||
return name;
|
|
||||||
}
|
|
||||||
Item *do_get_copy(THD *thd) const override
|
|
||||||
{ return get_item_copy<Item_func_vec_distance_cosine>(thd, this); }
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@
|
|||||||
#include "create_options.h"
|
#include "create_options.h"
|
||||||
#include "table_cache.h"
|
#include "table_cache.h"
|
||||||
#include "vector_mhnsw.h"
|
#include "vector_mhnsw.h"
|
||||||
#include "item_vectorfunc.h"
|
|
||||||
#include <scope.h>
|
#include <scope.h>
|
||||||
#include <my_atomic_wrapper.h>
|
#include <my_atomic_wrapper.h>
|
||||||
#include "bloom_filters.h"
|
#include "bloom_filters.h"
|
||||||
@ -1290,7 +1289,7 @@ int mhnsw_read_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
|
|||||||
{
|
{
|
||||||
THD *thd= table->in_use;
|
THD *thd= table->in_use;
|
||||||
TABLE *graph= table->hlindex;
|
TABLE *graph= table->hlindex;
|
||||||
auto *fun= static_cast<Item_func_vec_distance_common*>(dist->real_item());
|
auto *fun= static_cast<Item_func_vec_distance*>(dist->real_item());
|
||||||
DBUG_ASSERT(fun);
|
DBUG_ASSERT(fun);
|
||||||
|
|
||||||
limit= std::min<ulonglong>(limit, max_ef);
|
limit= std::min<ulonglong>(limit, max_ef);
|
||||||
@ -1507,11 +1506,11 @@ const LEX_CSTRING mhnsw_hlindex_table_def(THD *thd, uint ref_length)
|
|||||||
return {s, len};
|
return {s, len};
|
||||||
}
|
}
|
||||||
|
|
||||||
bool mhnsw_uses_distance(const TABLE *table, KEY *keyinfo, const Item *dist)
|
Item_func_vec_distance::distance_kind mhnsw_uses_distance(const TABLE *table, KEY *keyinfo)
|
||||||
{
|
{
|
||||||
if (keyinfo->option_struct->metric == EUCLIDEAN)
|
if (keyinfo->option_struct->metric == EUCLIDEAN)
|
||||||
return dynamic_cast<const Item_func_vec_distance_euclidean*>(dist) != NULL;
|
return Item_func_vec_distance::EUCLIDEAN;
|
||||||
return dynamic_cast<const Item_func_vec_distance_cosine*>(dist) != NULL;
|
return Item_func_vec_distance::COSINE;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -16,7 +16,7 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
#include <my_global.h>
|
#include <my_global.h>
|
||||||
#include "item.h"
|
#include "item_vectorfunc.h"
|
||||||
#include "m_string.h"
|
#include "m_string.h"
|
||||||
#include "structs.h"
|
#include "structs.h"
|
||||||
#include "table.h"
|
#include "table.h"
|
||||||
@ -33,7 +33,7 @@ int mhnsw_read_end(TABLE *table);
|
|||||||
int mhnsw_invalidate(TABLE *table, const uchar *rec, KEY *keyinfo);
|
int mhnsw_invalidate(TABLE *table, const uchar *rec, KEY *keyinfo);
|
||||||
int mhnsw_delete_all(TABLE *table, KEY *keyinfo, bool truncate);
|
int mhnsw_delete_all(TABLE *table, KEY *keyinfo, bool truncate);
|
||||||
void mhnsw_free(TABLE_SHARE *share);
|
void mhnsw_free(TABLE_SHARE *share);
|
||||||
bool mhnsw_uses_distance(const TABLE *table, KEY *keyinfo, const Item *dist);
|
Item_func_vec_distance::distance_kind mhnsw_uses_distance(const TABLE *table, KEY *keyinfo);
|
||||||
|
|
||||||
extern ha_create_table_option mhnsw_index_options[];
|
extern ha_create_table_option mhnsw_index_options[];
|
||||||
extern st_plugin_int *mhnsw_plugin;
|
extern st_plugin_int *mhnsw_plugin;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user