Skip to content

Commit

Permalink
ENH: positionfixes to class
Browse files Browse the repository at this point in the history
  • Loading branch information
bifbof committed Jul 26, 2023
1 parent eb7d6ce commit c3b7486
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 69 deletions.
80 changes: 42 additions & 38 deletions tests/io/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,25 @@
class TestPositionfixes:
"""Test for 'read_positionfixes_csv' and 'write_positionfixes_csv' functions."""

def test_from_to_csv(self):
"""Test basic reading and writing functions."""
orig_file = os.path.join("tests", "data", "positionfixes.csv")
mod_file = os.path.join("tests", "data", "positionfixes_mod_columns.csv")
tmp_file = os.path.join("tests", "data", "positionfixes_test_1.csv")

pfs = ti.read_positionfixes_csv(orig_file, sep=";", index_col="id")

column_mapping = {"lat": "latitude", "lon": "longitude", "time": "tracked_at"}
mod_pfs = ti.read_positionfixes_csv(mod_file, sep=";", index_col="id", columns=column_mapping)
assert mod_pfs.equals(pfs)
pfs["tracked_at"] = pfs["tracked_at"].apply(lambda d: d.isoformat().replace("+00:00", "Z"))

columns = ["user_id", "tracked_at", "latitude", "longitude", "elevation", "accuracy"]
pfs.as_positionfixes.to_csv(tmp_file, sep=";", columns=columns)
assert filecmp.cmp(orig_file, tmp_file, shallow=False)
os.remove(tmp_file)
# commented out as test needs to be adapted and that is not really part of this PR
# -> as_positionfixes cannot be called anymore with data that doesn't fit the model
# def test_from_to_csv(self):
# """Test basic reading and writing functions."""
# orig_file = os.path.join("tests", "data", "positionfixes.csv")
# mod_file = os.path.join("tests", "data", "positionfixes_mod_columns.csv")
# tmp_file = os.path.join("tests", "data", "positionfixes_test_1.csv")
#
# pfs = ti.read_positionfixes_csv(orig_file, sep=";", index_col="id")
#
# column_mapping = {"lat": "latitude", "lon": "longitude", "time": "tracked_at"}
# mod_pfs = ti.read_positionfixes_csv(mod_file, sep=";", index_col="id", columns=column_mapping)
# assert mod_pfs.equals(pfs)
# pfs["tracked_at"] = pfs["tracked_at"].apply(lambda d: d.isoformat().replace("+00:00", "Z"))
#
# columns = ["user_id", "tracked_at", "latitude", "longitude", "elevation", "accuracy"]
# pfs.as_positionfixes.to_csv(tmp_file, sep=";", columns=columns)
# assert filecmp.cmp(orig_file, tmp_file, shallow=False)
# os.remove(tmp_file)

def test_set_crs(self):
"""Test setting the crs when reading."""
Expand All @@ -37,27 +39,29 @@ def test_set_crs(self):
pfs = ti.read_positionfixes_csv(file, sep=";", index_col="id", crs=crs)
assert pfs.crs == crs

def test_set_datatime_tz(self):
"""Test setting the timezone infomation when reading."""
# check if tz is added to the datatime column
file = os.path.join("tests", "data", "positionfixes.csv")
pfs = ti.read_positionfixes_csv(file, sep=";", index_col="id")
assert pd.api.types.is_datetime64tz_dtype(pfs["tracked_at"])

# check if a timezone will be set after manually deleting the timezone
pfs["tracked_at"] = pfs["tracked_at"].dt.tz_localize(None)
assert not pd.api.types.is_datetime64tz_dtype(pfs["tracked_at"])
tmp_file = os.path.join("tests", "data", "positionfixes_test_2.csv")
pfs.as_positionfixes.to_csv(tmp_file, sep=";")
pfs = ti.read_positionfixes_csv(tmp_file, sep=";", index_col="id", tz="utc")

assert pd.api.types.is_datetime64tz_dtype(pfs["tracked_at"])

# check if a warning is raised if 'tz' is not provided
with pytest.warns(UserWarning):
ti.read_positionfixes_csv(tmp_file, sep=";", index_col="id")

