Skip to content

Commit

Permalink
ENH: remove caching in DataFrame accessors (#505)
Browse files Browse the repository at this point in the history
* ENH: remove caching in DataFrame accessors

* CLN: replace lambda with function definition

* CLN: remove more lambdas
  • Loading branch information
bifbof authored Aug 10, 2023
1 parent 2371b04 commit a0e9891
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 42 deletions.
47 changes: 20 additions & 27 deletions tests/io/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def test_from_to_csv(self):
column_mapping = {"lat": "latitude", "lon": "longitude", "time": "tracked_at"}
mod_pfs = ti.read_positionfixes_csv(mod_file, sep=";", index_col="id", columns=column_mapping)
assert mod_pfs.equals(pfs)
pfs["tracked_at"] = pfs["tracked_at"].apply(lambda d: d.isoformat().replace("+00:00", "Z"))

date_format = "%Y-%m-%dT%H:%M:%SZ"
columns = ["user_id", "tracked_at", "latitude", "longitude", "elevation", "accuracy"]
pfs.as_positionfixes.to_csv(tmp_file, sep=";", columns=columns)
pfs.as_positionfixes.to_csv(tmp_file, sep=";", columns=columns, date_format=date_format)
assert filecmp.cmp(orig_file, tmp_file, shallow=False)
os.remove(tmp_file)

Expand All @@ -49,11 +49,10 @@ def test_set_datatime_tz(self):
pfs = ti.read_positionfixes_csv(file, sep=";", index_col="id")
assert pd.api.types.is_datetime64tz_dtype(pfs["tracked_at"])

# check if a timezone will be set after manually deleting the timezone
pfs["tracked_at"] = pfs["tracked_at"].dt.tz_localize(None)
assert not pd.api.types.is_datetime64tz_dtype(pfs["tracked_at"])
# check if a timezone will be set without storing the timezone
date_format = "%Y-%m-%d %H:%M:%S"
tmp_file = os.path.join("tests", "data", "positionfixes_test_2.csv")
pfs.as_positionfixes.to_csv(tmp_file, sep=";")
pfs.as_positionfixes.to_csv(tmp_file, sep=";", date_format=date_format)
pfs = ti.read_positionfixes_csv(tmp_file, sep=";", index_col="id", tz="utc")

assert pd.api.types.is_datetime64tz_dtype(pfs["tracked_at"])
Expand Down Expand Up @@ -94,11 +93,10 @@ def test_from_to_csv(self):
mod_tpls = ti.read_triplegs_csv(mod_file, sep=";", columns=column_mapping, index_col="id")

assert mod_tpls.equals(tpls)
tpls["started_at"] = tpls["started_at"].apply(lambda d: d.isoformat().replace("+00:00", "Z"))
tpls["finished_at"] = tpls["finished_at"].apply(lambda d: d.isoformat().replace("+00:00", "Z"))

date_format = "%Y-%m-%dT%H:%M:%SZ"
columns = ["user_id", "started_at", "finished_at", "geom"]
tpls.as_triplegs.to_csv(tmp_file, sep=";", columns=columns)
tpls.as_triplegs.to_csv(tmp_file, sep=";", columns=columns, date_format=date_format)
assert filecmp.cmp(orig_file, tmp_file, shallow=False)
os.remove(tmp_file)

Expand All @@ -119,11 +117,10 @@ def test_set_datatime_tz(self):
tpls = ti.read_triplegs_csv(file, sep=";", index_col="id")
assert pd.api.types.is_datetime64tz_dtype(tpls["started_at"])

# check if a timezone will be set after manually deleting the timezone
tpls["started_at"] = tpls["started_at"].dt.tz_localize(None)
assert not pd.api.types.is_datetime64tz_dtype(tpls["started_at"])
# check if a timezone will be set without storing the timezone
tmp_file = os.path.join("tests", "data", "triplegs_test_2.csv")
tpls.as_triplegs.to_csv(tmp_file, sep=";")
date_format = "%Y-%m-%d %H:%M:%S"
tpls.as_triplegs.to_csv(tmp_file, sep=";", date_format=date_format)
tpls = ti.read_triplegs_csv(tmp_file, sep=";", index_col="id", tz="utc")

assert pd.api.types.is_datetime64tz_dtype(tpls["started_at"])
Expand Down Expand Up @@ -161,11 +158,10 @@ def test_from_to_csv(self):
sp = ti.read_staypoints_csv(orig_file, sep=";", tz="utc", index_col="id")
mod_sp = ti.read_staypoints_csv(mod_file, columns={"User": "user_id"}, sep=";", index_col="id")
assert mod_sp.equals(sp)
sp["started_at"] = sp["started_at"].apply(lambda d: d.isoformat().replace("+00:00", "Z"))
sp["finished_at"] = sp["finished_at"].apply(lambda d: d.isoformat().replace("+00:00", "Z"))

date_format = "%Y-%m-%dT%H:%M:%SZ"
columns = ["user_id", "started_at", "finished_at", "elevation", "geom"]
sp.as_staypoints.to_csv(tmp_file, sep=";", columns=columns)
sp.as_staypoints.to_csv(tmp_file, sep=";", columns=columns, date_format=date_format)
assert filecmp.cmp(orig_file, tmp_file, shallow=False)
os.remove(tmp_file)

Expand All @@ -186,11 +182,10 @@ def test_set_datatime_tz(self):
sp = ti.read_staypoints_csv(file, sep=";", index_col="id")
assert pd.api.types.is_datetime64tz_dtype(sp["started_at"])

# check if a timezone will be set after manually deleting the timezone
sp["started_at"] = sp["started_at"].dt.tz_localize(None)
assert not pd.api.types.is_datetime64tz_dtype(sp["started_at"])
# check if a timezone will be without storing the timezone
tmp_file = os.path.join("tests", "data", "staypoints_test_2.csv")
sp.as_staypoints.to_csv(tmp_file, sep=";")
date_format = "%Y-%m-%d %H:%M:%S"
sp.as_staypoints.to_csv(tmp_file, sep=";", date_format=date_format)
sp = ti.read_staypoints_csv(tmp_file, sep=";", index_col="id", tz="utc")

assert pd.api.types.is_datetime64tz_dtype(sp["started_at"])
Expand Down Expand Up @@ -299,10 +294,9 @@ def test_from_to_csv(self):
mod_trips_wo_geom = pd.DataFrame(mod_trips.drop(columns=["geom"]))
assert mod_trips_wo_geom.equals(trips)

trips["started_at"] = trips["started_at"].apply(lambda d: d.isoformat().replace("+00:00", "Z"))
trips["finished_at"] = trips["finished_at"].apply(lambda d: d.isoformat().replace("+00:00", "Z"))
date_format = "%Y-%m-%dT%H:%M:%SZ"
columns = ["user_id", "started_at", "finished_at", "origin_staypoint_id", "destination_staypoint_id"]
trips.as_trips.to_csv(tmp_file, sep=";", columns=columns)
trips.as_trips.to_csv(tmp_file, sep=";", columns=columns, date_format=date_format)
assert filecmp.cmp(orig_file, tmp_file, shallow=False)
os.remove(tmp_file)

Expand All @@ -313,11 +307,10 @@ def test_set_datatime_tz(self):
trips = ti.read_trips_csv(file, sep=";", index_col="id")
assert pd.api.types.is_datetime64tz_dtype(trips["started_at"])

# check if a timezone will be set after manually deleting the timezone
trips["started_at"] = trips["started_at"].dt.tz_localize(None)
assert not pd.api.types.is_datetime64tz_dtype(trips["started_at"])
# check if a timezone will be set without storing the timezone
tmp_file = os.path.join("tests", "data", "trips_test_2.csv")
trips.as_trips.to_csv(tmp_file, sep=";")
date_format = "%Y-%m-%d %H:%M:%S"
trips.as_trips.to_csv(tmp_file, sep=";", date_format=date_format)
trips = ti.read_trips_csv(tmp_file, sep=";", index_col="id", tz="utc")

assert pd.api.types.is_datetime64tz_dtype(trips["started_at"])
Expand Down
60 changes: 57 additions & 3 deletions tests/model/test_util.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import os
from functools import WRAPPER_ASSIGNMENTS

from geopandas import GeoDataFrame
import numpy as np
from pandas import Timestamp, Timedelta
import pandas as pd
import pytest
from geopandas import GeoDataFrame
from geopandas.testing import assert_geodataframe_equal
from pandas import Timedelta, Timestamp
from shapely.geometry import Point

import trackintel as ti
from trackintel.io.postgis import read_trips_postgis
from trackintel.model.util import _copy_docstring, get_speed_positionfixes
from trackintel.model.util import (
NonCachedAccessor,
_copy_docstring,
_register_trackintel_accessor,
get_speed_positionfixes,
)


@pytest.fixture
Expand Down Expand Up @@ -207,3 +213,51 @@ def bar(b: int) -> int:
assert attr_bar != old_docs
else:
assert attr_foo != attr_bar


class TestNonCachedAccessor:
"""Test if NonCachedAccessor works"""

def test_accessor(self):
"""Test accessor on class object and class instance."""

def foo(val):
return val

class A:
nca = NonCachedAccessor("nca_test", foo)

a = A()
assert A.nca == foo # class object
assert a.nca == a # class instance


class Test_register_trackintel_accessor:
"""Test if accessors are correctly registered."""

def test_register(self):
"""Test if accessor is registered in DataFrame"""

def foo(val):
return val

bar = _register_trackintel_accessor("foo")(foo)
assert foo == bar
assert "foo" in pd.DataFrame._accessors
assert foo == pd.DataFrame.foo
# remove accessor again to make tests independent
pd.DataFrame._accesors = pd.DataFrame._accessors.remove("foo")
del pd.DataFrame.foo

def test_duplicate_name_warning(self):
"""Test that duplicate name raises warning"""

def foo(val):
return val

_register_trackintel_accessor("foo")(foo)
with pytest.warns(UserWarning):
_register_trackintel_accessor("foo")(foo)
# remove accessor again to make tests independent
pd.DataFrame._accesors = pd.DataFrame._accessors.remove("foo")
del pd.DataFrame.foo
5 changes: 2 additions & 3 deletions trackintel/model/locations.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import pandas as pd
import trackintel as ti
import trackintel.io
from trackintel.io.file import write_locations_csv
from trackintel.io.postgis import write_locations_postgis
from trackintel.model.util import _copy_docstring
from trackintel.model.util import _copy_docstring, _register_trackintel_accessor
from trackintel.preprocessing.filter import spatial_filter
from trackintel.visualization.locations import plot_locations


@pd.api.extensions.register_dataframe_accessor("as_locations")
@_register_trackintel_accessor("as_locations")
class LocationsAccessor(object):
"""A pandas accessor to treat (Geo)DataFrames as collections of locations.
Expand Down
4 changes: 2 additions & 2 deletions trackintel/model/positionfixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from trackintel.model.util import _copy_docstring
from trackintel.preprocessing.positionfixes import generate_staypoints, generate_triplegs
from trackintel.visualization.positionfixes import plot_positionfixes
from trackintel.model.util import get_speed_positionfixes
from trackintel.model.util import get_speed_positionfixes, _register_trackintel_accessor


@pd.api.extensions.register_dataframe_accessor("as_positionfixes")
@_register_trackintel_accessor("as_positionfixes")
class PositionfixesAccessor(object):
"""A pandas accessor to treat (Geo)DataFrames as collections of `Positionfixes`.
Expand Down
4 changes: 2 additions & 2 deletions trackintel/model/staypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from trackintel.analysis.tracking_quality import temporal_tracking_quality
from trackintel.io.file import write_staypoints_csv
from trackintel.io.postgis import write_staypoints_postgis
from trackintel.model.util import _copy_docstring
from trackintel.model.util import _copy_docstring, _register_trackintel_accessor
from trackintel.preprocessing.filter import spatial_filter
from trackintel.preprocessing.staypoints import generate_locations, merge_staypoints
from trackintel.visualization.staypoints import plot_staypoints


@pd.api.extensions.register_dataframe_accessor("as_staypoints")
@_register_trackintel_accessor("as_staypoints")
class StaypointsAccessor(object):
"""A pandas accessor to treat (Geo)DataFrames as collections of `Staypoints`.
Expand Down
3 changes: 2 additions & 1 deletion trackintel/model/tours.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import pandas as pd
import trackintel as ti
from trackintel.model.util import _register_trackintel_accessor


@pd.api.extensions.register_dataframe_accessor("as_tours")
@_register_trackintel_accessor("as_tours")
class ToursAccessor(object):
"""A pandas accessor to treat DataFrames as collections of `Tours`.
Expand Down
4 changes: 2 additions & 2 deletions trackintel/model/triplegs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
from trackintel.geogr.distances import calculate_distance_matrix
from trackintel.io.file import write_triplegs_csv
from trackintel.io.postgis import write_triplegs_postgis
from trackintel.model.util import _copy_docstring, get_speed_triplegs
from trackintel.model.util import _copy_docstring, get_speed_triplegs, _register_trackintel_accessor
from trackintel.preprocessing.filter import spatial_filter
from trackintel.preprocessing.triplegs import generate_trips
from trackintel.visualization.triplegs import plot_triplegs


@pd.api.extensions.register_dataframe_accessor("as_triplegs")
@_register_trackintel_accessor("as_triplegs")
class TriplegsAccessor(object):
"""A pandas accessor to treat (Geo)DataFrames as collections of `Tripleg`.
Expand Down
4 changes: 2 additions & 2 deletions trackintel/model/trips.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from trackintel.analysis.tracking_quality import temporal_tracking_quality
from trackintel.io.postgis import write_trips_postgis
from trackintel.io.file import write_trips_csv
from trackintel.model.util import _copy_docstring
from trackintel.model.util import _copy_docstring, _register_trackintel_accessor
import pandas as pd
import geopandas as gpd

import trackintel as ti


@pd.api.extensions.register_dataframe_accessor("as_trips")
@_register_trackintel_accessor("as_trips")
class TripsAccessor(object):
"""A pandas accessor to treat (Geo)DataFrames as collections of trips.
Expand Down
31 changes: 31 additions & 0 deletions trackintel/model/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,34 @@ def _single_tripleg_mean_speed(positionfixes):
def _copy_docstring(wrapped, assigned=("__doc__",), updated=[]):
"""Thin wrapper for `functools.update_wrapper` to mimic `functools.wraps` but to only copy the docstring."""
return partial(update_wrapper, wrapped=wrapped, assigned=assigned, updated=updated)


class NonCachedAccessor:
def __init__(self, name: str, accessor) -> None:
self._name = name
self._accessor = accessor

def __get__(self, obj, cls):
if obj is None:
# we're accessing the attribute of the class, i.e., Dataset.geo
return self._accessor
# copied code from pandas accessor, minus the caching
return self._accessor(obj)


def _register_trackintel_accessor(name: str):
from pandas import DataFrame

def decorator(accessor):
if hasattr(DataFrame, name):
warnings.warn(
f"registration of accessor {repr(accessor)} under name "
f"{repr(name)} for type {repr(DataFrame)} is overriding a preexisting "
f"attribute with the same name.",
UserWarning,
)
setattr(DataFrame, name, NonCachedAccessor(name, accessor))
DataFrame._accessors.add(name)
return accessor

return decorator

0 comments on commit a0e9891

Please sign in to comment.