Skip to content

Commit

Permalink
Merge branch 'feature/MaxiBoether/emptyevals' into feature/MaxiBoethe…
Browse files Browse the repository at this point in the history
…r/sigmod-revision-2
  • Loading branch information
MaxiBoether committed Sep 23, 2024
2 parents dd497a2 + 9a2495a commit ba26601
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 84 deletions.
4 changes: 2 additions & 2 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ disable=raw-checker-failed,
too-many-arguments, # we can't determine a good limit here. reviews should spot bad cases of this.
duplicate-code, # Mostly imports and test setup.
cyclic-import, # We use these inside methods that require models from multiple apps. Tests will catch actual errors.
too-many-instance-attributes # We always ignore this anyways

too-many-instance-attributes, # We always ignore this anyways
too-many-positional-arguments # We do not want to limit the number of positional arguments
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
Expand Down
10 changes: 7 additions & 3 deletions modyn/evaluator/internal/grpc/evaluator_grpc_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,12 +317,16 @@ def get_evaluation_result(
single_eval_data = EvaluationIntervalData(interval_index=interval_idx, evaluation_data=metric_result)
evaluation_data.append(single_eval_data)

if len(evaluation_data) < len(self._evaluation_dict[evaluation_id].not_failed_interval_ids):
num_metrics = len(self._evaluation_dict[evaluation_id].raw_metrics)
expected_results = len(self._evaluation_dict[evaluation_id].not_failed_interval_ids) * num_metrics
if len(evaluation_data) < expected_results:
logger.error(
f"Could not retrieve results for all intervals of evaluation {evaluation_id}. "
f"Expected {len(self._evaluation_dict[evaluation_id].not_failed_interval_ids)}, "
f"but got {len(evaluation_data)}. Maybe an exception happened during evaluation."
f"Expected {len(self._evaluation_dict[evaluation_id].not_failed_interval_ids)} * {num_metrics} = {expected_results} results, "
f"but got {len(evaluation_data)} results. Most likely, an exception happened during evaluation."
)
return EvaluationResultResponse(valid=False)

return EvaluationResultResponse(valid=True, evaluation_results=evaluation_data)

def cleanup_evaluations(
Expand Down
3 changes: 0 additions & 3 deletions modyn/evaluator/internal/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ def _batch_evaluated_callback(self, y_true: torch.Tensor, y_pred: torch.Tensor,
self.total_correct += labeled_correctly
self.samples_seen += batch_size

self.total_correct += labeled_correctly
self.samples_seen += batch_size

def get_evaluation_result(self) -> float:
if self.samples_seen == 0:
self.warning("Did not see any samples.")
Expand Down
1 change: 0 additions & 1 deletion modyn/evaluator/internal/pytorch_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def __init__(
)

self._device = evaluation_info.device
self._device_type = "cuda" if "cuda" in self._device else "cpu"
self._amp = evaluation_info.amp

self._info("Initialized PyTorch evaluator.")
Expand Down
56 changes: 44 additions & 12 deletions modyn/supervisor/internal/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,39 +302,71 @@ def wait_for_evaluation_completion(self, evaluation_id: int) -> bool:
self.init_evaluator()
raise e

# NOT within retry block
if not res.valid:
exception_msg = f"Evaluation {evaluation_id} is invalid at server:\n{res}\n"
logger.error(exception_msg)
raise RuntimeError(exception_msg)
# Should only happen when requesting invalid id, hence we throw
_msg = f"Evaluation {evaluation_id} is invalid at server:\n{res}\n"
logger.error(_msg)
raise RuntimeError(_msg)

if res.HasField("exception"):
exception_msg = f"Exception at evaluator occurred:\n{res.exception}\n\n"
logger.error(exception_msg)
logger.error(f"Exception at evaluator occurred:\n{res.exception}\n\n")
self.cleanup_evaluations([evaluation_id])
logger.error(f"Performed cleanup for evaluation {evaluation_id} that threw exception.")
has_exception = True
break
break # Exit busy wait

if not res.is_running:
break
break # Exit busy wait

sleep(1)

return not has_exception

def get_evaluation_results(self, evaluation_id: int) -> list[EvaluationIntervalData]:
assert self.evaluator is not None
if not self.connected_to_evaluator:
raise ConnectionError("Tried to wait for evaluation to finish, but not there is no gRPC connection.")

req = EvaluationResultRequest(evaluation_id=evaluation_id)
res: EvaluationResultResponse = self.evaluator.get_evaluation_result(req)

for attempt in Retrying(
stop=stop_after_attempt(5),
wait=wait_random_exponential(multiplier=1, min=2, max=60),
reraise=True,
):
with attempt:
try:
res: EvaluationResultResponse = self.evaluator.get_evaluation_result(req)
except grpc.RpcError as e: # We catch and reraise to easily reconnect
logger.error(e)
logger.error(f"[Evaluation {evaluation_id}]: gRPC connection error, trying to reconnect.")
self.init_evaluator()
raise e

if not res.valid:
logger.error(f"Cannot get the evaluation result for evaluation {evaluation_id}")
raise RuntimeError(f"Cannot get the evaluation result for evaluation {evaluation_id}")
_msg = f"Cannot get the evaluation result for evaluation {evaluation_id}"
logger.error(_msg)
raise RuntimeError(_msg)

return res.evaluation_results

def cleanup_evaluations(self, evaluation_ids: list[int]) -> None:
assert self.evaluator is not None

req = EvaluationCleanupRequest(evaluation_ids=set(evaluation_ids))
res: EvaluationCleanupResponse = self.evaluator.cleanup_evaluations(req)
for attempt in Retrying(
stop=stop_after_attempt(5),
wait=wait_random_exponential(multiplier=1, min=2, max=60),
reraise=True,
):
with attempt:
try:
res: EvaluationCleanupResponse = self.evaluator.cleanup_evaluations(req)
except grpc.RpcError as e: # We catch and reraise to easily reconnect
logger.error(e)
logger.error(f"[Evaluations {evaluation_ids}]: gRPC connection error, trying to reconnect.")
self.init_evaluator()
raise e

failed = set(evaluation_ids) - {int(i) for i in res.succeeded}
if failed:
Expand Down
164 changes: 103 additions & 61 deletions modyn/supervisor/internal/pipeline_executor/evaluation_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,80 +339,122 @@ def _single_batched_evaluation(
self.pipeline.evaluation.device,
intervals=cast(list[tuple[int | None, int | None]], intervals),
)

def get_failure_reason(eval_aborted_reason: EvaluationAbortedReason) -> str:
return EvaluationAbortedReason.DESCRIPTOR.values_by_number[eval_aborted_reason].name

started_evaluations = []

for attempt in Retrying(
stop=stop_after_attempt(5),
wait=wait_random_exponential(multiplier=1, min=2, max=60),
stop=stop_after_attempt(10),
wait=wait_random_exponential(multiplier=2, min=2, max=180),
reraise=True,
):
with attempt:
try:
response: EvaluateModelResponse = self.grpc.evaluator.evaluate_model(request)
except grpc.RpcError as e: # We catch and reraise to reconnect
except grpc.RpcError as e: # We catch and reraise them to tenacity after reconnecting
logger.error(e)
logger.error("gRPC connection error, trying to reconnect...")
self.grpc.init_evaluator()
raise e

assert len(response.interval_responses) == len(
intervals
), f"We expected {len(intervals)} intervals, but got {len(response.interval_responses)}."

def get_failure_reason(eval_aborted_reason: EvaluationAbortedReason) -> str:
return EvaluationAbortedReason.DESCRIPTOR.values_by_number[eval_aborted_reason].name

if not response.evaluation_started:
failure_reasons: list[tuple[str | None, dict]] = []
# note: interval indexes correspond to the intervals in the request
for interval_idx, interval_response in enumerate(response.interval_responses):
if interval_response.eval_aborted_reason != EvaluationAbortedReason.NOT_ABORTED:
reason = get_failure_reason(interval_response.eval_aborted_reason)
failure_reasons.append((reason, {}))
logger.error(
f"Evaluation for model {model_id_to_eval} on split {intervals[interval_idx]} "
f"not started with reason: {reason}."
assert len(response.interval_responses) == len(
intervals
), f"We expected {len(intervals)} intervals, but got {len(response.interval_responses)}."

if not response.evaluation_started:
failure_reasons: list[tuple[str | None, dict]] = []
# note: interval indexes correspond to the intervals in the request
for interval_idx, interval_response in enumerate(response.interval_responses):
if interval_response.eval_aborted_reason != EvaluationAbortedReason.NOT_ABORTED:
reason = get_failure_reason(interval_response.eval_aborted_reason)
failure_reasons.append((reason, {}))
logger.error(
f"Evaluation for model {model_id_to_eval} on split {intervals[interval_idx]} "
f"not started with reason: {reason}."
)
# No retrying here, if we were to retry it should be done at the evaluator
# This is likely due to an invalid request in the first place.
return failure_reasons

logger.info(f"Evaluation started for model {model_id_to_eval} on intervals {intervals}.")
started_evaluations.append(response.evaluation_id)
if not self.grpc.wait_for_evaluation_completion(response.evaluation_id):
raise RuntimeError("There was an exception during evaluation") # Trigger retry

eval_data = self.grpc.get_evaluation_results(
response.evaluation_id
) # Will throw in case of invalid result

self.grpc.cleanup_evaluations(
[response.evaluation_id]
) # Early cleanup if succeeded since we have the data

eval_results: list[tuple[str | None, dict[str, Any]]] = []

# ---------------------------------------------- Result Builder ---------------------------------------------- #
# The `eval_results` list is a list of tuples. Each tuple contains a failure reason (if any) and a dictionary
# with the evaluation results. The order of the tuples corresponds to the order of the intervals.
#
# response.interval_responses contains the evaluation results for each interval in the same order as the
# intervals in the request. Failed evaluations are marked with a failure reason.

# Metric results come from the `EvaluateModelResponse` and are stored in the `evaluation_data` field. This
# only contains the metrics for the intervals that were successfully evaluated.
#
# Therefore we first build a list of results with the same order as the intervals. The metrics will be filled in
# the next loop that unwraps `EvaluationResultResponse`.
# ----------------------------------------------------- . ---------------------------------------------------- #

# Prepare eval_results structure
for interval_response in response.interval_responses:
if interval_response.eval_aborted_reason != EvaluationAbortedReason.NOT_ABORTED:
reason = get_failure_reason(interval_response.eval_aborted_reason)
eval_results.append((reason, {}))
else:
eval_results.append(
(
None,
{"dataset_size": interval_response.dataset_size, "metrics": []},
)
)

# Parse response into eval_results structure
for interval_result in eval_data:
interval_idx = interval_result.interval_index
assert (
eval_results[interval_idx][0] is None
), "Evaluation failed, this interval idx should not be returned by evaluator."
eval_results[interval_idx][1]["metrics"] = [
{"name": metric.metric, "result": metric.result} for metric in interval_result.evaluation_data
]

# Assert that all evaluated intervals have all metrics
# Will trigger a retry in case this is not successful
# Can happen, e.g., if the evaluator is overloaded
expected_num_metrics = len(request.metrics)
for abort_reason, data_dict in eval_results:
if abort_reason is not None:
# If there was any reason to abort, we don't care
continue

assert (
data_dict["dataset_size"] > 0
), f"dataset size of 0, but no EMPTY_INTERVAL response: {eval_results}"
actual_num_metrics = len(data_dict["metrics"])
assert actual_num_metrics == expected_num_metrics, (
f"actual_num_metrics = {actual_num_metrics}"
+ f" != expected_num_metrics = {expected_num_metrics}"
+ "\n"
+ str(eval_results)
)
return failure_reasons

logger.info(f"Evaluation started for model {model_id_to_eval} on intervals {intervals}.")
self.grpc.wait_for_evaluation_completion(response.evaluation_id)
eval_data = self.grpc.get_evaluation_results(response.evaluation_id)
self.grpc.cleanup_evaluations([response.evaluation_id])

eval_results: list[tuple[str | None, dict[str, Any]]] = []

# ---------------------------------------------- Result Builder ---------------------------------------------- #
# The `eval_results` list is a list of tuples. Each tuple contains a failure reason (if any) and a dictionary
# with the evaluation results. The order of the tuples corresponds to the order of the intervals.
#
# response.interval_responses contains the evaluation results for each interval in the same order as the
# intervals in the request. Failed evaluations are marked with a failure reason.

# Metric results come from the `EvaluateModelResponse` and are stored in the `evaluation_data` field. This
# only contains the metrics for the intervals that were successfully evaluated.
#
# Therefore we first build a list of results with the same order as the intervals. The metrics will be filled in
# the next loop that unwraps `EvaluationResultResponse`.
# ----------------------------------------------------- . ---------------------------------------------------- #

for interval_response in response.interval_responses:
if interval_response.eval_aborted_reason != EvaluationAbortedReason.NOT_ABORTED:
reason = get_failure_reason(interval_response.eval_aborted_reason)
eval_results.append((reason, {}))
else:
eval_results.append(
(
None,
{"dataset_size": interval_response.dataset_size, "metrics": []},
)
)

for interval_result in eval_data:
interval_idx = interval_result.interval_index
assert eval_results[interval_idx][0] is None, "Evaluation failed, no metrics should be present."
eval_results[interval_idx][1]["metrics"] = [
{"name": metric.metric, "result": metric.result} for metric in interval_result.evaluation_data
]
return eval_results
# All checks succeeded
self.grpc.cleanup_evaluations(started_evaluations) # Make sure to clean up everything we started.
return eval_results

raise RuntimeError("Unreachable code - just to satisfy mypy.")


# ------------------------------------------------------------------------------------ #
Expand Down
1 change: 1 addition & 0 deletions modyn/tests/evaluator/internal/metrics/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def test_accuracy() -> None:
accuracy.evaluate_batch(y_true, y_pred, 6)

assert accuracy.get_evaluation_result() == pytest.approx(1.0 / 3)
assert accuracy.samples_seen - accuracy.total_correct == 0 + 6 + 4


def test_accuracy_invalid() -> None:
Expand Down
2 changes: 0 additions & 2 deletions modyn/tests/evaluator/internal/test_pytorch_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ def test_evaluator_init(load_state_mock: MagicMock) -> None:
)
)
assert evaluator._device == "cpu"
assert evaluator._device_type == "cpu"
assert not evaluator._amp
load_state_mock.assert_called_once_with(pathlib.Path("trained_model.modyn"))

Expand All @@ -183,7 +182,6 @@ def test_no_transform_evaluator_init(load_state_mock: MagicMock):
assert isinstance(evaluator._model.model, MockModel)
assert not evaluator._label_transformer_function
assert evaluator._device == "cpu"
assert evaluator._device_type == "cpu"
assert not evaluator._amp
load_state_mock.assert_called_once_with(pathlib.Path("trained_model.modyn"))

Expand Down

0 comments on commit ba26601

Please sign in to comment.