Skip to content

Commit

Permalink
Don't allow missing snapshots without --snapshot-update
Browse files Browse the repository at this point in the history
  • Loading branch information
davidshepherd7 committed Sep 30, 2020
1 parent 9818a76 commit 2036522
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 42 deletions.
25 changes: 18 additions & 7 deletions snapshottest/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self, module, filepath):
self.imports = defaultdict(set)
self.visited_snapshots = set()
self.new_snapshots = set()
self.missing_snapshots = set()
self.failed_snapshots = set()
self.imports["snapshottest"].add("Snapshot")

Expand Down Expand Up @@ -89,6 +90,10 @@ def stats_visited_snapshots(cls):
def stats_new_snapshots(cls):
return cls.stats_for_module(lambda module: len(module.new_snapshots))

@classmethod
def stats_missing_snapshots(cls):
return cls.stats_for_module(lambda module: len(module.missing_snapshots))

@classmethod
def stats_failed_snapshots(cls):
return cls.stats_for_module(lambda module: len(module.failed_snapshots))
Expand Down Expand Up @@ -130,6 +135,9 @@ def __setitem__(self, key, value):
def mark_failed(self, key):
return self.failed_snapshots.add(key)

def mark_missing(self, key):
return self.missing_snapshots.add(key)

@property
def snapshot_dir(self):
return os.path.dirname(self.filepath)
Expand Down Expand Up @@ -202,18 +210,15 @@ def get_module_for_testpath(cls, test_filepath):
class SnapshotTest(object):
_current_tester = None

def __init__(self):
def __init__(self, snapshot_should_update):
self.curr_snapshot = ""
self.snapshot_counter = 1
self.snapshot_should_update = snapshot_should_update

@property
def module(self):
raise NotImplementedError("module property needs to be implemented")

@property
def update(self):
return False

@property
def test_name(self):
raise NotImplementedError("test_name property needs to be implemented")
Expand All @@ -232,6 +237,9 @@ def visit(self):
def fail(self):
self.module.mark_failed(self.test_name)

def missing(self):
self.module.mark_missing(self.test_name)

def store(self, data):
formatter = Formatter.get_formatter(data)
data = formatter.store(self, data)
Expand All @@ -249,13 +257,16 @@ def assert_equals(self, value, snapshot):
def assert_match(self, value, name=""):
self.curr_snapshot = name or self.snapshot_counter
self.visit()
if self.update:
if self.snapshot_should_update:
self.store(value)
else:
try:
prev_snapshot = self.module[self.test_name]
except SnapshotNotFound:
self.store(value) # first time this test has been seen
# There is no snapshot for this test, run with --snapshot-update
# to create it
self.missing()
raise
else:
try:
self.assert_value_matches_snapshot(value, prev_snapshot)
Expand Down
6 changes: 1 addition & 5 deletions snapshottest/pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,12 @@ def pytest_addoption(parser):
class PyTestSnapshotTest(SnapshotTest):
def __init__(self, request=None):
self.request = request
super(PyTestSnapshotTest, self).__init__()
super(PyTestSnapshotTest, self).__init__(request.config.option.snapshot_update)

@property
def module(self):
return SnapshotModule.get_module_for_testpath(self.request.node.fspath.strpath)

@property
def update(self):
return self.request.config.option.snapshot_update

@property
def test_name(self):
cls_name = getattr(self.request.node.cls, "__name__", "")
Expand Down
12 changes: 8 additions & 4 deletions snapshottest/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,17 @@
def reporting_lines(testing_cli):
successful_snapshots = SnapshotModule.stats_successful_snapshots()
bold = ["bold"]
new_snapshots = SnapshotModule.stats_new_snapshots()
if successful_snapshots:
yield (colored("{} snapshots passed", attrs=bold) + ".").format(
successful_snapshots
)
new_snapshots = SnapshotModule.stats_new_snapshots()
missing_snapshots = SnapshotModule.stats_missing_snapshots()
if missing_snapshots[0]:
yield (
colored("{} snapshots missing", "red", attrs=bold)
+ " in {} test suites. Run with `--snapshot-update` to create them."
).format(*missing_snapshots)
if new_snapshots[0]:
yield (
colored("{} snapshots written", "green", attrs=bold) + " in {} test suites."
Expand Down Expand Up @@ -46,9 +52,7 @@ def diff_report(left, right):
+ colored("Received value", "red", attrs=["bold"])
+ colored(" does not match ", attrs=["bold"])
+ colored(
"stored snapshot `{}`".format(
left.snapshottest.test_name,
),
"stored snapshot `{}`".format(left.snapshottest.test_name),
"green",
attrs=["bold"],
)
Expand Down
7 changes: 1 addition & 6 deletions snapshottest/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,12 @@ def __init__(self, test_class, test_id, test_filepath, should_update, assertEqua
self.test_id = test_id
self.test_filepath = test_filepath
self.assertEqual = assertEqual
self.should_update = should_update
super(UnitTestSnapshotTest, self).__init__()
super(UnitTestSnapshotTest, self).__init__(should_update)

@property
def module(self):
return SnapshotModule.get_module_for_testpath(self.test_filepath)

@property
def update(self):
return self.should_update

def assert_equals(self, value, snapshot):
self.assertEqual(value, snapshot)

Expand Down
11 changes: 1 addition & 10 deletions tests/test_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,7 @@ def test_can_normalize_iterator_objects():


@pytest.mark.parametrize(
"value",
[
0,
12.7,
True,
False,
None,
float("-inf"),
float("inf"),
],
"value", [0, 12.7, True, False, None, float("-inf"), float("inf")]
)
def test_basic_formatting_parsing(value):
formatter = Formatter()
Expand Down
4 changes: 4 additions & 0 deletions tests/test_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def pytest_snapshot_test(request, _apply_options):

class TestPyTestSnapShotTest:
def test_property_test_name(self, pytest_snapshot_test):
pytest_snapshot_test.snapshot_should_update = True
pytest_snapshot_test.assert_match("counter")
assert (
pytest_snapshot_test.test_name
Expand All @@ -41,6 +42,7 @@ def test_property_test_name(self, pytest_snapshot_test):


def test_pytest_snapshottest_property_test_name(pytest_snapshot_test):
pytest_snapshot_test.snapshot_should_update = True
pytest_snapshot_test.assert_match("counter")
assert (
pytest_snapshot_test.test_name
Expand All @@ -64,6 +66,7 @@ def test_pytest_snapshottest_property_test_name(pytest_snapshot_test):
def test_pytest_snapshottest_property_test_name_parametrize_singleline(
pytest_snapshot_test, arg
):
pytest_snapshot_test.snapshot_should_update = True
pytest_snapshot_test.assert_match("counter")
assert (
pytest_snapshot_test.test_name
Expand All @@ -84,6 +87,7 @@ def test_pytest_snapshottest_property_test_name_parametrize_singleline(
def test_pytest_snapshottest_property_test_name_parametrize_multiline(
pytest_snapshot_test, arg
):
pytest_snapshot_test.snapshot_should_update = True
pytest_snapshot_test.assert_match("counter")
assert (
pytest_snapshot_test.test_name
Expand Down
15 changes: 13 additions & 2 deletions tests/test_snapshot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections import OrderedDict

from snapshottest.module import SnapshotModule, SnapshotTest
from snapshottest.error import SnapshotNotFound


class GenericSnapshotTest(SnapshotTest):
Expand All @@ -15,7 +16,7 @@ def __init__(self, snapshot_module, update=False, current_test_id=None):
"update": update,
"current_test_id": current_test_id or "test_mocked",
}
super(GenericSnapshotTest, self).__init__()
super(GenericSnapshotTest, self).__init__(update)

@property
def module(self):
Expand All @@ -33,7 +34,7 @@ def test_name(self):

def reinitialize(self):
"""Reset internal state, as though starting a new test run"""
super(GenericSnapshotTest, self).__init__()
super(GenericSnapshotTest, self).__init__(False)


def assert_snapshot_test_ran(snapshot_test, test_name=None):
Expand Down Expand Up @@ -91,6 +92,7 @@ def fixture_snapshot_test(tmpdir):
@pytest.mark.parametrize("value", SNAPSHOTABLE_VALUES, ids=repr)
def test_snapshot_matches_itself(snapshot_test, value):
# first run stores the value as the snapshot
snapshot_test.snapshot_should_update = True
snapshot_test.assert_match(value)
assert_snapshot_test_succeeded(snapshot_test)

Expand All @@ -115,6 +117,7 @@ def test_snapshot_matches_itself(snapshot_test, value):
)
def test_snapshot_does_not_match_other_values(snapshot_test, value, other_value):
# first run stores the value as the snapshot
snapshot_test.snapshot_should_update = True
snapshot_test.assert_match(value)
assert_snapshot_test_succeeded(snapshot_test)

Expand All @@ -123,3 +126,11 @@ def test_snapshot_does_not_match_other_values(snapshot_test, value, other_value)
with pytest.raises(AssertionError):
snapshot_test.assert_match(other_value)
assert_snapshot_test_failed(snapshot_test)


def test_first_run_without_snapshots_fails(snapshot_test):
with pytest.raises(SnapshotNotFound):
snapshot_test.assert_match("foo", name="no_snapshot_exists_test")
assert snapshot_test.module.missing_snapshots == set(
["test_mocked no_snapshot_exists_test"]
)
10 changes: 2 additions & 8 deletions tests/test_sorted_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,7 @@ class Fruit(enum.IntEnum):
APPLE = 1
ORANGE = 2

dic = {
Fruit.APPLE: 100,
Fruit.ORANGE: 400,
}
dic = {Fruit.APPLE: 100, Fruit.ORANGE: 400}
assert SortedDict(dic)[Fruit.APPLE] == dic[Fruit.APPLE]
assert SortedDict(dic)[Fruit.ORANGE] == dic[Fruit.ORANGE]

Expand All @@ -55,10 +52,7 @@ class Fruit(enum.Enum):
APPLE = 1
ORANGE = 2

dic = {
Fruit.APPLE: 100,
Fruit.ORANGE: 400,
}
dic = {Fruit.APPLE: 100, Fruit.ORANGE: 400}
assert SortedDict(dic)[Fruit.APPLE] == dic[Fruit.APPLE]
assert SortedDict(dic)[Fruit.ORANGE] == dic[Fruit.ORANGE]

Expand Down

0 comments on commit 2036522

Please sign in to comment.