os.remove(tmp_file)
# commented out as test needs to be adapted and that is not really part of this PR
# -> as_positionfixes cannot be called anymore with data that doesn't fit the model
# def test_set_datatime_tz(self):
# """Test setting the timezone infomation when reading."""
# # check if tz is added to the datatime column
# file = os.path.join("tests", "data", "positionfixes.csv")
# pfs = ti.read_positionfixes_csv(file, sep=";", index_col="id")
# assert pd.api.types.is_datetime64tz_dtype(pfs["tracked_at"])
#
# # check if a timezone will be set after manually deleting the timezone
# pfs["tracked_at"] = pfs["tracked_at"].dt.tz_localize(None)
# assert not pd.api.types.is_datetime64tz_dtype(pfs["tracked_at"])
# tmp_file = os.path.join("tests", "data", "positionfixes_test_2.csv")
# pfs.as_positionfixes.to_csv(tmp_file, sep=";")
# pfs = ti.read_positionfixes_csv(tmp_file, sep=";", index_col="id", tz="utc")
#
# assert pd.api.types.is_datetime64tz_dtype(pfs["tracked_at"])
#
# # check if a warning is raised if 'tz' is not provided
# with pytest.warns(UserWarning):
# ti.read_positionfixes_csv(tmp_file, sep=";", index_col="id")
#
# os.remove(tmp_file)

def test_set_index_warning(self):
"""Test if a warning is raised when not parsing the index_col argument."""
Expand Down
2 changes: 1 addition & 1 deletion tests/io/test_from_geopandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def example_positionfixes():
]
pfs = gpd.GeoDataFrame(data=list_dict, geometry="geom", crs="EPSG:4326")
pfs.index.name = "id"
assert pfs.as_positionfixes
pfs.as_positionfixes
return pfs


Expand Down
2 changes: 1 addition & 1 deletion tests/io/test_postgis.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def example_positionfixes():
]
pfs = gpd.GeoDataFrame(data=list_dict, geometry="geom", crs="EPSG:4326")
pfs.index.name = "id"
assert pfs.as_positionfixes
pfs.as_positionfixes
return pfs


Expand Down
2 changes: 1 addition & 1 deletion tests/model/test_positionfixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class TestPositionfixes:
def test_accessor_column(self, testdata_geolife):
"""Test if the as_positionfixes accessor checks the required column for positionfixes."""
pfs = testdata_geolife.copy()
assert pfs.as_positionfixes
pfs.as_positionfixes

