diff --git a/snapshottest/module.py b/snapshottest/module.py index e31b46c..66f4614 100644 --- a/snapshottest/module.py +++ b/snapshottest/module.py @@ -24,6 +24,7 @@ def __init__(self, module, filepath): self.imports = defaultdict(set) self.visited_snapshots = set() self.new_snapshots = set() + self.missing_snapshots = set() self.failed_snapshots = set() self.imports["snapshottest"].add("Snapshot") @@ -89,6 +90,10 @@ def stats_visited_snapshots(cls): def stats_new_snapshots(cls): return cls.stats_for_module(lambda module: len(module.new_snapshots)) + @classmethod + def stats_missing_snapshots(cls): + return cls.stats_for_module(lambda module: len(module.missing_snapshots)) + @classmethod def stats_failed_snapshots(cls): return cls.stats_for_module(lambda module: len(module.failed_snapshots)) @@ -130,6 +135,9 @@ def __setitem__(self, key, value): def mark_failed(self, key): return self.failed_snapshots.add(key) + def mark_missing(self, key): + return self.missing_snapshots.add(key) + @property def snapshot_dir(self): return os.path.dirname(self.filepath) @@ -202,18 +210,15 @@ def get_module_for_testpath(cls, test_filepath): class SnapshotTest(object): _current_tester = None - def __init__(self): + def __init__(self, snapshot_should_update): self.curr_snapshot = "" self.snapshot_counter = 1 + self.snapshot_should_update = snapshot_should_update @property def module(self): raise NotImplementedError("module property needs to be implemented") - @property - def update(self): - return False - @property def test_name(self): raise NotImplementedError("test_name property needs to be implemented") @@ -232,6 +237,9 @@ def visit(self): def fail(self): self.module.mark_failed(self.test_name) + def missing(self): + self.module.mark_missing(self.test_name) + def store(self, data): formatter = Formatter.get_formatter(data) data = formatter.store(self, data) @@ -249,13 +257,16 @@ def assert_equals(self, value, snapshot): def assert_match(self, value, name=""): self.curr_snapshot = name or self.snapshot_counter self.visit() - if self.update: + if self.snapshot_should_update: self.store(value) else: try: prev_snapshot = self.module[self.test_name] except SnapshotNotFound: - self.store(value) # first time this test has been seen + # There is no snapshot for this test, run with --snapshot-update + # to create it + self.missing() + raise else: try: self.assert_value_matches_snapshot(value, prev_snapshot) diff --git a/snapshottest/nose.py b/snapshottest/nose.py index 371734d..e1714fb 100644 --- a/snapshottest/nose.py +++ b/snapshottest/nose.py @@ -4,6 +4,7 @@ from nose.plugins import Plugin +from .parse_env import env_snapshot_update from .module import SnapshotModule from .reporting import reporting_lines from .unittest import TestCase @@ -37,7 +38,7 @@ def options(self, parser, env=os.environ): def configure(self, options, conf): super(SnapshotTestPlugin, self).configure(options, conf) - self.snapshot_update = options.snapshot_update + self.snapshot_update = options.snapshot_update or env_snapshot_update() self.enabled = not options.snapshot_disable def wantClass(self, cls): diff --git a/snapshottest/parse_env.py b/snapshottest/parse_env.py new file mode 100644 index 0000000..db83a8c --- /dev/null +++ b/snapshottest/parse_env.py @@ -0,0 +1,9 @@ +import os + + +def _env_bool(val): + return val.lower() in ["1", "yes", "true", "t", "y"] + + +def env_snapshot_update(): + return _env_bool(os.environ.get("SNAPSHOT_UPDATE", "false")) diff --git a/snapshottest/pytest.py b/snapshottest/pytest.py index 2d40ca6..da38e29 100644 --- a/snapshottest/pytest.py +++ b/snapshottest/pytest.py @@ -2,6 +2,7 @@ import pytest import re +from .parse_env import env_snapshot_update from .module import SnapshotModule, SnapshotTest from .diff import PrettyDiff from .reporting import reporting_lines, diff_report @@ -27,16 +28,13 @@ def pytest_addoption(parser): class PyTestSnapshotTest(SnapshotTest): def __init__(self, request=None): self.request = request - super(PyTestSnapshotTest, self).__init__() + should_update = request.config.option.snapshot_update or env_snapshot_update() + super(PyTestSnapshotTest, self).__init__(should_update) @property def module(self): return SnapshotModule.get_module_for_testpath(self.request.node.fspath.strpath) - @property - def update(self): - return self.request.config.option.snapshot_update - @property def test_name(self): cls_name = getattr(self.request.node.cls, "__name__", "") diff --git a/snapshottest/reporting.py b/snapshottest/reporting.py index 26ca51f..9b702bd 100644 --- a/snapshottest/reporting.py +++ b/snapshottest/reporting.py @@ -7,11 +7,17 @@ def reporting_lines(testing_cli): successful_snapshots = SnapshotModule.stats_successful_snapshots() bold = ["bold"] + new_snapshots = SnapshotModule.stats_new_snapshots() if successful_snapshots: yield (colored("{} snapshots passed", attrs=bold) + ".").format( successful_snapshots ) - new_snapshots = SnapshotModule.stats_new_snapshots() + missing_snapshots = SnapshotModule.stats_missing_snapshots() + if missing_snapshots[0]: + yield ( + colored("{} snapshots missing", "red", attrs=bold) + + " in {} test suites. Run with `--snapshot-update` to create them." + ).format(*missing_snapshots) if new_snapshots[0]: yield ( colored("{} snapshots written", "green", attrs=bold) + " in {} test suites." @@ -46,9 +52,7 @@ def diff_report(left, right): + colored("Received value", "red", attrs=["bold"]) + colored(" does not match ", attrs=["bold"]) + colored( - "stored snapshot `{}`".format( - left.snapshottest.test_name, - ), + "stored snapshot `{}`".format(left.snapshottest.test_name), "green", attrs=["bold"], ) diff --git a/snapshottest/unittest.py b/snapshottest/unittest.py index b68fce7..6147a0e 100644 --- a/snapshottest/unittest.py +++ b/snapshottest/unittest.py @@ -2,6 +2,7 @@ import unittest import inspect +from .parse_env import env_snapshot_update from .module import SnapshotModule, SnapshotTest from .diff import PrettyDiff from .reporting import diff_report @@ -13,17 +14,12 @@ def __init__(self, test_class, test_id, test_filepath, should_update, assertEqua self.test_id = test_id self.test_filepath = test_filepath self.assertEqual = assertEqual - self.should_update = should_update - super(UnitTestSnapshotTest, self).__init__() + super(UnitTestSnapshotTest, self).__init__(should_update) @property def module(self): return SnapshotModule.get_module_for_testpath(self.test_filepath) - @property - def update(self): - return self.should_update - def assert_equals(self, value, snapshot): self.assertEqual(value, snapshot) @@ -78,13 +74,14 @@ def tearDownClass(cls): def setUp(self): """Do some custom setup""" + should_update = self.snapshot_should_update or env_snapshot_update() # print dir(self.__module__) self.addTypeEqualityFunc(PrettyDiff, self.comparePrettyDifs) self._snapshot = UnitTestSnapshotTest( test_class=self.__class__, test_id=self.id(), test_filepath=self._snapshot_file, - should_update=self.snapshot_should_update, + should_update=should_update, assertEqual=self.assertEqual, ) self._snapshot_tests.append(self._snapshot) diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 8c53056..2956a7a 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -75,16 +75,7 @@ def test_can_normalize_iterator_objects(): @pytest.mark.parametrize( - "value", - [ - 0, - 12.7, - True, - False, - None, - float("-inf"), - float("inf"), - ], + "value", [0, 12.7, True, False, None, float("-inf"), float("inf")] ) def test_basic_formatting_parsing(value): formatter = Formatter() diff --git a/tests/test_pytest.py b/tests/test_pytest.py index 06b9fb4..b8d7f72 100644 --- a/tests/test_pytest.py +++ b/tests/test_pytest.py @@ -21,6 +21,7 @@ def pytest_snapshot_test(request, _apply_options): class TestPyTestSnapShotTest: def test_property_test_name(self, pytest_snapshot_test): + pytest_snapshot_test.snapshot_should_update = True pytest_snapshot_test.assert_match("counter") assert ( pytest_snapshot_test.test_name @@ -41,6 +42,7 @@ def test_property_test_name(self, pytest_snapshot_test): def test_pytest_snapshottest_property_test_name(pytest_snapshot_test): + pytest_snapshot_test.snapshot_should_update = True pytest_snapshot_test.assert_match("counter") assert ( pytest_snapshot_test.test_name @@ -64,6 +66,7 @@ def test_pytest_snapshottest_property_test_name(pytest_snapshot_test): def test_pytest_snapshottest_property_test_name_parametrize_singleline( pytest_snapshot_test, arg ): + pytest_snapshot_test.snapshot_should_update = True pytest_snapshot_test.assert_match("counter") assert ( pytest_snapshot_test.test_name @@ -84,6 +87,7 @@ def test_pytest_snapshottest_property_test_name_parametrize_singleline( def test_pytest_snapshottest_property_test_name_parametrize_multiline( pytest_snapshot_test, arg ): + pytest_snapshot_test.snapshot_should_update = True pytest_snapshot_test.assert_match("counter") assert ( pytest_snapshot_test.test_name diff --git a/tests/test_snapshot_test.py b/tests/test_snapshot_test.py index 9249478..97f44f3 100644 --- a/tests/test_snapshot_test.py +++ b/tests/test_snapshot_test.py @@ -4,6 +4,7 @@ from collections import OrderedDict from snapshottest.module import SnapshotModule, SnapshotTest +from snapshottest.error import SnapshotNotFound class GenericSnapshotTest(SnapshotTest): @@ -15,7 +16,7 @@ def __init__(self, snapshot_module, update=False, current_test_id=None): "update": update, "current_test_id": current_test_id or "test_mocked", } - super(GenericSnapshotTest, self).__init__() + super(GenericSnapshotTest, self).__init__(update) @property def module(self): @@ -33,7 +34,7 @@ def test_name(self): def reinitialize(self): """Reset internal state, as though starting a new test run""" - super(GenericSnapshotTest, self).__init__() + super(GenericSnapshotTest, self).__init__(False) def assert_snapshot_test_ran(snapshot_test, test_name=None): @@ -91,6 +92,7 @@ def fixture_snapshot_test(tmpdir): @pytest.mark.parametrize("value", SNAPSHOTABLE_VALUES, ids=repr) def test_snapshot_matches_itself(snapshot_test, value): # first run stores the value as the snapshot + snapshot_test.snapshot_should_update = True snapshot_test.assert_match(value) assert_snapshot_test_succeeded(snapshot_test) @@ -115,6 +117,7 @@ def test_snapshot_matches_itself(snapshot_test, value): ) def test_snapshot_does_not_match_other_values(snapshot_test, value, other_value): # first run stores the value as the snapshot + snapshot_test.snapshot_should_update = True snapshot_test.assert_match(value) assert_snapshot_test_succeeded(snapshot_test) @@ -123,3 +126,11 @@ def test_snapshot_does_not_match_other_values(snapshot_test, value, other_value) with pytest.raises(AssertionError): snapshot_test.assert_match(other_value) assert_snapshot_test_failed(snapshot_test) + + +def test_first_run_without_snapshots_fails(snapshot_test): + with pytest.raises(SnapshotNotFound): + snapshot_test.assert_match("foo", name="no_snapshot_exists_test") + assert snapshot_test.module.missing_snapshots == set( + ["test_mocked no_snapshot_exists_test"] + ) diff --git a/tests/test_sorted_dict.py b/tests/test_sorted_dict.py index b8217d8..f58711a 100644 --- a/tests/test_sorted_dict.py +++ b/tests/test_sorted_dict.py @@ -42,10 +42,7 @@ class Fruit(enum.IntEnum): APPLE = 1 ORANGE = 2 - dic = { - Fruit.APPLE: 100, - Fruit.ORANGE: 400, - } + dic = {Fruit.APPLE: 100, Fruit.ORANGE: 400} assert SortedDict(dic)[Fruit.APPLE] == dic[Fruit.APPLE] assert SortedDict(dic)[Fruit.ORANGE] == dic[Fruit.ORANGE] @@ -55,10 +52,7 @@ class Fruit(enum.Enum): APPLE = 1 ORANGE = 2 - dic = { - Fruit.APPLE: 100, - Fruit.ORANGE: 400, - } + dic = {Fruit.APPLE: 100, Fruit.ORANGE: 400} assert SortedDict(dic)[Fruit.APPLE] == dic[Fruit.APPLE] assert SortedDict(dic)[Fruit.ORANGE] == dic[Fruit.ORANGE]