Skip to content

Commit

Permalink
Automatically deserialize bytes in actions
Browse files Browse the repository at this point in the history
  • Loading branch information
dainnilsson committed Oct 7, 2024
1 parent f8b4eff commit 654ce4a
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 42 deletions.
7 changes: 7 additions & 0 deletions helper/helper/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,13 @@ def __call__(self, action, target, params, event, signal, traversed=None):
elif action in self.list_actions():
action_f = self.get_action(action)
args = inspect.signature(action_f).parameters
# Decode any serialized bytes parameters
for key, param in args.items():
if param.annotation in (bytes, bytes | None):
value = params.get(key, None)
if value is not None:
params[key] = decode_bytes(value)
# Add event and signal if requested
if "event" in args:
params["event"] = event
if "signal" in args:
Expand Down
8 changes: 3 additions & 5 deletions helper/helper/management.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,23 +90,21 @@ def _await_reboot(self, serial, usb_enabled):
def configure(
self,
reboot: bool = False,
cur_lock_code: str = "",
new_lock_code: str = "",
cur_lock_code: bytes | None = None,
new_lock_code: bytes | None = None,
enabled_capabilities: dict = {},
auto_eject_timeout: int | None = None,
challenge_response_timeout: int | None = None,
device_flags: int | None = None,
):
cur_code = bytes.fromhex(cur_lock_code) or None
new_code = bytes.fromhex(new_lock_code) or None
config = DeviceConfig(
enabled_capabilities,
auto_eject_timeout,
challenge_response_timeout,
DEVICE_FLAG(device_flags) if device_flags else None,
)
serial = self.session.read_device_info().serial
self.session.write_device_config(config, reboot, cur_code, new_code)
self.session.write_device_config(config, reboot, cur_lock_code, new_lock_code)
flags = ["device_info"]
if reboot:
enabled = config.enabled_capabilities.get(TRANSPORT.USB)
Expand Down
8 changes: 4 additions & 4 deletions helper/helper/oath.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,13 @@ def _remember_key(self, key):
else:
return False

def _get_key(self, key: str | None, password: str | None):
def _get_key(self, key: bytes | None, password: str | None):
if key and password:
raise ValueError("Only one of 'key' and 'password' can be provided.")
if password:
return self.session.derive_key(password)
if key:
return decode_bytes(key)
return key
raise ValueError("One of 'key' and 'password' must be provided.")

