Skip to content

Commit

Permalink
ENH: Use of subtyping instead of accessors (#490)
Browse files Browse the repository at this point in the history
* ENH: positionfixes to class

* CLN: clean up some comments

* ENH: call correct methods in `to_csv` and `plot` directly

* ENH: rename `PositionfixesAccessor` to `Positionfixes`

* ENH: simplify validating of Positionfixes

* TST: add tests for Positionfixes and TrackintelGeoDataFrame

* CLN: call merge method and not merge function
  • Loading branch information
bifbof authored Aug 15, 2023
1 parent 68ac2b3 commit e2bc973
Show file tree
Hide file tree
Showing 9 changed files with 274 additions and 45 deletions.
4 changes: 2 additions & 2 deletions docs/modules/model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ Available Accessors

The following accessors are available within *trackintel*.

PositionfixesAccessor
Positionfixes
---------------------

.. autoclass:: trackintel.model.positionfixes.PositionfixesAccessor
.. autoclass:: trackintel.model.positionfixes.Positionfixes
:members:

StaypointsAccessor
Expand Down
34 changes: 33 additions & 1 deletion tests/model/test_positionfixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from shapely.geometry import LineString

import trackintel as ti
from trackintel import Positionfixes


@pytest.fixture
Expand All @@ -15,7 +16,7 @@ def testdata_geolife():


class TestPositionfixes:
"""Tests for the PositionfixesAccessor."""
"""Tests for the Positionfixes class."""

def test_accessor_column(self, testdata_geolife):
"""Test if the as_positionfixes accessor checks the required column for positionfixes."""
Expand Down Expand Up @@ -55,3 +56,34 @@ def test_similarity_matrix(self, testdata_geolife):
accessor_result = pfs.as_positionfixes.calculate_distance_matrix(dist_metric="haversine", n_jobs=1)
function_result = ti.geogr.distances.calculate_distance_matrix(pfs, dist_metric="haversine", n_jobs=1)
assert np.allclose(accessor_result, function_result)

def test_check_suceeding(self, testdata_geolife):
"""Test if check returns True on valid pfs"""
assert Positionfixes._check(testdata_geolife)

def test_check_missing_columns(self, testdata_geolife):
"""Test if check returns False if column is missing"""
assert not Positionfixes._check(testdata_geolife.drop(columns="user_id"))

def test_check_empty_df(self, testdata_geolife):
"""Test if check returns False if DataFrame is empty"""
assert not Positionfixes._check(testdata_geolife.drop(testdata_geolife.index))

def test_check_no_tz(self, testdata_geolife):
"""Test if check returns False if tracked at column has no tz"""
testdata_geolife["tracked_at"] = testdata_geolife["tracked_at"].dt.tz_localize(None)
assert not Positionfixes._check(testdata_geolife)

def test_check_false_geometry_type(self, testdata_geolife):
"""Test if check returns False if geometry type is wrong"""
testdata_geolife["geom"] = LineString(
[(13.476808430, 48.573711823), (13.506804, 48.939008), (13.4664690, 48.5706414)]
)
assert not Positionfixes._check(testdata_geolife)

def test_check_ignore_false_geometry_type(self, testdata_geolife):
"""Test if check returns True if geometry type is wrong but validate_geometry is set to False"""
testdata_geolife["geom"] = LineString(
[(13.476808430, 48.573711823), (13.506804, 48.939008), (13.4664690, 48.5706414)]
)
assert Positionfixes._check(testdata_geolife, validate_geometry=False)
110 changes: 110 additions & 0 deletions tests/model/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
NonCachedAccessor,
_copy_docstring,
_register_trackintel_accessor,
_wrapped_gdf_method,
get_speed_positionfixes,
TrackintelGeoDataFrame,
)


Expand All @@ -41,6 +43,30 @@ def load_positionfixes():
return pfs, correct_speeds


@pytest.fixture
def example_positionfixes():
"""Positionfixes for tests."""
p1 = Point(8.5067847, 47.4)
p2 = Point(8.5067847, 47.5)
p3 = Point(8.5067847, 47.6)

