Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: enable subclassing for Staypoints #509

Merged
merged 1 commit into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/modules/model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ Positionfixes
.. autoclass:: trackintel.model.positionfixes.Positionfixes
:members:

StaypointsAccessor
Staypoints
------------------

.. autoclass:: trackintel.model.staypoints.StaypointsAccessor
.. autoclass:: trackintel.model.staypoints.Staypoints
:members:

TriplegsAccessor
Expand Down
34 changes: 33 additions & 1 deletion tests/model/test_staypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from shapely.geometry import LineString

import trackintel as ti
from trackintel import Staypoints


@pytest.fixture
Expand All @@ -15,7 +16,7 @@ def testdata_sp():


class TestStaypoints:
"""Tests for the StaypointsAccessor."""
"""Tests for the Staypoints class."""

def test_accessor_columns(self, testdata_sp):
"""Test if the as_staypoints accessor checks the required column for staypoints."""
Expand Down Expand Up @@ -46,3 +47,34 @@ def test_staypoints_center(self, testdata_sp):
"""Check if sp has center method and returns (lat, lon) pairs as geometry."""
sp = testdata_sp.copy()
assert len(sp.as_staypoints.center) == 2

def test_check_suceeding(self, testdata_sp):
"""Test if check returns True on valid sp"""
assert Staypoints._check(testdata_sp)

def test_check_missing_columns(self, testdata_sp):
"""Test if check returns False if column is missing"""
assert not Staypoints._check(testdata_sp.drop(columns="user_id"))

def test_check_no_tz(self, testdata_sp):
"""Test if check returns False if datetime columns have no tz"""
tmp = testdata_sp["started_at"]
testdata_sp["started_at"] = testdata_sp["started_at"].dt.tz_localize(None)
assert not Staypoints._check(testdata_sp)
testdata_sp["started_at"] = tmp
testdata_sp["finished_at"] = testdata_sp["finished_at"].dt.tz_localize(None)
assert not Staypoints._check(testdata_sp)

def test_check_false_geometry_type(self, testdata_sp):
"""Test if check returns False if geometry type is wrong"""
testdata_sp["geom"] = LineString(
[(13.476808430, 48.573711823), (13.506804, 48.939008), (13.4664690, 48.5706414)]
)
assert not Staypoints._check(testdata_sp)

def test_check_ignore_false_geometry_type(self, testdata_sp):
"""Test if check returns True if geometry type is wrong but validate_geometry is set to False"""
testdata_sp["geom"] = LineString(
[(13.476808430, 48.573711823), (13.506804, 48.939008), (13.4664690, 48.5706414)]
)
assert Staypoints._check(testdata_sp, validate_geometry=False)
1 change: 1 addition & 0 deletions trackintel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from trackintel.io.file import read_tours_csv

from trackintel.model.positionfixes import Positionfixes
from trackintel.model.staypoints import Staypoints

