diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index 883201f..cf4d84d 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -3891,6 +3891,37 @@ class TestPostHandshakeAuth(unittest.TestCase): s.write(b'PHA') self.assertIn(b'WRONG_SSL_VERSION', s.recv(1024)) + def test_bpo37428_pha_cert_none(self): + # verify that post_handshake_auth does not implicitly enable cert + # validation. + hostname = 'localhost' + client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + client_context.post_handshake_auth = True + client_context.load_cert_chain(SIGNED_CERTFILE) + # no cert validation and CA on client side + client_context.check_hostname = False + client_context.verify_mode = ssl.CERT_NONE + + server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + server_context.load_cert_chain(SIGNED_CERTFILE) + server_context.load_verify_locations(SIGNING_CA) + server_context.post_handshake_auth = True + server_context.verify_mode = ssl.CERT_REQUIRED + + server = ThreadedEchoServer(context=server_context, chatty=False) + with server: + with client_context.wrap_socket(socket.socket(), + server_hostname=hostname) as s: + s.connect((HOST, server.port)) + s.write(b'HASCERT') + self.assertEqual(s.recv(1024), b'FALSE\n') + s.write(b'PHA') + self.assertEqual(s.recv(1024), b'OK\n') + s.write(b'HASCERT') + self.assertEqual(s.recv(1024), b'TRUE\n') + # server cert has not been validated + self.assertEqual(s.getpeercert(), {}) + def test_main(verbose=False): if support.verbose: diff --git a/Modules/_ssl.c b/Modules/_ssl.c index ec366f0..9bf1cde 100644 --- a/Modules/_ssl.c +++ b/Modules/_ssl.c @@ -732,6 +732,26 @@ newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock, #endif SSL_set_mode(self->ssl, mode); +#ifdef TLS1_3_VERSION + if (sslctx->post_handshake_auth == 1) { + if (socket_type == PY_SSL_SERVER) { + /* bpo-37428: OpenSSL does not ignore SSL_VERIFY_POST_HANDSHAKE. + * Set SSL_VERIFY_POST_HANDSHAKE flag only for server sockets and + * only in combination with SSL_VERIFY_PEER flag. */ + int mode = SSL_get_verify_mode(self->ssl); + if (mode & SSL_VERIFY_PEER) { + int (*verify_cb)(int, X509_STORE_CTX *) = NULL; + verify_cb = SSL_get_verify_callback(self->ssl); + mode |= SSL_VERIFY_POST_HANDSHAKE; + SSL_set_verify(self->ssl, mode, verify_cb); + } + } else { + /* client socket */ + SSL_set_post_handshake_auth(self->ssl, 1); + } + } +#endif + #if HAVE_SNI if (server_hostname != NULL) { /* Don't send SNI for IP addresses. We cannot simply use inet_aton() and @@ -2765,10 +2785,10 @@ _set_verify_mode(PySSLContext *self, enum py_ssl_cert_requirements n) "invalid value for verify_mode"); return -1; } -#ifdef TLS1_3_VERSION - if (self->post_handshake_auth) - mode |= SSL_VERIFY_POST_HANDSHAKE; -#endif + + /* bpo-37428: newPySSLSocket() sets SSL_VERIFY_POST_HANDSHAKE flag for + * server sockets and SSL_set_post_handshake_auth() for client. */ + /* keep current verify cb */ verify_cb = SSL_CTX_get_verify_callback(self->ctx); SSL_CTX_set_verify(self->ctx, mode, verify_cb); @@ -3346,8 +3366,6 @@ get_post_handshake_auth(PySSLContext *self, void *c) { #if TLS1_3_VERSION static int set_post_handshake_auth(PySSLContext *self, PyObject *arg, void *c) { - int (*verify_cb)(int, X509_STORE_CTX *) = NULL; - int mode = SSL_CTX_get_verify_mode(self->ctx); int pha = PyObject_IsTrue(arg); if (pha == -1) { @@ -3355,17 +3373,8 @@ set_post_handshake_auth(PySSLContext *self, PyObject *arg, void *c) { } self->post_handshake_auth = pha; - /* client-side socket setting, ignored by server-side */ - SSL_CTX_set_post_handshake_auth(self->ctx, pha); - - /* server-side socket setting, ignored by client-side */ - verify_cb = SSL_CTX_get_verify_callback(self->ctx); - if (pha) { - mode |= SSL_VERIFY_POST_HANDSHAKE; - } else { - mode ^= SSL_VERIFY_POST_HANDSHAKE; - } - SSL_CTX_set_verify(self->ctx, mode, verify_cb); + /* bpo-37428: newPySSLSocket() sets SSL_VERIFY_POST_HANDSHAKE flag for + * server sockets and SSL_set_post_handshake_auth() for client. */ return 0; }