diff --git a/src/azure-cli-core/azure/cli/core/auth/identity.py b/src/azure-cli-core/azure/cli/core/auth/identity.py index 91629e89441..6dbc176f093 100644 --- a/src/azure-cli-core/azure/cli/core/auth/identity.py +++ b/src/azure-cli-core/azure/cli/core/auth/identity.py @@ -82,7 +82,7 @@ def __init__(self, authority, tenant_id=None, client_id=None, encrypt=False, use config_dir = get_config_dir() self._token_cache_file = os.path.join(config_dir, "msal_token_cache") self._secret_file = os.path.join(config_dir, "service_principal_entries") - self._msal_http_cache_file = os.path.join(config_dir, "msal_http_cache.bin") + self._msal_http_cache_file = os.path.join(config_dir, "msal_http_cache.json") # We make _msal_app_instance an instance attribute, instead of a class attribute, # because MSAL apps can have different tenant IDs. @@ -131,8 +131,8 @@ def _load_msal_token_cache(self): return cache def _load_msal_http_cache(self): - from .binary_cache import BinaryCache - http_cache = BinaryCache(self._msal_http_cache_file) + from .json_cache import JsonCache + http_cache = JsonCache(self._msal_http_cache_file) return http_cache @property diff --git a/src/azure-cli-core/azure/cli/core/auth/json_cache.py b/src/azure-cli-core/azure/cli/core/auth/json_cache.py new file mode 100644 index 00000000000..688119929fe --- /dev/null +++ b/src/azure-cli-core/azure/cli/core/auth/json_cache.py @@ -0,0 +1,112 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +import json +from collections.abc import MutableMapping + +from azure.cli.core.decorators import retry +from knack.log import get_logger +from msal.throttled_http_client import NormalizedResponse + +logger = get_logger(__name__) + + +class NormalizedResponseJsonEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, NormalizedResponse): + return { + "status_code": o.status_code, + "text": o.text, + "headers": o.headers, + } + return super().default(o) + + +class JsonCache(MutableMapping): + """ + A simple dict-like class that is backed by a json file. This is designed for the MSAL HTTP cache. + + All direct modifications with `__setitem__` and `__delitem__` will save the file. + Indirect modifications should be followed by a call to `save`. + """ + def __init__(self, file_name): + super().__init__() + self.filename = file_name + self.data = {} + self.load() + + @retry() + def _load(self): + """Load cache with retry. If it still fails at last, raise the original exception as-is.""" + try: + with open(self.filename, 'r', encoding='utf-8') as f: + data = json.load(f) + response_keys = [key for key in data if key != "_index_"] + for key in response_keys: + try: + response_dict = data[key] + # Reconstruct NormalizedResponse from the stored dict + response = NormalizedResponse.__new__(NormalizedResponse) + response.status_code = response_dict["status_code"] + response.text = response_dict["text"] + response.headers = response_dict["headers"] + data[key] = response + except KeyError as e: + logger.debug("Failed to reconstruct NormalizedResponse for key %s: %s", key, e) + # If reconstruction fails, remove the entry from cache + del data[key] + return data + except FileNotFoundError: + # The cache file has not been created. This is expected. No need to retry. + logger.debug("%s not found. Using a fresh one.", self.filename) + return {} + + def load(self): + logger.debug("load: %s", self.filename) + try: + self.data = self._load() + except Exception as ex: # pylint: disable=broad-exception-caught + # If we still get exception after retry, ignore all types of exceptions and use a new cache. + # - EOFError is caused by empty cache file created by other az instance, but hasn't been filled yet. + # - KeyError is caused by reading cache generated by different MSAL versions. + logger.debug("Failed to load cache: %s. Using a fresh one.", ex) + self.data = {} # Ignore a non-existing or corrupted http_cache + + @retry() + def _save(self): + with open(self.filename, 'w', encoding='utf-8') as f: + # At this point, an empty cache file will be created. Loading this cache file will + # raise EOFError. This can be simulated by adding time.sleep(30) here. + # So during loading, EOFError is ignored. + json.dump(self.data, f, cls=NormalizedResponseJsonEncoder) + + def save(self): + logger.debug("save: %s", self.filename) + # If 2 processes write at the same time, the cache will be corrupted, + # but that is fine. Subsequent runs would reach eventual consistency. + try: + self._save() + except TypeError as e: + # If serialization fails, skip saving to avoid corrupting the cache file + logger.debug("Failed to save cache due to TypeError: %s", e) + + def get(self, key, default=None): + return self.data.get(key, default) + + def __getitem__(self, key): + return self.data[key] + + def __setitem__(self, key, value): + self.data[key] = value + self.save() + + def __delitem__(self, key): + del self.data[key] + self.save() + + def __iter__(self): + return iter(self.data) + + def __len__(self): + return len(self.data)