Skip to content

Commit

Permalink
ENH: enable subclassing for Triplegs (#510)
Browse files Browse the repository at this point in the history
* ENH: enable subclassing for Triplegs

Here comes the next class :)
For more info see #490 .

* CLN: correct accessor property name

* TST: add test for property as_triplegs

---------

Co-authored-by: Ye <[email protected]>
  • Loading branch information
bifbof and hongyeehh authored Aug 17, 2023
1 parent d3ba813 commit bed90f0
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 38 deletions.
4 changes: 2 additions & 2 deletions docs/modules/model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ Staypoints
.. autoclass:: trackintel.model.staypoints.Staypoints
:members:

TriplegsAccessor
Triplegs
----------------

.. autoclass:: trackintel.model.triplegs.TriplegsAccessor
.. autoclass:: trackintel.model.triplegs.Triplegs
:members:

LocationsAccessor
Expand Down
41 changes: 39 additions & 2 deletions tests/model/test_triplegs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import pytest

import pytest
from shapely.geometry import Point

import trackintel as ti
from trackintel import Triplegs


@pytest.fixture
Expand All @@ -17,7 +18,7 @@ def testdata_tpls():


class TestTriplegs:
"""Tests for the TriplegsAccessor."""
"""Tests for the Triplegs class."""

def test_accessor_column(self, testdata_tpls):
"""Test if the as_triplegs accessor checks the required column for triplegs."""
Expand Down Expand Up @@ -45,3 +46,39 @@ def test_accessor_geometry_type(self, testdata_tpls):
with pytest.raises(AttributeError, match="The geometry must be a LineString"):
tpls["geom"] = Point([(13.476808430, 48.573711823)])
tpls.as_triplegs

def test_accessor_recursive(self, testdata_tpls):
tpls = testdata_tpls.as_triplegs
assert type(tpls) == Triplegs
assert id(tpls) == id(tpls.as_triplegs)

def test_check_suceeding(self, testdata_tpls):
"""Test if check returns True on valid pfs"""
assert Triplegs._check(testdata_tpls)

def test_check_missing_columns(self, testdata_tpls):
"""Test if check returns False if column is missing"""
assert not Triplegs._check(testdata_tpls.drop(columns="user_id"))

def test_check_empty_df(self, testdata_tpls):
"""Test if check returns False if DataFrame is empty"""
assert not Triplegs._check(testdata_tpls.drop(testdata_tpls.index))

def test_check_no_tz(self, testdata_tpls):
"""Test if check returns False if tracked at column has no tz"""
tmp = testdata_tpls["started_at"]
testdata_tpls["started_at"] = testdata_tpls["started_at"].dt.tz_localize(None)
assert not Triplegs._check(testdata_tpls)
testdata_tpls["started_at"] = tmp
testdata_tpls["finished_at"] = testdata_tpls["finished_at"].dt.tz_localize(None)
assert not Triplegs._check(testdata_tpls)

def test_check_false_geometry_type(self, testdata_tpls):
"""Test if check returns False if geometry type is wrong"""
testdata_tpls["geom"] = Point([(13.476808430, 48.573711823)])
assert not Triplegs._check(testdata_tpls)

def test_check_ignore_false_geometry_type(self, testdata_tpls):
"""Test if check returns True if geometry type is wrong but validate_geometry is set to False"""
testdata_tpls["geom"] = Point([(13.476808430, 48.573711823)])
assert Triplegs._check(testdata_tpls, validate_geometry=False)
2 changes: 1 addition & 1 deletion tests/preprocessing/test_triplegs.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_accessor(self, example_triplegs):
# test if generated trips are equal
assert_geodataframe_equal(trips_expl, trips_acc)
assert_geodataframe_equal(sp_expl, sp_acc)
assert_geodataframe_equal(tpls_expl, tpls_acc)
assert_geodataframe_equal(tpls_acc, tpls_expl)

def test_accessor_arguments(self, example_triplegs):
"""Test if the accessor is robust to different ways to receive arguments"""
Expand Down
1 change: 1 addition & 0 deletions trackintel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from trackintel.io.file import read_tours_csv

from trackintel.model.positionfixes import Positionfixes
from trackintel.model.triplegs import Triplegs
from trackintel.model.staypoints import Staypoints

