src: more automatic memory management in node_crypto.cc

Prefer custom smart pointers fitted to the OpenSSL data structures
over more manual memory management and lots of `goto`s.

PR-URL: https://github.com/nodejs/node/pull/20238
Reviewed-By: Daniel Bevenius <daniel.bevenius@gmail.com>
Reviewed-By: Tobias Nießen <tniessen@tnie.de>
Reviewed-By: Joyee Cheung <joyeec9h3@gmail.com>
Reviewed-By: Tiancheng "Timothy" Gu <timothygu99@gmail.com>
Reviewed-By: Ben Noordhuis <info@bnoordhuis.nl>
Reviewed-By: James M Snell <jasnell@gmail.com>
This commit is contained in:
Anna Henningsen 2018-04-23 18:29:11 +02:00
parent 1cf7ef6433
commit d7cba76856
No known key found for this signature in database
GPG Key ID: 9C63F3A6CD2AD8F9
4 changed files with 723 additions and 896 deletions

File diff suppressed because it is too large Load Diff

View File

@ -75,6 +75,32 @@ struct MarkPopErrorOnReturn {
~MarkPopErrorOnReturn() { ERR_pop_to_mark(); } ~MarkPopErrorOnReturn() { ERR_pop_to_mark(); }
}; };
template <typename T, void (*function)(T*)>
struct FunctionDeleter {
void operator()(T* pointer) const { function(pointer); }
typedef std::unique_ptr<T, FunctionDeleter> Pointer;
};
template <typename T, void (*function)(T*)>
using DeleteFnPtr = typename FunctionDeleter<T, function>::Pointer;
// Define smart pointers for the most commonly used OpenSSL types:
using X509Pointer = DeleteFnPtr<X509, X509_free>;
using BIOPointer = DeleteFnPtr<BIO, BIO_free_all>;
using SSLCtxPointer = DeleteFnPtr<SSL_CTX, SSL_CTX_free>;
using SSLSessionPointer = DeleteFnPtr<SSL_SESSION, SSL_SESSION_free>;
using SSLPointer = DeleteFnPtr<SSL, SSL_free>;
using EVPKeyPointer = DeleteFnPtr<EVP_PKEY, EVP_PKEY_free>;
using EVPKeyCtxPointer = DeleteFnPtr<EVP_PKEY_CTX, EVP_PKEY_CTX_free>;
using EVPMDPointer = DeleteFnPtr<EVP_MD_CTX, EVP_MD_CTX_free>;
using RSAPointer = DeleteFnPtr<RSA, RSA_free>;
using BignumPointer = DeleteFnPtr<BIGNUM, BN_free>;
using NetscapeSPKIPointer = DeleteFnPtr<NETSCAPE_SPKI, NETSCAPE_SPKI_free>;
using ECGroupPointer = DeleteFnPtr<EC_GROUP, EC_GROUP_free>;
using ECPointPointer = DeleteFnPtr<EC_POINT, EC_POINT_free>;
using ECKeyPointer = DeleteFnPtr<EC_KEY, EC_KEY_free>;
using DHPointer = DeleteFnPtr<DH, DH_free>;
enum CheckResult { enum CheckResult {
CHECK_CERT_REVOKED = 0, CHECK_CERT_REVOKED = 0,
CHECK_OK = 1 CHECK_OK = 1
@ -87,14 +113,14 @@ extern void UseExtraCaCerts(const std::string& file);
class SecureContext : public BaseObject { class SecureContext : public BaseObject {
public: public:
~SecureContext() override { ~SecureContext() override {
FreeCTXMem(); Reset();
} }
static void Initialize(Environment* env, v8::Local<v8::Object> target); static void Initialize(Environment* env, v8::Local<v8::Object> target);
SSL_CTX* ctx_; SSLCtxPointer ctx_;
X509* cert_; X509Pointer cert_;
X509* issuer_; X509Pointer issuer_;
#ifndef OPENSSL_NO_ENGINE #ifndef OPENSSL_NO_ENGINE
bool client_cert_engine_provided_ = false; bool client_cert_engine_provided_ = false;
#endif // !OPENSSL_NO_ENGINE #endif // !OPENSSL_NO_ENGINE
@ -171,28 +197,16 @@ class SecureContext : public BaseObject {
#endif #endif
SecureContext(Environment* env, v8::Local<v8::Object> wrap) SecureContext(Environment* env, v8::Local<v8::Object> wrap)
: BaseObject(env, wrap), : BaseObject(env, wrap) {
ctx_(nullptr),
cert_(nullptr),
issuer_(nullptr) {
MakeWeak(); MakeWeak();
env->isolate()->AdjustAmountOfExternalAllocatedMemory(kExternalSize); env->isolate()->AdjustAmountOfExternalAllocatedMemory(kExternalSize);
} }
void FreeCTXMem() { inline void Reset() {
if (!ctx_) {
return;
}
env()->isolate()->AdjustAmountOfExternalAllocatedMemory(-kExternalSize); env()->isolate()->AdjustAmountOfExternalAllocatedMemory(-kExternalSize);
SSL_CTX_free(ctx_); ctx_.reset();
if (cert_ != nullptr) cert_.reset();
X509_free(cert_); issuer_.reset();
if (issuer_ != nullptr)
X509_free(issuer_);
ctx_ = nullptr;
cert_ = nullptr;
issuer_ = nullptr;
} }
}; };
@ -215,20 +229,15 @@ class SSLWrap {
cert_cb_(nullptr), cert_cb_(nullptr),
cert_cb_arg_(nullptr), cert_cb_arg_(nullptr),
cert_cb_running_(false) { cert_cb_running_(false) {
ssl_ = SSL_new(sc->ctx_); ssl_.reset(SSL_new(sc->ctx_.get()));
CHECK(ssl_);
env_->isolate()->AdjustAmountOfExternalAllocatedMemory(kExternalSize); env_->isolate()->AdjustAmountOfExternalAllocatedMemory(kExternalSize);
CHECK_NE(ssl_, nullptr);
} }
virtual ~SSLWrap() { virtual ~SSLWrap() {
DestroySSL(); DestroySSL();
if (next_sess_ != nullptr) {
SSL_SESSION_free(next_sess_);
next_sess_ = nullptr;
}
} }
inline SSL* ssl() const { return ssl_; }
inline void enable_session_callbacks() { session_callbacks_ = true; } inline void enable_session_callbacks() { session_callbacks_ = true; }
inline bool is_server() const { return kind_ == kServer; } inline bool is_server() const { return kind_ == kServer; }
inline bool is_client() const { return kind_ == kClient; } inline bool is_client() const { return kind_ == kClient; }
@ -319,8 +328,8 @@ class SSLWrap {
Environment* const env_; Environment* const env_;
Kind kind_; Kind kind_;
SSL_SESSION* next_sess_; SSLSessionPointer next_sess_;
SSL* ssl_; SSLPointer ssl_;
bool session_callbacks_; bool session_callbacks_;
bool new_session_wait_; bool new_session_wait_;
@ -344,10 +353,6 @@ class SSLWrap {
class CipherBase : public BaseObject { class CipherBase : public BaseObject {
public: public:
~CipherBase() override {
EVP_CIPHER_CTX_free(ctx_);
}
static void Initialize(Environment* env, v8::Local<v8::Object> target); static void Initialize(Environment* env, v8::Local<v8::Object> target);
protected: protected:
@ -407,7 +412,7 @@ class CipherBase : public BaseObject {
} }
private: private:
EVP_CIPHER_CTX* ctx_; DeleteFnPtr<EVP_CIPHER_CTX, EVP_CIPHER_CTX_free> ctx_;
const CipherKind kind_; const CipherKind kind_;
bool auth_tag_set_; bool auth_tag_set_;
unsigned int auth_tag_len_; unsigned int auth_tag_len_;
@ -418,8 +423,6 @@ class CipherBase : public BaseObject {
class Hmac : public BaseObject { class Hmac : public BaseObject {
public: public:
~Hmac() override;
static void Initialize(Environment* env, v8::Local<v8::Object> target); static void Initialize(Environment* env, v8::Local<v8::Object> target);
protected: protected:
@ -438,13 +441,11 @@ class Hmac : public BaseObject {
} }
private: private:
HMAC_CTX* ctx_; DeleteFnPtr<HMAC_CTX, HMAC_CTX_free> ctx_;
}; };
class Hash : public BaseObject { class Hash : public BaseObject {
public: public:
~Hash() override;
static void Initialize(Environment* env, v8::Local<v8::Object> target); static void Initialize(Environment* env, v8::Local<v8::Object> target);
bool HashInit(const char* hash_type); bool HashInit(const char* hash_type);
@ -463,7 +464,7 @@ class Hash : public BaseObject {
} }
private: private:
EVP_MD_CTX* mdctx_; EVPMDPointer mdctx_;
bool finalized_; bool finalized_;
}; };
@ -480,19 +481,16 @@ class SignBase : public BaseObject {
} Error; } Error;
SignBase(Environment* env, v8::Local<v8::Object> wrap) SignBase(Environment* env, v8::Local<v8::Object> wrap)
: BaseObject(env, wrap), : BaseObject(env, wrap) {
mdctx_(nullptr) {
} }
~SignBase() override;
Error Init(const char* sign_type); Error Init(const char* sign_type);
Error Update(const char* data, int len); Error Update(const char* data, int len);
protected: protected:
void CheckThrow(Error error); void CheckThrow(Error error);
EVP_MD_CTX* mdctx_; EVPMDPointer mdctx_;
}; };
class Sign : public SignBase { class Sign : public SignBase {
@ -573,12 +571,6 @@ class PublicKeyCipher {
class DiffieHellman : public BaseObject { class DiffieHellman : public BaseObject {
public: public:
~DiffieHellman() override {
if (dh != nullptr) {
DH_free(dh);
}
}
static void Initialize(Environment* env, v8::Local<v8::Object> target); static void Initialize(Environment* env, v8::Local<v8::Object> target);
bool Init(int primeLength, int g); bool Init(int primeLength, int g);
@ -603,8 +595,7 @@ class DiffieHellman : public BaseObject {
DiffieHellman(Environment* env, v8::Local<v8::Object> wrap) DiffieHellman(Environment* env, v8::Local<v8::Object> wrap)
: BaseObject(env, wrap), : BaseObject(env, wrap),
initialised_(false), initialised_(false),
verifyError_(0), verifyError_(0) {
dh(nullptr) {
MakeWeak(); MakeWeak();
} }
@ -618,29 +609,26 @@ class DiffieHellman : public BaseObject {
bool initialised_; bool initialised_;
int verifyError_; int verifyError_;
DH* dh; DHPointer dh_;
}; };
class ECDH : public BaseObject { class ECDH : public BaseObject {
public: public:
~ECDH() override { ~ECDH() override {
if (key_ != nullptr)
EC_KEY_free(key_);
key_ = nullptr;
group_ = nullptr; group_ = nullptr;
} }
static void Initialize(Environment* env, v8::Local<v8::Object> target); static void Initialize(Environment* env, v8::Local<v8::Object> target);
static EC_POINT* BufferToPoint(Environment* env, static ECPointPointer BufferToPoint(Environment* env,
const EC_GROUP* group, const EC_GROUP* group,
char* data, char* data,
size_t len); size_t len);
protected: protected:
ECDH(Environment* env, v8::Local<v8::Object> wrap, EC_KEY* key) ECDH(Environment* env, v8::Local<v8::Object> wrap, ECKeyPointer&& key)
: BaseObject(env, wrap), : BaseObject(env, wrap),
key_(key), key_(std::move(key)),
group_(EC_KEY_get0_group(key_)) { group_(EC_KEY_get0_group(key_.get())) {
MakeWeak(); MakeWeak();
CHECK_NE(group_, nullptr); CHECK_NE(group_, nullptr);
} }
@ -654,9 +642,9 @@ class ECDH : public BaseObject {
static void SetPublicKey(const v8::FunctionCallbackInfo<v8::Value>& args); static void SetPublicKey(const v8::FunctionCallbackInfo<v8::Value>& args);
bool IsKeyPairValid(); bool IsKeyPairValid();
bool IsKeyValidForCurve(const BIGNUM* private_key); bool IsKeyValidForCurve(const BignumPointer& private_key);
EC_KEY* key_; ECKeyPointer key_;
const EC_GROUP* group_; const EC_GROUP* group_;
}; };

View File

@ -74,8 +74,10 @@ TLSWrap::TLSWrap(Environment* env,
CHECK_NE(sc, nullptr); CHECK_NE(sc, nullptr);
// We've our own session callbacks // We've our own session callbacks
SSL_CTX_sess_set_get_cb(sc_->ctx_, SSLWrap<TLSWrap>::GetSessionCallback); SSL_CTX_sess_set_get_cb(sc_->ctx_.get(),
SSL_CTX_sess_set_new_cb(sc_->ctx_, SSLWrap<TLSWrap>::NewSessionCallback); SSLWrap<TLSWrap>::GetSessionCallback);
SSL_CTX_sess_set_new_cb(sc_->ctx_.get(),
SSLWrap<TLSWrap>::NewSessionCallback);
stream->PushStreamListener(this); stream->PushStreamListener(this);
@ -116,35 +118,36 @@ void TLSWrap::InitSSL() {
crypto::NodeBIO::FromBIO(enc_in_)->AssignEnvironment(env()); crypto::NodeBIO::FromBIO(enc_in_)->AssignEnvironment(env());
crypto::NodeBIO::FromBIO(enc_out_)->AssignEnvironment(env()); crypto::NodeBIO::FromBIO(enc_out_)->AssignEnvironment(env());
SSL_set_bio(ssl_, enc_in_, enc_out_); SSL_set_bio(ssl_.get(), enc_in_, enc_out_);
// NOTE: This could be overridden in SetVerifyMode // NOTE: This could be overridden in SetVerifyMode
SSL_set_verify(ssl_, SSL_VERIFY_NONE, crypto::VerifyCallback); SSL_set_verify(ssl_.get(), SSL_VERIFY_NONE, crypto::VerifyCallback);
#ifdef SSL_MODE_RELEASE_BUFFERS #ifdef SSL_MODE_RELEASE_BUFFERS
long mode = SSL_get_mode(ssl_); // NOLINT(runtime/int) long mode = SSL_get_mode(ssl_.get()); // NOLINT(runtime/int)
SSL_set_mode(ssl_, mode | SSL_MODE_RELEASE_BUFFERS); SSL_set_mode(ssl_.get(), mode | SSL_MODE_RELEASE_BUFFERS);
#endif // SSL_MODE_RELEASE_BUFFERS #endif // SSL_MODE_RELEASE_BUFFERS
SSL_set_app_data(ssl_, this); SSL_set_app_data(ssl_.get(), this);
SSL_set_info_callback(ssl_, SSLInfoCallback); SSL_set_info_callback(ssl_.get(), SSLInfoCallback);
#ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB #ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
if (is_server()) { if (is_server()) {
SSL_CTX_set_tlsext_servername_callback(sc_->ctx_, SelectSNIContextCallback); SSL_CTX_set_tlsext_servername_callback(sc_->ctx_.get(),
SelectSNIContextCallback);
} }
#endif // SSL_CTRL_SET_TLSEXT_SERVERNAME_CB #endif // SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
ConfigureSecureContext(sc_); ConfigureSecureContext(sc_);
SSL_set_cert_cb(ssl_, SSLWrap<TLSWrap>::SSLCertCallback, this); SSL_set_cert_cb(ssl_.get(), SSLWrap<TLSWrap>::SSLCertCallback, this);
if (is_server()) { if (is_server()) {
SSL_set_accept_state(ssl_); SSL_set_accept_state(ssl_.get());
} else if (is_client()) { } else if (is_client()) {
// Enough space for server response (hello, cert) // Enough space for server response (hello, cert)
crypto::NodeBIO::FromBIO(enc_in_)->set_initial(kInitialClientBufferLength); crypto::NodeBIO::FromBIO(enc_in_)->set_initial(kInitialClientBufferLength);
SSL_set_connect_state(ssl_); SSL_set_connect_state(ssl_.get());
} else { } else {
// Unexpected // Unexpected
ABORT(); ABORT();
@ -342,7 +345,7 @@ Local<Value> TLSWrap::GetSSLError(int status, int* err, std::string* msg) {
if (ssl_ == nullptr) if (ssl_ == nullptr)
return Local<Value>(); return Local<Value>();
*err = SSL_get_error(ssl_, status); *err = SSL_get_error(ssl_.get(), status);
switch (*err) { switch (*err) {
case SSL_ERROR_NONE: case SSL_ERROR_NONE:
case SSL_ERROR_WANT_READ: case SSL_ERROR_WANT_READ:
@ -395,7 +398,7 @@ void TLSWrap::ClearOut() {
char out[kClearOutChunkSize]; char out[kClearOutChunkSize];
int read; int read;
for (;;) { for (;;) {
read = SSL_read(ssl_, out, sizeof(out)); read = SSL_read(ssl_.get(), out, sizeof(out));
if (read <= 0) if (read <= 0)
break; break;
@ -421,7 +424,7 @@ void TLSWrap::ClearOut() {
} }
} }
int flags = SSL_get_shutdown(ssl_); int flags = SSL_get_shutdown(ssl_.get());
if (!eof_ && flags & SSL_RECEIVED_SHUTDOWN) { if (!eof_ && flags & SSL_RECEIVED_SHUTDOWN) {
eof_ = true; eof_ = true;
EmitRead(UV_EOF); EmitRead(UV_EOF);
@ -469,7 +472,7 @@ bool TLSWrap::ClearIn() {
for (i = 0; i < buffers.size(); ++i) { for (i = 0; i < buffers.size(); ++i) {
size_t avail = buffers[i].len; size_t avail = buffers[i].len;
char* data = buffers[i].base; char* data = buffers[i].base;
written = SSL_write(ssl_, data, avail); written = SSL_write(ssl_.get(), data, avail);
CHECK(written == -1 || written == static_cast<int>(avail)); CHECK(written == -1 || written == static_cast<int>(avail));
if (written == -1) if (written == -1)
break; break;
@ -610,7 +613,7 @@ int TLSWrap::DoWrite(WriteWrap* w,
int written = 0; int written = 0;
for (i = 0; i < count; i++) { for (i = 0; i < count; i++) {
written = SSL_write(ssl_, bufs[i].base, bufs[i].len); written = SSL_write(ssl_.get(), bufs[i].base, bufs[i].len);
CHECK(written == -1 || written == static_cast<int>(bufs[i].len)); CHECK(written == -1 || written == static_cast<int>(bufs[i].len));
if (written == -1) if (written == -1)
break; break;
@ -690,8 +693,8 @@ ShutdownWrap* TLSWrap::CreateShutdownWrap(Local<Object> req_wrap_object) {
int TLSWrap::DoShutdown(ShutdownWrap* req_wrap) { int TLSWrap::DoShutdown(ShutdownWrap* req_wrap) {
crypto::MarkPopErrorOnReturn mark_pop_error_on_return; crypto::MarkPopErrorOnReturn mark_pop_error_on_return;
if (ssl_ != nullptr && SSL_shutdown(ssl_) == 0) if (ssl_ && SSL_shutdown(ssl_.get()) == 0)
SSL_shutdown(ssl_); SSL_shutdown(ssl_.get());
shutdown_ = true; shutdown_ = true;
EncOut(); EncOut();
@ -726,7 +729,7 @@ void TLSWrap::SetVerifyMode(const FunctionCallbackInfo<Value>& args) {
} }
// Always allow a connection. We'll reject in javascript. // Always allow a connection. We'll reject in javascript.
SSL_set_verify(wrap->ssl_, verify_mode, crypto::VerifyCallback); SSL_set_verify(wrap->ssl_.get(), verify_mode, crypto::VerifyCallback);
} }
@ -783,7 +786,7 @@ void TLSWrap::GetServername(const FunctionCallbackInfo<Value>& args) {
CHECK_NE(wrap->ssl_, nullptr); CHECK_NE(wrap->ssl_, nullptr);
const char* servername = SSL_get_servername(wrap->ssl_, const char* servername = SSL_get_servername(wrap->ssl_.get(),
TLSEXT_NAMETYPE_host_name); TLSEXT_NAMETYPE_host_name);
if (servername != nullptr) { if (servername != nullptr) {
args.GetReturnValue().Set(OneByteString(env->isolate(), servername)); args.GetReturnValue().Set(OneByteString(env->isolate(), servername));
@ -808,7 +811,7 @@ void TLSWrap::SetServername(const FunctionCallbackInfo<Value>& args) {
#ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB #ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
node::Utf8Value servername(env->isolate(), args[0].As<String>()); node::Utf8Value servername(env->isolate(), args[0].As<String>());
SSL_set_tlsext_host_name(wrap->ssl_, *servername); SSL_set_tlsext_host_name(wrap->ssl_.get(), *servername);
#endif // SSL_CTRL_SET_TLSEXT_SERVERNAME_CB #endif // SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
} }

View File

@ -410,7 +410,6 @@ class BufferValue : public MaybeStackBuffer<char> {
// Use this when a variable or parameter is unused in order to explicitly // Use this when a variable or parameter is unused in order to explicitly
// silence a compiler warning about that. // silence a compiler warning about that.
template <typename T> inline void USE(T&&) {} template <typename T> inline void USE(T&&) {}
} // namespace node
// Run a function when exiting the current scope. // Run a function when exiting the current scope.
struct OnScopeLeave { struct OnScopeLeave {
@ -420,6 +419,37 @@ struct OnScopeLeave {
~OnScopeLeave() { fn_(); } ~OnScopeLeave() { fn_(); }
}; };
// Simple RAII wrapper for contiguous data that uses malloc()/free().
template<typename T>
struct MallocedBuffer {
T* data;
size_t size;
T* release() {
T* ret = data;
data = nullptr;
return ret;
}
MallocedBuffer() : data(nullptr) {}
explicit MallocedBuffer(size_t size) : data(Malloc<T>(size)), size(size) {}
MallocedBuffer(MallocedBuffer&& other) : data(other.data), size(other.size) {
other.data = nullptr;
}
MallocedBuffer& operator=(MallocedBuffer&& other) {
this->~MallocedBuffer();
return *new(this) MallocedBuffer(other);
}
~MallocedBuffer() {
free(data);
}
MallocedBuffer(const MallocedBuffer&) = delete;
MallocedBuffer& operator=(const MallocedBuffer&) = delete;
};
} // namespace node
#endif // defined(NODE_WANT_INTERNALS) && NODE_WANT_INTERNALS #endif // defined(NODE_WANT_INTERNALS) && NODE_WANT_INTERNALS
#endif // SRC_UTIL_H_ #endif // SRC_UTIL_H_