def _set_key_verifier(self, key):
Expand All @@ -163,7 +163,7 @@ def _do_validate(self, key):
@action
def validate(
self,
key: str | None = None,
key: bytes | None = None,
password: str | None = None,
remember: bool = False,
):
Expand Down Expand Up @@ -192,7 +192,7 @@ def validate(
@action
def set_key(
self,
key: str | None = None,
key: bytes | None = None,
password: str | None = None,
remember: bool = False,
):
Expand Down
17 changes: 8 additions & 9 deletions helper/helper/piv.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,9 @@ def verify_pin(self, signal, pin: str):
return dict(status=True, authenticated=self._authenticated)

@action
def authenticate(self, signal, key: str):
def authenticate(self, signal, key: bytes):
try:
self._authenticate(bytes.fromhex(key), signal)
self._authenticate(key, signal)
return dict(status=True)
except ApduError as e:
if e.sw == SW.SECURITY_CONDITION_NOT_SATISFIED:
Expand All @@ -213,14 +213,13 @@ def authenticate(self, signal, key: str):
@action(condition=lambda self: self._authenticated)
def set_key(
self,
params,
key: str,
key: bytes,
key_type: int = MANAGEMENT_KEY_TYPE.TDES,
store_key: bool = False,
):
pivman_set_mgm_key(
self.session,
bytes.fromhex(key),
key,
MANAGEMENT_KEY_TYPE(key_type),
False,
store_key,
Expand Down Expand Up @@ -264,9 +263,9 @@ def slots(self):
return SlotsNode(self.session)

@action(closes_child=False)
def examine_file(self, data: str, password: str | None = None):
def examine_file(self, data: bytes, password: str | None = None):
try:
private_key, certs = _parse_file(bytes.fromhex(data), password)
private_key, certs = _parse_file(data, password)
certificate = _choose_cert(certs)

return dict(
Expand Down Expand Up @@ -461,9 +460,9 @@ def move_key(
return dict()

@action
def import_file(self, data: str, password: str | None = None, **kwargs):
def import_file(self, data: bytes, password: str | None = None, **kwargs):
try:
private_key, certs = _parse_file(bytes.fromhex(data), password)
private_key, certs = _parse_file(data, password)
except InvalidPasswordError:
logger.debug("Invalid or missing password", exc_info=True)
raise ValueError("Wrong/Missing password")
Expand Down
48 changes: 24 additions & 24 deletions helper/helper/yubiotp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .base import RpcNode, action, child
from .base import RpcNode, action, child, decode_bytes

from yubikit.core import NotSupportedError, CommandError
from yubikit.core.otp import modhex_encode, modhex_decode
Expand Down Expand Up @@ -95,15 +95,15 @@ def format_yubiotp_csv(
self,
serial: int,
public_id: str,
private_id: str,
key: str,
private_id: bytes,
key: bytes,
):
return dict(
csv=format_csv(
serial,
modhex_decode(public_id),
bytes.fromhex(private_id),
bytes.fromhex(key),
private_id,
key,
)
)

Expand Down Expand Up @@ -145,19 +145,16 @@ def _can_calculate(self, slot):
return False

@action(condition=lambda self: self._maybe_configured(self.slot))
def delete(self, curr_acc_code: str | None = None):
def delete(self, curr_acc_code: bytes | None = None):
try:
access_code = bytes.fromhex(curr_acc_code) if curr_acc_code else None
self.session.delete_slot(self.slot, access_code)
self.session.delete_slot(self.slot, curr_acc_code)
return dict()
except CommandError:
raise ValueError(_FAIL_MSG)

@action(condition=lambda self: self._can_calculate(self.slot))
def calculate(self, event, challenge: str):
response = self.session.calculate_hmac_sha1(
self.slot, bytes.fromhex(challenge), event
)
def calculate(self, event, challenge: bytes):
response = self.session.calculate_hmac_sha1(self.slot, challenge, event)
return dict(response=response)

@staticmethod
Expand Down Expand Up @@ -189,13 +186,13 @@ def _apply_options(config, options) -> None:

if "token_id" in options:
token_id, *args = options.pop("token_id")
config.token_id(bytes.fromhex(token_id), *args)
config.token_id(decode_bytes(token_id), *args)

@staticmethod
def _get_config(cfg_type: str, **kwargs) -> SlotConfiguration:
match cfg_type:
case "hmac_sha1":
return HmacSha1SlotConfiguration(bytes.fromhex(kwargs["key"]))
return HmacSha1SlotConfiguration(decode_bytes(kwargs["key"]))
case "hotp":
return HotpSlotConfiguration(parse_b32_key(kwargs["key"]))
case "static_password":
Expand All @@ -207,8 +204,8 @@ def _get_config(cfg_type: str, **kwargs) -> SlotConfiguration:
case "yubiotp":
return YubiOtpSlotConfiguration(
fixed=modhex_decode(kwargs["public_id"]),
uid=bytes.fromhex(kwargs["private_id"]),
key=bytes.fromhex(kwargs["key"]),
uid=decode_bytes(kwargs["private_id"]),
key=decode_bytes(kwargs["key"]),
)
case unsupported:
raise ValueError(
Expand All @@ -217,17 +214,20 @@ def _get_config(cfg_type: str, **kwargs) -> SlotConfiguration:

@action
def put(
self, type: str, options: dict = {}, curr_acc_code: str | None = None, **kwargs
self,
type: str,
options: dict = {},
curr_acc_code: bytes | None = None,
**kwargs,
):
access_code = bytes.fromhex(curr_acc_code) if curr_acc_code else None
config = self._get_config(type, **kwargs)
self._apply_options(config, options)
try:
self.session.put_configuration(
self.slot,
config,
access_code,
access_code,
curr_acc_code,
curr_acc_code,
)
return dict()
except CommandError:
Expand All @@ -240,16 +240,16 @@ def put(
def update(
self,
params,
acc_code: str | None = None,
curr_acc_code: str | None = None,
acc_code: bytes | None = None,
curr_acc_code: bytes | None = None,
**kwargs,
):
config = UpdateConfiguration()
self._apply_options(config, kwargs)
self.session.update_configuration(
self.slot,
config,
bytes.fromhex(acc_code) if acc_code else None,
bytes.fromhex(curr_acc_code) if curr_acc_code else None,
acc_code,
curr_acc_code,
)
return dict()

0 comments on commit 654ce4a

Please sign in to comment.