Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't allow missing snapshots without --snapshot-update #112

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions snapshottest/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self, module, filepath):
self.imports = defaultdict(set)
self.visited_snapshots = set()
self.new_snapshots = set()
self.missing_snapshots = set()
self.failed_snapshots = set()
self.imports["snapshottest"].add("Snapshot")

Expand Down Expand Up @@ -89,6 +90,10 @@ def stats_visited_snapshots(cls):
def stats_new_snapshots(cls):
return cls.stats_for_module(lambda module: len(module.new_snapshots))

@classmethod
def stats_missing_snapshots(cls):
return cls.stats_for_module(lambda module: len(module.missing_snapshots))

@classmethod
def stats_failed_snapshots(cls):
return cls.stats_for_module(lambda module: len(module.failed_snapshots))
Expand Down Expand Up @@ -130,6 +135,9 @@ def __setitem__(self, key, value):
def mark_failed(self, key):
return self.failed_snapshots.add(key)

def mark_missing(self, key):
return self.missing_snapshots.add(key)

@property
def snapshot_dir(self):
return os.path.dirname(self.filepath)
Expand Down Expand Up @@ -202,18 +210,15 @@ def get_module_for_testpath(cls, test_filepath):
class SnapshotTest(object):
_current_tester = None

def __init__(self):
def __init__(self, snapshot_should_update):
self.curr_snapshot = ""
self.snapshot_counter = 1
self.snapshot_should_update = snapshot_should_update

@property
def module(self):
raise NotImplementedError("module property needs to be implemented")

@property
def update(self):
return False

@property
def test_name(self):
raise NotImplementedError("test_name property needs to be implemented")
Expand All @@ -232,6 +237,9 @@ def visit(self):
def fail(self):
self.module.mark_failed(self.test_name)

def missing(self):
self.module.mark_missing(self.test_name)

def store(self, data):
formatter = Formatter.get_formatter(data)
data = formatter.store(self, data)
Expand All @@ -249,13 +257,16 @@ def assert_equals(self, value, snapshot):
def assert_match(self, value, name=""):
self.curr_snapshot = name or self.snapshot_counter
self.visit()
if self.update:
if self.snapshot_should_update:
self.store(value)
else:
try:
prev_snapshot = self.module[self.test_name]
except SnapshotNotFound:
self.store(value) # first time this test has been seen
# There is no snapshot for this test, run with --snapshot-update
# to create it
self.missing()
raise
else:
try:
self.assert_value_matches_snapshot(value, prev_snapshot)
Expand Down
3 changes: 2 additions & 1 deletion snapshottest/nose.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from nose.plugins import Plugin

from .parse_env import env_snapshot_update
from .module import SnapshotModule
from .reporting import reporting_lines
from .unittest import TestCase
Expand Down Expand Up @@ -37,7 +38,7 @@ def options(self, parser, env=os.environ):

def configure(self, options, conf):
super(SnapshotTestPlugin, self).configure(options, conf)
self.snapshot_update = options.snapshot_update
self.snapshot_update = options.snapshot_update or env_snapshot_update()
self.enabled = not options.snapshot_disable

def wantClass(self, cls):
Expand Down
9 changes: 9 additions & 0 deletions snapshottest/parse_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import os


def _env_bool(val):
return val.lower() in ["1", "yes", "true", "t", "y"]


def env_snapshot_update():
return _env_bool(os.environ.get("SNAPSHOT_UPDATE", "false"))
8 changes: 3 additions & 5 deletions snapshottest/pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest
import re

from .parse_env import env_snapshot_update
from .module import SnapshotModule, SnapshotTest
from .diff import PrettyDiff
from .reporting import reporting_lines, diff_report
Expand All @@ -27,16 +28,13 @@ def pytest_addoption(parser):
class PyTestSnapshotTest(SnapshotTest):
def __init__(self, request=None):
self.request = request
super(PyTestSnapshotTest, self).__init__()
should_update = request.config.option.snapshot_update or env_snapshot_update()
super(PyTestSnapshotTest, self).__init__(should_update)

@property
def module(self):
return SnapshotModule.get_module_for_testpath(self.request.node.fspath.strpath)

@property
def update(self):
return self.request.config.option.snapshot_update

@property
def test_name(self):
cls_name = getattr(self.request.node.cls, "__name__", "")
Expand Down
12 changes: 8 additions & 4 deletions snapshottest/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,17 @@
def reporting_lines(testing_cli):
successful_snapshots = SnapshotModule.stats_successful_snapshots()
bold = ["bold"]
new_snapshots = SnapshotModule.stats_new_snapshots()
if successful_snapshots:
yield (colored("{} snapshots passed", attrs=bold) + ".").format(
successful_snapshots
)
new_snapshots = SnapshotModule.stats_new_snapshots()
missing_snapshots = SnapshotModule.stats_missing_snapshots()
if missing_snapshots[0]:
yield (
colored("{} snapshots missing", "red", attrs=bold)
+ " in {} test suites. Run with `--snapshot-update` to create them."
).format(*missing_snapshots)
if new_snapshots[0]:
yield (
colored("{} snapshots written", "green", attrs=bold) + " in {} test suites."
Expand Down Expand Up @@ -46,9 +52,7 @@ def diff_report(left, right):
+ colored("Received value", "red", attrs=["bold"])
+ colored(" does not match ", attrs=["bold"])
+ colored(
"stored snapshot `{}`".format(
left.snapshottest.test_name,
),
"stored snapshot `{}`".format(left.snapshottest.test_name),
"green",
attrs=["bold"],
)
Expand Down
11 changes: 4 additions & 7 deletions snapshottest/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest
import inspect

from .parse_env import env_snapshot_update
from .module import SnapshotModule, SnapshotTest
from .diff import PrettyDiff
from .reporting import diff_report
Expand All @@ -13,17 +14,12 @@ def __init__(self, test_class, test_id, test_filepath, should_update, assertEqua
self.test_id = test_id
self.test_filepath = test_filepath
self.assertEqual = assertEqual
self.should_update = should_update
super(UnitTestSnapshotTest, self).__init__()
super(UnitTestSnapshotTest, self).__init__(should_update)

@property
def module(self):
return SnapshotModule.get_module_for_testpath(self.test_filepath)

@property
def update(self):
return self.should_update

def assert_equals(self, value, snapshot):
self.assertEqual(value, snapshot)

Expand Down Expand Up @@ -78,13 +74,14 @@ def tearDownClass(cls):

def setUp(self):
"""Do some custom setup"""
should_update = self.snapshot_should_update or env_snapshot_update()
# print dir(self.__module__)
self.addTypeEqualityFunc(PrettyDiff, self.comparePrettyDifs)
self._snapshot = UnitTestSnapshotTest(
test_class=self.__class__,
test_id=self.id(),
test_filepath=self._snapshot_file,
should_update=self.snapshot_should_update,
should_update=should_update,
assertEqual=self.assertEqual,
)
self._snapshot_tests.append(self._snapshot)
Expand Down
11 changes: 1 addition & 10 deletions tests/test_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,7 @@ def test_can_normalize_iterator_objects():


@pytest.mark.parametrize(
"value",
[
0,
12.7,
True,
False,
None,
float("-inf"),
float("inf"),
],
"value", [0, 12.7, True, False, None, float("-inf"), float("inf")]
)
def test_basic_formatting_parsing(value):
formatter = Formatter()
Expand Down
4 changes: 4 additions & 0 deletions tests/test_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def pytest_snapshot_test(request, _apply_options):

class TestPyTestSnapShotTest:
def test_property_test_name(self, pytest_snapshot_test):
pytest_snapshot_test.snapshot_should_update = True
pytest_snapshot_test.assert_match("counter")
assert (
pytest_snapshot_test.test_name
Expand All @@ -41,6 +42,7 @@ def test_property_test_name(self, pytest_snapshot_test):


def test_pytest_snapshottest_property_test_name(pytest_snapshot_test):
pytest_snapshot_test.snapshot_should_update = True
pytest_snapshot_test.assert_match("counter")
assert (
pytest_snapshot_test.test_name
Expand All @@ -64,6 +66,7 @@ def test_pytest_snapshottest_property_test_name(pytest_snapshot_test):
def test_pytest_snapshottest_property_test_name_parametrize_singleline(
pytest_snapshot_test, arg
):
pytest_snapshot_test.snapshot_should_update = True
pytest_snapshot_test.assert_match("counter")
assert (
pytest_snapshot_test.test_name
Expand All @@ -84,6 +87,7 @@ def test_pytest_snapshottest_property_test_name_parametrize_singleline(
def test_pytest_snapshottest_property_test_name_parametrize_multiline(
pytest_snapshot_test, arg
):
pytest_snapshot_test.snapshot_should_update = True
pytest_snapshot_test.assert_match("counter")
assert (
pytest_snapshot_test.test_name
Expand Down
15 changes: 13 additions & 2 deletions tests/test_snapshot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections import OrderedDict

from snapshottest.module import SnapshotModule, SnapshotTest
from snapshottest.error import SnapshotNotFound


class GenericSnapshotTest(SnapshotTest):
Expand All @@ -15,7 +16,7 @@ def __init__(self, snapshot_module, update=False, current_test_id=None):
"update": update,
"current_test_id": current_test_id or "test_mocked",
}
super(GenericSnapshotTest, self).__init__()
super(GenericSnapshotTest, self).__init__(update)

@property
def module(self):
Expand All @@ -33,7 +34,7 @@ def test_name(self):

def reinitialize(self):
"""Reset internal state, as though starting a new test run"""
super(GenericSnapshotTest, self).__init__()
super(GenericSnapshotTest, self).__init__(False)


def assert_snapshot_test_ran(snapshot_test, test_name=None):
Expand Down Expand Up @@ -91,6 +92,7 @@ def fixture_snapshot_test(tmpdir):
@pytest.mark.parametrize("value", SNAPSHOTABLE_VALUES, ids=repr)
def test_snapshot_matches_itself(snapshot_test, value):
# first run stores the value as the snapshot
snapshot_test.snapshot_should_update = True
snapshot_test.assert_match(value)
assert_snapshot_test_succeeded(snapshot_test)

Expand All @@ -115,6 +117,7 @@ def test_snapshot_matches_itself(snapshot_test, value):
)
def test_snapshot_does_not_match_other_values(snapshot_test, value, other_value):
# first run stores the value as the snapshot
snapshot_test.snapshot_should_update = True
snapshot_test.assert_match(value)
assert_snapshot_test_succeeded(snapshot_test)

Expand All @@ -123,3 +126,11 @@ def test_snapshot_does_not_match_other_values(snapshot_test, value, other_value)
with pytest.raises(AssertionError):
snapshot_test.assert_match(other_value)
assert_snapshot_test_failed(snapshot_test)


def test_first_run_without_snapshots_fails(snapshot_test):
with pytest.raises(SnapshotNotFound):
snapshot_test.assert_match("foo", name="no_snapshot_exists_test")
assert snapshot_test.module.missing_snapshots == set(
["test_mocked no_snapshot_exists_test"]
)
10 changes: 2 additions & 8 deletions tests/test_sorted_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,7 @@ class Fruit(enum.IntEnum):
APPLE = 1
ORANGE = 2

dic = {
Fruit.APPLE: 100,
Fruit.ORANGE: 400,
}
dic = {Fruit.APPLE: 100, Fruit.ORANGE: 400}
assert SortedDict(dic)[Fruit.APPLE] == dic[Fruit.APPLE]
assert SortedDict(dic)[Fruit.ORANGE] == dic[Fruit.ORANGE]

Expand All @@ -55,10 +52,7 @@ class Fruit(enum.Enum):
APPLE = 1
ORANGE = 2

dic = {
Fruit.APPLE: 100,
Fruit.ORANGE: 400,
}
dic = {Fruit.APPLE: 100, Fruit.ORANGE: 400}
assert SortedDict(dic)[Fruit.APPLE] == dic[Fruit.APPLE]
assert SortedDict(dic)[Fruit.ORANGE] == dic[Fruit.ORANGE]

Expand Down