from trackintel.__version__ import __version__
from .core import print_version
2 changes: 1 addition & 1 deletion trackintel/analysis/tracking_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def _split_overlaps(source, granularity="day"):
gdf = source.copy()
gdf[["started_at", "finished_at"]] = gdf.apply(_get_times, axis="columns", result_type="expand", freq=freq)
# must call DataFrame.explode directly because GeoDataFrame.explode cannot be used on multiple columns
gdf = super(type(gdf), gdf).explode(["started_at", "finished_at"], ignore_index=True)
gdf = pd.DataFrame.explode(gdf, ["started_at", "finished_at"], ignore_index=True)
if "duration" in gdf.columns:
gdf["duration"] = gdf["finished_at"] - gdf["started_at"]
return gdf
Expand Down
3 changes: 2 additions & 1 deletion trackintel/io/postgis.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,8 @@ def read_staypoints_postgis(
def write_staypoints_postgis(
staypoints, name, con, schema=None, if_exists="fail", index=True, index_label=None, chunksize=None, dtype=None
):
staypoints.to_postgis(
gpd.GeoDataFrame.to_postgis(
staypoints,
name,
con,
schema=schema,
Expand Down
2 changes: 1 addition & 1 deletion trackintel/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .positionfixes import Positionfixes
from .staypoints import StaypointsAccessor
from .staypoints import Staypoints
from .triplegs import TriplegsAccessor
from .locations import LocationsAccessor
from .trips import TripsAccessor
Expand Down
81 changes: 52 additions & 29 deletions trackintel/model/staypoints.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
import pandas as pd

import trackintel as ti
from trackintel.analysis.labelling import create_activity_flag
from trackintel.analysis.tracking_quality import temporal_tracking_quality
from trackintel.io.file import write_staypoints_csv
from trackintel.io.postgis import write_staypoints_postgis
from trackintel.model.util import _copy_docstring, _register_trackintel_accessor
from trackintel.model.util import (
TrackintelBase,
TrackintelGeoDataFrame,
_copy_docstring,
_register_trackintel_accessor,
)
from trackintel.preprocessing.filter import spatial_filter
from trackintel.preprocessing.staypoints import generate_locations, merge_staypoints
from trackintel.visualization.staypoints import plot_staypoints

_required_columns = ["user_id", "started_at", "finished_at"]


@_register_trackintel_accessor("as_staypoints")
class StaypointsAccessor(object):
"""A pandas accessor to treat (Geo)DataFrames as collections of `Staypoints`.
class Staypoints(TrackintelBase, TrackintelGeoDataFrame):
"""A pandas accessor to treat a GeoDataFrame as collections of `Staypoints`.

This will define certain methods and accessors, as well as make sure that the DataFrame
adheres to some requirements.
Expand Down Expand Up @@ -41,27 +49,23 @@ class StaypointsAccessor(object):
>>> df.as_staypoints.generate_locations()
"""

required_columns = ["user_id", "started_at", "finished_at"]
def __init__(self, *args, validate_geometry=True, **kwargs):
super().__init__(*args, **kwargs)
self._validate(self, validate_geometry=validate_geometry)

def __init__(self, pandas_obj):
self._validate(pandas_obj)
self._obj = pandas_obj
# create circular reference directly -> avoid second call of init via accessor
@property
def as_staypoints(self):
return self

@staticmethod
def _validate(obj):
def _validate(obj, validate_geometry=True):
# check columns
if any([c not in obj.columns for c in StaypointsAccessor.required_columns]):
if any([c not in obj.columns for c in _required_columns]):
raise AttributeError(
"To process a DataFrame as a collection of staypoints, it must have the properties"
f" {StaypointsAccessor.required_columns}, but it has {', '.join(obj.columns)}."
f" {_required_columns}, but it has {', '.join(obj.columns)}."
)
# check geometry
assert (
obj.geometry.is_valid.all()
), "Not all geometries are valid. Try x[~ x.geometry.is_valid] where x is you GeoDataFrame"
if obj.geometry.iloc[0].geom_type != "Point":
raise AttributeError("The geometry must be a Point (only first checked).")

# check timestamp dtypes
assert pd.api.types.is_datetime64tz_dtype(
obj["started_at"]
Expand All @@ -70,11 +74,32 @@ def _validate(obj):
obj["finished_at"]
), f"dtype of finished_at is {obj['finished_at'].dtype} but has to be tz aware datetime64"

if validate_geometry:
# check geometry
assert (
obj.geometry.is_valid.all()
), "Not all geometries are valid. Try x[~ x.geometry.is_valid] where x is you GeoDataFrame"
if obj.geometry.iloc[0].geom_type != "Point":
raise AttributeError("The geometry must be a Point (only first checked).")

@staticmethod
def _check(obj, validate_geometry=True):
"""Check does the same as _validate but returns bool instead of potentially raising an error."""
if any([c not in obj.columns for c in _required_columns]):
return False
if not pd.api.types.is_datetime64tz_dtype(obj["started_at"]):
return False
if not pd.api.types.is_datetime64tz_dtype(obj["finished_at"]):
return False
if validate_geometry:
return obj.geometry.is_valid.all() and obj.geometry.iloc[0].geom_type == "Point"
return True

@property
def center(self):
"""Return the center coordinate of this collection of staypoints."""
lat = self._obj.geometry.y
lon = self._obj.geometry.x
lat = self.geometry.y
lon = self.geometry.x
return (float(lon.mean()), float(lat.mean()))

@_copy_docstring(generate_locations)
Expand All @@ -84,7 +109,7 @@ def generate_locations(self, *args, **kwargs):

See :func:`trackintel.preprocessing.staypoints.generate_locations`.
"""
return ti.preprocessing.staypoints.generate_locations(self._obj, *args, **kwargs)
return ti.preprocessing.staypoints.generate_locations(self, *args, **kwargs)

@_copy_docstring(merge_staypoints)
def merge_staypoints(self, *args, **kwargs):
Expand All @@ -93,7 +118,7 @@ def merge_staypoints(self, *args, **kwargs):

See :func:`trackintel.preprocessing.staypoints.merge_staypoints`.
"""
return ti.preprocessing.staypoints.merge_staypoints(self._obj, *args, **kwargs)
return ti.preprocessing.staypoints.merge_staypoints(self, *args, **kwargs)

@_copy_docstring(create_activity_flag)
def create_activity_flag(self, *args, **kwargs):
Expand All @@ -102,7 +127,7 @@ def create_activity_flag(self, *args, **kwargs):

See :func:`trackintel.analysis.labelling.create_activity_flag`.
"""
return ti.analysis.labelling.create_activity_flag(self._obj, *args, **kwargs)
return ti.analysis.labelling.create_activity_flag(self, *args, **kwargs)

@_copy_docstring(spatial_filter)
def spatial_filter(self, *args, **kwargs):
Expand All @@ -111,7 +136,7 @@ def spatial_filter(self, *args, **kwargs):

See :func:`trackintel.preprocessing.filter.spatial_filter`.
"""
return ti.preprocessing.filter.spatial_filter(self._obj, *args, **kwargs)
return ti.preprocessing.filter.spatial_filter(self, *args, **kwargs)

@_copy_docstring(plot_staypoints)
def plot(self, *args, **kwargs):
Expand All @@ -120,7 +145,7 @@ def plot(self, *args, **kwargs):

See :func:`trackintel.visualization.staypoints.plot_staypoints`.
"""
ti.visualization.staypoints.plot_staypoints(self._obj, *args, **kwargs)
ti.visualization.staypoints.plot_staypoints(self, *args, **kwargs)

@_copy_docstring(write_staypoints_csv)
def to_csv(self, filename, *args, **kwargs):
Expand All @@ -129,7 +154,7 @@ def to_csv(self, filename, *args, **kwargs):

See :func:`trackintel.io.file.write_staypoints_csv`.
"""
ti.io.file.write_staypoints_csv(self._obj, filename, *args, **kwargs)
ti.io.file.write_staypoints_csv(self, filename, *args, **kwargs)

@_copy_docstring(write_staypoints_postgis)
def to_postgis(
Expand All @@ -140,9 +165,7 @@ def to_postgis(

See :func:`trackintel.io.postgis.write_staypoints_postgis`.
"""
ti.io.postgis.write_staypoints_postgis(
self._obj, name, con, schema, if_exists, index, index_label, chunksize, dtype
)
ti.io.postgis.write_staypoints_postgis(self, name, con, schema, if_exists, index, index_label, chunksize, dtype)

@_copy_docstring(temporal_tracking_quality)
def temporal_tracking_quality(self, *args, **kwargs):
Expand All @@ -151,4 +174,4 @@ def temporal_tracking_quality(self, *args, **kwargs):

See :func:`trackintel.analysis.tracking_quality.temporal_tracking_quality`.
"""
return ti.analysis.tracking_quality.temporal_tracking_quality(self._obj, *args, **kwargs)
return ti.analysis.tracking_quality.temporal_tracking_quality(self, *args, **kwargs)
Loading