Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Use of subtyping instead of accessors #490

Merged
merged 10 commits into from
Aug 15, 2023
4 changes: 2 additions & 2 deletions docs/modules/model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ Available Accessors

The following accessors are available within *trackintel*.

PositionfixesAccessor
Positionfixes
---------------------

.. autoclass:: trackintel.model.positionfixes.PositionfixesAccessor
.. autoclass:: trackintel.model.positionfixes.Positionfixes
:members:

StaypointsAccessor
Expand Down
34 changes: 33 additions & 1 deletion tests/model/test_positionfixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from shapely.geometry import LineString

import trackintel as ti
from trackintel import Positionfixes


@pytest.fixture
Expand All @@ -15,7 +16,7 @@ def testdata_geolife():


class TestPositionfixes:
"""Tests for the PositionfixesAccessor."""
"""Tests for the Positionfixes class."""

def test_accessor_column(self, testdata_geolife):
"""Test if the as_positionfixes accessor checks the required column for positionfixes."""
Expand Down Expand Up @@ -55,3 +56,34 @@ def test_similarity_matrix(self, testdata_geolife):
accessor_result = pfs.as_positionfixes.calculate_distance_matrix(dist_metric="haversine", n_jobs=1)
function_result = ti.geogr.distances.calculate_distance_matrix(pfs, dist_metric="haversine", n_jobs=1)
assert np.allclose(accessor_result, function_result)

def test_check_suceeding(self, testdata_geolife):
"""Test if check returns True on valid pfs"""
assert Positionfixes._check(testdata_geolife)

def test_check_missing_columns(self, testdata_geolife):
"""Test if check returns False if column is missing"""
assert not Positionfixes._check(testdata_geolife.drop(columns="user_id"))

def test_check_empty_df(self, testdata_geolife):
"""Test if check returns False if DataFrame is empty"""
assert not Positionfixes._check(testdata_geolife.drop(testdata_geolife.index))

def test_check_no_tz(self, testdata_geolife):
"""Test if check returns False if tracked at column has no tz"""
testdata_geolife["tracked_at"] = testdata_geolife["tracked_at"].dt.tz_localize(None)
assert not Positionfixes._check(testdata_geolife)

def test_check_false_geometry_type(self, testdata_geolife):
"""Test if check returns False if geometry type is wrong"""
testdata_geolife["geom"] = LineString(
[(13.476808430, 48.573711823), (13.506804, 48.939008), (13.4664690, 48.5706414)]
)
assert not Positionfixes._check(testdata_geolife)

def test_check_ignore_false_geometry_type(self, testdata_geolife):
"""Test if check returns True if geometry type is wrong but validate_geometry is set to False"""
testdata_geolife["geom"] = LineString(
[(13.476808430, 48.573711823), (13.506804, 48.939008), (13.4664690, 48.5706414)]
)
assert Positionfixes._check(testdata_geolife, validate_geometry=False)
110 changes: 110 additions & 0 deletions tests/model/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
NonCachedAccessor,
_copy_docstring,
_register_trackintel_accessor,
_wrapped_gdf_method,
get_speed_positionfixes,
TrackintelGeoDataFrame,
)


Expand All @@ -41,6 +43,30 @@ def load_positionfixes():
return pfs, correct_speeds


@pytest.fixture
def example_positionfixes():
"""Positionfixes for tests."""
p1 = Point(8.5067847, 47.4)
p2 = Point(8.5067847, 47.5)
p3 = Point(8.5067847, 47.6)

t1 = pd.Timestamp("1971-01-01 00:00:00", tz="utc")
t2 = pd.Timestamp("1971-01-01 05:00:00", tz="utc")
t3 = pd.Timestamp("1971-01-02 07:00:00", tz="utc")

list_dict = [
{"user_id": 0, "tracked_at": t1, "geometry": p1},
{"user_id": 0, "tracked_at": t2, "geometry": p2},
{"user_id": 1, "tracked_at": t3, "geometry": p3},
]
pfs = GeoDataFrame(data=list_dict, geometry="geometry", crs="EPSG:4326")
pfs.index.name = "id"

