diff --git a/crypto/rsa/rsa_crpt.c b/crypto/rsa/rsa_crpt.c index 21c922e609..d6c7353406 100644 --- a/crypto/rsa/rsa_crpt.c +++ b/crypto/rsa/rsa_crpt.c @@ -61,28 +61,16 @@ int RSA_flags(const RSA *r) void RSA_blinding_off(RSA *rsa) { - BN_BLINDING_free(rsa->blinding); - rsa->blinding = NULL; rsa->flags &= ~RSA_FLAG_BLINDING; rsa->flags |= RSA_FLAG_NO_BLINDING; } int RSA_blinding_on(RSA *rsa, BN_CTX *ctx) { - int ret = 0; - - if (rsa->blinding != NULL) - RSA_blinding_off(rsa); - - rsa->blinding = RSA_setup_blinding(rsa, ctx); - if (rsa->blinding == NULL) - goto err; rsa->flags |= RSA_FLAG_BLINDING; rsa->flags &= ~RSA_FLAG_NO_BLINDING; - ret = 1; - err: - return ret; + return 1; } static BIGNUM *rsa_get_public_exp(const BIGNUM *d, const BIGNUM *p, @@ -162,8 +150,6 @@ BN_BLINDING *RSA_setup_blinding(RSA *rsa, BN_CTX *in_ctx) goto err; } - BN_BLINDING_set_current_thread(ret); - err: BN_CTX_end(ctx); if (ctx != in_ctx) diff --git a/crypto/rsa/rsa_lib.c b/crypto/rsa/rsa_lib.c index d9ceb80880..77bb330bbb 100644 --- a/crypto/rsa/rsa_lib.c +++ b/crypto/rsa/rsa_lib.c @@ -25,6 +25,7 @@ #include "crypto/bn.h" #include "crypto/evp.h" #include "crypto/rsa.h" +#include "crypto/sparse_array.h" #include "crypto/security_bits.h" #include "rsa_local.h" @@ -92,6 +93,10 @@ static RSA *rsa_new_intern(ENGINE *engine, OSSL_LIB_CTX *libctx) return NULL; } + ret->blindings_sa = ossl_rsa_alloc_blinding(); + if (ret->blindings_sa == NULL) + goto err; + ret->libctx = libctx; ret->meth = RSA_get_default_method(); #if !defined(OPENSSL_NO_ENGINE) && !defined(FIPS_MODULE) @@ -181,8 +186,7 @@ void RSA_free(RSA *r) RSA_PSS_PARAMS_free(r->pss); sk_RSA_PRIME_INFO_pop_free(r->prime_infos, ossl_rsa_multip_info_free); #endif - BN_BLINDING_free(r->blinding); - BN_BLINDING_free(r->mt_blinding); + ossl_rsa_free_blinding(r); OPENSSL_free(r); } @@ -1380,4 +1384,5 @@ int EVP_PKEY_CTX_set_rsa_keygen_primes(EVP_PKEY_CTX *ctx, int primes) return evp_pkey_ctx_set_params_strict(ctx, params); } + #endif diff --git a/crypto/rsa/rsa_local.h b/crypto/rsa/rsa_local.h index db9eb2a1df..90122f97fc 100644 --- a/crypto/rsa/rsa_local.h +++ b/crypto/rsa/rsa_local.h @@ -93,8 +93,7 @@ struct rsa_st { BN_MONT_CTX *_method_mod_n; BN_MONT_CTX *_method_mod_p; BN_MONT_CTX *_method_mod_q; - BN_BLINDING *blinding; - BN_BLINDING *mt_blinding; + void *blindings_sa; CRYPTO_RWLOCK *lock; int dirty_cnt; @@ -196,5 +195,7 @@ int ossl_rsa_fips186_4_gen_prob_primes(RSA *rsa, RSA_ACVP_TEST *test, int ossl_rsa_padding_add_PKCS1_type_2_ex(OSSL_LIB_CTX *libctx, unsigned char *to, int tlen, const unsigned char *from, int flen); +void ossl_rsa_free_blinding(RSA *rsa); +void *ossl_rsa_alloc_blinding(void); #endif /* OSSL_CRYPTO_RSA_LOCAL_H */ diff --git a/crypto/rsa/rsa_ossl.c b/crypto/rsa/rsa_ossl.c index 0c0c73c65c..a1d4877342 100644 --- a/crypto/rsa/rsa_ossl.c +++ b/crypto/rsa/rsa_ossl.c @@ -15,12 +15,15 @@ #include "internal/cryptlib.h" #include "crypto/bn.h" +#include "crypto/sparse_array.h" #include "rsa_local.h" #include "internal/constant_time.h" #include #include #include +DEFINE_SPARSE_ARRAY_OF(BN_BLINDING); + static int rsa_ossl_public_encrypt(int flen, const unsigned char *from, unsigned char *to, RSA *rsa, int padding); static int rsa_ossl_private_encrypt(int flen, const unsigned char *from, @@ -207,86 +210,76 @@ static int rsa_ossl_public_encrypt(int flen, const unsigned char *from, return r; } -static BN_BLINDING *rsa_get_blinding(RSA *rsa, int *local, BN_CTX *ctx) +static void free_bn_blinding(ossl_uintmax_t idx, BN_BLINDING *b, void *arg) +{ + BN_BLINDING_free(b); +} + +void ossl_rsa_free_blinding(RSA *rsa) +{ + SPARSE_ARRAY_OF(BN_BLINDING) *blindings = rsa->blindings_sa; + + ossl_sa_BN_BLINDING_doall_arg(blindings, free_bn_blinding, NULL); + ossl_sa_BN_BLINDING_free(blindings); +} + +void *ossl_rsa_alloc_blinding(void) +{ + return ossl_sa_BN_BLINDING_new(); +} + +static BN_BLINDING *ossl_rsa_get_thread_bn_blinding(RSA *rsa) +{ + SPARSE_ARRAY_OF(BN_BLINDING) *blindings = rsa->blindings_sa; + uintptr_t tid = (uintptr_t)CRYPTO_THREAD_get_current_id(); + + return ossl_sa_BN_BLINDING_get(blindings, tid); +} + +static int ossl_rsa_set_thread_bn_blinding(RSA *rsa, BN_BLINDING *b) +{ + SPARSE_ARRAY_OF(BN_BLINDING) *blindings = rsa->blindings_sa; + uintptr_t tid = (uintptr_t)CRYPTO_THREAD_get_current_id(); + + return ossl_sa_BN_BLINDING_set(blindings, tid, b); +} + +static BN_BLINDING *rsa_get_blinding(RSA *rsa, BN_CTX *ctx) { BN_BLINDING *ret; if (!CRYPTO_THREAD_read_lock(rsa->lock)) return NULL; - if (rsa->blinding == NULL) { - /* - * This dance with upgrading the lock from read to write will be - * slower in cases of a single use RSA object, but should be - * significantly better in multi-thread cases (e.g. servers). It's - * probably worth it. - */ - CRYPTO_THREAD_unlock(rsa->lock); - if (!CRYPTO_THREAD_write_lock(rsa->lock)) - return NULL; - if (rsa->blinding == NULL) - rsa->blinding = RSA_setup_blinding(rsa, ctx); - } - - ret = rsa->blinding; - if (ret == NULL) - goto err; - - if (BN_BLINDING_is_current_thread(ret)) { - /* rsa->blinding is ours! */ - - *local = 1; - } else { - /* resort to rsa->mt_blinding instead */ - - /* - * instructs rsa_blinding_convert(), rsa_blinding_invert() that the - * BN_BLINDING is shared, meaning that accesses require locks, and - * that the blinding factor must be stored outside the BN_BLINDING - */ - *local = 0; - - if (rsa->mt_blinding == NULL) { - CRYPTO_THREAD_unlock(rsa->lock); - if (!CRYPTO_THREAD_write_lock(rsa->lock)) - return NULL; - if (rsa->mt_blinding == NULL) - rsa->mt_blinding = RSA_setup_blinding(rsa, ctx); - } - ret = rsa->mt_blinding; - } - - err: + ret = ossl_rsa_get_thread_bn_blinding(rsa); CRYPTO_THREAD_unlock(rsa->lock); + + if (ret == NULL) { + ret = RSA_setup_blinding(rsa, ctx); + if (!CRYPTO_THREAD_write_lock(rsa->lock)) { + BN_BLINDING_free(ret); + ret = NULL; + } else { + if (!ossl_rsa_set_thread_bn_blinding(rsa, ret)) { + BN_BLINDING_free(ret); + ret = NULL; + } + } + CRYPTO_THREAD_unlock(rsa->lock); + } + return ret; } -static int rsa_blinding_convert(BN_BLINDING *b, BIGNUM *f, BIGNUM *unblind, - BN_CTX *ctx) +static int rsa_blinding_convert(BN_BLINDING *b, BIGNUM *f, BN_CTX *ctx) { - if (unblind == NULL) { - /* - * Local blinding: store the unblinding factor in BN_BLINDING. - */ - return BN_BLINDING_convert_ex(f, NULL, b, ctx); - } else { - /* - * Shared blinding: store the unblinding factor outside BN_BLINDING. - */ - int ret; - - if (!BN_BLINDING_lock(b)) - return 0; - - ret = BN_BLINDING_convert_ex(f, unblind, b, ctx); - BN_BLINDING_unlock(b); - - return ret; - } + /* + * Local blinding: store the unblinding factor in BN_BLINDING. + */ + return BN_BLINDING_convert_ex(f, NULL, b, ctx); } -static int rsa_blinding_invert(BN_BLINDING *b, BIGNUM *f, BIGNUM *unblind, - BN_CTX *ctx) +static int rsa_blinding_invert(BN_BLINDING *b, BIGNUM *f, BN_CTX *ctx) { /* * For local blinding, unblind is set to NULL, and BN_BLINDING_invert_ex @@ -297,7 +290,7 @@ static int rsa_blinding_invert(BN_BLINDING *b, BIGNUM *f, BIGNUM *unblind, * to access the blinding without a lock. */ BN_set_flags(f, BN_FLG_CONSTTIME); - return BN_BLINDING_invert_ex(f, unblind, b, ctx); + return BN_BLINDING_invert_ex(f, NULL, b, ctx); } /* signing */ @@ -308,13 +301,6 @@ static int rsa_ossl_private_encrypt(int flen, const unsigned char *from, int i, num = 0, r = -1; unsigned char *buf = NULL; BN_CTX *ctx = NULL; - int local_blinding = 0; - /* - * Used only if the blinding structure is shared. A non-NULL unblind - * instructs rsa_blinding_convert() and rsa_blinding_invert() to store - * the unblinding factor outside the blinding structure. - */ - BIGNUM *unblind = NULL; BN_BLINDING *blinding = NULL; if ((ctx = BN_CTX_new_ex(rsa->libctx)) == NULL) @@ -359,19 +345,13 @@ static int rsa_ossl_private_encrypt(int flen, const unsigned char *from, goto err; if (!(rsa->flags & RSA_FLAG_NO_BLINDING)) { - blinding = rsa_get_blinding(rsa, &local_blinding, ctx); + blinding = rsa_get_blinding(rsa, ctx); if (blinding == NULL) { ERR_raise(ERR_LIB_RSA, ERR_R_INTERNAL_ERROR); goto err; } - } - if (blinding != NULL) { - if (!local_blinding && ((unblind = BN_CTX_get(ctx)) == NULL)) { - ERR_raise(ERR_LIB_RSA, ERR_R_BN_LIB); - goto err; - } - if (!rsa_blinding_convert(blinding, f, unblind, ctx)) + if (!rsa_blinding_convert(blinding, f, ctx)) goto err; } @@ -405,7 +385,7 @@ static int rsa_ossl_private_encrypt(int flen, const unsigned char *from, } if (blinding) - if (!rsa_blinding_invert(blinding, ret, unblind, ctx)) + if (!rsa_blinding_invert(blinding, ret, ctx)) goto err; if (padding == RSA_X931_PADDING) { @@ -524,13 +504,6 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from, unsigned char *buf = NULL; unsigned char kdk[SHA256_DIGEST_LENGTH] = {0}; BN_CTX *ctx = NULL; - int local_blinding = 0; - /* - * Used only if the blinding structure is shared. A non-NULL unblind - * instructs rsa_blinding_convert() and rsa_blinding_invert() to store - * the unblinding factor outside the blinding structure. - */ - BIGNUM *unblind = NULL; BN_BLINDING *blinding = NULL; /* @@ -606,19 +579,13 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from, goto err; if (!(rsa->flags & RSA_FLAG_NO_BLINDING)) { - blinding = rsa_get_blinding(rsa, &local_blinding, ctx); + blinding = rsa_get_blinding(rsa, ctx); if (blinding == NULL) { ERR_raise(ERR_LIB_RSA, ERR_R_INTERNAL_ERROR); goto err; } - } - if (blinding != NULL) { - if (!local_blinding && ((unblind = BN_CTX_get(ctx)) == NULL)) { - ERR_raise(ERR_LIB_RSA, ERR_R_BN_LIB); - goto err; - } - if (!rsa_blinding_convert(blinding, f, unblind, ctx)) + if (!rsa_blinding_convert(blinding, f, ctx)) goto err; } @@ -652,7 +619,7 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from, } if (blinding) - if (!rsa_blinding_invert(blinding, ret, unblind, ctx)) + if (!rsa_blinding_invert(blinding, ret, ctx)) goto err; /*