Skip to content

Commit

Permalink
feat: remove unused snapshots in snapshot file (#11)
Browse files Browse the repository at this point in the history
* wip

* wip: use correct execution number in registering

* refactor: change album type to snapshotfiles

* test: refactor test injection string

* wip: always expand on snapshots in discovered files

* refactor: register assertion

* refactor: simplify snapshot files diff logic

* refactor: documentation and type hints

* refactor: simplify snapshot files count logic

* refactor: rename variables

* refactor: fix write snapshot conditional

* refactor: fix used snapshots

* cr: use explicit class type

* cr: move filename length logic into clean method

* refactor: use generator in sum

* refactor: iterate on values when counting snapshots

* cr: pythonic

* chore: delete unused snapshots

* chore: add extra line to githooks ini file

for some reason it keeps adding it
  • Loading branch information
iamogbz authored Dec 4, 2019
1 parent e9de910 commit a5c46b1
Show file tree
Hide file tree
Showing 14 changed files with 323 additions and 135 deletions.
1 change: 1 addition & 0 deletions .githooks.ini
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
[pre-commit]
command = inv lint

1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def readme() -> str:
use_scm_version={"write_to": "version.txt"},
package_dir={"": "src"},
packages=["syrupy"],
py_modules=["syrupy"],
zip_safe=False,
install_requires=[],
setup_requires=["setuptools_scm"],
Expand Down
7 changes: 5 additions & 2 deletions src/syrupy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,11 @@ def snapshot(request):
classname=request.cls.__name__ if request.cls else None,
methodname=request.function.__name__ if request.function else None,
nodename=getattr(request.node, "name", ""),
testname=getattr(request.node, "name", "")
or (request.function.__name__ if request.function else None),
testname=getattr(
request.node,
"name",
request.function.__name__ if request.function else None,
),
)
return SnapshotAssertion(
update_snapshots=request.config.option.update_snapshots,
Expand Down
63 changes: 36 additions & 27 deletions src/syrupy/assertion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import traceback
import pytest
import os
from typing import List, Optional, Any
from typing import Any, Callable, List, Optional, Type

from .exceptions import SnapshotDoesNotExist

Expand All @@ -15,8 +15,8 @@ def __init__(
self,
*,
update_snapshots: bool,
io_class: SnapshotIO,
serializer_class: SnapshotSerializer,
io_class: Type[SnapshotIO],
serializer_class: Type[SnapshotSerializer],
test_location: TestLocation,
session,
):
Expand All @@ -25,24 +25,34 @@ def __init__(
self._serializer_class = serializer_class
self._test_location = test_location
self._executions = 0
self._session = session

from .session import SnapshotSession

self._session: SnapshotSession = session
self._session.register_request(self)

@property
def io(self):
def io(self) -> SnapshotIO:
if not getattr(self, "_io", None):
self._io = self._io_class(
test_location=self._test_location, file_hook=self._file_hook
)
return self._io

@property
def serializer(self):
def serializer(self) -> SnapshotSerializer:
if not getattr(self, "_serializer", None):
self._serializer = self._serializer_class()
return self._serializer

@property
def num_executions(self) -> int:
return int(self._executions)

def with_class(
self, io_class: SnapshotIO = None, serializer_class: SnapshotSerializer = None
self,
io_class: Type[SnapshotIO] = None,
serializer_class: Type[SnapshotSerializer] = None,
):
return self.__class__(
update_snapshots=self._update_snapshots,
Expand All @@ -56,7 +66,7 @@ def assert_match(self, data) -> bool:
return self._assert(data)

def get_assert_diff(self, data) -> List[str]:
deserialized = self._recall_data(index=self._executions - 1)
deserialized = self._recall_data(index=self.num_executions - 1)
if deserialized is None:
return ["Snapshot does not exist!"]

Expand All @@ -65,11 +75,11 @@ def get_assert_diff(self, data) -> List[str]:

return []

def _file_hook(self, filepath):
self._session.add_visited_file(filepath)
def _file_hook(self, filepath, snapshot_name):
self._session.add_visited_snapshots({filepath: {snapshot_name}})

def __repr__(self) -> str:
return f"<SnapshotAssertion ({self._executions})>"
return f"<SnapshotAssertion ({self.num_executions})>"

def __call__(self, data) -> bool:
return self._assert(data)
Expand All @@ -78,26 +88,25 @@ def __eq__(self, other) -> bool:
return self._assert(other)

def _assert(self, data) -> bool:
executions = self._executions
self._executions += 1

if self._update_snapshots:
serialized_data = self.serializer.encode(data)
self.io.pre_write(serialized_data, index=executions)
filepath = self.io.write(serialized_data, index=executions)
self.io.post_write(serialized_data, index=executions)
self._session.register_assertion(self)
try:
if self._update_snapshots:
serialized_data = self.serializer.encode(data)
self.io.create_or_update_snapshot(
serialized_data, index=self.num_executions
)
return True

deserialized = self._recall_data(index=self.num_executions)
if deserialized is None or data != deserialized:
return False
return True

deserialized = self._recall_data(index=executions)
if deserialized is None or data != deserialized:
return False
return True
finally:
self._executions += 1

def _recall_data(self, index: int) -> Optional[Any]:
try:
self.io.pre_read(index=index)
saved_data = self.io.read(index=index)
self.io.post_read(index=index)
saved_data = self.io.read_snapshot(index=index)
return self.serializer.decode(saved_data)
except SnapshotDoesNotExist:
return None
152 changes: 118 additions & 34 deletions src/syrupy/io.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,92 @@
from typing import Any, Optional
from typing import Any, Callable, Optional, Set

import os
import yaml

from .constants import SNAPSHOT_DIRNAME
from .exceptions import SnapshotDoesNotExist
from .location import TestLocation
from .types import SnapshotFiles


class SnapshotIO:
def __init__(self, test_location: TestLocation, file_hook):
self._test_location = test_location
self._file_hook = file_hook

def pre_write(self, data: Any, index: int = 0):
self.ensure_snapshot_dir(index)
@property
def test_location(self):
return self._test_location

def write(self, data: Any, index: int = 0):
snapshot_name = self.get_snapshot_name(index)
snapshots = self._load_documents(index)
snapshots[snapshot_name] = snapshots.get(snapshot_name, {})
snapshots[snapshot_name]["data"] = data
with open(self.get_filepath(index), "w") as f:
yaml.safe_dump(snapshots, f)
@property
def dirname(self) -> str:
test_dirname = os.path.dirname(self.test_location.filename)
snapshot_dir = self._get_snapshot_dirname()
if snapshot_dir is not None:
return os.path.join(test_dirname, SNAPSHOT_DIRNAME, snapshot_dir)
return os.path.join(test_dirname, SNAPSHOT_DIRNAME)

def post_write(self, data: Any, index: int = 0):
self._file_hook(self.get_filepath(index))
def discover_snapshots(self, filepath: str) -> Set[str]:
"""
Utility method for getting all the snapshots from a file.
Returns an empty set if the file cannot be read.
"""
try:
return set(self._read_file(filepath).keys())
except:
return set()

def read_snapshot(self, index: int) -> Any:
"""
Utility method for reading the contents of a snapshot assertion.
Will call `pre_read`, then `read` and finally `post_read`,
returning the contents parsed from the `read` method.
"""
try:
self.pre_read(index=index)
return self.read(index=index)
finally:
self.post_read(index=index)

def create_or_update_snapshot(self, serialized_data: Any, index: int):
"""
Utility method for reading the contents of a snapshot assertion.
Will call `pre_write`, then `write` and finally `post_write`.
"""
self.pre_write(serialized_data, index=index)
self.write(serialized_data, index=index)
self.post_write(serialized_data, index=index)

def delete_snapshot(self, snapshot_file: str, snapshot_name: str):
"""
Utility method for removing a snapshot from a snapshot file.
"""
self._write_snapshot_or_remove_file(snapshot_file, snapshot_name, None)

def pre_read(self, index: int = 0):
pass

def read(self, index: int = 0) -> Any:
snapshot_file = self.get_filepath(index)
snapshot_name = self.get_snapshot_name(index)
snapshots = self._load_documents(index)
snapshot = snapshots.get(snapshot_name, None)
snapshot = self._read_snapshot_from_file(snapshot_file, snapshot_name)
if snapshot is None:
raise SnapshotDoesNotExist()
return snapshot["data"]
return snapshot

def post_read(self, index: int = 0):
self._file_hook(self.get_filepath(index))
self._snap_file_hook(index)

def pre_write(self, data: Any, index: int = 0):
self._ensure_snapshot_dir(index)

def write(self, data: Any, index: int = 0):
snapshot_file = self.get_filepath(index)
snapshot_name = self.get_snapshot_name(index)
self._write_snapshot_or_remove_file(snapshot_file, snapshot_name, data)

def post_write(self, data: Any, index: int = 0):
self._snap_file_hook(index)

def get_snapshot_name(self, index: int = 0) -> str:
index_suffix = f".{index}" if index > 0 else ""
Expand All @@ -49,37 +96,74 @@ def get_snapshot_name(self, index: int = 0) -> str:
return f"{self._test_location.classname}.{methodname}{index_suffix}"
return f"{methodname}{index_suffix}"

def get_snapshot_dirname(self) -> Optional[str]:
return None

def get_filepath(self, index: int) -> str:
basename = self.get_file_basename(index=index)
return os.path.join(self._get_dirname(), basename)
return os.path.join(self.dirname, basename)

def get_file_basename(self, index: int) -> str:
return f"{os.path.basename(self._test_location.filename)[: -len('.py')]}.yaml"
return f"{os.path.splitext(os.path.basename(self._test_location.filename))[0]}.yaml"

def _get_snapshot_dirname(self) -> Optional[str]:
return None

def ensure_snapshot_dir(self, index: int):
def _ensure_snapshot_dir(self, index: int):
"""
Ensures the folder path for the snapshot file exists.
"""
try:
os.makedirs(os.path.dirname(self.get_filepath(index)))
except FileExistsError:
pass

@property
def test_location(self):
return self._test_location

def _load_documents(self, index: int) -> dict:
def _read_snapshot_from_file(self, snapshot_file: str, snapshot_name: str) -> Any:
"""
Read the snapshot file and get only the snapshot data for assertion
"""
snapshots = self._read_file(snapshot_file)
return snapshots.get(snapshot_name, {}).get("data", None)

def _read_file(self, filepath: str) -> Any:
"""
Read the snapshot data from the snapshot file into a python instance.
"""
try:
with open(self.get_filepath(index), "r") as f:
with open(filepath, "r") as f:
return yaml.safe_load(f) or {}
except FileNotFoundError:
pass
return {}

def _get_dirname(self) -> str:
test_dirname = os.path.dirname(self._test_location.filename)
snapshot_dir = self.get_snapshot_dirname()
if snapshot_dir is not None:
return os.path.join(test_dirname, SNAPSHOT_DIRNAME, snapshot_dir)
return os.path.join(test_dirname, SNAPSHOT_DIRNAME)
def _write_snapshot_or_remove_file(
self, snapshot_file: str, snapshot_name: str, data: Any
):
"""
Adds the snapshot data to the snapshots read from the file
or removes the snapshot entry if data is `None`.
If the snapshot file will be empty remove the entire file.
"""
snapshots = self._read_file(snapshot_file)
if data is None and snapshot_name in snapshots:
del snapshots[snapshot_name]
else:
snapshots[snapshot_name] = snapshots.get(snapshot_name, {})
snapshots[snapshot_name]["data"] = data

if snapshots:
self._write_file(snapshot_file, snapshots)
else:
os.remove(snapshot_file)

def _write_file(self, filepath: str, data: Any):
"""
Writes the snapshot data into the snapshot file that be read later.
"""
with open(filepath, "w") as f:
yaml.safe_dump(data, f)

def _snap_file_hook(self, index: int):
"""
Notify the assertion of an access to a snapshot in a file
"""
snapshot_file = self.get_filepath(index)
snapshot_name = self.get_snapshot_name(index)
self._file_hook(snapshot_file, snapshot_name)
4 changes: 2 additions & 2 deletions src/syrupy/plugins/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

class PNGImageSnapshotIO(AbstractImageSnapshotIO):
@property
def extension(self):
def extension(self) -> str:
return "png"


class SVGImageSnapshotIO(AbstractImageSnapshotIO):
@property
def extension(self):
def extension(self) -> str:
return "svg"
Loading

0 comments on commit a5c46b1

Please sign in to comment.