# assert validity of positionfixes.
pfs.as_positionfixes
return pfs


class TestSpeedPositionfixes:
def test_positionfixes_stable(self, load_positionfixes):
"""Test whether the positionfixes stay the same apart from the new speed column"""
Expand Down Expand Up @@ -215,6 +241,90 @@ def bar(b: int) -> int:
assert attr_foo != attr_bar


class Test_wrapped_gdf_method:
def test_no_geodataframe(self, example_positionfixes):
"""Test if function return value does not subclass GeoDataFrame then __class__ is not touched"""

def foo(gdf: GeoDataFrame) -> pd.DataFrame:
return gdf.drop(columns=gdf.geometry.name)

foo = _wrapped_gdf_method(foo)
assert type(foo(example_positionfixes)) == pd.DataFrame

def test_failed_check(self, example_positionfixes):
"""Test if _check fails then __class__ is not touched"""

class A(GeoDataFrame):
@staticmethod
def _check(obj, validate_geometry=True):
return False

@_wrapped_gdf_method
def foo(a: A) -> GeoDataFrame:
return GeoDataFrame(a)

a = A(example_positionfixes)
assert type(foo(a)) == GeoDataFrame

def test_keep_class(self, example_positionfixes):
"""Test if original class is restored if return value subclasses GeoDataFarme and fulfills _check"""

class A(GeoDataFrame):
@staticmethod
def _check(obj, validate_geometry=True):
return True

@_wrapped_gdf_method
def foo(a: A) -> GeoDataFrame:
return GeoDataFrame(a)

a = A(example_positionfixes)
assert type(foo(a)) == A


class TestTrackintelGeoDataFrame:
"""Test helper class TrackintelGeoDataFrame."""

class A(TrackintelGeoDataFrame):
"""Mimic TrackintelGeoDataFrame subclass by taking the same arguments"""

def __init__(self, *args, validate_geometry=True, **kwargs):
super().__init__(*args, **kwargs)

@staticmethod
def _check(obj, validate_geometry=True):
return True

def test_getitem(self, example_positionfixes):
"""Test if loc on all columns returns original class."""
a = self.A(example_positionfixes)
b = a.loc[[True for _ in a.columns]]
assert type(b) == self.A

def test_copy(self, example_positionfixes):
"""Test if copy maintains class."""
a = self.A(example_positionfixes)
b = a.copy()
assert type(b) == self.A

def test_merge(self, example_positionfixes):
"""Test if merge maintains class"""
a = self.A(example_positionfixes)
b = a.merge(a, on="user_id", suffixes=("", "_other"))
assert type(b) == self.A

def test_constructor_dataframe_fallback(self, example_positionfixes):
"""Test if _constructor gets DataFrame it falls through"""
a = self.A(example_positionfixes)
df = example_positionfixes.drop(columns=example_positionfixes.geometry.name)
assert type(a._constructor(df)) == pd.DataFrame

def test_constructor_calls_init(self, example_positionfixes):
"""Test if _constructor gets GeoDataFrame and fulfills test then builds class"""
a = self.A(example_positionfixes)
assert type(a._constructor(a)) == self.A


class TestNonCachedAccessor:
"""Test if NonCachedAccessor works"""

Expand Down
3 changes: 2 additions & 1 deletion trackintel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from trackintel.io.file import read_trips_csv
from trackintel.io.file import read_tours_csv

#
from trackintel.model.positionfixes import Positionfixes

