diff --git a/sql/item_create.cc b/sql/item_create.cc index edbac314564..ae0802d5fdf 100644 --- a/sql/item_create.cc +++ b/sql/item_create.cc @@ -6258,7 +6258,8 @@ class Create_func_vec_distance_euclidean: public Create_func_arg2 { public: Item *create_2_arg(THD *thd, Item *arg1, Item *arg2) override - { return new (thd->mem_root) Item_func_vec_distance_euclidean(thd, arg1, arg2); } + { return new (thd->mem_root) + Item_func_vec_distance(thd, arg1, arg2, Item_func_vec_distance::EUCLIDEAN); } static Create_func_vec_distance_euclidean s_singleton; @@ -6274,7 +6275,8 @@ class Create_func_vec_distance_cosine: public Create_func_arg2 { public: Item *create_2_arg(THD *thd, Item *arg1, Item *arg2) override - { return new (thd->mem_root) Item_func_vec_distance_cosine(thd, arg1, arg2); } + { return new (thd->mem_root) + Item_func_vec_distance(thd, arg1, arg2, Item_func_vec_distance::COSINE); } static Create_func_vec_distance_cosine s_singleton; diff --git a/sql/item_vectorfunc.cc b/sql/item_vectorfunc.cc index e7686657728..f4bc5bbe5b4 100644 --- a/sql/item_vectorfunc.cc +++ b/sql/item_vectorfunc.cc @@ -24,7 +24,47 @@ #include "vector_mhnsw.h" #include "sql_type_vector.h" -key_map Item_func_vec_distance_common::part_of_sortkey() const +static double calc_distance_euclidean(float *v1, float *v2, size_t v_len) +{ + double d= 0; + for (size_t i= 0; i < v_len; i++, v1++, v2++) + { + float dist= get_float(v1) - get_float(v2); + d+= dist * dist; + } + return sqrt(d); +} + +static double calc_distance_cosine(float *v1, float *v2, size_t v_len) +{ + double dotp=0, abs1=0, abs2=0; + for (size_t i= 0; i < v_len; i++, v1++, v2++) + { + float f1= get_float(v1), f2= get_float(v2); + abs1+= f1 * f1; + abs2+= f2 * f2; + dotp+= f1 * f2; + } + return 1 - dotp/sqrt(abs1*abs2); +} + +Item_func_vec_distance::Item_func_vec_distance(THD *thd, Item *a, Item *b, + distance_kind kind) + :Item_real_func(thd, a, b), kind(kind) +{ +} + +bool Item_func_vec_distance::fix_length_and_dec(THD *thd) +{ + switch (kind) { + case EUCLIDEAN: calc_distance= calc_distance_euclidean; break; + case COSINE: calc_distance= calc_distance_cosine; break; + } + set_maybe_null(); // if wrong dimensions + return Item_real_func::fix_length_and_dec(thd); +} + +key_map Item_func_vec_distance::part_of_sortkey() const { key_map map(0); if (Item_field *item= get_field_arg()) @@ -33,13 +73,13 @@ key_map Item_func_vec_distance_common::part_of_sortkey() const KEY *keyinfo= f->table->s->key_info; for (uint i= f->table->s->keys; i < f->table->s->total_keys; i++) if (keyinfo[i].algorithm == HA_KEY_ALG_VECTOR && f->key_start.is_set(i) - && mhnsw_uses_distance(f->table, keyinfo + i, this)) + && mhnsw_uses_distance(f->table, keyinfo + i) == kind) map.set_bit(i); } return map; } -double Item_func_vec_distance_common::val_real() +double Item_func_vec_distance::val_real() { String *r1= args[0]->val_str(); String *r2= args[1]->val_str(); diff --git a/sql/item_vectorfunc.h b/sql/item_vectorfunc.h index 58dc300c451..6e5a956c033 100644 --- a/sql/item_vectorfunc.h +++ b/sql/item_vectorfunc.h @@ -22,7 +22,7 @@ #include "lex_string.h" #include "item_func.h" -class Item_func_vec_distance_common: public Item_real_func +class Item_func_vec_distance: public Item_real_func { Item_field *get_field_arg() const { @@ -36,16 +36,20 @@ class Item_func_vec_distance_common: public Item_real_func { return check_argument_types_or_binary(NULL, 0, arg_count); } - virtual double calc_distance(float *v1, float *v2, size_t v_len) = 0; + double (*calc_distance)(float *v1, float *v2, size_t v_len); public: - Item_func_vec_distance_common(THD *thd, Item *a, Item *b) - :Item_real_func(thd, a, b) {} - bool fix_length_and_dec(THD *thd) override + enum distance_kind { EUCLIDEAN, COSINE } kind; + Item_func_vec_distance(THD *thd, Item *a, Item *b, distance_kind kind); + LEX_CSTRING func_name_cstring() const override { - set_maybe_null(); // if wrong dimensions - return Item_real_func::fix_length_and_dec(thd); + static LEX_CSTRING name[3]= { + { STRING_WITH_LEN("VEC_DISTANCE_EUCLIDEAN") }, + { STRING_WITH_LEN("VEC_DISTANCE_COSINE") } + }; + return name[kind]; } + bool fix_length_and_dec(THD *thd) override; double val_real() override; Item *get_const_arg() const { @@ -56,60 +60,8 @@ public: return NULL; } key_map part_of_sortkey() const override; -}; - - -class Item_func_vec_distance_euclidean: public Item_func_vec_distance_common -{ - double calc_distance(float *v1, float *v2, size_t v_len) override - { - double d= 0; - for (size_t i= 0; i < v_len; i++, v1++, v2++) - { - float dist= get_float(v1) - get_float(v2); - d+= dist * dist; - } - return sqrt(d); - } - -public: - Item_func_vec_distance_euclidean(THD *thd, Item *a, Item *b) - :Item_func_vec_distance_common(thd, a, b) {} - LEX_CSTRING func_name_cstring() const override - { - static LEX_CSTRING name= { STRING_WITH_LEN("VEC_DISTANCE_EUCLIDEAN") }; - return name; - } Item *do_get_copy(THD *thd) const override - { return get_item_copy(thd, this); } -}; - - -class Item_func_vec_distance_cosine: public Item_func_vec_distance_common -{ - double calc_distance(float *v1, float *v2, size_t v_len) override - { - double dotp=0, abs1=0, abs2=0; - for (size_t i= 0; i < v_len; i++, v1++, v2++) - { - float f1= get_float(v1), f2= get_float(v2); - abs1+= f1 * f1; - abs2+= f2 * f2; - dotp+= f1 * f2; - } - return 1 - dotp/sqrt(abs1*abs2); - } - -public: - Item_func_vec_distance_cosine(THD *thd, Item *a, Item *b) - :Item_func_vec_distance_common(thd, a, b) {} - LEX_CSTRING func_name_cstring() const override - { - static LEX_CSTRING name= { STRING_WITH_LEN("VEC_DISTANCE_COSINE") }; - return name; - } - Item *do_get_copy(THD *thd) const override - { return get_item_copy(thd, this); } + { return get_item_copy(thd, this); } }; diff --git a/sql/vector_mhnsw.cc b/sql/vector_mhnsw.cc index 8fbeffb3a58..85036f51e0a 100644 --- a/sql/vector_mhnsw.cc +++ b/sql/vector_mhnsw.cc @@ -20,7 +20,6 @@ #include "create_options.h" #include "table_cache.h" #include "vector_mhnsw.h" -#include "item_vectorfunc.h" #include #include #include "bloom_filters.h" @@ -1290,7 +1289,7 @@ int mhnsw_read_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) { THD *thd= table->in_use; TABLE *graph= table->hlindex; - auto *fun= static_cast(dist->real_item()); + auto *fun= static_cast(dist->real_item()); DBUG_ASSERT(fun); limit= std::min(limit, max_ef); @@ -1507,11 +1506,11 @@ const LEX_CSTRING mhnsw_hlindex_table_def(THD *thd, uint ref_length) return {s, len}; } -bool mhnsw_uses_distance(const TABLE *table, KEY *keyinfo, const Item *dist) +Item_func_vec_distance::distance_kind mhnsw_uses_distance(const TABLE *table, KEY *keyinfo) { if (keyinfo->option_struct->metric == EUCLIDEAN) - return dynamic_cast(dist) != NULL; - return dynamic_cast(dist) != NULL; + return Item_func_vec_distance::EUCLIDEAN; + return Item_func_vec_distance::COSINE; } /* diff --git a/sql/vector_mhnsw.h b/sql/vector_mhnsw.h index 5f60fcf2d2b..fbb61e14773 100644 --- a/sql/vector_mhnsw.h +++ b/sql/vector_mhnsw.h @@ -16,7 +16,7 @@ */ #include -#include "item.h" +#include "item_vectorfunc.h" #include "m_string.h" #include "structs.h" #include "table.h" @@ -33,7 +33,7 @@ int mhnsw_read_end(TABLE *table); int mhnsw_invalidate(TABLE *table, const uchar *rec, KEY *keyinfo); int mhnsw_delete_all(TABLE *table, KEY *keyinfo, bool truncate); void mhnsw_free(TABLE_SHARE *share); -bool mhnsw_uses_distance(const TABLE *table, KEY *keyinfo, const Item *dist); +Item_func_vec_distance::distance_kind mhnsw_uses_distance(const TABLE *table, KEY *keyinfo); extern ha_create_table_option mhnsw_index_options[]; extern st_plugin_int *mhnsw_plugin;