diff --git a/src/azure-cli/azure/cli/command_modules/keyvault/_help.py b/src/azure-cli/azure/cli/command_modules/keyvault/_help.py index 525b4c5b159..7fc77f3465a 100644 --- a/src/azure-cli/azure/cli/command_modules/keyvault/_help.py +++ b/src/azure-cli/azure/cli/command_modules/keyvault/_help.py @@ -906,6 +906,20 @@ the secret will be downloaded. This operation requires the secrets/backup permission. """ +helps['keyvault secret copy'] = """ +type: command +short-summary: Copy a secret from one Key Vault to another. +long-summary: Copies the latest version of a secret from a source Key Vault to a destination Key Vault. + This operation copies the secret value and its metadata (tags, content-type, attributes). +examples: + - name: Copy a specific secret from one vault to another. + text: az keyvault secret copy --source-vault SourceVault --destination-vault DestVault --name MySecret + - name: Copy all secrets from one vault to another. + text: az keyvault secret copy --source-vault SourceVault --destination-vault DestVault --all + - name: Copy a secret and overwrite if it already exists in the destination. + text: az keyvault secret copy --source-vault SourceVault --destination-vault DestVault --name MySecret --overwrite +""" + helps['keyvault secret restore'] = """ type: command short-summary: Restores a backed up secret to a vault. diff --git a/src/azure-cli/azure/cli/command_modules/keyvault/_params.py b/src/azure-cli/azure/cli/command_modules/keyvault/_params.py index 5718281f2b4..3cc398dd516 100644 --- a/src/azure-cli/azure/cli/command_modules/keyvault/_params.py +++ b/src/azure-cli/azure/cli/command_modules/keyvault/_params.py @@ -568,6 +568,20 @@ class CLISecurityDomainOperation(str, Enum): with self.argument_context('keyvault secret restore') as c: c.extra('vault_base_url', vault_name_type, required=True, arg_group='Id', type=get_vault_base_url_type(self.cli_ctx), id_part=None) + + with self.argument_context('keyvault secret copy') as c: + c.extra('vault_base_url', vault_name_type, type=get_vault_base_url_type(self.cli_ctx), + options_list=['--source-vault'], help='Name of the source Key Vault.', required=True) + c.extra('destination_vault', vault_name_type, type=get_vault_base_url_type(self.cli_ctx), + options_list=['--destination-vault'], help='Name of the destination Key Vault.', required=True) + c.argument('name', options_list=['--name', '-n'], + help='Name of the secret to copy. Mutually exclusive with --all. If neither --name nor --all is ' + 'specified, all secrets will be copied.', + required=False) + c.extra('all_secrets', arg_type=get_three_state_flag(), options_list=['--all'], + help='Copy all secrets from the source vault. Mutually exclusive with --name. If neither --name nor ' + '--all is specified, all secrets will be copied.') + c.extra('overwrite', arg_type=get_three_state_flag(), help='Overwrite existing secrets in destination.') # endregion # region keyvault security-domain diff --git a/src/azure-cli/azure/cli/command_modules/keyvault/commands.py b/src/azure-cli/azure/cli/command_modules/keyvault/commands.py index c70ae0b600c..ca5a2085e41 100644 --- a/src/azure-cli/azure/cli/command_modules/keyvault/commands.py +++ b/src/azure-cli/azure/cli/command_modules/keyvault/commands.py @@ -204,6 +204,7 @@ def load_command_table(self, _): g.keyvault_custom('download', 'download_secret') g.keyvault_custom('backup', 'backup_secret') g.keyvault_custom('restore', 'restore_secret', transform=transform_secret_set_attributes) + g.keyvault_custom('copy', 'copy_secret') # certificate track2 with self.command_group('keyvault certificate', data_certificate_entity.command_type) as g: diff --git a/src/azure-cli/azure/cli/command_modules/keyvault/custom.py b/src/azure-cli/azure/cli/command_modules/keyvault/custom.py index 9cc5f7e3916..1d6bb068d7c 100644 --- a/src/azure-cli/azure/cli/command_modules/keyvault/custom.py +++ b/src/azure-cli/azure/cli/command_modules/keyvault/custom.py @@ -2517,3 +2517,124 @@ def set_attributes_certificate(client, certificate_name, version=None, policy=No if kwargs.get('enabled') is not None or kwargs.get('tags') is not None: return client.update_certificate_properties(certificate_name=certificate_name, version=version, **kwargs) return client.get_certificate(certificate_name=certificate_name) + + +def _copy_single_secret(source_client, dest_client, secret_name, overwrite, is_single_mode): + from azure.core.exceptions import ResourceNotFoundError, HttpResponseError + + try: + # Check destination + if not overwrite: + try: + dest_client.get_secret(secret_name) + logger.warning("Secret '%s' already exists in destination. Skipping.", secret_name) + return None # Skipped + except ResourceNotFoundError: + pass + except HttpResponseError as e: + logger.warning("Error checking secret '%s' in destination: %s", secret_name, str(e)) + if e.status_code == 403: + logger.error("Access denied checking secret '%s' in destination.", secret_name) + return False # Failed + + # Copy + logger.info("Copying secret: %s", secret_name) + s = source_client.get_secret(secret_name) + + new_secret = dest_client.set_secret( + s.name, + s.value, + content_type=s.properties.content_type, + tags=s.properties.tags, + enabled=s.properties.enabled, + not_before=s.properties.not_before, + expires_on=s.properties.expires_on + ) + + logger.info("Successfully copied secret: %s", secret_name) + return {'name': new_secret.name, 'id': new_secret.id} + + except ResourceNotFoundError: + if is_single_mode: + raise CLIError("Secret '{}' not found in source vault.".format(secret_name)) + logger.error("Secret '%s' not found in source vault.", secret_name) + return False + except HttpResponseError as e: + if is_single_mode: + raise CLIError("Failed to copy secret '{}': {}".format(secret_name, str(e))) + + if e.status_code == 403: # Forbidden + logger.error("Access denied (403) for secret '%s': %s", secret_name, str(e)) + else: + logger.error("Failed to copy secret '%s': %s", secret_name, str(e)) + return False + + +def copy_secret(cmd, client, destination_vault, name=None, all_secrets=None, overwrite=False): + from azure.core.exceptions import HttpResponseError + from azure.keyvault.secrets import SecretClient + from azure.cli.core._profile import Profile + from azure.cli.core.commands.client_factory import prepare_client_kwargs_track2 + + # If neither a specific secret name nor --all is provided, default to copying all secrets. + if not name and not all_secrets: + all_secrets = True + + # A specific secret name and --all are mutually exclusive. + if name and all_secrets: + raise MutuallyExclusiveArgumentError("Specify either a secret name or --all, but not both.") + # Validation + if client.vault_url.rstrip('/') == destination_vault.rstrip('/'): + raise CLIError("Source and destination Key Vaults cannot be the same.") + + profile = Profile(cli_ctx=cmd.cli_ctx) + credential, _, _ = profile.get_login_credentials(subscription_id=cmd.cli_ctx.data.get('subscription_id')) + + # Use standard client kwargs for consistent logging/telemetry + client_kwargs = prepare_client_kwargs_track2(cmd.cli_ctx) + # KeyVault clients handle this internally or differently sometimes, mimicking _client_factory + client_kwargs.pop('http_logging_policy', None) + + dest_client = SecretClient(vault_url=destination_vault, credential=credential, **client_kwargs) + + # Fail fast if destination vault is not accessible or does not exist + try: + # Perform a lightweight call to validate vault accessibility. + # A 404 for a dummy secret name means the vault is reachable but the secret does not exist. + dest_client.get_secret("azure-cli-validation-dummy") + except HttpResponseError as e: + if getattr(e, "status_code", None) == 404: + # Vault is accessible but the dummy secret does not exist, which is expected. + pass + else: + raise CLIError(f"Failed to access destination Key Vault '{destination_vault}': {str(e)}") + + secrets_to_copy = [] + if name: + secrets_to_copy.append(name) + else: + logger.warning("Copying all secrets from source...") + try: + source_secrets = client.list_properties_of_secrets() + for s in source_secrets: + if s.managed: + logger.warning("Skipping managed secret: %s", s.name) + continue + secrets_to_copy.append(s.name) + except HttpResponseError as e: + raise CLIError(f"Failed to list secrets from source: {str(e)}") + + copied_secrets = [] + failed_secrets = [] + for secret_name in secrets_to_copy: + result = _copy_single_secret(client, dest_client, secret_name, overwrite, bool(name)) + if result: + copied_secrets.append(result) + elif result is False: + failed_secrets.append(secret_name) + + if failed_secrets: + logger.warning("Operation completed with failures. %s secrets failed to copy: %s", + len(failed_secrets), ', '.join(failed_secrets)) + + return copied_secrets diff --git a/src/azure-cli/azure/cli/command_modules/keyvault/tests/latest/test_keyvault_commands.py b/src/azure-cli/azure/cli/command_modules/keyvault/tests/latest/test_keyvault_commands.py index a1df39f8ac5..ec69755c7db 100644 --- a/src/azure-cli/azure/cli/command_modules/keyvault/tests/latest/test_keyvault_commands.py +++ b/src/azure-cli/azure/cli/command_modules/keyvault/tests/latest/test_keyvault_commands.py @@ -2778,5 +2778,96 @@ def test_keyvault_mhsm_region(self, resource_group, managed_hsm): self.cmd('keyvault region remove -g {rg} --hsm-name {hsm_name} -r uksouth') +class KeyVaultCopyScenarioTest(ScenarioTest): + # Filter User-Agent to prevent recording mismatch between recording env (Windows) and CI (Linux) + FILTER_HEADERS = ScenarioTest.FILTER_HEADERS + ['user-agent'] + + @ResourceGroupPreparer(name_prefix='cli_test_keyvault_copy') + @KeyVaultPreparer(name_prefix='cli-test-kv-src-', additional_params='--enable-rbac-authorization false') + def test_keyvault_secret_copy(self, resource_group, key_vault): + self.kwargs.update({ + 'src_kv': key_vault, + 'dest_kv': self.create_random_name('cli-test-kv-dest-', 24), + 'secret_name': self.create_random_name('secret-', 24), + 'secret_value': 'mysecretvalue', + 'new_val': 'newval', + 'secret_name_2': self.create_random_name('secret2-', 24) + }) + + # Create Dest KV + # Use simple creation to ensure speed and reliability in playback + self.cmd('keyvault create -g {rg} -n {dest_kv} --enable-rbac-authorization false') + self.addCleanup(self.cmd, 'keyvault delete -g {rg} -n {dest_kv}') + self.addCleanup(self.cmd, 'keyvault purge -n {dest_kv}') + + # Set secret in Source with tags and content-type + self.cmd('keyvault secret set --vault-name {kv} -n {secret_name} --value {secret_value} --tags tag1=value1 --content-type text/plain') + + # 1. Copy specific secret + self.cmd('keyvault secret copy --source-vault {kv} --destination-vault {dest_kv} --name {secret_name}') + self.cmd('keyvault secret show --vault-name {dest_kv} -n {secret_name}', checks=[ + self.check('value', '{secret_value}'), + self.check('tags.tag1', 'value1'), + self.check('contentType', 'text/plain') + ]) + + # 2. Copy all secrets + # Add another secret to source + self.cmd('keyvault secret set --vault-name {kv} -n {secret_name_2} --value {secret_value}') + + # Run copy --all + self.cmd('keyvault secret copy --source-vault {kv} --destination-vault {dest_kv} --all') + + # Verify both exist in dest + self.cmd('keyvault secret show --vault-name {dest_kv} -n {secret_name_2}', checks=[ + self.check('value', '{secret_value}') + ]) + + # 3. Test overwrite protection (default behavior: skip) + # Update source + self.cmd('keyvault secret set --vault-name {kv} -n {secret_name} --value {new_val}') + + # Copy without overwrite (should skip) + self.cmd('keyvault secret copy --source-vault {kv} --destination-vault {dest_kv} --name {secret_name}') + + # Verify destination still has old value + self.cmd('keyvault secret show --vault-name {dest_kv} -n {secret_name}', checks=[ + self.check('value', '{secret_value}') + ]) + + # 4. Test overwrite + self.cmd('keyvault secret copy --source-vault {kv} --destination-vault {dest_kv} --name {secret_name} --overwrite') + + # Verify destination has new value + self.cmd('keyvault secret show --vault-name {dest_kv} -n {secret_name}', checks=[ + self.check('value', '{new_val}') + ]) + + # 5. Test Mutual Exclusivity + self.cmd('keyvault secret copy --source-vault {kv} --destination-vault {dest_kv} --name {secret_name} --all', expect_failure=True) + + # 6. Test Source == Destination (Should fail) + self.cmd('keyvault secret copy --source-vault {kv} --destination-vault {kv} --name {secret_name}', expect_failure=True) + + # 7. Test Non-existent Destination (Should fail fast) + self.cmd('keyvault secret copy --source-vault {kv} --destination-vault non_existent_kv_12345 --name {secret_name}', expect_failure=True) + + # 8. Test Non-existent Secret in Source (Should fail) + self.cmd('keyvault secret copy --source-vault {kv} --destination-vault {dest_kv} --name non_existent_secret_123', expect_failure=True) + + # 9. Test Default Behavior (Implicit --all) + # Add a unique secret to check implicit copy + secret_name_3 = self.create_random_name('secret3-', 24) + self.kwargs['secret_name_3'] = secret_name_3 + self.cmd('keyvault secret set --vault-name {kv} -n {secret_name_3} --value {secret_value}') + + # Run copy without --name or --all + self.cmd('keyvault secret copy --source-vault {kv} --destination-vault {dest_kv}') + + # Verify it was copied + self.cmd('keyvault secret show --vault-name {dest_kv} -n {secret_name_3}', checks=[ + self.check('value', '{secret_value}') + ]) + if __name__ == '__main__': unittest.main() diff --git a/src/azure-cli/azure/cli/command_modules/keyvault/tests/latest/test_keyvault_unit.py b/src/azure-cli/azure/cli/command_modules/keyvault/tests/latest/test_keyvault_unit.py new file mode 100644 index 00000000000..f3787734044 --- /dev/null +++ b/src/azure-cli/azure/cli/command_modules/keyvault/tests/latest/test_keyvault_unit.py @@ -0,0 +1,190 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import unittest +from unittest import mock +from azure.core.exceptions import ResourceNotFoundError, HttpResponseError +from knack.util import CLIError + +# Mock the logger to prevent actual logging during tests +with mock.patch('azure.cli.command_modules.keyvault.custom.logger'): + from azure.cli.command_modules.keyvault.custom import copy_secret + +class KeyVaultCopySecretTest(unittest.TestCase): + def setUp(self): + self.cmd = mock.MagicMock() + self.cmd.cli_ctx = mock.MagicMock() + self.cmd.cli_ctx.data = { + 'subscription_id': 'sub_id', + 'headers': {}, + 'completer_active': False, + 'command': 'keyvault secret copy' + } + + # Patches + self.patcher_profile = mock.patch('azure.cli.core._profile.Profile') + self.mock_profile = self.patcher_profile.start() + self.mock_profile_instance = mock.MagicMock() + self.mock_profile.return_value = self.mock_profile_instance + self.mock_profile_instance.get_login_credentials.return_value = (mock.Mock(), mock.Mock(), mock.Mock()) + + self.patcher_secret_client = mock.patch('azure.keyvault.secrets.SecretClient') + self.mock_secret_client_cls = self.patcher_secret_client.start() + + # Source Client Mock (passed as argument) + self.source_client = mock.MagicMock() + self.source_client.vault_url = "https://source-kv.vault.azure.net/" + + # Dest Client Mock (instantiated inside function) + self.dest_client = mock.MagicMock() + self.mock_secret_client_cls.return_value = self.dest_client + + def tearDown(self): + self.patcher_profile.stop() + self.patcher_secret_client.stop() + + def test_copy_single_secret_success(self): + # Setup + secret_name = "mysecret" + destination_vault = "https://dest-kv.vault.azure.net/" + + # Mocks for verification check + # Dummy check raises 404 which is expected/success path for connectivity check + not_found_error = HttpResponseError(message="Not Found") + not_found_error.status_code = 404 + self.dest_client.get_secret.side_effect = [not_found_error, ResourceNotFoundError] + # First call is dummy check (fails with 404), second is check existence (fails with ResourceNotFoundError -> OK to copy) + + # Source secret + secret_obj = mock.Mock() + secret_obj.name = secret_name + secret_obj.value = "secret_value" + secret_obj.properties.content_type = "text/plain" + secret_obj.properties.tags = {} + secret_obj.properties.enabled = True + secret_obj.properties.not_before = None + secret_obj.properties.expires_on = None + + self.source_client.get_secret.return_value = secret_obj + + # Result of set_secret + new_secret = mock.Mock() + new_secret.name = secret_name + new_secret.id = destination_vault + "/secrets/" + secret_name + self.dest_client.set_secret.return_value = new_secret + + # Act + result = copy_secret(self.cmd, self.source_client, destination_vault, name=secret_name) + + # Assert + self.assertEqual(len(result), 1) + self.assertEqual(result[0]['name'], secret_name) + self.dest_client.set_secret.assert_called_with( + secret_name, "secret_value", content_type="text/plain", tags={}, + enabled=True, not_before=None, expires_on=None + ) + + def test_copy_secret_already_exists_no_overwrite(self): + # Setup + secret_name = "mysecret" + destination_vault = "https://dest-kv.vault.azure.net/" + + # Dummy check 404 + not_found_error = HttpResponseError(message="Not Found") + not_found_error.status_code = 404 + + # Pre-check existence returns Success (means it exists) + self.dest_client.get_secret.side_effect = [not_found_error, mock.Mock()] + + # Act + result = copy_secret(self.cmd, self.source_client, destination_vault, name=secret_name, overwrite=False) + + # Assert + self.assertEqual(len(result), 0) # Should be empty list as it was skipped + self.dest_client.set_secret.assert_not_called() + + def test_copy_secret_already_exists_with_overwrite(self): + # Setup + secret_name = "mysecret" + destination_vault = "https://dest-kv.vault.azure.net/" + + # Dummy check 404 + not_found_error = HttpResponseError(message="Not Found") + not_found_error.status_code = 404 + self.dest_client.get_secret.side_effect = [not_found_error] # No second call because overwrite=True skips check + + # Source secret + secret_obj = mock.Mock() + secret_obj.name = secret_name + secret_obj.value = "val" + secret_obj.properties.content_type = None + secret_obj.properties.tags = None + secret_obj.properties.enabled = True + secret_obj.properties.not_before = None + secret_obj.properties.expires_on = None + self.source_client.get_secret.return_value = secret_obj + + new_secret = mock.Mock() + new_secret.name = secret_name + new_secret.id = destination_vault + "/secrets/" + secret_name + self.dest_client.set_secret.return_value = new_secret + + # Act + result = copy_secret(self.cmd, self.source_client, destination_vault, name=secret_name, overwrite=True) + + # Assert + self.assertEqual(len(result), 1) + self.dest_client.set_secret.assert_called() + + def test_copy_all_secrets(self): + # Setup + destination_vault = "https://dest-kv.vault.azure.net/" + + # Dummy check 404 + not_found_error = HttpResponseError(message="Not Found") + not_found_error.status_code = 404 + # We have 2 secrets. For each, we check existence (fails -> copy). + # Side effect sequence: DummyCheck -> Check(sec1) -> Check(sec2) + self.dest_client.get_secret.side_effect = [ + not_found_error, + ResourceNotFoundError, + ResourceNotFoundError + ] + + # List secrets source + s1 = mock.Mock(); s1.name = "sec1"; s1.managed = False + s2 = mock.Mock(); s2.name = "sec2"; s2.managed = False + s3 = mock.Mock(); s3.name = "mgd1"; s3.managed = True # Should be skipped + self.source_client.list_properties_of_secrets.return_value = [s1, s2, s3] + + # Get secret details + def get_secret_side_effect(name): + m = mock.Mock() + m.name = name + m.value = "val" + m.properties.content_type = None + m.properties.tags = None + m.properties.enabled = True + m.properties.not_before = None + m.properties.expires_on = None + return m + self.source_client.get_secret.side_effect = get_secret_side_effect + + new_secret = mock.Mock() + new_secret.name = "sec" + new_secret.id = "id" + self.dest_client.set_secret.return_value = new_secret + + # Act + result = copy_secret(self.cmd, self.source_client, destination_vault, all_secrets=True) + + # Assert + self.assertEqual(len(result), 2) + call_args = self.dest_client.set_secret.call_args_list + self.assertEqual(call_args[0][0][0], "sec1") + self.assertEqual(call_args[1][0][0], "sec2") + +if __name__ == '__main__': + unittest.main()