Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions py/selenium/webdriver/common/virtual_authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import functools
from base64 import urlsafe_b64decode, urlsafe_b64encode
from collections.abc import Callable
from enum import Enum
from typing import Any
from typing import Any, TypeVar

F = TypeVar("F", bound=Callable[..., Any])


class Protocol(str, Enum):
Expand Down Expand Up @@ -130,7 +135,7 @@ def sign_count(self) -> int:
return self._sign_count

@classmethod
def create_non_resident_credential(cls, id: bytes, rp_id: str, private_key: bytes, sign_count: int) -> "Credential":
def create_non_resident_credential(cls, id: bytes, rp_id: str, private_key: bytes, sign_count: int) -> Credential:
"""Creates a non-resident (i.e. stateless) credential.

Args:
Expand All @@ -147,7 +152,7 @@ def create_non_resident_credential(cls, id: bytes, rp_id: str, private_key: byte
@classmethod
def create_resident_credential(
cls, id: bytes, rp_id: str, user_handle: bytes | None, private_key: bytes, sign_count: int
) -> "Credential":
) -> Credential:
"""Creates a resident (i.e. stateful) credential.

Args:
Expand Down Expand Up @@ -177,7 +182,7 @@ def to_dict(self) -> dict[str, Any]:
return credential_data

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "Credential":
def from_dict(cls, data: dict[str, Any]) -> Credential:
_id = urlsafe_b64decode(f"{data['credentialId']}==")
is_resident_credential = bool(data["isResidentCredential"])
rp_id = data.get("rpId", None)
Expand All @@ -192,28 +197,28 @@ def __str__(self) -> str:
user_handle={self.user_handle}, private_key={self.private_key}, sign_count={self.sign_count})"


def required_chromium_based_browser(func):
def required_chromium_based_browser(func: F) -> F:
"""Decorator to ensure that the client used is a chromium-based browser."""

@functools.wraps(func)
def wrapper(self, *args, **kwargs):
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
assert self.caps["browserName"].lower() not in [
"firefox",
"safari",
], "This only currently works in Chromium based browsers"
return func(self, *args, **kwargs)

return wrapper
return wrapper # type: ignore[return-value]


def required_virtual_authenticator(func):
def required_virtual_authenticator(func: F) -> F:
"""Decorator to ensure that the function is called with a virtual authenticator."""

@functools.wraps(func)
@required_chromium_based_browser
def wrapper(self, *args, **kwargs):
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
if not self.virtual_authenticator_id:
raise ValueError("This function requires a virtual authenticator to be set.")
return func(self, *args, **kwargs)

return wrapper
return wrapper # type: ignore[return-value]
21 changes: 16 additions & 5 deletions py/selenium/webdriver/remote/script_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,30 @@
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import uuid


class ScriptKey:
def __init__(self, id=None):
self._id = id or uuid.uuid4()
_id: str

def __init__(self, id: str | None = None) -> None:
self._id = id if id is not None else str(uuid.uuid4())

@property
def id(self):
def id(self) -> str:
return self._id

def __eq__(self, other):
return self._id == other
def __eq__(self, other: object) -> bool:
if isinstance(other, ScriptKey):
return self._id == other._id
if isinstance(other, str):
return self._id == other
return NotImplemented

def __hash__(self) -> int:
return hash(self._id)

def __repr__(self) -> str:
return f"ScriptKey(id={self.id})"
Loading