diff --git a/src/subscription_manager/managerlib.py b/src/subscription_manager/managerlib.py index df2deb997..dac08371e 100644 --- a/src/subscription_manager/managerlib.py +++ b/src/subscription_manager/managerlib.py @@ -749,166 +749,6 @@ def lookup_provided_products(self, pool_id: str) -> Optional[List[Tuple[str, str return provided_products -class ImportFileExtractor: - """ - Responsible for checking an import file and pulling cert and key from it. - An import file may include only the certificate, but may also include its - key. - - An import file is processed looking for: - - -----BEGIN ----- - - .. - -----END ----- - - and will only process if it finds CERTIFICATE or KEY in the text. - - For example the following would locate a key and cert. - - -----BEGIN CERTIFICATE----- - - -----END CERTIFICATE----- - -----BEGIN PUBLIC KEY----- - - -----END PUBLIC KEY----- - - """ - - _REGEX_START_GROUP = "start" - _REGEX_CONTENT_GROUP = "content" - _REGEX_END_GROUP = "end" - _REGEX = r"(?P<%s>[-]*BEGIN[\w\ ]*[-]*)(?P<%s>[^-]*)(?P<%s>[-]*END[\w\ ]*[-]*)" % ( - _REGEX_START_GROUP, - _REGEX_CONTENT_GROUP, - _REGEX_END_GROUP, - ) - _PATTERN = re.compile(_REGEX) - - _CERT_DICT_TAG = "CERTIFICATE" - _KEY_DICT_TAG = "KEY" - _ENT_DICT_TAG = "ENTITLEMENT" - _SIG_DICT_TAG = "RSA SIGNATURE" - - def __init__(self, cert_file_path: str): - self.path = cert_file_path - self.file_name = os.path.basename(cert_file_path) - - content = self._read(cert_file_path) - self.parts = self._process_content(content) - - def _read(self, file_path: str) -> str: - fd = open(file_path, "r") - file_content = fd.read() - fd.close() - return file_content - - def _process_content(self, content: str) -> Dict[str, str]: - part_dict = {} - matches = self._PATTERN.finditer(content) - for match in matches: - start = match.group(self._REGEX_START_GROUP) - meat = match.group(self._REGEX_CONTENT_GROUP) - end = match.group(self._REGEX_END_GROUP) - - dict_key = None - if not start.find(self._KEY_DICT_TAG) < 0: - dict_key = self._KEY_DICT_TAG - elif not start.find(self._CERT_DICT_TAG) < 0: - dict_key = self._CERT_DICT_TAG - elif not start.find(self._ENT_DICT_TAG) < 0: - dict_key = self._ENT_DICT_TAG - elif not start.find(self._SIG_DICT_TAG) < 0: - dict_key = self._SIG_DICT_TAG - - if dict_key is None: - continue - - part_dict[dict_key] = start + meat + end - return part_dict - - def contains_key_content(self) -> bool: - return self._KEY_DICT_TAG in self.parts - - def get_key_content(self) -> Optional[str]: - key_content = None - if self._KEY_DICT_TAG in self.parts: - key_content = self.parts[self._KEY_DICT_TAG] - return key_content - - def get_cert_content(self) -> str: - cert_content = "" - if self._CERT_DICT_TAG in self.parts: - cert_content = self.parts[self._CERT_DICT_TAG] - if self._ENT_DICT_TAG in self.parts: - cert_content = cert_content + os.linesep + self.parts[self._ENT_DICT_TAG] - if self._SIG_DICT_TAG in self.parts: - cert_content = cert_content + os.linesep + self.parts[self._SIG_DICT_TAG] - return cert_content - - def verify_valid_entitlement(self) -> bool: - """ - Verify that a valid entitlement was processed. - - @return: True if valid, False otherwise. - """ - try: - cert = self.get_cert() - # Don't want to check class explicitly, instead we'll look for - # order info, which only an entitlement cert could have: - if not hasattr(cert, "order"): - return False - except CertificateException: - return False - ent_key = Key(self.get_key_content()) - if ent_key.bogus(): - return False - return True - - # TODO: rewrite to use certlib.EntitlementCertBundleInstall? - def write_to_disk(self) -> None: - """ - Write/copy cert to the entitlement cert dir. - """ - self._ensure_entitlement_dir_exists() - dest_file_path = os.path.join(ENT_CONFIG_DIR, self._create_filename_from_cert_serial_number()) - - # Write the key/cert content to new files - log.debug("Writing certificate file: %s" % (dest_file_path)) - cert_content = self.get_cert_content() - self._write_file(dest_file_path, cert_content) - - if self.contains_key_content(): - dest_key_file_path = self._get_key_path_from_dest_cert_path(dest_file_path) - log.debug("Writing key file: %s" % (dest_key_file_path)) - self._write_file(dest_key_file_path, self.get_key_content()) - - def _write_file(self, target_path: str, content: str) -> None: - new_file = open(target_path, "w") - try: - new_file.write(content) - finally: - new_file.close() - - def _ensure_entitlement_dir_exists(self) -> None: - if not os.access(ENT_CONFIG_DIR, os.R_OK): - os.mkdir(ENT_CONFIG_DIR) - - def _get_key_path_from_dest_cert_path(self, dest_cert_path: str) -> str: - file_parts = os.path.splitext(dest_cert_path) - return file_parts[0] + "-key" + file_parts[1] - - def _create_filename_from_cert_serial_number(self) -> str: - "create from serial" - ent_cert = self.get_cert() - return "%s.pem" % (ent_cert.serial) - - def get_cert(self) -> "EntitlementCertificate": - cert_content: str = self.get_cert_content() - ent_cert: EntitlementCertificate = create_from_pem(cert_content) - return ent_cert - - def _sub_dict(datadict: dict, subkeys: Iterable[str], default: Optional[object] = None) -> dict: """Return a dict that is a subset of datadict matching only the keys in subkeys""" return dict([(k, datadict.get(k, default)) for k in subkeys]) diff --git a/test/test_managerlib.py b/test/test_managerlib.py index f7ec64a52..975e0d067 100644 --- a/test/test_managerlib.py +++ b/test/test_managerlib.py @@ -715,201 +715,6 @@ def MockSystemLog(self, message, priority): EXPECTED_CONTENT_V3 = EXPECTED_CERT_CONTENT_V3 + os.linesep + EXPECTED_KEY_CONTENT_V3 -class ExtractorStub(managerlib.ImportFileExtractor): - def __init__(self, content, file_path="test/file/path"): - self.content = content - self.writes = [] - managerlib.ImportFileExtractor.__init__(self, file_path) - - # Stub out any file system access - def _read(self, file_path): - return self.content - - def _write_file(self, target, content): - self.writes.append((target, content)) - - def _ensure_entitlement_dir_exists(self): - # Do nothing but stub out the dir check to avoid file system access. - pass - - -class TestImportFileExtractor(unittest.TestCase): - def test_contains_key_content_when_key_and_cert_exists_in_import_file(self): - extractor = ExtractorStub(EXPECTED_CONTENT) - self.assertTrue(extractor.contains_key_content()) - - def test_contains_key_content_when_key_and_cert_exists_in_import_file_v3(self): - extractor = ExtractorStub(EXPECTED_CONTENT_V3) - self.assertTrue(extractor.contains_key_content()) - - def test_does_not_contain_key_when_key_does_not_exist_in_import_file(self): - extractor = ExtractorStub(EXPECTED_CERT_CONTENT) - self.assertFalse(extractor.contains_key_content()) - - def test_does_not_contain_key_when_key_does_not_exist_in_import_file_v3(self): - extractor = ExtractorStub(EXPECTED_CERT_CONTENT_V3) - self.assertFalse(extractor.contains_key_content()) - - def test_get_key_content_when_key_exists(self): - extractor = ExtractorStub(EXPECTED_CONTENT, file_path="12345.pem") - self.assertTrue(extractor.contains_key_content()) - self.assertEqual(EXPECTED_KEY_CONTENT, extractor.get_key_content()) - - def test_get_key_content_when_key_exists_v3(self): - extractor = ExtractorStub(EXPECTED_CONTENT_V3, file_path="12345.pem") - self.assertTrue(extractor.contains_key_content()) - self.assertEqual(EXPECTED_KEY_CONTENT_V3, extractor.get_key_content()) - - def test_get_key_content_returns_None_when_key_does_not_exist(self): - extractor = ExtractorStub(EXPECTED_CERT_CONTENT, file_path="12345.pem") - self.assertFalse(extractor.get_key_content()) - - def test_get_key_content_returns_None_when_key_does_not_exist_v3(self): - extractor = ExtractorStub(EXPECTED_CERT_CONTENT_V3, file_path="12345.pem") - self.assertFalse(extractor.get_key_content()) - - def test_get_cert_content(self): - extractor = ExtractorStub(EXPECTED_CONTENT, file_path="12345.pem") - self.assertTrue(extractor.contains_key_content()) - self.assertEqual(EXPECTED_CERT_CONTENT, extractor.get_cert_content()) - - def test_get_cert_content_v3(self): - extractor = ExtractorStub(EXPECTED_CONTENT_V3, file_path="12345.pem") - self.assertTrue(extractor.contains_key_content()) - self.assertEqual(EXPECTED_CERT_CONTENT_V3, extractor.get_cert_content()) - - def test_get_cert_content_returns_None_when_cert_does_not_exist(self): - extractor = ExtractorStub(EXPECTED_KEY_CONTENT, file_path="12345.pem") - self.assertFalse(extractor.get_cert_content()) - - def test_get_cert_content_returns_None_when_cert_does_not_exist_v3(self): - extractor = ExtractorStub(EXPECTED_KEY_CONTENT_V3, file_path="12345.pem") - self.assertFalse(extractor.get_cert_content()) - - def test_verify_valid_entitlement_for_invalid_cert(self): - extractor = ExtractorStub(EXPECTED_KEY_CONTENT, file_path="12345.pem") - self.assertFalse(extractor.verify_valid_entitlement()) - - def test_verify_valid_entitlement_for_invalid_cert_v3(self): - extractor = ExtractorStub(EXPECTED_KEY_CONTENT_V3, file_path="12345.pem") - self.assertFalse(extractor.verify_valid_entitlement()) - - def test_verify_valid_entitlement_for_invalid_cert_bundle(self): - # Use a bundle of cert + key, but the cert is not an entitlement cert: - extractor = ExtractorStub(IDENTITY_CERT_WITH_KEY, file_path="12345.pem") - self.assertFalse(extractor.verify_valid_entitlement()) - - def test_verify_valid_entitlement_for_no_key(self): - extractor = ExtractorStub(EXPECTED_CERT_CONTENT, file_path="12345.pem") - self.assertFalse(extractor.verify_valid_entitlement()) - - def test_verify_valid_entitlement_for_no_key_v3(self): - extractor = ExtractorStub(EXPECTED_CERT_CONTENT_V3, file_path="12345.pem") - self.assertFalse(extractor.verify_valid_entitlement()) - - def test_verify_valid_entitlement_for_no_cert_content(self): - extractor = ExtractorStub("", file_path="12345.pem") - self.assertFalse(extractor.verify_valid_entitlement()) - - def test_write_cert_only(self): - expected_cert_file = "%d.pem" % (EXPECTED_CERT.serial) - extractor = ExtractorStub(EXPECTED_CERT_CONTENT, file_path=expected_cert_file) - extractor.write_to_disk() - - self.assertEqual(1, len(extractor.writes)) - - write_one = extractor.writes[0] - self.assertEqual(os.path.join(ENT_CONFIG_DIR, expected_cert_file), write_one[0]) - self.assertEqual(EXPECTED_CERT_CONTENT, write_one[1]) - - def test_write_cert_only_v3(self): - expected_cert_file = "%d.pem" % (EXPECTED_CERT_V3.serial) - extractor = ExtractorStub(EXPECTED_CERT_CONTENT_V3, file_path=expected_cert_file) - extractor.write_to_disk() - - self.assertEqual(1, len(extractor.writes)) - - write_one = extractor.writes[0] - self.assertEqual(os.path.join(ENT_CONFIG_DIR, expected_cert_file), write_one[0]) - self.assertEqual(EXPECTED_CERT_CONTENT_V3, write_one[1]) - - def test_write_key_and_cert(self): - filename = "%d.pem" % (EXPECTED_CERT.serial) - self._assert_correct_cert_and_key_files_generated_with_filename(filename) - - def test_write_key_and_cert_v3(self): - filename = "%d.pem" % (EXPECTED_CERT_V3.serial) - self._assert_correct_cert_and_key_files_generated_with_filename_v3(filename) - - def test_file_renamed_when_imported_with_serial_no_and_custom_extension(self): - filename = "%d.cert" % (EXPECTED_CERT.serial) - self._assert_correct_cert_and_key_files_generated_with_filename(filename) - - def test_file_renamed_when_imported_with_serial_no_and_custom_extension_v3(self): - filename = "%d.cert" % (EXPECTED_CERT_V3.serial) - self._assert_correct_cert_and_key_files_generated_with_filename_v3(filename) - - def test_file_renamed_when_imported_with_serial_no_and_no_extension(self): - filename = str(EXPECTED_CERT.serial) - self._assert_correct_cert_and_key_files_generated_with_filename(filename) - - def test_file_renamed_when_imported_with_serial_no_and_no_extension_v3(self): - filename = str(EXPECTED_CERT_V3.serial) - self._assert_correct_cert_and_key_files_generated_with_filename_v3(filename) - - def test_file_renamed_when_imported_with_custom_name_and_pem_extension(self): - filename = "entitlement.pem" - self._assert_correct_cert_and_key_files_generated_with_filename(filename) - - def test_file_renamed_when_imported_with_custom_name_and_pem_extension_v3(self): - filename = "entitlement.pem" - self._assert_correct_cert_and_key_files_generated_with_filename_v3(filename) - - def test_file_renamed_when_imported_with_custom_name_no_extension(self): - filename = "entitlement" - self._assert_correct_cert_and_key_files_generated_with_filename(filename) - - def test_file_renamed_when_imported_with_custom_name_no_extension_v3(self): - filename = "entitlement" - self._assert_correct_cert_and_key_files_generated_with_filename_v3(filename) - - def _assert_correct_cert_and_key_files_generated_with_filename(self, filename): - expected_file_prefix = "%d" % (EXPECTED_CERT.serial) - expected_cert_file = expected_file_prefix + ".pem" - expected_key_file = expected_file_prefix + "-key.pem" - - extractor = ExtractorStub(EXPECTED_CONTENT, file_path=filename) - extractor.write_to_disk() - - self.assertEqual(2, len(extractor.writes)) - - write_one = extractor.writes[0] - self.assertEqual(os.path.join(ENT_CONFIG_DIR, expected_cert_file), write_one[0]) - self.assertEqual(EXPECTED_CERT_CONTENT, write_one[1]) - - write_two = extractor.writes[1] - self.assertEqual(os.path.join(ENT_CONFIG_DIR, expected_key_file), write_two[0]) - self.assertEqual(EXPECTED_KEY_CONTENT, write_two[1]) - - def _assert_correct_cert_and_key_files_generated_with_filename_v3(self, filename): - expected_file_prefix = "%d" % (EXPECTED_CERT_V3.serial) - expected_cert_file = expected_file_prefix + ".pem" - expected_key_file = expected_file_prefix + "-key.pem" - - extractor = ExtractorStub(EXPECTED_CONTENT_V3, file_path=filename) - extractor.write_to_disk() - - self.assertEqual(2, len(extractor.writes)) - - write_one = extractor.writes[0] - self.assertEqual(os.path.join(ENT_CONFIG_DIR, expected_cert_file), write_one[0]) - self.assertEqual(EXPECTED_CERT_CONTENT_V3, write_one[1]) - - write_two = extractor.writes[1] - self.assertEqual(os.path.join(ENT_CONFIG_DIR, expected_key_file), write_two[0]) - self.assertEqual(EXPECTED_KEY_CONTENT_V3, write_two[1]) - - class TestMergedPoolsStackingGroupSorter(unittest.TestCase): def test_sorter_adds_group_for_non_stackable_entitlement(self): pool = self._create_pool("test-prod-1", "Test Prod 1")