Skip to content

Commit

Permalink
Merge pull request #48 from jjshoots/upgrade_ruff
Browse files Browse the repository at this point in the history
Upgrade ruff and typehints
  • Loading branch information
jjshoots authored Jun 29, 2024
2 parents e62f9b6 + 146122f commit 1c8074b
Show file tree
Hide file tree
Showing 35 changed files with 373 additions and 28 deletions.
2 changes: 2 additions & 0 deletions PyFlyt/core/abstractions/base_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ def step(self, state: np.ndarray, setpoint: np.ndarray) -> np.ndarray:
"""Step the controller.
Args:
----
state (np.ndarray): state
setpoint (np.ndarray): setpoint
"""
pass
6 changes: 6 additions & 0 deletions PyFlyt/core/abstractions/base_drone.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class DroneClass(ABC):
Each drone inheriting from this class must have several attributes and methods implemented before they can be considered usable.
Args:
----
p (bullet_client.BulletClient): PyBullet physics client ID.
start_pos (np.ndarray): an `(3,)` array for the starting X, Y, Z position for the drone.
start_orn (np.ndarray): an `(3,)` array for the starting X, Y, Z orientation for the drone.
Expand Down Expand Up @@ -64,6 +65,7 @@ class DroneClass(ABC):
>>> self.use_camera = use_camera
>>> if self.use_camera:
>>> self.camera = Camera(...)
"""

def __init__(
Expand All @@ -80,6 +82,7 @@ def __init__(
"""Defines the default configuration for UAVs, to be used in conjunction with the Aviary class.
Args:
----
p (bullet_client.BulletClient): PyBullet physics client ID.
start_pos (np.ndarray): an `(3,)` array for the starting X, Y, Z position for the drone.
start_orn (np.ndarray): an `(3,)` array for the starting X, Y, Z orientation for the drone.
Expand All @@ -88,6 +91,7 @@ def __init__(
drone_model (str): name of the drone itself, must be the same name as the folder where the URDF and YAML files are located.
model_dir (None | str = None): directory where the drone model folder is located, if none is provided, defaults to the directory of the default drones.
np_random (None | np.random.RandomState = None): random number generator of the simulation.
"""
if physics_hz % control_hz != 0:
raise ValueError(
Expand Down Expand Up @@ -264,9 +268,11 @@ def register_controller(
"""Default register_controller.
Args:
----
controller_id (int): ID to bind to this controller
controller_constructor (type[ControlClass]): A class pointer to the controller implementation, must be subclass of `ControlClass`.
base_mode (int): Whether this controller uses outputs of an underlying controller as setpoints.
"""
if controller_id <= 0:
raise ValueError(
Expand Down
2 changes: 2 additions & 0 deletions PyFlyt/core/abstractions/base_wind_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ def __call__(self, time: float, position: np.ndarray) -> np.ndarray:
"""When given the time float and a position as an (n, 3) array, must return a (n, 3) array representing the local wind velocity.
Args:
----
time (float): float representing the timestep of the simulation in seconds.
position (np.ndarray): (n, 3) array representing a series of n positions to sample wind velocites.
"""
pass

Expand Down
14 changes: 13 additions & 1 deletion PyFlyt/core/abstractions/boosters.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class Boosters:
Additionally, some boosters, typically of the solid fuel variety, cannot be extinguished and reignited, a property we call reignitability.
Args:
----
p (bullet_client.BulletClient): PyBullet physics client ID.
physics_period (float): physics period of the simulation.
np_random (np.random.RandomState): random number generator of the simulation.
Expand All @@ -31,6 +32,7 @@ class Boosters:
thrust_unit (np.ndarray): an `(n, 3)` array representing the unit vector pointing in the direction of force for each booster, relative to the booster link's body frame.
reignitable (np.ndarray | list[bool]): a list of booleans representing whether the booster can be extinguished and then reignited.
noise_ratio (np.ndarray): a list of floats representing the percent amount of fluctuation present in each booster.
"""

def __init__(
Expand All @@ -54,6 +56,7 @@ def __init__(
"""Used for simulating an array of boosters.
Args:
----
p (bullet_client.BulletClient): PyBullet physics client ID.
physics_period (float): physics period of the simulation.
np_random (np.random.RandomState): random number generator of the simulation.
Expand All @@ -69,6 +72,7 @@ def __init__(
thrust_unit (np.ndarray): an `(n, 3)` array representing the unit vector pointing in the direction of force for each booster, relative to the booster link's body frame.
reignitable (np.ndarray | list[bool]): a list of booleans representing whether the booster can be extinguished and then reignited.
noise_ratio (np.ndarray): a list of floats representing the percent amount of fluctuation present in each booster.
"""
self.p = p
self.physics_period = physics_period
Expand Down Expand Up @@ -118,7 +122,9 @@ def reset(self, starting_fuel_ratio: float | np.ndarray = 1.0):
"""Reset the boosters.
Args:
----
starting_fuel_ratio (float | np.ndarray): ratio amount of fuel that the booster is reset to.
"""
# deal with everything in percents
self.ratio_fuel_remaining = (
Expand All @@ -135,8 +141,10 @@ def get_states(self) -> np.ndarray:
- (b0, b1, ..., bn) represent the remaining fuel ratio
- (c0, c1, ..., cn) represent the current throttle state
Returns:
Returns
-------
np.ndarray: A (3 * num_boosters, ) array
"""
return np.concatenate(
[
Expand All @@ -156,9 +164,11 @@ def physics_update(
"""Converts booster settings into forces on the booster and inertia change on fuel tank.
Args:
----
ignition (np.ndarray): (num_boosters,) array of booleans for engine on or off.
pwm (np.ndarray): (num_boosters,) array of floats between [0, 1] for min or max thrust.
rotation (np.ndarray): (num_boosters, 3, 3) rotation matrices to rotate each booster's thrust axis around, this is readily obtained from the `gimbals` component.
"""
assert np.all(ignition >= 0.0) and np.all(
ignition <= 1.0
Expand Down Expand Up @@ -214,8 +224,10 @@ def _compute_thrust_mass_inertia(
"""_compute_thrust_mass_inertia.
Args:
----
ignition (np.ndarray): (num_boosters,) array of booleans for engine on or off.
pwm (np.ndarray): (num_boosters,) array of floats between [0, 1] for min or max thrust.
"""
# if not reignitable, logical or ignition_state with ignition
# otherwise, just follow ignition
Expand Down
6 changes: 6 additions & 0 deletions PyFlyt/core/abstractions/boring_bodies.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@ class BoringBodies:
The `BoringBodies` component is used to represent a normal body moving through the air.
Args:
----
p (bullet_client.BulletClient): PyBullet physics client ID.
physics_period (float): physics period of the simulation.
np_random (np.random.RandomState): random number generator of the simulation.
uav_id (int): ID of the drone.
body_ids (np.ndarray | Sequence[int]): (n,) array of IDs for the links representing the bodies.
drag_coefs (np.ndarray): (n, 3) array of drag coefficients for each body in the link-referenced XYZ directions.
normal_areas (np.ndarray): (n, 3) array of frontal areas in the link-referenced XYZ directions.
"""

def __init__(
Expand All @@ -35,13 +37,15 @@ def __init__(
"""Used for simulating a body moving through the air.
Args:
----
p (bullet_client.BulletClient): PyBullet physics client ID.
physics_period (float): physics period of the simulation.
np_random (np.random.RandomState): random number generator of the simulation.
uav_id (int): ID of the drone.
body_ids (np.ndarray | Sequence[int]): (n,) array of IDs for the links representing the bodies.
drag_coefs (np.ndarray): (n, 3) array of drag coefficients for each body in the link-referenced XYZ directions.
normal_areas (np.ndarray): (n, 3) array of frontal areas in the link-referenced XYZ directions.
"""
self.p = p
self.physics_period = physics_period
Expand Down Expand Up @@ -76,7 +80,9 @@ def state_update(self, rotation_matrix: np.ndarray):
"""Updates the local surface velocity of the boring body.
Args:
----
rotation_matrix (np.ndarray): (3, 3) rotation_matrix of the main body
"""
# get all the states for all the bodies
link_states = self.p.getLinkStates(
Expand Down
12 changes: 10 additions & 2 deletions PyFlyt/core/abstractions/camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class Camera:
On image capture, the camera returns an RGBA image, a depth map, and a segmentation map with pixel values representing the IDs of objects in the environment.
Args:
----
p (bullet_client.BulletClient): PyBullet physics client ID.
uav_id (int): ID of the drone.
camera_id (int): integer representing the ID of the link that the camera is attached to.
Expand All @@ -29,6 +30,7 @@ class Camera:
camera_position_offset (np.ndarray = np.array([0.0, 0.0, 0.0])): an (3,) array representing an offset of where the camera is from the center of the link in `camera_id`.
is_tracking_camera (bool = False): if the camera is a tracking camera, the focus point of the camera is adjusted to focus on the center body of the aircraft instead of at infinity.
cinematic (bool = False): it's not a bug, it's a feature.
"""

def __init__(
Expand All @@ -47,6 +49,7 @@ def __init__(
"""Used for implementing camera modules.
Args:
----
p (bullet_client.BulletClient): PyBullet physics client ID.
uav_id (int): ID of the drone.
camera_id (int): integer representing the ID of the link that the camera is attached to.
Expand All @@ -57,6 +60,7 @@ def __init__(
camera_position_offset (np.ndarray = np.array([0.0, 0.0, 0.0])): an (3,) array representing an offset of where the camera is from the center of the link in `camera_id`.
is_tracking_camera (bool = False): if the camera is a tracking camera, the focus point of the camera is adjusted to focus on the center body of the aircraft instead of at infinity.
cinematic (bool = False): it's not a bug, it's a feature.
"""
check_numpy()
if is_tracking_camera and use_gimbal:
Expand Down Expand Up @@ -96,8 +100,10 @@ def __init__(
def view_mat(self) -> np.ndarray:
"""Generates the view matrix for the camera depending on the current orientation and implicit parameters.
Returns:
Returns
-------
np.ndarray: view matrix.
"""
# get the state of the camera on the robot
camera_state = self.p.getLinkState(self.uav_id, self.camera_id)
Expand Down Expand Up @@ -155,8 +161,10 @@ def physics_update(self):
def capture_image(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Captures the 3 relevant images from the camera.
Returns:
Returns
-------
tuple[np.ndarray, np.ndarray, np.ndarray]: rgbaImg, depthImg, segImg
"""
_, _, rgbaImg, depthImg, segImg = self.p.getCameraImage(
height=self.camera_resolution[0],
Expand Down
14 changes: 13 additions & 1 deletion PyFlyt/core/abstractions/gimbals.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ class Gimbals:
Each gimbal can rotate about two arbitrary axis that may not be orthogonal to each other.
Args:
----
p (bullet_client.BulletClient): PyBullet physics client ID.
physics_period (float): physics period of the simulation.
np_random (np.random.RandomState): random number generator of the simulation.
gimbal_unit_1 (np.ndarray): first unit vector that the gimbal rotates around.
gimbal_unit_2 (np.ndarray): second unit vector that the gimbal rotates around.
gimbal_tau (np.ndarray): gimbal actuation time constant.
gimbal_range_degrees (np.ndarray): gimbal actuation range in degrees.
"""

def __init__(
Expand All @@ -38,13 +40,15 @@ def __init__(
"""Used for simulating an array of gimbals.
Args:
----
p (bullet_client.BulletClient): PyBullet physics client ID.
physics_period (float): physics period of the simulation.
np_random (np.random.RandomState): random number generator of the simulation.
gimbal_unit_1 (np.ndarray): first unit vector that the gimbal rotates around.
gimbal_unit_2 (np.ndarray): second unit vector that the gimbal rotates around.
gimbal_tau (np.ndarray): gimbal actuation time constant.
gimbal_range_degrees (np.ndarray): gimbal actuation range in degrees.
"""
self.p = p
self.physics_period = physics_period
Expand Down Expand Up @@ -119,8 +123,10 @@ def reset(self):
def get_states(self) -> np.ndarray:
"""Gets the current state of the components.
Returns:
Returns
-------
np.ndarray: a (2 * num_gimbals, ) array where every pair of values represents the current state of the gimbal
"""
return np.concatenate(
[
Expand All @@ -142,10 +148,13 @@ def compute_rotation(self, gimbal_command: np.ndarray) -> np.ndarray:
"""Returns a rotation vector after the gimbal rotation.
Args:
----
gimbal_command (np.ndarray): (num_gimbals, 2) array of floats between [-1, 1].
Returns:
-------
rotation_vector (np.ndarray): (num_gimbals, 3, 3) rotation matrices for all gimbals.
"""
assert np.all(gimbal_command >= -1.0) and np.all(
gimbal_command <= 1.0
Expand Down Expand Up @@ -182,14 +191,17 @@ def _jitted_compute_rotation(
"""Compute the rotation matrix given the gimbal action values.
Args:
----
gimbal_angles (np.ndarray): gimbal_angles
w1 (np.ndarray): w1 from self
w2 (np.ndarray): w2 from self
w1_squared (np.ndarray): w1_squared from self
w2_squared (np.ndarray): w2_squared from self
Returns:
-------
tuple[np.ndarray, np.ndarray]:
"""
# precompute some things
sin_angles = np.sin(gimbal_angles)
Expand Down
Loading

0 comments on commit 1c8074b

Please sign in to comment.