From 654ce4a6cce809f2ca8cfef3e523be5f5ce80957 Mon Sep 17 00:00:00 2001 From: Dain Nilsson Date: Mon, 7 Oct 2024 16:12:33 +0200 Subject: [PATCH] Automatically deserialize bytes in actions --- helper/helper/base.py | 7 ++++++ helper/helper/management.py | 8 +++---- helper/helper/oath.py | 8 +++---- helper/helper/piv.py | 17 +++++++------ helper/helper/yubiotp.py | 48 ++++++++++++++++++------------------- 5 files changed, 46 insertions(+), 42 deletions(-) diff --git a/helper/helper/base.py b/helper/helper/base.py index 3429e142e..2739590bc 100644 --- a/helper/helper/base.py +++ b/helper/helper/base.py @@ -130,6 +130,13 @@ def __call__(self, action, target, params, event, signal, traversed=None): elif action in self.list_actions(): action_f = self.get_action(action) args = inspect.signature(action_f).parameters + # Decode any serialized bytes parameters + for key, param in args.items(): + if param.annotation in (bytes, bytes | None): + value = params.get(key, None) + if value is not None: + params[key] = decode_bytes(value) + # Add event and signal if requested if "event" in args: params["event"] = event if "signal" in args: diff --git a/helper/helper/management.py b/helper/helper/management.py index cccfee3f2..003151e8f 100644 --- a/helper/helper/management.py +++ b/helper/helper/management.py @@ -90,15 +90,13 @@ def _await_reboot(self, serial, usb_enabled): def configure( self, reboot: bool = False, - cur_lock_code: str = "", - new_lock_code: str = "", + cur_lock_code: bytes | None = None, + new_lock_code: bytes | None = None, enabled_capabilities: dict = {}, auto_eject_timeout: int | None = None, challenge_response_timeout: int | None = None, device_flags: int | None = None, ): - cur_code = bytes.fromhex(cur_lock_code) or None - new_code = bytes.fromhex(new_lock_code) or None config = DeviceConfig( enabled_capabilities, auto_eject_timeout, @@ -106,7 +104,7 @@ def configure( DEVICE_FLAG(device_flags) if device_flags else None, ) serial = self.session.read_device_info().serial - self.session.write_device_config(config, reboot, cur_code, new_code) + self.session.write_device_config(config, reboot, cur_lock_code, new_lock_code) flags = ["device_info"] if reboot: enabled = config.enabled_capabilities.get(TRANSPORT.USB) diff --git a/helper/helper/oath.py b/helper/helper/oath.py index c54016af2..383440445 100644 --- a/helper/helper/oath.py +++ b/helper/helper/oath.py @@ -142,13 +142,13 @@ def _remember_key(self, key): else: return False - def _get_key(self, key: str | None, password: str | None): + def _get_key(self, key: bytes | None, password: str | None): if key and password: raise ValueError("Only one of 'key' and 'password' can be provided.") if password: return self.session.derive_key(password) if key: - return decode_bytes(key) + return key raise ValueError("One of 'key' and 'password' must be provided.") def _set_key_verifier(self, key): @@ -163,7 +163,7 @@ def _do_validate(self, key): @action def validate( self, - key: str | None = None, + key: bytes | None = None, password: str | None = None, remember: bool = False, ): @@ -192,7 +192,7 @@ def validate( @action def set_key( self, - key: str | None = None, + key: bytes | None = None, password: str | None = None, remember: bool = False, ): diff --git a/helper/helper/piv.py b/helper/helper/piv.py index 0ad166fcd..a86c933ae 100644 --- a/helper/helper/piv.py +++ b/helper/helper/piv.py @@ -201,9 +201,9 @@ def verify_pin(self, signal, pin: str): return dict(status=True, authenticated=self._authenticated) @action - def authenticate(self, signal, key: str): + def authenticate(self, signal, key: bytes): try: - self._authenticate(bytes.fromhex(key), signal) + self._authenticate(key, signal) return dict(status=True) except ApduError as e: if e.sw == SW.SECURITY_CONDITION_NOT_SATISFIED: @@ -213,14 +213,13 @@ def authenticate(self, signal, key: str): @action(condition=lambda self: self._authenticated) def set_key( self, - params, - key: str, + key: bytes, key_type: int = MANAGEMENT_KEY_TYPE.TDES, store_key: bool = False, ): pivman_set_mgm_key( self.session, - bytes.fromhex(key), + key, MANAGEMENT_KEY_TYPE(key_type), False, store_key, @@ -264,9 +263,9 @@ def slots(self): return SlotsNode(self.session) @action(closes_child=False) - def examine_file(self, data: str, password: str | None = None): + def examine_file(self, data: bytes, password: str | None = None): try: - private_key, certs = _parse_file(bytes.fromhex(data), password) + private_key, certs = _parse_file(data, password) certificate = _choose_cert(certs) return dict( @@ -461,9 +460,9 @@ def move_key( return dict() @action - def import_file(self, data: str, password: str | None = None, **kwargs): + def import_file(self, data: bytes, password: str | None = None, **kwargs): try: - private_key, certs = _parse_file(bytes.fromhex(data), password) + private_key, certs = _parse_file(data, password) except InvalidPasswordError: logger.debug("Invalid or missing password", exc_info=True) raise ValueError("Wrong/Missing password") diff --git a/helper/helper/yubiotp.py b/helper/helper/yubiotp.py index feb37bf33..098400e50 100644 --- a/helper/helper/yubiotp.py +++ b/helper/helper/yubiotp.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import RpcNode, action, child +from .base import RpcNode, action, child, decode_bytes from yubikit.core import NotSupportedError, CommandError from yubikit.core.otp import modhex_encode, modhex_decode @@ -95,15 +95,15 @@ def format_yubiotp_csv( self, serial: int, public_id: str, - private_id: str, - key: str, + private_id: bytes, + key: bytes, ): return dict( csv=format_csv( serial, modhex_decode(public_id), - bytes.fromhex(private_id), - bytes.fromhex(key), + private_id, + key, ) ) @@ -145,19 +145,16 @@ def _can_calculate(self, slot): return False @action(condition=lambda self: self._maybe_configured(self.slot)) - def delete(self, curr_acc_code: str | None = None): + def delete(self, curr_acc_code: bytes | None = None): try: - access_code = bytes.fromhex(curr_acc_code) if curr_acc_code else None - self.session.delete_slot(self.slot, access_code) + self.session.delete_slot(self.slot, curr_acc_code) return dict() except CommandError: raise ValueError(_FAIL_MSG) @action(condition=lambda self: self._can_calculate(self.slot)) - def calculate(self, event, challenge: str): - response = self.session.calculate_hmac_sha1( - self.slot, bytes.fromhex(challenge), event - ) + def calculate(self, event, challenge: bytes): + response = self.session.calculate_hmac_sha1(self.slot, challenge, event) return dict(response=response) @staticmethod @@ -189,13 +186,13 @@ def _apply_options(config, options) -> None: if "token_id" in options: token_id, *args = options.pop("token_id") - config.token_id(bytes.fromhex(token_id), *args) + config.token_id(decode_bytes(token_id), *args) @staticmethod def _get_config(cfg_type: str, **kwargs) -> SlotConfiguration: match cfg_type: case "hmac_sha1": - return HmacSha1SlotConfiguration(bytes.fromhex(kwargs["key"])) + return HmacSha1SlotConfiguration(decode_bytes(kwargs["key"])) case "hotp": return HotpSlotConfiguration(parse_b32_key(kwargs["key"])) case "static_password": @@ -207,8 +204,8 @@ def _get_config(cfg_type: str, **kwargs) -> SlotConfiguration: case "yubiotp": return YubiOtpSlotConfiguration( fixed=modhex_decode(kwargs["public_id"]), - uid=bytes.fromhex(kwargs["private_id"]), - key=bytes.fromhex(kwargs["key"]), + uid=decode_bytes(kwargs["private_id"]), + key=decode_bytes(kwargs["key"]), ) case unsupported: raise ValueError( @@ -217,17 +214,20 @@ def _get_config(cfg_type: str, **kwargs) -> SlotConfiguration: @action def put( - self, type: str, options: dict = {}, curr_acc_code: str | None = None, **kwargs + self, + type: str, + options: dict = {}, + curr_acc_code: bytes | None = None, + **kwargs, ): - access_code = bytes.fromhex(curr_acc_code) if curr_acc_code else None config = self._get_config(type, **kwargs) self._apply_options(config, options) try: self.session.put_configuration( self.slot, config, - access_code, - access_code, + curr_acc_code, + curr_acc_code, ) return dict() except CommandError: @@ -240,8 +240,8 @@ def put( def update( self, params, - acc_code: str | None = None, - curr_acc_code: str | None = None, + acc_code: bytes | None = None, + curr_acc_code: bytes | None = None, **kwargs, ): config = UpdateConfiguration() @@ -249,7 +249,7 @@ def update( self.session.update_configuration( self.slot, config, - bytes.fromhex(acc_code) if acc_code else None, - bytes.fromhex(curr_acc_code) if curr_acc_code else None, + acc_code, + curr_acc_code, ) return dict()