t1 = pd.Timestamp("1971-01-01 00:00:00", tz="utc")
t2 = pd.Timestamp("1971-01-01 05:00:00", tz="utc")
t3 = pd.Timestamp("1971-01-02 07:00:00", tz="utc")

list_dict = [
{"user_id": 0, "tracked_at": t1, "geometry": p1},
{"user_id": 0, "tracked_at": t2, "geometry": p2},
{"user_id": 1, "tracked_at": t3, "geometry": p3},
]
pfs = GeoDataFrame(data=list_dict, geometry="geometry", crs="EPSG:4326")
pfs.index.name = "id"

# assert validity of positionfixes.
pfs.as_positionfixes
return pfs


class TestSpeedPositionfixes:
def test_positionfixes_stable(self, load_positionfixes):
"""Test whether the positionfixes stay the same apart from the new speed column"""
Expand Down Expand Up @@ -215,6 +241,90 @@ def bar(b: int) -> int:
assert attr_foo != attr_bar


class Test_wrapped_gdf_method:
def test_no_geodataframe(self, example_positionfixes):
"""Test if function return value does not subclass GeoDataFrame then __class__ is not touched"""

def foo(gdf: GeoDataFrame) -> pd.DataFrame:
return gdf.drop(columns=gdf.geometry.name)

foo = _wrapped_gdf_method(foo)
assert type(foo(example_positionfixes)) == pd.DataFrame

def test_failed_check(self, example_positionfixes):
"""Test if _check fails then __class__ is not touched"""

class A(GeoDataFrame):
@staticmethod
def _check(obj, validate_geometry=True):
return False

@_wrapped_gdf_method
def foo(a: A) -> GeoDataFrame:
return GeoDataFrame(a)

a = A(example_positionfixes)
assert type(foo(a)) == GeoDataFrame

def test_keep_class(self, example_positionfixes):
"""Test if original class is restored if return value subclasses GeoDataFarme and fulfills _check"""

class A(GeoDataFrame):
@staticmethod
def _check(obj, validate_geometry=True):
return True

@_wrapped_gdf_method
def foo(a: A) -> GeoDataFrame:
return GeoDataFrame(a)

a = A(example_positionfixes)
assert type(foo(a)) == A


class TestTrackintelGeoDataFrame:
"""Test helper class TrackintelGeoDataFrame."""

class A(TrackintelGeoDataFrame):
"""Mimic TrackintelGeoDataFrame subclass by taking the same arguments"""

def __init__(self, *args, validate_geometry=True, **kwargs):
super().__init__(*args, **kwargs)

@staticmethod
def _check(obj, validate_geometry=True):
return True

def test_getitem(self, example_positionfixes):
"""Test if loc on all columns returns original class."""
a = self.A(example_positionfixes)
b = a.loc[[True for _ in a.columns]]
assert type(b) == self.A

def test_copy(self, example_positionfixes):
"""Test if copy maintains class."""
a = self.A(example_positionfixes)
b = a.copy()
assert type(b) == self.A

def test_merge(self, example_positionfixes):
"""Test if merge maintains class"""
a = self.A(example_positionfixes)
b = a.merge(a, on="user_id", suffixes=("", "_other"))
assert type(b) == self.A

def test_constructor_dataframe_fallback(self, example_positionfixes):
"""Test if _constructor gets DataFrame it falls through"""
a = self.A(example_positionfixes)
df = example_positionfixes.drop(columns=example_positionfixes.geometry.name)
assert type(a._constructor(df)) == pd.DataFrame

def test_constructor_calls_init(self, example_positionfixes):
"""Test if _constructor gets GeoDataFrame and fulfills test then builds class"""
a = self.A(example_positionfixes)
assert type(a._constructor(a)) == self.A


class TestNonCachedAccessor:
"""Test if NonCachedAccessor works"""

Expand Down
3 changes: 2 additions & 1 deletion trackintel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from trackintel.io.file import read_trips_csv
from trackintel.io.file import read_tours_csv

#
from trackintel.model.positionfixes import Positionfixes

