WALinuxAgent/wla-Jira-https-issues.redhat.com-browse-RHEL-129954.patch

707 lines
36 KiB
Diff

From 0e90372ba24091860266bb0a3c33fc20e38a1a97 Mon Sep 17 00:00:00 2001
From: Norberto Arrieta <narrieta@users.noreply.github.com>
Date: Tue, 4 Mar 2025 12:55:27 -0800
Subject: [PATCH] Jira: https://issues.redhat.com/browse/RHEL-129954
RH-Author: yuxisun <None>
RH-MergeRequest: 23: Support for FIPS 140-3 (#3324)
RH-Jira: RHEL-129954
RH-Acked-by: Vitaly Kuznetsov <vkuznets@redhat.com>
RH-Acked-by: Miroslav Rezanina <mrezanin@redhat.com>
RH-Commit: [1/1] da147f85a89d1375c0f4d7e36fffd0f68b231770
Support for FIPS 140-3 (#3324)
When fetching certificates from WireServer, the Agent uses DES_EDE3_CBC. The PFX it receives has a MAC computed using PKCS12KDF. Both are deprecated on FIPS 140-3.
This PR switches to AES128_CBC for communication with the WireServer (a subsequent PR will change it to AES256_CBC) and skips MAC verification when it is not needed.
The changes also include some minor cleanup to remove data structures that are not used.
Upstream PR: https://github.com/Azure/WALinuxAgent/pull/3324
Signed-off-by: Yuxin Sun <yuxisun@redhat.com>
---
azurelinuxagent/common/event.py | 20 ++
azurelinuxagent/common/protocol/goal_state.py | 216 +++++++++++-------
azurelinuxagent/common/protocol/restapi.py | 24 --
azurelinuxagent/common/protocol/wire.py | 18 +-
azurelinuxagent/common/utils/cryptutil.py | 46 ++--
azurelinuxagent/ga/update.py | 1 +
tests/common/protocol/test_goal_state.py | 81 ++++++-
tests/common/protocol/test_hostplugin.py | 35 ++-
tests/common/protocol/test_wire.py | 8 +-
tests/ga/test_update.py | 2 +-
10 files changed, 275 insertions(+), 176 deletions(-)
diff --git a/azurelinuxagent/common/event.py b/azurelinuxagent/common/event.py
index 6b9521ca..9b8a926e 100644
--- a/azurelinuxagent/common/event.py
+++ b/azurelinuxagent/common/event.py
@@ -93,6 +93,7 @@ class WALAEventOperation:
FetchGoalState = "FetchGoalState"
Firewall = "Firewall"
GoalState = "GoalState"
+ GoalStateCertificates = "GoalStateCertificates"
GoalStateUnsupportedFeatures = "GoalStateUnsupportedFeatures"
HealthCheck = "HealthCheck"
HealthObservation = "HealthObservation"
@@ -733,6 +734,25 @@ def error(op, fmt, *args):
add_event(op=op, message=fmt.format(*args), is_success=False, log_event=False)
+class LogEvent(object):
+ """
+ Helper class that allows the use of info()/warn()/error() using a specific instance of a logger.
+ """
+ def __init__(self, logger_):
+ self._logger = logger_
+
+ def info(self, op, fmt, *args):
+ self._logger.info(fmt, *args)
+ add_event(op=op, message=fmt.format(*args), is_success=True)
+
+ def warn(self, op, fmt, *args):
+ self._logger.warn(fmt, *args)
+ add_event(op=op, message="[WARNING] " + fmt.format(*args), is_success=False, log_event=False)
+
+ def error(self, op, fmt, *args):
+ self._logger.error(fmt, *args)
+ add_event(op=op, message=fmt.format(*args), is_success=False, log_event=False)
+
def add_log_event(level, message, forced=False, reporter=__event_logger__):
"""
:param level: LoggerLevel of the log event
diff --git a/azurelinuxagent/common/protocol/goal_state.py b/azurelinuxagent/common/protocol/goal_state.py
index f94f3ae5..2556cc73 100644
--- a/azurelinuxagent/common/protocol/goal_state.py
+++ b/azurelinuxagent/common/protocol/goal_state.py
@@ -24,15 +24,14 @@ import json
from azurelinuxagent.common import conf
from azurelinuxagent.common import logger
from azurelinuxagent.common.AgentGlobals import AgentGlobals
-from azurelinuxagent.common.datacontract import set_properties
-from azurelinuxagent.common.event import add_event, WALAEventOperation
+from azurelinuxagent.common.event import add_event, WALAEventOperation, LogEvent
from azurelinuxagent.common.exception import ProtocolError, ResourceGoneError
from azurelinuxagent.common.future import ustr
from azurelinuxagent.common.protocol.extensions_goal_state_factory import ExtensionsGoalStateFactory
from azurelinuxagent.common.protocol.extensions_goal_state import VmSettingsParseError, GoalStateSource
from azurelinuxagent.common.protocol.hostplugin import VmSettingsNotSupported, VmSettingsSupportStopped
-from azurelinuxagent.common.protocol.restapi import Cert, CertList, RemoteAccessUser, RemoteAccessUsersList, ExtHandlerPackage, ExtHandlerPackageList
-from azurelinuxagent.common.utils import fileutil
+from azurelinuxagent.common.protocol.restapi import RemoteAccessUser, RemoteAccessUsersList, ExtHandlerPackage, ExtHandlerPackageList
+from azurelinuxagent.common.utils import fileutil, shellutil
from azurelinuxagent.common.utils.archive import GoalStateHistory, SHARED_CONF_FILE_NAME
from azurelinuxagent.common.utils.cryptutil import CryptUtil
from azurelinuxagent.common.utils.textutil import parse_doc, findall, find, findtext, getattrib, gettext
@@ -41,6 +40,7 @@ from azurelinuxagent.common.utils.textutil import parse_doc, findall, find, find
GOAL_STATE_URI = "http://{0}/machine/?comp=goalstate"
CERTS_FILE_NAME = "Certificates.xml"
P7M_FILE_NAME = "Certificates.p7m"
+PFX_FILE_NAME = "Certificates.pfx"
PEM_FILE_NAME = "Certificates.pem"
TRANSPORT_CERT_FILE_NAME = "TransportCert.pem"
TRANSPORT_PRV_FILE_NAME = "TransportPrivate.pem"
@@ -282,16 +282,8 @@ class GoalState(object):
self._check_and_download_missing_certs_on_disk()
def _download_certificates(self, certs_uri):
- xml_text = self._wire_client.fetch_config(certs_uri, self._wire_client.get_header_for_cert())
- certs = Certificates(xml_text, self.logger)
- # Log and save the certificates summary (i.e. the thumbprint but not the certificate itself) to the goal state history
- for c in certs.summary:
- message = "Downloaded certificate {0}".format(c)
- self.logger.info(message)
- add_event(op=WALAEventOperation.GoalState, message=message)
- if len(certs.warnings) > 0:
- self.logger.warn(certs.warnings)
- add_event(op=WALAEventOperation.GoalState, message=certs.warnings)
+ certs = Certificates(self._wire_client, certs_uri, self.logger)
+ # Save the certificates summary (i.e. the thumbprints but not the certificates themselves) to the goal state history
if self._save_to_history:
self._history.save_certificates(json.dumps(certs.summary))
return certs
@@ -511,31 +503,83 @@ class SharedConfig(object):
self.xml_text = xml_text
-class Certificates(object):
- def __init__(self, xml_text, my_logger):
- self.cert_list = CertList()
- self.summary = [] # debugging info
- self.warnings = []
+class Certificates(LogEvent):
+ def __init__(self, wire_client, uri, logger_):
+ super(Certificates, self).__init__(logger_)
+ self.summary = []
+ self._crypt_util = CryptUtil(conf.get_openssl_cmd())
- # Save the certificates
- local_file = os.path.join(conf.get_lib_dir(), CERTS_FILE_NAME)
- fileutil.write_file(local_file, xml_text)
+ try:
+ pfx_file = self._download_certificates_pfx(wire_client, uri)
+ if pfx_file is None: # The response from the WireServer may not have any certificates
+ return
- # Separate the certificates into individual files.
- xml_doc = parse_doc(xml_text)
- data = findtext(xml_doc, "Data")
- if data is None:
- return
+ try:
+ pem_file = self._convert_certificates_pfx_to_pem(pfx_file)
+ finally:
+ self._remove_file(pfx_file)
- # if the certificates format is not Pkcs7BlobWithPfxContents do not parse it
- certificate_format = findtext(xml_doc, "Format")
- if certificate_format and certificate_format != "Pkcs7BlobWithPfxContents":
- message = "The Format is not Pkcs7BlobWithPfxContents. Format is {0}".format(certificate_format)
- my_logger.warn(message)
- add_event(op=WALAEventOperation.GoalState, message=message)
- return
+ self.summary = self._extract_certificate(pem_file)
+
+ for c in self.summary:
+ self.info(WALAEventOperation.GoalStateCertificates, "Downloaded certificate {0}", c)
+
+ except Exception as e:
+ self.error(WALAEventOperation.GoalStateCertificates, "Error fetching the goal state certificates: {0}", ustr(e))
+
+ def _remove_file(self, file):
+ if os.path.exists(file):
+ try:
+ os.remove(file)
+ except Exception as e:
+ self.warn(WALAEventOperation.GoalStateCertificates, "Failed to remove {0}: {1}", file, ustr(e))
+
+ def _download_certificates_pfx(self, wire_client, uri):
+ """
+ Downloads the certificates from the WireServer and saves them to a pfx file.
+ Returns the full path of the pfx file, or None, if the WireServer response does not have a "Data" element
+ """
+ trans_prv_file = os.path.join(conf.get_lib_dir(), TRANSPORT_PRV_FILE_NAME)
+ trans_cert_file = os.path.join(conf.get_lib_dir(), TRANSPORT_CERT_FILE_NAME)
+ xml_file = os.path.join(conf.get_lib_dir(), CERTS_FILE_NAME)
+ pfx_file = os.path.join(conf.get_lib_dir(), PFX_FILE_NAME)
+
+ for cypher in ["AES128_CBC", "DES_EDE3_CBC"]:
+ headers = wire_client.get_headers_for_encrypted_request(cypher)
+
+ try:
+ xml_text = wire_client.fetch_config(uri, headers)
+ except Exception as e:
+ self.warn(WALAEventOperation.GoalStateCertificates, "Error in Certificates request [cypher: {0}]: {1}", cypher, ustr(e))
+ continue
- cryptutil = CryptUtil(conf.get_openssl_cmd())
+ fileutil.write_file(xml_file, xml_text)
+
+ xml_doc = parse_doc(xml_text)
+ data = findtext(xml_doc, "Data")
+ if data is None:
+ self.info(WALAEventOperation.GoalStateCertificates, "The Data element of the Certificates response is empty")
+ return None
+ certificate_format = findtext(xml_doc, "Format")
+ if certificate_format and certificate_format != "Pkcs7BlobWithPfxContents":
+ self.warn(WALAEventOperation.GoalStateCertificates, "The Certificates format is not Pkcs7BlobWithPfxContents; skipping. Format is {0}", certificate_format)
+ return None
+
+ p7m_file = Certificates._create_p7m_file(data)
+
+ try:
+ self._crypt_util.decrypt_certificates_p7m(p7m_file, trans_prv_file, trans_cert_file, pfx_file)
+ except shellutil.CommandError as e:
+ self.warn(WALAEventOperation.GoalState, "Error in transport decryption [cypher: {0}]: {1}", cypher, ustr(e))
+ self._remove_file(pfx_file)
+ continue
+
+ return pfx_file
+
+ raise Exception("Cannot download certificates using any of the supported cyphers")
+
+ @staticmethod
+ def _create_p7m_file(data):
p7m_file = os.path.join(conf.get_lib_dir(), P7M_FILE_NAME)
p7m = ("MIME-Version:1.0\n" # pylint: disable=W1308
"Content-Disposition: attachment; filename=\"{0}\"\n"
@@ -543,68 +587,72 @@ class Certificates(object):
"Content-Transfer-Encoding: base64\n"
"\n"
"{2}").format(p7m_file, p7m_file, data)
-
fileutil.write_file(p7m_file, p7m)
+ return p7m_file
- trans_prv_file = os.path.join(conf.get_lib_dir(), TRANSPORT_PRV_FILE_NAME)
- trans_cert_file = os.path.join(conf.get_lib_dir(), TRANSPORT_CERT_FILE_NAME)
+ def _convert_certificates_pfx_to_pem(self, pfx_file):
+ """
+ Convert the pfx file to pem file.
+ """
pem_file = os.path.join(conf.get_lib_dir(), PEM_FILE_NAME)
- # decrypt certificates
- cryptutil.decrypt_p7m(p7m_file, trans_prv_file, trans_cert_file, pem_file)
+ for nomacver in [True, False]:
+ try:
+ self._crypt_util.convert_pfx_to_pem(pfx_file, nomacver, pem_file)
+ return pem_file
+ except shellutil.CommandError as e:
+ self._remove_file(pem_file) # An error may leave an empty pem file, which can produce a failure on some versions of open SSL (e.g. 3.2.2) on the next invocation
+ self.warn(WALAEventOperation.GoalState, "Error converting PFX to PEM [-nomacver: {0}]: {1}", nomacver, ustr(e))
+ continue
+
+ raise Exception("Cannot convert PFX to PEM")
+
+ def _extract_certificate(self, pem_file):
+ """
+ Parse the certificates and private keys from the pem file and store them in the certificates directory.
+ """
# The parsing process use public key to match prv and crt.
- buf = []
- prvs = {}
- thumbprints = {}
+ private_keys = {} # map of private keys indexed by public key
+ thumbprints = {} # map of thumbprints indexed by public key
+ buffer = [] # buffer for reading lines belonging to a certificate or private key
index = 0
- v1_cert_list = []
-
- # Ensure pem_file exists before read the certs data since decrypt_p7m may clear the pem_file wen decryption fails
- if os.path.exists(pem_file):
- with open(pem_file) as pem:
- for line in pem.readlines():
- buf.append(line)
- if re.match(r'[-]+END.*KEY[-]+', line):
- tmp_file = Certificates._write_to_tmp_file(index, 'prv', buf)
- pub = cryptutil.get_pubkey_from_prv(tmp_file)
- prvs[pub] = tmp_file
- buf = []
- index += 1
- elif re.match(r'[-]+END.*CERTIFICATE[-]+', line):
- tmp_file = Certificates._write_to_tmp_file(index, 'crt', buf)
- pub = cryptutil.get_pubkey_from_crt(tmp_file)
- thumbprint = cryptutil.get_thumbprint_from_crt(tmp_file)
- thumbprints[pub] = thumbprint
- # Rename crt with thumbprint as the file name
- crt = "{0}.crt".format(thumbprint)
- v1_cert_list.append({
- "name": None,
- "thumbprint": thumbprint
- })
- os.rename(tmp_file, os.path.join(conf.get_lib_dir(), crt))
- buf = []
- index += 1
+
+ with open(pem_file) as pem:
+ for line in pem.readlines():
+ buffer.append(line)
+ if re.match(r'[-]+END.*KEY[-]+', line):
+ tmp_file = Certificates._write_to_tmp_file(index, 'prv', buffer)
+ pub = self._crypt_util.get_pubkey_from_prv(tmp_file)
+ private_keys[pub] = tmp_file
+ buffer = []
+ index += 1
+ elif re.match(r'[-]+END.*CERTIFICATE[-]+', line):
+ tmp_file = Certificates._write_to_tmp_file(index, 'crt', buffer)
+ pub = self._crypt_util.get_pubkey_from_crt(tmp_file)
+ thumbprint = self._crypt_util.get_thumbprint_from_crt(tmp_file)
+ thumbprints[pub] = thumbprint
+ # Rename crt with thumbprint as the file name
+ crt = "{0}.crt".format(thumbprint)
+ os.rename(tmp_file, os.path.join(conf.get_lib_dir(), crt))
+ buffer = []
+ index += 1
# Rename prv key with thumbprint as the file name
- for pubkey in prvs:
+ for pubkey in private_keys:
thumbprint = thumbprints[pubkey]
if thumbprint:
- tmp_file = prvs[pubkey]
+ tmp_file = private_keys[pubkey]
prv = "{0}.prv".format(thumbprint)
os.rename(tmp_file, os.path.join(conf.get_lib_dir(), prv))
else:
- # Since private key has *no* matching certificate,
- # it will not be named correctly
- self.warnings.append("Found NO matching cert/thumbprint for private key!")
+ # Since private key has *no* matching certificate, it will not be named correctly
+ self.warn(WALAEventOperation.GoalState, "Found a private key with no matching cert/thumbprint!")
+ certificates = []
for pubkey, thumbprint in thumbprints.items():
- has_private_key = pubkey in prvs
- self.summary.append({"thumbprint": thumbprint, "hasPrivateKey": has_private_key})
-
- for v1_cert in v1_cert_list:
- cert = Cert()
- set_properties("certs", cert, v1_cert)
- self.cert_list.certificates.append(cert)
+ has_private_key = pubkey in private_keys
+ certificates.append({"thumbprint": thumbprint, "hasPrivateKey": has_private_key})
+ return certificates
@staticmethod
def _write_to_tmp_file(index, suffix, buf):
@@ -614,9 +662,7 @@ class Certificates(object):
class EmptyCertificates:
def __init__(self):
- self.cert_list = CertList()
- self.summary = [] # debugging info
- self.warnings = []
+ self.summary = []
class RemoteAccess(object):
"""
diff --git a/azurelinuxagent/common/protocol/restapi.py b/azurelinuxagent/common/protocol/restapi.py
index 54e020c1..7e563b4a 100644
--- a/azurelinuxagent/common/protocol/restapi.py
+++ b/azurelinuxagent/common/protocol/restapi.py
@@ -43,30 +43,6 @@ class VMInfo(DataContract):
self.tenantName = tenantName
-class CertificateData(DataContract):
- def __init__(self, certificateData=None):
- self.certificateData = certificateData
-
-
-class Cert(DataContract):
- def __init__(self,
- name=None,
- thumbprint=None,
- certificateDataUri=None,
- storeName=None,
- storeLocation=None):
- self.name = name
- self.thumbprint = thumbprint
- self.certificateDataUri = certificateDataUri
- self.storeLocation = storeLocation
- self.storeName = storeName
-
-
-class CertList(DataContract):
- def __init__(self):
- self.certificates = DataContractList(Cert)
-
-
class VMAgentFamily(object):
def __init__(self, name):
self.name = name
diff --git a/azurelinuxagent/common/protocol/wire.py b/azurelinuxagent/common/protocol/wire.py
index 00a01f09..0277b7f0 100644
--- a/azurelinuxagent/common/protocol/wire.py
+++ b/azurelinuxagent/common/protocol/wire.py
@@ -115,8 +115,7 @@ class WireProtocol(DataContract):
return vminfo
def get_certs(self):
- certificates = self.client.get_certs()
- return certificates.cert_list
+ return self.client.get_certs()
def get_goal_state(self):
return self.client.get_goal_state()
@@ -1140,13 +1139,11 @@ class WireClient(object):
"Content-Type": "text/xml;charset=utf-8"
}
- def get_header_for_cert(self):
- return self._get_header_for_encrypted_request("DES_EDE3_CBC")
-
def get_header_for_remote_access(self):
- return self._get_header_for_encrypted_request("AES128_CBC")
+ return self.get_headers_for_encrypted_request("AES128_CBC")
- def _get_header_for_encrypted_request(self, cypher):
+ @staticmethod
+ def get_headers_for_encrypted_request(cypher):
trans_cert_file = os.path.join(conf.get_lib_dir(), TRANSPORT_CERT_FILE_NAME)
try:
content = fileutil.read_file(trans_cert_file)
@@ -1154,12 +1151,15 @@ class WireClient(object):
raise ProtocolError("Failed to read {0}: {1}".format(trans_cert_file, e))
cert = get_bytes_from_pem(content)
- return {
+ headers = {
"x-ms-agent-name": "WALinuxAgent",
"x-ms-version": PROTOCOL_VERSION,
- "x-ms-cipher-name": cypher,
"x-ms-guest-agent-public-x509-cert": cert
}
+ if cypher is not None: # the cypher header is optional, currently defaults to AES128_CBC
+ headers["x-ms-cipher-name"] = cypher
+
+ return headers
def get_host_plugin(self):
if self._host_plugin is None:
diff --git a/azurelinuxagent/common/utils/cryptutil.py b/azurelinuxagent/common/utils/cryptutil.py
index 00126e25..789a9486 100644
--- a/azurelinuxagent/common/utils/cryptutil.py
+++ b/azurelinuxagent/common/utils/cryptutil.py
@@ -86,36 +86,22 @@ class CryptUtil(object):
thumbprint = thumbprint.rstrip().split('=')[1].replace(':', '').upper()
return thumbprint
- def decrypt_p7m(self, p7m_file, trans_prv_file, trans_cert_file, pem_file):
-
- def _cleanup_files(files_to_cleanup):
- for file_path in files_to_cleanup:
- if os.path.exists(file_path):
- try:
- os.remove(file_path)
- logger.info("Removed file {0}", file_path)
- except Exception as e:
- logger.error("Failed to remove file {0}: {1}", file_path, ustr(e))
-
- if not os.path.exists(p7m_file):
- raise IOError(errno.ENOENT, "File not found", p7m_file)
- elif not os.path.exists(trans_prv_file):
- raise IOError(errno.ENOENT, "File not found", trans_prv_file)
- else:
- try:
- shellutil.run_pipe([
- [self.openssl_cmd, "cms", "-decrypt", "-in", p7m_file, "-inkey", trans_prv_file, "-recip", trans_cert_file],
- [self.openssl_cmd, "pkcs12", "-nodes", "-password", "pass:", "-out", pem_file]])
- except shellutil.CommandError as command_error:
- logger.error("Failed to decrypt {0} (return code: {1})\n[stdout]\n{2}\n[stderr]\n{3}",
- p7m_file, command_error.returncode, command_error.stdout, command_error.stderr)
- # If the decryption fails, old version of openssl overwrite the output file(if exist) with empty data while
- # new version of openssl(3.2.2) does not overwrite the output file, So output file may contain old certs data.
- # Correcting the behavior by removing the temporary output files since having empty/no data is makes sense when decryption fails
- # otherwise we end up processing old certs again.
- files_to_remove = [p7m_file, pem_file]
- logger.info("Removing temporary state certificate files {0}", files_to_remove)
- _cleanup_files(files_to_remove)
+ def decrypt_certificates_p7m(self, p7m_file, trans_prv_file, trans_cert_file, pfx_file):
+ umask = None
+ try:
+ umask = os.umask(0o077)
+ with open(pfx_file, "wb") as pfx_file_:
+ shellutil.run_command([self.openssl_cmd, "cms", "-decrypt", "-in", p7m_file, "-inkey", trans_prv_file, "-recip", trans_cert_file], stdout=pfx_file_)
+ finally:
+ if umask is not None:
+ os.umask(umask)
+
+ def convert_pfx_to_pem(self, pfx_file, nomacver, pem_file):
+ command = [self.openssl_cmd, "pkcs12", "-nodes", "-password", "pass:", "-in", pfx_file, "-out", pem_file]
+ if nomacver:
+ command.append("-nomacver")
+
+ shellutil.run_command(command)
def crt_to_ssh(self, input_file, output_file):
with open(output_file, "ab") as file_out:
diff --git a/azurelinuxagent/ga/update.py b/azurelinuxagent/ga/update.py
index 7ab19101..f806ff26 100644
--- a/azurelinuxagent/ga/update.py
+++ b/azurelinuxagent/ga/update.py
@@ -85,6 +85,7 @@ READONLY_FILE_GLOBS = [
"*.p7m",
"*.pem",
"*.prv",
+ "Certificates.xml",
"ovf-env.xml"
]
diff --git a/tests/common/protocol/test_goal_state.py b/tests/common/protocol/test_goal_state.py
index a5f89587..9b70ce05 100644
--- a/tests/common/protocol/test_goal_state.py
+++ b/tests/common/protocol/test_goal_state.py
@@ -6,6 +6,7 @@ import datetime
import glob
import os
import re
+import subprocess
import shutil
import time
@@ -492,9 +493,85 @@ class GoalStateTestCase(AgentTestCase, HttpRequestPredicates):
goal_state = GoalState(protocol.client)
- self.assertEqual(0, len(goal_state.certs.summary), "Cert list should be empty")
- self.assertEqual(1, http_get_handler.certificate_requests, "There should have been exactly 1 requests for the goal state certificates")
+ self.assertEqual(0, len(goal_state.certs.summary), "Certificates should be empty")
+ self.assertEqual(2, http_get_handler.certificate_requests, "There should have been exactly 2 requests for the goal state certificates") # 1 for the initial request, 1 for the retry with an older cypher
+ def test_goal_state_should_try_legacy_cypher_and_then_fail_when_no_cyphers_are_supported_by_the_wireserver(self):
+ cyphers = []
+ def http_get_handler(url, *_, **kwargs):
+ if HttpRequestPredicates.is_certificates_request(url):
+ cypher = kwargs["headers"].get("x-ms-cipher-name")
+ if cypher is None:
+ raise Exception("x-ms-cipher-name header is missing from the Certificates request")
+ cyphers.append(cypher)
+ return MockHttpResponse(status=400, body="unsupported cypher: {0}".format(cypher).encode('utf-8'))
+ return None
+
+ with mock_wire_protocol(wire_protocol_data.DATA_FILE) as protocol:
+ with patch("azurelinuxagent.common.event.LogEvent.error") as log_error_patch:
+ protocol.set_http_handlers(http_get_handler=http_get_handler)
+ goal_state = GoalState(protocol.client)
+
+ log_error_args, _ = log_error_patch.call_args
+
+ self.assertEqual(cyphers, ["AES128_CBC", "DES_EDE3_CBC"], "There should have been 2 requests for the goal state certificates (AES128_CBC and DES_EDE3_CBC)")
+ self.assertEqual(log_error_args[0], "GoalStateCertificates", "An error fetching the goal state Certificates should have been reported")
+ self.assertEqual(0, len(goal_state.certs.summary), "Certificates should be empty")
+ self.assertFalse(os.path.exists(os.path.join(conf.get_lib_dir(), "Certificates.pfx")), "The Certificates.pfx file should not have been created")
+
+ def test_goal_state_should_try_legacy_cypher_and_then_fail_when_no_cyphers_are_supported_by_openssl(self):
+ cyphers = []
+ def http_get_handler(url, *_, **kwargs):
+ if HttpRequestPredicates.is_certificates_request(url):
+ cyphers.append(kwargs["headers"].get("x-ms-cipher-name"))
+ return None
+
+ original_popen = subprocess.Popen
+ openssl = conf.get_openssl_cmd()
+ decrypt_calls = []
+ def mock_fail_popen(command, *args, **kwargs):
+ if len(command) > 3 and command[0:3] == [openssl, "cms", "-decrypt"]:
+ decrypt_calls.append(command)
+ command[1] = "fake_openssl_command" # force an error on the openssl to simulate a decryption failure
+ return original_popen(command, *args, **kwargs)
+
+ with mock_wire_protocol(wire_protocol_data.DATA_FILE) as protocol:
+ protocol.set_http_handlers(http_get_handler=http_get_handler)
+ with patch("azurelinuxagent.common.event.LogEvent.error") as log_error_patch:
+ with patch("azurelinuxagent.ga.cgroupapi.subprocess.Popen", mock_fail_popen):
+ goal_state = GoalState(protocol.client)
+
+ log_error_args, _ = log_error_patch.call_args
+
+ self.assertEqual(cyphers, ["AES128_CBC", "DES_EDE3_CBC"], "There should have been 2 requests for the goal state certificates (AES128_CBC and DES_EDE3_CBC)")
+ self.assertEqual(2, len(decrypt_calls), "There should have been 2 calls to 'openssl cms -decrypt'")
+ self.assertEqual(log_error_args[0], "GoalStateCertificates", "An error fetching the goal state Certificates should have been reported")
+ self.assertEqual(0, len(goal_state.certs.summary), "Certificates should be empty")
+ self.assertFalse(os.path.exists(os.path.join(conf.get_lib_dir(), "Certificates.pfx")), "The Certificates.pfx file should not have been created")
+
+ def test_goal_state_should_try_without_and_with_mac_verification_then_fail_when_the_pfx_cannot_be_converted(self):
+ original_popen = subprocess.Popen
+ openssl = conf.get_openssl_cmd()
+ nomacver = []
+
+ def mock_fail_popen(command, *args, **kwargs):
+ if len(command) > 2 and command[0] == openssl and command[1] == "pkcs12":
+ nomacver.append("-nomacver" in command)
+ # force an error on the openssl to simulate the conversion failure
+ command[1] = "fake_openssl_command"
+ return original_popen(command, *args, **kwargs)
+
+
+ with mock_wire_protocol(wire_protocol_data.DATA_FILE) as protocol:
+ with patch("azurelinuxagent.common.event.LogEvent.error") as log_error_patch:
+ with patch("azurelinuxagent.ga.cgroupapi.subprocess.Popen", mock_fail_popen):
+ goal_state = GoalState(protocol.client)
+
+ log_error_args, _ = log_error_patch.call_args
+
+ self.assertEqual(nomacver, [True, False], "There should have been 2 attempts to parse the PFX (with and without -nomacver)")
+ self.assertEqual(log_error_args[0], "GoalStateCertificates", "An error fetching the goal state Certificates should have been reported")
+ self.assertEqual(0, len(goal_state.certs.summary), "Certificates should be empty")
def test_it_should_raise_when_goal_state_properties_not_initialized(self):
with GoalStateTestCase._create_protocol_ws_and_hgap_in_sync() as protocol:
diff --git a/tests/common/protocol/test_hostplugin.py b/tests/common/protocol/test_hostplugin.py
index 4c97c73f..7d94139b 100644
--- a/tests/common/protocol/test_hostplugin.py
+++ b/tests/common/protocol/test_hostplugin.py
@@ -365,8 +365,7 @@ class TestHostPlugin(HttpRequestPredicates, AgentTestCase):
# ensure host plugin is not set as default
self.assertFalse(wire.HostPluginProtocol.is_default_channel)
- @patch("azurelinuxagent.common.event.add_event")
- def test_put_status_error_reporting(self, patch_add_event):
+ def test_put_status_error_reporting(self):
"""
Validate the telemetry when uploading status fails
"""
@@ -377,22 +376,22 @@ class TestHostPlugin(HttpRequestPredicates, AgentTestCase):
put_error = wire.HttpError("put status http error")
with patch.object(restutil, "http_put", side_effect=put_error):
- with patch.object(wire.HostPluginProtocol,
- "ensure_initialized", return_value=True):
- self.assertRaises(wire.ProtocolError, wire_protocol_client.upload_status_blob)
-
- # The agent tries to upload via HostPlugin and that fails due to
- # http_put having a side effect of "put_error"
- #
- # The agent tries to upload using a direct connection, and that succeeds.
- self.assertEqual(1, wire_protocol_client.status_blob.upload.call_count) # pylint: disable=no-member
- # The agent never touches the default protocol is this code path, so no change.
- self.assertFalse(wire.HostPluginProtocol.is_default_channel)
- # The agent never logs telemetry event for direct fallback
- self.assertEqual(1, patch_add_event.call_count)
- self.assertEqual('ReportStatus', patch_add_event.call_args[1]['op'])
- self.assertTrue('Falling back to direct' in patch_add_event.call_args[1]['message'])
- self.assertEqual(True, patch_add_event.call_args[1]['is_success'])
+ with patch.object(wire.HostPluginProtocol, "ensure_initialized", return_value=True):
+ with patch("azurelinuxagent.common.event.add_event") as patch_add_event:
+ self.assertRaises(wire.ProtocolError, wire_protocol_client.upload_status_blob)
+
+ # The agent tries to upload via HostPlugin and that fails due to
+ # http_put having a side effect of "put_error"
+ #
+ # The agent tries to upload using a direct connection, and that succeeds.
+ self.assertEqual(1, wire_protocol_client.status_blob.upload.call_count) # pylint: disable=no-member
+ # The agent never touches the default protocol is this code path, so no change.
+ self.assertFalse(wire.HostPluginProtocol.is_default_channel)
+ # The agent never logs telemetry event for direct fallback
+ self.assertEqual(1, patch_add_event.call_count)
+ self.assertEqual('ReportStatus', patch_add_event.call_args[1]['op'])
+ self.assertTrue('Falling back to direct' in patch_add_event.call_args[1]['message'])
+ self.assertEqual(True, patch_add_event.call_args[1]['is_success'])
def test_validate_http_request_when_uploading_status(self):
"""Validate correct set of data is sent to HostGAPlugin when reporting VM status"""
diff --git a/tests/common/protocol/test_wire.py b/tests/common/protocol/test_wire.py
index c3dc9461..bec4634f 100644
--- a/tests/common/protocol/test_wire.py
+++ b/tests/common/protocol/test_wire.py
@@ -497,12 +497,6 @@ class TestWireProtocol(AgentTestCase, HttpRequestPredicates):
client.report_event(self._get_telemetry_events_generator(event_list), flush=True)
self.assertEqual(mock_http_request.call_count, 3)
- def test_get_header_for_cert_should_use_triple_des(self, *_):
- with mock_wire_protocol(wire_protocol_data.DATA_FILE) as protocol:
- headers = protocol.client.get_header_for_cert()
- self.assertIn("x-ms-cipher-name", headers)
- self.assertEqual(headers["x-ms-cipher-name"], "DES_EDE3_CBC", "Unexpected x-ms-cipher-name")
-
def test_get_header_for_remote_access_should_use_aes128(self, *_):
with mock_wire_protocol(wire_protocol_data.DATA_FILE) as protocol:
headers = protocol.client.get_header_for_remote_access()
@@ -1096,7 +1090,7 @@ class UpdateGoalStateTestCase(HttpRequestPredicates, AgentTestCase):
self.assertEqual(protocol.client.get_hosting_env().deployment_name, new_hosting_env_deployment_name)
self.assertEqual(protocol.client.get_shared_conf().xml_text, new_shared_conf)
self.assertEqual(sequence_number, new_sequence_number)
- self.assertEqual(len(protocol.client.get_certs().cert_list.certificates), 0)
+ self.assertEqual(len(protocol.client.get_certs().summary), 0)
self.assertEqual(protocol.client.get_host_plugin().container_id, new_container_id)
self.assertEqual(protocol.client.get_host_plugin().role_config_name, new_role_config_name)
diff --git a/tests/ga/test_update.py b/tests/ga/test_update.py
index 167e69dc..376e9fc0 100644
--- a/tests/ga/test_update.py
+++ b/tests/ga/test_update.py
@@ -2059,7 +2059,7 @@ class TryUpdateGoalStateTestCase(HttpRequestPredicates, AgentTestCase):
# Double check the certificates are correct
goal_state = protocol.get_goal_state()
- thumbprints = [c.thumbprint for c in goal_state.certs.cert_list.certificates]
+ thumbprints = [c["thumbprint"] for c in goal_state.certs.summary]
for extension in goal_state.extensions_goal_state.extensions:
for settings in extension.settings:
--
2.47.3