diff --git a/flask_app/blueprints/api/warnings.py b/flask_app/blueprints/api/warnings.py index 2903df61..3c48d38e 100644 --- a/flask_app/blueprints/api/warnings.py +++ b/flask_app/blueprints/api/warnings.py @@ -21,8 +21,9 @@ def add_warning(message:str, filename:str=None, lineno:int=None, test_id:int=Non timestamp = get_current_time() warning = Warning.query.filter_by(session_id=session_id, test_id=test_id, lineno=lineno, filename=filename, message=message).first() + unique_warnings_num = len(Warning.query.filter_by(session_id=session_id, test_id=test_id).all()) if warning is None: - if obj.num_warnings < current_app.config['MAX_WARNINGS_PER_ENTITY']: + if unique_warnings_num < current_app.config['MAX_WARNINGS_PER_ENTITY']: warning = Warning(message=message, timestamp=timestamp, filename=filename, lineno=lineno, test_id=test_id, session_id=session_id) db.session.add(warning) else: diff --git a/tests/test_warnings.py b/tests/test_warnings.py index bf478ead..1441cad3 100644 --- a/tests/test_warnings.py +++ b/tests/test_warnings.py @@ -24,17 +24,30 @@ def test_add_warnings_nonexistent_session(warning_container, message): warning_container.add_warning(message=message) -def test_max_warnings_per_entity(warning_container, message, webapp): +@pytest.mark.parametrize("unique_warning", [True, False]) +def test_max_warnings_per_entity(warning_container, message, webapp, unique_warning): max_warnings = 3 webapp.app.config['MAX_WARNINGS_PER_ENTITY'] = max_warnings for i in range(max_warnings + 1): - warning_container.add_warning(message=f'{message}{i}' ) + warning_container.add_warning(message=f'{message}{i if not unique_warning else ""}') warning_container.refresh() assert warning_container.num_warnings == max_warnings + 1 + assert len(warning_container.query_warnings().all()) == 1 if unique_warning else max_warnings - assert len(warning_container.query_warnings().all()) == max_warnings + +def test_unique_warning_per_entity(warning_container, webapp): + max_warnings = 3 + webapp.app.config['MAX_WARNINGS_PER_ENTITY'] = max_warnings + + for i in range(max_warnings): + warning_container.add_warning(message=message) + warning_container.add_warning(message="unique warning") + + warning_container.refresh() + assert warning_container.num_warnings == max_warnings + 1 + assert len(warning_container.query_warnings().all()) == 2 def test_add_warning_twice(warning_container, filename, lineno, message, timestamp):