From 20365224f9b57b794fb3c38a3c88c21490274ac7 Mon Sep 17 00:00:00 2001 From: David Shepherd Date: Mon, 3 Feb 2020 09:23:59 +0000 Subject: [PATCH] Don't allow missing snapshots without --snapshot-update --- snapshottest/module.py | 25 ++++++++++++++++++------- snapshottest/pytest.py | 6 +----- snapshottest/reporting.py | 12 ++++++++---- snapshottest/unittest.py | 7 +------ tests/test_formatter.py | 11 +---------- tests/test_pytest.py | 4 ++++ tests/test_snapshot_test.py | 15 +++++++++++++-- tests/test_sorted_dict.py | 10 ++-------- 8 files changed, 48 insertions(+), 42 deletions(-) diff --git a/snapshottest/module.py b/snapshottest/module.py index e31b46c..66f4614 100644 --- a/snapshottest/module.py +++ b/snapshottest/module.py @@ -24,6 +24,7 @@ def __init__(self, module, filepath): self.imports = defaultdict(set) self.visited_snapshots = set() self.new_snapshots = set() + self.missing_snapshots = set() self.failed_snapshots = set() self.imports["snapshottest"].add("Snapshot") @@ -89,6 +90,10 @@ def stats_visited_snapshots(cls): def stats_new_snapshots(cls): return cls.stats_for_module(lambda module: len(module.new_snapshots)) + @classmethod + def stats_missing_snapshots(cls): + return cls.stats_for_module(lambda module: len(module.missing_snapshots)) + @classmethod def stats_failed_snapshots(cls): return cls.stats_for_module(lambda module: len(module.failed_snapshots)) @@ -130,6 +135,9 @@ def __setitem__(self, key, value): def mark_failed(self, key): return self.failed_snapshots.add(key) + def mark_missing(self, key): + return self.missing_snapshots.add(key) + @property def snapshot_dir(self): return os.path.dirname(self.filepath) @@ -202,18 +210,15 @@ def get_module_for_testpath(cls, test_filepath): class SnapshotTest(object): _current_tester = None - def __init__(self): + def __init__(self, snapshot_should_update): self.curr_snapshot = "" self.snapshot_counter = 1 + self.snapshot_should_update = snapshot_should_update @property def module(self): raise NotImplementedError("module property needs to be implemented") - @property - def update(self): - return False - @property def test_name(self): raise NotImplementedError("test_name property needs to be implemented") @@ -232,6 +237,9 @@ def visit(self): def fail(self): self.module.mark_failed(self.test_name) + def missing(self): + self.module.mark_missing(self.test_name) + def store(self, data): formatter = Formatter.get_formatter(data) data = formatter.store(self, data) @@ -249,13 +257,16 @@ def assert_equals(self, value, snapshot): def assert_match(self, value, name=""): self.curr_snapshot = name or self.snapshot_counter self.visit() - if self.update: + if self.snapshot_should_update: self.store(value) else: try: prev_snapshot = self.module[self.test_name] except SnapshotNotFound: - self.store(value) # first time this test has been seen + # There is no snapshot for this test, run with --snapshot-update + # to create it + self.missing() + raise else: try: self.assert_value_matches_snapshot(value, prev_snapshot) diff --git a/snapshottest/pytest.py b/snapshottest/pytest.py index 2d40ca6..feeab9e 100644 --- a/snapshottest/pytest.py +++ b/snapshottest/pytest.py @@ -27,16 +27,12 @@ def pytest_addoption(parser): class PyTestSnapshotTest(SnapshotTest): def __init__(self, request=None): self.request = request - super(PyTestSnapshotTest, self).__init__() + super(PyTestSnapshotTest, self).__init__(request.config.option.snapshot_update) @property def module(self): return SnapshotModule.get_module_for_testpath(self.request.node.fspath.strpath) - @property - def update(self): - return self.request.config.option.snapshot_update - @property def test_name(self): cls_name = getattr(self.request.node.cls, "__name__", "") diff --git a/snapshottest/reporting.py b/snapshottest/reporting.py index 26ca51f..9b702bd 100644 --- a/snapshottest/reporting.py +++ b/snapshottest/reporting.py @@ -7,11 +7,17 @@ def reporting_lines(testing_cli): successful_snapshots = SnapshotModule.stats_successful_snapshots() bold = ["bold"] + new_snapshots = SnapshotModule.stats_new_snapshots() if successful_snapshots: yield (colored("{} snapshots passed", attrs=bold) + ".").format( successful_snapshots ) - new_snapshots = SnapshotModule.stats_new_snapshots() + missing_snapshots = SnapshotModule.stats_missing_snapshots() + if missing_snapshots[0]: + yield ( + colored("{} snapshots missing", "red", attrs=bold) + + " in {} test suites. Run with `--snapshot-update` to create them." + ).format(*missing_snapshots) if new_snapshots[0]: yield ( colored("{} snapshots written", "green", attrs=bold) + " in {} test suites." @@ -46,9 +52,7 @@ def diff_report(left, right): + colored("Received value", "red", attrs=["bold"]) + colored(" does not match ", attrs=["bold"]) + colored( - "stored snapshot `{}`".format( - left.snapshottest.test_name, - ), + "stored snapshot `{}`".format(left.snapshottest.test_name), "green", attrs=["bold"], ) diff --git a/snapshottest/unittest.py b/snapshottest/unittest.py index b68fce7..1187aba 100644 --- a/snapshottest/unittest.py +++ b/snapshottest/unittest.py @@ -13,17 +13,12 @@ def __init__(self, test_class, test_id, test_filepath, should_update, assertEqua self.test_id = test_id self.test_filepath = test_filepath self.assertEqual = assertEqual - self.should_update = should_update - super(UnitTestSnapshotTest, self).__init__() + super(UnitTestSnapshotTest, self).__init__(should_update) @property def module(self): return SnapshotModule.get_module_for_testpath(self.test_filepath) - @property - def update(self): - return self.should_update - def assert_equals(self, value, snapshot): self.assertEqual(value, snapshot) diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 8c53056..2956a7a 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -75,16 +75,7 @@ def test_can_normalize_iterator_objects(): @pytest.mark.parametrize( - "value", - [ - 0, - 12.7, - True, - False, - None, - float("-inf"), - float("inf"), - ], + "value", [0, 12.7, True, False, None, float("-inf"), float("inf")] ) def test_basic_formatting_parsing(value): formatter = Formatter() diff --git a/tests/test_pytest.py b/tests/test_pytest.py index 06b9fb4..b8d7f72 100644 --- a/tests/test_pytest.py +++ b/tests/test_pytest.py @@ -21,6 +21,7 @@ def pytest_snapshot_test(request, _apply_options): class TestPyTestSnapShotTest: def test_property_test_name(self, pytest_snapshot_test): + pytest_snapshot_test.snapshot_should_update = True pytest_snapshot_test.assert_match("counter") assert ( pytest_snapshot_test.test_name @@ -41,6 +42,7 @@ def test_property_test_name(self, pytest_snapshot_test): def test_pytest_snapshottest_property_test_name(pytest_snapshot_test): + pytest_snapshot_test.snapshot_should_update = True pytest_snapshot_test.assert_match("counter") assert ( pytest_snapshot_test.test_name @@ -64,6 +66,7 @@ def test_pytest_snapshottest_property_test_name(pytest_snapshot_test): def test_pytest_snapshottest_property_test_name_parametrize_singleline( pytest_snapshot_test, arg ): + pytest_snapshot_test.snapshot_should_update = True pytest_snapshot_test.assert_match("counter") assert ( pytest_snapshot_test.test_name @@ -84,6 +87,7 @@ def test_pytest_snapshottest_property_test_name_parametrize_singleline( def test_pytest_snapshottest_property_test_name_parametrize_multiline( pytest_snapshot_test, arg ): + pytest_snapshot_test.snapshot_should_update = True pytest_snapshot_test.assert_match("counter") assert ( pytest_snapshot_test.test_name diff --git a/tests/test_snapshot_test.py b/tests/test_snapshot_test.py index 9249478..97f44f3 100644 --- a/tests/test_snapshot_test.py +++ b/tests/test_snapshot_test.py @@ -4,6 +4,7 @@ from collections import OrderedDict from snapshottest.module import SnapshotModule, SnapshotTest +from snapshottest.error import SnapshotNotFound class GenericSnapshotTest(SnapshotTest): @@ -15,7 +16,7 @@ def __init__(self, snapshot_module, update=False, current_test_id=None): "update": update, "current_test_id": current_test_id or "test_mocked", } - super(GenericSnapshotTest, self).__init__() + super(GenericSnapshotTest, self).__init__(update) @property def module(self): @@ -33,7 +34,7 @@ def test_name(self): def reinitialize(self): """Reset internal state, as though starting a new test run""" - super(GenericSnapshotTest, self).__init__() + super(GenericSnapshotTest, self).__init__(False) def assert_snapshot_test_ran(snapshot_test, test_name=None): @@ -91,6 +92,7 @@ def fixture_snapshot_test(tmpdir): @pytest.mark.parametrize("value", SNAPSHOTABLE_VALUES, ids=repr) def test_snapshot_matches_itself(snapshot_test, value): # first run stores the value as the snapshot + snapshot_test.snapshot_should_update = True snapshot_test.assert_match(value) assert_snapshot_test_succeeded(snapshot_test) @@ -115,6 +117,7 @@ def test_snapshot_matches_itself(snapshot_test, value): ) def test_snapshot_does_not_match_other_values(snapshot_test, value, other_value): # first run stores the value as the snapshot + snapshot_test.snapshot_should_update = True snapshot_test.assert_match(value) assert_snapshot_test_succeeded(snapshot_test) @@ -123,3 +126,11 @@ def test_snapshot_does_not_match_other_values(snapshot_test, value, other_value) with pytest.raises(AssertionError): snapshot_test.assert_match(other_value) assert_snapshot_test_failed(snapshot_test) + + +def test_first_run_without_snapshots_fails(snapshot_test): + with pytest.raises(SnapshotNotFound): + snapshot_test.assert_match("foo", name="no_snapshot_exists_test") + assert snapshot_test.module.missing_snapshots == set( + ["test_mocked no_snapshot_exists_test"] + ) diff --git a/tests/test_sorted_dict.py b/tests/test_sorted_dict.py index b8217d8..f58711a 100644 --- a/tests/test_sorted_dict.py +++ b/tests/test_sorted_dict.py @@ -42,10 +42,7 @@ class Fruit(enum.IntEnum): APPLE = 1 ORANGE = 2 - dic = { - Fruit.APPLE: 100, - Fruit.ORANGE: 400, - } + dic = {Fruit.APPLE: 100, Fruit.ORANGE: 400} assert SortedDict(dic)[Fruit.APPLE] == dic[Fruit.APPLE] assert SortedDict(dic)[Fruit.ORANGE] == dic[Fruit.ORANGE] @@ -55,10 +52,7 @@ class Fruit(enum.Enum): APPLE = 1 ORANGE = 2 - dic = { - Fruit.APPLE: 100, - Fruit.ORANGE: 400, - } + dic = {Fruit.APPLE: 100, Fruit.ORANGE: 400} assert SortedDict(dic)[Fruit.APPLE] == dic[Fruit.APPLE] assert SortedDict(dic)[Fruit.ORANGE] == dic[Fruit.ORANGE]