From 0da820cb12274bc6368239be8d169a4591eb3148 Mon Sep 17 00:00:00 2001 From: Sergei Golubchik Date: Sat, 24 Aug 2024 15:09:51 +0200 Subject: [PATCH] mhnsw: use plugin index options and transaction_participant API --- mysql-test/main/vector.result | 31 +++++++ mysql-test/main/vector.test | 14 +++ sql/create_options.cc | 2 +- sql/sql_builtin.cc.in | 5 +- sql/sql_class.h | 4 - sql/sql_show.cc | 4 +- sql/sql_table.cc | 12 ++- sql/sys_vars.cc | 20 ----- sql/table.cc | 10 +++ sql/vector_mhnsw.cc | 158 +++++++++++++++++++++++++--------- sql/vector_mhnsw.h | 3 +- 11 files changed, 187 insertions(+), 76 deletions(-) diff --git a/mysql-test/main/vector.result b/mysql-test/main/vector.result index 6b34c229302..5702dd1c368 100644 --- a/mysql-test/main/vector.result +++ b/mysql-test/main/vector.result @@ -17,6 +17,37 @@ show keys from t1; Table Non_unique Key_name Seq_in_index Column_name Collation Cardinality Sub_part Packed Null Index_type Comment Index_comment Ignored t1 0 PRIMARY 1 id A 0 NULL NULL BTREE NO t1 1 v 1 v A NULL 1 NULL VECTOR NO +drop table t1; +set mhnsw_max_edges_per_node=@@mhnsw_max_edges_per_node+1; +create table t1 (id int auto_increment primary key, v blob not null, vector index (v)); +show create table t1; +Table Create Table +t1 CREATE TABLE `t1` ( + `id` int(11) NOT NULL AUTO_INCREMENT, + `v` blob NOT NULL, + PRIMARY KEY (`id`), + VECTOR KEY `v` (`v`) `max_edges_per_node`=7 +) ENGINE=MyISAM DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_uca1400_ai_ci +show keys from t1; +Table Non_unique Key_name Seq_in_index Column_name Collation Cardinality Sub_part Packed Null Index_type Comment Index_comment Ignored +t1 0 PRIMARY 1 id A 0 NULL NULL BTREE NO +t1 1 v 1 v A NULL 1 NULL VECTOR NO +drop table t1; +create table t1 (id int auto_increment primary key, v blob not null, +vector index (v) max_edges_per_node=5); +show create table t1; +Table Create Table +t1 CREATE TABLE `t1` ( + `id` int(11) NOT NULL AUTO_INCREMENT, + `v` blob NOT NULL, + PRIMARY KEY (`id`), + VECTOR KEY `v` (`v`) `max_edges_per_node`=5 +) ENGINE=MyISAM DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_uca1400_ai_ci +show keys from t1; +Table Non_unique Key_name Seq_in_index Column_name Collation Cardinality Sub_part Packed Null Index_type Comment Index_comment Ignored +t1 0 PRIMARY 1 id A 0 NULL NULL BTREE NO +t1 1 v 1 v A NULL 1 NULL VECTOR NO +set mhnsw_max_edges_per_node=default; select * from information_schema.statistics where table_name='t1'; TABLE_CATALOG def TABLE_SCHEMA test diff --git a/mysql-test/main/vector.test b/mysql-test/main/vector.test index fec8fbaa955..e03c702fd94 100644 --- a/mysql-test/main/vector.test +++ b/mysql-test/main/vector.test @@ -11,6 +11,20 @@ create table t1 (id int auto_increment primary key, v blob not null, vector inde replace_result InnoDB MyISAM; show create table t1; show keys from t1; +drop table t1; +set mhnsw_max_edges_per_node=@@mhnsw_max_edges_per_node+1; +create table t1 (id int auto_increment primary key, v blob not null, vector index (v)); +replace_result InnoDB MyISAM; +show create table t1; +show keys from t1; +drop table t1; +create table t1 (id int auto_increment primary key, v blob not null, + vector index (v) max_edges_per_node=5); +replace_result InnoDB MyISAM; +show create table t1; +show keys from t1; +set mhnsw_max_edges_per_node=default; + query_vertical select * from information_schema.statistics where table_name='t1'; # print unpack("H*",pack("f*",map{rand}1..5)) insert t1 (v) values (x'e360d63ebe554f3fcdbc523f4522193f5236083d'), diff --git a/sql/create_options.cc b/sql/create_options.cc index b4d6dfa8ba4..df71631706f 100644 --- a/sql/create_options.cc +++ b/sql/create_options.cc @@ -765,7 +765,7 @@ bool engine_table_options_frm_read(const uchar *buff, size_t length, buff++; } - for (count=0; count < share->keys; count++) + for (count=0; count < share->total_keys; count++) { while (buff < buff_end && *buff) { diff --git a/sql/sql_builtin.cc.in b/sql/sql_builtin.cc.in index 810f98a876c..1896954db46 100644 --- a/sql/sql_builtin.cc.in +++ b/sql/sql_builtin.cc.in @@ -31,8 +31,8 @@ extern #endif builtin_maria_plugin @mysql_mandatory_plugins@ @mysql_optional_plugins@ - builtin_maria_binlog_plugin, - builtin_maria_mysql_password_plugin; + builtin_maria_binlog_plugin, builtin_maria_mhnsw_plugin, + builtin_maria_mysql_password_plugin; struct st_maria_plugin *mysql_optional_plugins[]= { @@ -42,5 +42,6 @@ struct st_maria_plugin *mysql_optional_plugins[]= struct st_maria_plugin *mysql_mandatory_plugins[]= { builtin_maria_binlog_plugin, builtin_maria_mysql_password_plugin, + builtin_maria_mhnsw_plugin, @mysql_mandatory_plugins@ 0 }; diff --git a/sql/sql_class.h b/sql/sql_class.h index 45c99d44f60..2b6ea6b23b7 100644 --- a/sql/sql_class.h +++ b/sql/sql_class.h @@ -922,10 +922,6 @@ typedef struct system_variables my_bool binlog_alter_two_phase; Charset_collation_map_st character_set_collations; - - /* Temporary for HNSW tests */ - uint mhnsw_max_edges_per_node; - uint mhnsw_min_limit; } SV; /** diff --git a/sql/sql_show.cc b/sql/sql_show.cc index 02afcf82010..b7f2e80f3ef 100644 --- a/sql/sql_show.cc +++ b/sql/sql_show.cc @@ -70,6 +70,7 @@ #include "opt_trace.h" #include "my_cpu.h" #include "key.h" +#include "vector_mhnsw.h" #include "lex_symbol.h" #define KEYWORD_SIZE 64 @@ -2501,7 +2502,8 @@ int show_create_table_ex(THD *thd, TABLE_LIST *table_list, packet->append(STRING_WITH_LEN(" */ ")); } append_create_options(thd, packet, key_info->option_list, check_options, - hton->index_options); + (key_info->algorithm == HA_KEY_ALG_VECTOR + ? mhnsw_index_options : hton->index_options)); } if (table->versioned()) diff --git a/sql/sql_table.cc b/sql/sql_table.cc index 9e13c30376c..8729883cc04 100644 --- a/sql/sql_table.cc +++ b/sql/sql_table.cc @@ -63,6 +63,7 @@ #include "rpl_mi.h" #include "rpl_rli.h" #include "log.h" +#include "vector_mhnsw.h" #ifdef WITH_WSREP #include "wsrep_mysqld.h" @@ -3209,6 +3210,8 @@ mysql_prepare_create_table_finalize(THD *thd, HA_CREATE_INFO *create_info, uint key_length=0; Create_field *auto_increment_key= 0; Key_part_spec *column; + st_plugin_int *index_plugin= hton2plugin[create_info->db_type->slot]; + ha_create_table_option *index_options= create_info->db_type->index_options; is_hash_field_needed= false; if (key->type == Key::IGNORE_KEY) @@ -3250,6 +3253,8 @@ mysql_prepare_create_table_finalize(THD *thd, HA_CREATE_INFO *create_info, } if (key->key_create_info.algorithm == HA_KEY_ALG_UNDEF) key->key_create_info.algorithm= HA_KEY_ALG_VECTOR; + index_plugin= mhnsw_plugin; + index_options= mhnsw_index_options; break; case Key::IGNORE_KEY: DBUG_ASSERT(0); @@ -3267,10 +3272,9 @@ mysql_prepare_create_table_finalize(THD *thd, HA_CREATE_INFO *create_info, key_info->usable_key_parts= key_number; key_info->algorithm= key->key_create_info.algorithm; key_info->option_list= key->option_list; - if (parse_option_list(thd, create_info->db_type, &key_info->option_struct, - &key_info->option_list, - create_info->db_type->index_options, FALSE, - thd->mem_root)) + if (parse_option_list(thd, index_plugin, &key_info->option_struct, + &key_info->option_list, index_options, + FALSE, thd->mem_root)) DBUG_RETURN(TRUE); if (key->type == Key::FULLTEXT) diff --git a/sql/sys_vars.cc b/sql/sys_vars.cc index ff49e2ffeac..b12601a922a 100644 --- a/sql/sys_vars.cc +++ b/sql/sys_vars.cc @@ -7448,23 +7448,3 @@ static Sys_var_ulonglong Sys_binlog_large_commit_threshold( // Allow a smaller minimum value for debug builds to help with testing VALID_RANGE(IF_DBUG(100, 10240) * 1024, ULLONG_MAX), DEFAULT(128 * 1024 * 1024), BLOCK_SIZE(1)); - -static Sys_var_uint Sys_mhnsw_min_limit( - "mhnsw_min_limit", - "Defines the minimal number of result candidates to look for in the " - "vector index for ORDER BY ... LIMIT N queries. The search will never " - "search for less rows than that, even if LIMIT is smaller. " - "This notably improves the search quality at low LIMIT values, " - "at the expense of search time", - SESSION_VAR(mhnsw_min_limit), CMD_LINE(REQUIRED_ARG), - VALID_RANGE(1, 65535), DEFAULT(20), BLOCK_SIZE(1)); -static Sys_var_uint Sys_mhnsw_max_edges_per_node( - "mhnsw_max_edges_per_node", - "Larger values means slower INSERT, larger index size and higher " - "memory consumption, but better search results", - SESSION_VAR(mhnsw_max_edges_per_node), CMD_LINE(REQUIRED_ARG), - VALID_RANGE(3, 200), DEFAULT(6), BLOCK_SIZE(1)); -static Sys_var_ulonglong Sys_mhnsw_cache_size( - "mhnsw_cache_size", "Size of the cache for the MHNSW vector index", - GLOBAL_VAR(mhnsw_cache_size), CMD_LINE(REQUIRED_ARG), - VALID_RANGE(1024*1024, SIZE_T_MAX), DEFAULT(16*1024*1024), BLOCK_SIZE(1)); diff --git a/sql/table.cc b/sql/table.cc index 15f68b0fe69..08ac70da752 100644 --- a/sql/table.cc +++ b/sql/table.cc @@ -3426,6 +3426,16 @@ int TABLE_SHARE::init_from_binary_frm_image(THD *thd, bool write, if (parse_engine_table_options(thd, handler_file->partition_ht(), share)) goto err; + if (share->hlindexes()) + { + DBUG_ASSERT(share->hlindexes() == 1); + keyinfo= share->key_info + share->keys; + if (parse_option_list(thd, mhnsw_plugin, &keyinfo->option_struct, + &keyinfo->option_list, mhnsw_index_options, + TRUE, thd->mem_root)) + goto err; + } + if (share->found_next_number_field) { reg_field= *share->found_next_number_field; diff --git a/sql/vector_mhnsw.cc b/sql/vector_mhnsw.cc index e153a1c1007..d560fd323d6 100644 --- a/sql/vector_mhnsw.cc +++ b/sql/vector_mhnsw.cc @@ -17,19 +17,38 @@ #include #include "key.h" // key_copy() +#include "create_options.h" #include "vector_mhnsw.h" #include "item_vectorfunc.h" #include #include #include "bloom_filters.h" -ulonglong mhnsw_cache_size; - // Algorithm parameters static constexpr float alpha = 1.1f; static constexpr float generosity = 1.1f; static constexpr uint ef_construction= 10; +static ulonglong mhnsw_cache_size; +static MYSQL_SYSVAR_ULONGLONG(cache_size, mhnsw_cache_size, + PLUGIN_VAR_RQCMDARG, "Size of the cache for the MHNSW vector index", + nullptr, nullptr, 16*1024*1024, 1024*1024, SIZE_T_MAX, 1); +static MYSQL_THDVAR_UINT(min_limit, PLUGIN_VAR_RQCMDARG, + "Defines the minimal number of result candidates to look for in the " + "vector index for ORDER BY ... LIMIT N queries. The search will never " + "search for less rows than that, even if LIMIT is smaller. " + "This notably improves the search quality at low LIMIT values, " + "at the expense of search time", nullptr, nullptr, 20, 1, 65535, 1); +static MYSQL_THDVAR_UINT(max_edges_per_node, PLUGIN_VAR_RQCMDARG, + "Larger values means slower INSERT, larger index size and higher " + "memory consumption, but better search results", + nullptr, nullptr, 6, 3, 200, 1); + +struct ha_index_option_struct +{ + ulonglong M; // option struct does not support uint +}; + enum Graph_table_fields { FIELD_LAYER, FIELD_TREF, FIELD_VEC, FIELD_NEIGHBORS }; @@ -277,7 +296,7 @@ public: MHNSW_Context(TABLE *t) : tref_len(t->file->ref_length), gref_len(t->hlindex->file->ref_length), - M(t->in_use->variables.mhnsw_max_edges_per_node) + M(static_cast(t->s->key_info[t->s->keys].option_struct->M)) { mysql_rwlock_init(PSI_INSTRUMENT_ME, &commit_lock); mysql_mutex_init(PSI_INSTRUMENT_ME, &cache_lock, MY_MUTEX_INIT_FAST); @@ -429,58 +448,59 @@ public: reset(nullptr); } - static MHNSW_Trx *get_from_thd(THD *thd, TABLE *table); + static MHNSW_Trx *get_from_thd(TABLE *table, bool for_update); // it's okay in a transaction-local cache, there's no concurrent access Hash_set &get_cache() { return node_cache; } - /* fake handlerton to use thd->ha_data and to get notified of commits */ - static struct MHNSW_hton : public handlerton - { - MHNSW_hton() - { - db_type= DB_TYPE_HLINDEX_HELPER; - flags = HTON_NOT_USER_SELECTABLE | HTON_HIDDEN; - savepoint_offset= 0; - savepoint_set= [](THD *, void *){ return 0; }; - savepoint_rollback_can_release_mdl= [](THD *){ return true; }; - savepoint_rollback= do_savepoint_rollback; - commit= do_commit; - rollback= do_rollback; - } - static int do_commit(THD *thd, bool); - static int do_rollback(THD *thd, bool); - static int do_savepoint_rollback(THD *thd, void *); - } hton; + static transaction_participant tp; + static int do_commit(THD *thd, bool); + static int do_savepoint_rollback(THD *thd, void *); + static int do_rollback(THD *thd, bool); }; -MHNSW_Trx::MHNSW_hton MHNSW_Trx::hton; - -int MHNSW_Trx::MHNSW_hton::do_savepoint_rollback(THD *thd, void *) +struct transaction_participant MHNSW_Trx::tp= { - for (auto trx= static_cast(thd_get_ha_data(thd, &hton)); + 0, 0, 0, + nullptr, /* close_connection */ + [](THD *, void *){ return 0; }, /* savepoint_set */ + MHNSW_Trx::do_savepoint_rollback, + [](THD *thd){ return true; }, /*savepoint_rollback_can_release_mdl*/ + nullptr, /*savepoint_release*/ + MHNSW_Trx::do_commit, MHNSW_Trx::do_rollback, + nullptr, /* prepare */ + nullptr, /* recover */ + nullptr, nullptr, /* commit/rollback_by_xid */ + nullptr, nullptr, /* recover_rollback_by_xid/recovery_done */ + nullptr, nullptr, nullptr, /* snapshot, commit/prepare_ordered */ + nullptr, nullptr /* checkpoint, versioned */ +}; + +int MHNSW_Trx::do_savepoint_rollback(THD *thd, void *) +{ + for (auto trx= static_cast(thd_get_ha_data(thd, &tp)); trx; trx= trx->next) trx->reset(nullptr); return 0; } -int MHNSW_Trx::MHNSW_hton::do_rollback(THD *thd, bool) +int MHNSW_Trx::do_rollback(THD *thd, bool) { MHNSW_Trx *trx_next; - for (auto trx= static_cast(thd_get_ha_data(thd, &hton)); + for (auto trx= static_cast(thd_get_ha_data(thd, &tp)); trx; trx= trx_next) { trx_next= trx->next; trx->~MHNSW_Trx(); } - thd_set_ha_data(current_thd, &hton, nullptr); + thd_set_ha_data(current_thd, &tp, nullptr); return 0; } -int MHNSW_Trx::MHNSW_hton::do_commit(THD *thd, bool) +int MHNSW_Trx::do_commit(THD *thd, bool) { MHNSW_Trx *trx_next; - for (auto trx= static_cast(thd_get_ha_data(thd, &hton)); + for (auto trx= static_cast(thd_get_ha_data(thd, &tp)); trx; trx= trx_next) { trx_next= trx->next; @@ -504,23 +524,30 @@ int MHNSW_Trx::MHNSW_hton::do_commit(THD *thd, bool) } trx->~MHNSW_Trx(); } - thd_set_ha_data(current_thd, &hton, nullptr); + thd_set_ha_data(current_thd, &tp, nullptr); return 0; } -MHNSW_Trx *MHNSW_Trx::get_from_thd(THD *thd, TABLE *table) +MHNSW_Trx *MHNSW_Trx::get_from_thd(TABLE *table, bool for_update) { - auto trx= static_cast(thd_get_ha_data(thd, &hton)); + if (!table->file->has_transactions()) + return NULL; + + THD *thd= table->in_use; + auto trx= static_cast(thd_get_ha_data(thd, &tp)); + if (!for_update && !trx) + return NULL; + while (trx && trx->table_share != table->s) trx= trx->next; if (!trx) { trx= new (&thd->transaction->mem_root) MHNSW_Trx(table); - trx->next= static_cast(thd_get_ha_data(thd, &hton)); - thd_set_ha_data(thd, &hton, trx); + trx->next= static_cast(thd_get_ha_data(thd, &tp)); + thd_set_ha_data(thd, &tp, trx); if (!trx->next) { bool all= thd_test_options(thd, OPTION_NOT_AUTOCOMMIT | OPTION_BEGIN); - trans_register_ha(thd, all, &hton, 0); + trans_register_ha(thd, all, &tp, 0); } } return trx; @@ -546,12 +573,8 @@ MHNSW_Context *MHNSW_Context::get_from_share(TABLE_SHARE *share, TABLE *table) int MHNSW_Context::acquire(MHNSW_Context **ctx, TABLE *table, bool for_update) { TABLE *graph= table->hlindex; - THD *thd= table->in_use; - if (table->file->has_transactions() && - (for_update || thd_get_ha_data(thd, &MHNSW_Trx::hton))) - *ctx= MHNSW_Trx::get_from_thd(thd, table); - else + if (!(*ctx= MHNSW_Trx::get_from_thd(table, for_update))) { *ctx= MHNSW_Context::get_from_share(table->s, table); if (table->file->has_transactions()) @@ -908,7 +931,7 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target, { skip_deleted= layer == 0; if (ef > 1 || layer == 0) - ef= std::max(graph->in_use->variables.mhnsw_min_limit, ef); + ef= std::max(THDVAR(graph->in_use, min_limit), ef); } // WARNING! heuristic here @@ -1287,3 +1310,52 @@ const LEX_CSTRING mhnsw_hlindex_table_def(THD *thd, uint ref_length) len= my_snprintf(s, len, templ, ref_length); return {s, len}; } + +/* + Declare the plugin and index options +*/ + +ha_create_table_option mhnsw_index_options[]= +{ + HA_IOPTION_SYSVAR("max_edges_per_node", M, max_edges_per_node), + HA_IOPTION_END +}; + +st_plugin_int *mhnsw_plugin; + +static int mhnsw_init(void *p) +{ + mhnsw_plugin= (st_plugin_int *)p; + mhnsw_plugin->data= &MHNSW_Trx::tp; + if (setup_transaction_participant(mhnsw_plugin)) + return 1; + + return resolve_sysvar_table_options(mhnsw_index_options); +} + +static int mhnsw_deinit(void *) +{ + free_sysvar_table_options(mhnsw_index_options); + return 0; +} + +static struct st_mysql_storage_engine mhnsw_daemon= +{ MYSQL_DAEMON_INTERFACE_VERSION }; + +static struct st_mysql_sys_var *mhnsw_sys_vars[]= +{ + MYSQL_SYSVAR(cache_size), + MYSQL_SYSVAR(max_edges_per_node), + MYSQL_SYSVAR(min_limit), + NULL +}; + +maria_declare_plugin(mhnsw) +{ + MYSQL_DAEMON_PLUGIN, + &mhnsw_daemon, "mhnsw", "MariaDB plc", + "A plugin for mhnsw vector index algorithm", + PLUGIN_LICENSE_GPL, mhnsw_init, mhnsw_deinit, 0x0100, NULL, + mhnsw_sys_vars, "1.0", MariaDB_PLUGIN_MATURITY_STABLE +} +maria_declare_plugin_end; diff --git a/sql/vector_mhnsw.h b/sql/vector_mhnsw.h index 12e109b339d..0e885a719f0 100644 --- a/sql/vector_mhnsw.h +++ b/sql/vector_mhnsw.h @@ -29,4 +29,5 @@ int mhnsw_invalidate(TABLE *table, const uchar *rec, KEY *keyinfo); int mhnsw_delete_all(TABLE *table, KEY *keyinfo, bool truncate); void mhnsw_free(TABLE_SHARE *share); -extern ulonglong mhnsw_cache_size; +extern ha_create_table_option mhnsw_index_options[]; +extern st_plugin_int *mhnsw_plugin;