from trackintel.__version__ import __version__
from .core import print_version
11 changes: 1 addition & 10 deletions trackintel/io/postgis.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,7 @@ def read_positionfixes_postgis(
def write_positionfixes_postgis(
positionfixes, name, con, schema=None, if_exists="fail", index=True, index_label=None, chunksize=None, dtype=None
):
positionfixes.to_postgis(
name,
con,
schema=schema,
if_exists=if_exists,
index=index,
index_label=index_label,
chunksize=chunksize,
dtype=dtype,
)
gpd.GeoDataFrame.to_postgis(positionfixes, name, con, schema, if_exists, index, index_label, chunksize, dtype)
hongyeehh marked this conversation as resolved.
Show resolved Hide resolved


@_handle_con_string
Expand Down
2 changes: 1 addition & 1 deletion trackintel/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .positionfixes import PositionfixesAccessor
from .positionfixes import Positionfixes
from .staypoints import StaypointsAccessor
from .triplegs import TriplegsAccessor
from .locations import LocationsAccessor
Expand Down
84 changes: 55 additions & 29 deletions trackintel/model/positionfixes.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
import pandas as pd
import geopandas as gpd
import trackintel as ti
from trackintel.geogr.distances import calculate_distance_matrix
from trackintel.io.file import write_positionfixes_csv
from trackintel.io.postgis import write_positionfixes_postgis
from trackintel.model.util import _copy_docstring
from trackintel.preprocessing.positionfixes import generate_staypoints, generate_triplegs
from trackintel.visualization.positionfixes import plot_positionfixes
from trackintel.model.util import get_speed_positionfixes, _register_trackintel_accessor
from trackintel.model.util import (
get_speed_positionfixes,
TrackintelBase,
TrackintelGeoDataFrame,
_register_trackintel_accessor,
)

_required_columns = ["user_id", "tracked_at"]


@_register_trackintel_accessor("as_positionfixes")
class PositionfixesAccessor(object):
class Positionfixes(TrackintelBase, TrackintelGeoDataFrame, gpd.GeoDataFrame):
"""A pandas accessor to treat (Geo)DataFrames as collections of `Positionfixes`.

This will define certain methods and accessors, as well as make sure that the DataFrame
Expand All @@ -37,41 +45,59 @@ class PositionfixesAccessor(object):
>>> df.as_positionfixes.generate_staypoints()
"""

required_columns = ["user_id", "tracked_at"]
def __init__(self, *args, validate_geometry=True, **kwargs):
# could be moved to super class
# validate kwarg is necessary as the object is not fully initialised if we call it from _constructor
# (geometry-link is missing). thus we need a way to stop validating too early.
super().__init__(*args, **kwargs)
self._validate(self, validate_geometry=validate_geometry)

def __init__(self, pandas_obj):
self._validate(pandas_obj)
self._obj = pandas_obj
# create circular reference directly -> avoid second call of init via accessor
@property
def as_positionfixes(self):
return self

@staticmethod
def _validate(obj):
assert obj.shape[0] > 0, "Geodataframe is empty with shape: {}".format(obj.shape)
def _validate(obj, validate_geometry=True):
assert obj.shape[0] > 0, f"Geodataframe is empty with shape: {obj.shape}"
# check columns
if any([c not in obj.columns for c in PositionfixesAccessor.required_columns]):
if any([c not in obj.columns for c in _required_columns]):
raise AttributeError(
"To process a DataFrame as a collection of positionfixes, "
+ "it must have the properties [%s], but it has [%s]."
% (", ".join(PositionfixesAccessor.required_columns), ", ".join(obj.columns))
f"it must have the columns {_required_columns}, but it has [{', '.join(obj.columns)}]."
)
# check timestamp dtypes
assert pd.api.types.is_datetime64tz_dtype(
obj["tracked_at"]
), f"dtype of tracked_at is {obj['tracked_at'].dtype} but has to be datetime64 and timezone aware"

# check geometry
assert obj.geometry.is_valid.all(), (
"Not all geometries are valid. Try x[~ x.geometry.is_valid] " "where x is you GeoDataFrame"
)
if validate_geometry:
assert obj.geometry.is_valid.all(), (
"Not all geometries are valid. Try x[~ x.geometry.is_valid] " "where x is you GeoDataFrame"
)

if obj.geometry.iloc[0].geom_type != "Point":
raise AttributeError("The geometry must be a Point (only first checked).")
if obj.geometry.iloc[0].geom_type != "Point":
raise AttributeError("The geometry must be a Point (only first checked).")

# check timestamp dtypes
assert pd.api.types.is_datetime64tz_dtype(
obj["tracked_at"]
), "dtype of tracked_at is {} but has to be datetime64 and timezone aware".format(obj["tracked_at"].dtype)
@staticmethod
def _check(obj, validate_geometry=True):
"""Check does the same as _validate but returns bool instead of potentially raising an error."""
if any([c not in obj.columns for c in _required_columns]):
return False
if obj.shape[0] <= 0:
return False
if not pd.api.types.is_datetime64tz_dtype(obj["tracked_at"]):
return False
if validate_geometry:
return obj.geometry.is_valid.all() and obj.geometry.iloc[0].geom_type == "Point"
return True

@property
def center(self):
"""Return the center coordinate of this collection of positionfixes."""
lat = self._obj.geometry.y
lon = self._obj.geometry.x
lat = self.geometry.y
lon = self.geometry.x
return (float(lon.mean()), float(lat.mean()))

@_copy_docstring(generate_staypoints)
Expand All @@ -81,7 +107,7 @@ def generate_staypoints(self, *args, **kwargs):

See :func:`trackintel.preprocessing.positionfixes.generate_staypoints`.
"""
return ti.preprocessing.positionfixes.generate_staypoints(self._obj, *args, **kwargs)
return ti.preprocessing.positionfixes.generate_staypoints(self, *args, **kwargs)

@_copy_docstring(generate_triplegs)
def generate_triplegs(self, staypoints=None, *args, **kwargs):
Expand All @@ -90,7 +116,7 @@ def generate_triplegs(self, staypoints=None, *args, **kwargs):

See :func:`trackintel.preprocessing.positionfixes.generate_triplegs`.
"""
return ti.preprocessing.positionfixes.generate_triplegs(self._obj, staypoints, *args, **kwargs)
return ti.preprocessing.positionfixes.generate_triplegs(self, staypoints, *args, **kwargs)

@_copy_docstring(plot_positionfixes)
def plot(self, *args, **kwargs):
Expand All @@ -99,7 +125,7 @@ def plot(self, *args, **kwargs):

See :func:`trackintel.visualization.positionfixes.plot_positionfixes`.
"""
ti.visualization.positionfixes.plot_positionfixes(self._obj, *args, **kwargs)
ti.visualization.positionfixes.plot_positionfixes(self, *args, **kwargs)

@_copy_docstring(write_positionfixes_csv)
def to_csv(self, filename, *args, **kwargs):
Expand All @@ -108,7 +134,7 @@ def to_csv(self, filename, *args, **kwargs):

See :func:`trackintel.io.file.write_positionfixes_csv`.
"""
ti.io.file.write_positionfixes_csv(self._obj, filename, *args, **kwargs)
ti.io.file.write_positionfixes_csv(self, filename, *args, **kwargs)

@_copy_docstring(write_positionfixes_postgis)
def to_postgis(
Expand All @@ -120,7 +146,7 @@ def to_postgis(
See :func:`trackintel.io.postgis.write_positionfixes_postgis`.
"""
ti.io.postgis.write_positionfixes_postgis(
self._obj, name, con, schema, if_exists, index, index_label, chunksize, dtype
self, name, con, schema, if_exists, index, index_label, chunksize, dtype
)

@_copy_docstring(calculate_distance_matrix)
Expand All @@ -130,7 +156,7 @@ def calculate_distance_matrix(self, *args, **kwargs):

See :func:`trackintel.geogr.distances.calculate_distance_matrix`.
"""
return ti.geogr.distances.calculate_distance_matrix(self._obj, *args, **kwargs)
return ti.geogr.distances.calculate_distance_matrix(self, *args, **kwargs)

@_copy_docstring(get_speed_positionfixes)
def get_speed(self, *args, **kwargs):
Expand All @@ -139,4 +165,4 @@ def get_speed(self, *args, **kwargs):

See :func:`trackintel.model.util.get_speed_positionfixes`.
"""
return ti.model.util.get_speed_positionfixes(self._obj, *args, **kwargs)
return ti.model.util.get_speed_positionfixes(self, *args, **kwargs)
Loading
Loading