-
Notifications
You must be signed in to change notification settings - Fork 612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[🐛 BUG] Incorrect evaluation results due to multi-GPU distributed sampler #1872
Changes from 5 commits
105fc0a
4e420e7
301ea24
45d7265
96e2930
8fcdadc
eea88ba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,12 +14,13 @@ | |
|
||
from recbole.evaluator.register import Register | ||
import torch | ||
import copy | ||
|
||
|
||
class DataStruct(object): | ||
def __init__(self): | ||
def __init__(self, init=None): | ||
self._data_dict = {} | ||
if init is not None: | ||
self._data_dict.update(init) | ||
|
||
def __getitem__(self, name: str): | ||
return self._data_dict[name] | ||
|
@@ -41,6 +42,9 @@ def get(self, name: str): | |
def set(self, name: str, value): | ||
self._data_dict[name] = value | ||
|
||
def __iter__(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Used for for key, value in struct:
... |
||
return iter(self._data_dict.items()) | ||
|
||
def update_tensor(self, name: str, value: torch.Tensor): | ||
if name not in self._data_dict: | ||
self._data_dict[name] = value.cpu().clone().detach() | ||
|
@@ -190,7 +194,7 @@ def eval_batch_collect( | |
if self.register.need("data.label"): | ||
self.label_field = self.config["LABEL_FIELD"] | ||
self.data_struct.update_tensor( | ||
"data.label", interaction[self.label_field].to(self.device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Redundant conversion because all input will be transferred to |
||
"data.label", interaction[self.label_field] | ||
) | ||
|
||
def model_collect(self, model: torch.nn.Module): | ||
|
@@ -213,13 +217,13 @@ def eval_collect(self, eval_pred: torch.Tensor, data_label: torch.Tensor): | |
|
||
if self.register.need("data.label"): | ||
self.label_field = self.config["LABEL_FIELD"] | ||
self.data_struct.update_tensor("data.label", data_label.to(self.device)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. |
||
self.data_struct.update_tensor("data.label", data_label) | ||
|
||
def get_data_struct(self): | ||
"""Get all the evaluation resource that been collected. | ||
And reset some of outdated resource. | ||
""" | ||
returned_struct = copy.deepcopy(self.data_struct) | ||
returned_struct = DataStruct(self.data_struct) | ||
for key in ["rec.topk", "rec.meanrank", "rec.score", "rec.items", "data.label"]: | ||
if key in self.data_struct: | ||
del self.data_struct[key] | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -31,7 +31,7 @@ | |||||||||||||||||||||||||
|
||||||||||||||||||||||||||
from recbole.data.interaction import Interaction | ||||||||||||||||||||||||||
from recbole.data.dataloader import FullSortEvalDataLoader | ||||||||||||||||||||||||||
from recbole.evaluator import Evaluator, Collector | ||||||||||||||||||||||||||
from recbole.evaluator import Evaluator, Collector, DataStruct | ||||||||||||||||||||||||||
from recbole.utils import ( | ||||||||||||||||||||||||||
ensure_dir, | ||||||||||||||||||||||||||
get_local_time, | ||||||||||||||||||||||||||
|
@@ -46,6 +46,7 @@ | |||||||||||||||||||||||||
WandbLogger, | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
from torch.nn.parallel import DistributedDataParallel | ||||||||||||||||||||||||||
import torch.distributed as dist | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
class AbstractTrainer(object): | ||||||||||||||||||||||||||
|
@@ -577,8 +578,11 @@ def evaluate( | |||||||||||||||||||||||||
return | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
if load_best_model: | ||||||||||||||||||||||||||
# Refer to: https://pytorch.org/tutorials/intermediate/ddp_tutorial.html#save-and-load-checkpoints | ||||||||||||||||||||||||||
dist.barrier() | ||||||||||||||||||||||||||
checkpoint_file = model_file or self.saved_model_file | ||||||||||||||||||||||||||
checkpoint = torch.load(checkpoint_file, map_location=self.device) | ||||||||||||||||||||||||||
map_location = {"cuda:%d" % 0: "cuda:%d" % self.config["local_rank"]} | ||||||||||||||||||||||||||
checkpoint = torch.load(checkpoint_file, map_location=map_location) | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix the Refer to the example in the ddp tutorial. |
||||||||||||||||||||||||||
self.model.load_state_dict(checkpoint["state_dict"]) | ||||||||||||||||||||||||||
self.model.load_other_parameter(checkpoint.get("other_parameter")) | ||||||||||||||||||||||||||
message_output = "Loading model structure and parameters from {}".format( | ||||||||||||||||||||||||||
|
@@ -608,9 +612,7 @@ def evaluate( | |||||||||||||||||||||||||
else eval_data | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
num_sample = 0 | ||||||||||||||||||||||||||
for batch_idx, batched_data in enumerate(iter_data): | ||||||||||||||||||||||||||
num_sample += len(batched_data) | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unused any longer. |
||||||||||||||||||||||||||
interaction, scores, positive_u, positive_i = eval_func(batched_data) | ||||||||||||||||||||||||||
if self.gpu_available and show_progress: | ||||||||||||||||||||||||||
iter_data.set_postfix_str( | ||||||||||||||||||||||||||
|
@@ -621,35 +623,31 @@ def evaluate( | |||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
self.eval_collector.model_collect(self.model) | ||||||||||||||||||||||||||
struct = self.eval_collector.get_data_struct() | ||||||||||||||||||||||||||
result = self.evaluator.evaluate(struct) | ||||||||||||||||||||||||||
if not self.config["single_spec"]: | ||||||||||||||||||||||||||
result = self._map_reduce(result, num_sample) | ||||||||||||||||||||||||||
struct = self._gather_evaluation_resources(struct) | ||||||||||||||||||||||||||
result = self.evaluator.evaluate(struct) | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Gather |
||||||||||||||||||||||||||
self.wandblogger.log_eval_metrics(result, head="eval") | ||||||||||||||||||||||||||
return result | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def _map_reduce(self, result, num_sample): | ||||||||||||||||||||||||||
gather_result = {} | ||||||||||||||||||||||||||
total_sample = [ | ||||||||||||||||||||||||||
torch.zeros(1).to(self.device) for _ in range(self.config["world_size"]) | ||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||
torch.distributed.all_gather( | ||||||||||||||||||||||||||
total_sample, torch.Tensor([num_sample]).to(self.device) | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
total_sample = torch.cat(total_sample, 0) | ||||||||||||||||||||||||||
total_sample = torch.sum(total_sample).item() | ||||||||||||||||||||||||||
for key, value in result.items(): | ||||||||||||||||||||||||||
result[key] = torch.Tensor([value * num_sample]).to(self.device) | ||||||||||||||||||||||||||
gather_result[key] = [ | ||||||||||||||||||||||||||
torch.zeros_like(result[key]).to(self.device) | ||||||||||||||||||||||||||
for _ in range(self.config["world_size"]) | ||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||
torch.distributed.all_gather(gather_result[key], result[key]) | ||||||||||||||||||||||||||
gather_result[key] = torch.cat(gather_result[key], dim=0) | ||||||||||||||||||||||||||
gather_result[key] = round( | ||||||||||||||||||||||||||
torch.sum(gather_result[key]).item() / total_sample, | ||||||||||||||||||||||||||
self.config["metric_decimal_place"], | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
return gather_result | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unused anymore. |
||||||||||||||||||||||||||
def _gather_evaluation_resources(self, struct: DataStruct) -> DataStruct: | ||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||
Gather the evaluation resources from all ranks, e.g., 'rec.items', 'rec.score', 'data.label' | ||||||||||||||||||||||||||
Only 'rec.*' and 'data.label' are gathered, because they are distributed into different ranks. | ||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||
struct: data struct collected from all ranks | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
Returns: gathered data struct | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||
gather_struct = DataStruct(struct) | ||||||||||||||||||||||||||
for key, value in struct: | ||||||||||||||||||||||||||
# Adjust the condition according to | ||||||||||||||||||||||||||
# [the key definition in evaluator](/docs/source/developer_guide/customize_metrics.rst#set-metric_need) | ||||||||||||||||||||||||||
if key.startswith("rec.") or key == "data.label": | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only The |
||||||||||||||||||||||||||
gather_struct[key] = [None for _ in range(self.config["world_size"])] | ||||||||||||||||||||||||||
dist.all_gather_object(gather_struct[key], value) | ||||||||||||||||||||||||||
gather_struct[key] = torch.cat(gather_struct[key], dim=0) | ||||||||||||||||||||||||||
return gather_struct | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def _spilt_predict(self, interaction, batch_size): | ||||||||||||||||||||||||||
spilt_interaction = dict() | ||||||||||||||||||||||||||
|
@@ -786,7 +784,7 @@ def pretrain(self, train_data, verbose=True, show_progress=False): | |||||||||||||||||||||||||
self.logger.info(train_loss_output) | ||||||||||||||||||||||||||
self._add_train_loss_to_tensorboard(epoch_idx, train_loss) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
if (epoch_idx + 1) % self.save_step == 0: | ||||||||||||||||||||||||||
if (epoch_idx + 1) % self.save_step == 0 and self.config["local_rank"] == 0: | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Refer to: https://pytorch.org/tutorials/intermediate/ddp_tutorial.html#save-and-load-checkpoints:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please double-check if the following code needs to be fixed:
Thanks! |
||||||||||||||||||||||||||
saved_model_file = os.path.join( | ||||||||||||||||||||||||||
self.checkpoint_dir, | ||||||||||||||||||||||||||
"{}-{}-{}.pth".format( | ||||||||||||||||||||||||||
|
@@ -986,6 +984,8 @@ def _save_checkpoint(self, epoch): | |||||||||||||||||||||||||
epoch (int): the current epoch id | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||
if not self.config["single_spec"] and self.config["local_rank"] != 0: | ||||||||||||||||||||||||||
return | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. |
||||||||||||||||||||||||||
state = { | ||||||||||||||||||||||||||
"config": self.config, | ||||||||||||||||||||||||||
"epoch": epoch, | ||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Used for "deep copy"