Skip to content

Commit

Permalink
ENH: enable subclassing for Locations (#511)
Browse files Browse the repository at this point in the history
* ENH: enable subclassing for Locations

Here comes the next class :)
For more info see #490 .

* TST: add test for as_locations property

---------

Co-authored-by: Ye <[email protected]>
  • Loading branch information
bifbof and hongyeehh authored Aug 17, 2023
1 parent bed90f0 commit 9f1b15a
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 33 deletions.
4 changes: 2 additions & 2 deletions docs/modules/model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ Triplegs
.. autoclass:: trackintel.model.triplegs.Triplegs
:members:

LocationsAccessor
Locations
-----------------

.. autoclass:: trackintel.model.locations.LocationsAccessor
.. autoclass:: trackintel.model.locations.Locations
:members:

TripsAccessor
Expand Down
58 changes: 47 additions & 11 deletions tests/model/test_locations.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import pytest

import pytest
from shapely.geometry import LineString

import trackintel as ti
from trackintel import Locations


@pytest.fixture
Expand All @@ -14,24 +15,59 @@ def testdata_locs():
sp, locs = sp.as_staypoints.generate_locations(
method="dbscan", epsilon=10, num_samples=1, distance_metric="haversine", agg_level="dataset"
)
locs.as_locations
return locs


class TestLocations:
"""Tests for the LocationsAccessor."""
"""Tests for the Locations class."""

def test_accessor_column(self, testdata_locs):
"""Test if the as_locations accessor checks the required column for locations."""
locs = testdata_locs.copy()

with pytest.raises(AttributeError, match="To process a DataFrame as a collection of locations"):
locs.drop(["user_id"], axis=1).as_locations
testdata_locs.drop(["user_id"], axis=1).as_locations

def test_accessor_geometry_type(self, testdata_locs):
"""Test if the as_locations accessor requires Point geometry."""
locs = testdata_locs.copy()
with pytest.raises(AttributeError, match="The center geometry must be a Point"):
locs["center"] = LineString(
[(13.476808430, 48.573711823), (13.506804, 48.939008), (13.4664690, 48.5706414)]
)
locs.as_locations
testdata_locs["center"] = LineString(
[(13.476808430, 48.573711823), (13.506804, 48.939008), (13.4664690, 48.5706414)]
)
with pytest.raises(ValueError, match="The center geometry must be a Point"):
testdata_locs.as_locations

def test_accessor_empty(self, testdata_locs):
"""Test if as_locations accessor raises error if data is empty."""
with pytest.raises(ValueError, match="GeoDataFrame is empty with shape:"):
testdata_locs.drop(testdata_locs.index).as_locations

def test_accessor_recursive(self, testdata_locs):
"""Test if as_locations works recursivly"""
locs = testdata_locs.as_locations
assert type(locs) == Locations
assert id(locs) == id(locs.as_locations)

def test_check_suceeding(self, testdata_locs):
"""Test if check returns True on valid pfs"""
assert Locations._check(testdata_locs)

def test_check_missing_columns(self, testdata_locs):
"""Test if check returns False if column is missing"""
assert not Locations._check(testdata_locs.drop(columns="user_id"))

def test_check_empty_df(self, testdata_locs):
"""Test if check returns False if DataFrame is empty"""
assert not Locations._check(testdata_locs.drop(testdata_locs.index))

def test_check_false_geometry_type(self, testdata_locs):
"""Test if check returns False if geometry type is wrong"""
testdata_locs["center"] = LineString(
[(13.476808430, 48.573711823), (13.506804, 48.939008), (13.4664690, 48.5706414)]
)
assert not Locations._check(testdata_locs)

def test_check_ignore_false_geometry_type(self, testdata_locs):
"""Test if check returns True if geometry type is wrong but validate_geometry is set to False"""
testdata_locs["center"] = LineString(
[(13.476808430, 48.573711823), (13.506804, 48.939008), (13.4664690, 48.5706414)]
)
assert Locations._check(testdata_locs, validate_geometry=False)
1 change: 1 addition & 0 deletions trackintel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from trackintel.io.file import read_tours_csv

from trackintel.model.positionfixes import Positionfixes
from trackintel.model.locations import Locations
from trackintel.model.triplegs import Triplegs
from trackintel.model.staypoints import Staypoints

Expand Down
3 changes: 2 additions & 1 deletion trackintel/io/postgis.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,8 @@ def write_locations_postgis(
locations = locations.copy()
locations["extent"] = locations["extent"].apply(lambda x: wkb.dumps(x, srid=srid, hex=True))

locations.to_postgis(
gpd.GeoDataFrame.to_postgis(
locations,
name,
con,
schema=schema,
Expand Down
2 changes: 1 addition & 1 deletion trackintel/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .positionfixes import Positionfixes
from .locations import Locations
from .triplegs import Triplegs
from .staypoints import Staypoints
from .locations import LocationsAccessor
from .trips import TripsAccessor
from .tours import ToursAccessor
56 changes: 38 additions & 18 deletions trackintel/model/locations.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import trackintel as ti
from trackintel.io.file import write_locations_csv
from trackintel.io.postgis import write_locations_postgis
from trackintel.model.util import _copy_docstring, _register_trackintel_accessor
from trackintel.model.util import (
TrackintelBase,
TrackintelGeoDataFrame,
_copy_docstring,
_register_trackintel_accessor,
)
from trackintel.preprocessing.filter import spatial_filter
from trackintel.visualization.locations import plot_locations

_required_columns = ["user_id", "center"]


@_register_trackintel_accessor("as_locations")
class LocationsAccessor(object):
"""A pandas accessor to treat (Geo)DataFrames as collections of locations.
class Locations(TrackintelBase, TrackintelGeoDataFrame):
"""A pandas accessor to treat a GeoDataFrames as a collections of locations.
This will define certain methods and accessors, as well as make sure that the DataFrame
adheres to some requirements.
Expand All @@ -28,24 +35,39 @@ class LocationsAccessor(object):
>>> df.as_locations.plot()
"""

required_columns = ["user_id", "center"]
def __init__(self, *args, validate_geometry=True, **kwargs):
super().__init__(*args, **kwargs)
self._validate(self, validate_geometry=validate_geometry)

def __init__(self, pandas_obj):
self._validate(pandas_obj)
self._obj = pandas_obj
@property
def as_locations(self):
return self

@staticmethod
def _validate(obj):
if any([c not in obj.columns for c in LocationsAccessor.required_columns]):
def _validate(obj, validate_geometry):
if any([c not in obj.columns for c in _required_columns]):
raise AttributeError(
"To process a DataFrame as a collection of locations, it must have the properties"
f" {LocationsAccessor.required_columns}, but it has [{', '.join(obj.columns)}]."
f" {_required_columns}, but it has [{', '.join(obj.columns)}]."
)
if obj.shape[0] <= 0:
raise ValueError(f"GeoDataFrame is empty with shape: {obj.shape}")

if not (obj.shape[0] > 0 and obj["center"].iloc[0].geom_type == "Point"):
if validate_geometry and obj["center"].iloc[0].geom_type != "Point":
# todo: We could think about allowing both geometry types for locations (point and polygon)
# One for extend and one for the center
raise AttributeError("The center geometry must be a Point (only first checked).")
raise ValueError("The center geometry must be a Point (only first checked).")

@staticmethod
def _check(obj, validate_geometry=True):
"""Check does the same as _validate but returns bool instead of potentially raising an error."""
if any([c not in obj.columns for c in _required_columns]):
return False
if obj.shape[0] <= 0:
return False
if validate_geometry:
return obj.geometry.iloc[0].geom_type == "Point"
return True

@_copy_docstring(plot_locations)
def plot(self, *args, **kwargs):
Expand All @@ -54,7 +76,7 @@ def plot(self, *args, **kwargs):
See :func:`trackintel.visualization.locations.plot_locations`.
"""
ti.visualization.locations.plot_locations(self._obj, *args, **kwargs)
ti.visualization.locations.plot_locations(self, *args, **kwargs)

@_copy_docstring(write_locations_csv)
def to_csv(self, filename, *args, **kwargs):
Expand All @@ -63,7 +85,7 @@ def to_csv(self, filename, *args, **kwargs):
See :func:`trackintel.io.file.write_locations_csv`.
"""
ti.io.file.write_locations_csv(self._obj, filename, *args, **kwargs)
ti.io.file.write_locations_csv(self, filename, *args, **kwargs)

@_copy_docstring(write_locations_postgis)
def to_postgis(
Expand All @@ -74,9 +96,7 @@ def to_postgis(
See :func:`trackintel.io.postgis.write_locations_postgis`.
"""
ti.io.postgis.write_locations_postgis(
self._obj, name, con, schema, if_exists, index, index_label, chunksize, dtype
)
ti.io.postgis.write_locations_postgis(self, name, con, schema, if_exists, index, index_label, chunksize, dtype)

@_copy_docstring(spatial_filter)
def spatial_filter(self, *args, **kwargs):
Expand All @@ -85,4 +105,4 @@ def spatial_filter(self, *args, **kwargs):
See :func:`trackintel.preprocessing.filter.spatial_filter`.
"""
return ti.preprocessing.filter.spatial_filter(self._obj, *args, **kwargs)
return ti.preprocessing.filter.spatial_filter(self, *args, **kwargs)

0 comments on commit 9f1b15a

Please sign in to comment.