diff --git a/contrib/dblink/dblink.c b/contrib/dblink/dblink.c index a6b28faf72d..a33ce3425b9 100644 --- a/contrib/dblink/dblink.c +++ b/contrib/dblink/dblink.c @@ -100,7 +100,7 @@ static PGresult *storeQueryResult(volatile storeInfo *sinfo, PGconn *conn, const static void storeRow(volatile storeInfo *sinfo, PGresult *res, bool first); static remoteConn *getConnectionByName(const char *name); static HTAB *createConnHash(void); -static void createNewConnection(const char *name, remoteConn *rconn); +static remoteConn *createNewConnection(const char *name); static void deleteConnection(const char *name); static char **get_pkey_attnames(Relation rel, int16 *indnkeyatts); static char **get_text_array_contents(ArrayType *array, int *numitems); @@ -113,7 +113,7 @@ static HeapTuple get_tuple_of_interest(Relation rel, int *pkattnums, int pknumat static Relation get_rel_from_relname(text *relname_text, LOCKMODE lockmode, AclMode aclmode); static char *generate_relation_name(Relation rel); static void dblink_connstr_check(const char *connstr); -static void dblink_security_check(PGconn *conn, remoteConn *rconn); +static void dblink_security_check(PGconn *conn, const char *connname); static void dblink_res_error(PGconn *conn, const char *conname, PGresult *res, bool fail, const char *fmt,...) pg_attribute_printf(5, 6); static char *get_connect_string(const char *servername); @@ -131,16 +131,22 @@ static remoteConn *pconn = NULL; static HTAB *remoteConnHash = NULL; /* - * Following is list that holds multiple remote connections. + * Following is hash that holds multiple remote connections. * Calling convention of each dblink function changes to accept - * connection name as the first parameter. The connection list is + * connection name as the first parameter. The connection hash is * much like ecpg e.g. a mapping between a name and a PGconn object. + * + * To avoid potentially leaking a PGconn object in case of out-of-memory + * errors, we first create the hash entry, then open the PGconn. + * Hence, a hash entry whose rconn.conn pointer is NULL must be + * understood as a leftover from a failed create; it should be ignored + * by lookup operations, and silently replaced by create operations. */ typedef struct remoteConnHashEnt { char name[NAMEDATALEN]; - remoteConn *rconn; + remoteConn rconn; } remoteConnHashEnt; /* initial number of connection hashes */ @@ -239,7 +245,7 @@ dblink_get_conn(char *conname_or_str, errmsg("could not establish connection"), errdetail_internal("%s", msg))); } - dblink_security_check(conn, rconn); + dblink_security_check(conn, NULL); if (PQclientEncoding(conn) != GetDatabaseEncoding()) PQsetClientEncoding(conn, GetDatabaseEncodingName()); freeconn = true; @@ -299,15 +305,6 @@ dblink_connect(PG_FUNCTION_ARGS) else if (PG_NARGS() == 1) conname_or_str = text_to_cstring(PG_GETARG_TEXT_PP(0)); - if (connname) - { - rconn = (remoteConn *) MemoryContextAlloc(TopMemoryContext, - sizeof(remoteConn)); - rconn->conn = NULL; - rconn->openCursorCount = 0; - rconn->newXactForCursor = false; - } - /* first check for valid foreign data server */ connstr = get_connect_string(conname_or_str); if (connstr == NULL) @@ -338,6 +335,13 @@ dblink_connect(PG_FUNCTION_ARGS) #endif } + /* if we need a hashtable entry, make that first, since it might fail */ + if (connname) + { + rconn = createNewConnection(connname); + Assert(rconn->conn == NULL); + } + /* OK to make connection */ conn = PQconnectdb(connstr); @@ -346,8 +350,8 @@ dblink_connect(PG_FUNCTION_ARGS) msg = pchomp(PQerrorMessage(conn)); PQfinish(conn); ReleaseExternalFD(); - if (rconn) - pfree(rconn); + if (connname) + deleteConnection(connname); ereport(ERROR, (errcode(ERRCODE_SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION), @@ -356,16 +360,16 @@ dblink_connect(PG_FUNCTION_ARGS) } /* check password actually used if not superuser */ - dblink_security_check(conn, rconn); + dblink_security_check(conn, connname); /* attempt to set client encoding to match server encoding, if needed */ if (PQclientEncoding(conn) != GetDatabaseEncoding()) PQsetClientEncoding(conn, GetDatabaseEncodingName()); + /* all OK, save away the conn */ if (connname) { rconn->conn = conn; - createNewConnection(connname, rconn); } else { @@ -409,10 +413,7 @@ dblink_disconnect(PG_FUNCTION_ARGS) PQfinish(conn); ReleaseExternalFD(); if (rconn) - { deleteConnection(conname); - pfree(rconn); - } else pconn->conn = NULL; @@ -1336,6 +1337,9 @@ dblink_get_connections(PG_FUNCTION_ARGS) hash_seq_init(&status, remoteConnHash); while ((hentry = (remoteConnHashEnt *) hash_seq_search(&status)) != NULL) { + /* ignore it if it's not an open connection */ + if (hentry->rconn.conn == NULL) + continue; /* stash away current value */ astate = accumArrayResult(astate, CStringGetTextDatum(hentry->name), @@ -2597,8 +2601,8 @@ getConnectionByName(const char *name) hentry = (remoteConnHashEnt *) hash_search(remoteConnHash, key, HASH_FIND, NULL); - if (hentry) - return hentry->rconn; + if (hentry && hentry->rconn.conn != NULL) + return &hentry->rconn; return NULL; } @@ -2614,8 +2618,8 @@ createConnHash(void) return hash_create("Remote Con hash", NUMCONN, &ctl, HASH_ELEM); } -static void -createNewConnection(const char *name, remoteConn *rconn) +static remoteConn * +createNewConnection(const char *name) { remoteConnHashEnt *hentry; bool found; @@ -2629,19 +2633,15 @@ createNewConnection(const char *name, remoteConn *rconn) hentry = (remoteConnHashEnt *) hash_search(remoteConnHash, key, HASH_ENTER, &found); - if (found) - { - PQfinish(rconn->conn); - ReleaseExternalFD(); - pfree(rconn); - + if (found && hentry->rconn.conn != NULL) ereport(ERROR, (errcode(ERRCODE_DUPLICATE_OBJECT), errmsg("duplicate connection name"))); - } - hentry->rconn = rconn; - strlcpy(hentry->name, name, sizeof(hentry->name)); + /* New, or reusable, so initialize the rconn struct to zeroes */ + memset(&hentry->rconn, 0, sizeof(remoteConn)); + + return &hentry->rconn; } static void @@ -2667,7 +2667,7 @@ deleteConnection(const char *name) } static void -dblink_security_check(PGconn *conn, remoteConn *rconn) +dblink_security_check(PGconn *conn, const char *connname) { if (!superuser()) { @@ -2675,8 +2675,8 @@ dblink_security_check(PGconn *conn, remoteConn *rconn) { PQfinish(conn); ReleaseExternalFD(); - if (rconn) - pfree(rconn); + if (connname) + deleteConnection(connname); ereport(ERROR, (errcode(ERRCODE_S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED),