diff --git a/snapshottest/module.py b/snapshottest/module.py index e31b46c..8b71125 100644 --- a/snapshottest/module.py +++ b/snapshottest/module.py @@ -25,7 +25,7 @@ def __init__(self, module, filepath): self.visited_snapshots = set() self.new_snapshots = set() self.failed_snapshots = set() - self.imports["snapshottest"].add("Snapshot") + self.imports['snapshottest'].add('Snapshot') def load_snapshots(self): try: @@ -75,7 +75,7 @@ def stats_for_module(cls, getter): count_snapshots += length count_modules += min(length, 1) - return count_snapshots, count_modules + return count_snapshots, count_modules, module.new_snapshots, module.unvisited_snapshots @classmethod def stats_unvisited_snapshots(cls): @@ -146,26 +146,21 @@ def save(self): pass # Create __init__.py in case doesn't exist - open(os.path.join(self.snapshot_dir, "__init__.py"), "a").close() + open(os.path.join(self.snapshot_dir, '__init__.py'), 'a').close() pretty = Formatter(self.imports) - with codecs.open(self.filepath, "w", encoding="utf-8") as snapshot_file: + with codecs.open(self.filepath, 'w', encoding="utf-8") as snapshot_file: snapshots_declarations = [ """snapshots['{}'] = {}""".format(key, pretty(self.snapshots[key])) for key in sorted(self.snapshots.keys()) ] - imports = "\n".join( - [ - "from {} import {}".format( - module, ", ".join(sorted(module_imports)) - ) - for module, module_imports in sorted(self.imports.items()) - ] - ) - snapshot_file.write( - """# -*- coding: utf-8 -*- + imports = '\n'.join([ + 'from {} import {}'.format(module, ', '.join(sorted(module_imports))) + for module, module_imports in sorted(self.imports.items()) + ]) + snapshot_file.write('''# -*- coding: utf-8 -*- # snapshottest: v1 - https://goo.gl/zC4yUc from __future__ import unicode_literals @@ -175,10 +170,7 @@ def save(self): snapshots = Snapshot() {} -""".format( - imports, "\n\n".join(snapshots_declarations) - ) - ) +'''.format(imports, '\n\n'.join(snapshots_declarations))) @classmethod def get_module_for_testpath(cls, test_filepath): @@ -186,15 +178,11 @@ def get_module_for_testpath(cls, test_filepath): dirname = os.path.dirname(test_filepath) snapshot_dir = os.path.join(dirname, "snapshots") - snapshot_basename = "snap_{}.py".format( - os.path.splitext(os.path.basename(test_filepath))[0] - ) + snapshot_basename = 'snap_{}.py'.format(os.path.splitext(os.path.basename(test_filepath))[0]) snapshot_filename = os.path.join(snapshot_dir, snapshot_basename) - snapshot_module = "{}".format(os.path.splitext(snapshot_basename)[0]) + snapshot_module = '{}'.format(os.path.splitext(snapshot_basename)[0]) - cls._snapshot_modules[test_filepath] = SnapshotModule( - snapshot_module, snapshot_filename - ) + cls._snapshot_modules[test_filepath] = SnapshotModule(snapshot_module, snapshot_filename) return cls._snapshot_modules[test_filepath] @@ -203,7 +191,7 @@ class SnapshotTest(object): _current_tester = None def __init__(self): - self.curr_snapshot = "" + self.curr_snapshot = '' self.snapshot_counter = 1 @property @@ -239,14 +227,12 @@ def store(self, data): def assert_value_matches_snapshot(self, test_value, snapshot_value): formatter = Formatter.get_formatter(test_value) - formatter.assert_value_matches_snapshot( - self, test_value, snapshot_value, Formatter() - ) + formatter.assert_value_matches_snapshot(self, test_value, snapshot_value) def assert_equals(self, value, snapshot): assert value == snapshot - def assert_match(self, value, name=""): + def assert_match(self, value, name=''): self.curr_snapshot = name or self.snapshot_counter self.visit() if self.update: @@ -270,10 +256,8 @@ def save_changes(self): self.module.save() -def assert_match_snapshot(value, name=""): +def assert_match_snapshot(value, name=''): if not SnapshotTest._current_tester: - raise Exception( - "You need to use assert_match_snapshot in the SnapshotTest context." - ) + raise Exception("You need to use assert_match_snapshot in the SnapshotTest context.") SnapshotTest._current_tester.assert_match(value, name) diff --git a/snapshottest/reporting.py b/snapshottest/reporting.py index 26ca51f..0688db3 100644 --- a/snapshottest/reporting.py +++ b/snapshottest/reporting.py @@ -6,55 +6,45 @@ def reporting_lines(testing_cli): successful_snapshots = SnapshotModule.stats_successful_snapshots() - bold = ["bold"] + bold = ['bold'] if successful_snapshots: - yield (colored("{} snapshots passed", attrs=bold) + ".").format( - successful_snapshots - ) + yield ( + colored('{} snapshots passed', attrs=bold) + '.' + ).format(successful_snapshots) new_snapshots = SnapshotModule.stats_new_snapshots() if new_snapshots[0]: yield ( - colored("{} snapshots written", "green", attrs=bold) + " in {} test suites." + colored('{0} snapshots written {2}', 'green', attrs=bold) + ' in {1} test suites.' ).format(*new_snapshots) inspect_str = colored( - "Inspect your code or run with `{} --snapshot-update` to update them.".format( - testing_cli - ), - attrs=["dark"], + 'Inspect your code or run with `{} --snapshot-update` to update them.'.format(testing_cli), + attrs=['dark'] ) failed_snapshots = SnapshotModule.stats_failed_snapshots() if failed_snapshots[0]: yield ( - colored("{} snapshots failed", "red", attrs=bold) - + " in {} test suites. " + colored('{} snapshots failed', 'red', attrs=bold) + ' in {} test suites. ' + inspect_str ).format(*failed_snapshots) unvisited_snapshots = SnapshotModule.stats_unvisited_snapshots() if unvisited_snapshots[0]: yield ( - colored("{} snapshots deprecated", "yellow", attrs=bold) - + " in {} test suites. " + colored('{0} snapshots deprecated {3}', 'yellow', attrs=bold) + ' in {1} test suites. ' + inspect_str ).format(*unvisited_snapshots) def diff_report(left, right): return [ - "stored snapshot should match the received value", - "", - colored("> ") - + colored("Received value", "red", attrs=["bold"]) - + colored(" does not match ", attrs=["bold"]) - + colored( - "stored snapshot `{}`".format( - left.snapshottest.test_name, - ), - "green", - attrs=["bold"], - ) - + colored(".", attrs=["bold"]), - colored("") - + "> " - + os.path.relpath(left.snapshottest.module.filepath, os.getcwd()), - "", + 'stored snapshot should match the received value', + '', + colored('> ') + + colored('Received value', 'red', attrs=['bold']) + + colored(' does not match ', attrs=['bold']) + + colored('stored snapshot `{}`'.format( + left.snapshottest.test_name, + ), 'green', attrs=['bold']) + + colored('.', attrs=['bold']), + colored('') + '> ' + os.path.relpath(left.snapshottest.module.filepath, os.getcwd()), + '', ] + left.get_diff(right)