From 8a2e2d3de5a96eaa50b9c9336d97396f8489bda4 Mon Sep 17 00:00:00 2001 From: Michael Carlstrom Date: Fri, 30 Aug 2024 20:45:28 -0400 Subject: [PATCH] Add types to Node.py (#1346) * Add types Signed-off-by: Michael Carlstrom * string around type Signed-off-by: Michael Carlstrom * string around type Signed-off-by: Michael Carlstrom * use error Signed-off-by: Michael Carlstrom * flake8 Signed-off-by: Michael Carlstrom * switch error raise order Signed-off-by: Michael Carlstrom * unify error Signed-off-by: Michael Carlstrom --------- Signed-off-by: Michael Carlstrom --- rclpy/rclpy/callback_groups.py | 3 +- rclpy/rclpy/exceptions.py | 2 +- rclpy/rclpy/impl/_rclpy_pybind11.pyi | 142 +++++++++++++++++++++++--- rclpy/rclpy/node.py | 143 ++++++++++++++------------- rclpy/rclpy/parameter.py | 15 +-- rclpy/rclpy/qos.py | 14 +-- rclpy/rclpy/topic_endpoint_info.py | 5 +- 7 files changed, 223 insertions(+), 101 deletions(-) diff --git a/rclpy/rclpy/callback_groups.py b/rclpy/rclpy/callback_groups.py index de5ef04af..bee08d611 100644 --- a/rclpy/rclpy/callback_groups.py +++ b/rclpy/rclpy/callback_groups.py @@ -23,7 +23,8 @@ from rclpy.client import Client from rclpy.service import Service from rclpy.waitable import Waitable - Entity = Union[Subscription, Timer, Client, Service, Waitable[Any]] + from rclpy.guard_condition import GuardCondition + Entity = Union[Subscription, Timer, Client, Service, Waitable[Any], GuardCondition] class CallbackGroup: diff --git a/rclpy/rclpy/exceptions.py b/rclpy/rclpy/exceptions.py index ee22e8b38..dbe63e06d 100644 --- a/rclpy/rclpy/exceptions.py +++ b/rclpy/rclpy/exceptions.py @@ -109,7 +109,7 @@ class InvalidParameterTypeException(ParameterException): from rclpy.parameter import Parameter - def __init__(self, desired_parameter: Parameter, expected_type: Parameter.Type) -> None: + def __init__(self, desired_parameter: Parameter, expected_type: str) -> None: from rclpy.parameter import Parameter ParameterException.__init__( self, diff --git a/rclpy/rclpy/impl/_rclpy_pybind11.pyi b/rclpy/rclpy/impl/_rclpy_pybind11.pyi index 0570c430c..679f6fa4f 100644 --- a/rclpy/rclpy/impl/_rclpy_pybind11.pyi +++ b/rclpy/rclpy/impl/_rclpy_pybind11.pyi @@ -20,8 +20,8 @@ from typing import Any, Generic, Literal, overload, Sequence, TypedDict from rclpy.clock import JumpHandle from rclpy.clock_type import ClockType -from rclpy.qos import (QoSDurabilityPolicy, QoSHistoryPolicy, QoSLivelinessPolicy, - QoSReliabilityPolicy) +from rclpy.duration import Duration +from rclpy.parameter import Parameter from rclpy.subscription import MessageInfo from rclpy.type_support import MsgT @@ -173,10 +173,118 @@ class Subscription(Destroyable, Generic[MsgT]): """Count the publishers from a subscription.""" -class Node: +class Node(Destroyable): + + def __init__(self, node_name: str, namespace_: str, context: Context, + pycli_args: list[str] | None, use_global_arguments: bool, + enable_rosout: bool) -> None: ... + + @property + def pointer(self) -> int: + """Get the address of the entity as an integer.""" + + def get_fully_qualified_name(self) -> str: + """Get the fully qualified name of the node.""" + + def logger_name(self) -> str: + """Get the name of the logger associated with a node.""" + + def get_node_name(self) -> str: + """Get the name of a node.""" + + def get_namespace(self) -> str: + """Get the namespace of a node.""" + + def get_count_publishers(self, topic_name: str) -> int: + """Return the count of all the publishers known for that topic in the entire ROS graph.""" + + def get_count_subscribers(self, topic_name: str) -> int: + """Return the count of all the subscribers known for that topic in the entire ROS graph.""" + + def get_count_clients(self, service_name: str) -> int: + """Return the count of all the clients known for that service in the entire ROS graph.""" + + def get_count_services(self, service_name: str) -> int: + """Return the count of all the servers known for that service in the entire ROS graph.""" + + def get_node_names_and_namespaces(self) -> list[tuple[str, str, str] | tuple[str, str]]: + """Get the list of nodes discovered by the provided node.""" + + def get_node_names_and_namespaces_with_enclaves(self) -> list[tuple[str, str, str] | + tuple[str, str]]: + """Get the list of nodes discovered by the provided node, with their enclaves.""" + + def get_action_client_names_and_types_by_node(self, remote_node_name: str, + remote_node_namespace: str) -> list[tuple[str, + list[str]]]: + """Get action client names and types by node.""" + + def get_action_server_names_and_types_by_node(self, remote_node_name: str, + remote_node_namespace: str) -> list[tuple[str, + list[str]]]: + """Get action server names and types by node.""" + + def get_action_names_and_types(self) -> list[tuple[str, list[str]]]: + """Get action names and types.""" + + def get_parameters(self, pyparamter_cls: type[Parameter]) -> dict[str, Parameter]: + """Get a list of parameters for the current node.""" + + +def rclpy_resolve_name(node: Node, topic_name: str, only_expand: bool, is_service: bool) -> str: + """Expand and remap a topic or service name.""" + + +def rclpy_get_publisher_names_and_types_by_node(node: Node, no_demangle: bool, node_name: str, + node_namespace: str + ) -> list[tuple[str, list[str]]]: + """Get topic names and types for which a remote node has publishers.""" + + +def rclpy_get_subscriber_names_and_types_by_node(node: Node, no_demangle: bool, node_name: str, + node_namespace: str + ) -> list[tuple[str, list[str]]]: + """Get topic names and types for which a remote node has subscribers.""" + + +def rclpy_get_service_names_and_types_by_node(node: Node, node_name: str, node_namespace: str + ) -> list[tuple[str, list[str]]]: + """Get all service names and types in the ROS graph.""" + + +def rclpy_get_client_names_and_types_by_node(node: Node, node_name: str, node_namespace: str + ) -> list[tuple[str, list[str]]]: + """Get service names and types for which a remote node has servers.""" + + +def rclpy_get_service_names_and_types(node: Node) -> list[tuple[str, list[str]]]: + """Get all service names and types in the ROS graph.""" + + +class TypeHashDict(TypedDict): + version: int + value: bytes + + +class QoSDict(TypedDict): pass +class TopicEndpointInfoDict(TypedDict): + node_name: str + node_namespace: str + topic_type: str + topic_type_hash: TypeHashDict + endpoint_type: int + endpoint_gid: list[int] + qos_profile: rmw_qos_profile_dict + + +def rclpy_get_publishers_info_by_topic(node: Node, topic_name: str, no_mangle: bool + ) -> list[TopicEndpointInfoDict]: + """Get publishers info for a topic.""" + + class Publisher(Destroyable, Generic[MsgT]): def __init__(self, arg0: Node, arg1: type[MsgT], arg2: str, arg3: rmw_qos_profile_t) -> None: @@ -257,14 +365,14 @@ PredefinedQosProfileTNames = Literal['qos_profile_sensor_data', 'qos_profile_def class rmw_qos_profile_dict(TypedDict): - qos_history: QoSHistoryPolicy | int - qos_depth: int - qos_reliability: QoSReliabilityPolicy | int - qos_durability: QoSDurabilityPolicy | int - pyqos_lifespan: rcl_duration_t - pyqos_deadline: rcl_duration_t - qos_liveliness: QoSLivelinessPolicy | int - pyqos_liveliness_lease_duration: rcl_duration_t + depth: int + history: int + reliability: int + durability: int + lifespan: Duration + deadline: Duration + liveliness: int + liveliness_lease_duration: Duration avoid_ros_namespace_conventions: bool @@ -335,6 +443,18 @@ class WaitSet(Destroyable): """Wait until timeout is reached or event happened.""" +class RCLError(RuntimeError): + pass + + +class NodeNameNonExistentError(RCLError): + pass + + +class InvalidHandle(RuntimeError): + pass + + class SignalHandlerOptions(Enum): _value_: int NO = ... diff --git a/rclpy/rclpy/node.py b/rclpy/rclpy/node.py index cba822bbb..8da2faafc 100644 --- a/rclpy/rclpy/node.py +++ b/rclpy/rclpy/node.py @@ -65,6 +65,7 @@ from rclpy.expand_topic_name import expand_topic_name from rclpy.guard_condition import GuardCondition from rclpy.impl.implementation_singleton import rclpy_implementation as _rclpy +from rclpy.impl.rcutils_logger import RcutilsLogger from rclpy.logging import get_logger from rclpy.logging_service import LoggingService from rclpy.parameter import (AllowableParameterValue, AllowableParameterValueT, Parameter, @@ -137,7 +138,7 @@ def __init__( use_global_arguments: bool = True, enable_rosout: bool = True, start_parameter_services: bool = True, - parameter_overrides: Optional[List[Parameter]] = None, + parameter_overrides: Optional[List[Parameter[Any]]] = None, allow_undeclared_parameters: bool = False, automatically_declare_parameters_from_overrides: bool = False, enable_logger_service: bool = False @@ -168,29 +169,31 @@ def __init__( to get and set logger levels of this node. Otherwise, logger levels are only managed locally. That is, logger levels cannot be changed remotely. """ - self.__handle = None self._context = get_default_context() if context is None else context - self._parameters: Dict[str, Parameter] = {} - self._publishers: List[Publisher] = [] - self._subscriptions: List[Subscription] = [] - self._clients: List[Client] = [] - self._services: List[Service] = [] + self._parameters: Dict[str, Parameter[Any]] = {} + self._publishers: List[Publisher[Any]] = [] + self._subscriptions: List[Subscription[Any]] = [] + self._clients: List[Client[Any, Any, Any]] = [] + self._services: List[Service[Any, Any, Any]] = [] self._timers: List[Timer] = [] self._guards: List[GuardCondition] = [] self.__waitables: List[Waitable[Any]] = [] self._default_callback_group = MutuallyExclusiveCallbackGroup() - self._pre_set_parameters_callbacks: List[Callable[[List[Parameter]], List[Parameter]]] = [] + self._pre_set_parameters_callbacks: List[Callable[[List[Parameter[Any]]], + List[Parameter[Any]]]] = [] self._on_set_parameters_callbacks: \ - List[Callable[[List[Parameter]], SetParametersResult]] = [] - self._post_set_parameters_callbacks: List[Callable[[List[Parameter]], None]] = [] + List[Callable[[List[Parameter[Any]]], SetParametersResult]] = [] + self._post_set_parameters_callbacks: List[Callable[[List[Parameter[Any]]], None]] = [] self._rate_group = ReentrantCallbackGroup() self._allow_undeclared_parameters = allow_undeclared_parameters - self._parameter_overrides: Dict[str, Parameter] = {} + self._parameter_overrides: Dict[str, Parameter[Any]] = {} self._descriptors: Dict[str, ParameterDescriptor] = {} namespace = namespace or '' - if not self._context.ok(): + + if self._context.handle is None or not self._context.ok(): raise NotInitializedException('cannot create node') + with self._context.handle: try: self.__node = _rclpy.Node( @@ -215,10 +218,11 @@ def __init__( with self.handle: self._logger = get_logger(self.__node.logger_name()) - self.__executor_weakref = None + self.__executor_weakref: Optional[weakref.ReferenceType[Executor]] = None - self._parameter_event_publisher = self.create_publisher( - ParameterEvent, '/parameter_events', qos_profile_parameter_events) + self._parameter_event_publisher: Optional[Publisher[ParameterEvent]] = \ + self.create_publisher(ParameterEvent, '/parameter_events', + qos_profile_parameter_events) with self.handle: self._parameter_overrides = self.__node.get_parameters(Parameter) @@ -256,22 +260,22 @@ def __init__( self._context.track_node(self) @property - def publishers(self) -> Iterator[Publisher]: + def publishers(self) -> Iterator[Publisher[Any]]: """Get publishers that have been created on this node.""" yield from self._publishers @property - def subscriptions(self) -> Iterator[Subscription]: + def subscriptions(self) -> Iterator[Subscription[Any]]: """Get subscriptions that have been created on this node.""" yield from self._subscriptions @property - def clients(self) -> Iterator[Client]: + def clients(self) -> Iterator[Client[Any, Any, Any]]: """Get clients that have been created on this node.""" yield from self._clients @property - def services(self) -> Iterator[Service]: + def services(self) -> Iterator[Service[Any, Any, Any]]: """Get services that have been created on this node.""" yield from self._services @@ -311,7 +315,7 @@ def executor(self, new_executor: Executor) -> None: new_executor.add_node(self) self.__executor_weakref = weakref.ref(new_executor) - def _wake_executor(self): + def _wake_executor(self) -> None: executor = self.executor if executor: executor.wake() @@ -332,7 +336,7 @@ def default_callback_group(self) -> CallbackGroup: return self._default_callback_group @property - def handle(self): + def handle(self) -> _rclpy.Node: """ Get the handle to the underlying `rcl_node_t`. @@ -343,7 +347,7 @@ def handle(self): return self.__node @handle.setter - def handle(self, value): + def handle(self, value: None) -> None: raise AttributeError('handle cannot be modified after node creation') def get_name(self) -> str: @@ -360,7 +364,7 @@ def get_clock(self) -> Clock: """Get the clock used by the node.""" return self._clock - def get_logger(self): + def get_logger(self) -> RcutilsLogger: """Get the nodes logger.""" return self._logger @@ -383,7 +387,7 @@ def declare_parameter( value: Union[AllowableParameterValue, Parameter.Type, ParameterValue] = None, descriptor: Optional[ParameterDescriptor] = None, ignore_override: bool = False - ) -> Parameter: + ) -> Parameter[Any]: """ Declare and initialize a parameter. @@ -423,7 +427,7 @@ def declare_parameters( ParameterDescriptor], ]], ignore_override: bool = False - ) -> List[Parameter]: + ) -> List[Parameter[Any]]: """ Declare a list of parameters. @@ -470,7 +474,7 @@ def declare_parameters( :raises: InvalidParameterValueException if the registered callback rejects any parameter. :raises: TypeError if any tuple in **parameters** does not match the annotated type. """ - parameter_list: List[Parameter] = [] + parameter_list: List[Parameter[Any]] = [] descriptors: Dict[str, ParameterDescriptor] = {} for index, parameter_tuple in enumerate(parameters): if len(parameter_tuple) < 1 or len(parameter_tuple) > 3: @@ -545,6 +549,8 @@ def declare_parameters( 'being included in self._parameter_overrides, and ', 'ignore_override=False') + from typing import cast + value = cast(AllowableParameterValue, value) parameter_list.append(Parameter(name, value=value)) descriptors.update({name: descriptor}) @@ -565,7 +571,7 @@ def declare_parameters( def _declare_parameter_common( self, - parameter_list: List[Parameter], + parameter_list: List[Parameter[Any]], descriptors: Optional[Dict[str, ParameterDescriptor]] = None ) -> List[SetParametersResult]: """ @@ -612,13 +618,14 @@ def _declare_parameter_common( ) if not result.successful: if result.reason.startswith('Wrong parameter type'): - raise InvalidParameterTypeException( - param, Parameter.Type(descriptors[param._name].type).name) + if descriptors: + raise InvalidParameterTypeException( + param, Parameter.Type(descriptors[param._name].type).name) raise InvalidParameterValueException(param.name, param.value, result.reason) results.append(result) return results - def undeclare_parameter(self, name: str): + def undeclare_parameter(self, name: str) -> None: """ Undeclare a previously declared parameter. @@ -676,7 +683,7 @@ def get_parameter_type(self, name: str) -> Parameter.Type: else: raise ParameterNotDeclaredException(name) - def get_parameters(self, names: List[str]) -> List[Parameter]: + def get_parameters(self, names: List[str]) -> List[Parameter[Any]]: """ Get a list of parameters. @@ -693,7 +700,7 @@ def get_parameters(self, names: List[str]) -> List[Parameter]: raise TypeError('All names must be instances of type str') return [self.get_parameter(name) for name in names] - def get_parameter(self, name: str) -> Parameter: + def get_parameter(self, name: str) -> Parameter[Any]: """ Get a parameter by name. @@ -721,7 +728,7 @@ def get_parameter(self, name: str) -> Parameter: raise ParameterNotDeclaredException(name) def get_parameter_or( - self, name: str, alternative_value: Optional[Parameter] = None) -> Parameter: + self, name: str, alternative_value: Optional[Parameter[Any]] = None) -> Parameter[Any]: """ Get a parameter or the alternative value. @@ -745,7 +752,7 @@ def get_parameter_or( return self._parameters[name] - def get_parameters_by_prefix(self, prefix: str) -> Dict[str, Parameter]: + def get_parameters_by_prefix(self, prefix: str) -> Dict[str, Parameter[Any]]: """ Get parameters that have a given prefix in their names as a dictionary. @@ -771,7 +778,7 @@ def get_parameters_by_prefix(self, prefix: str) -> Dict[str, Parameter]: if param_name.startswith(prefix) } - def set_parameters(self, parameter_list: List[Parameter]) -> List[SetParametersResult]: + def set_parameters(self, parameter_list: List[Parameter[Any]]) -> List[SetParametersResult]: """ Set parameters for the node, and return the result for the set action. @@ -823,7 +830,8 @@ def set_parameters(self, parameter_list: List[Parameter]) -> List[SetParametersR return results - def set_parameters_atomically(self, parameter_list: List[Parameter]) -> SetParametersResult: + def set_parameters_atomically(self, parameter_list: List[Parameter[Any]] + ) -> SetParametersResult: """ Set the given parameters, all at one time, and then aggregate result. @@ -868,7 +876,7 @@ def set_parameters_atomically(self, parameter_list: List[Parameter]) -> SetParam def _set_parameters_atomically( self, - parameter_list: List[Parameter], + parameter_list: List[Parameter[Any]], ) -> SetParametersResult: modified_parameter_list = self._call_pre_set_parameters_callback(parameter_list) @@ -888,7 +896,7 @@ def _set_parameters_atomically( def _set_parameters_atomically_common( self, - parameter_list: List[Parameter], + parameter_list: List[Parameter[Any]], descriptors: Optional[Dict[str, ParameterDescriptor]] = None, allow_not_set_type: bool = False ) -> SetParametersResult: @@ -979,7 +987,8 @@ def _set_parameters_atomically_common( self._parameters[param.name] = param parameter_event.stamp = self._clock.now().to_msg() - self._parameter_event_publisher.publish(parameter_event) + if self._parameter_event_publisher: + self._parameter_event_publisher.publish(parameter_event) # call post set parameter registered callbacks self._call_post_set_parameters_callback(parameter_list) @@ -1046,7 +1055,7 @@ def list_parameters( return result - def _check_undeclared_parameters(self, parameter_list: List[Parameter]): + def _check_undeclared_parameters(self, parameter_list: List[Parameter[Any]]) -> None: """ Check if parameter list has correct types and was declared beforehand. @@ -1062,9 +1071,10 @@ def _check_undeclared_parameters(self, parameter_list: List[Parameter]): if not self._allow_undeclared_parameters and any(undeclared_parameters): raise ParameterNotDeclaredException(list(undeclared_parameters)) - def _call_pre_set_parameters_callback(self, parameter_list: List[Parameter]): + def _call_pre_set_parameters_callback(self, parameter_list: List[Parameter[Any]] + ) -> Optional[List[Parameter[Any]]]: if self._pre_set_parameters_callbacks: - modified_parameter_list = [] + modified_parameter_list: List[Parameter[Any]] = [] for callback in self._pre_set_parameters_callbacks: modified_parameter_list.extend(callback(parameter_list)) @@ -1072,14 +1082,14 @@ def _call_pre_set_parameters_callback(self, parameter_list: List[Parameter]): else: return None - def _call_post_set_parameters_callback(self, parameter_list: List[Parameter]): + def _call_post_set_parameters_callback(self, parameter_list: List[Parameter[Any]]) -> None: if self._post_set_parameters_callbacks: for callback in self._post_set_parameters_callbacks: callback(parameter_list) def add_pre_set_parameters_callback( self, - callback: Callable[[List[Parameter]], List[Parameter]] + callback: Callable[[List[Parameter[Any]]], List[Parameter[Any]]] ) -> None: """ Add a callback gets triggered before parameters are validated. @@ -1115,7 +1125,7 @@ def add_pre_set_parameters_callback( def add_on_set_parameters_callback( self, - callback: Callable[[List[Parameter]], SetParametersResult] + callback: Callable[[List[Parameter[Any]]], SetParametersResult] ) -> None: """ Add a callback in front to the list of callbacks. @@ -1132,7 +1142,7 @@ def add_on_set_parameters_callback( def add_post_set_parameters_callback( self, - callback: Callable[[List[Parameter]], None] + callback: Callable[[List[Parameter[Any]]], None] ) -> None: """ Add a callback gets triggered after parameters are set successfully. @@ -1153,7 +1163,7 @@ def add_post_set_parameters_callback( def remove_pre_set_parameters_callback( self, - callback: Callable[[List[Parameter]], List[Parameter]] + callback: Callable[[List[Parameter[Any]]], List[Parameter[Any]]] ) -> None: """ Remove a callback from list of callbacks. @@ -1167,7 +1177,7 @@ def remove_pre_set_parameters_callback( def remove_on_set_parameters_callback( self, - callback: Callable[[List[Parameter]], SetParametersResult] + callback: Callable[[List[Parameter[Any]]], SetParametersResult] ) -> None: """ Remove a callback from list of callbacks. @@ -1181,7 +1191,7 @@ def remove_on_set_parameters_callback( def remove_post_set_parameters_callback( self, - callback: Callable[[List[Parameter]], None] + callback: Callable[[List[Parameter[Any]]], None] ) -> None: """ Remove a callback from list of callbacks. @@ -1195,7 +1205,7 @@ def remove_post_set_parameters_callback( def _apply_descriptors( self, - parameter_list: List[Parameter], + parameter_list: List[Parameter[Any]], descriptors: Dict[str, ParameterDescriptor], check_read_only: bool = True ) -> SetParametersResult: @@ -1221,7 +1231,7 @@ def _apply_descriptors( def _apply_descriptor( self, - parameter: Parameter, + parameter: Parameter[Any], descriptor: Optional[ParameterDescriptor] = None, check_read_only: bool = True ) -> SetParametersResult: @@ -1279,7 +1289,7 @@ def _apply_descriptor( def _apply_integer_range( self, - parameter: Parameter, + parameter: Parameter[Any], integer_range: IntegerRange ) -> SetParametersResult: min_value = min(integer_range.from_value, integer_range.to_value) @@ -1315,7 +1325,7 @@ def _apply_integer_range( def _apply_floating_point_range( self, - parameter: Parameter, + parameter: Parameter[Any], floating_point_range: FloatingPointRange ) -> SetParametersResult: min_value = min(floating_point_range.from_value, floating_point_range.to_value) @@ -1360,7 +1370,7 @@ def _apply_floating_point_range( def _apply_descriptor_and_set( self, - parameter: Parameter, + parameter: Parameter[Any], descriptor: Optional[ParameterDescriptor] = None, check_read_only: bool = True ) -> SetParametersResult: @@ -1461,7 +1471,8 @@ def set_descriptor( self._descriptors[name] = descriptor return self.get_parameter(name).get_parameter_value() - def _validate_topic_or_service_name(self, topic_or_service_name, *, is_service=False): + def _validate_topic_or_service_name(self, topic_or_service_name: str, *, + is_service: bool = False) -> None: name = self.get_name() namespace = self.get_namespace() validate_node_name(name) @@ -1470,7 +1481,7 @@ def _validate_topic_or_service_name(self, topic_or_service_name, *, is_service=F expanded_topic_or_service_name = expand_topic_name(topic_or_service_name, name, namespace) validate_full_topic_name(expanded_topic_or_service_name, is_service=is_service) - def _validate_qos_or_depth_parameter(self, qos_or_depth) -> QoSProfile: + def _validate_qos_or_depth_parameter(self, qos_or_depth: Union[QoSProfile, int]) -> QoSProfile: if isinstance(qos_or_depth, QoSProfile): return qos_or_depth elif isinstance(qos_or_depth, int): @@ -1534,7 +1545,7 @@ def create_publisher( callback_group: Optional[CallbackGroup] = None, event_callbacks: Optional[PublisherEventCallbacks] = None, qos_overriding_options: Optional[QoSOverridingOptions] = None, - publisher_class: Type[Publisher[MsgT]] = Publisher[MsgT], + publisher_class: Type[Publisher[MsgT]] = Publisher, ) -> Publisher[MsgT]: """ Create a new publisher. @@ -1803,7 +1814,7 @@ def create_timer( def create_guard_condition( self, - callback: Callable, + callback: Callable[[], None], callback_group: Optional[CallbackGroup] = None ) -> GuardCondition: """ @@ -1845,7 +1856,7 @@ def create_rate( timer = self.create_timer(period, callback, group, clock) return Rate(timer, context=self.context) - def destroy_publisher(self, publisher: Publisher) -> bool: + def destroy_publisher(self, publisher: Publisher[Any]) -> bool: """ Destroy a publisher created by the node. @@ -1863,7 +1874,7 @@ def destroy_publisher(self, publisher: Publisher) -> bool: return True return False - def destroy_subscription(self, subscription: Subscription) -> bool: + def destroy_subscription(self, subscription: Subscription[Any]) -> bool: """ Destroy a subscription created by the node. @@ -1881,7 +1892,7 @@ def destroy_subscription(self, subscription: Subscription) -> bool: return True return False - def destroy_client(self, client: Client) -> bool: + def destroy_client(self, client: Client[Any, Any, Any]) -> bool: """ Destroy a service client created by the node. @@ -1897,7 +1908,7 @@ def destroy_client(self, client: Client) -> bool: return True return False - def destroy_service(self, service: Service) -> bool: + def destroy_service(self, service: Service[Any, Any, Any]) -> bool: """ Destroy a service server created by the node. @@ -1955,7 +1966,7 @@ def destroy_rate(self, rate: Rate) -> bool: rate.destroy() return success - def destroy_node(self): + def destroy_node(self) -> None: """ Destroy the node. @@ -2152,7 +2163,7 @@ def get_fully_qualified_name(self) -> str: with self.handle: return self.handle.get_fully_qualified_name() - def _count_publishers_or_subscribers(self, topic_name, func): + def _count_publishers_or_subscribers(self, topic_name: str, func: Callable[[str], int]) -> int: fq_topic_name = expand_topic_name(topic_name, self.get_name(), self.get_namespace()) validate_full_topic_name(fq_topic_name) with self.handle: @@ -2188,7 +2199,7 @@ def count_subscribers(self, topic_name: str) -> int: return self._count_publishers_or_subscribers( topic_name, self.handle.get_count_subscribers) - def _count_clients_or_servers(self, service_name, func): + def _count_clients_or_servers(self, service_name: str, func: Callable[[str], int]) -> int: fq_service_name = expand_topic_name(service_name, self.get_name(), self.get_namespace()) validate_full_topic_name(fq_service_name, is_service=True) with self.handle: @@ -2228,7 +2239,7 @@ def _get_info_by_topic( self, topic_name: str, no_mangle: bool, - func: Callable[[object, str, bool], List[Dict]] + func: Callable[[_rclpy.Node, str, bool], List['_rclpy.TopicEndpointInfoDict']] ) -> List[TopicEndpointInfo]: with self.handle: if no_mangle: diff --git a/rclpy/rclpy/parameter.py b/rclpy/rclpy/parameter.py index 464ee7065..e1945ea96 100644 --- a/rclpy/rclpy/parameter.py +++ b/rclpy/rclpy/parameter.py @@ -23,6 +23,7 @@ from typing import overload from typing import Tuple from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from rcl_interfaces.msg import Parameter as ParameterMsg @@ -33,7 +34,6 @@ PARAMETER_SEPARATOR_STRING = '.' if TYPE_CHECKING: - from typing_extensions import TypeVar # Mypy does not handle string literals of array.array[int/str/float] very well # So if user has newer version of python can use proper array types. if sys.version_info > (3, 9): @@ -51,15 +51,13 @@ List[float], Tuple[float, ...], 'array.array[float]', List[str], Tuple[str, ...], 'array.array[str]'] - AllowableParameterValueT = TypeVar('AllowableParameterValueT', - bound=AllowableParameterValue, - default=AllowableParameterValue) else: - from typing import TypeVar # Done to prevent runtime errors of undefined values. # after python3.13 is minimum support this could be removed. AllowableParameterValue = Any - AllowableParameterValueT = TypeVar('AllowableParameterValueT') + +AllowableParameterValueT = TypeVar('AllowableParameterValueT', + bound=AllowableParameterValue) class Parameter(Generic[AllowableParameterValueT]): @@ -171,9 +169,12 @@ def from_parameter_msg(cls, param_msg: ParameterMsg) -> 'Parameter[AllowablePara def __init__(self, name: str, type_: Optional['Parameter.Type'] = None) -> None: ... @overload - def __init__(self, name: str, type_: Optional['Parameter.Type'], + def __init__(self, name: str, type_: 'Parameter.Type', value: AllowableParameterValueT) -> None: ... + @overload + def __init__(self, name: str, *, value: AllowableParameterValueT) -> None: ... + def __init__(self, name: str, type_: Optional['Parameter.Type'] = None, value=None) -> None: if type_ is None: # This will raise a TypeError if it is not possible to get a type from the value. diff --git a/rclpy/rclpy/qos.py b/rclpy/rclpy/qos.py index 6c8b18f92..ceb6d2314 100644 --- a/rclpy/rclpy/qos.py +++ b/rclpy/rclpy/qos.py @@ -14,7 +14,7 @@ from enum import Enum, IntEnum from typing import (Callable, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, - TypedDict, TypeVar, Union) + TypeVar, Union) import warnings from rclpy.duration import Duration @@ -56,18 +56,6 @@ def __init__(self, message: str) -> None: Exception(self, f'Invalid QoSProfile: {message}') -class QoSProfileDictionary(TypedDict): - history: 'QoSHistoryPolicy' - depth: int - reliability: 'QoSReliabilityPolicy' - durability: 'QoSDurabilityPolicy' - lifespan: Duration - deadline: Duration - liveliness: 'QoSLivelinessPolicy' - liveliness_lease_duration: Duration - avoid_ros_namespace_conventions: bool - - class QoSProfile: """Define Quality of Service policies.""" diff --git a/rclpy/rclpy/topic_endpoint_info.py b/rclpy/rclpy/topic_endpoint_info.py index 40301ef91..ef5c74cba 100644 --- a/rclpy/rclpy/topic_endpoint_info.py +++ b/rclpy/rclpy/topic_endpoint_info.py @@ -15,7 +15,8 @@ from enum import IntEnum from typing import List, Union -from rclpy.qos import QoSHistoryPolicy, QoSPresetProfiles, QoSProfile, QoSProfileDictionary +from rclpy.impl.implementation_singleton import rclpy_implementation as _rclpy +from rclpy.qos import QoSHistoryPolicy, QoSPresetProfiles, QoSProfile from rclpy.type_hash import TypeHash, TypeHashDictionary @@ -159,7 +160,7 @@ def qos_profile(self) -> QoSProfile: return self._qos_profile @qos_profile.setter - def qos_profile(self, value: Union[QoSProfile, QoSProfileDictionary]) -> None: + def qos_profile(self, value: Union[QoSProfile, '_rclpy.rmw_qos_profile_dict']) -> None: if isinstance(value, QoSProfile): self._qos_profile = value elif isinstance(value, dict):