from trackintel.__version__ import __version__
from .core import print_version
11 changes: 1 addition & 10 deletions trackintel/io/postgis.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,7 @@ def read_positionfixes_postgis(
def write_positionfixes_postgis(
positionfixes, name, con, schema=None, if_exists="fail", index=True, index_label=None, chunksize=None, dtype=None
):
positionfixes.to_postgis(
name,
con,
schema=schema,
if_exists=if_exists,
index=index,
index_label=index_label,
chunksize=chunksize,
dtype=dtype,
)
gpd.GeoDataFrame.to_postgis(positionfixes, name, con, schema, if_exists, index, index_label, chunksize, dtype)


@_index_warning_default_none
Expand Down
2 changes: 1 addition & 1 deletion trackintel/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .positionfixes import PositionfixesAccessor
from .positionfixes import Positionfixes
from .staypoints import StaypointsAccessor
from .triplegs import TriplegsAccessor
from .locations import LocationsAccessor
Expand Down
81 changes: 54 additions & 27 deletions trackintel/model/positionfixes.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
import pandas as pd
import geopandas as gpd
import trackintel as ti
from trackintel.geogr.distances import calculate_distance_matrix
from trackintel.io.file import write_positionfixes_csv
from trackintel.io.postgis import write_positionfixes_postgis
from trackintel.model.util import _copy_docstring
from trackintel.preprocessing.positionfixes import generate_staypoints, generate_triplegs
from trackintel.visualization.positionfixes import plot_positionfixes
from trackintel.model.util import get_speed_positionfixes, _register_trackintel_accessor
from trackintel.model.util import (
get_speed_positionfixes,
TrackintelBase,
TrackintelGeoDataFrame,
_register_trackintel_accessor,
)

_required_columns = ["user_id", "tracked_at"]


@_register_trackintel_accessor("as_positionfixes")
class PositionfixesAccessor(object):
class Positionfixes(TrackintelBase, TrackintelGeoDataFrame, gpd.GeoDataFrame):
"""A pandas accessor to treat (Geo)DataFrames as collections of `Positionfixes`.
This will define certain methods and accessors, as well as make sure that the DataFrame
Expand All @@ -37,40 +45,59 @@ class PositionfixesAccessor(object):
>>> df.as_positionfixes.generate_staypoints()
"""

required_columns = ["user_id", "tracked_at"]
def __init__(self, *args, validate_geometry=True, **kwargs):
# could be moved to super class
# validate kwarg is necessary as the object is not fully initialised if we call it from _constructor
# (geometry-link is missing). thus we need a way to stop validating too early.
super().__init__(*args, **kwargs)
self._validate(self, validate_geometry=validate_geometry)

def __init__(self, pandas_obj):
self._validate(pandas_obj)
self._obj = pandas_obj
# create circular reference directly -> avoid second call of init via accessor
@property
def as_positionfixes(self):
return self

@staticmethod
def _validate(obj):
def _validate(obj, validate_geometry=True):
assert obj.shape[0] > 0, f"Geodataframe is empty with shape: {obj.shape}"
# check columns
if any([c not in obj.columns for c in PositionfixesAccessor.required_columns]):
if any([c not in obj.columns for c in _required_columns]):
raise AttributeError(
"To process a DataFrame as a collection of positionfixes, it must have the properties"
f" {PositionfixesAccessor.required_columns}, but it has [{', '.join(obj.columns)}]."
f" {_required_columns}, but it has [{', '.join(obj.columns)}]."
)

# check geometry
assert (
obj.geometry.is_valid.all()
), "Not all geometries are valid. Try x[~ x.geometry.is_valid] where x is you GeoDataFrame"

if obj.geometry.iloc[0].geom_type != "Point":
raise AttributeError("The geometry must be a Point (only first checked).")

# check timestamp dtypes
assert pd.api.types.is_datetime64tz_dtype(
obj["tracked_at"]
), f"dtype of tracked_at is {obj['tracked_at'].dtype} but has to be datetime64 and timezone aware"

# check geometry
if validate_geometry:
assert (
obj.geometry.is_valid.all()
), "Not all geometries are valid. Try x[~ x.geometry.is_valid] where x is you GeoDataFrame"

if obj.geometry.iloc[0].geom_type != "Point":
raise AttributeError("The geometry must be a Point (only first checked).")

@staticmethod
def _check(obj, validate_geometry=True):
"""Check does the same as _validate but returns bool instead of potentially raising an error."""
if any([c not in obj.columns for c in _required_columns]):
return False
if obj.shape[0] <= 0:
return False
if not pd.api.types.is_datetime64tz_dtype(obj["tracked_at"]):
return False
if validate_geometry:
return obj.geometry.is_valid.all() and obj.geometry.iloc[0].geom_type == "Point"
return True

@property
def center(self):
"""Return the center coordinate of this collection of positionfixes."""
lat = self._obj.geometry.y
lon = self._obj.geometry.x
lat = self.geometry.y
lon = self.geometry.x
return (float(lon.mean()), float(lat.mean()))

@_copy_docstring(generate_staypoints)
Expand All @@ -80,7 +107,7 @@ def generate_staypoints(self, *args, **kwargs):
See :func:`trackintel.preprocessing.positionfixes.generate_staypoints`.
"""
return ti.preprocessing.positionfixes.generate_staypoints(self._obj, *args, **kwargs)
return ti.preprocessing.positionfixes.generate_staypoints(self, *args, **kwargs)

@_copy_docstring(generate_triplegs)
def generate_triplegs(self, staypoints=None, *args, **kwargs):
Expand All @@ -89,7 +116,7 @@ def generate_triplegs(self, staypoints=None, *args, **kwargs):
See :func:`trackintel.preprocessing.positionfixes.generate_triplegs`.
"""
return ti.preprocessing.positionfixes.generate_triplegs(self._obj, staypoints, *args, **kwargs)
return ti.preprocessing.positionfixes.generate_triplegs(self, staypoints, *args, **kwargs)

@_copy_docstring(plot_positionfixes)
def plot(self, *args, **kwargs):
Expand All @@ -98,7 +125,7 @@ def plot(self, *args, **kwargs):
See :func:`trackintel.visualization.positionfixes.plot_positionfixes`.
"""
ti.visualization.positionfixes.plot_positionfixes(self._obj, *args, **kwargs)
ti.visualization.positionfixes.plot_positionfixes(self, *args, **kwargs)

@_copy_docstring(write_positionfixes_csv)
def to_csv(self, filename, *args, **kwargs):
Expand All @@ -107,7 +134,7 @@ def to_csv(self, filename, *args, **kwargs):
See :func:`trackintel.io.file.write_positionfixes_csv`.
"""
ti.io.file.write_positionfixes_csv(self._obj, filename, *args, **kwargs)
ti.io.file.write_positionfixes_csv(self, filename, *args, **kwargs)

@_copy_docstring(write_positionfixes_postgis)
def to_postgis(
Expand All @@ -119,7 +146,7 @@ def to_postgis(
See :func:`trackintel.io.postgis.write_positionfixes_postgis`.
"""
ti.io.postgis.write_positionfixes_postgis(
self._obj, name, con, schema, if_exists, index, index_label, chunksize, dtype
self, name, con, schema, if_exists, index, index_label, chunksize, dtype
)

@_copy_docstring(calculate_distance_matrix)
Expand All @@ -129,7 +156,7 @@ def calculate_distance_matrix(self, *args, **kwargs):
See :func:`trackintel.geogr.distances.calculate_distance_matrix`.
"""
return ti.geogr.distances.calculate_distance_matrix(self._obj, *args, **kwargs)
return ti.geogr.distances.calculate_distance_matrix(self, *args, **kwargs)

@_copy_docstring(get_speed_positionfixes)
def get_speed(self, *args, **kwargs):
Expand All @@ -138,4 +165,4 @@ def get_speed(self, *args, **kwargs):
See :func:`trackintel.model.util.get_speed_positionfixes`.
"""
return ti.model.util.get_speed_positionfixes(self._obj, *args, **kwargs)
return ti.model.util.get_speed_positionfixes(self, *args, **kwargs)
Loading

0 comments on commit e2bc973

Please sign in to comment.