From 40dbbd8de499668590e8af51a15799fbc430595e Mon Sep 17 00:00:00 2001
From: Daiki Ueno <ueno@gnu.org>
Date: Wed, 10 Jan 2024 19:13:17 +0900
Subject: [PATCH] rsa-psk: minimize branching after decryption

This moves any non-trivial code between gnutls_privkey_decrypt_data2
and the function return in _gnutls_proc_rsa_psk_client_kx up until the
decryption.  This also avoids an extra memcpy to session->key.key.

Signed-off-by: Daiki Ueno <ueno@gnu.org>

Upstream-Status: Backport [https://gitlab.com/gnutls/gnutls/-/commit/40dbbd8de499668590e8af51a15799fbc430595e]
CVE: CVE-2024-0553
Signed-off-by: Vijay Anusuri <vanusuri@mvista.com>
---
 lib/auth/rsa_psk.c | 68 ++++++++++++++++++++++++----------------------
 1 file changed, 35 insertions(+), 33 deletions(-)

diff --git a/lib/auth/rsa_psk.c b/lib/auth/rsa_psk.c
index 93c2dc9..c6cfb92 100644
--- a/lib/auth/rsa_psk.c
+++ b/lib/auth/rsa_psk.c
@@ -269,7 +269,6 @@ _gnutls_proc_rsa_psk_client_kx(gnutls_session_t session, uint8_t * data,
 	int ret, dsize;
 	ssize_t data_size = _data_size;
 	gnutls_psk_server_credentials_t cred;
-	gnutls_datum_t premaster_secret = { NULL, 0 };
 	volatile uint8_t ver_maj, ver_min;
 
 	cred = (gnutls_psk_server_credentials_t)
@@ -329,24 +328,48 @@ _gnutls_proc_rsa_psk_client_kx(gnutls_session_t session, uint8_t * data,
 	ver_maj = _gnutls_get_adv_version_major(session);
 	ver_min = _gnutls_get_adv_version_minor(session);
 
-	premaster_secret.data = gnutls_malloc(GNUTLS_MASTER_SIZE);
-	if (premaster_secret.data == NULL) {
+	/* Find the key of this username. A random value will be
+	 * filled in if the key is not found.
+	 */
+	ret = _gnutls_psk_pwd_find_entry(session, info->username,
+			                 strlen(info->username), &pwd_psk);
+	if (ret < 0)
+		return gnutls_assert_val(ret);
+
+	/* Allocate memory for premaster secret, and fill in the
+	 * fields except the decryption result.
+	 */
+	session->key.key.size = 2 + GNUTLS_MASTER_SIZE + 2 + pwd_psk.size;
+	session->key.key.data = gnutls_malloc(session->key.key.size);
+	if (session->key.key.data == NULL) {
 		gnutls_assert();
+		_gnutls_free_key_datum(&pwd_psk);
+		/* No need to zeroize, as the secret is not copied in yet */
+		_gnutls_free_datum(&session->key.key);
 		return GNUTLS_E_MEMORY_ERROR;
 	}
-	premaster_secret.size = GNUTLS_MASTER_SIZE;
 
 	/* Fallback value when decryption fails. Needs to be unpredictable. */
-	ret = gnutls_rnd(GNUTLS_RND_NONCE, premaster_secret.data,
-			 premaster_secret.size);
+	ret = gnutls_rnd(GNUTLS_RND_NONCE, session->key.key.data + 2,
+			 GNUTLS_MASTER_SIZE);
 	if (ret < 0) {
 		gnutls_assert();
-		goto cleanup;
+		_gnutls_free_key_datum(&pwd_psk);
+		/* No need to zeroize, as the secret is not copied in yet */
+		_gnutls_free_datum(&session->key.key);
+		return ret;
 	}
 
+	_gnutls_write_uint16(GNUTLS_MASTER_SIZE, session->key.key.data);
+	_gnutls_write_uint16(pwd_psk.size,
+			     &session->key.key.data[2 + GNUTLS_MASTER_SIZE]);
+	memcpy(&session->key.key.data[2 + GNUTLS_MASTER_SIZE + 2], pwd_psk.data,
+	       pwd_psk.size);
+	_gnutls_free_key_datum(&pwd_psk);
+
 	gnutls_privkey_decrypt_data2(session->internals.selected_key, 0,
-				     &ciphertext, premaster_secret.data,
-				     premaster_secret.size);
+				     &ciphertext, session->key.key.data + 2,
+				     GNUTLS_MASTER_SIZE);
 	/* After this point, any conditional on failure that cause differences
 	 * in execution may create a timing or cache access pattern side
 	 * channel that can be used as an oracle, so tread carefully */
@@ -365,31 +388,10 @@ _gnutls_proc_rsa_psk_client_kx(gnutls_session_t session, uint8_t * data,
 	/* This is here to avoid the version check attack
 	 * discussed above.
 	 */
-	premaster_secret.data[0] = ver_maj;
-	premaster_secret.data[1] = ver_min;
+	session->key.key.data[2] = ver_maj;
+	session->key.key.data[3] = ver_min;
 
-	/* find the key of this username
-	 */
-	ret =
-	    _gnutls_psk_pwd_find_entry(session, info->username, strlen(info->username), &pwd_psk);
-	if (ret < 0) {
-		gnutls_assert();
-		goto cleanup;
-	}
-
-	ret =
-	    set_rsa_psk_session_key(session, &pwd_psk, &premaster_secret);
-	if (ret < 0) {
-		gnutls_assert();
-		goto cleanup;
-	}
-
-	ret = 0;
-      cleanup:
-	_gnutls_free_key_datum(&pwd_psk);
-	_gnutls_free_temp_key_datum(&premaster_secret);
-
-	return ret;
+	return 0;
 }
 
 static int
-- 
2.25.1