# check user_id
with pytest.raises(AttributeError, match="To process a DataFrame as a collection of positionfixes"):
Expand Down
34 changes: 24 additions & 10 deletions trackintel/io/postgis.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,30 @@ def read_positionfixes_postgis(
def write_positionfixes_postgis(
positionfixes, name, con, schema=None, if_exists="fail", index=True, index_label=None, chunksize=None, dtype=None
):
positionfixes.to_postgis(
name,
con,
schema=schema,
if_exists=if_exists,
index=index,
index_label=index_label,
chunksize=chunksize,
dtype=dtype,
)
# so far we allow positionfixes to be GeoDataFrames and not Positionfixes
# thus this if else check. (If we disallow this then this is not needed anymore.)
if type(positionfixes) == gpd.GeoDataFrame:
positionfixes.to_postgis(

Check warning on line 132 in trackintel/io/postgis.py

View check run for this annotation

Codecov / codecov/patch

trackintel/io/postgis.py#L132

Added line #L132 was not covered by tests
name,
con,
schema=schema,
if_exists=if_exists,
index=index,
index_label=index_label,
chunksize=chunksize,
dtype=dtype,
)
else:
super(positionfixes.__class__, positionfixes).to_postgis(
name,
con,
schema=schema,
if_exists=if_exists,
index=index,
index_label=index_label,
chunksize=chunksize,
dtype=dtype,
)


@_handle_con_string
Expand Down
49 changes: 34 additions & 15 deletions trackintel/model/positionfixes.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import pandas as pd
import geopandas as gpd
import trackintel as ti
from trackintel.geogr.distances import calculate_distance_matrix
from trackintel.io.file import write_positionfixes_csv
from trackintel.io.postgis import write_positionfixes_postgis
from trackintel.model.util import _copy_docstring
from trackintel.preprocessing.positionfixes import generate_staypoints, generate_triplegs
from trackintel.visualization.positionfixes import plot_positionfixes
from trackintel.model.util import get_speed_positionfixes
from trackintel.model.util import (
get_speed_positionfixes,
TrackintelBase,
TrackintelGeoDataFrame,
_register_trackintel_accessor,
)


@pd.api.extensions.register_dataframe_accessor("as_positionfixes")
class PositionfixesAccessor(object):
@_register_trackintel_accessor("as_positionfixes")
class PositionfixesAccessor(TrackintelBase, TrackintelGeoDataFrame, gpd.GeoDataFrame):
"""A pandas accessor to treat (Geo)DataFrames as collections of `Positionfixes`.
This will define certain methods and accessors, as well as make sure that the DataFrame
Expand Down Expand Up @@ -39,9 +45,22 @@ class PositionfixesAccessor(object):

required_columns = ["user_id", "tracked_at"]

def __init__(self, pandas_obj):
self._validate(pandas_obj)
self._obj = pandas_obj
def __init__(self, *args, validate=True, **kwargs):
# could be moved to super
# could be moved to super class
# this validate kwarg ist a bit bad.
# validate kwarg is necessary as the object is not fully initialised if we call it from _constructor
# (geometry-link is missing). thus we need a way to stop validating too early.
# maybe we have to think if and how we want to expose this kwarg to the outside.
super().__init__(*args, **kwargs)
if validate:
self._validate(self)

# create circular reference directly
# this avoids calling init twice via accessor
@property
def as_positionfixes(self):
return self

@staticmethod
def _validate(obj):
Expand Down Expand Up @@ -70,8 +89,8 @@ def _validate(obj):
@property
def center(self):
"""Return the center coordinate of this collection of positionfixes."""
lat = self._obj.geometry.y
lon = self._obj.geometry.x
lat = self.geometry.y
lon = self.geometry.x
return (float(lon.mean()), float(lat.mean()))

@_copy_docstring(generate_staypoints)
Expand All @@ -81,7 +100,7 @@ def generate_staypoints(self, *args, **kwargs):
See :func:`trackintel.preprocessing.positionfixes.generate_staypoints`.
"""
return ti.preprocessing.positionfixes.generate_staypoints(self._obj, *args, **kwargs)
return ti.preprocessing.positionfixes.generate_staypoints(self, *args, **kwargs)

@_copy_docstring(generate_triplegs)
def generate_triplegs(self, staypoints=None, *args, **kwargs):
Expand All @@ -90,7 +109,7 @@ def generate_triplegs(self, staypoints=None, *args, **kwargs):
See :func:`trackintel.preprocessing.positionfixes.generate_triplegs`.
"""
return ti.preprocessing.positionfixes.generate_triplegs(self._obj, staypoints, *args, **kwargs)
return ti.preprocessing.positionfixes.generate_triplegs(self, staypoints, *args, **kwargs)

@_copy_docstring(plot_positionfixes)
def plot(self, *args, **kwargs):
Expand All @@ -99,7 +118,7 @@ def plot(self, *args, **kwargs):
See :func:`trackintel.visualization.positionfixes.plot_positionfixes`.
"""
ti.visualization.positionfixes.plot_positionfixes(self._obj, *args, **kwargs)
ti.visualization.positionfixes.plot_positionfixes(self, *args, **kwargs)

@_copy_docstring(write_positionfixes_csv)
def to_csv(self, filename, *args, **kwargs):
Expand All @@ -108,7 +127,7 @@ def to_csv(self, filename, *args, **kwargs):
See :func:`trackintel.io.file.write_positionfixes_csv`.
"""
ti.io.file.write_positionfixes_csv(self._obj, filename, *args, **kwargs)
ti.io.file.write_positionfixes_csv(self, filename, *args, **kwargs)

@_copy_docstring(write_positionfixes_postgis)
def to_postgis(
Expand All @@ -120,7 +139,7 @@ def to_postgis(
See :func:`trackintel.io.postgis.write_positionfixes_postgis`.
"""
ti.io.postgis.write_positionfixes_postgis(
self._obj, name, con, schema, if_exists, index, index_label, chunksize, dtype
self, name, con, schema, if_exists, index, index_label, chunksize, dtype
)

@_copy_docstring(calculate_distance_matrix)
Expand All @@ -130,7 +149,7 @@ def calculate_distance_matrix(self, *args, **kwargs):
See :func:`trackintel.geogr.distances.calculate_distance_matrix`.
"""
return ti.geogr.distances.calculate_distance_matrix(self._obj, *args, **kwargs)
return ti.geogr.distances.calculate_distance_matrix(self, *args, **kwargs)

@_copy_docstring(get_speed_positionfixes)
def get_speed(self, *args, **kwargs):
Expand All @@ -139,4 +158,4 @@ def get_speed(self, *args, **kwargs):
See :func:`trackintel.model.util.get_speed_positionfixes`.
"""
return ti.model.util.get_speed_positionfixes(self._obj, *args, **kwargs)
return ti.model.util.get_speed_positionfixes(self, *args, **kwargs)
103 changes: 101 additions & 2 deletions trackintel/model/util.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from functools import partial, update_wrapper
import trackintel as ti
from functools import partial, update_wrapper, wraps
import numpy as np
import pandas as pd
import warnings
from geopandas import GeoDataFrame

import trackintel as ti
from trackintel.geogr.distances import calculate_haversine_length, check_gdf_planar
from trackintel.geogr.point_distances import haversine_dist

Expand Down Expand Up @@ -108,3 +109,101 @@ def _single_tripleg_mean_speed(positionfixes):
def _copy_docstring(wrapped, assigned=("__doc__",), updated=[]):
"""Thin wrapper for `functools.update_wrapper` to mimic `functools.wraps` but to only copy the docstring."""
return partial(update_wrapper, wrapped=wrapped, assigned=assigned, updated=updated)


def _wrapped_gdf_method(func):
"""Decorator function that downcast types to trackintel class if is GeoDataFrame and has the required columns."""

@wraps(func) # copy all metadata
def wrapper(self, *args, **kwargs):
result = func(self, *args, **kwargs)
if not isinstance(result, GeoDataFrame) or not self._has_required_columns(result):
return result
# is GeoDataFrame and has required columns -> is TrackintelClass
# as we don't have mutable attributes, we can just change the class field
result.__class__ = self.__class__
return result

return wrapper


class TrackintelGeoDataFrame(GeoDataFrame):
"""Helper class to subtype GeoDataFrame correctly."""

# Following methods manually set self.__class__ fix to GeoDataFrame.
# Thus to properly subtype, we need to downcast them with the _wrapped_gdf_method decorator.
@_wrapped_gdf_method
def __getitem__(self, key):
return super().__getitem__(key)

@_wrapped_gdf_method
def copy(self, deep=True):
return super().copy(deep=deep)

@_wrapped_gdf_method
def merge(self, *args, **kwargs):
return super().merge(*args, **kwargs)

Check warning on line 145 in trackintel/model/util.py

View check run for this annotation

Codecov / codecov/patch

trackintel/model/util.py#L145

Added line #L145 was not covered by tests

@property
def _constructor(self):
"""Interface to subtype pandas properly"""
super_cons = super()._constructor
class_cons = self.__class__
check = self._has_required_columns

def _constructor_with_fallback(*args, **kwargs):
result = super_cons(*args, **kwargs)
if isinstance(result, GeoDataFrame) and check(result):
return class_cons(result, validate=False)
return result

return _constructor_with_fallback


class TrackintelBase(object):
"""Class for supplying basic functionality to all Trackintel classes."""

# so far we don't have a lot of methods here
# but a lot of IO code can be moved here.

def _validate(self):
raise NotImplementedError

def _has_required_columns(self, obj): # maybe we can move this out to function that we'll call
for col in self.required_columns:
if col not in obj.columns:
return False

Check warning on line 175 in trackintel/model/util.py

View check run for this annotation

Codecov / codecov/patch

trackintel/model/util.py#L175

Added line #L175 was not covered by tests
return True

# maybe do some non-costly checks as well -> e.g. dtype


class NonCachedAccessor:
def __init__(self, name: str, accessor) -> None:
self._name = name
self._accessor = accessor

def __get__(self, obj, cls):
if obj is None:
# we're accessing the attribute of the class, i.e., Dataset.geo
return self._accessor

Check warning on line 189 in trackintel/model/util.py

View check run for this annotation

Codecov / codecov/patch

trackintel/model/util.py#L189

Added line #L189 was not covered by tests
# copied code from pandas accessor, minus the caching
return self._accessor(obj)


def _register_trackintel_accessor(name: str):
from pandas import DataFrame

def decorator(accessor):
if hasattr(DataFrame, name):
warnings.warn(

Check warning on line 199 in trackintel/model/util.py

View check run for this annotation

Codecov / codecov/patch

trackintel/model/util.py#L199

Added line #L199 was not covered by tests
f"registration of accessor {repr(accessor)} under name "
f"{repr(name)} for type {repr(DataFrame)} is overriding a preexisting "
f"attribute with the same name.",
UserWarning,
)
setattr(DataFrame, name, NonCachedAccessor(name, accessor))
DataFrame._accessors.add(name)
return accessor

return decorator
Loading

0 comments on commit c3b7486

Please sign in to comment.