diff --git a/py/selenium/webdriver/common/virtual_authenticator.py b/py/selenium/webdriver/common/virtual_authenticator.py index 0210edda0a8bd..da92f51c02d4f 100644 --- a/py/selenium/webdriver/common/virtual_authenticator.py +++ b/py/selenium/webdriver/common/virtual_authenticator.py @@ -15,10 +15,15 @@ # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import functools from base64 import urlsafe_b64decode, urlsafe_b64encode +from collections.abc import Callable from enum import Enum -from typing import Any +from typing import Any, TypeVar + +F = TypeVar("F", bound=Callable[..., Any]) class Protocol(str, Enum): @@ -130,7 +135,7 @@ def sign_count(self) -> int: return self._sign_count @classmethod - def create_non_resident_credential(cls, id: bytes, rp_id: str, private_key: bytes, sign_count: int) -> "Credential": + def create_non_resident_credential(cls, id: bytes, rp_id: str, private_key: bytes, sign_count: int) -> Credential: """Creates a non-resident (i.e. stateless) credential. Args: @@ -147,7 +152,7 @@ def create_non_resident_credential(cls, id: bytes, rp_id: str, private_key: byte @classmethod def create_resident_credential( cls, id: bytes, rp_id: str, user_handle: bytes | None, private_key: bytes, sign_count: int - ) -> "Credential": + ) -> Credential: """Creates a resident (i.e. stateful) credential. Args: @@ -177,7 +182,7 @@ def to_dict(self) -> dict[str, Any]: return credential_data @classmethod - def from_dict(cls, data: dict[str, Any]) -> "Credential": + def from_dict(cls, data: dict[str, Any]) -> Credential: _id = urlsafe_b64decode(f"{data['credentialId']}==") is_resident_credential = bool(data["isResidentCredential"]) rp_id = data.get("rpId", None) @@ -192,28 +197,28 @@ def __str__(self) -> str: user_handle={self.user_handle}, private_key={self.private_key}, sign_count={self.sign_count})" -def required_chromium_based_browser(func): +def required_chromium_based_browser(func: F) -> F: """Decorator to ensure that the client used is a chromium-based browser.""" @functools.wraps(func) - def wrapper(self, *args, **kwargs): + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: assert self.caps["browserName"].lower() not in [ "firefox", "safari", ], "This only currently works in Chromium based browsers" return func(self, *args, **kwargs) - return wrapper + return wrapper # type: ignore[return-value] -def required_virtual_authenticator(func): +def required_virtual_authenticator(func: F) -> F: """Decorator to ensure that the function is called with a virtual authenticator.""" @functools.wraps(func) @required_chromium_based_browser - def wrapper(self, *args, **kwargs): + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: if not self.virtual_authenticator_id: raise ValueError("This function requires a virtual authenticator to be set.") return func(self, *args, **kwargs) - return wrapper + return wrapper # type: ignore[return-value] diff --git a/py/selenium/webdriver/remote/script_key.py b/py/selenium/webdriver/remote/script_key.py index 930b699c7d79b..eb2480e3f702b 100644 --- a/py/selenium/webdriver/remote/script_key.py +++ b/py/selenium/webdriver/remote/script_key.py @@ -15,19 +15,30 @@ # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import uuid class ScriptKey: - def __init__(self, id=None): - self._id = id or uuid.uuid4() + _id: str + + def __init__(self, id: str | None = None) -> None: + self._id = id if id is not None else str(uuid.uuid4()) @property - def id(self): + def id(self) -> str: return self._id - def __eq__(self, other): - return self._id == other + def __eq__(self, other: object) -> bool: + if isinstance(other, ScriptKey): + return self._id == other._id + if isinstance(other, str): + return self._id == other + return NotImplemented + + def __hash__(self) -> int: + return hash(self._id) def __repr__(self) -> str: return f"ScriptKey(id={self.id})" diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index 8a8f282ecac70..edb734c1f9766 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -17,6 +17,8 @@ """The WebDriver implementation.""" +from __future__ import annotations + import base64 import contextlib import copy @@ -28,6 +30,7 @@ import zipfile from abc import ABCMeta from base64 import b64decode, urlsafe_b64encode +from collections.abc import AsyncGenerator, Generator, Iterable from contextlib import asynccontextmanager, contextmanager from importlib import import_module from typing import Any, cast @@ -84,7 +87,7 @@ def import_cdp() -> None: cdp = import_module("selenium.webdriver.common.bidi.cdp") -def _create_caps(caps) -> dict: +def _create_caps(caps: dict[str, Any]) -> dict[str, Any]: """Makes a W3C alwaysMatch capabilities object. Filters out capability names that are not in the W3C spec. Spec-compliant @@ -104,7 +107,7 @@ def _create_caps(caps) -> dict: def get_remote_connection( - capabilities: dict, + capabilities: dict[str, Any], command_executor: str | RemoteConnection, keep_alive: bool, ignore_local_proxy: bool, @@ -135,7 +138,7 @@ def get_remote_connection( ) -def create_matches(options: list[BaseOptions]) -> dict: +def create_matches(options: list[BaseOptions]) -> dict[str, Any]: capabilities: dict[str, Any] = {"capabilities": {}} opts = [] for opt in options: @@ -273,7 +276,7 @@ def __init__( def __repr__(self) -> str: return f'<{type(self).__module__}.{type(self).__name__} (session="{self.session_id}")>' - def __enter__(self) -> "WebDriver": + def __enter__(self) -> WebDriver: return self def __exit__( @@ -281,11 +284,13 @@ def __exit__( exc_type: type[BaseException] | None, exc: BaseException | None, traceback: types.TracebackType | None, - ): + ) -> None: self.quit() @contextmanager - def file_detector_context(self, file_detector_class, *args, **kwargs): + def file_detector_context( + self, file_detector_class: type[FileDetector], *args: Any, **kwargs: Any + ) -> Generator[None, None, None]: """Override the current file detector temporarily within a limited context. Ensures the original file detector is set after exiting the context. @@ -323,7 +328,7 @@ def mobile(self) -> Mobile: def name(self) -> str: """Returns the name of the underlying browser for this instance.""" if "browserName" in self.caps: - return self.caps["browserName"] + return cast(str, self.caps["browserName"]) raise KeyError("browserName not specified in session capabilities") def start_client(self) -> None: @@ -341,7 +346,7 @@ def stop_client(self) -> None: """ pass - def start_session(self, capabilities: dict) -> None: + def start_session(self, capabilities: dict[str, Any]) -> None: """Creates a new session with the desired capabilities. Args: @@ -357,7 +362,7 @@ def start_session(self, capabilities: dict) -> None: self.service.stop() raise - def _wrap_value(self, value): + def _wrap_value(self, value: Any) -> Any: if isinstance(value, dict): converted = {} for key, val in value.items(): @@ -375,7 +380,7 @@ def create_web_element(self, element_id: str) -> WebElement: """Creates a web element with the specified `element_id`.""" return self._web_element_cls(self, element_id) - def _unwrap_value(self, value): + def _unwrap_value(self, value: Any) -> Any: if isinstance(value, dict): if "element-6066-11e4-a52e-4f735466cecf" in value: return self.create_web_element(value["element-6066-11e4-a52e-4f735466cecf"]) @@ -388,7 +393,7 @@ def _unwrap_value(self, value): return list(self._unwrap_value(item) for item in value) return value - def execute_cdp_cmd(self, cmd: str, cmd_args: dict): + def execute_cdp_cmd(self, cmd: str, cmd_args: dict[str, Any]) -> Any: """Execute Chrome Devtools Protocol command and get returned result. The command and command args should follow chrome devtools protocol domains/commands: @@ -461,15 +466,16 @@ def title(self) -> str: print(element.title()) ``` """ - return self.execute(Command.GET_TITLE).get("value", "") + return cast(str, self.execute(Command.GET_TITLE).get("value", "")) - def pin_script(self, script: str, script_key=None) -> ScriptKey: + def pin_script(self, script: str, script_key: ScriptKey | str | None = None) -> ScriptKey: """Store a JavaScript script by a unique hashable ID for later execution. Example: `script = "return document.getElementById('foo').value"` """ - script_key_instance = ScriptKey(script_key) + key_id = script_key.id if isinstance(script_key, ScriptKey) else script_key + script_key_instance = ScriptKey(key_id) self.pinned_scripts[script_key_instance.id] = script return script_key_instance @@ -492,7 +498,7 @@ def get_pinned_scripts(self) -> list[str]: """ return list(self.pinned_scripts) - def execute_script(self, script: str, *args): + def execute_script(self, script: str | ScriptKey, *args: Any) -> Any: """Synchronously Executes JavaScript in the current window/frame. Args: @@ -517,7 +523,7 @@ def execute_script(self, script: str, *args): return self.execute(command, {"script": script, "args": converted_args})["value"] - def execute_async_script(self, script: str, *args) -> dict: + def execute_async_script(self, script: str, *args: Any) -> Any: """Asynchronously Executes JavaScript in the current window/frame. Args: @@ -539,12 +545,12 @@ def execute_async_script(self, script: str, *args) -> dict: @property def current_url(self) -> str: """Gets the URL of the current page.""" - return self.execute(Command.GET_CURRENT_URL)["value"] + return cast(str, self.execute(Command.GET_CURRENT_URL)["value"]) @property def page_source(self) -> str: """Gets the source of the current page.""" - return self.execute(Command.GET_PAGE_SOURCE)["value"] + return cast(str, self.execute(Command.GET_PAGE_SOURCE)["value"]) def close(self) -> None: """Closes the current window.""" @@ -562,12 +568,12 @@ def quit(self) -> None: @property def current_window_handle(self) -> str: """Returns the handle of the current window.""" - return self.execute(Command.W3C_GET_CURRENT_WINDOW_HANDLE)["value"] + return cast(str, self.execute(Command.W3C_GET_CURRENT_WINDOW_HANDLE)["value"]) @property def window_handles(self) -> list[str]: """Returns the handles of all windows within the current session.""" - return self.execute(Command.W3C_GET_WINDOW_HANDLES)["value"] + return cast(list[str], self.execute(Command.W3C_GET_WINDOW_HANDLES)["value"]) def maximize_window(self) -> None: """Maximizes the current window that webdriver is using.""" @@ -592,7 +598,7 @@ def print_page(self, print_options: PrintOptions | None = None) -> str: if print_options: options = print_options.to_dict() - return self.execute(Command.PRINT_PAGE, options)["value"] + return cast(str, self.execute(Command.PRINT_PAGE, options)["value"]) @property def switch_to(self) -> SwitchTo: @@ -626,7 +632,7 @@ def refresh(self) -> None: """Refreshes the current page.""" self.execute(Command.REFRESH) - def get_cookies(self) -> list[dict]: + def get_cookies(self) -> Any: """Get all cookies visible to the current WebDriver instance. Returns: @@ -635,7 +641,7 @@ def get_cookies(self) -> list[dict]: """ return self.execute(Command.GET_ALL_COOKIES)["value"] - def get_cookie(self, name) -> dict | None: + def get_cookie(self, name: str) -> Any: """Get a single cookie by name (case-sensitive,). Returns: @@ -655,7 +661,7 @@ def get_cookie(self, name) -> dict | None: return None - def delete_cookie(self, name) -> None: + def delete_cookie(self, name: str) -> None: """Delete a single cookie with the given name (case-sensitive). Raises: @@ -674,7 +680,7 @@ def delete_all_cookies(self) -> None: """Delete all cookies in the scope of the session.""" self.execute(Command.DELETE_ALL_COOKIES) - def add_cookie(self, cookie_dict) -> None: + def add_cookie(self, cookie_dict: dict[str, Any]) -> None: """Adds a cookie to your current session. Args: @@ -761,7 +767,7 @@ def timeouts(self) -> Timeouts: return Timeouts(**timeouts) @timeouts.setter - def timeouts(self, timeouts) -> None: + def timeouts(self, timeouts: Timeouts) -> None: """Set all timeouts for the session. This will override any previously set timeouts. @@ -773,7 +779,7 @@ def timeouts(self, timeouts) -> None: driver.timeouts = my_timeouts ``` """ - _ = self.execute(Command.SET_TIMEOUTS, timeouts._to_json())["value"] + _ = self.execute(Command.SET_TIMEOUTS, timeouts._to_json())["value"] # type: ignore[arg-type] def find_element(self, by: str | RelativeBy = By.ID, value: str | None = None) -> WebElement: """Find an element given a By strategy and locator. @@ -799,7 +805,7 @@ def find_element(self, by: str | RelativeBy = By.ID, value: str | None = None) - raise NoSuchElementException(f"Cannot locate relative element with: {by.root}") return elements[0] - return self.execute(Command.FIND_ELEMENT, {"using": by, "value": value})["value"] + return cast(WebElement, self.execute(Command.FIND_ELEMENT, {"using": by, "value": value})["value"]) def find_elements(self, by: str | RelativeBy = By.ID, value: str | None = None) -> list[WebElement]: """Find elements given a By strategy and locator. @@ -826,18 +832,18 @@ def find_elements(self, by: str | RelativeBy = By.ID, value: str | None = None) raise FileNotFoundError(f"Could not find findElements.js in package {_pkg}") raw_function = raw_data.decode("utf8") find_element_js = f"/* findElements */return ({raw_function}).apply(null, arguments);" - return self.execute_script(find_element_js, by.to_dict()) + return cast(list[WebElement], self.execute_script(find_element_js, by.to_dict())) # Return empty list if driver returns null # See https://github.com/SeleniumHQ/selenium/issues/4555 return self.execute(Command.FIND_ELEMENTS, {"using": by, "value": value})["value"] or [] @property - def capabilities(self) -> dict: + def capabilities(self) -> dict[str, Any]: """Returns the drivers current capabilities being used.""" return self.caps - def get_screenshot_as_file(self, filename) -> bool: + def get_screenshot_as_file(self, filename: str) -> bool: """Save a screenshot of the current window to a PNG image file. Returns: @@ -866,7 +872,7 @@ def get_screenshot_as_file(self, filename) -> bool: del png return True - def save_screenshot(self, filename) -> bool: + def save_screenshot(self, filename: str) -> bool: """Save a screenshot of the current window to a PNG image file. Returns: @@ -897,9 +903,9 @@ def get_screenshot_as_base64(self) -> str: Example: `driver.get_screenshot_as_base64()` """ - return self.execute(Command.SCREENSHOT)["value"] + return cast(str, self.execute(Command.SCREENSHOT)["value"]) - def set_window_size(self, width, height, windowHandle: str = "current") -> None: + def set_window_size(self, width: int, height: int, windowHandle: str = "current") -> None: """Sets the width and height of the current window. Args: @@ -913,7 +919,7 @@ def set_window_size(self, width, height, windowHandle: str = "current") -> None: self._check_if_window_handle_is_current(windowHandle) self.set_window_rect(width=int(width), height=int(height)) - def get_window_size(self, windowHandle: str = "current") -> dict: + def get_window_size(self, windowHandle: str = "current") -> dict[str, int]: """Gets the width and height of the current window. Example: @@ -923,11 +929,11 @@ def get_window_size(self, windowHandle: str = "current") -> dict: size = self.get_window_rect() if size.get("value", None): - size = size["value"] + size = size["value"] # type: ignore[assignment] return {k: size[k] for k in ("width", "height")} - def set_window_position(self, x: float, y: float, windowHandle: str = "current") -> dict: + def set_window_position(self, x: float, y: float, windowHandle: str = "current") -> dict[str, Any]: """Sets the x,y position of the current window. Args: @@ -941,7 +947,7 @@ def set_window_position(self, x: float, y: float, windowHandle: str = "current") self._check_if_window_handle_is_current(windowHandle) return self.set_window_rect(x=int(x), y=int(y)) - def get_window_position(self, windowHandle="current") -> dict: + def get_window_position(self, windowHandle: str = "current") -> dict[str, int]: """Gets the x,y position of the current window. Example: @@ -957,7 +963,7 @@ def _check_if_window_handle_is_current(self, windowHandle: str) -> None: if windowHandle != "current": warnings.warn("Only 'current' window is supported for W3C compatible browsers.", stacklevel=2) - def get_window_rect(self) -> dict: + def get_window_rect(self) -> dict[str, int]: """Get the window's position and size. Returns: @@ -966,9 +972,15 @@ def get_window_rect(self) -> dict: Example: `driver.get_window_rect()` """ - return self.execute(Command.GET_WINDOW_RECT)["value"] + return cast(dict[str, int], self.execute(Command.GET_WINDOW_RECT)["value"]) - def set_window_rect(self, x=None, y=None, width=None, height=None) -> dict: + def set_window_rect( + self, + x: int | None = None, + y: int | None = None, + width: int | None = None, + height: int | None = None, + ) -> dict[str, Any]: """Set the window's position and size. Sets the x, y coordinates and height and width of the current window. @@ -983,14 +995,17 @@ def set_window_rect(self, x=None, y=None, width=None, height=None) -> dict: if (x is None and y is None) and (not height and not width): raise InvalidArgumentException("x and y or height and width need values") - return self.execute(Command.SET_WINDOW_RECT, {"x": x, "y": y, "width": width, "height": height})["value"] + return cast( + dict[str, Any], + self.execute(Command.SET_WINDOW_RECT, {"x": x, "y": y, "width": width, "height": height})["value"], + ) @property def file_detector(self) -> FileDetector: return self._file_detector @file_detector.setter - def file_detector(self, detector) -> None: + def file_detector(self, detector: FileDetector) -> None: """Set the file detector for keyboard input. By default, this is set to a file detector that does nothing. @@ -1006,16 +1021,16 @@ def file_detector(self, detector) -> None: self._file_detector = detector @property - def orientation(self) -> dict: + def orientation(self) -> str: """Gets the current orientation of the device. Example: `orientation = driver.orientation` """ - return self.execute(Command.GET_SCREEN_ORIENTATION)["value"] + return cast(str, self.execute(Command.GET_SCREEN_ORIENTATION)["value"]) @orientation.setter - def orientation(self, value) -> None: + def orientation(self, value: str) -> None: """Sets the current orientation of the device. Args: @@ -1070,20 +1085,23 @@ def start_devtools(self) -> tuple[Any, WebSocketConnection]: return self._devtools, self._websocket_connection @asynccontextmanager - async def bidi_connection(self): + async def bidi_connection(self) -> AsyncGenerator[BidiConnection, None]: global cdp import_cdp() if self.caps.get("se:cdp"): ws_url = self.caps.get("se:cdp") - version = self.caps.get("se:cdpVersion").split(".")[0] + cdp_version = self.caps.get("se:cdpVersion") + if not cdp_version: + raise WebDriverException("se:cdp capability present but se:cdpVersion is missing") + version = cdp_version.split(".")[0] else: version, ws_url = self._get_cdp_details() if not ws_url: raise WebDriverException("Unable to find url to connect to from capabilities") - devtools = cdp.import_devtools(version) - async with cdp.open_cdp(ws_url) as conn: + devtools = cdp.import_devtools(version) # type: ignore[union-attr] + async with cdp.open_cdp(ws_url) as conn: # type: ignore[union-attr] targets = await conn.execute(devtools.target.get_targets()) for target in targets: if target.target_id == self.current_window_handle: @@ -1299,7 +1317,7 @@ def input(self) -> Input: return self._input - def _get_cdp_details(self): + def _get_cdp_details(self) -> tuple[str, str]: import json import urllib3 @@ -1307,9 +1325,9 @@ def _get_cdp_details(self): http = urllib3.PoolManager() try: if self.caps.get("browserName") == "chrome": - debugger_address = self.caps.get("goog:chromeOptions").get("debuggerAddress") + debugger_address = self.caps.get("goog:chromeOptions").get("debuggerAddress") # type: ignore[union-attr] elif self.caps.get("browserName") in ("MicrosoftEdge", "webview2"): - debugger_address = self.caps.get("ms:edgeOptions").get("debuggerAddress") + debugger_address = self.caps.get("ms:edgeOptions").get("debuggerAddress") # type: ignore[union-attr] except AttributeError: raise WebDriverException("Can't get debugger address.") @@ -1321,7 +1339,7 @@ def _get_cdp_details(self): import re - version = re.search(r".*/(\d+)\.", browser_version).group(1) + version = re.search(r".*/(\d+)\.", browser_version).group(1) # type: ignore[union-attr] return version, websocket_url @@ -1408,12 +1426,12 @@ def set_user_verified(self, verified: bool) -> None: """ self.execute(Command.SET_USER_VERIFIED, {"authenticatorId": self._authenticator_id, "isUserVerified": verified}) - def get_downloadable_files(self) -> list: + def get_downloadable_files(self) -> list[str]: """Retrieves the downloadable files as a list of file names.""" if "se:downloadsEnabled" not in self.capabilities: raise WebDriverException("You must enable downloads in order to work with downloadable files.") - return self.execute(Command.GET_DOWNLOADABLE_FILES)["value"]["names"] + return cast(list[str], self.execute(Command.GET_DOWNLOADABLE_FILES)["value"]["names"]) def download_file(self, file_name: str, target_directory: str) -> None: """Download a file with the specified file name to the target directory. @@ -1473,7 +1491,7 @@ def fedcm(self) -> FedCM: @property def supports_fedcm(self) -> bool: """Returns whether the browser supports FedCM capabilities.""" - return self.capabilities.get(ArgOptions.FEDCM_CAPABILITY, False) + return cast(bool, self.capabilities.get(ArgOptions.FEDCM_CAPABILITY, False)) def _require_fedcm_support(self) -> None: """Raises an exception if FedCM is not supported.""" @@ -1489,7 +1507,12 @@ def dialog(self) -> Dialog: self._require_fedcm_support() return Dialog(self) - def fedcm_dialog(self, timeout=5, poll_frequency=0.5, ignored_exceptions=None): + def fedcm_dialog( + self, + timeout: float = 5, + poll_frequency: float = 0.5, + ignored_exceptions: Iterable[type[Exception]] | None = None, + ) -> Dialog | None: """Waits for and returns the FedCM dialog. Args: