Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[🐛 BUG] Incorrect evaluation results due to multi-GPU distributed sampler #1872

Closed
wants to merge 7 commits into from

Conversation

ChenglongMa
Copy link
Contributor

@ChenglongMa ChenglongMa commented Sep 21, 2023

Bug description

1. Duplicate sampling in DistributedSampler

if not config["single_spec"]:
index_sampler = torch.utils.data.distributed.DistributedSampler(
list(range(self.sample_size)), shuffle=shuffle, drop_last=False
)
self.step = max(1, self.step // config["world_size"])
shuffle = False

The torch.utils.data.distributed.DistributedSampler used in AbstractDataLoader pads the number of samples to make it divisible by the number of processes. Then here DsitributedSampler duplicates the last few samples if drop_last is False.

For example, suppose we have 10 samples and train models on 3 GPUs.

Then, it is likely that we will get the following partitions:
GPU 1: 1, 4, 7, 0
GPU 2: 0, 3, 6, 9
GPU 3: 2, 5, 8, 1

So, you will find that in order for each GPU to get the same amount of data, the sampler is repeatedly allocated 1 and 0 to different GPUs.

And let's check the evaluation logic in DDP mode:

def _map_reduce(self, result, num_sample):
gather_result = {}
total_sample = [
torch.zeros(1).to(self.device) for _ in range(self.config["world_size"])
]
torch.distributed.all_gather(
total_sample, torch.Tensor([num_sample]).to(self.device)
)
total_sample = torch.cat(total_sample, 0)
total_sample = torch.sum(total_sample).item()
for key, value in result.items():
result[key] = torch.Tensor([value * num_sample]).to(self.device)
gather_result[key] = [
torch.zeros_like(result[key]).to(self.device)
for _ in range(self.config["world_size"])
]
torch.distributed.all_gather(gather_result[key], result[key])
gather_result[key] = torch.cat(gather_result[key], dim=0)
gather_result[key] = round(
torch.sum(gather_result[key]).item() / total_sample,
self.config["metric_decimal_place"],
)
return gather_result

Basically, it takes this way to recalculate the final evaluation result:

$$ \text{final value} = \frac {\text{value} \times \text{num sample}}{\text{total sample}} $$

This can cause serious bugs because $\text{total sample}$ has changed to 12 instead of 10.

2. Incorrect way to get num_sample.

num_sample += len(batched_data)

This line tries to get the number of samples in each batched_data, however, it is unable to get the correct value since batched_data is a tuple and its length is not the number of samples.

Fix

1. Implement NoDuplicateDistributedSampler

Referring to pytorch/pytorch#25162 (comment), I implemented a NoDuplicateDistributedSampler. It will partition the samples unevenly into different GPUs, e.g., [1, 4, 7, 0], [3, 6, 9], [2, 5, 8].

2. Correct way to get num_sample.

We can get the num_samples and total_size in eval_data.sampler. Then we don't have to gather and calculate them from different GPUs, i.e.,

num_samples = eval_data.sampler.num_samples
total_size = eval_data.sampler.total_size

Existing Limitation and Future Work [Done]

Original Bug: The evaluation recalculation method as mentioned above only works for those metrics following the same way, e.g., MAE, Recall, Precision. However, some metrics are not working in this way, e.g., Gini Index, Shannon Entropy, and Item Coverage.

I'm investigating other methods of calculation but I can only do this first because of time constraints.
✨ I've found a new way to calculate the evaluation metrics:

We can gather data struct (e.g., rec.items, rec.score) from all GPUs first and then calculate the result using the self.evaluator.evaluate(struct) function (_map_reduce() will be obsolete).

Then we can get consistent results as running on a single GPU.

I will keep updating this PR.
And please let me know if I've thought anything wrong.

Thanks! 💪

@@ -190,7 +194,7 @@ def eval_batch_collect(
if self.register.need("data.label"):
self.label_field = self.config["LABEL_FIELD"]
self.data_struct.update_tensor(
"data.label", interaction[self.label_field].to(self.device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant conversion because all input will be transferred to cpu in the update_tensor method.

@@ -213,13 +217,13 @@ def eval_collect(self, eval_pred: torch.Tensor, data_label: torch.Tensor):

if self.register.need("data.label"):
self.label_field = self.config["LABEL_FIELD"]
self.data_struct.update_tensor("data.label", data_label.to(self.device))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

@@ -41,6 +42,9 @@ def get(self, name: str):
def set(self, name: str, value):
self._data_dict[name] = value

def __iter__(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used for for loop, e.g.:

for key, value in struct:
      ...

self._data_dict = {}
if init is not None:
self._data_dict.update(init)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used for "deep copy"

for batch_idx, batched_data in enumerate(iter_data):
num_sample += len(batched_data)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused any longer.

torch.sum(gather_result[key]).item() / total_sample,
self.config["metric_decimal_place"],
)
return gather_result
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused anymore.

if not self.config["single_spec"]:
result = self._map_reduce(result, num_sample)
struct = self._gather_evaluation_resources(struct)
result = self.evaluator.evaluate(struct)
Copy link
Contributor Author

@ChenglongMa ChenglongMa Sep 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gather struct from all GPUs. And concatenate them as one. We can then evaluate and compute the result in the same way as on a single GPU.

for key, value in struct:
# Adjust the condition according to
# [the key definition in evaluator](/docs/source/developer_guide/customize_metrics.rst#set-metric_need)
if key.startswith("rec.") or key == "data.label":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only rec.* and data.label are gathered, because they are distributed into different GPUs. While other keys like data.num_items, data.num_users are the same across different GPUs.

The keys refer to docs/source/developer_guide/customize_metrics.rst#set-metric_need

checkpoint_file = model_file or self.saved_model_file
checkpoint = torch.load(checkpoint_file, map_location=self.device)
map_location = {"cuda:%d" % 0: "cuda:%d" % self.config["local_rank"]}
checkpoint = torch.load(checkpoint_file, map_location=map_location)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix the EOFError: Ran out of input error when using DDP.

Refer to the example in the ddp tutorial.

@@ -786,7 +784,7 @@ def pretrain(self, train_data, verbose=True, show_progress=False):
self.logger.info(train_loss_output)
self._add_train_loss_to_tensorboard(epoch_idx, train_loss)

if (epoch_idx + 1) % self.save_step == 0:
if (epoch_idx + 1) % self.save_step == 0 and self.config["local_rank"] == 0:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refer to: https://pytorch.org/tutorials/intermediate/ddp_tutorial.html#save-and-load-checkpoints:

All processes should see same parameters as they all start from same random parameters and gradients are synchronized in backward passes. Therefore, saving it in one process is sufficient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please double-check if the following code needs to be fixed:

  1. In XGBoostTrainer

    if load_best_model:
    if model_file:
    checkpoint_file = model_file
    else:
    checkpoint_file = self.temp_best_file
    self.model.load_model(checkpoint_file)

  2. In LightGBMTrainer

    if load_best_model:
    if model_file:
    checkpoint_file = model_file
    else:
    checkpoint_file = self.temp_best_file
    self.model = self.lgb.Booster(model_file=checkpoint_file)

Thanks!

@@ -986,6 +984,8 @@ def _save_checkpoint(self, epoch):
epoch (int): the current epoch id

"""
if not self.config["single_spec"] and self.config["local_rank"] != 0:
return
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

@@ -48,8 +48,7 @@ def ensure_dir(dir_path):
dir_path (str): directory path

"""
if not os.path.exists(dir_path):
os.makedirs(dir_path)
os.makedirs(dir_path, exist_ok=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In DDP, when multiple processes jointly query whether a file/dir exists, the previous method will report an error because there is no thread lock.

@Ethan-TZ
Copy link
Member

@ChenglongMa Thanks for your nice contribution to our framework!
About the DistributedSampler, In our framework, the batch_size is often very large (for example, 1024) or it is 1. When it's very large, the impact of the last few repeated samples on the model is minimal. However, I'm more concerned about whether your implementation would make mistakes if it is 1?

@ChenglongMa
Copy link
Contributor Author

Hi @Ethan-TZ,

Thanks for pointing out this. It is indeed a bug that the code will stuck if the batch_size is too small.

What you said is also very reasonable. So do you recommend using the original DistributedSampler?

Thanks!

@Ethan-TZ
Copy link
Member

@ChenglongMa Considering stability, we will continue to use the DistributedSampler class. However, we also greatly appreciate your contribution. The points you raised are very important to us, and we will incorporate your code into subsequent development versions.

@ChenglongMa
Copy link
Contributor Author

Hi @Ethan-TZ,

That's great! Thanks for your clarification. But please consider fixing the evaluation logic, especially the calculation of metrics like the gini index.

@Ethan-TZ
Copy link
Member

@ChenglongMa
Thank you for pointing it out. In fact, we noticed this in our previous versions. In general, the main time-consuming process of the recommendation algorithm is evaluation. Therefore, the original consideration of adopting the DDP architecture was to fully utilize the computational resources of different GPUs (i.e., computing metrics on different GPUs).

However, if the results from all GPUs are collected onto one GPU for calculation, the computational efficiency is usually not high. Therefore, for linearly additive metrics like NDCG, we still use the previous architecture.

Currently, for non-additive metrics like AUC and Gini index, we use the average value to approximate the global value. Admittedly, this is an imprecise approach, but it's a trade-off between efficiency and accuracy. Since we can only calculate rough results now, we will design more specialized distributed algorithms for these non-additive metrics in the future. Generally speaking, the metrics we use most often (NDCG, MRR, Recall) are additive. We are currently considering taking your implementation for those non-additive metrics. Thanks for your nice contribution!

@ChenglongMa
Copy link
Contributor Author

@Ethan-TZ. Got it! Thanks for your detailed explanation👍. I will close this PR then.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants