diff --git a/deche/core.py b/deche/core.py index 9c57a26..ec9f86e 100644 --- a/deche/core.py +++ b/deche/core.py @@ -250,10 +250,7 @@ def inner(): return inner - def __call__(self, func): # noqa: C901 - # TODO - very lazy async support. Refactor - # TODO - fsspec also has async support - could make exists/load calls async - + def __call__(self, func): if inspect.iscoroutinefunction(func): @functools.wraps(func) @@ -266,25 +263,17 @@ async def wrapper(*args, **kwargs): key, _ = tokenize(obj=inputs) if self.valid(path=f"{path}/{key}"): return self._load(func=func)(key=key) - elif self._exists(func=func, ext=Extensions.exception)(key=key): - raise self._load(func=func, ext=Extensions.exception)(key=key) - try: - self.write_input(path=f"{path}/{key}", inputs=inputs) - logger.debug(f"Calling {func}") - output = await func(*args, **kwargs) - if self.result_validator is not None: - logger.debug(f"Validating result with {self.result_validator}") - try: - self.result_validator(output) - except Exception as e: - raise ValidationError(e) - logger.debug(f"Function {func} ran successfully") - self.write_output(path=f"{path}/{key}", output=output) - except Exception as e: - logger.debug(f"Function {func} raised {e}") - self.write_output(path=f"{path}/{key}{Extensions.exception}", output=e) - raise e - + self.write_input(path=f"{path}/{key}", inputs=inputs) + logger.debug(f"Calling {func}") + output = await func(*args, **kwargs) + if self.result_validator is not None: + logger.debug(f"Validating result with {self.result_validator}") + try: + self.result_validator(output) + except Exception as e: + raise ValidationError(e) + logger.debug(f"Function {func} ran successfully") + self.write_output(path=f"{path}/{key}", output=output) return output else: @@ -299,25 +288,17 @@ def wrapper(*args, **kwargs): key, _ = tokenize(obj=inputs) if self.valid(path=f"{path}/{key}"): return self._load(func=func)(key=key) - elif self._exists(func=func, ext=Extensions.exception)(key=key): - raise self._load(func=func, ext=Extensions.exception)(key=key) - try: - self.write_input(path=f"{path}/{key}", inputs=inputs) - logger.debug(f"Calling {func}") - output = func(*args, **kwargs) - if self.result_validator is not None: - logger.debug(f"Validating result with {self.result_validator}") - try: - assert self.result_validator(output) is not False - except Exception as e: - raise ValidationError(e) - logger.debug(f"Function {func} ran successfully") - self.write_output(path=f"{path}/{key}", output=output) - except Exception as e: - logger.debug(f"Function {func} raised {e}") - self.write_output(path=f"{path}/{key}{Extensions.exception}", output=e) - raise e - + self.write_input(path=f"{path}/{key}", inputs=inputs) + logger.debug(f"Calling {func}") + output = func(*args, **kwargs) + if self.result_validator is not None: + logger.debug(f"Validating result with {self.result_validator}") + try: + assert self.result_validator(output) is not False + except Exception as e: + raise ValidationError(e) + logger.debug(f"Function {func} ran successfully") + self.write_output(path=f"{path}/{key}", output=output) return output wrapper.tokenize = tokenize_func(func=func, ignore=self.non_hashable_kwargs, cls_attrs=self.cls_attrs) @@ -326,20 +307,14 @@ def wrapper(*args, **kwargs): wrapper.is_valid = self.is_valid(func=wrapper) wrapper.has_inputs = self._exists(func=wrapper, ext=Extensions.inputs) wrapper.has_data = self._exists(func=wrapper) - wrapper.has_exception = self._exists(func=wrapper, ext=Extensions.exception) wrapper.list_cached_inputs = self._list(func=wrapper, ext=Extensions.inputs) wrapper.list_cached_data = self._list(func=wrapper, filter_=data_filter) - wrapper.list_cached_exceptions = self._list(func=wrapper, ext=Extensions.exception) wrapper.iter_cached_inputs = self._iter(func=wrapper, ext=Extensions.inputs) wrapper.iter_cached_data = self._iter(func=wrapper, filter_=data_filter) - wrapper.iter_cached_exception = self._iter(func=wrapper, ext=Extensions.exception) wrapper.load_cached_inputs = self._load(func=wrapper, ext=Extensions.inputs) wrapper.load_cached_data = self._load(func=wrapper) - wrapper.load_cached_exception = self._load(func=wrapper, ext=Extensions.exception) wrapper.remove_cached_inputs = self._remove(func=wrapper, ext=Extensions.inputs) wrapper.remove_cached_data = self._remove(func=wrapper) - wrapper.remove_cached_exception = self._remove(func=wrapper, ext=Extensions.exception) - wrapper.remove_all_cached_exceptions = self._remove_all(func=wrapper, ext=Extensions.exception) wrapper.path = functools.partial(self._path, func=func) wrapper.deche = self return wrapper diff --git a/tests/unit/test_core.py b/tests/unit/test_core.py index 2c361cf..5ffbf0f 100644 --- a/tests/unit/test_core.py +++ b/tests/unit/test_core.py @@ -152,14 +152,14 @@ def test_list_cached_data(): assert result == ["/deche.test_utils.func/f4f46c47d91eea40eba825cf941ff22bdc87ce849400ed3fd85be092e43031d4"] -def test_list_cached_exceptions(): - with pytest.raises(ZeroDivisionError): - exc_func() - result = exc_func.list_cached_exceptions() - assert result == ["6c8d328939ceaaf60d6cbe813bf07a48656647184baa590fe9b6632bfc3d7936"] +# def test_list_cached_exceptions(): +# with pytest.raises(ZeroDivisionError): +# exc_func() +# result = exc_func.list_cached_exceptions() +# assert result == ["6c8d328939ceaaf60d6cbe813bf07a48656647184baa590fe9b6632bfc3d7936"] - result = exc_func.list_cached_exceptions(key_only=False) - assert result == ["/deche.test_utils.exc_func/6c8d328939ceaaf60d6cbe813bf07a48656647184baa590fe9b6632bfc3d7936.exc"] +# result = exc_func.list_cached_exceptions(key_only=False) +# assert result == ["/deche.test_utils.exc_func/6c8d328939ceaaf60d6cbe813bf07a48656647184baa590fe9b6632bfc3d7936.exc"] def test_iter(): @@ -194,31 +194,31 @@ def test_load_cached_data(): assert result == expected -def test_load_cached_exception(): - try: - exc_func() - except ZeroDivisionError as expected: - result = exc_func.load_cached_exception(kwargs={}) - assert isinstance(result, type(expected)) - assert type(expected) == type(result) +# def test_load_cached_exception(): +# try: +# exc_func() +# except ZeroDivisionError as expected: +# result = exc_func.load_cached_exception(kwargs={}) +# assert isinstance(result, type(expected)) +# assert type(expected) == type(result) - key = exc_func.tokenize() - result = exc_func.load_cached_exception(key=key) - assert isinstance(result, type(expected)) +# key = exc_func.tokenize() +# result = exc_func.load_cached_exception(key=key) +# assert isinstance(result, type(expected)) -def test_remove_all_exceptions(): - try: - exc_func(1) - except ZeroDivisionError: - pass - try: - exc_func(2) - except ZeroDivisionError: - pass - assert len(exc_func.list_cached_exceptions()) == 2 - exc_func.remove_all_cached_exceptions() - assert len(exc_func.list_cached_exceptions()) == 0 +# def test_remove_all_exceptions(): +# try: +# exc_func(1) +# except ZeroDivisionError: +# pass +# try: +# exc_func(2) +# except ZeroDivisionError: +# pass +# assert len(exc_func.list_cached_exceptions()) == 2 +# exc_func.remove_all_cached_exceptions() +# assert len(exc_func.list_cached_exceptions()) == 0 def test_exists(): @@ -272,12 +272,12 @@ def test_cache_path(c: Cache): assert func.path() == "/deche.test_utils.func" -def test_cache_exception(c: Cache): - try: - exc_func() - except ZeroDivisionError as e: - exc = exc_func.load_cached_exception(kwargs={}) - assert type(exc) == type(e) +# def test_cache_exception(c: Cache): +# try: +# exc_func() +# except ZeroDivisionError as e: +# exc = exc_func.load_cached_exception(kwargs={}) +# assert type(exc) == type(e) def test_cached_exception_raises(cached_exception): diff --git a/tests/unit/test_exceptions.py b/tests/unit/test_exceptions.py new file mode 100644 index 0000000..91d1d40 --- /dev/null +++ b/tests/unit/test_exceptions.py @@ -0,0 +1,108 @@ +import time + +import pytest + +from deche import Cache + + +@pytest.fixture +def memory_cache(): + return Cache(fs_protocol="memory", prefix="/") + + +def test_exception_not_cached(memory_cache): + @memory_cache + def failing_func(): + time.sleep(0.1) + raise ValueError("This function always fails") + + # First call should raise the exception + start_time = time.time() + with pytest.raises(ValueError): + failing_func() + first_call_time = time.time() - start_time + + # Second call should also raise the exception, not return a cached exception + start_time = time.time() + with pytest.raises(ValueError): + failing_func() + second_call_time = time.time() - start_time + + assert failing_func.list_cached_data() == [] + assert abs(first_call_time - second_call_time) < 0.05 # Both calls should take similar time + + +def test_successful_execution_after_exception(memory_cache): + call_count = 0 + + @memory_cache + def sometimes_failing_func(fail=True): + nonlocal call_count + call_count += 1 + time.sleep(0.1) + if fail: + raise ValueError("This function fails when fail=True") + return "Success" + + # First call should raise the exception + start_time = time.time() + with pytest.raises(ValueError): + sometimes_failing_func(fail=True) + exception_time = time.time() - start_time + + # Second call with fail=False should execute the function and cache the result + start_time = time.time() + result = sometimes_failing_func(fail=False) + success_time = time.time() - start_time + + assert result == "Success" + assert call_count == 2 + assert exception_time > 0.1 + assert success_time > 0.1 + + # Third call with fail=False should return the cached result + start_time = time.time() + result = sometimes_failing_func(fail=False) + cached_time = time.time() - start_time + + assert result == "Success" + assert call_count == 2 # Call count shouldn't increase + assert cached_time < 0.01 # Cached call should be very fast + + +def test_cache_behavior_unchanged_for_successful_calls(memory_cache): + call_count = 0 + + @memory_cache + def cached_func(x): + nonlocal call_count + call_count += 1 + time.sleep(0.1) + return x * 2 + + # First call should execute the function + start_time = time.time() + result = cached_func(5) + first_call_time = time.time() - start_time + + assert result == 10 + assert call_count == 1 + assert first_call_time > 0.1 + + # Second call should return cached result + start_time = time.time() + result = cached_func(5) + second_call_time = time.time() - start_time + + assert result == 10 + assert call_count == 1 # Call count shouldn't increase + assert second_call_time < 0.01 # Cached call should be very fast + + # Call with different argument should execute the function again + start_time = time.time() + result = cached_func(7) + third_call_time = time.time() - start_time + + assert result == 14 + assert call_count == 2 + assert third_call_time > 0.1