from trackintel.__version__ import __version__
Expand Down
3 changes: 2 additions & 1 deletion trackintel/io/postgis.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ def read_triplegs_postgis(
def write_triplegs_postgis(
triplegs, name, con, schema=None, if_exists="fail", index=True, index_label=None, chunksize=None, dtype=None
):
triplegs.to_postgis(
gpd.GeoDataFrame.to_postgis(
triplegs,
name,
con,
schema=schema,
Expand Down
2 changes: 1 addition & 1 deletion trackintel/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .positionfixes import Positionfixes
from .triplegs import Triplegs
from .staypoints import Staypoints
from .triplegs import TriplegsAccessor
from .locations import LocationsAccessor
from .trips import TripsAccessor
from .tours import ToursAccessor
86 changes: 56 additions & 30 deletions trackintel/model/triplegs.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
import pandas as pd
import trackintel as ti

import trackintel as ti
from trackintel.analysis.labelling import predict_transport_mode
from trackintel.analysis.modal_split import calculate_modal_split
from trackintel.analysis.tracking_quality import temporal_tracking_quality
from trackintel.geogr.distances import calculate_distance_matrix
from trackintel.io.file import write_triplegs_csv
from trackintel.io.postgis import write_triplegs_postgis
from trackintel.model.util import _copy_docstring, get_speed_triplegs, _register_trackintel_accessor
from trackintel.model.util import (
TrackintelBase,
TrackintelGeoDataFrame,
_copy_docstring,
_register_trackintel_accessor,
get_speed_triplegs,
)
from trackintel.preprocessing.filter import spatial_filter
from trackintel.preprocessing.triplegs import generate_trips
from trackintel.visualization.triplegs import plot_triplegs

_required_columns = ["user_id", "started_at", "finished_at"]


@_register_trackintel_accessor("as_triplegs")
class TriplegsAccessor(object):
"""A pandas accessor to treat (Geo)DataFrames as collections of `Tripleg`.
class Triplegs(TrackintelBase, TrackintelGeoDataFrame):
"""A pandas accessor to treat a GeoDataFrame as a collections of `Tripleg`.
This will define certain methods and accessors, as well as make sure that the DataFrame
adheres to some requirements.
Expand All @@ -40,27 +48,24 @@ class TriplegsAccessor(object):
>>> df.as_triplegs.plot()
"""

required_columns = ["user_id", "started_at", "finished_at"]
def __init__(self, *args, validate_geometry=True, **kwargs):
super().__init__(*args, **kwargs)
self._validate(self, validate_geometry=validate_geometry)

def __init__(self, pandas_obj):
self._validate(pandas_obj)
self._obj = pandas_obj
# create circular reference directly -> avoid second call of init via accessor
@property
def as_triplegs(self):
return self

@staticmethod
def _validate(obj):
def _validate(obj, validate_geometry=True):
assert obj.shape[0] > 0, f"Geodataframe is empty with shape: {obj.shape}"
# check columns
if any([c not in obj.columns for c in TriplegsAccessor.required_columns]):
if any([c not in obj.columns for c in _required_columns]):
raise AttributeError(
"To process a DataFrame as a collection of triplegs, it must have the properties"
f" {TriplegsAccessor.required_columns}, but it has [{', '.join(obj.columns)}]."
f" {_required_columns}, but it has [{', '.join(obj.columns)}]."
)
# check geometry
assert (
obj.geometry.is_valid.all()
), "Not all geometries are valid. Try x[~ x.geometry.is_valid] where x is you GeoDataFrame"
if obj.geometry.iloc[0].geom_type != "LineString":
raise AttributeError("The geometry must be a LineString (only first checked).")

# check timestamp dtypes
assert pd.api.types.is_datetime64tz_dtype(
Expand All @@ -70,14 +75,37 @@ def _validate(obj):
obj["finished_at"]
), f"dtype of finished_at is {obj['finished_at'].dtype} but has to be datetime64 and timezone aware"

# check geometry
if validate_geometry:
assert (
obj.geometry.is_valid.all()
), "Not all geometries are valid. Try x[~ x.geometry.is_valid] where x is you GeoDataFrame"
if obj.geometry.iloc[0].geom_type != "LineString":
raise AttributeError("The geometry must be a LineString (only first checked).")

@staticmethod
def _check(obj, validate_geometry=True):
"""Check does the same as _validate but returns bool instead of potentially raising an error."""
if any([c not in obj.columns for c in _required_columns]):
return False
if obj.shape[0] <= 0:
return False
if not pd.api.types.is_datetime64tz_dtype(obj["started_at"]):
return False
if not pd.api.types.is_datetime64tz_dtype(obj["finished_at"]):
return False
if validate_geometry:
return obj.geometry.is_valid.all() and obj.geometry.iloc[0].geom_type == "LineString"
return True

@_copy_docstring(plot_triplegs)
def plot(self, *args, **kwargs):
"""
Plot this collection of triplegs.
See :func:`trackintel.visualization.triplegs.plot_triplegs`.
"""
ti.visualization.triplegs.plot_triplegs(self._obj, *args, **kwargs)
ti.visualization.triplegs.plot_triplegs(self, *args, **kwargs)

@_copy_docstring(write_triplegs_csv)
def to_csv(self, filename, *args, **kwargs):
Expand All @@ -86,7 +114,7 @@ def to_csv(self, filename, *args, **kwargs):
See :func:`trackintel.io.file.write_triplegs_csv`.
"""
ti.io.file.write_triplegs_csv(self._obj, filename, *args, **kwargs)
ti.io.file.write_triplegs_csv(self, filename, *args, **kwargs)

@_copy_docstring(write_triplegs_postgis)
def to_postgis(
Expand All @@ -97,9 +125,7 @@ def to_postgis(
See :func:`trackintel.io.postgis.store_positionfixes_postgis`.
"""
ti.io.postgis.write_triplegs_postgis(
self._obj, name, con, schema, if_exists, index, index_label, chunksize, dtype
)
ti.io.postgis.write_triplegs_postgis(self, name, con, schema, if_exists, index, index_label, chunksize, dtype)

@_copy_docstring(calculate_distance_matrix)
def calculate_distance_matrix(self, *args, **kwargs):
Expand All @@ -108,7 +134,7 @@ def calculate_distance_matrix(self, *args, **kwargs):
See :func:`trackintel.geogr.distances.calculate_distance_matrix`.
"""
return ti.geogr.distances.calculate_distance_matrix(self._obj, *args, **kwargs)
return ti.geogr.distances.calculate_distance_matrix(self, *args, **kwargs)

@_copy_docstring(spatial_filter)
def spatial_filter(self, *args, **kwargs):
Expand All @@ -117,7 +143,7 @@ def spatial_filter(self, *args, **kwargs):
See :func:`trackintel.preprocessing.filter.spatial_filter`.
"""
return ti.preprocessing.filter.spatial_filter(self._obj, *args, **kwargs)
return ti.preprocessing.filter.spatial_filter(self, *args, **kwargs)

@_copy_docstring(generate_trips)
def generate_trips(self, *args, **kwargs):
Expand All @@ -128,14 +154,14 @@ def generate_trips(self, *args, **kwargs):
"""
# if staypoints in kwargs: 'staypoints' can not be in args as it would be the first argument
if "staypoints" in kwargs:
return ti.preprocessing.triplegs.generate_trips(triplegs=self._obj, **kwargs)
return ti.preprocessing.triplegs.generate_trips(triplegs=self, **kwargs)
# if 'staypoints' no in kwargs it has to be the first argument in 'args'
else:
assert len(args) <= 1, (
"All arguments except 'staypoints' have to be given as keyword arguments. You gave"
f" {args[1:]} as positional arguments."
)
return ti.preprocessing.triplegs.generate_trips(staypoints=args[0], triplegs=self._obj, **kwargs)
return ti.preprocessing.triplegs.generate_trips(staypoints=args[0], triplegs=self, **kwargs)

@_copy_docstring(predict_transport_mode)
def predict_transport_mode(self, *args, **kwargs):
Expand All @@ -144,7 +170,7 @@ def predict_transport_mode(self, *args, **kwargs):
See :func:`trackintel.analysis.labelling.predict_transport_mode`.
"""
return ti.analysis.labelling.predict_transport_mode(self._obj, *args, **kwargs)
return ti.analysis.labelling.predict_transport_mode(self, *args, **kwargs)

@_copy_docstring(calculate_modal_split)
def calculate_modal_split(self, *args, **kwargs):
Expand All @@ -153,7 +179,7 @@ def calculate_modal_split(self, *args, **kwargs):
See :func:`trackintel.analysis.modal_split.calculate_modal_split`.
"""
return ti.analysis.modal_split.calculate_modal_split(self._obj, *args, **kwargs)
return ti.analysis.modal_split.calculate_modal_split(self, *args, **kwargs)

@_copy_docstring(temporal_tracking_quality)
def temporal_tracking_quality(self, *args, **kwargs):
Expand All @@ -162,7 +188,7 @@ def temporal_tracking_quality(self, *args, **kwargs):
See :func:`trackintel.analysis.tracking_quality.temporal_tracking_quality`.
"""
return ti.analysis.tracking_quality.temporal_tracking_quality(self._obj, *args, **kwargs)
return ti.analysis.tracking_quality.temporal_tracking_quality(self, *args, **kwargs)

@_copy_docstring(get_speed_triplegs)
def get_speed(self, *args, **kwargs):
Expand All @@ -171,4 +197,4 @@ def get_speed(self, *args, **kwargs):
See :func:`trackintel.model.util.get_speed_triplegs`.
"""
return ti.model.util.get_speed_triplegs(self._obj, *args, **kwargs)
return ti.model.util.get_speed_triplegs(self, *args, **kwargs)
3 changes: 2 additions & 1 deletion trackintel/visualization/triplegs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import geopandas as gpd
import matplotlib.pyplot as plt

from trackintel.visualization.osm import plot_osm_streets
Expand Down Expand Up @@ -60,7 +61,7 @@ def plot_triplegs(
south = min(triplegs_bounds.miny) - 0.03
plot_osm_streets(north, south, east, west, ax)

triplegs.plot(ax=ax, cmap="viridis")
gpd.GeoDataFrame.plot(triplegs)(ax=ax, cmap="viridis")
ax.set_aspect("equal", adjustable="box")

if out_filename is not None:
Expand Down

0 comments on commit bed90f0

Please sign in to comment.