You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I’d like to propose the addition of a batch.remove_at_index() method to PyTorch Geometric, which would take a mask or a list of indices and efficiently filter the batch without using the to_list() / from_list() conversions. Below is a sample comparison between two approaches, demonstrating significant performance improvement.
func1 uses the standard Batch.from_data_list() method to filter the batch, which is slower.
func2 directly manipulates the batch object without converting to/from lists, resulting in approximately a 10x speedup.
Additional Suggestions:
To make this feature more efficient, it could be beneficial to reuse some of the existing code from the collate function to handle custom node and edge attributes iteratively. Additionally, it might be useful to provide an option to return a "negative sub-batch" (i.e., the elements that are excluded from the mask) alongside the positive sub-batch.
Usage :
Currently, when we filter batch data like sub_batch = batch[mask], it returns a data list. But I believe it would be more convenient if it kept the result as a Batch object. This way, users can maintain the Batch format and call sub_batch.to_data_list() only when they explicitly need a data list. This would streamline operations where batch structure needs to be preserved.
I believe these improvements could greatly enhance performance, especially in batch filtering scenarios.
Let me know what you think of this idea, and if any further clarifications are needed!
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hello PyG team!
I’d like to propose the addition of a
batch.remove_at_index()
method to PyTorch Geometric, which would take a mask or a list of indices and efficiently filter the batch without using theto_list() / from_list()
conversions. Below is a sample comparison between two approaches, demonstrating significant performance improvement.Code Example:
Execution Results:
Explanation:
Additional Suggestions:
To make this feature more efficient, it could be beneficial to reuse some of the existing code from the collate function to handle custom node and edge attributes iteratively. Additionally, it might be useful to provide an option to return a "negative sub-batch" (i.e., the elements that are excluded from the mask) alongside the positive sub-batch.
Usage :
Currently, when we filter batch data like
sub_batch = batch[mask]
, it returns a data list. But I believe it would be more convenient if it kept the result as aBatch
object. This way, users can maintain theBatch
format and callsub_batch.to_data_list()
only when they explicitly need a data list. This would streamline operations where batch structure needs to be preserved.I believe these improvements could greatly enhance performance, especially in batch filtering scenarios.
Let me know what you think of this idea, and if any further clarifications are needed!
Beta Was this translation helpful? Give feedback.
All reactions