diff --git a/src/otf_api/api.py b/src/otf_api/api.py index 6b63a13..8766eeb 100644 --- a/src/otf_api/api.py +++ b/src/otf_api/api.py @@ -117,12 +117,14 @@ def _performance_summary_request( def get_classes( self, + start_date: date | None = None, + end_date: date | None = None, studio_uuids: list[str] | None = None, include_home_studio: bool | None = None, limit: int | None = None, + filters: list[filters.ClassFilter] | filters.ClassFilter | None = None, exclude_cancelled: bool = True, exclude_unbookable: bool = True, - filters: list[filters.ClassFilter] | filters.ClassFilter | None = None, ) -> models.OtfClassList: """Get the classes for the user. @@ -130,15 +132,17 @@ def get_classes( UUIDs are provided, it will default to the user's home studio. Args: + start_date (date | None): The start date for the classes. Default is None. + end_date (date | None): The end date for the classes. Default is None. studio_uuids (list[str] | None): The studio UUIDs to get the classes for. Default is None, which will\ default to the user's home studio only. include_home_studio (bool): Whether to include the home studio in the classes. Default is True. limit (int | None): Limit the number of classes returned. Default is None. + filters (list[ClassFilter] | ClassFilter | None): A list of filters to apply to the classes, or a single\ + filter. Filters are applied as an OR operation. Default is None. exclude_cancelled (bool): Whether to exclude classes that were cancelled by the studio. Default is True. exclude_unbookable (bool): Whether to exclude classes that are outside the scheduling window. Default is\ True. - filters (list[ClassFilter] | ClassFilter | None): A list of filters to apply to the classes, or a single\ - filter. Filters are applied as an OR operation. Default is None. Returns: OtfClassList: The classes for the user. @@ -179,6 +183,20 @@ def get_classes( if limit: classes_list.classes = classes_list.classes[:limit] + # apply date filters + if start_date: + if not isinstance(start_date, date | datetime): + raise ValueError("start_date must be a date or datetime object") + start_date = start_date.date() if isinstance(start_date, datetime) else start_date + classes_list.classes = [c for c in classes_list.classes if c.starts_at_local.date() >= start_date] + + if end_date: + if not isinstance(end_date, date | datetime): + raise ValueError("end_date must be a date or datetime object") + end_date = end_date.date() if isinstance(end_date, datetime) else end_date + classes_list.classes = [c for c in classes_list.classes if c.starts_at_local.date() <= end_date] + + # apply provided filters if filters: filtered_classes: list[models.OtfClass] = []