From d943ee180684617f526ff79b3b32d015ae6bdf0b Mon Sep 17 00:00:00 2001 From: infinity Date: Fri, 10 Nov 2023 02:06:27 +0000 Subject: [PATCH 01/28] implement node/layer/graph-wise sampling --- examples/clustergcn_nodeclass.py | 50 +++ examples/fastgcn_nodeclass.py | 76 +++++ sgl/__pycache__/__init__.cpython-37.pyc | Bin 0 -> 132 bytes sgl/__pycache__/__init__.cpython-39.pyc | Bin 0 -> 136 bytes sgl/data/__pycache__/__init__.cpython-37.pyc | Bin 0 -> 617 bytes sgl/data/__pycache__/__init__.cpython-39.pyc | Bin 0 -> 577 bytes sgl/data/__pycache__/base_data.cpython-37.pyc | Bin 0 -> 10901 bytes sgl/data/__pycache__/base_data.cpython-39.pyc | Bin 0 -> 11024 bytes .../__pycache__/base_dataset.cpython-37.pyc | Bin 0 -> 11913 bytes .../__pycache__/base_dataset.cpython-39.pyc | Bin 0 -> 11775 bytes .../__pycache__/base_sampler.cpython-37.pyc | Bin 0 -> 1565 bytes .../__pycache__/transforms.cpython-37.pyc | Bin 0 -> 7082 bytes .../__pycache__/transforms.cpython-39.pyc | Bin 0 -> 7060 bytes sgl/data/__pycache__/utils.cpython-37.pyc | Bin 0 -> 644 bytes sgl/data/__pycache__/utils.cpython-39.pyc | Bin 0 -> 654 bytes sgl/data/base_data.py | 6 +- .../__pycache__/__init__.cpython-37.pyc | Bin 0 -> 1276 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 1258 bytes sgl/dataset/__pycache__/acm.cpython-37.pyc | Bin 0 -> 4328 bytes sgl/dataset/__pycache__/acm.cpython-39.pyc | Bin 0 -> 4365 bytes sgl/dataset/__pycache__/actor.cpython-37.pyc | Bin 0 -> 4571 bytes sgl/dataset/__pycache__/actor.cpython-39.pyc | Bin 0 -> 4654 bytes .../__pycache__/airports.cpython-37.pyc | Bin 0 -> 3862 bytes .../__pycache__/airports.cpython-39.pyc | Bin 0 -> 3968 bytes sgl/dataset/__pycache__/amazon.cpython-37.pyc | Bin 0 -> 2570 bytes sgl/dataset/__pycache__/amazon.cpython-39.pyc | Bin 0 -> 2626 bytes .../__pycache__/amazon_product.cpython-37.pyc | Bin 0 -> 3674 bytes .../__pycache__/amazon_product.cpython-39.pyc | Bin 0 -> 3704 bytes sgl/dataset/__pycache__/aminer.cpython-37.pyc | Bin 0 -> 3443 bytes sgl/dataset/__pycache__/aminer.cpython-39.pyc | Bin 0 -> 3497 bytes .../choose_edge_type.cpython-37.pyc | Bin 0 -> 2972 bytes .../choose_edge_type.cpython-39.pyc | Bin 0 -> 2962 bytes .../__pycache__/coauthor.cpython-37.pyc | Bin 0 -> 2568 bytes .../__pycache__/coauthor.cpython-39.pyc | Bin 0 -> 2624 bytes .../__pycache__/custom_dataset.cpython-37.pyc | Bin 0 -> 7129 bytes .../__pycache__/custom_dataset.cpython-39.pyc | Bin 0 -> 7049 bytes sgl/dataset/__pycache__/dblp.cpython-37.pyc | Bin 0 -> 4367 bytes sgl/dataset/__pycache__/dblp.cpython-39.pyc | Bin 0 -> 4414 bytes .../__pycache__/dblp_original.cpython-37.pyc | Bin 0 -> 4532 bytes .../__pycache__/dblp_original.cpython-39.pyc | Bin 0 -> 4507 bytes .../__pycache__/facebook.cpython-37.pyc | Bin 0 -> 3227 bytes .../__pycache__/facebook.cpython-39.pyc | Bin 0 -> 3270 bytes sgl/dataset/__pycache__/flickr.cpython-37.pyc | Bin 0 -> 3590 bytes sgl/dataset/__pycache__/flickr.cpython-39.pyc | Bin 0 -> 3620 bytes sgl/dataset/__pycache__/github.cpython-37.pyc | Bin 0 -> 3212 bytes sgl/dataset/__pycache__/github.cpython-39.pyc | Bin 0 -> 3257 bytes sgl/dataset/__pycache__/imdb.cpython-37.pyc | Bin 0 -> 4255 bytes sgl/dataset/__pycache__/imdb.cpython-39.pyc | Bin 0 -> 4310 bytes .../__pycache__/karateclub.cpython-37.pyc | Bin 0 -> 3112 bytes .../__pycache__/karateclub.cpython-39.pyc | Bin 0 -> 3174 bytes .../__pycache__/linkx_dataset.cpython-37.pyc | Bin 0 -> 4654 bytes .../__pycache__/linkx_dataset.cpython-39.pyc | Bin 0 -> 4672 bytes sgl/dataset/__pycache__/nell.cpython-37.pyc | Bin 0 -> 4164 bytes sgl/dataset/__pycache__/nell.cpython-39.pyc | Bin 0 -> 4177 bytes sgl/dataset/__pycache__/ogbn.cpython-37.pyc | Bin 0 -> 2850 bytes sgl/dataset/__pycache__/ogbn.cpython-39.pyc | Bin 0 -> 2904 bytes .../__pycache__/ogbn_mag.cpython-37.pyc | Bin 0 -> 4571 bytes .../__pycache__/ogbn_mag.cpython-39.pyc | Bin 0 -> 4631 bytes .../__pycache__/planetoid.cpython-37.pyc | Bin 0 -> 4399 bytes .../__pycache__/planetoid.cpython-39.pyc | Bin 0 -> 4255 bytes .../planetoid_sampling.cpython-37.pyc | Bin 0 -> 4481 bytes .../planetoid_sampling.cpython-39.pyc | Bin 0 -> 4447 bytes sgl/dataset/__pycache__/reddit.cpython-37.pyc | Bin 0 -> 3228 bytes sgl/dataset/__pycache__/reddit.cpython-39.pyc | Bin 0 -> 3253 bytes sgl/dataset/__pycache__/twitch.cpython-37.pyc | Bin 0 -> 3340 bytes sgl/dataset/__pycache__/twitch.cpython-39.pyc | Bin 0 -> 3386 bytes sgl/dataset/__pycache__/utils.cpython-37.pyc | Bin 0 -> 2834 bytes sgl/dataset/__pycache__/utils.cpython-39.pyc | Bin 0 -> 2933 bytes sgl/dataset/__pycache__/webkb.cpython-37.pyc | Bin 0 -> 4685 bytes sgl/dataset/__pycache__/webkb.cpython-39.pyc | Bin 0 -> 4720 bytes sgl/dataset/__pycache__/wikics.cpython-37.pyc | Bin 0 -> 3979 bytes sgl/dataset/__pycache__/wikics.cpython-39.pyc | Bin 0 -> 4071 bytes sgl/dataset/planetoid.py | 9 + .../__pycache__/__init__.cpython-37.pyc | Bin 0 -> 127 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 131 bytes .../__pycache__/base_model.cpython-37.pyc | Bin 0 -> 10475 bytes .../__pycache__/base_model.cpython-39.pyc | Bin 0 -> 9501 bytes .../__pycache__/sample_models.cpython-37.pyc | Bin 0 -> 4119 bytes .../__pycache__/sample_models.cpython-39.pyc | Bin 0 -> 5267 bytes .../__pycache__/simple_models.cpython-37.pyc | Bin 0 -> 9192 bytes .../__pycache__/simple_models.cpython-39.pyc | Bin 0 -> 6970 bytes sgl/models/backup.py | 278 +++++++++++++++ sgl/models/base_model.py | 97 +++++- sgl/models/homo/__init__.py | 10 +- .../homo/__pycache__/__init__.cpython-37.pyc | Bin 0 -> 715 bytes .../homo/__pycache__/__init__.cpython-39.pyc | Bin 0 -> 629 bytes .../__pycache__/clustergcn.cpython-37.pyc | Bin 0 -> 1017 bytes .../__pycache__/clustergcn.cpython-39.pyc | Bin 0 -> 799 bytes .../homo/__pycache__/fastgcn.cpython-37.pyc | Bin 0 -> 1362 bytes .../homo/__pycache__/fastgcn.cpython-39.pyc | Bin 0 -> 1465 bytes .../homo/__pycache__/gamlp.cpython-37.pyc | Bin 0 -> 922 bytes .../homo/__pycache__/gamlp.cpython-39.pyc | Bin 0 -> 932 bytes .../gamlp_recursive.cpython-37.pyc | Bin 0 -> 962 bytes .../gamlp_recursive.cpython-39.pyc | Bin 0 -> 972 bytes .../homo/__pycache__/gbp.cpython-37.pyc | Bin 0 -> 956 bytes .../homo/__pycache__/gbp.cpython-39.pyc | Bin 0 -> 962 bytes .../homo/__pycache__/graphsage.cpython-37.pyc | Bin 0 -> 1316 bytes .../homo/__pycache__/graphsage.cpython-39.pyc | Bin 0 -> 1480 bytes .../homo/__pycache__/nafs.cpython-37.pyc | Bin 0 -> 857 bytes .../homo/__pycache__/nafs.cpython-39.pyc | Bin 0 -> 867 bytes .../homo/__pycache__/sgc.cpython-37.pyc | Bin 0 -> 845 bytes .../homo/__pycache__/sgc.cpython-39.pyc | Bin 0 -> 855 bytes .../homo/__pycache__/sgc_dist.cpython-37.pyc | Bin 0 -> 858 bytes .../homo/__pycache__/sgc_dist.cpython-39.pyc | Bin 0 -> 868 bytes .../homo/__pycache__/sign.cpython-37.pyc | Bin 0 -> 941 bytes .../homo/__pycache__/sign.cpython-39.pyc | Bin 0 -> 951 bytes .../homo/__pycache__/ssgc.cpython-37.pyc | Bin 0 -> 882 bytes .../homo/__pycache__/ssgc.cpython-39.pyc | Bin 0 -> 892 bytes .../__pycache__/vanillagcn.cpython-37.pyc | Bin 0 -> 1116 bytes sgl/models/homo/clustergcn.py | 11 + sgl/models/homo/fastgcn.py | 24 ++ sgl/models/homo/graphsage.py | 21 ++ sgl/models/homo/sgc.py | 5 +- sgl/models/homo/sgc_dist.py | 4 +- sgl/models/homo/vanillagcn.py | 17 + sgl/models/simple_models.py | 59 ++++ .../__pycache__/__init__.cpython-37.pyc | Bin 0 -> 130 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 134 bytes .../__pycache__/base_op.cpython-37.pyc | Bin 0 -> 2650 bytes .../__pycache__/base_op.cpython-39.pyc | Bin 0 -> 2690 bytes .../__pycache__/utils.cpython-37.pyc | Bin 0 -> 3492 bytes .../__pycache__/utils.cpython-39.pyc | Bin 0 -> 3399 bytes sgl/operators/base_op.py | 2 +- .../__pycache__/__init__.cpython-37.pyc | Bin 0 -> 276 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 282 bytes .../laplacian_graph_op.cpython-37.pyc | Bin 0 -> 1090 bytes .../laplacian_graph_op.cpython-39.pyc | Bin 0 -> 1094 bytes .../__pycache__/ppr_graph_op.cpython-37.pyc | Bin 0 -> 1164 bytes .../__pycache__/ppr_graph_op.cpython-39.pyc | Bin 0 -> 1178 bytes sgl/operators/graph_op/laplacian_graph_op.py | 2 +- .../__pycache__/__init__.cpython-37.pyc | Bin 0 -> 970 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 938 bytes .../concat_message_op.cpython-37.pyc | Bin 0 -> 782 bytes .../concat_message_op.cpython-39.pyc | Bin 0 -> 796 bytes ...arnable_weighted_message_op.cpython-37.pyc | Bin 0 -> 1799 bytes ...arnable_weighted_message_op.cpython-39.pyc | Bin 0 -> 1817 bytes .../last_message_op.cpython-37.pyc | Bin 0 -> 706 bytes .../last_message_op.cpython-39.pyc | Bin 0 -> 720 bytes ...arnable_weighted_messahe_op.cpython-37.pyc | Bin 0 -> 2908 bytes ...arnable_weighted_messahe_op.cpython-39.pyc | Bin 0 -> 2956 bytes .../__pycache__/max_message_op.cpython-37.pyc | Bin 0 -> 799 bytes .../__pycache__/max_message_op.cpython-39.pyc | Bin 0 -> 813 bytes .../mean_message_op.cpython-37.pyc | Bin 0 -> 757 bytes .../mean_message_op.cpython-39.pyc | Bin 0 -> 771 bytes .../__pycache__/min_message_op.cpython-37.pyc | Bin 0 -> 799 bytes .../__pycache__/min_message_op.cpython-39.pyc | Bin 0 -> 813 bytes .../over_smooth_distance_op.cpython-37.pyc | Bin 0 -> 1360 bytes .../over_smooth_distance_op.cpython-39.pyc | Bin 0 -> 1362 bytes ...projected_concat_message_op.cpython-37.pyc | Bin 0 -> 1341 bytes ...projected_concat_message_op.cpython-39.pyc | Bin 0 -> 1351 bytes .../simple_weighted_message_op.cpython-37.pyc | Bin 0 -> 1932 bytes .../simple_weighted_message_op.cpython-39.pyc | Bin 0 -> 1954 bytes .../__pycache__/sum_message_op.cpython-37.pyc | Bin 0 -> 740 bytes .../__pycache__/sum_message_op.cpython-39.pyc | Bin 0 -> 754 bytes sgl/sampler/__init__.py | 8 + .../__pycache__/__init__.cpython-37.pyc | Bin 0 -> 299 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 247 bytes .../__pycache__/base_sampler.cpython-37.pyc | Bin 0 -> 764 bytes .../__pycache__/base_sampler.cpython-39.pyc | Bin 0 -> 790 bytes .../__pycache__/fastgcn.cpython-37.pyc | Bin 0 -> 1778 bytes .../__pycache__/sampler.cpython-37.pyc | Bin 0 -> 12125 bytes .../__pycache__/sampler.cpython-39.pyc | Bin 0 -> 9891 bytes .../sampler_fastgcn.cpython-37.pyc | Bin 0 -> 1786 bytes .../sampler_methods.cpython-37.pyc | Bin 0 -> 9763 bytes sgl/sampler/__pycache__/utils.cpython-37.pyc | Bin 0 -> 1093 bytes sgl/sampler/base_sampler.py | 13 + sgl/sampler/sampler.py | 317 ++++++++++++++++++ sgl/sampler/utils.py | 35 ++ sgl/tasks/__init__.py | 4 +- sgl/tasks/__pycache__/__init__.cpython-37.pyc | Bin 0 -> 854 bytes sgl/tasks/__pycache__/__init__.cpython-39.pyc | Bin 0 -> 826 bytes .../__pycache__/base_task.cpython-37.pyc | Bin 0 -> 752 bytes .../__pycache__/base_task.cpython-39.pyc | Bin 0 -> 772 bytes .../clustering_metrics.cpython-37.pyc | Bin 0 -> 3268 bytes .../clustering_metrics.cpython-39.pyc | Bin 0 -> 3280 bytes .../correct_and_smooth.cpython-37.pyc | Bin 0 -> 4686 bytes .../correct_and_smooth.cpython-39.pyc | Bin 0 -> 4732 bytes .../link_prediction.cpython-37.pyc | Bin 0 -> 9960 bytes .../link_prediction.cpython-39.pyc | Bin 0 -> 9949 bytes .../node_classification.cpython-37.pyc | Bin 0 -> 6805 bytes .../node_classification.cpython-39.pyc | Bin 0 -> 6645 bytes .../node_classification_dist.cpython-37.pyc | Bin 0 -> 4768 bytes .../node_classification_dist.cpython-39.pyc | Bin 0 -> 4800 bytes ...ode_classification_sampling.cpython-37.pyc | Bin 0 -> 4180 bytes ...ode_classification_sampling.cpython-39.pyc | Bin 0 -> 4302 bytes ...assification_with_label_use.cpython-37.pyc | Bin 0 -> 5266 bytes ...assification_with_label_use.cpython-39.pyc | Bin 0 -> 5292 bytes .../node_clustering.cpython-37.pyc | Bin 0 -> 8114 bytes .../node_clustering.cpython-39.pyc | Bin 0 -> 8115 bytes sgl/tasks/__pycache__/utils.cpython-37.pyc | Bin 0 -> 10871 bytes sgl/tasks/__pycache__/utils.cpython-39.pyc | Bin 0 -> 10956 bytes sgl/tasks/node_classification_sampling.py | 127 +++++++ sgl/tasks/utils.py | 42 ++- .../__pycache__/__init__.cpython-37.pyc | Bin 0 -> 208 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 214 bytes .../correct_and_smooth.cpython-37.pyc | Bin 0 -> 2344 bytes .../correct_and_smooth.cpython-39.pyc | Bin 0 -> 2365 bytes sgl/tricks/__pycache__/utils.cpython-37.pyc | Bin 0 -> 2237 bytes sgl/tricks/__pycache__/utils.cpython-39.pyc | Bin 0 -> 2240 bytes sgl_dair.egg-info/PKG-INFO | 175 ++++++++++ sgl_dair.egg-info/SOURCES.txt | 123 +++++++ sgl_dair.egg-info/dependency_links.txt | 1 + sgl_dair.egg-info/requires.txt | 10 + sgl_dair.egg-info/top_level.txt | 1 + 204 files changed, 1499 insertions(+), 28 deletions(-) create mode 100644 examples/clustergcn_nodeclass.py create mode 100644 examples/fastgcn_nodeclass.py create mode 100644 sgl/__pycache__/__init__.cpython-37.pyc create mode 100644 sgl/__pycache__/__init__.cpython-39.pyc create mode 100644 sgl/data/__pycache__/__init__.cpython-37.pyc create mode 100644 sgl/data/__pycache__/__init__.cpython-39.pyc create mode 100644 sgl/data/__pycache__/base_data.cpython-37.pyc create mode 100644 sgl/data/__pycache__/base_data.cpython-39.pyc create mode 100644 sgl/data/__pycache__/base_dataset.cpython-37.pyc create mode 100644 sgl/data/__pycache__/base_dataset.cpython-39.pyc create mode 100644 sgl/data/__pycache__/base_sampler.cpython-37.pyc create mode 100644 sgl/data/__pycache__/transforms.cpython-37.pyc create mode 100644 sgl/data/__pycache__/transforms.cpython-39.pyc create mode 100644 sgl/data/__pycache__/utils.cpython-37.pyc create mode 100644 sgl/data/__pycache__/utils.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/__init__.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/__init__.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/acm.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/acm.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/actor.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/actor.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/airports.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/airports.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/amazon.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/amazon.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/amazon_product.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/amazon_product.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/aminer.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/aminer.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/choose_edge_type.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/choose_edge_type.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/coauthor.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/coauthor.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/custom_dataset.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/custom_dataset.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/dblp.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/dblp.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/dblp_original.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/dblp_original.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/facebook.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/facebook.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/flickr.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/flickr.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/github.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/github.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/imdb.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/imdb.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/karateclub.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/karateclub.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/linkx_dataset.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/linkx_dataset.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/nell.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/nell.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/ogbn.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/ogbn.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/ogbn_mag.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/ogbn_mag.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/planetoid.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/planetoid.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/planetoid_sampling.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/planetoid_sampling.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/reddit.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/reddit.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/twitch.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/twitch.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/utils.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/utils.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/webkb.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/webkb.cpython-39.pyc create mode 100644 sgl/dataset/__pycache__/wikics.cpython-37.pyc create mode 100644 sgl/dataset/__pycache__/wikics.cpython-39.pyc create mode 100644 sgl/models/__pycache__/__init__.cpython-37.pyc create mode 100644 sgl/models/__pycache__/__init__.cpython-39.pyc create mode 100644 sgl/models/__pycache__/base_model.cpython-37.pyc create mode 100644 sgl/models/__pycache__/base_model.cpython-39.pyc create mode 100644 sgl/models/__pycache__/sample_models.cpython-37.pyc create mode 100644 sgl/models/__pycache__/sample_models.cpython-39.pyc create mode 100644 sgl/models/__pycache__/simple_models.cpython-37.pyc create mode 100644 sgl/models/__pycache__/simple_models.cpython-39.pyc create mode 100644 sgl/models/backup.py create mode 100644 sgl/models/homo/__pycache__/__init__.cpython-37.pyc create mode 100644 sgl/models/homo/__pycache__/__init__.cpython-39.pyc create mode 100644 sgl/models/homo/__pycache__/clustergcn.cpython-37.pyc create mode 100644 sgl/models/homo/__pycache__/clustergcn.cpython-39.pyc create mode 100644 sgl/models/homo/__pycache__/fastgcn.cpython-37.pyc create mode 100644 sgl/models/homo/__pycache__/fastgcn.cpython-39.pyc create mode 100644 sgl/models/homo/__pycache__/gamlp.cpython-37.pyc create mode 100644 sgl/models/homo/__pycache__/gamlp.cpython-39.pyc create mode 100644 sgl/models/homo/__pycache__/gamlp_recursive.cpython-37.pyc create mode 100644 sgl/models/homo/__pycache__/gamlp_recursive.cpython-39.pyc create mode 100644 sgl/models/homo/__pycache__/gbp.cpython-37.pyc create mode 100644 sgl/models/homo/__pycache__/gbp.cpython-39.pyc create mode 100644 sgl/models/homo/__pycache__/graphsage.cpython-37.pyc create mode 100644 sgl/models/homo/__pycache__/graphsage.cpython-39.pyc create mode 100644 sgl/models/homo/__pycache__/nafs.cpython-37.pyc create mode 100644 sgl/models/homo/__pycache__/nafs.cpython-39.pyc create mode 100644 sgl/models/homo/__pycache__/sgc.cpython-37.pyc create mode 100644 sgl/models/homo/__pycache__/sgc.cpython-39.pyc create mode 100644 sgl/models/homo/__pycache__/sgc_dist.cpython-37.pyc create mode 100644 sgl/models/homo/__pycache__/sgc_dist.cpython-39.pyc create mode 100644 sgl/models/homo/__pycache__/sign.cpython-37.pyc create mode 100644 sgl/models/homo/__pycache__/sign.cpython-39.pyc create mode 100644 sgl/models/homo/__pycache__/ssgc.cpython-37.pyc create mode 100644 sgl/models/homo/__pycache__/ssgc.cpython-39.pyc create mode 100644 sgl/models/homo/__pycache__/vanillagcn.cpython-37.pyc create mode 100644 sgl/models/homo/clustergcn.py create mode 100644 sgl/models/homo/fastgcn.py create mode 100644 sgl/models/homo/graphsage.py create mode 100644 sgl/models/homo/vanillagcn.py create mode 100644 sgl/operators/__pycache__/__init__.cpython-37.pyc create mode 100644 sgl/operators/__pycache__/__init__.cpython-39.pyc create mode 100644 sgl/operators/__pycache__/base_op.cpython-37.pyc create mode 100644 sgl/operators/__pycache__/base_op.cpython-39.pyc create mode 100644 sgl/operators/__pycache__/utils.cpython-37.pyc create mode 100644 sgl/operators/__pycache__/utils.cpython-39.pyc create mode 100644 sgl/operators/graph_op/__pycache__/__init__.cpython-37.pyc create mode 100644 sgl/operators/graph_op/__pycache__/__init__.cpython-39.pyc create mode 100644 sgl/operators/graph_op/__pycache__/laplacian_graph_op.cpython-37.pyc create mode 100644 sgl/operators/graph_op/__pycache__/laplacian_graph_op.cpython-39.pyc create mode 100644 sgl/operators/graph_op/__pycache__/ppr_graph_op.cpython-37.pyc create mode 100644 sgl/operators/graph_op/__pycache__/ppr_graph_op.cpython-39.pyc create mode 100644 sgl/operators/message_op/__pycache__/__init__.cpython-37.pyc create mode 100644 sgl/operators/message_op/__pycache__/__init__.cpython-39.pyc create mode 100644 sgl/operators/message_op/__pycache__/concat_message_op.cpython-37.pyc create mode 100644 sgl/operators/message_op/__pycache__/concat_message_op.cpython-39.pyc create mode 100644 sgl/operators/message_op/__pycache__/iterate_learnable_weighted_message_op.cpython-37.pyc create mode 100644 sgl/operators/message_op/__pycache__/iterate_learnable_weighted_message_op.cpython-39.pyc create mode 100644 sgl/operators/message_op/__pycache__/last_message_op.cpython-37.pyc create mode 100644 sgl/operators/message_op/__pycache__/last_message_op.cpython-39.pyc create mode 100644 sgl/operators/message_op/__pycache__/learnable_weighted_messahe_op.cpython-37.pyc create mode 100644 sgl/operators/message_op/__pycache__/learnable_weighted_messahe_op.cpython-39.pyc create mode 100644 sgl/operators/message_op/__pycache__/max_message_op.cpython-37.pyc create mode 100644 sgl/operators/message_op/__pycache__/max_message_op.cpython-39.pyc create mode 100644 sgl/operators/message_op/__pycache__/mean_message_op.cpython-37.pyc create mode 100644 sgl/operators/message_op/__pycache__/mean_message_op.cpython-39.pyc create mode 100644 sgl/operators/message_op/__pycache__/min_message_op.cpython-37.pyc create mode 100644 sgl/operators/message_op/__pycache__/min_message_op.cpython-39.pyc create mode 100644 sgl/operators/message_op/__pycache__/over_smooth_distance_op.cpython-37.pyc create mode 100644 sgl/operators/message_op/__pycache__/over_smooth_distance_op.cpython-39.pyc create mode 100644 sgl/operators/message_op/__pycache__/projected_concat_message_op.cpython-37.pyc create mode 100644 sgl/operators/message_op/__pycache__/projected_concat_message_op.cpython-39.pyc create mode 100644 sgl/operators/message_op/__pycache__/simple_weighted_message_op.cpython-37.pyc create mode 100644 sgl/operators/message_op/__pycache__/simple_weighted_message_op.cpython-39.pyc create mode 100644 sgl/operators/message_op/__pycache__/sum_message_op.cpython-37.pyc create mode 100644 sgl/operators/message_op/__pycache__/sum_message_op.cpython-39.pyc create mode 100644 sgl/sampler/__init__.py create mode 100644 sgl/sampler/__pycache__/__init__.cpython-37.pyc create mode 100644 sgl/sampler/__pycache__/__init__.cpython-39.pyc create mode 100644 sgl/sampler/__pycache__/base_sampler.cpython-37.pyc create mode 100644 sgl/sampler/__pycache__/base_sampler.cpython-39.pyc create mode 100644 sgl/sampler/__pycache__/fastgcn.cpython-37.pyc create mode 100644 sgl/sampler/__pycache__/sampler.cpython-37.pyc create mode 100644 sgl/sampler/__pycache__/sampler.cpython-39.pyc create mode 100644 sgl/sampler/__pycache__/sampler_fastgcn.cpython-37.pyc create mode 100644 sgl/sampler/__pycache__/sampler_methods.cpython-37.pyc create mode 100644 sgl/sampler/__pycache__/utils.cpython-37.pyc create mode 100644 sgl/sampler/base_sampler.py create mode 100644 sgl/sampler/sampler.py create mode 100644 sgl/sampler/utils.py create mode 100644 sgl/tasks/__pycache__/__init__.cpython-37.pyc create mode 100644 sgl/tasks/__pycache__/__init__.cpython-39.pyc create mode 100644 sgl/tasks/__pycache__/base_task.cpython-37.pyc create mode 100644 sgl/tasks/__pycache__/base_task.cpython-39.pyc create mode 100644 sgl/tasks/__pycache__/clustering_metrics.cpython-37.pyc create mode 100644 sgl/tasks/__pycache__/clustering_metrics.cpython-39.pyc create mode 100644 sgl/tasks/__pycache__/correct_and_smooth.cpython-37.pyc create mode 100644 sgl/tasks/__pycache__/correct_and_smooth.cpython-39.pyc create mode 100644 sgl/tasks/__pycache__/link_prediction.cpython-37.pyc create mode 100644 sgl/tasks/__pycache__/link_prediction.cpython-39.pyc create mode 100644 sgl/tasks/__pycache__/node_classification.cpython-37.pyc create mode 100644 sgl/tasks/__pycache__/node_classification.cpython-39.pyc create mode 100644 sgl/tasks/__pycache__/node_classification_dist.cpython-37.pyc create mode 100644 sgl/tasks/__pycache__/node_classification_dist.cpython-39.pyc create mode 100644 sgl/tasks/__pycache__/node_classification_sampling.cpython-37.pyc create mode 100644 sgl/tasks/__pycache__/node_classification_sampling.cpython-39.pyc create mode 100644 sgl/tasks/__pycache__/node_classification_with_label_use.cpython-37.pyc create mode 100644 sgl/tasks/__pycache__/node_classification_with_label_use.cpython-39.pyc create mode 100644 sgl/tasks/__pycache__/node_clustering.cpython-37.pyc create mode 100644 sgl/tasks/__pycache__/node_clustering.cpython-39.pyc create mode 100644 sgl/tasks/__pycache__/utils.cpython-37.pyc create mode 100644 sgl/tasks/__pycache__/utils.cpython-39.pyc create mode 100644 sgl/tasks/node_classification_sampling.py create mode 100644 sgl/tricks/__pycache__/__init__.cpython-37.pyc create mode 100644 sgl/tricks/__pycache__/__init__.cpython-39.pyc create mode 100644 sgl/tricks/__pycache__/correct_and_smooth.cpython-37.pyc create mode 100644 sgl/tricks/__pycache__/correct_and_smooth.cpython-39.pyc create mode 100644 sgl/tricks/__pycache__/utils.cpython-37.pyc create mode 100644 sgl/tricks/__pycache__/utils.cpython-39.pyc create mode 100644 sgl_dair.egg-info/PKG-INFO create mode 100644 sgl_dair.egg-info/SOURCES.txt create mode 100644 sgl_dair.egg-info/dependency_links.txt create mode 100644 sgl_dair.egg-info/requires.txt create mode 100644 sgl_dair.egg-info/top_level.txt diff --git a/examples/clustergcn_nodeclass.py b/examples/clustergcn_nodeclass.py new file mode 100644 index 0000000..e912174 --- /dev/null +++ b/examples/clustergcn_nodeclass.py @@ -0,0 +1,50 @@ +import argparse +import networkx as nx +import torch.nn.functional as F +from sgl.dataset import Planetoid +from sgl.models.homo import ClusterGCN +from sgl.tasks import NodeClassification_Sampling + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description = "Run .") + parser.add_argument("--clustering_method", + nargs = "?", + default = "random", + choices = ["random", "metis"], + help = "Clustering method for graph decomposition. Default is the random procedure.") + + parser.add_argument("--epochs", + type = int, + default = 200, + help = "Number of training epochs. Default is 200.") + + parser.add_argument("--seed", + type = int, + default = 42, + help = "Random seed for train_test split. Default is 42.") + + parser.add_argument("--dropout", + type = float, + default = 0.5, + help = "Dropout parameter. Default is 0.5.") + + parser.add_argument("--learning_rate", + type = float, + default = 0.01, + help = "Learning rate. Default is 0.01.") + + parser.add_argument("--test_ratio", + type = float, + default = 0.9, + help = "Test data ratio. Default is 0.1.") + + parser.add_argument("--cluster_number", + type = int, + default = 10, + help = "Number of clusters extracted. Default is 10.") + args = parser.parse_args() + device = 'cuda:0' + dataset = Planetoid("cora", "/home/ssq/test_data/", f"clustergcn_{args.cluster_number}") + model = ClusterGCN(nx.from_scipy_sparse_matrix(dataset.adj), dataset.x.numpy(), dataset.y.unsqueeze(1).numpy(), device, dataset.num_features, 128, dataset.num_classes, args.clustering_method, args.cluster_number, args.test_ratio) + test_acc = NodeClassification_Sampling(dataset, model, lr=0.1, weight_decay=5e-5, epochs=30, device=device, loss_fn=F.nll_loss, train_batch_size=1, eval_batch_size=1).test_acc diff --git a/examples/fastgcn_nodeclass.py b/examples/fastgcn_nodeclass.py new file mode 100644 index 0000000..e550504 --- /dev/null +++ b/examples/fastgcn_nodeclass.py @@ -0,0 +1,76 @@ +import argparse +import torch.nn.functional as F +from sgl.dataset import Planetoid +from sgl.models.homo import FastGCN, GraphSAGE, VanillaGCN +from sgl.tasks import NodeClassification_Sampling + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("FastGCN") + parser.add_argument( + "--hidden", type=int, default=128, help="dimension of hidden layer" + ) + parser.add_argument("--dropout", type=float, default=0.5, help="dropout") + parser.add_argument( + "--layer_sizes", type=str, default="128-128", help="sampling sizes per layer" + ) + args = parser.parse_args() + device = "cuda:0" + dataset = Planetoid("cora", "/home/ssq/test_data/", "official") + # model = FastGCN( + # dataset, + # hidden_dim=args.hidden, + # output_dim=dataset.num_classes, + # dropout=args.dropout, + # device=device, + # inductive=False, + # prob_type="uniform" + # ) + # test_acc = NodeClassification_Sampling( + # dataset, + # model, + # lr=0.1, + # weight_decay=5e-5, + # epochs=20, + # device=device, + # loss_fn=F.nll_loss, + # train_batch_size=256, + # eval_batch_size=256, + # ).test_acc + # print(f"final test acc: {test_acc}") + model = GraphSAGE( + dataset, + hidden_dim=args.hidden, + output_dim=dataset.num_classes, + dropout=args.dropout, + device=device, + ) + test_acc = NodeClassification_Sampling( + dataset, + model, + lr=0.1, + weight_decay=5e-5, + epochs=20, + device=device, + loss_fn=F.nll_loss, + train_batch_size=64, + eval_batch_size=64, + ).test_acc + print(f"final test acc: {test_acc}") + # model = VanillaGCN( + # dataset, + # hidden_dim=args.hidden, + # output_dim=dataset.num_classes, + # dropout=args.dropout, + # device=device, + # ) + # test_acc = NodeClassification_Sampling( + # dataset, + # model, + # lr=0.1, + # weight_decay=5e-5, + # epochs=20, + # device=device, + # loss_fn=F.nll_loss + # ).test_acc + # print(f"final test acc: {test_acc}") diff --git a/sgl/__pycache__/__init__.cpython-37.pyc b/sgl/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d116424784444aff324ffeae9eedd6401ac578ba GIT binary patch literal 132 zcmZ?b<>g`k0+m9yR9+zc7{q}AMj*ohh>JOZL<&O`LkeRsgC^50=Hm1mKTXD4EP08! zsVfg`k0+m9yR9+zc7{oyaj6jY95EpX*i4=w?h7`tN22G}0%*E+BewvK8Sn?8c zQ&%z+F$1N+#4kntjQreG{o>+6{a|+=-JHq_eV|(X`1s7c%#!$cy@JYH95%W6DWy57 Lb|BL~12F>tCj1?l literal 0 HcmV?d00001 diff --git a/sgl/data/__pycache__/__init__.cpython-37.pyc b/sgl/data/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6977bc1972dd6fdcc7135530ae600a2aab3ab635 GIT binary patch literal 617 zcmZ{h%}ygR49DlQ^U;|pMQ}qLk$Ql6ffYh4EnHSyxaAV1aynGR&5%r2@JPJUTsiG4 z>F~ob9ECPWz9pjRkfs*UD>npH+2wkbb-Doa%4K8kN73E3JmD z)1GuoZ{6@JSSWy3^ANPrgO*PBluL^LOE#S0wOL7HZO<)ew14zF*CHO1YwBf8eo>QS zdc|cOjWlpEBVkM!DI;U#jDo?R;HHd50^3;q?CZvX%Q literal 0 HcmV?d00001 diff --git a/sgl/data/__pycache__/__init__.cpython-39.pyc b/sgl/data/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..410cdf53e3c506869810aca60d60b41da760f567 GIT binary patch literal 577 zcmZ{hy>7xV5Pdp!#sILsed@o8HIoa5V8;=7D3En zin+k&I;;Z;OQ6fTkg^mqmK}1T?m?gRA!j)Z*npCX$mpH23mMNz(fOiYU&QUuij6eD zr8PBIvQqAEXk#RX7Hi>@Jgw?is^M<8sH~`0Z4Sb{@TC&X&MN18BnEj+j2gAy3zZ$(p9>|`z7b=&AQ=yQt#XDEf~4e>JERD!x6@GOs|3;y3>-j literal 0 HcmV?d00001 diff --git a/sgl/data/__pycache__/base_data.cpython-37.pyc b/sgl/data/__pycache__/base_data.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9854cdafdb2928aa219b8927fbbe2e4ff7cf5d8 GIT binary patch literal 10901 zcmcIq&u<*peV<>u+}|Wultjw1vXm8DbQQ{JU8QwpRfsLw4xCNu7*=lf=6 zXP3Ly3TSsRZ{E-EulL^f$9wbnYPDqGx4wVZ`_;N({3{E|&qn4p9^Xac8m_rxY?$)4 zHY~iY9edZ=a7?z@JJxP)BWD_K7_Q^yJ~Z6iLu(_CQr?|IX%3}1lnQPUr6NiNluB+H zr7}uI*ZhjHR{1?J8O=2_Ds_C-+HD7_dx%p0U9ac&Rfny~FQx87)ObP}s5QY!T;$wwx4dR&E~b(19G=(k`2UO~FhVCVA6kAruw3(@`7JYaLZfA)-VB@| zx1A5>HtmO&YNEzERbyvs{PzIEA`BS~0sGrQ;iMRvkIiG&Spn<*TIwv+S*Ww6?t7>! zg*NK!&;xiGg0VSkrn&$tGr z&uI>xp#&Pi(Vo}XJ@A9Z`(C5n@G*_Ot<}M;clsY(YjpZMjdss%c<9;s!0T?^4>Ems zx_)r2(c4?^xoxG|M~%K}=-FSzLdf~|+k4*HQdH{t-JTz`dmS&b0Y56f3!Hb9>Z@p> zMIu@)K|oR>hfpJ{w-?RDU=^27lnbz^?nkA!`@JoVBr4zC>9>QpP>La`{s&Q#2l-Jw zi$SvxEoST5L7@D|-tl_Td|cl<*cFPDC74m%JBpU#Mt`rZd@ot>bCKhFJDZVByfzV* zGs23BL&ge(+UvVsRFX=;`%^=G3ON4fh3ohGyWVx*-@pEqH{Slt&e7p@e{1Ku+YZ{- z-){pQbL*g@&}wyi-JsPfVwnFt62mN+1v775G3$7jGXK1(|0QeL98~W>(Cf+Qj>(bu z@y$opeIA+6*yj2$v9r~Al<``|OUTqo1I4I1A_i9Z6qaEP>Vjybc{Xh1Xfif+5lvU{ zBsgkTmYRtL{eO4AJbx-`D^dUoU|81TXBVFtxS8$WpADAVN=V681G&RArF zQpGi>=xr(L8C7ONEvl{je2Zjaha|RgA*o=sI0$JM)^uGz%Bgk_{8tHN9tn1s26Cus z+2YY!UKChV>?Od|v+Q6VMkXw))B<2;`!vujoAE7Vd^p{)wWY9fb4uTKpf`3806z{gD(A??0l;^!**ON$S%w5sB}lj-qJfy zc!>=;i639=tUDsH5^LJYqMl7qpT)o_f+FwVB3ugzyhJUiqY**TM2VpO1Q{);z}mJ& zO#bYh~?3)4|PnY*g%e1u)dnD?lcnOCydV>9UV6 zJ1(T};qf`YhP`ACYJ&d>CWl6tVaqm8a^RI(>Q7N4i*d^SC4t|^lkDFZ%_p0l(`dM} z#fhDQ<~g+gU@{tpXOq>?6f~CsXEK_|ZoAtj88k&LG}Jy7hAUGsh>!lG42A`bflMHP zN8tBVi0IW%WmI1 zE<>%mB%ZpU^V>M2U4bf&U4Y+n`K{Snt%*Hlyf~S^s6s8a>9vCc<@p~Izi3mdVh-wh zsqoLp}g5}}gOJ|_Ig;Tb=jZf(=ybT1KV zCdO@|xQdx0VPf00Pc3xC%ldUTjmko??ss^6jx#IaEu!Vy259l;0ixSXk2(=LG^JK z7q344qQ)L~k6&fopEG%liA=y4P3f{!pF@|w!Q)>=g19|5^qZ4h4QF9^a`Jb0HYTGw z7?A!b)Z2LcZzD;)>E*0!8k*O}qxbopm8FLMM<%4+C3Y7^)<1J&M zl;ucoyG)4t|8dXg<>A@mj_vT4jJ|=<)v&C4hlP_`Sayxapft@H9J^vL62Kh*pgc17 zuin!Vo|oIX1yG-dgfARd!DE#Yeq{OqxSXSm(LWF=OYjf+e}P^WB@yTCBp;S2Db)T3 zJxaRP{|(?22&eZA)aE%Wfc~GTElTYda1)4G2`#1%tVfnw*LmC%rZt#5DaCV_Pv;K( z{TsoeTOzE(P2uCOj;k>xmilH;hoe%3Ji!@KTT-+`K2=IeUZVa>L1- zoG%VX1?cSKq<(d<@>SjGFNclohM-e4Sf6~8dgGn`2fkdTA{G6r-Es|5D#UwE)h+@1 zV1&9?1dzqbW7gHdTUka(C#>)Yj`*^6h+`UWYiFY6Rz`2&=#?^Qk%&nx8jiMra*9iL zIC9us4s;`j^2McSUj>oV>F*tFJbM~G&TU)$%}v~1_|b*Q%>wZ{fa@UI-ttb-L)lHH zlr5UlsPtZY=Rn^!%MHg`O&K5QyuowF*TX~e0%**k}idi$4(Q+Qdor-`ek5JUt;nmlQ)?B1(MhUjchEqG2un!RtxK;8+f~|)-`NVp9YW@ zdzlIuLYb|$L~`!oQ|8#@%|#+P zf!Q!|IZwtE11aZ9M5HV}y~b2_u(cMwYUSj1i2p zZ96jqJjG7PWL_V+kCDlSU6NTqMhhM1>La!>@e1_y&iurf_UnY-VnT{VZ>VoD$1SJ0 zA5@#kt4t^weHSz&Msf`{+(q(VctqjBYQI`26zWO5ig&qEkUncAMU|+p0C|fQRs2qO z?`U0rjU}J6={qeh3w?sYh*$<^)?SJ^D%{jxsN8})^eV)4H&>xkyIp&$_O<^77niR{ literal 0 HcmV?d00001 diff --git a/sgl/data/__pycache__/base_data.cpython-39.pyc b/sgl/data/__pycache__/base_data.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f177c37bdb6ed5fadb749159b3040019c62f2908 GIT binary patch literal 11024 zcmcIq&2JpZb?>hk&Ie~Wq_~pE)!IX?yt}iJskP&*6I)T_^-B8@W4zHW7BORIJeqD& zLy-G=soe7(oyNicf)I7?hDi1mxgD4nE`%7&-W2z{n*=2M7`*IrxzL0XnSm z`@QP!=@~Y6#YHlMs;PQa^}ed=y;oJ!t<@?9KC8Poy<_Lj9>SSy&uYliE%g%1q3aL-yRqE>WgP@6$*2DOq~My-rm z3AKt_MXicj*)`ubTC=|fC8OOkqe{K!{K2V-PS-3#Cgz~Ig`@lfC?3FhwZau9z?ws3@mMVJ9 zy9-jDMfr@oDCL?+a5i#oxEo%(HwuN1oNSNwD15NN^UV~s^mn?-_mYJ_6FI)OwI11| zYm;C#qpYYr7OX_9gQ4q16#)v}hlXl`#-r!2+!=0rSA2i>%Gh!9}Ul=@0rrr&Go-|3zemSusmy(YjTkt5Zx$bqYB!+l z)dWM&ctM5SCnr=f9{_;P4OGFL?TLWYpmUa~fh(V^I) zy;&5fJtTA8dfU(vS!$lt7MSIc)Qh4&aJ~J@=#xk{HfGsr?lx>A z%voo9=?O_Nb=a{9r-g3*fL!5rGJ-X1|+iube`58yv9|Ycpr&b<4 zO--E?60Az&8?RNE$>^ipT7$xoq{IRTMK9m8n5Nrar|E z=6+V~DNhGeKp?{o2tnky2py$Yy;&;%nogq;dMWqvS3L0w@ zOJuw2-_0W!Pdb}-^U1R56lB#GPZEuuZ7)NkD5C8jbIV-tUu8ck30A=j%uQN^^MQTj z=);9r1T(-%&wnu}#s?W57G}gESZ&2XpjDFnmDx+_UXWd-_Z<0s7(05hIK9jXJ3#W~ z%xWjg`7|o|ZKBO{_+LiDcZl0UhUikOjBk~-TBhub=!S+NqZ^V-bVFXJbwk*9kz(B>+q-4A+wIn@YHORPr09#@&d|DcSXu)yfYYmF|9f7v+gBnS@ags z@|mP2k;L;+;WgFWxzp}(K_nlR>6~QMaSFNWa7wpvI(Nd#J)Y6cfM;+goC!-|`5prT zgaFCOz>)X~k)0sdXJMI;F^E>JRn;N(GDu`%I2v=kHYD(qwIR!bizYsQGN_+kkX#tw zNAfw#hP`Nx>O%hsTn&9M4>+4co%z=$2%MTVT^hd^M2f|e+=(8fkXWs zvu2tJ?hywb57RidxI?)r2Dv%X?%$z-|9I+=FZG)NAP2ax`nA z@6Q7e+1uUyT>h#g-`nc?zUTh}J;&_DTRzi1!5uZ0a+lGX>0008Wwv#Mi|PmFk>euA zMYG*juah-1a=QC{zg1AL3S<7>wmxvl32VmR-SfPW*V5tT+6=C8I_%O!XPt3Ap;jIt zTyH~pp8rpP#C+GsYRz<*V4QQuBx>cnQIUxQGqvKV?P-z@N&Yh$62no6=MUgx;b*ED z8d4p8BE%d&Au#e5*Wa?*qeWuQ9Dv)Xu3!!sO-KNs!|qp5jrF45VpB8Cs1t$x3kp7k zGAy1TPb2tOh$B~Ux|n|_V$L~Y?pSbfqXHrrpNS>$B!WLh6l1{$%^Rqvk~=`DB`0f% zBRL*67*RjH5kI!8)<`kS0D{dDoHr z?;=Zs_47%vu3y>K;uki&wG}h(YsJj_S~2s!R?GzH#Y~W1bjLw@(H#frMTF?*Bz&$* zz`PVyU-xkDIFubT88-6}#dD%;$3Mps>SIH8aI?m6IV`+teCzCC71pjA@D8pKz@G3n zYqtRyp5Fkw3Ya{q4b~Z62RzKZ1(@8#4Q>SpzmH}fz(O9C@FZ3^DnBq-x@)WV&`a*s z2A4w|_;9E9XTVS*R`r9hERY`qQWiQtMsK-E6AmkG@o-l7TRyCYvoIo+uzJ-%$qB1V z#utsl+1;OnI9=}<_f6H1F;(e*8S)f^c*GLT4@Tlw#=(-y?cipW*mnPIP{d(@&wbRy z%NSV;tGZ`cI;w|N*LVQ_(vrb(O9qb?cz6JY`{wSl(4Lb=!bVsGp5?Xw}l*0-&1L)7tqoSey&w!;wEQ4|c&>`+T%o^zF_?EN#3eHP@nJ3I!&2W477)kOphHN7 z+?E{e&`XUP5`D~twX4S9;t73c$9?WwyU)q-zuHz;urxUmj9fl6%APr~sXPT=cJumi|z8`o_4i+Z1xj1Usrg zo?gB+eBYOsS>$3zbzH883P3{V)Vvab4~A)DPhj|sEW_jn-gkmV{PsJhagq$JGtv1@ zX3yguma=!qL|%)=qdgu&U$Og0C60EV!Wp*jqYI11UwR{PP5Z*6&aKyUU3J9|O%z3x_D z|D8dAU6FbD5*bU_71bRnSBdc9EC(Smg7_u5>-1Ltfqe8ZFf4Cu~GFE zAJipZYjkc>*tBntE<5?K%p~#-P{C+UC1-RRUp?gWRQbfPo z(vJ^G{QpV%K(&bnD)cJ_W$gq-Q-${Eg&T>5;2$|swWNlmoBlcEbFof`YNI>okDxx- z4E{p_JFvNje|u~QMDWQOV`cP&L)x4SJw);SpjDZ&4SI$oL>x~Ow{&K%Y<*F-320Qn zjPH3lQwCP`lzo9y#(x6nKT>egO1Qg@e~(RHI_@+v_BY62PK>QwG~6Tuvi7af`RrN| z>kwOk(@>LDn?o%s5TpNQhTi-%J^wU>WRK(!Qnw*iJ45E|G#PD8PeV&~W*)5`Xx%!A zmY7O0s;8kPD=UXqR3JY8-3+w}1xo7#AqUJc$+o_ZLf<%)n-@f5FCB8z;;MMb*tGEX zE}RPyiGDt`H*MZJU>jSA1S>8gGhGVJ@p)Pz0-P1cW=hU0JOTo`gdZg*79M5R;6Y%W zcBXcLbXcP-b#yW;oG@Uw@$?9bKeZ4ucNr7n(PE|IT}34$-sP-#B3flEZgPE*6)m~K z7>o2}U{r4+!<~zWBVH#UqB{j)oZ0vX<_CQ83Qu@S|T1xrU-y zzt`V6Sk-^}l7G+XCo--PeWP)mB@WK8(PWjpM5U|xk1p4sA-yc|HjYJ!z~OrRjruqL E7vIc+ZU6uP literal 0 HcmV?d00001 diff --git a/sgl/data/__pycache__/base_dataset.cpython-37.pyc b/sgl/data/__pycache__/base_dataset.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70ebb0e3871ad3fb26572f002e55c2ede8da355c GIT binary patch literal 11913 zcmeHNON<+7Uayz!a{1ZaZucXPac5^T=}e|GyUfBa8$y^zvbzk;u*oJ+%u+c1s=MvB zUGA@{lWwbwARVNY&}f5@04*BizKmFj1LDMi1LC|w9QJ@p9FX9^9JnD4BP96!zbco@ z-F7z3264ewef9sI|L^(xKfZdWP%tF?dbKuCCQddnMs%gnJG^tQ?)gwGVQ5!B%3MFDV7#=8gz!~g3hokGf-A% zIhMye%L=TBw}CPxRu*M)py$}Up!0UVRbUHoOk)+cC~}H+5joX3rm-cqEOJVyzrt1p zT?T!Qofq^R=nL$kpyxrq!fJwE0KLX83AzIM3+z=vFM@uJT^4i|^y}=3pqD_ev#Ww$ zW^Z8XuLU`pJkWk9!}1ulSNBdGb*4WnVd_m?~mYOJ%Z%YjHfx zPcaFl(oFqG#XEyX$CLG?p>ZtHegS)JGf;W`aggnCr(wG;)>7Tv++5EEDXUxW*n#S} zy+CU@%`Q^7<9LB?p@M~_4~kY?OT<$a!!QEF@_4=3wVLbzg&)=1g34N+?RtV-wh}|O z(1;ziv#ti3Yqz%oO>okv2S1Sb3daBW)$2P>$G+~md)M#Y`QXj=;lXuxyM3KFUVn(W z6TC5n8@)qtwVK_gXIW*Mz9IriPRm8PjDH$Wzmn*{#;9-u`%?VqO^=xTGzed6(PYUu zAv}jTURAtA$HHKhqRR-VudaeXQu^g27aGfS4mt{Z(i5DAh)1xn>3HAh^xAgE?&5g8 z&$+{=^9+i#7~){|>fVlfA;Eq%$$qlf3wDH`M{_e6SH^2UX1|tTUrVx|R`?}!B5N_{ zakuT%*=sRpxxbj?9G7?r?DWxwzmniKxtF4WPB5dxEcViOL3p&A|FL~N9g>vdKQgL9@bYM{mc~PkhTDXrCEp8zd=jmv{i(1g57P3(bX(}al zT4uAqavO;5*6M9;*V$ohG?zBp=ojCMCbrq#-q_ga&n4n0*tOUC%agU^CEn;4J`B+$ zs#_jM!>a2BPhl!J?baaGAUq|E9SweS+-QGd42NeP12s1q1JNhPKxvy}c4r!cT4${pqaZl=*0|YzDY+7; zb|yYq8|s$j9ro-Q+}6Q~RD#j=nt;2w@&wO#{Yh6HAyBl-0<~>-10$~JPG=hw;u2P) z>zRoTtR#m!QEH+S@mZR&(X(ABLeCrco8b8y3BD5pk1J2-nMOU@_9HPqkNAv%W@Gi` z;Vfoxi5FwGFq>H>(vP+htRVhljc_B%e=$at*^Km&`qfHsnvh+i0@GQ|ShRe1zhiCL zb#I^BGjeiMrW12Ks=K+qLRu;*g_Z@4+3GlKAJ&x>K$Pw6*V|zZKZk|_9ip{obMH`Sm_llX zVo?al8jA2m3P=YMFdn3#2zxf?0s{%1IDrPgMgg57zCpot3f`iCR)~w)qTVM^lhj!* zlv`SVlF=?T0%@AfYxLJu{1=ESMo*YN4ms6q;E8_(1h??GB#p+FAdaRI_#T8L zhOoJ=^0!fjf0=@xra*LP9Eh%c6&XGr0f@THDL%naMZzsA@Q(<_41aRdlL^Dq%!Nj$ znTL6|z$(JLD}Y{PRY4a)zZ4Uby}@2(uRWDvw_jtI+3TRH>6dg{G%n2$v!B}w7IyI7{mWQn& zRFx&FK$cfxYlvMzN9PHiY4+dhsit4975IC^S3v+PuzH%;8EnDlm|y}3zL(%sJI(23 zn|%Uz{vg5g@@by2A&{_3XCv%{<@>zR64?Dw%vtWAJI(ne+IqYOR^Lyso3>@5kxmSb zu;yK=_gw@dRqOwaz0la>B8e@UFYagBfw;*R(+qT#A z?2cvq1=?^4YbDsZWa}DA|H8EK$2BG-LcC!~X~Bk-J^4s+=R@zr6Il#hZD|<#P2?(G z3b}8E9t-5c5!q5(XEK4W>B^z%IYZlQRugb%VtGc-80B{5lvodvw!9NsWVmU z?7Lp=p{Pf6NpnuDU`$!&WsGVv` zgEZ3ynS-z3IAzez#z6OVpG3yr`l%sk;e^ikXVp|$+MpZO`YGekS6gMMU}B|k4}vTQ2XpZpVm=4+aSSs$Sn$gj)%=RIEDg&3 z9Owl>&-n_|(NBfX`}6+7u1b_wLFqZHsT{Ri6unjW6-rkm>P7cFMzrV`{jwMljc?H} zQ7w$G^i1CS6swmTR=p*^cqAjO;w}4?;mWb(7e!d2a7)=z{H&kdQdst5Ww034SZu9Q z9Z_!4TJ;xwgJ|TQ^K-}YmJ0jR*zAAp9>no>^O4;hD?jcjF;41Gz1GHEa!7gmxb9l( zpA%AB=u4*rf1p?vCuI4IfUUnQ)G4>NQ-5TGtc{#dYwOD9<~>|Vr`(-7oT%sdEfnX! zfFRI_VUQBc_&b!6a(nHj2bpb$(l|)hdp*0$gzsJWa!}mijU8Nna|e!}{Z8*Nbh7A= z>%0qX!woc|g50-Vn?DHCg7hP31C8AvF9br92{{&dS({yEA8TnoFW`lH$E@Ne)pF}}ama*+D$*(EPp6rsQNNt6 zEsPzCr$}fx`?!k{E+X;-1(|48gG^N3EKXJoN{xNaaZNdL0@B^>V@yfIX}hKl8uz}P zo~a`qQf1KqEwSZn;Y#_jIhlfjK~9L4$98jj$KzMfk(0X5LO3Mr;h_~vp?^a;m)w04 zu6Pw_L8J*45ke+N&OJ;{+CbyV-UaM;evliSNecgrtH%3mWCQM zYh3D+Arj^RW>f}{K?qpE;nxN!Krt@+0mE+tf(SmZhA`}>w#jh0=ck7N(+G+I)GYvW z1i|+JWKsyh2T%Mo>h=e^mmL~tDI=g22YiqV0h0fXmmd}cu%-a4MFFgU?_Hh$ou3V%Ri^KTIC5)HTIqXe&ck2DpEvW@&|wH&3UjweB61m^fG|<9KH@Y zB9Qe51tt#{d=*fiso=_;>4Hu#p%lfpA-)Dh(f1;NxB|c|4a(@%BHAkn;G6TR{@id0 zz*iDsnL+|S2(@`X9|?j#_2&oqu1+9bfmi~7UV8vRG+aJL4DGG>`PS;O98u>4bslLK zj%Afr1zZ90SdHODUpuCIH=CL1F8F;q7@0P{Pq8^mHQ8*1MIR^idu^JSHb@@KdeohihYOmDtG;Z zTT~Z)AJv8U_g?XF3k`vLqK%Vd!We!SbH{unxStsPGj*>B?xR$QPjD|mP-Y=0RnBO2 zSqAI7a3ohGtp9>nLum*?AuwtEtNy~Y6^BBSfv7Ve>S#@Ak6;CbsH6Rnw=Q7)vs6>8 zvk{59+$O(+Bm5aouf&B|`Y-%ui>1$~4F5WUK=B+1DEuY226ypB}RE$v|tiw&B zgnyI(ng22cA5w6cf?uKFTNK=-;2s6vMi8WUy}J!tl&;o=$MrsC{VD|{VEC_zLY~ub z9BvY=0*M1vfbVU5P~n9FM;Na`DwJ4jRFTA1B!+@)$M)(LeR#ruo${!|W+?FlzPnyB zW;6Z~l@j9+B8D#WxJAKlQb0f-$j$JsqD?6s3K|GZ9oMKreuy5KI(%#6$ z-U5fOLut5^f~3ECqwc^)8c0PsU`K(3Qe7*LKbgn%xS5YR~osG58cX*CFmA_U2* zYD9TOc~wb6D4;g-mLVGGT}4=x=amXt8pTL#%p(Rdq*mqUi;TF(4f;@memOj`MerYp zgd*c7vItUWBX!y$@KS(NA##}du{KD)7r?t3Kci|6% zb`7N3b#>*8elC<=pMt!l)oF2Zw8!MSsa>+~seQ<5@siQd;(Qz}bS%}-F z4P%y1mN==^mQKvu_h^=X7G0mF(JBBE`s*s`vZS-71>6{*NZS=@N`LLgt`5FxgfzOh zzI^hom2S}Z&rq;T0e!7?^4|Gtl=kx!yiWnyEBrPELUxgv5?(r=xNjEXjZ}61TNI4< znV`sz$!XZXgU2PQaPq!c*Y(QyhFMT4;{)TqsI^`Vvha#>&w&Xd1RX!57B3+PGLE}J z0xQsj*veCekg2rOg1p;k_6|4P9`>&-1Z6OfuS_?{ak&wHWhy+jp)V6!2kdN literal 0 HcmV?d00001 diff --git a/sgl/data/__pycache__/base_dataset.cpython-39.pyc b/sgl/data/__pycache__/base_dataset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5d24b4017b7877fd9948c735afa653808a76314 GIT binary patch literal 11775 zcmc&)-H#htRlX36PWa7+>GrO}3vrGtKHj|lMW}9J>&4QWT!tt%{ zwqv{8x2lu0RYuG704vb2!|n?dfSv{jRRqvde3 z*K~RNK~3?sEt^@s_6FOr5HhygUCXo{v>n&axwg6AXKik^T#KP#`OCYu?O0#i@4D@q z`*(JDbFk~QsM`218W3HUBui?@Ou|6OO!+`+sFuc5rhOnC%8dm01WO7&2|mSi!KYZ7 z8OW=%49nu3W;vF}+d!TID~dcB@Jp;D_^g%fbIHUz%}{v10m_!9VMSXJ=L;5XT41YZXKS@x{pSHM5VE(pE?{&{v$@Mpl+ z*d@WQvKKJ?FZvl8=<7|l=~!-y25IsmUc>WcJkD(pPw^yEa^=>a z)}nBbo?sGkC7Jr3igyZ+jwkI&W8+Aob#qgCsBQTwZ{GLQ18%o0$HAPcTU%STjGr+3 z&7S3}wlnaxj@|AfhTFF5>m~}ASOR|Dj7o`c!er=*ZbG08-6%zwMf4k%li1BFRZPFBkEb#4{|%8c?5PZeeXkKjP^T6OaVs zNgWz6dF!Ffvk2q0ikE0yKsG424zh;FQ9vYRSd2@e{+ujDM*&xQAA94b#Lpw}Q1+L% zZ1*d@LD%Y8eQdZlIJY^kBJG%5KTne)HtL}1?mFj?WzvN8coSx`ouV;(lcX*ZU7oIj zX#*(tqyg1<0~Td|D$Pin%)a06+D-OcBwHS?#AQc0o{B(yvivW`PmhVh9%qeL7ItfgnW)^5qAMWbv1g|Mi%tT)bxwS^}7m zOh4l35o?i94axPUlx`A40+I_XjItx4U3R2&lRy{1l=L3bFKvvpa`#7hWr2O{wo82IVlO^cuP{qIs_dwXELdynfR47!TMk~n2>qH3U z{Ul6j*W#RTk-tKOK!xM9D{Tq7bzj9MQ>{h-;C+T&!XUonFTG`9mDsVJ@sl{l z@-#-emJ#PNo&(S6%*Sc52`e(N7D-EUf#U*O3G@8(=3EEjrXxQ?Wu_}ZtYg|y@8el? zuw0+3?o`K1p>6mxB(zH8m}EL!{w|)0Ua!ncJLM7c#8aKXi0{T^Rp(_r{j;P1rSHe2 zUYM5}oho9--A{+$nM3u-Jti#t!$`I~JUcJ@sU9H3@r68u!F0;QMT;lw?OX>5WC*aF z?|@8fwV(6}l0T2e@QXyMM2;=r7ZLX{o*#V}JHKG)<3}4*(<4TM^iyz#>B#pa zUZK;BC$GVy62>ZseQ_G=!g+Qm2R@gYvi0cFJ6#L&ouuWbvVMO zjH4+XPYq>wZqXP-CFWo@S`cJ8U_r>Pe5g33z@_lMEV{3DG<5$eQWZCW)RzM{15)8n z?5LfjCts2JUkLN2khdf9>d5h89i{|&}_N2%8WXla zHBSz{w`QQ2j#0E3g*N&g<}!d{S8>&C6)?K)YJtY-Uzvb3^pU(HkL3O(5konZLxtb= z+oRKlBn~@WMl!@{*o@Me-uK^%AIKj5?4X^3$u5mlY*()f6k4b~EEF#ZazJ?noZxyd0DErH^F( zs+%7dJpGSxefhEC7LOG8xKY1Id&|p>jD8-ampsGEZ7V2gc$qNuo{&+ZIvM_cQ&~f+ z_Ws4oOykjN#IV$R$|yf7h*2A3)TkvYTC$Az;%F%vH;mz^6ooK~qh+s%9+lRlRcTc8 zmcTCye#ujqjy5V>@=D(FU6nYujND6@K{>29FIuX=7D2osQ5*Uj=*^0k_llx7)VCF{ zK&8;P!bkGn4-ivv&v^Mm88Kyd)hmzJjwCNHXo2XqvaNV&FTJg>^oPo5B`mShS*J20 z--@~Jt#}6UNImOij^u3>$2Mc@(QCJ$YBiL zwuIwr#QB-8I~M;&5a%cF!6azi^|L|`w3$#~p^LNKXVwEwTiH+U2H7-Raya~_5XGr; zej<=VD%CgMw);D`tiEIOS|uf~6TP{2(jS8+iO{YnBdQ7;0LaVVfAJ-^M+GTHV zqwg>+MMTTq$MuNt!ic`7!cp~8Vg5#bwxC~V?Q@PB#EG|!E?DnlNLqH+Y3Sf_IolgF zbc6%0C@P>SHtlWP``&NN#vr4g5entL)!x~40~gV(*7ui#E}3@@%t-hAI|j@7br1>n zuw()yaSfy6ZkDcP3`N6Rk=GHM1y6sJN|!dJFibgckE@iBg4D+v$|sdPXhkU@A5xV( z>dWG<3?9}r!X*)xe5}dI$4RZC44<3V9t&DK2km5q1ketvFaq$4E^g(=bPo@RWy)QJ zU&Iw10Ge*<>F5!p0d!Ob$N>e&01C8G0&t8QdI7}*qn83G_7dB4P7*u*>O&QYXab!7vPGL`x5^vFEdI5 zTGIgEw3iX^lMaBF|2NguDb4_^vG+TG;3yuziWsyb>n^!CX?VDdWXV;0eR38ba~zWT z@lpjj9d(V`X9R}k>G`^})PpVHAh;h1n_g-B8-398|sgjEC!BDA;4= zkM{%Z{tF=$y^l#f)?4KI&!{AnC8QRh7Sm9RD!)$i$WoZe+H8~B954K#^xF!+f;71gL9 z>7!7;`01YIHck56gWHrw4Q>RgOaQmz#zPk41Ii`(A#?|)phLtZ@(PGAw*x=_w<+c~ zh};Kh=(wg7ib1rfq0@&vq#qkPlvvQbld(T)=9e%7&L5yYIs})1%LY(9z3E5Rm-azaI@V*FiSL6x~&xHJ0V18cK9^-1@v7siRQHUe0Qi7hSgzqv! zcp#6oM-U?2KzS6@3Y{Q-x=x6C(MnJc!e0jO$lU)A1OlD#W0?C1)Q~uC?z;(~rqCx$ z{ZJbvaUe)~_;BN{#=jSEI3}dVI@5v>4C%zjs+%UPca2V_lXZ=8E=U8qsq~fLhtN8~ z_kl>udr8En2-AS7I?Q^Q>A<)3k>M6RL*RupE_&E!&=kfMY42~mv@j+4H_&VT%OEwa zk#)jP0LcVv#O{I-Cs&TjJzzqen7xq-uwX42dzd>I-I`BaG zDSV06ZMLkL!!d1Gatz(XO=YWl|2o2_aZ z6~+~^ZPD(*)rZ&&|ApMr@Qd@NNWaOQn7^}rrE9mEUFRy8N%`7R47%v9TgVpnGEY%Z z*r~G!CcF=8XRqk?#|kt~|kQ zP~2BQ2-F?&dn~?z)pfmm@-3^luhopVjoYHGTE$PpzsX%2=Ln(H`0tUNOCWyAcIu>@ zd`+mR93=?-No&&2I<5BLpzaK?(yib)Rl@h2b#gw|qwhI|?=^7C=A+bow+(58!a^N3 zHYO?9(se8mtf=XF3-$t;?c%alG<(t5`1+OL0{tr1jO5VDr?UjPY6XSXDau(b-^(Zd E4+#VqYybcN literal 0 HcmV?d00001 diff --git a/sgl/data/__pycache__/base_sampler.cpython-37.pyc b/sgl/data/__pycache__/base_sampler.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdd56f38f860080db76d3c409ead6df3e2a32e29 GIT binary patch literal 1565 zcmZ`(TW=#Z6t?G<$)#!4R?E^w?6B-h5G3yiq2;<-fg+?nAdQ49H+GY0XFOBeX`3j0 zq5DHx-uO#?JByTm%srQ3C@H)&Ruj4AoI~3;xWmr}@KePah zSJ2H%5RxEZ082&imgJ!bMFcu(A>^^>!1k%@9P_FvrJiBt*phAk3AFhG1O^b{BD{&N zUv9!Xa!YG#VHG%gYHvY5&B(HnKi%R*%5a&=W}k#Rj=R#Mq7vv(A-bg`M$jm&#i z8q4OIHwzcen&tX=Rg2Cc%-_60*U8hUq_vy9hOqLo?2GSgElVY;>@L|lG~*Q_;I5PWdYoR1Gs&wd7+r9 zg)|Hnod8J^7a3We>I00wk1><;D@YEo@UpD0q+nY=_WDnvk6_l{pG_8ZB`3yQOpagw z{6)FCo|t(#5!~|02?x$?^BXrSmoiqU!ZJ31l?Im++M`44!I18Obr=kSgeEkl>%A?b z#t(~T$V+Ufv402MU?>aZLfZHu2GjL3zh|N5Sz$6&+pGX?7FZ=I;5HmgYx@S-+J5mT zJPPy^h~&DL08v1tjZd!?&u}Dt2>LN@>#;ln`7rn%B)kaz4tS8CgEJ>L!S#y`1=b(k zgqz@wSh@*s=^b$MECkj9Cog)y`v90or=!RvylJEo8bi4Ns*R-{MJ><>3T*OJ$t#AI zFWG+~KV`E8S97RAd;Rh5`|JJ*1Ty~hq7j^aM4!-g_Wr#8XVDZm0u@y7sxanyi~*yU zCE5eVE|$D(NAyP!Tt7$gDGDs}BY|hBV-S7l1}UA=bUzLKx6z&p8*$>Iira$=29chgEPM5 z*pk2jiYfuRaT<7$Md5`E6h*r$@&mf;x~ndTE()|j5ul4Mx@tH5&LL+slAWUH$h^F~ zJomwK&-u=GuHTuMC@A=?++XvwmlfqN^f3I{sJw+MyrU|LP=wl6s_9*=s{E`~b=coQsX;%;wh3yPS;+bJ=P9>>J_tPhPegOSSO7}{sW zYuujqa$bHx;dy4CSmL;t!`tKht|U%8(Z#%2z$|lO*(-WB)|3|~G5QPQ6y#2b3p6Ks zPly-Mx+q>k>%3Uu{xqY5*3;q)S{KABx2E%z#LIZ|ia3imC&y-cRh+|YF9`LvQn~mw z>3Y4QC+3}guk9tq4Zj{G_78iJ-woV$MN10rbi3_4UJ!O=Vx@P88|z&G>Gwp#OU#?n z?QQ*yMpBeNzj>!(CYsks?4aMN1tg~SlIhK^tb4V7Abjc7BTpoj^rF5D>PFg6v2dsB zU0mV&C}O3l42)P4YD)>{aMxOfT#Qtq@2Njkxm}azW97CIY5RI?JVjf}9GH<2E3q1z zu`Y~920qf<{k-7d=Q$Ugu-nHv91rUb-*9?p5whFs`9Z_!$?k@`;kW&0=MCrRax=lS zqZ96QI$k9Gy4!B=@Yt?!d-y76Q+7Mft(}|aLLRu=k7_iRv)K>oq*4}c;OqF8RsQkj z?#c};1JdKGaJJmgiMq~)hkwyS&kLN3jvEN)(&>sSCouvv&03}@lEQ~>yYIaxB{s^8 zpzf{2ln@V+0(~Q|Bdl1mh>o&E#RL^Jj+{o3M*QREz2{i4~v}x7WY}L>vH93!# ziZQl|<7kjmDC#;2`jJV{#isB&Dr^dRtnI5$fdVZfbWn#uV9WKEE|(%bHHvKOwLKl% zZtSkHwm;b9UE>eC>)&&(Icsh^^p>4y3mQe?_!purmT?~X?KbI~jlvTvJiymLA4O8J zdzQq*Ew|@6-E>$Mbx3KBA36_ZH)x!$n5jZ_;dc@v^mjcuLqp7>c=ppt_pI);IDyAx zX2W+wPmHN^0nH)lR52{=B-ANSp}tlz61Bk_nb(zV_9#$s#6e*UK`F&Jkl|5=2 zck-zD2w=$`r8L!n&Tq&4N{5nF_?5xRHTJZD3Ewg{HCFNC)S%%WUFkP`K!(#GH27xU za=1t8{7&FKOy$|G2`o}oE>ZPh$F4bdWFN0QnB2jRu{f08lV0e_2cD>;i@`VY@BjBM zGu&jU$gQ^$z3xVd zvDx-}Nx={O0IfmYljK*)L0N!#+UCxh)qh-EETo{$Tidw1F)gatCMK4 z)v`9L&Z_fTib&SM0+OODL=D>T>t2{Xr0NKhlJMF!>GeDpaV~|Jq4AM9j*QO+P#!Wq zr3vFhi$1V{+6E)I9$t>MCX5M8N(F*niL4BS$W-%#LagCl9F*eXe(tHVZ$AZWwal;? zm&Tu^L_dSUtGJ&}l@Ty}gY>7`9xYNmr>L@0H6`P}EHHj=L={q%}WC{W4FBt+o9Y{A{z`bt8F>C<0*@`~4Oh$Ox4Ehmm8#&p}@DRbx z$lL}%?4}BL(TOd?*-Vr=a2uEn$J6)CrwZb_!FQlI-+}(IIt!dVp&%|;O`xg(A9zG= zNA9P`1z@p3C0u|;HmMH8xABgiEEs=|(O2KS!gljqGhTMWetpXUlU;%pT&{4SPpjPN z4y>V+vH`*!Z@ZFD3dl85-W)z}ZxC7g9dJ%9=p|M^@bC9M-uX)LKzG$Vje%{$7LB%& z5%ge>6WtfviCU}XFmc`Y!Xuk0uR+V(xI(f>#VBZnY$ItCZLOr1@dqYGav)*`L3UEw z1VZ*@%3czK?EG0mM##yb`cJ43S^@?0VHwGdmQgWfU*m1l_rYz6=i-hO2b*Eg1`EBV z$&Fae-Wn|}yoKHzWQ9(fi|tfDLnyzh1P>q&3Hf*=Z$kbjO{qK$Ib`xQDe^MdA6R(z(6E}d5M~BC>_$L#XqD2% zOgUiDNKdy9NYU;*auFo)U@4C1dTA_xP<=}hW7s`SiA{a0i?~mS&6-EVfXiN52DNuC zBdJN$ZFwHTiCJ$0f8`6PGem<6-=zBcRD3`M&}0+|ytZU=SAjQdd1NW^%AkG~v}5FO5?0=XFm4cs0dk)GLY8cPjCygDZ2mYZWmH9NP6$JokKnY#UP%IY zn*#ySOQw!b+VswlV^w z#zi2p8kZJu7B-lOW-vk-4nDD`4<@5y7+oi>yk`uif-lq2Kt?9frVd8~ZbO3R!5atU zMMX<~I0LjW8*`S2)@ffE%xo`4$Kz?{=Vh#-JUAAYx`?Usx}v%L6R`sA%*IpkY;8WCL|Kk!2_ANT&OUi7%i-9PLr@M*^Q6~t{eTaV z5bQzV>GXrBa|lli0vsgK!3|8elV<8j^^t}LU{Z(sjph0+a$|Ovp%ahN$8-$NAyW3k zYZztw-05(lPr=J}!&*dVXW%vRBH8J}cq+*F&)kMtMg;pIR+0<40RflH)63MYc9^y&1}Q?4$tRYZ<{WtskOq|0-lCyLORP>2RU6YF&noZGJb6CeAsvhbRdT$l&2$J0!hsZs5$URmus+&`kX0kU zMN_>;#c3)IlPCES+J22IBzMQLQ5J0rD*f3^pXPuDM3*dNH5PDUnn)BV)U;l}ouq7a zQZ;}flXxy8`{9_a@z_9+QL2^Xm}G4191hNhSk_~#0^UX!<(itFaUmcH!g0o`=uWJP z?g-5CEegIFv$BE`{F9z^@GRuXl>QJWi*wAbs5GfQHnP$*86~ViJ|aotTJU9fk_EIV zz-QFmD*Ft5K;{iu99;rZBSfX=Aldf~qu3RkeL_B!1g67lEi+2T$X)_Bm0J+`%H zp#V0^B-w&rmM;2}vIzTeR1S0SeR^>v^37kL>;1X!|R!lxvw33Fyzj-j=YZ zr@-i7qBsH&EfT;=Ei$`UyPd^DO<~B$-b|C<^F&ngQ~}X`T@w zUO?~Q)O&P_Z0ufU1yM#6dhJc8-R<^L^)EZ#N@Ily?#0TovxJa&Dcyth6i8m9PQyfO zcoH$5M%^8zNB2%6n2y|^lO6dXK0&@p1!oB6ff1)4&noAX=CDqaNlHd&tuaV;+fNn zqaCUk=ZKx*pqiI6rWX8v(NHZGqecnL8a*qTBxT*n?rRloZM~wcuVI1c0XIYq0bspy zN`6ce(ehFZXJog4{D@AHl*Xgbs?--)4cD bv)cU3Y;kdResS?kd2a6H+|{`gCrkeU6JDhY literal 0 HcmV?d00001 diff --git a/sgl/data/__pycache__/transforms.cpython-39.pyc b/sgl/data/__pycache__/transforms.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3652577f810583054fe839ccd0d012d4be7fed06 GIT binary patch literal 7060 zcma)BOK;rP73Sq@I1jzXvTQktL*=PS6v=kem*YH|hmxkLn!@gDTw_YjrA8dhaK@J$ zTM{@xQ6)e(NdqssD6)`&E}C7DAJApjMVCYu1zMoZu8XeHMf#mX&S)e#MMvi4<>k2# zo_o%BzH{wOO_enKR&THS`s14RH+q=-Ie2&lS9C*Yn$)D|X!Yza>Vlv3x`DgiF}qgX z>e_X?Tc{TV$rv5cb?Q#HSTFLk*(r6)^>VjTuXLyCQ)pZEEB>@U?geNmIscsnC!(c_GKChtQd&0(afJcIUm`7E~= z{eoXy)Oeoxk8OEYF5vB1epiv_J~rfexrkX7~@R!=qk#vqh|@|&?QQ(O7*K!r^+>&K;ylCIZq zMQ@=>sqWI8i{!yZ9CS`ip@kzL}Dd| zG(Rx$q1OI~B^N*6ea?+~1FXgMvG(W%w~rR7di{P7w%oqzZF-wQCy4i6a8E8d7tA`k z(O$Rf$0}%goz5PQ?MbhXFLSq4uj{_M_vX_P58NBX4Vuf{8iY+!D+}l9(zLwtua`ex zeI3hy^yDhsZ7*`;p1bMeUwqg1L-&&Fh0?veR1<0%DuT9I(=%Q5AYdKNK;ho zKS|=zw%7OFUN$U?x}-Qah}^rX7q*saR;E-#2Hn(*f_-1j(-7xSoc?^$KdU?|F5oep z+YG$OmlGgzQ#6)oP()Y#FOwB)9_T|0K4xy|tmb>Hr54FrrX?0)qV~R?=oG#1-cIOlLoSZ>wux=D4lVMemM#f2 zv9%b{gq+PIj^{~CsSSksb8IJOTsY99U%#zMi;O0A602R5#-es0Ue|7|42y}0`4RgL zOZT-!jp|s5nK%<|s9)1=|28huN->+(t^l~T_SAt5qp~;F+2Pc~_-X*&Be?B9##<4< zD-f&P?jVZERpE!Pb-XBYDT39e(-I?D5K4bHEpz>@AGEe(Mf;xGccR#9?xaT3i&Jx} z6ZF$k5CtJx!=|suNfHtTK$4G&VB9v9UQoyJN%F@_GH*JKzr zqAQ8shVg(F3Bd5x*v?UeOtUyFB|7fqVI?Ua7Vc|@&VAs{juo|&%H*?B=w~weln;xU z98gJ%?m=$iS^Z`#XGm+HsZZllyPT9!R-_HXHR96YR8maJWVl=}B^9dc8m)#~VECC? z$J3B3!SD(&JLkY2P9^3?LTT|#GL?t}AX_qZ9R?0#EbM>Epb~7klPYvy%pfTbBzWb= zzB^KI#ui5XPEm^EeY@UH&P&`QrZnnH8=CJSg_y9%MBz>#eVI=k#N)u#3;l>Yex{Sf zanO1iBe=OqjNK*(kK)+!l@)ilwzU7u1B8HXpfGTa-c7A91j;qx}*OFH@19ClbK_e*vflVsm z1~jrob)dY1cl2b#0Sb(|#{N~dod=roiW?1@+b(G7GA!XrjiYr|zg9ldU%g0UVC%3=h{%;@mh@7-nY59PUJ+INfl|GKw_sEzL3di&0#XiCMnO`O zZ2iH^icpk8^j}aR)RbZ%qAJ3po^vXVh%2_v0BI#IMR(*Z*btL8SsFXK+Dt_L*4)vf zcSpfhq8IToDc(m2zovzEAP)(3Z!B*?{^yhOnRFspejV~Q+pJF`ENwUEr@EO6_1r;x;sQ`7bRwx0SLKS^N1X9+siB9=zMxz0`Awl!tjSCJV&{7=D04>ZXoUNgC zHqeH1@!4dSX?YcEs1DC0)nqO?b4~k5+{gg9l;615C=na z2m`C_W*Ik9cBIoG$kg$E6FK`1c`$p*$bZKvU^WJ)4jBRBHH@=%?sPn1r`Y9qQ6r|a zunpuJS73m(#mNki;h*UZ6O9!Jq$YifTBoYzDdp`c3I*y7DwzL_Ju5AQy^t_V73pQ> zR(s6eQ>)xr61hFq*g z$g`>+(L6;ysv#Xr#8gT|>kcy_EC`27;6|jYI->e`8zNSX`YKKJ78R82p1=n67}|b^ zE80MTBcMFkmIVDd%%B#42E>?bBsVs2VwQLmSk$ah!kwfXF)d8gE4Zpif;eRBJT?$y zoN=WE<`|oMr(^RmqIC}|fw$2`X{K)EbO=ad92z;XqB{{Qx+5;DS19)8CPVwA9l#h4_CZo!w~O*YV?1iumcb@m(jh}~978`Y`k8s?KtgdzrEMQ_+B& zYRs`{Y&!o1&R#?&3QXp+DdXsLm=8oK5*U1t0sh1ppBW=&K=+{52Xr!Q?q6XgQF0Xd zoh`T1>-DoOSaJQ;)++PdOSKhu8Bz0cwhtQ_pgc>RMmgE&AYwA3x;e^}ZY?3Gj$wfF z9`zG^f~rx$$%1uc%$WjM?VQ~l*K0aWc?qqxL2HfvgCQb1BhpN$7!HCH)=fM$b!y9h zbo6K9VuE@N6>B&u>i~T8fHR_oP_R*ZSly(FXn7f$GsfG90f<7b8D1l`vK>Q#VNk9f-OZJb50ylx)fKd+ykn=fCS*S>%keo=Xp U?xim(k5?BKE-YMIICr7)Uj}uNng9R* literal 0 HcmV?d00001 diff --git a/sgl/data/__pycache__/utils.cpython-37.pyc b/sgl/data/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d045ff738b3041bcadd667a3825627fedd8083a GIT binary patch literal 644 zcmYjP&5qMB5T3DO6U=~S`17b*b}L<3yR9bg-al)v_LaS8D>ZlXYUi5Po@}l{!GFzSzm%sR zkpV2?$uSWHsX3zD-Le^hyJcSx@bF3W*b4{O=+TRXIAAX$5*{N;9DEU8`eY##D)B?~ zsZZmYhB+)TNY$)L8B*7odKt1MEnR;NX>?^43au{mtXV6)57Km(LDpRx25WseV-y3< z9CPMhZ6xBx_88VQ;k2ZSmlP`RL9meZKb4)>k#1>XKV|ZWl z8JA4*TbKzcWI{QSIE5npNz-s%V@MAcHVN`~t*`YDsUW8T1X^UF>PmW37;$l?MN z3c^JyatTVxC9=5;Wo4Tju0Tb(OfFZUs_amW>rhv&P=h`2lwE4_7>p@bX`CluLb*nh zJOxwAb(-cGm{D%fEYHE5vPbj001L`Zy2p#Ks60kXybR0A<8+@_U`2U?9`GuxDo@fH zufw|X6m9S(Y${LF7H`A0@(k^83tGyv^pJO9S9y*e@gD5)KJ1(MJTBlpT*M_@#{0N} z4{#ONa2+>r6Sr_1cd&&IaTg!q9`0kiI}dGq{LyMX`DI+ST7_ZhAY?O?55hc24Y&t! zI!MwyGwd92@Ga>Z@eE%6hZp8r`7G)jzmS_}aW4+DR&nSZUS@g1{+lR|(qve_n|w@P1BFGP(BMHA_7)~qSuaNg#jH98F2<@Z!)JzD4FgyB}oVH=XmE| z>aAD)1+P52EsVFO)Au}^oYU@k-xK9HB|~4E*N4HtG>l(d{Cj0}@lJn}y^LY_hMz%$ zGLXqIXE94T$O-4rq&(z>O|+;01>ro}RD`0ig(WIOS-5}|a=;O`u}U?l2^VpSreRvR zgms#M8R0U{(j3eQS8$#dU_scyMQT7pxQa`(49mhbd_*g-B0PnUX%$w5r|}7`!J2R# zpVB(43(w#)+JFt=S=^*8*b<(@=hTFz@I1btZP*rGz?ZZGJHm^&OD$;89%z%ghTr&+ zfmi<0xzSwytz9#knNhCoksk7GFO6aib~}s*QJf~4tv0FeQD2iYrn5Nmhh92<6!rsI zupg;mdL@g^gANY(M-G)2bfT!oGY>KJdOXuQ3e&4$N0a%Hd>%{3DpqN*k8-zi^6~We z+kr||67YYRrvW05pSvAyqQK)P)zw%$L)8z`DD-tve<<8@MFR}`JjMDF_VUkAVehAVajz;DCGNP#2i20DkITn2cjo4zzvlLit zmLdxu)ws-3VR2ZhEH#!XmO9Hc%M8mb%N)x*%L2PmRB zKQKd`pX?G!J&DlM5z3Ci(f;>tU4I;-7yxEpCn+|78R-zk2e27kQeE)|ii7)1#UZFa z)NV3~Ra^rrmSn{}uu=(ET!j2h(CKx=Ot5YwVlfq}-Zje{O}be^X=nGeNwIl7y%vlRRX7?g(DchGXi!l2gK4gjT=ol=YYq+Wm@C(X<&> z`*Y5mrb}V1Uw7)7w!%h#-kIm@29srJpD=0r^<%?nN`8;ED^IZ_+iml_`oWF&w!PF# z{1mOy{ox?=LCk&Nr@o4BC~vTb=0efjjivvLxjxtrUFCby?FOOWp%te;dqO0C0^u`< zOXdj4rI3a+d!}PZ4(&)Yi=>kkS(S6LChI-hvE<8NvM!fRxo}c(%D(MaWb1E2F3P17 z)2aG%@`_wOVaMF5fnJfTny$+=dG&-j4f%q65$}2VlDvj@^GjwjX8}5|=fu8V*}=-A?EwiR=E8ef!$xUflOLljLCY-i=#tghz*)NiW=- zNWBR?HwH&)88g3wlGIQ!Zi=Q@7Bz12Q4P{AF`a4R&$rpFFZ|u1D zKe@BxZtvW>`Qgp`J9pbcT?M;(nW~p&)iqH0@+Ac+CX!TbU@RePGG600A6=z$`QNM^ zl81gI61l2AiWfmhxZWeEDxvn6r~Fi)ZDjmRKzU==As=%`>zm4ZyxH?%w^Rj%hq*Hs z+8CrTe#s@L%#Ao1=cK0hM3627K>W55sDDqm*V2c#5YVLo6-Dl*x$`#&e=ij0qtT)F*Xlho$Rza z7MYnDv=-i)|7UD$9(<;GOW?Oq%7-KHn~-cUw!dP?llB=OSM=WY8j}_ZcEUO~y(vDk zQdZew7dXlj4$ROT7ICE3Y56=Cv@4%+=vSuEgMU1rHQ2bSFFIU)C)l4PIp*UzNLGV3 z^}{XVfpkBHHd=>q1D9m=gr#;?$J1CrcBeU6O>2rCXjy;454N?Q^H`;c)fNulovid0 zR%&Jo#C_fpPq_LzYl27bGO-5a#cEBg_Vojt!?>kSi{u%M1@x;FGG)y(F|)$`((nXDpj1IkEr0v^H1HpNFhXSKGIq2$fA+ zvIc8hp?Q%lC)5*haQ^{I=U_Vv=g+fhHir|N&#oL-#>*+KRG0N*@t0fUm7`|bIGxXy zvlY54AF$|6lcml5D)#sfXX91QMoX!Yt&-=Z3ljGzagi?L)jzC$%syd}_*nP+n%!j| z7XjAT&H15D(0K((-Hm{vl2P?pejioVsgKfoUe)J));#{Vh~%zH&XM?ilM9JsH<#*cKGzYyY-=)bLQg+Y>b;{M>BE-gvx zTYR*x!+VlTHnzgJoSdxd9*xID@ggeDij3Wt#qBkha%!PIT&1vms9d zDptr^`u=y82g?5}h=+-`s@sWD&EhOP*Pq)ern{+=I(8=x3#zHD{n!tBd&NX;r@S(| zRGsQ#&$mN#hi- zN(OyrIY~X0x`Q}zA8AEP;){lE@)^9}1gzqr(gVg+jM|N70C5-?Da+J|E0DP=n zHIE&fY5{rsUb630G|(jy0MmZ86_K){(cbO% z-6Di;yLayHY~Q?oKQBjdMA49!yKod(FJIh^)0=(ZSOS^8)c&ttgEVckfCMj)WR-JZ ze2Ty?1+vq;jN!k*%%p*e;h0y&Rbg{W*kZJ#6U-&R56nZX`XCVe_<$hbpA=xJ>kjk#v2hQ~JS-|Bho zg}fXm8-zx4Q)5zsg1I#aI{Tro$jcOAEN4|GFdWVV{BBG>RFdMm%R&eTQBSMrwD2LL z$2T6rFZ9=xPH)bY=LhhGrs(o;7N%q9*%+olnCNdbE#))C3#tg$t>XOOp&Lo3c?wM% LfmTCE;W7RTEy%C(<+l*!cHSOy||r>>q#3aj%Qri6T4}%ZDwapETy%V zyA&3*C1%-&$er}1Q};RPjMTUMiT;p%?UP^Iw>(t+4wjT9xupc`0dN4EgLA=$pw)62 zesAu-A2wDP`zJNd|7>k?nMO;eK%Ims^1j zvQVQ{c`zErA&ABYp$=7gO9jI{H0R6aPAbEf%&p<$*jHg7{eBdOJzDYnkF=v-_5@YP zJT93hcA@3pxo1fM?JTc?wvBdGE`QGY+(=sYYM~R>W&KMbUDer+%x2y^^js;dU-SlDyQmCuu`?!(P)O9^aQ*b;tIG&0rmiGl|em28A#Ry7)Sd_PJe?fa@q&84AAdtsJ^rMvjU zK<{Nm#Y`=he3;4icf%x9fe!s_7)RQ3eXW8h@gsTQIsTI%HdNi$VWx?MUF|H4`)Zyx zsF+D=0rXeAsQZ2|4zkSmzhYm1dvh-xgqvBmzj^=Gowwq{gUxI=-keIk2|YK4hiVx! z-$2RQs2I0JM=Xmbx4AeSHzD7~Od%Y~Tqt)#Ump%pQOFck73z^ndUP_gAJJr?_fWF$ zqaq6(!A6fQwWMLGBmNmD+pICzM8HB^+G}iu&cCwrf~Jn&N=Ac+p;}M->tP%Y!bE55 z5&jb#M2)K9;+XS-_U&4Q6^BXTeiX!`a9b&O{o?l4t!@9o$M?4Vt?fIv ze|7u8_J>`eu7X{yQ}xZW>RX_SGT!74A77=j`QNNPl4tfKvB+u7 zal8gX!U-QiT?y4ET=Qdrww3b}fdjC1Jn}YwyuPV|r<=PWET~mf`kH^{OdCT@;q8-q z7FL=K3;R(TCGZ!_@@1r^3Kk9}DRR2YTDFOi21}Rx5d_Oj?h%}7`&ha zsd`BXh3kd=3}q_0&Bv=G`Png@Q}BZPWFIOSYuC^x;~>NUt`BO!V^SD zE^e@7;cSHB#q?=Q+PV0GNn|*h{{{P-eZ)E2cd-NRn29~zWpyHQJGW>Zd=2kU*~H!- zg0F&7MX7<;hD5`O^CyO!>74LM-R$YCFR_cDtB?gqj^RW=r+O(%-Z>GI#S1G5wf^cllcn>0`64h4`=iJg@VB%As!z57jr|t# z#bft`<#Qw>_WlQ?zcSq$XCRxim7cKYLOqOI2~&Ja_X3 z&S);ba#WuzYg(rzTSwwAzn!d%JGy;5moMilbW8q*CD(1HJI}>o!P#V$v+--%%2&y| z^t{CVNjx7zhso+6Rui}UUSrqLE0N`Fc4Tq*?Awqq``0zZ82x>WVgjtaQ}6?mti3uC zyq^GAWn=ect{>OesgF`U;1At3Q6j^G7Zj`N4OABbj6$%2vWSw=HMYxg_DP$9>=>bP z!v8Av?fYeZ?{f7$aNrIT8`sLTe+}(2p_sJq-iV`2_tL@eW}g-$JX6IbXfDEa6Xw&r zvT-9$dqJGt1TmZPieMD=_i3FDDkGRPK}??$$?T03_hT>=Ji46QeuC}<<=-e6KP<%4 zhuww3-MPC}-e&jJ5{=kdG!B9ERFvqVb9UF=_aFTF-koiubx{Fe>+KbGnuGu!DoA!Q zV+^E=7ET{1fRZ4_hhbupYGDE80_tF_sNKF>UZi$-5b5rk$&OxYWV-RqMHDdu3k9 zO^77(`{^i=Uek<5vRrK`==DZ}Q4Dw`&*bsuN%-DKl|3}us~(gS>N?ecTrwJX%b5<8 z_J?War~PxRE9ckjhXD{!XeP9Vft1)she?(yuTDE0_`oAx_0X(d@ZCbFe!-_VnQ=km zea`F}d>4kuo}*%x&0XHcpEz}`1%4H6>T!5m5Jj{rI@~>Vtg^QG!l`RlEtlJ;ReQ$m zaKWqpXhKO<63x>l_Hcy5JK%B4KCEn?+7@<4yF-Ad1cTEUxPbV` zdC1@2L&-c;Ky>`EQ11}XLBX9yx5u$KVaL`9LneONk6Z1m-%6GV3u5L6+Ww3uj;_dx zL5l*J&FWhKNYx!RhaEh(1j&OSd+gN>G$s;&G)QDR7{7im4uQ?5D6vG11}U!t4B>r~ zyxy&r@v~x(;obNBGO+IY_deX-y8YgRqL!oy#Yj=>!+l`WVsR(cw+Fzz1V}?^qC$NK z(sb<-GJJz1tDl4KGt7P|xV_G0)c*x$P7%1nSH)ErqbeLRUNVX2(h~tmgTT5u2&4o& zIS3s2e&PE5AeAG4I=^W8{{AS4%MtYs@$XReU8)Ge804?WZiNMWTV@Ie_QhL9AF3Y^ z3jr+kHdS}1B1luTrf-a17{_?#yo5;vZWVd6c?mT^WCXGu?$pTjoJPCRsR-`cR2<9s zo#E@w7qv9oAdFhr2B{J_EULq(_c#vK7Oi&^mA7h=7>;5Ji8rPXnXG)5RT0KfvTGFb z7Vyw9;~Nj*1?IJ7Qk-|?)d9TV6kQ|U{A>(88zUXXnR&w*DbJ~C6fT4IMtRyd>3%Ya Qo?ZqH-Um&*pzvq?7eYxTA^-pY literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/actor.cpython-37.pyc b/sgl/dataset/__pycache__/actor.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a102bfa06e67db97b9d8d53a4e0d1e7a7ad6f727 GIT binary patch literal 4571 zcma)ATW=f36`t8$E|<%TD2Y~VH%%3{C`>daoit4^FpOMW2LS>#t&2EilWr-_ilVgK zrDlh*MC|qUUR(OTj%UuCR|Z|$V548%*pl71AXgRsqY>0Q~M?4wiNNkn+3uli}Iu%bRWi~@ipLr0W__mItbuY`U@_@?iQ zs;E6=khv}zXg#qenrJJZGl%tj@V}u;Zw(X}F7>H-Y<0TYOZN6ggHb<{90v|cbT)p} z5-T=J`YTkKqj24Gepr?*#f3HJpK)a=J7aTe;bfMuUtk$OwK!9yj2*ahp0S{u*_mT8 z^gNVG#uk2NpYo5fDyX8=Q0jt#HD~7$X??bJZxW5v_DF2UiP}!5lSv{~D5QtUc0-qB zKOTnK|F9oT!yijYvVbU*hCq8kFp5Vi2&4;o-I~Z`5Tlw3CL>o_6B?IMj|z z5(U0!`<`bSP{AZBK`@B=X&MB7VUNGv-A~40H%*Va_uhT~Ms#xAO@~pJc!R}sjfOgt z6WN3$RMG}2#%tW;TUL!XxWi`+!*6FL^?+1JKMUndd#~U_Kx0GbNH`A}^jMnnJ20OY zSk8@Mxfu5{GX`ChcV=Hprs~xoh6M+EVP8#U7)1TuFzTpdHEWs4A=!BlieZ@2=!LV% zdhVg5mr)I&|4*7Zo@vJa3{JkZ|70CGILDf8;w_Aq_kXkVMsTz9S}=1u@nj}n!b&_F zt8-KHk@+gBZ(#hfe=8cLYLJX4x4%xuki0INT`#Uf_8*u>R_MEPD@q3aD7}r&YLTpA z!3B=VUN&i~tEdde8#z~O!GFgNEVAo@=YlO;+etdn&OtJYVGwkl7|FIJx3O4z^uWmw z+FV}2z&YRY8a_#|CBq=7;|mLkPFGe7)8d%ra*|j#E(%3lro=O|T;Afd*7GDk6Ohb; zRPqi=x`m1a!y~m&Y@-tE)FOrK%-oLAYul0E1fBdAD%}Xq-x1`*1A0a}5=s)I5{S*V zQH+Spnq4h8JKuqXEHrb*-wJFDf4|3K!E=86Cb%O!w|;9~VRzUko!nApolVUA02B&g`9ir6%c*x?O0Y zGaRY?>FzWQ&79h z1&_Q-w%?t3T>}tlw_TFgL1FG)d%FMu116=39K}kOF$D@^C-ziEz4`@%c;b^*^^6S$ zpeu~;Og(_0mE$4R8e?GGv0B^*qIkS<)|;&@v1LVnrmag@SVb|~YCtN52Y$c>zXDy& ztv~z?Pu&1b25p_Pg)K^_oV;nFg*nfyA)c7Ua%k@OPG&*3w?4M>QtsyPYXni5i&d3V z14y}>+0ZK>YsleR;8CMs`A1k*J4@+=lNe)xtDp1$|=h|RX*^s z7ETMjs;FYFa(vC4eAS#>&V8KW!)aBDvu?4YzusqL0Y7CwV;?c>RKrfyta?x5o=#~`G(=RC7RIF z`}yWcQ(anYWgFQhIqv%`e%)c}g+u?3T>gu*yv5n9r7mYJ^R})C+&xj@O=9jht&i9* zSZsat#4VZkvc6+N3C-o1$&l%v`2+$WgRl*AUD)iDFFgW(%qsxJ7eTQ?Q-1uE2^$do z9w?R|I=i+s)Fs3`5AQYi9}x&X`Zsz;}tG>lCOdG zd0hPv<4dj{^H<=#Y?(aJ_8^HUFcj(pfl3857dRjkHymuEV>q}>zEW@)3}4_$70+OK zL-vKW|1bCQ4eWh(qn|#H_xC`ty3re7eWOo0(^qNp3tE#mG4vIu^qPGNLIVWAL14OI zkZxA<6Ky@%ZC}++G6`dYiioKYIvc|QX!%Pfie&?e{04=1Thn-lO3(nWoaFeVlYOL+qBsVJhifjVIGK!glQlEY{`s zeqtg{dBc#0$0OCgR3I`(=x2gTgnkcF?LgzGDPWXj6PgEwWUdLK5Q3T(kwOt$rP>2~ zCfw2$6LgWOXrF>B1hgVNmOr3vs%GFpI2!IFZZ#iqJWMrd{MZA{ozK{vZkw4Q!9%$y9A)4U!`hGi)QDL0}QT+^!Xm z#EKD=Qc@}3GH30{1Xg7X-%^Qoc9SHM1a!I_C-E$li7v~qcoLgCiF$85iNY}=noty$ zPANybUPQ43%DQ^aeyS@g{!@rWsSpW}$`6RpR^y4>MNbZ>5;Rsh&!Wn!B+6UmP*0m~ z^XV+?m)Ih`hKj)u$$WkOBL23~;txg)gN6~imS@ejmH_rtB8RA!;8#+nS?GkELc2^h}E4-70SGHeG5;DXV9)RcH(OYK=GT{{TA4Y~TO@ literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/actor.cpython-39.pyc b/sgl/dataset/__pycache__/actor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b98eac5c24cdf9b68c55a5374295b9373a65f30 GIT binary patch literal 4654 zcma)A&5zs073T~oilRPNyDM*OH%%3{H6k>YHffunso|zb8aF_&4eBCk)TD&0;cB&} zNbPVGdj+zG$VGAqnx2aQyGQ>6`8#@^Yk>ehwD zZ@$}6byDls%-o6UlSaQ`=A~$D((E@U>-}|R-L!-&%J(f%4jT)*zaiKytK&Vysn$-1 z>)Ng9K&IjE!Z=N&_J+wI3e#cOVY+lv4yJpURCf{)zCBQbG*rN-Pwzy53+A zisWt_C3scIkaj)&5SR3%JLp2gVV{fj&#hf1O2WNg>D!?r%EG&EEm*&Vc}4hU?ux3Y z-M7GZT{JLyVofwLRz9~JtM7w{4PE->P=V;ufQCoB+tXgMyE`6^2a#kra9EqDwvs5Xw_T2@`te~9PnZqs2 zJhVz?9r>Ak%svKIP(`bu)rAG}oSlcH0rRa}(`c;PW6_Qi)lO&AX(Cl9q=%24hAzoL zJPNh{!62H2zmk&p0Z=FnfcAo59FJ8HNEh(BHI>OQOw+Kq{9rmzd#Nr30@BfLkkV1Q zQjEb^=&it9Q4oy6IFthwE<2P8vqL+KC5~P&+b7 z6zHPu`<@v9^?>21U>FV3GzcD8kFNIil1bQ0(}Uivo9|tV4)68SQPd;apgBDwrS9}l zHo*zClsb#mn8&twjWw9V_~UuQP~2VdJtRS{9*)Tq8NQ4c0lAGJDdF6=AkET(y$wZr z&dQllF&FDzX4W8)^3MD_$xOWx#L(wpHyo&$41;KJJB+&Oo|?DJ#}R3L5QU1M~G+`HGKWH^Y@8LC#876 z^5BX@SroxHdZ7ILj1Kn38Q++-1$j4xhK4#rqi+ zNkvF3Z=t1I=!i9#EJw4ASokp~q3q1vu90uYkv{~SyoydYf)lF(r!XhtPZ|@;eDVAmJ@~K<@ulaCF^bLlWxuxJK`BXX1Q(D z$<9Fi*LLo}DLD%UN91I1ObfWAgAIeth4mk?o&SRUzHw2do%vgPN=?%@dOb*?I~uFK z+3i^x%3%_#FjiftVb5Hcm_*lRm(bL8?OekS=4s){8@E3XHu+ zn%|pyJ>w|SUZ*5~1PF8E+Pe*BV4O;6D#x*sWqbkz>5{f*GV0gQD8#8(!s;9S3jkM0 z#mJ{kxRqcPh)Dg68d%KX8_b7i@mTH2Y`(ViFe~^olDYuoDw>hhufdsuiSIMPEIe1gKs(j>8h%uvsv-~f4?q%Ll<=DzSRo?f3iQt7< zRaEh$me-fJ;4W`rAKLp2{&S9KO+0dP!_epQzL!yNC_8H4zSgjJLp8GoT&8!lF4!?! zjSJ^vo;ULH8TvL}vA$yI7iTDJ5M9^-)Zm`ia=N1p+``%dyX8&M$k)|owwAeBGpl9w zY(1-H6|wfEmv0z8TA~R_{yyKFH`Rrst!yLPq%i-X6~E?K>bbA@;T2}(EoRN1R~NIE zvAjzH(jl0SaC>>{Pp#N5#%*gGtEX<1DCX%Krztkpw!UMM63yfpg(Ks65>v=|8UNfc z7mRI>*)pQAN2~&$`#c~XEPXM%_mmMB7yMH|7#BRhvQ+e?dwN#9(9C~G?(^YeJm=`5 zSZiWk3ll?Ejp9aztm#z5KptR|CBH@oKjknfCj5-n$tWMOml5`CncUU(Fo}>_mV`P) zu2aFsbB9yVGKaS@F^6BIKr0yZiygrl3maNqk9`jR|07=h9J{ZA8~gNG^uG&;;zFQFmT5|4B#o=v8Q$m{k7(@-;| zee(O`(JYyz(x!c!bUa7;m&9Qz={_X|Ky9M}(2fAIF2DO5leEfh10LQRtImbOnKNXM zCND+ycsJD!B#xePPWcN0W+w&v+$`4M@qA_2a-nmCOV64ur{YS(PCSO7Z;hwTX@^% z9^NR1Y~k%9yMY<9R6u%s%!#^3tP(&Ozv>`Erp2@ z$`QCQZh{4+11ACfc3v{+0KyBx1nLshB(AV8!6O`%l)djBBWNHD)Kx;A%z*QW-9us+wZvA@mfp3NXE_EZjRVw4;hBgAU)Ud1uzDX038&trl;AOAlQB zOZlI%*i%(eq53NG;fia!Hj*)aM^c=XU3A(n$`RxfceF?4h?!MPIYLuue8?LF>Y`v0 z$z=ZG`%#EgqCJkMGt~wTDlC(5v{4ixD~Vt1)Cx1MmvAL6|cPG;*oFE+vMWO2rz;^0I`b+*;S5L@Kb!A0=3bv@8 zA_h|V0Rh@-GL?In$uV{Nv{pGO#mcKn%v%*=Pb+Wp=~C_w*do1x&VnK~p~yaa9={hb z;s-T`N<)oZ?(zB7(tSNu4I&vP2P~;vEo4Fwp;M*elX9*g(0&k15-~$=8|YdP9Lxq$ z@x>H#+JoZ5jjO{V(IY1zU!`tD9i>#dzN{37VHBa7J#m6mK}P6wIZ3-z|7wT&BD-{_9O3(vNJ z(zHwNsY?l|$y<$w(-+QvK=mpqMC`ALY8<7eEH^oi_@% literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/airports.cpython-37.pyc b/sgl/dataset/__pycache__/airports.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4834271a965110eee6ffcf3f62e369a2248281a7 GIT binary patch literal 3862 zcmbVPUvuO}5ucH?T512X56NeOvJUdz@> zd!yN$y+}O7N9Bov_f)FvH~0p80zLqb^UBkGg}lJ8XaCq|@`Sdg)#@3Idb<19-TTvK zv&L|J{pmZ=%~i(!O%IF90rNISEI8{SD8qh`N3YV}*AcE8QpeI_cxea?g%)lcpIir^oxZuJGe#CE$}SMJMjd;q4t zn~CUNsKPu_U0XNDM@b-~Pz3vN66w0go~22KN0sSy8Kxo|1^GCMRUl^X3~0y2RhcvS zE||!irf-SPr)-}KS5%+-eLHf5C;aE^l=my(Yocy=S2RTPIl~!SqK&&JR>Ugq{-?}g z{Tg&%)80>FInJcYwLQthFRY!O_Oku`co2ul3$o#hb&!EOSAy>-pss-i)uW&&90`j0%u(Nt$ zpYsoCrh(DKXbA>Io!zc2J$z5QdorBHiFTuj%*N4lb8hGDG#o|SX{NUKvq>s$O#Q{l zdfm3JB-yh_>e_F^WD@;c%1r7?J{d<+R%uf&2;vm>4LAw@ z;~wqKe5KcebTSH*4C6GwM!_Hn^L#G-I80)3NxBs*lnH`iltwaC(R`cQe2Z8dLk+wb z5a$Of$`!F)S37x>>}!V(s2!PQ3c6Fc=(R7=pl^g%6hKV+!A!#-_&fWe+dIfcQ7_M* z_CC1(;C6Cy+{=eaZytWVa2^vo;}h9|L7Esj#R21<`E6JY?wX&&r>l#z?JTS~;ItMO z-O0c|!^n3)6gz|S4=oJ)GobTx>)cw%1W6r8?a%FVYiJp_Pdm4l+C*X)C2_8%)g;`D zlKl3xc`>PuRkxz;EFWvHdn7z)Yo6GaPyy~S>5=zY}H^3}TborDxjkP;3auCjq^Lb-Y`}Nlf9imHtHo&CgZ&*mRT=X zax!@1<7m)>C3~YVSCQ-)gY^DTUbu4U=ZF9IcJlzayek4kQarvOf>3eyGdT*CF}JRa zWt=KWUX}eYEqTRaQ#3aL?>>2EJN)$gUijw|dI&GI-eq@14TuoISn!m_0^Ge4qV>Jz=VP z<`ouV!B_rSt?&=)qO!&CTq|lbzN`zkYz$f1oNI=8quHY?nz5y3Zxs${bxEUE*)E!A zuFvE!Wga z;TEe!r>GZ=Vy(bAg!gB!TsIo83LkcSP+mJ(QJu53V!gOVb?Grn-*lK-KdSwe%fE3} zZg4hTSJ#RS6YU$qpn)tUF!#reC+v4DwVqtEok2itSAGkQ)Ld#_d+{G!CgS?lc?#>V z&zS(FKM3XzT}$(0%}=!TY_Gei9aOc{1R;_=ng;n*+R~4Z>rLIzHX>EK;TX$=uBc27 z4s`87mJT09X^z5XrDJVJClRW|K{&=OlrlWg{1LWJBNX}s`%?lj=|tP;PP8K?qp_~O z`_R-z`8s{liH>8{-JAizk5dsH2cvMTt&?1H)FY}En#bCf*)u35j^&&5nmfau_V%Mt zO(Z_(QJG575^0}kpiZDw!ypU0#;CN~&gDScB3HV$kVnVb9%PBG8^Nyh1cvANB~hItwHuWM#pYJcAdAl1<-Z~ zwwFrVJi!i3=Ear-b05Q?mfUPQAPezxZ;oH9wC5-RMulzg3zW|B>j1x&s2KcK41Oz1 z{JLeesLt{0l|JIxnqe0BwGDp#Iey=H2`PkM)lTZc{a{eJ0UWed1*kAOzuBENW$xJ9$@1S{&K@wtTzK(bch0cSwkCY#G!M&34Z zNX0_le)`YH47hKL3SqoxExGoBOK-u3)g!Xu?}1`=xO$3SaHeDGsg6L_`h-J|R?*(K zQBeI|N&f^4=;XJF`wo#eh)|Z-?yS3>-uQJAp&8$f)A2-Yo91ZSv{2h zrdyYUkeQFD!pR0mw@&Rgb!0)HYe6u|!~~63pc_H(bP^`B8Pj)a52%CH)XB^*)c(jG zk&lR62hq)WUpt7B1pWK6u|o%SgBHC;gnCd?ASVSptJQ2&8;;lVZkd_xs;Y!oXp%l2YhsZL24NOwi*Zy|`e*>jWFRd=- vAlaFyILXZ)j#=MfVuLijIK+&jBWUmL%uqk3z|LDB9iakSzMzHa{)7PIuH literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/airports.cpython-39.pyc b/sgl/dataset/__pycache__/airports.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a9dd261186c27d432d1200c48926882ec0899d2 GIT binary patch literal 3968 zcmbVP&668P74M!IjYgllt6lGUQ%SK-j0;hCWd|w+#z251F2GrvvI`ScU}{!9tDW^| zM(%0ZyJE(N$W*vc_)eutd-VUpkyD?jf{U*lb0GX)OWLn+LaKiA=5-XO8y*ELl zQDyjDd-75E_6lSFq?@xp2RH9yWCBDm!DH6t=4uV>uES@tQr9)Ia<`1D6?=n9w_@h( zxH_nHYle5?`k>Km44U2MV5z&r*&QZI!hOMn8`dUvcUkZUtX+PMH?f^I*QGl$8193q z?W7|7AW%UTsyMH|CDCx7!8c3U{KcCY1`7n`?R|&gK-pVHyp`y7>+mQ_TEZ@LAaHqYHKeYCF07s zvN&C*y`)QV`Ye>X`kNpgg&#_pN?poE!%)gHZR+`cl)w+Zq{Hd8p-g*WmW4CV`NKf% zXS!tEq}`d%^qQZH2ELL(l=#@l@5Mou&81I+I1*>1oBl!>-|vS>C<7JFx2evzh{Pe( zz=J+<6<>v!BDU>nCkx{}?a%?WBhyqtcM2Q5`ad-28sP;6UUBR=)6n-{u`k=5{d5p^ zvg}Fc!JWI;90tXg`{a0Z13+vcg$OK6pNbN7|W2;D;VoSgFfy04v;P;1UgCR?2A>{)RpYiI0k_f6h#%Lm>U%zf)Ts`Y72|73^n;N z?sMyew-jb39O$)ka+@RUO&}ZhRt~HixpA6sW^#kr($4tfeWixktxl&Osr}LJb}t=t z;xLubZWzh5lPNjsz56umbzslVAjniGJH{lPKNP1fJ@NOw|9iW6pq$+mfg&j`UlT#d zxO=G_1j-m(mxeM*lq5gOZ_~VO;zfI;i$9XFu9z)}z2F$&Y?0>4tJyLLV=I+!!X5h<(hiuD{iOUmyo59a6cy9u)bN1xcBNR#p z|B206vV4o$*CYxJA9sipE@1%o|O+ z7c^xHO>>8|JfmT=Sjrny*XY=s;G~9IBCd~jL(6}h)3SAwz2w=ob9z>$^}KP+4;saC zzDzoT;(6H6Emoj=Q!VAFL(se>%EhW$&fR<^Z{@YTp0DOOnehJX6>H`UE208R|F*a^ zURJH?YQC0VqFVNdCGR;*t-a(&f8(rJ=WP6@x|FXQM63&gm`Cgr*t1xFzMfQP*Bk5x zo;+jYnfCuWtG&Tbbw_>&A)vX`y!!ep{7k%d%kxCoU7IsL+6mvE-?S~w4>dp1*0bIA zhIUXP6BEKH9B4%3H)zXBh`M2FkhT%$+6{(SCUi-qvbV3RchjW*FibKONh=v@J3I(GvbrPa+h_OGVAd`@^jgChzWZf3#0zNf*AouJ<>%6WN?ey^w?(ybZd zfqvVdzcNSv2j?P_(Eq2iacR(xECaN&gGx~Ytrm4if%#*8@oE#=8U@g-(90b(j6U7M z#kwXC<~|0Ng^OI}LE`3(+!HLXk^aXPvIohUxd%i7lN)R2Ix{a7^)*(!mIlDfS}knZ6>0;oTSM(l&LNQ7KBYnt%`g=gkq_^ zJnM1HtA0t+UlF-W_tg<|7ym?w9FSN|_jT3x2dNmLo%40w_n(Y{cs65tTI~Uau$nrW`DK;%xJkq(av4N7<~?&S zjAMKoyh9U)Tp^BnM)UbVuQ!F>vziTyS6vL-YkF7BYc lF(x{w3qNxVa~9M3dbek$`Z0xe)&!wUhGrj)ex0w`{{>lR{ZRk_ literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/amazon.cpython-37.pyc b/sgl/dataset/__pycache__/amazon.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e3c8e8866e94e090ed71a6a74ab12e81bb20ecc GIT binary patch literal 2570 zcmZWrOK%)S5bmD$?8Dw94h~5W4oILCWEXPaKoAPgCBaVD zh68IUr-&bbkn9`$1Aa?i5r_PRoT%#AwZY7)YihcCdg}4jSNn0d+ahqi@%VG`w@=7l z_^`YjQ0~Aee*nP=rwK`EWYI>qBgf3SkqdV#@lrqX&6=GAX&8k@cam1xj@oG_>ZIMM zo32DFl-woU<=zqDo(Si5w94tXr0+k6uaM0?)$Uyx7h_P`n>iPs#wsp_>N~ny>?KT! zn6vFkB6OSQ`&p91qsnz)=Cfj^*JPaWJY{8(OcdkI+ao-Cd7)o;m75?0i6|$L#d}Z5 zHsua?kHW|n4)?f!MCLSdK@WInbdR@q`-nt7@9-|%1HQsn;T}FE4vAXe#Rctsl*Y3> z({3@&RsO;n47JyoG+KTb?>KBU;y#Fqa=Pcso`OKm>CaSAWmRNu9ov<~?R8i`fv_v5 zBD?OKR)l$#4brd3aZuSO^btOVFj_F$oPZ(DX5Z4`C{K&260(#bd}elS5zLJ&PQ^x+ ztBrD66uDG_-`$79nK-@9TiENd{4663y3Zh0I5yEJc$M zR>W#tYS#o^drb!PI%@*L*pA4AjFng{w3w3dBx4hP00;jVCq@gI5~V^FGVYWj+13tP zs~wr=3ar7j0yaX%MoC7v4P>pTH=ugHYrcQm|`rl=iW6LVks`{brw7 z5O#Z0Uur(bL?T8?=GkOavY9ggZ1!Dk=VhUt-F%XP&obU;d?Mu)I90nzzAt2-%5`{> z_@PF;f{QD-6;a>B_!7*O76@X^E-wuK?~YURaTgT!nfaX7^wb4J7}<3YUrY1d9jv~Y z`=lN`sr#;c6<(GbAoL15%Qu*TWjncW^MRg~`! zhdUEBo^B1GE{5fJG@i=gPL|!=64_{+#&T~6Q5>dm2_<9DGlZkItd&Yg${iT~S2i%cOQH>ycHvwWsV|N4XV{1#p`6jMhH4>wr6TsJzOpyvnbf zDyYJW9{VSxvbp;#s9P$4b3^XUtzW*V+lM{XI&N33s*U9HfMmBEqB?sm`8y?bhmzTv z>Q)^S@fGgl{Q-6cSo^8-kbF-v>!F$XhTJC)VIDFOY)ULAC|{?X%!KwK^L#o|5i$uX z47i*j%uusIC32=jOdtiIpgu;FsgL5ZYC_9F4V;XQI?LAWtIVwVO@T;t-|_ zm_g)e8@h*fI81B*-dBd(|3Y)61(M2<_1pz~Cal!J*gb z3bkOShrl>xro}rWQ~2--;?9O)aPc-Mrj1zg>c8*-=%}46P*&4eDhnEivu&HIkO->T zjjs~{sNR@l&?z>6h9`W(kn9G~In+UR1F($@BJZF;FhV&Pl0w6Ec*f%XlEi~0ho3Ki z-uW(o@m|ycIBx~kY<*b==lTU$fjvy({}p3-zm5F}tC2C?Vk`yF06{X{VeIiVPMQ@% z`8t5uK^Iep20I22a)1ZjLxFUvyNiZ05=jCL?v#Lmk+0z<YN`LGZLd9!cN1nqaA=5I>%?GiOo++Pk`N{#1{(`QS=&`Jty%qT zRkdfPE#)G;;46~{2!WYP=PB|8d4ReNx$rCORn94O@9a+0)X`B@sdRq6^PSEnlOciY z?I*vL{~ZwWAAET6V({=w82Sqkf(TlWlBOQ*bw6cpE=U8odqr3d(t%s^i>Qp#*y*eo zmZNl3j?-~DNhjqrol^3Uh(Lr-i3nx9_0t1EKPL0RckmT*Jf}8zsIq1S%ILTj^4FQk zv@~;OC(Yx6E13y?k{8mBM15WrH9VTyMsB`pHujFnDp8kQH$`r^=-*!A*)Lw`8(#Vt zgd`~yB=yAXYjQ#b6T#Cs^(7Ob7(69gng*apB6fNxhGO)TqysS)6SzlWDh}Wte@z%k zhv3O8HvFi}Hg#o#W>uT|nYT!6*qgQ7e;7ZGIu0L({xgV{3i>5gGXnwL(!Wwejn|T` zcjdR9@DE}A8p3Z_OU{BVZ3z!sALO7VS5fO<)6elKhB1UO5(JE4$8*od%erjVMk=jh z_{`?+E~rOUR?4HQHb;8hG_^8P+}{i`XM8dr+uJ>S zb|6T4#}>_{IsmgrF!UG%p%MJvpko@+%|UPFV(;OS;*Q;gdI;lJF!VhThFn1!&OCwA zzhc+q?~uBG_}3nU-aodl^`GRjkV~WLDqm{8VGF?NJg|PP8_Uk>yaK;fcFskv)SGas z4T}0)syS7M@T9)~T>JnRr??d>ppWxwFsD5b#M`{SGydfrFS&p_q1bO+5VWJu-9V&~ zGln?ZyAYgWHEx2hy2V%BJWxmQvbqn#PWesV!E_s*)T+!NHyjf+_nh@MLQ`?#mLF;c zXP|%Dm4wIW=I+k?n`3X>e8~wc7<|j+;Q=V<1Qd^V_&g)m^pJF(fx_8RS29pYXWrJa zm&?uO=PP3x{XvqP=4Q2iw165(^lG_UtK_t*-gzXe<*Llo;{>9al$nO|aY#xad(UXH z3x$my0X;eYeguX<$*a6FDgb@%_Yv^ds_4VOLAw@I03bq30u6Mu^E$t!h|KSijODrx z`0}5WbYn_3_spanyO>YK0PmlY&!A&;9>A=C6f&~)DXC`f68*0Sgu35n zARTo@#|jG$8UfS!M%n@3Ladi2MZQ6W1M_nPAwB@j(F;p2t#|(DJNMp!!dg{*TJ$Pw zqzz21mMhEZN^0u^E)~{;4XSn7T*7n%Gl)y;L$9$+z_cCw?vIWE)lYDN$&1{~Cw*o9 z6laD~oJ!zBAurU=a3X@!cm=J?`l>#+{<1FY$Za~8`RU5o7#F#@1cPEX?^JqfJ#}K~ z?Lq-^*cHSN;dK2L2tsFc2mpB%fWqiKI;9@WOc65A`D^#i(oqAx0~ewaYsVo5-NAV0$EwJ9eG6(mXff=>cQvCBZkzjmDkNb3D6gQW906hH;>gkQ5fDFANp&BP zj~t^uLV+-aVsTW4#@qOY?+1IjkM?YTyF6yM`^CF(#I8yp0AA#64);}ZtFwR=*zFYl z%W;;^N7%WrRynsr&PxCe&?vWK&Y!HaqF-^%!A1~2=y&GQ;m3hOC3w*9Q6TBs$*%P* zWl=yYyrJp4xCOb|bxZubMo>j`8hLc+!(dSu9k|u`%noY3z?NwlN=L9ZZ1UyfLaGn( zl=nf{SD}+Fu#+v2a9nTg!Tb9EFTnpLg?YBG?AKssu{L?3UB`Cx)K`bg6Ozsg_B&5p ZuZceFWAqCwDAz2<0J6v%L$iF-{~t8Yo=gA$ literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/amazon_product.cpython-37.pyc b/sgl/dataset/__pycache__/amazon_product.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8580934d3bfdecde08774b4272a95193b0dd09bf GIT binary patch literal 3674 zcmbVPOK%&=5$^684u?-Ytd&jj(4p%ws-wn$FM^`VJUXyP?o+YCTW-9lyjS{8szGux9WyjsF-|^j`GdT&Cm<&AC9mph+G$Z@$fMtx+vm~R3j1p;N zq-X9!k5kA%cC3sX6f@(HzQ&pjsRXIa37mwrgas=#DyonU5>M!wvK_~deCap>_frdS zu8xzt*6&wyqRhE+`fiKm0=fss9XUmH`;Y#nQzwAii`jO-FJy`#Z8)&KI zNH3KrHcVxvUeHx$G87fK#SElWLzgbHDz&IdM~f$p+??FHLt!j_4ANBN=b^$W zXWZcCmYGwH1Vo^n(R+XaL|$x%oaX6}v&mX!;vzNxvs3hXR z)8dKk@HNgrNtX33?KL}Y{z^-#HY~RFq9Fei86b!AAwSZyX$aY z0LWQ$RoRGd)MbMTTv1OEOCY^3-w6CvcH%I(h1;ZfPm7FJPpSi#!!^e5-MkUR9XCjC zfiRgoJ_XautI}k@bliNjD{iRjC)YZQo1$U~yu|3-Pi;F@4!%hN(p^aDB4nr(KstkD z{1sm2Z|sn1M31cImbQIU;O)fEAX9b3!x4x>vxS(=V4{O6I{N0eKm>)TQ$>GhQr9(U5_Rvd2ub^$*t zgU(mpUJ?s=bK*xY`uyha^NXE0>iQzQDM!nfukCENZm<33^221a-+zAT<&`_((`Vt@ zv)1|-Jig?*Fl6<``dV=J&F$W${^pmRe!u(j$=xr8tHIjF2A99x>%ox64}+_HKYZNn z%HGvW*IMB#HyErB?>@V~cH6yw?WtP693_u_wHJh&QLoj1wYIwb=*qJ%*LGjtfA!>> zUL>zR`hTpmVI8ALASjg8jYa5!b4r9CNr9kj*kc<>>55;7-??qIDjY zL+_~3`W$Qw@g^DI5n>1jcLTVf8SND~TOwl&pBQ245vH^t!T_?D2;`Q|i%XLFJI`(;u$nrmTUrBc3aJ z7qZGh{gB)v-!8mnc{Q(PwX7^>vg$s)L3aMWPS$~6JR{G^29dJ|b6FW?)#ZFr)~68UdIx5hZ{DQ6pJ8GAkjGmXYors z5Q8vM?xey2U<-dw6@U%Ju`)!wtBg(@h*cbA#)%g}xPkbk3cmEhR2g6bI9hbJRrzro zZN2oOG!_CwiuR+*0Kx?~8-$4}K74Xt2UOevPv$9^dkuUC-azma3#HdL$dBsx&n;agj9m0Z7l*Jqn!4cXd#aS4hM0JLkt7F&*n^;2L&AyjAJxs1~^T7=asSnaj6zN3*XiyR`!cVIJFA+x#$L`J(MaQhFKhhl&(_Qr+MQVh9719<_6n$It28g>CZmGnfk4BqZGEOZ)3cd+f(kf}4RSUwDszze-ddR)J` zDx<&lwA?6;83~4Qf&3cj9+2_UF-0s+7@|EPh^Jp`_B6%ZfQ8ao$Oyc~0$pI|p|qII zM(0nt!BgC{VF-L1fuUQg2@_j!S{q7Wc7#9!PT& zJx+5I-dnSIZMN!@k2`=@3j-(3P5PjvIgoyk>i-Bj3hI*{UvDgF{%7A9tMC@uLW)7h PhPSDX6L7#Poi+XqHzL&h literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/amazon_product.cpython-39.pyc b/sgl/dataset/__pycache__/amazon_product.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2d968db74e6f20bb23924efe686bbcb57553855 GIT binary patch literal 3704 zcmai1OK%&=5$^684u?-svQ~~Zjw3AsV;Mx0yzxUZY$txNVTZP3$MO=p7>w9MX{dRr zyJ?BcOb(WT%_Z2)c@aR$&G{Ai5jpiWC!cfE#i{NQ_1a`JL_epey1M$Ss*i51Rz~n# z*u3w}*AV)b8XSL282knb7f?||aez8`i#tS*jnM2Aa6V^sih9ltOP%t;T%}WiITBRE zTBoMxjG!J)bS8A)3>x8NXEL1XOkwl@(E_#JB5HXPJ4WXW#m`Z*_#Spf?IxCm2h2^@ zp;Kwc)Vu2nmwTc~WG&eU9Ok*y>G^>tD>UAUf*6J(?y7H(Kk7W{;cciq)WHIuafi|SgERfW!$O0cE9@8UfJB}au!f}|| zPfmkwU5|4wx1E!?Vx7x^L&3Pxa(EK>LY{X{9LsUmyvSp&@DARW9l=~ba(ucC@?X1w z?v)(jaiO|RQ<~fhdeV&C(32*MV*yGj+3&C{IZiilId`0QXn%2WJr2D^&NmmIKX|+l z47V5gTCk|pfxBMJ?dcpCw3Q551x_&mg;$|Mc!pH5g-uL8jAl+K-5Q^}i`Bw|hd`@r zuI>IYbk+djf5REpcp%hxYp3`wAOY-YwwvaD)unx>Hwc1Ol#KREJpzwUd6n!<7VS#q>pHQEui~bAQG@m0j3u~H&G-94evmg5r z7?-(Q4)qy;p0t*vt?;gF`VKQCQGddeSXsUv_*``3FuAG5|AvA_OD9dDgeeL6t+%cR zan}v_P3Vk2&2=DJxGar#61CjY=^^=m$!V9k8g@}zu=YD9@6UI?DQ-E>@3ZH>W z83hnepcvo5&HR<@5smDT)mqcBz+^v#b!-kQSy#9?05jMKP^?VDw3^D6N7i8qEa#wb z1pq`wpPbb3NY0UYPH1YwSL%RT(|*8stQpv6J=*WPv#GjU@aSugnX#vif^-8ky`EW9B`^B1pfuD5S3e}3^{vbwSH zyM-5@-435V3zwg@S1!|d!F6HD(({$&;Lht?{e_LyE8UHa-is%9E)SQ2&3m7Pww|4ap@8JG291Jx66K_s`F4ats|VUd6abks<@fEqXS>(NYj|p&RBBBhMR+Z_|Kt2*d!(d zD)^B)E@K;)|3^DIdlH<7vQ8j!3Ph?{s1Z2=Pg0l?1TY5~gF;Y&yT>W+7buxUIZjWk zF!c&kSddQuy|V~vOV`C&)H7gY&murlAua6Mdq}nY5|lEOih&5qtgM(puFOibyoOUU zjk2<+9%@k8BRlw;<;)Z{@M(qFN?O^i@1eWstLZOER?TW@EiH?Qw7P?@qs_ms0HZMB zw_{)wH7+9l&n1NQ?+d6!V|OwwgOs|MN-gjsJ0*}2ODRt4wEAZ|n-Env`59WX{p z8l#4o-u)<@NE<4TuAu0OiNwsec=!UNY!aixLgYo4>^`D^P!mh?ccBttGsht4OB4R0t-rK_|8chw8H#A}kUp2Xg1@3Y5w;W})< zTTmf94K8PrY5WPc$qe+$dZvMY0d5G+_`w2aB)9=T$Y7_L#6f!K>g1us!Bd2yvV!)% zH_%Bj#T$gJDkcXk4l)>p>;*pXaYlr>ThO_{IDsdH@@vP@lp53kAB2peYw$|#X=YZ) ztY`*yuwaKoSZ{D-iz0=r@&Eu>b@fIMlpSqU`J5l|9<8F}qO>q(2b5skyq{L{b# zUNz@O$w15jar5b%re$*w&xFUKd5DRhD-nlTR!2|vP|e+jN}g%Q;$fHsUg$-_qZ$XY ztfIE3`^{p`OjJ5&zgJ%X{R5I-I--olF=MpHgz@wZHcnH{S0IQtph9pz)A&d;Rjs*YtY$JYkKT}N7>H4o*VI|@(->O QlMT01%?t4JDsCA60f`~KrT_o{ literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/aminer.cpython-37.pyc b/sgl/dataset/__pycache__/aminer.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a800d8072575c5b67b49cd943189b47ff225be9c GIT binary patch literal 3443 zcmZuz-ESMm5#PN#9*>`*DA}@*52)!^*dkO=q-{{BMw>`+4LGuEsj)-AxYFDeb=L8Y z+B;elbL>MUGTgFMZ5^4uIcaBq1ij1P|GOcPyU6cE>S$g-!vh6}p3Br)Y3?=nYDpl7XGD zJg9Uk1}=ovL9J6Wup8C~OPwXoZZT05-Z>MVtex4NhTwawReFva*=~z#_qU%0kyL;R z`=enf0aSh`Q>o%xsyjTy=W_nJ8w=C#5kF=s+GR7oX zhYQxR1Q(XD`*z1Rs|DR0=v5FdN!NEeu2}n$9dJ<=mGe@kC_Pyc)qhx`ChF%-rz|UC zNi@#c8Shj9FN>yuYhp#Lp0iF}Toc!^E`7-y)@eY)8_;Y(n(4x@JH+C4N9kdV&!;jP zN$vW68vB0BRvyIEcYv?9-xM+LwtTn$uQ?qsL*897^E_ks|j3LyIoW zG5LWPR(rE&0YN_-gG@tf@&P8nr&h``80N413D1~srFQCM?993-WIVGj5eTW9u_whd z!g_*1EqfP=7`@;GN5yg4Dk-``I|9zpl|=P?g5V}yNJh|H6$$P65H?7CU%B+TI#h8_ zCJA&P2l~TqdYI^fIoh(dlgRKuH4WdWJD{j2`F=0#CW-I=o4xq)=3zXLn@MuKxp(W% zd*SJLGwFw$vx_%#3ERU{bsa3;!br%ndOE~Jaj0D- z(~*i6_9G?9b{}ITKfy${JL7^M*=jX~70>veIjm`~GuX^pBg;DLY>li|*nLh%C+|h0 z!DFd5;)4wt%7Kj1WFx?T^t2lWV&foIfQO@W!-$|b&OE1kVG%o&kuHDM4M%cEDa6sr z&i1VxfB*A)JO1|0o!g(@-ru?3veY`r)!Q_^xtPcR`ub}MQurmQ`VnvmxrFg5_xR+k znQ8t%b&o`%A4x^SSfdUAIwPXi*c3K^#@6nnxw-JMaZGzi54Yf(pveW*9yypS&@X z{UzI%@xRc3tW3jbKx&(}$9E0;7SEmhntw~ah)<$0?uwt&h2$K@CyjZe&ykBeR)Mcz zP+%teAP?ZX@gKi}AGg_~&!$M3fRP8**^#S0oWp{zvkO@95`NCCEfzHw7>azem)gR~ ztgo3cd1Re^&Hm|}an6pb59k)8;HfiyXLi$#sg*gIO}wz0`+mu$&hb-2TL8U_Q5;V| z??AHQ)ccAdLwT2cS~7Qg>rA*9kQ978TItl0>NLGb5wec3wK)Rnn8>7RtE-Ij9sh8F8S6@S3 zr88O8Xo?+ZS$oEhw~d}l;~O&$by%KdOGE|iUktYKDt}Q0rMctCea6*)=3k-;?HY70 zLaxpsaQNl{OG~hY3w?N2&MJ^#DQldSrpsv=XDgz1X8rBXw0YV{s~5FwIcrkpf54&- z9G2FP%Fisd&DnH?v&m{|XDj5%bV;Bp5EWVj;QqAokbTY~>!I2CJ-g2y<}q8})qHFc znJyvu`4OskGAX~zgOkz*?NJ)kWs`J*NXYSX3V-!IOpAb}h%R8PU?h~+`z&LR>QjX2 z1p@Dq|J^!v_VT=kH}XSJAg^2P_f0yy@@gm{z9?^@h^D=GFueJQSQ3AiPrg5YF|!xi z-U{PhH%x8ojh(e zbb0sgcAnz*6g^jz#b6@QP^chEbz@Ng_Wb=n+`F@5RMv$wR=q>*#F0cnRo$pB6{QgD z3}s|8lD5%3pcTp}D%KC?B{UVYVf8R4u9Kr?J_Sfs61}*Ca>F$1s-N@@Kv-S4enWW z;Bb=a?xH1`g&ky~cxo+E=R~E&_jTF#2eBBTUGjC+_m4;2FvqB0fmE-VX5j88ejYY= z{d@O!wr_vDuYLfm`ZZ1TY!n5owo&`@fbno6+O+gky-QdsS&C|wdPsD&S=-ffzjhg5H4Zi&t-KXg!^T5CMT?D#Q sMY%*T&vDpkkJ2DaOn+*me5q^><>9fFyWl1zDD&hg6g&h(6%MBPe=^fi82|tP literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/aminer.cpython-39.pyc b/sgl/dataset/__pycache__/aminer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..051a7a09d6aa2f2a2a93612deabb25bbe5fa45c2 GIT binary patch literal 3497 zcmZuz&2!tv72gFw5ClI&eb};^q}8<354ID^$+S)nt|ysDaXsV6o=8ne+QTf&qzNkKTI=-xmny5c# zQ{JfoZioc~*F{q-K4+bVSQ5*4FMP!u)@ee&E6{6T^s;qn*d0ox-R>ycPn0e_mGMZb z3ItQ+9_Sm zZ7Xw1vEVzaz0tF<&HS+ekeTS6{t}bm6Dwmm7uKKoW1h3X%IwU^+0;5N3C9`kb9mtnlSAw7yb<;EmUb63gy0M=O<= zmQskV<*m)zTfy#w`&+^0*4;aw-r3#ytYxV+u&W=?M5-6l_W|jvuPI4^m}crHprzy% z#%tW;<9BA}`TyKK5^4UBWHgLL8UUa-B5RdRUE(Iod3SqTJ zK|nF^y2K6DJ-iShzrYEpx$KtD$5&=@Ugmj2K8imh-`fBX^*m$8i0H?bu#RoKot&Q_ z+KcEe^(hUFOJ7dfUrt&jZ71naJCBnv){WpykwApF;qcX`CYiA1zDVi(Sp=Bcad+KBVO{t;jj6+5{Tq5j*_nU6&*@WVtm}3 zNB*36>6{htR52(}(_JhH;m5Dx(M|U7^NA%0MrK%JPno(ghXr3_$FSuIJf2%OS-f;c zp?hF(nJt{$`Uc6DzSr0{?C;K$b9Uh41SIH*Gdtzl#LAuArhV`>ydSWMb1(#73B$!G zgV%vX!-@AbL$2~p_@rXa^wyYgF>nI*@aRpqFE8blo9q%tdB%YihQlR}%sr;Pc|O|* zT9r>Y^ePke;9Jnx#5aYbBK+q#o#dGENflDnphcZHAlduS!pJ^ph$=W)=H+$#8mq{z zv?8w=4S|7n^%wl$26kAOt!}{5EMFiFP%==O=lhBpV6*pAP}M)@U!n$$nzSZHrY<0? z{!3PY_AWHwxt~{Y?n2(2Rwj#?kCjzXpIU$Y{bXs}%xcH=d@*05tp5dzZ#XP#ys(bi zoK2QF8^4{|`7(JdTM(!R#Pa}AI9dMVa_krH>xk|zfOGJh+8lB49(!S>pIk@i((eI= ziQC4u=Fd!G(-kDcAV#H6$NqU<99Px}M`=;}CJBYHkk4Mx6;MCKbSY>lvP&4t7%AoZ zKFis|1_kah!taFt%{p*)iZs~bCfOn1TkY#bihK_%E~H5RCdzBpO9sPR4{1ZH8!op_ z9gOjh=C@~dW7{{Qq}PqoTL9*(-jIzh(++KgmaLG?qvUu&Dsw#Y#|L20JbXQI{2FM9 z@gEG(kF<68sMXZ|_Pxy_`R^#YxG0+8Sfc$!g-H8nXQjH}rjnfo zd&wvkx@M@MD0W-v_IjhiDC%bDG7JUs=tUB~H&O+LR%`cJvA{hl-n_3ZwWs+kg_{U5 z_0yYI0CI`xHJRmbpEvNcPJMfcf6sO}VLaXdrGR~_$^BE$E@p=>o%)Vz``kHoojJG1 zE$;reX4h;7m93bYe8s9^t#j(w*q=A?Lv=oDUS>gvn5M{TEmDs~g(nEK9|VI$jL=vG zx)uZnqi$4C)NjbytEOeR_eoHM(7oXPXIq4Go2vSm-HGYk9)3XIhI#$RGKA=`wk7y zHeArAoTP1P3AAGx11kO69frLpkyKl>+ed`CW&%3sOZ0=83VZEY511BRsS5{(alg=H z2|c2rw;!QDHN9pL{8!#2pnp{*I$&{51ExL7!YDNztda79l1AMkNNyIUzeVZFTzYD% MJhWCd3j_220Ir@_j{pDw literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/choose_edge_type.cpython-37.pyc b/sgl/dataset/__pycache__/choose_edge_type.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73fbc656308dcca39e6efaa7219ac177080e49ba GIT binary patch literal 2972 zcmaJ@L66(U6`mo76s484+O^`{b<;vl(ljbsZ|WWz6kDUHcM}9~Hb%Vdwgggw)=aeO zQlvC9veyO$6qV5?$6k7=3#^YVddi`P{uEw&@?XfQ-y^katx=T342O?1!#D4J-#7D% zR;$5qy*vF-eXzmUf9U1ns-f|76!R}sl1ZMhuzK>4m+wyK;OQjpScGEig&t?p4SiKp zo~rFIDV{shlm0Ui21?*2TMf2A)?urw2HBQ<8_$JYlTFyyV8<<6&t2J;>##IS3-uLM zZ_Lwq?#T_giLWhrO?IELu&vhR8>f!`TD2Dpp6~~(fBhwVXNP^K5cfvL7T(cplBmLu zBAZ1K%ylJ)DzZ-|FluFIIvsEtg3Cd1QSPFcO;kBMXO_RPPd+|6_^j|O{jl$74->$}i@!w)-1$4x6#j31w0n|`)vhtq z-3K4t`(g6rY}X8vT^ZZhD7!m2$ugs&%dxj6Pl}BL%ya~|A6R|6Tolb@1S-ZGTsQH= z@*0N7LELl1$+! z3zXQXj90j()gz^il994yjH+sl=Gyv9jk8B;AAubWVyiA90sJ!4M#Z>52m;(2X zz9?Fkp@`Dixb&lFO|+7ufsF=9Y?cwCl+#s5p^(wL`3I`pIpJSXoLm%v2mu0f@9dA3 z=WM|dx2JRePwWdWxpbZdxwjANyoPAienL_6Frv2LAF#Ox_=p9UF2I^t_gqMkv%fOw z{mr%B3x~7&%=+gT38)3RdmiKgML8Ge^{VwHv&=T~U;z(4WnZyr@0i&&+ss|EAP@fG zWughiw2W)CjQ3Ctf#c76z%bHU;dsHn1kvFTI~5wq{j)O_gzaS-3pY_|;R1;lsB&(- zuREBk&x`BoY?5S%Vw7gGB2q*>35}Z4^-W`spgzdPlRJI98*;{H4sY_OZ{01CeUZRh zdr3Bk6LSZRiy@^CPp{pL)9=~d7o$}6RGQ6(C%yNcw)#CvX^+TjZuVRH4a{74I!=d5 zcc~pbj&(Xpho%tJD%O5!ls>L(pnIIkY+U%D*Jz;hI*ot^5Vk7h5&{a}WZEj(uWOQk z5;mG*8D7UYqtaKjG$p6QodM<@fxD~lu0C}$d!qE}2!yKOQMh@DO+&&&E)=a)J$}`- z3@B2vMNX77XkQK9oA_YfLd9Gk_!n;AwB1cN;7zBA_YQZQ23XPHT}Qu-UeGO`7W}l5 z1wsj3Bng+Y@F8Au=;$dkogA<}WqR8}?-%?J0Poy6dk4@5-~jz?NQIL_B(_tcIGL#ftuw86$RzYTcDUAO{V@Yv`A_c)x3hFuj32M9ZdS1`PsjvSs4e{4{DIr4t6T;BQ z>m(8P*)%+6(<4Nk80;UuEO|&l{y#)&_InrM0gu*7!7K?>Mz0X#*q#(3$%ePf+~^xL z-M6SB(3NN^+0$?6?~~zbc#G|e$zP#)NtMp)+d#NIj?YHpIEl21O_s(F6BQY&CMFE% z5e-tSZ{i=}7O~@hO#C`Rx)n=40e(ggmKZPuZ1V$D%n^VJ_~CZ=4KAEESQ3=rA-*`} zE<#F0ywSVpKcGQ$5kv@VB~0$4L8>oZDd@l>Yjnhs4LZ;an{y^xb1vI+2gRkGHFrT3 zz7TV9*3AXz)5M-o0tPok2WT$+r4%}F=T)41es)lW@ICZL?JkGQXH9rWtj9RwSE7U+^!(&7J7uPO4RGQ22e+-!dG)072|$+H%JX}S9#Pw#+M?9BI=71 zU`G=U^tY)ZTo*pJ6Kolp6j8fWl?NEtSY7Q3bet$$xK{z61g3QQa^^H1t0-D=?s6|I mg1vDjXNkH)2P8xPBeVs);|S5Mb=S6UZq^0A@tvR@@c#ld*wHBf literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/choose_edge_type.cpython-39.pyc b/sgl/dataset/__pycache__/choose_edge_type.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..846154101741dd027e4deb6b5993a89506e76e38 GIT binary patch literal 2962 zcmZuz&yU;26`mOmDM~A?wQFU6q-iP#sT;M3*EI^%2iFMd-2?%gZLxM-lt4hx;zTPf zMM}fb`Uj{>WuVEiy><(%duo85a_Aq!Yft_ca&f;ml-61!Bsd&C&KtgY@B6-C)M_;t zj_Z?8)$I+&{znh1M-303;L87tCT1~DSWw-0!0odL1nwenMlz5iH*h(NoxoEy<*M2) zi{+V!-Pn5}17As;WUImE$9351szJ8J*v6TR*WxDZYp~;tTW3zZ9$$f_X)QEXR--Xb z`^=3u;!V75#aH9b3l^-aEAcxgLVvH;=M0hX2duaC3bC`po+#!0VQxzI;dGp+(hEbA zg(1v!6%SNso{eGD%1m|I=j5_D1nv&5d${sVGzB|nhQH)9ZbZT7^bLO4b4oWKCqq+S zJ=p*3Abj}eKOThp2lqe!;`4_GN2P1%qdlQrOn@+#Uknq3^c~Ywh97>tcbtvXUY<|( z9(;EH659FW5f&>efscR){&DIfWDRT`YT>p(l3PHk-*p7W|ZJ zzu@Ph5HC4r6ZDY|=kOI}P*c-$Y!KQ*8`E?+nXW7;FNHEdiV1DL zf;Fv9mCjWhDr0k0HET52_K_N8r)nRG9rh!mmYD#KS-y^jamjt09b8SW$=;J?>mn6l zIvv?K%GOw`c-S{#KZ){1hN$GrqmB-xjcob9(G=pCe@l6Ca7p9{FqpXpS(|g@t(bX# zWnc4{$Kr)wxcjiqYRFdYx0E#(<7;z%o54dAa?WE1kWGwpE@N4+f3Vp7r(@igf&;>n zGvl4Z7jW|n=iD!R%5x#l>s9X?W|(Od{u~kffqly+@0&H#EF5}9CA`B+r6w%XLax$6 zK0-49n!oA-$tZ1w<~jcclt)nPL~0cHZ(3%&-EO8g(eqZXO9k!{o%3tGmVO7HDP0|<1Eo9E_n$;M z9j1f4l+-KNzRZ>WQ)PnyL}{FjN)Kck_LaUu9#92BT7_f6Lh0q1HY)DbH3@vX!bQZ-!Qh-KG$1 zcGoYb?mfJ)Rc4L{E=b20P0>J>JqbpLCTz{uQb^e7+!XplctFW`TJn=hCo^}A$}`scXm>u9Jj`7}YYhHs0*+P!S_ zc$h|Jn5DnN(42&j@iwFYi{@Kg{}RtVw`}Iq#{(UWkF6S9u4a`1TJiEAAQOMARDO@n zg8n3tZo#o;ch13^_X=<(2BRtqG9fyQ2)=5033=E?PG%nRRWt5FJOU8C77Iv4t*8|q zW-CYnR;0lvX_y?DdV%kO>dvfQ)C+ma^xr3s@Wr~76si9MnP?PsE6hieN0TFp3X<^3 z@{tm3bFX*Nd*1AIm#F}6)~sAw!q`leauk{4QYP8pmaT*S3Aq!(^iQcFtaN2y!@}cR_zEQD52*EvAeXCLp?r}{{eR5nL z1PR(jEi7)MGeGhG#b=b>~ z`+pKm&2grK_KL}NFIi4iI#t0&HMa7r25qjI7E@gsR(ZDVZPjsO#j*@x#ZUIfY3NQ`7U_2v4P+WT0EU+PH(mG9heD$;9kmGYf42zv z2VYi~3(7}uS3iIdM9`FEG-7n2JCSQ`UgW`(rGD0mT6WDzgDi|ftGj7C>qMQb8+Ef@ z)XVx&pOV`|c*1{3gfGLn6Kx23mke4@;8)1bfEw?%ipw!5ot;9+PhuTcQV(3yD-Tny zWGwjJB$cKkilaO&;8hnUu=c#1nJX&JMUnBUOedO)#n%a*y}B?iys8}#l0;OHh>6Wd zWRD70cn`zKk*@GX>mixb$OAnPq1AoS7M+JAYKgAs!7~tju>sHU5phY>hA1wX;8qb& z^|(-mCDl_l92tM%QL_47eB$k(4|hP+RM5kg+SHK8IsJuds##6u?9{25aJFFm46?4> znjCm@S`+Tq4#-wbPJ`Mxqu=9O2zMLqjv(NNyE9-agtH^39G}dtFSEFv$C=#D3%y-U z%d${fiW{>qS$aGgbWQu)IGxH*m4W~~*d&z-;20BdKFJ}WT%qgcN~wxOR+U^7gO{-$ zSH`pXHvS^{<|<$1g7bZuOBHLmTxxTz;z`aY;uwzpF;1-(axE*3DwNx;WV&Zu^w_wn zC^YzkH3eXVoF{2qRh<7xo(@LiB9o)4dN8_s`^%f@$?>S#Pe;qjk5<7A%M;ZH*E?`m zT@Zu@@b?1kQlHK)tvnsB;!G&+feW8L+_&JaUIL-XDHP#=36##Mdq#eP!u{@?F-W|# zW4c+z<0KX`nkzAp*CO@W36ltFqh6=UORt&@OhZ) z5U@w^8=L&MK*128U)a{V3})JFGJAilby>YP8tqT?c)B}G zifmMklkrrI_VfJauFR8h7OTS%(U~ z&1peX z)Q<2T2TfZCaBe94Is5gCrgO5X+ozqnU3ZXu?veZ*m+01FZery-&U)IlFHy-;g`xKG>la!H&Xif(~}8$xK2DV;~r_X&)SToCC_iGGK-t z8+u~c(e4xb#fHKbRExw{W7uRoU8rPi+#;8iaR7G;@ov0)nw2MDF2RDhGY-rR;|egF z*5_YYYEy6Ef-8@KPI`;JdmAlnDfT50osh@s9W(@R8qZ-28AlaI#z~6QbnK=hIoTg; z6XGJ*C*V+M?N+7xhN(S6FE#?S{u1sONp%ebp_{Y~5V#CbaOq35PZ?M?5fsmm((TO`* ziP*LryA6B}U69`ZWFv#9k5C|MLNi#Jx(d1p&q=(s;_zTa;TPe*d2t9}#+O|H;j@6v zuAX-BosR`7tnqF{|oS+kHR{eORF`o45xaMR(23u l`dRdbtrL>Ux%Ugr&XM4QMTS1acCrC<0b&8`!u)&D`4?AfnbiOQ literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/coauthor.cpython-39.pyc b/sgl/dataset/__pycache__/coauthor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfaa038867bdd45b6c571723d555c2119ce36f9d GIT binary patch literal 2624 zcmZWrTaV<#6|VBF?Y>R#X2YQ;i4-ZID2U4{5gHIHfv}JYud8O-v%YOr zH8a!LJkTrgg!lnbq?t$f75oAoc&yhv@fY?f=ajo=W(T+G__(TEK9}!&r<2jBPvCm( z@dx799wGn3hpU$d5C4Rrz68Mur#UHT?9fhkW6#d{u@84A4~kygvukc17Ev5o-OKyM zARZLMcvy_$Q8A9kl-wuW=fM-g0}-v=c!$#mWZL@;zC!k=)cE&hQZGOm>{nd8m*_+Z zJ@w3}ev~sQ63z~?T$ll`j?27)M_rlF&X@Jd+>%MjtAeRI&otxR+bN!X^+Mn9Qu`nT zi76+s!zbU71Ij(_KZ#;jcs$^}CuB`yAM}t%Ru6ce51x>?$A^3b_mGeI4&0;fh)3c+ zcyhyp?^elDFDhxAR6TWOa}#t{rQ7e~w^7I8!cd=pXeg&&Q902N&o%v)YO0-vterEr zak#q+>*o-C?KR}cU(<%Lpm9O=8gdpk?m7JepCTB27z0kg7H@y*$Oz7k-Fk7ldT|rY zy)r4pURmiqwXEw(YQgWWqI6^Rd^$A!Pm_Ep-j@=b@L`jXGK6DH$XHfFV3|bI&8=Ei zsZdIENnv%O7s~iH?k4DxU~aQ*JQzC^rI3jho25R}GRaDo@e?@ut0cEt#I#TvRY<$1 zM1EjAwAgsEsx;Vxx&4uvh_N(Jlw#~h@^m_1RE3x;^?3f^{)cbmrzdlDn9nx|m~Wk% z)u(a{rVn7KAqYZ4_&rC5G@z>+TT5qKKU0c(HW%svjCWzEJ0LVUgG3xT9Ho2aos<7S z^8V|dI}m(#-wX@Ik~HC>$WpfQW`NMCZ`?}N#yhIA5}cOFG2@w(yKtQG^Xgd0DV5K` zlf*A{;w4-h<5t9e7v?{}oN_=AXSI7__p^JRu^#tAp7q#FqdCNLUL;BjX`wWS3_hjvO)yOO5Xi}T z_$nA;{90zEmOkiHw~K+hlzA5j2HLei0D`wHhPhT)VuX7179etU^CBb{o1Tfo2=*e1 zZ5)09MHBFnu7L+y+B&VogSPjN6mqmRF(>2T>Dvb+W#&{g2P48bHSpt-=e#qxm79UxbS$#Ng*Fi`GJOgLMRwk*^w2o+z|iYu0_3nMhF9Qp z^#TY&C$taXxCtQf=p8zy4$O=ZE?1~)^Hyrf0pEfPG3mmvX!#Q;*liv;@`Wp?0nD^s z8EU8NKaB(Z#XE3q#YlwO>h8z60L1TQWxdpUz}FeyvuwNv%nvnE-UZwv$H;e3AiqMX zSTef}x`{67zPDxja7*^ri(_)VUAzfLY-$8T;DpZV_GOJ+Yb{_0_C1OJajfmr0X8qJ zRmM!8v7+Kjpi*Xrj6Gf^dADMFg9#ymfGmA#F=T-uXL!)RQQQP!Mw{M~iadvIcuCcm zCW(A)8zz2UO+bWn96Ge`!tlZ%+_9_EiRo2phF#NmC@saBpw7}qxsY$;DSrcDZbB=Y zVJn*<;n?=t0r-!iLo--K^>G5QA< RmF<;709xn_p;tcd{vVcCodo~@ literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/custom_dataset.cpython-37.pyc b/sgl/dataset/__pycache__/custom_dataset.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..902fb0af8463d50c2b0e0e78f2ff3375f98fd2c7 GIT binary patch literal 7129 zcmd5>TW=i6b?)lh^jtY4hc_*)OxBn6IJ=~=abzb^HtQATUBxoBmS}5tqD{Ls)g*_U z=^j?sNaSp{A4~-{0t8WB36j+3c13m3F0HZC7>O>ec!S?S+25UDtKHx7c55FZGw(%ewCL8vT{_ zN`JMz+Fxt0G2>l>yS(ts;DvDM+-k3*?(rh(#c&z*4bC1Itf%=>1VT9&E~j(ffc1Mcs4F&)Yxa-=FYQX+Y0S`A?K{^psW{xvdc@n{LE+< zc!gK-_IQmi;9dNwVH@od7PqL%w}y#~`~G`zKNbueJ~g+t_c{gQh#$R401Hp@7a*yf z8orU5piG*nbBugjbJu|Rn+lub!{ z(ppeOA7A(pl&?zq<;$}{sH|QXsnQRF-Y~o^L@ZR@5BWi;JAM7Fa+2X76rv2a%Jcnh z)Rn$3JgPPZBJPAq66Sm54+43ZD97j6U$yDasLuBf!YCAh4Eq-bAMbM4>E*68I#5Ztn0|+<^}BP$R@jcVQQOF&HNm6loclfWgo@eNY(tyO@_eAg>Cf%iCw0RX_;Dn<&m(|Ft-x10`6bHljvh=*=*FD zUBKnsm#dvF_I?1Y6D? zIT_2$%*yP{k%nX!IMGKq!!K}(#hcQ;+LEypZs)it4l&9RZK5eVa=Eo(9Gl6$EO47n z|H+d{D^9c%)o5JuM0(uOz5%rie&MK?6=rz4n&)q3c zd6=~{p0z}B@bVhjV2M{Yj5Q-G&Tyy_hv!VZKEp$|WlriGm3igZMDMGza#2kkUVC9e zC$)=()Oo>Bs>9n^@Oy4rNQ)`Mh=srQvXWc`wpIO3UQ!9vaa% zY(uUbmoV?onUPhPF>1)wv~tei3)gs!-VZ6Fq3_3)4~!3u$oxRJeBZcle2|;M;%>{? zrQgUo-HLEv$zn#Yg-=3p)+CS7Ob+8=k2ep4Ct)*a?uUUKim>@4?8sP*ZooC-_>SPqGTy#&HaHW^dmL?t50)FJ@z2G`hmWZD!s}uK>P0nMqjcJ+~6Li9; zbJiTs(~X*)IDR8PKIE*UkbGw#&Xnu>GWPvX=-iHqz2I@!qvd=;%R$U!r;42^^XcQ( zx+tPae3eLx$XAFEgb+0l`0B#|UcL5h_+%lwaipEPSfl|9M3!ij9r12onbAPGfjH;~ zr^L%Bq*uE-Gow{NN$`t#G&ASpf&9S$jMJL|^ zF<6Z);g4MMHM0ilnQ+o3wOuzobKPv9RKZ`#T*p|G)%EB(*FB!LqdhEUC2#HE(e9v! z0^EU>AFGs!$1?{fcL$GWj*V`~9B!q~u`AxeVZhPLOm6G?ebjB#EsleQ{>P|0sN4Aw z+PC12-Caa^m~^>|ZGwep6DwXu%fcjBded0xywogOZXREL7po8iki~|xLJ`qmCAO;d=t1-P&m3)D@vw8Bwk&-zMFk)4waNq|C|3ABV1VC` z`49-tR~6qs83w()N5@{}VQT%15wp+gztpC9&ibH{YiXymC6)P zHFzQB2@#Lg$`z;LdY1PU-P1wi^hX;{zeIve;UsHYyb;cXGezp^yQbh!Jh z+Af4Hpk?7#riPbx+Y5_1yadoJp{yY$%H=W0tnfNt1lX+dCBBSzO@qz_7{H3EPL!#k zvv?DdXxIsKyhgCikN3wIF2GGRr@fpN|Ron;!>^Mc{3LTw!{XJYap#cet-x?2~wE& zb0Uw3G>Op30v6G>rj1V<95Ma~PeN|lfZHML2**UYarMgXPLU&^)OCzit_TsNqNz6` zuVmr5g(vwMh%{1cl};-fl8w^}Mt%m{Vj{>R@1YP6Q*w3%VVnP9x{x2%-^YhIxCDK5 z08z_RMf*Fq@80>|oxR)l$M{c6etGuvp;N0n(&Jz5)cuv#F4roH$aM6LZtpEzqLZ(aH;Z3J8YM?hM<1{PBoJrk#C-^A|`c^zb; zwKu3i8;vH~a_`z2t0^{Rr#(dtUh_4(?d5dPmoX6;ZGyt0S@+f)P0_cb+c+RAQHuszXC_e&F)@ibauSlMM z!cK2c?~d`&>zONya6xdra6N#VHw{@rTy=kGAe!R6Vf@ngJHTvaobZQ`C@XL`^G;tK zf4>fB?E#|VJG}MyZyK3*^7ndd5#vjE$}@5owA}wOBe!nJOf4N8Af!9O>}5#zqpb2X z94N! zfEMUr5%WLST3E&`4a~MOqXocyogDnZyAY9KX$Ci0bmYvZ|jLV(`QR!`-4o+Iq>;>A)fd)nAHq0O}`GyZxyD6L(rr`OU~W)h_hBubWejznowCrYJsQzuH-IIfJQiPGkeH(#14jVhB{ z62gadwJdR$Il*M!oN858L6HRlO9fd*Xa2nqU zk0hhgW$&ms4fc7Wmcx`8m^*6CO;5o97g><13qc?#@zZw|;u~;2f@0Lv_mTS|^hV)7 zp-_K-fcQ}@V>p-yiZ9sTnkV*yDLU+61d=Or>+3ptU5Ojyi9apfBrMX2`-8Xsfu^MS zT{imKoO>MmC*PTZn4p9pD1{0RthzbH%urCG`m zL39v`Ki1_DZgUgm$~;XZg^NEWq6634RQfiNcZlqOw7!rBsX8$} z+=q2!dr5}|Ntp5`AEWh`;8 zNxAlFA_F1)L7ZGdcbqM?Haq9|-a&s*9R!?%Pp)NEBo3!OUDmp1dUo+=B=~oU+$Hiw zkckPC_38}iV-%9V0x?*_rlfciNpS^ebQ2eGwuWDIlS(%F8m32yGX9oKJnXvVTBuD& zdPtv_SWVY$N~|%H$JW_8W~iVC3AVY37K|A8u4B%cSw~VnlX}mi-~a!~_gvoG`d?++ z;`_8Iv`abBdvqVwf5r3iwEFs-vzbqKq zr+tPy=yr~Kp?Di#szRJ>(XF_` zaA&_o_iS5_ae1fzuR&i{335eh<+HuG+ZU862)Ytii<1Fp-WtkoFUb=`eX&-eVFYdS j^l_?#%RN)x&HTdREy_al#^rY)bjwtSvDabr_1*sh=~~#z literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/custom_dataset.cpython-39.pyc b/sgl/dataset/__pycache__/custom_dataset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae0b1cbd149fb370b2194c7d2872ba14ed47c863 GIT binary patch literal 7049 zcmd5>&2J>fb?@r0>G|aFn_SWIL{2Qzk;t_XtPN+Q7}l;<>ue}RqAhE%a=SUzB!@fG zJ*=+blA~@TFgLId3B0)(wi9pdE!&QQBR z4nBCA^}4>g-h1`&-tQHIdc9)c_od^v!hih(!}xb9O#f^YzKN9l41^n;^^A7@W^GfK zt+uV-PTy^pSl;IK%k7HJo4snk)~@yI?Yb^oy+(hoJ=brxo4RcG=KBlnh5llDQJ0|l)@Ts}Az1t}f$KpqM62d}C zz7LYxso@)`3Cg6ITIYUPfTV=FEx(J=M1ft z)VQdo?h|&O`qq%@NDXd4C~M~vV)g0k(XhW4idMYe3VUHcjAYX4;x9T0dR^Yyj|K9B zp=?Q7o6R{@_VI-uVG&eCzkGQ%2$j_fBUSls&>M!gg@}b}`XN6Eb)&DpRZcP-ghEuo zR(ZbPjk?nJg-6AufrvX{l7#u1`GY_nCd%&=%Q`6`B2VoS7K!$!Y=yfH8kRs?t zewUwOR7kCJWnYGgr0ix%*-6;jS9TQiLscfpbCKt+3g}HVvqc$F? zQlgKnWwTLhT8YcCFEyS~Wo9aA3_1?tRb+BbU!LUjH0l{F`#gRak?r^ban_T^M$6ND zCp~7E(VEkK*Yox(u;%Q(ld;Uqtjx|FX-IZ~6Mlp<{RF34d{x@BH5p4`iiS)aqL(Au zL{oO;a%gJ;P_-*ixcid4eFPi$P#z7tw1S*ojCHc(iBfu^ZfM`PhFFF zsmDv}2B@cLc^ziSG;*mU52IF!QOhI;uddR5Ec4pBv1(-HDGqhw@R*6$rg-SO%t@W2 z3a>pd(fW$4Uer>DH=dZ7llny?b)GQf=3xCS*g-cfrR9{N$K2m|Sw%L%ZJsyJ&7XcR ztB#s-;bJkZq*bzh9~jXM+mK6-+1bySk=2+n+K|g>?VQ0fuJ8uEKQKPTqGz=q)}l)O zzHVGcEr$WNjdP0`(ZW}Z$1M4$>o|1ue2Ap2=KRj4vqQg;bGjAbVAF0Iy&9f`;;coM zrIj4U!ya!P1}9-FXzhoA9Ez}Y5_V)PMmJzWVIf-Bq$gc;4!iOYws+hMr$BU1TQIqo zMwtd3H3EL*=dIv6QkIAxDytLslubrvw2fg}=o56psB_jT#_2|_P8`3UA1^ZMQAjp6 z5NFEueHr`yC$y7CSz2#2PVh73;`-)vPEij~o^`Jwa(7*<3umwD1s;`+1P|t+%HmUAq(=*r1CFE*&D&`vc znyiV(yK2&lVLRGz(6IQC6?c%dY3LyX(_rOCEM;PE>Ja6oVQ=aP>6*;pR_dG>;#)Wv zIEIgImu6=)&8OXLG0-vd!~?L*=0HJdBO93bI}JHLl-k|hv>RoIH2c@OicA9J)Z zV{xy1x@qatC~Ek4ce6?mgZzi@10wkSm=A#heO2@Q<6+RtTeM$R9+n4tSf1sZn;eX| zL1dQ*xl2_qyniR`^?+F}9W40+LEA;c8e*+oyvG_edcgXMVK&^h+ej z1az{t#g{@m1jt-BKwVf^`KSlhQDKw5HLB@Jl-ZZFCcNDF}L zSCMMkIGa188~xw`Y%rNzc^Ck+!Y+~$1aDy*5l7@!S|i-NrANU=T~&78xpmv${qWt} z@2gdM`R}~t=XLM+@4k2Y*4;OERrB_(J4Gjd_dD<2yDey|iajC(@ids#)=aFTs9v5R zS3SE|H@jG;Z>|w}31qXBA0t9Hf;NlbL_#9JM?{}5z2o!-)Z32OeuR`XKnxfk!j>>l zgezBO)^>s{H%;+*^i{41;hdrwt3(#cLV6P^At)q`6zilzk0r^9(1Ve^!J3$G{D_ap z?PD`JJA&|zAW8D0_{YfMSk6-gxd^p5nLz*6?R$6s{O<1U_aKf&tt8Xb$s+C4%4Bop zdyBkQBE=!vGz>O0j2`{|tfJ5VXa<;(qobttBvTA+COGg*D2wOdKy4rTXdJ`mi9OQG zyY<@w(Bd@`V2j8f5xF!sddancA}J;a_?21x3wJXUOubn9{~rd{7V(t0(iSn;Xmyt7 z^YJfjWx~I-rIMe6c8h1aUfCXU;&Y5&j=bkel-}^8i zIu6!&6w&dI7`5FpK6)*45d%9Z>?TYG(Bn-*mf;%RUl?$Rc&{73F#ZmxnHk4>T_Y=T zH}i^b*MO2epiO*(x9W~DyYV@Wf1dLk;-@FZ&Tenq;whFb$+C}m0QG=t$6P8w~{`kL1{tR{< z*A-P>dqnzx>p9EjAaN5K&KDYf6&h;2vISm8PXj1>9?!xmOfG$qHnsMsf#&m%+3{au zghh<7g!l54?m+7r+yBRu_PQ<`uf*5EYbjl(@vwemM%)KB@cwUr=-CMnHgZeA_0T`p z2ONRikADMhD}}sASk30ed@ggkv~ObmR;NZY#p7Z$%1w=xt>vY36*pYQUzzC&p0!ov z_D#$(^5C|TuBMA5-(&OmI`mtg(r-@dXNrHG(offARw3IrAe*CaW=Jnn{0}67bE6zq zY5}neG{YLL6Nkv-MxFz&f02t=ZNx0ltC7xO&KJ}5^Ky1gE@ABDBDQ%6vCYcGYI-ew zX)3l^du*La#5Nl`wkf9@I<{HoxbvCBHXA?O(6J5P&y8)YFwiZEvRuJ*rm6Cf+T0D+kl@xe%4oBNpq8u9LGN1{A&f&x-6OKqmmCM#q zdE%h+*eZt?Qz&(`d1ZLA_PC>hR9y%HL5ZKfeh@cdb_6-08S6&|jGz~VDhV0W16co$ z8nkUM;Lo10zcr6-+zF7`-DILiFH@OY*R|W6atzu*3o6KHxhq}>dk8}MD@Ab-bbh0cUy?W#J8zP?pIHXxP=^IxF}4-H!16gAUX)e zw{(7lTiZmrG7tA?#>HDiwEwz8xwnbjCGuw=o1e>pz@H$mYTAhx>n`4)s`_{b;;ySB z+Y2)643!4re1jnr1M4((R7~6cU&Gz=A&w4@-B%< zfNTcw?#IK3tD45cU7ioH6m&Yn{;-ETO58Qmbt6Pj^CW#|D5maA-P8PovWz9}Eh%PR zNn{|TKZuh{+>HaKR;R}-?hEt>^+CWn_~f&!%EaN+r^{IPOph-9f&_n$2<@hPlW1_G z!34!PA0s6{2Qk=^O`+`uLfZRWWjwY7(fGyVSneQ_knoT%l`w&MO>xM&k}U9INF9ZHr^jZYJsw295y-1$GYRlQgM literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/dblp.cpython-37.pyc b/sgl/dataset/__pycache__/dblp.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fdc5b99ebc6de95650663405e9cd2f44edf780f GIT binary patch literal 4367 zcmZu!TXP$?6$Y@&7_IMKYhY$pZpIxdFpqt6m>~O0eb)(0O#Ob@ZsHBt<3QE z+W!0gt80w?hZ?7UHi&mnk`R?-l80=-+k)q`)wWGvptUV8%uR#0< z;ZQrBQTi}ex-jeveWh(^J<#QD9QAzVM_pe#o|nd+*Rqrg9kfVeRRV!FeN>Gq8oAIx zR=&_t$vdb&pE6y8`p-JF8~1uaH|T^dOWP`rQ(YN$(uZE$^TxqY7m{HZq+Rk9`lBxu z3njUUDrFhZaK3*Lk9fvBky@#pu_JL@$Z+PXOj;-Kr_{;Vqv8>9xEVVxW!4G*y@gg8 zrGiqG3@2powuG9)s;Yr%T#z`fO>NVsRZ|7@p#f*ml|*$tqTm(~afEYJMH+TJFNlKF z^A!1kUK*;n>nDkyJBmk|CAwgy>Sd42;dy;O@>M7GJ)^2Fd#UOKkr&7VEcdt*8mi=_ zev%T|vbCM~Ved>(HIKdEl_k&XhMgqwynnH8Ufp~c5B$v}+26c-{pPjs@L)6PhnrKO zH%X_B;h|c=)bFAsRaA@%UKcB(%56Tbk#rlU`m}_e^JU-n(!(JtiW%)FKOL#4D`+bI z=u3uq4<-2_Dl*OymweAsODW88#6ROOi?z;R4j2ctz0OwY=nK2gY3le|G#WhcRWt52 z{m>uyQJORZ{6{#5KsI}^0zDk1O(TNhnDd**=DzqkBuh+ii8>!I8qF?p zT+oPo{G3M6gB_DzLW*OA(<7VonQEj9eM+U{!l#psPbaN{B9Bl%LZ$0x8j;Z2fmsKg zRJAa8K?BnBk_KuOOmCu0#fGP@8?k4HbWXjQ@MIP$MtIWK#5>8ABhJ}=^*-%lNinexUYYK+HW8Vf zS+o|uhW|}AvG*Sv-U9d?l;Xh{{5B*TPTVgTvZZ^%CndADyUwJ8f}OCAOJ54otdNzq z*aeQ_lmjyiheI5xb6h;n1?`F_9QqY$bpN0CX$>|hn>!A--&5>Qk{t0#1(H>vP3>Td zcp%;D(8lO6smls@S?Xpr{OYU7@iZr^8copyEo;yC{UTd;+ zzAsw@nqhwu*#dq`tDt&hE6~_;AzM5ypRjC>iZQqRQxz56V!z#QJC*`<1nWh#@tW*p#LzRXbh!Gv)fgd zXL}_&=OzS_c)fTONnJIgfy`H1>U6uK!6*cnk~4YqJPF?&sl10~Yv&-JP)g$z080h~ zy^^FIm3qTCA?P{VRIi?2v*&k!cziRVHta}=eNsP405(dr!+}Q_#5v4Y=Y1_yPxJf= zJN3RM9lm>d1O5W7Btyk48}OozUsWvdm(ZfN%Nv3yqRXPe7fiAF70a>8ygcJ_f!r2& z-LeIK`8KoDZ0^T+o z^tRb}0Nb1QRi5RRzde9An<5C&>4ZAEiNTi(CS+Xfm%ZaSRj%~fm0P+Aw{GZQ$}{hCXj{HFy5Y($DW;8 z_4L}(W|pENQ^|$GTuG>sxaFVV59tFJPCiq`fdl6EdUj>WMyu|5{q_32ejokou-Wt( z{=RweUi8`qWB;Va`9BAZcTn;Km0*G=Y{+{SFKN5yn7&G{g0Gdh!)mW;#_Yr!)_OHV zJBdH6_v(hOB#mLS*EF=7w1x}41uN-)7y~Q|!q0x?ET8j7LcX zqW;rJMKZf4!_hr77t7{eCZe5Cg?XgV>FoUE#w~oCqX$VKqfi8cIEnhi>-=A3*2sSk zB4Rx*SkKyret+XVTUgM`?p4rs(00XIRHdJAzpy`L18$^k_PnTuzSRu*w zWfYDgsU7G$)P6rp2a$}@ek47}tKA?_SrByHu3dImDw8#gpiLjyq>83q=pa*H=%|++ z)L$-{@u2?a9oowVgSa1uN!Qkn%rd3xqfp%ovOzG3N4k=al34Z0Rpx(0Q?XI<8>kd3 zcwq_tN9#T>SYRnzIR!hmPAUZ#);bgRDV$2V1-oB8rWszrPHKgH%70^{<)hS58iL`t zoV~6k7r<4vQB5iWXSb~#`gEI;Op69QLf3QI4~T+qK+G|YR90!&3xYU}RS-y*nkyrj z^`ktGN^c2ByIfbyQoS0GZGvDwN+TJnC@|`3KTtA^(;yayVD~Ug3{?wMlq(`(Oec?$ z!E-_7BKG={>sk=>lQ7SN;7j(O*S7Cv!)QCt54LY#yYYH*bhw@GC)+cjw@Ih1(UDxk z(r=*Tq%q?bcb@t7n$_SAx1LR!B;nRvr>>S%7pN7}ic#VLFSs{#?rnTpz@L`9zWqMj|9NQ0$cz5&T{6KqzCPnL~#7iYYr z75V%*t)L4#A>AxtoglpKJEYT0D_z;A{5q+8Jl*Bl_7vom26G0%Rg8Mgh{>Xk@TW_K?Xd zb6W6Cc7j7ZMVJ)URhBNDjZn~-K4lB1us&x3nU2dbc8m|BHX z*u)24Gxtq4bq+?DS3z-6s+j9QqS4g*j3IY=r+iv7dwQEpxG2~GJUseRs1}u?c9mV2 zQJu}e3Nyo{8Ol8&-n^LcfmYR14!x>0dhi9t*wi4ztsToZ!?3rvhBa0wl+1YC)%3DNE5lwWiA#I9amuGN0+n`MzQqpoaayDwgnD zSqC*B+knP?OU3euf69slk`a6V9n!DP_J;JLG3WGS}snU7k*L4IV}{c$F=F2B3@0=JhuMytLgfrty(7w#agjWH|8%{dc|R?{lq%@ z2WQg_&L*!ayVxKPQ;Pz3Db4v1dQ3Ndw~_kg_cFVTUV&`quw$FUZ@&*I^M7APpwaI` z6cc5wJ=P^NU+Ic8%J) z0z}upr=2W~fJ9`N?qkIS=oc*dN?_)@?2WJb3Z)_eUdDXF$G{G%w|zgI3a z=F+vZJ6HQ>d*!;EP1qz42H7|jx?x6RQF2=e`~C57oB(QxGkNqP3Ev;fvWHe{_pn@0 zQ<(z%67f*4)K*+>BD!KE0 zjri4kcA*&`GVIV-~#xgFeDFp7bTad0Pgq$OTJH#2L(49-61E|DLb)G8S?S-e%x(m{RAK8 z^8pUe`kizQ94p}XTjhMhQ%_ZdYXGH%`^g#DfKb&v*}~4+FOl>x%pYivkhGy{21y&z zh5w6eIJtByiGby2kg|lB290-soA6XhFLm8A#MTXry!l>G#@@}~*896V*WbOPt7(=} z80qQ&9tN}P<-JT@9|9y3dX0n$9QiHC1?*$6C&6}Idye+2B%*=+bFBYDBGBy@UkiT$ ziy3tH>)mU-%h*9DRvoeYu(UMhs9f&dSWV3>(9U|^sdL2xh*lX679 zgQ@a9Ro|wHOsZ`lVe@K`lyu9S;=vK{1X4$GoB9Zm$sbZh=~EH{(#_f1q#q>-UPm1o zFxA(HGFgtmWCJCC1(nCWDmkfF-|*WNi~9~0&-Q+89$MW+UCr_>Lbcj4*p|?wc1Lml zK@!QIp&@rn4Z62``{6dI4?Sg|*Eqtju<+FAv~lCkYhxHs@pL*%AcIGa2v4 uX_%BR!o~CBV-+X4dHWeL&*|AAk;?eKS|0b?99_nE^qlZ`1FtUgv;POQPf7Ry literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/dblp_original.cpython-37.pyc b/sgl/dataset/__pycache__/dblp_original.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73e22997df25d6ab853919834061acb257a4dd4f GIT binary patch literal 4532 zcmd5<-ESMm5x>1Veu|VRik7W>UX7wK0hyGQG+~O>S(#UY{uMZWmN0dMyuV`QN3Gd)aaO^P>9P>8^41nbj=3 z!md`Q$#h{X6F%vuO-Gl1B~l@i2SFO>m0!n?#VC=-)KenWOX4V%K|l4U;TW{VPl6~4 zdXdoXqfn)~_?vMWCUFq;iShX_KX0f%qalofU7NG6!|kri>mrALUKBR0SYzIBn z-rm~UZf)K9;NI4UPke72jD>6&c1D>{d4acQPV}6yQ3P_wX%NQXjQuClOOMAw5e_0w z!!$(uwiis&y+qFRy!HO()=%$xal(aaZqD{m-mN?DZ@#yEZ!rwmMi{-jd3XE0#b6L~ z+HN2u!8vEQy)~bo3Hv$FpKs>{Qg}Rk7IFh>n#=^0LWw~3_h>r69vmjwgHdjJ)bd6~ zmrbL2Ak%OV_CsJC#;J&+5W4q48OFeR6p)R2Qlv5z&w{8qli<^UcpHs9pawPjG6eWu z?-&sFg5m9Ca2w~v6_$=k8UScG&5S3MGf<2K2jbX}w#ufvD~D>ww8&&+lBC`(h3(^E zbI1D~b$KXZMYw6+;T~`@mPC4yajLy!5DWxXf(?#EhA4q*)=Zq0+H%bDoR~EXMMM}G5hdQsK z5Gpov+EpSNNQ%>PiK+@!Roaf}lJEDUKq=qTL|zyc07tPiEe3r5aVs8Ar^^Eo;2H|$Z@t#P^ICr!{o`pFL0d%1w3o@H`^Np7 z{yI)GjD0vpKNTwVLw-a&G^HgbFR?q*g7G=y^oO=>T%78~<{yBGvQb&~bi-`@eN+nph?1<*2tc~Q z{;U(0GFYXxZ$GywodVpf({$uA<|Wy&E|5O(*-hI3qFu7Z*I_C0aZcRSK6O&} zGzXdUJfAs#MYi zZFVs{R$topt#8@W^)IZ@I;otmI;o-s?)&wVWwf9>S}SNl6SQi)o>d2SR^bf%s`xKo zw@e@0X1#g6G9Pyq;~2S^%WCs^_r7)I4A+d#t4F8lI(#MU5Ssr5XI(#8%T}{WR!2UJ zGH*3Dtgj$Vwze3hmKhyDi_`08)=494?BC$^1NU27)(mEi{f$egU7L@+acOKF=)IX; z%hrcBuk&l)kgu+1YsAxk>s$M2?F+guZ*%LQ1g?M2tkbtM1`A-|^$vRPW^9N)Xgrek z?EE&ZcXWQ1YU!NGShY);qb^^hiLdBp>e9loOUeR5?;@4$@fdsL4o|u*^3}xMs&XH^oa0(X5vwku-oI?DdmGJ8_K7Wpbz;JWzB_r9uBd zmjxdR1V}E9OmIWkLqe$w^oD`#at%8+ug_e3!$^7&8#D{C`FW_2_h`Dh=o4`ALmJDe zaTEediUg3GA8D6B)J$oo;|ElYse(QozpvRY^pbs`3H)=rKHg)JJfM6)6~&xOafV*0 zl#Jci7IS>VDy)J3P4Fn& z)5_&&v2e?D+Jvle%M<`$%+@}G`(ZC_la@^FTTR3P^52wykmA1oUm_*u>G9zE38-gb z>O3Hv&mht1G*{&xh+WzDM|hVbu=={<`%fo9G-Iy|zR#0BhMkFw0AVU+5i2C687Aq_ zMSz8rC_R>xO2{{A(|W4d-ZXn*BH7RHOc#qm;aX}sK9X4XB67kTNa#Y>{`q|jU8>=Ub1 zet}i3Zj0&iL?(P1q^*Lk{ZgbtCLj5!uNQt5KM>dJ2?X_&-*6p-MHztXIu{)C1wM)qvZdIfIx9KI}y_?JX^&B|-;eC>Q; z4VY27(W{86sEOj5#miqaUg6cJ&J)&~1GQ)h(MKV=0){n=1!Fd=D-*oTpr0GO$bvnM>4%P`SRlV|MjaurUOB$EhF2kc{=G(oGL0LJlzZdWJL_+6-Z>)!wlS!;*KcpWeMaJ^FNMp)$00YMrb(`#6iWCq*r8z`YjC(_AC9^1Qp+71 zTh@)`zD&bG7=*w$j8hRsA#4x88OFeR?316mQlv5zkNv2XYw)p8vW>+qP=lF)3<18| zKLCV%e|R$)+{Auyg{c$L1^^mKx$}fF2gPWxD-H~6yJouEa;UaUi(Ez~N$TEESUw)M zw%p%SmxmHwgwy8k?Eoj^Nn{r}r`k!T5%dtQ@*fMHt4o4#9AiZmxi^B@VO)2U*)fFV zZFew@1Jbc`&YJG{shj${*kO3AAD^j361S56XzL;L2zI=oNXCRw&?ZIN=b_=D3odAj z$(&%E${1dDNzU1F!0)k;3M#Z++GDdXInfn6vh)c2r1YipSxQ=RL-^IY6 zM(Nf}*KQJvhpn>-*ck6j-v12CnK;74kN+=W=?gl$#v@#u5gw}D|BBAx_G$ds+uLh% znN0f0ejCBAqdEsW?+oQc?5aDrr$f2fq0r$WEfC~zI0XFndGxbe-Ik*r@?Jehfz3+> zlnZqA3~=4Jgnv-HkOrttrGTO~QDt^!c~;7vSjTq8vO?yZ6quD3j-AXJ6**h8PS|I} zQ$i_o3$feoN`{%OieC1Nj^0H*ifW4Brz@U^Sde<2E*eMHPL%juFPzz9CMoTZFY1zt zLYUbqYDbA^ASq|dIjZVZHBdcgy6SmBZGVOT^>7EJy zroW2a3}X+0(MyF&y^!yd47rB&pEJ3Nk7?)CE9mNZ>F;@iFcMzukA>=hrbutgK3p+& zULOAs8_7;%BVCEi?ZoaLmQB=|r1?5;VC37&HqUqTeDQa|c)IwSS^T@WE~!AOl2;lP z$UgX=b;MEzue3(?6PvD80G)MP1v!s#Nq(%$WEJ9e%Qm2BhkWtdI#zhH7nA>OVTvOA zELJ+*{?oBb+FS)o6~$=tcj%T{PoPb{^+z_evgs1|xxHpRVINt4{NOYDsBq+@_HiL~ zj*HN^#0y#BFBW&66;A9Ec2rEu$Ca#jTs^^wI{n7QbaXf!Evv_gF zMQ&aq?n}I7!&@@aOEcaDtXE1Gvl6fD+Q&;L&@p4Yin+^|=GM&IvYA^xp9|hO@UC3q ztsCA7Z@}6=996SQR)ux*&w#)YtbuK{a|uKlHFG(mWwfvGg)i-+Ilh?9WwpB&YO{us zvGmNoYkkEY)kEv3emduh&WQ~EH)%hHxHSnLmYMK7B z>7$p^vV|GFbVdi|d@Ng>jaTnlmv*md>|IJ%5h3Ayu=>x~Z}(2TFFS+mxa zHS2Tel7TX#Qp->Wz~}htiFMS>nxkucW!L%2mbSsGIa<4b*p(UWl?$|0AosoOO13(* z`3hhCk|OkSwn8$!SHH3!Auf+zq;tM*Dx_TDU4jmJJUISH97?rQ1dO-(P{+O!kRQ-r57ZqM6 zMH9hv0U?Z}<89(l6hn|La*A`G>Y_LH)vmljlPLJeS5dv77MKb=9Jk((w1BRPcsdpm zcXw^8iLTH|_7DwqAwky}C1EUoKmzcxlLK8&5msY`d1<1H158Xc+t+qH zk;}x3q!bTz0ryi~P}8xlVwz`e*P0#Z;vt@4NKySHiKM{;xhzN$U5H~WE|Wc7z{^7y zRq6-3x+eHgAnS3lZ!#8g8E%TYOpg=jE?2N(>uMgmn?}>~ps!ho#ZPm9e4mCJXMF-r zeoVBYnnWR>q!16i`M!1tM9q{gbUmM{301JC>jj!^!!8*BP2gYL_VCb>zlI$78>%Ro z{NuFHvPo22(c^yE*8RL7%va?n;2 zU&8lG`B&biadSkfcst7!UafMsZhqHW!<;=y$uiHhC7fDZc5?)b$p#S8{en@(JmXH4 YyE(M)(cUWB*9ttM4v&Ff6q|*A1BBSnO8@`> literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/facebook.cpython-37.pyc b/sgl/dataset/__pycache__/facebook.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de2d3827c0708f01d5887187af9ecb576cb9444b GIT binary patch literal 3227 zcmZ`*OLH7a5w5C!_dF!aT5Q1rvVn-5Aa=&=iHC9U@Y*b~A=vc>HWO@gw5M95k$SpY z)h#T|_HaNVmJ31f52Pc1h6@*>4twFGD@R}8%j%htC7>rdJ16UvmHFl4C!J2n@Vs;Q zB>8)bv47H^Ftid|`Z73Ur+nyu4`s{nu58QB8N(gBat&W!uFG5a2A?yV4MXUDTl-Jq zQSvg+5469#;9iUVK>PXb?qoEH)61{n*h{kJ$X#ecU$V1>*JxF-ir4H^oQqmWeuvfk zLU2}DHQTpOF%mg7_{gnoP_K4C{hFNzwRORNLn9&H7T&gG&`Rvgf~8u6UPn9SQIRO+ z;W+I_(IkVVB1IbL%|hj)q%4!hwowsRd!=@aA+_6BRc}Vw(KM=5Jjo(#6phljEEm$> z#_2@9A>E3WCr8mZ$r2S;$zq#uvBg9lLk%p9hzp`BDJx=ouC~i0-PLv$PZMpcJg=ZT zIac5PUm6(U4F!OnAZj#>qQA1&y}@2SO$KFoIC%E-!*|n@<3Tx22MeGKmOkqjC#nmB zw8XB{k$DjV_QKsrlzukZC1okxgc!sW9&|L4SOCLy@~l$N~>B4JZ92O>g9 zac&}m%yJ=MS!+i-V@e$}=d*eLvw6=^_pnCY2hrW=3gWPmc6W0%jRCC47^%66`WDS6 z<6fCdt>G`ZW1+XrXJ1|EeYNM6y;j;27U-@<`**-x*ZxAx1!9|@a>Te@^IDuD+D+tI zJF^FSRaKPV9}LWmrU=$-+>a+q#Fm4i8ehQ{IJ%CiH}9^#fwQ$!s7Y2SN`pPCah!Fe zQf*VL-?%ekXPI2d7^M%wxXruV=d+E~?N^glSEkrODQT|!1(=ny_!hf>^J*@|7Mml% z?OPJExh1VRE-!8HPGz4vwS##MB+eGwcjs>H(mYUi3sE-r=Dzgi0rvGOzxGZ!W`|W! zd$nJONRM_M)S`Bz|0jRms*n)(TQWEmfByNreX?G)&pUOiZc|!)&a(GxR&@_T^*7Gu z9nNN()mq&#?pT*DeVG> zbsNXql=swkNbXZc%}{4`Ks2gOaIBlgFO#&?A*qi-A(LZmsr;3;MtQ0OBQ+CShJBSx z#(S0KV=dIK=E+So`ODa&@G2=B8D9sKLT(FC5%LXgTU~J%023IYu{V?LBFaW|S$YUz z1W6NE=tWQ{B|i}AA1ky0W55NKQHwAJ6cEZ#Mu9PQ?nrS#Y2@11B6PWguj9SS#av+P zG5d6Au;v3V|f6t)>kiTmo{dqfz`Z)tg^&Ys( zf8)_^(;3a~{W49^NNrED;;7m-<#O8;)omnzDpA=zpyRu=!S{#|3e^S?sz(6ww5U$B zgUtZ%7J~ zyZmhcFyw7Oa6|ZFwsn=)ZnP{&rUW+?5HU6;1S;wedM#?DsDwn34x?zA%Oe!dNVlWt z@F-535z|v>A2^0ps*}dJ6tC(5k&lRMg6PhoXBj1Fir#EhiBQC;e>C;mG;pMVX;6IL z4cboI@m>F}89^>_U`8}D08S&NcuRZ|&#CWR;#MnUc-}pD zn*6Q7*gxp!`eWng1AK)9kxcTG^|@)q*y`K7mN|Xb$h^LXR;2zo=m)0HO2cuZ-!QzL zHpi`gYuxU)$1D96&Ym&pNcWUUH)+hR{;K59S;xD;PHelwwew8HlU*>4?OY~L;xaCh zvSaDiWIv5m63b|3lqR|%^P?=y@muD4UB#Kq$5Am!M`a}IIRjep`r|Bqifu57(Wo!v z+UIPCOILcQLElPj>C51h&3WGeAIgT|UD=eaQwBZTas{m~SLGeF!RO3o{SfEBtNo|( zAbFYR``X`GaIZzTr~Q0qXEYea>BU#E?FCtM;4V%=4e?}e3!gbz$x2?axi}M*ko+F2 z__^S$v?{h|&(RY(75K=lY*4RqK>doH1(kKqe?vVXz6QRgWH_1Fo(9XK^*Sx>6o->U zDG%FeKZ-^fEEXv`fnJ}ee2^4HQrkG1#N}?G9b-`K*4EYQQFb_vN)?Z?2n$7nG%kvT z^tW+3lGmgg(Q@Y~8YWqy;xbt*6E2n*$zvP?6C>h+s7#8I*p931B1w0&oyFrs+bYjX zoSiOM-~C@FFv430c*U_|eZnYu#a?%MyZJcj6~#gC`Lhqdo1Pr^iecJY0H(LRx9;Rb zwPBDZT?rA!za*O6HShG!Qm5|HdIL^l7LPzncWLhPJ{ZY|e8vsRgtSimia!Cqa5MJy z>*nQ_*<||0G6apg_=>kdhOAoqjS9B&bMLhtg!$;|;CTxtb`u48d_n@_)TdG@O{uPk;$1HuI9Y&0O#+|=M)%mn~LLe^-T!0Gf|_gR1^(6 zR()&Mp-Qz)?tV*MxX&`qehuRbAB1t6x4F-!YnO^&4!WT&xe&5Tt?j?TT;3JmVCT5I zic7J`BjK0X;(odDo6T%@@I__adEF9gSq(gi&=BJS~kyGRikQBbp4oR@7t_w zpZt}xS&OsjJLO8%GWW15U1~pOpWr5Dt>3k>ux_{57NimvYO}e;S+@EfE56ub+`ONF znwtx^6@>@5U=o*yDk;d&!$iAftcD37m&jpKT+o<)n|dN-vgAps2o2a-YI&hGeV*qEUH*ZQV3}nWTjd>G;SZ zGC9_k%8#@)$Wt8{sTtTX>`^ir?v|PlwNN{nC%19tFCCH~S`csA^U%B9>t_;ux;h0;C;0xR#5^Os!qkbDC={j)^I-)qiM0o13<{IW&J{XNGS z;14?%VxqULsPHp@R6hWB@gF?et;?tBgI}cys;$jYHaRReO)lLu8Fv$*p-fcv0ONL%Q-4YSqTbqHxw zP)c9`8vuu7wxjKfRS=)I1HnU^KHv54n_ff{2ZltQLk7@ZUUW_L{~oPDX;QnB(O^GK z)KijtK$10-N?nQy-IoA~DGKY@^46OjDA^SyJAG%_2TAv^9HoURHBI)ZFP}2{H7z#j dmolyPAJvw6Orf!$Wa9HCZplX}*5vEfe*v%@MUMag literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/flickr.cpython-37.pyc b/sgl/dataset/__pycache__/flickr.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1308335de7c0cc4f8b314b86d9b68602402b81d GIT binary patch literal 3590 zcmbVOTW=h<6(%{ev$L}otz_Ab5+k+y;C738727GCxV95Jj$OsJ8egR>91O-w?rNmD zndDm9?Mxpc1L;eV_BkjZJr?MX=H>`hzfHs!r4GBub$w)4rZN{q$Y*CsrY2sD~+nJG8$= z159mNd0FdO%%%>lzC^pYR{`0jH6uIJqxF{vZr7kqC|%m3b5Pd4MKQDF9rB0SU#|_9Nk~l}ZWRK*8;*%)H`vfCt zgX}vw+NeN+5j55YeO}V}wCZY6wuJ==m8BhC62v%yz z&jPt2RK=&DM&(Bfhnr*84F?;$6sw zC`Vr}6j-7?{0Fp2^i%$sJhvYr{tk3^EM=u4Q}$*WCZI(gZ2L6i0ASMnK)D*bT6rDJ z^@?T+Ujpg1cO?pi?5A;dP1l9Kr-LS)13l<<^yKnMyH}#LA4K9B2-8{R_Qb5KYSWEG z*?h2_URkx?Uh9P3O=}c^?kC58bibK8^ea>#eFjPBnOY5iVgkwf9^T?FWS^9Rb%N;Deb;H8kBWmJL45Y$>Fak|IH(DN1prgy6E*j^v7Ss!&`C0 zK2~SWZ^-q}|*afl#}+5CjIgBx@d6;dgdlfJUR^>13u^8X0(e#O%AUKAbh%lBEXI46s!@ML3lg8Dw$#ot{hY}o|+)a34KD>qxfu=5r5bwd6wXfHAo_fe2U2ao(bhKf1-7Y`{&A(dI7s z?rzZ@x8>sA@q8|C>$tdvk`HYpPi%RA!u$qC#XLrnMR_uxH^%u6ZRt|)MxpQb^Xup< zl#ul|_BVzzxB7;!=h}0rN*D)2_BIZcZK7i}OlSzPQ#l}H5`2@`2r|}dn<+lf3f5$rtj}_T|_DXNnSyCi{FrX>Q5d`y#`X~f~gx43`2Q2e&M^%8wq^Yua zx~;5!8u1VGC_Bwq0>K2L6C7T$SSSlnuSfI#hVt&F$?y?NM9R4i7{+HR3#by@Y7}Rx zy7Jr4Ifve}W;M10p?aY-7GoMC4mWm_QS2wcwp2~$z8q&v;eoQilK3^f>?eA5ozWrl zw^=yckg8@XW+WL4RhwCtvbNY*+4`)?iGpV=5(*C$;RA))QS)-A@!;36f%qk42rq(7 z*<=y_6uaaENScdAJ2X!EL+%c6Ue;4njg->p*{1NHs{tyb=&{Oz{@stD1-b z34>roAy5It+Xa#KUd8x->kz=XMJ+HVh!`dZ12aZr7nQ`+A42yV?ybq;ATrxR1b$Lg~Wyj+BQMiuJMmIoY zfxPPQ@B7O0{Wzr~aLK-E`2O=z5S2ZKA1W8#7bveB^D}z7CI-B#v#)eE2bpRfu+Kh= zB4ESDgBw1E`AI!UGao~ln;;0if{R;jqtdY5Q|_sfS}$^~S5*OX0A~8JmaD1g>IlI# zHq9$2Cxg%!F@8(;H$*?o#$AyGT)?}QCD&-?H+BC4u->lrFWqPG(1wQ_LsiFr*a5_t yTYA#`v=3UkBN;}*uzM5g%t;TgS4xJ~xmQXRUeoSJ=m6uwqcq9{)UJWs*1rMqXtTZm literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/flickr.cpython-39.pyc b/sgl/dataset/__pycache__/flickr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a06a0c053f1cb7d54f5db74f38838370ca5bd4d GIT binary patch literal 3620 zcmai0OLH8z5e6``v$IbwDM|`$#gS4eN1KY4k`+gaA}Q9}v?RukC{fA6Q-2(OZ20&Hnvl>@L~)3^#fZDad|Q#-E#YF$>6Xo$8&$g1y}4Sq26H5= zM~!a7%voVGn(9s&+74UMbay(M>CRwu5782JULxwSsV%E}jN(VAU49Kaqs2B>rF+~@ zSAnQ4CY0Uwr7xIl6V*u9Lyt3`di@||sz#HII85MECOy6O@T1S8AF&7>LtRV}oTK#> z>SJot(#uNMVm5VX`6b%I-4f6)tr*&&Ra$$A;DmMBfYGH*It63pD`cZ?6*Qez&fPHR zt#jq{&A3ZC^U6*7{h$~4VQ;dx-e1$YSRjLtA8v=&ONYCSXKQ){3q{oC{ zg|)oYm#adRJPPJijwjMEkm{sY9K`bmEN0x7Z0|?aliUwtFQA(s__-e%s^Uo|q$b<8 zvIPtK%8vbrDVryW1Z}j?S6Ed%uNV44c-||t`{Vi5Bx3VIJez-X@4>lnv^g&Z;k?!Z zu64e!U?&~%8mQfbCIFuZ;S;2e9c*LrW;|7B*_oWMhc(CELtxa_H5PschykGYuQ_>>t0Z-dfRw7bmuO&eY*5v+%M*|LkCY|%j z)mYX>e1*A|XqIp-R#mTtfsnl);=QEc!dkX7u`JSyn~{16&Jh`4xs{5y9$Vh(*5C1|&x3B6jI0nkjKS$~GR_zT$~ z2DM|SGccZQ^B({oKMkE~dIvxSs{lf^F}TrBpZdTaT;t+hXhL5(BI9@Wlsgb`V2~p! zy1=fFWK{e+5ES4+NeU{!)y46ZRVh<(WqzI}J<%B?$pF9#V3{BGe#JJ^ zgv)D_0Qr7Eul+86zn8@QfJfKl_{@bKkHpdny>RXWap$SHx8Z-Va6ewSd8tb;E=7;U z2v{^+&poBjpI-FiPab;@A3Tu|^e#Odup6HR%L`p#8E!7DugJlNlz+S=?#pwZ-VojH z&%Dtm(co?rvR|k(7gpudTXMKAhqsn5Fa757{osC|-@Scf^YMVKf0jJC#PyBxQ@Z=U z?%sp0?HmAvpMePovyqrU>%PX`xz!{lEc~#c;HK^^7E&TQ6+?EaSVhz?>R@-jw zpxfxn*)K?5&l_1ItIDaYzJ;%%XMbCUiok|<%fcvbT|nX==MgsV=g`a6_H; zX^h6Fxa0R|Q;*B&bI>54{{3_e*-?+@(Rr9Xq}9v?^bHo_VZDqKO)VC<(-BOS zg2*2*gGObWBv}q(8bAtF4iKq?&?Hv)2=_rLv%0f};$hRFxNgeTQ@Sd(KE>n}KZ46z*f3 zWys-}Zaj><7%DDR(Y+@}DO0$wEbuLUM~l6ySJxOFFmHnegH@?2=EHE3kx-R=e^b^v z8!21ss+`bYVWCiXpa}0P%-(i5caj5s2^)wGbO_Iam)T?%{}8+61W;8o)51Rm&jbH^ zG7jqM%0cN@ zo>YGBr0{$D>RNA?j>{lOw4g(9F|+tAc@IX1xMX~C5*XhqHrp20SH$%l8LL2H3&pH1 z13XVvJugb=5CXTS>Yn#(=!eB0Q!gqPZX8NrIpSyaa$O?u4|Vrj-AzHK8hh1fkA)#r z#|K5lNzI}ozA2y^Kp4ec+;j=9+R!ZblzXZuwr7oHRSBE~YSoW`S(Ovf(V2u_)H3uv zC?|y!88ZG*f7C@UNJkx!`dq-x79+o@neXfWGf)S2bg1d9z*U>d+_7@9FVk0z~Q%kWO# zo^*ztNq5+ttPR&Vd&Yz-yfY@ev~_9^*9Cvhdj2I&Vmm#q-DfhH?ty9Tlp=kasH93& z&(dh&HJ_0ALyXm-OWZ>l3#uWM_!UAN8VBkdXtZB{EbG6DPFTv>)fg>!SA!0Us#+e zyJq{&DMn(q1|NI11M1f72CL9e6TYBo)!^l_LD z;y5edqgawh^yXBSqqM5h#*y(fQG1nkjSIEc_)~Aj#cUEQnPf$bjp9+BRMk@YX_9B+ z4e9;3IX8~SX_3lArOQpCv*6hj2r#p_;2j>w+4IVBpp=M;o$kR58ug8jtA8^A1nbfXl|=NJ&|2FqfM7V zgz+y~ZSI-he66vkzw+CN(+oTo{&aul4!6JvKIRK<5M>GbEU5WYU)hUHq#rgN#|tKG|4s#m@Pk*uF&|_^Wm3JZJK&7)A`(Wi{2#ewh{Up-hfqk;$H= zt^T_@pg8;CJ55`-6e8XM*r7#{DE)p|l!I)p!L};O9*&y=k@l8qU@3RAz;eq)g9lF3g_ETssum zx2#31ta2*3p$b44cX*cve7?T4`)bOSu@oyP7>%+20&{gq`~kbbrPW+m57+{!Y~L1; zEo@;gV773;yUICtYZvoeNZbc(-&=UKNAp0v2Z*bMzX*iC2(hoPg4#djm>sFG_UoXI zkO-YRtgYG=!JmUgOCj~`w?ueq{pA;n_Q|?xpLgn3-KGTkf)(#MOmz<;`FGA19nR*P zYOU^=i&z&Py!Mfz!)m&k?M zQ;8g>fL1ETX>~~xhg(;}F=|yDUwtA}`LX6F+B$mGyQ8fF!AuoR5@e9<_L6C;T~*4_ zo{l~&i}A;4QI%5jkF`@4X{Bw1xTNsK_Y^F_RokeA+7YwKRQo^w<@-{WQodv6q{o@+ zbsH2ol=S3xNFGo+El^x_NHkU_IM&OPS7~19hzvoE5b3eDWqG9SQJL$|NX^8SVUN;m zyr(oDYfJ8Gp5BI5unIwnvWlXT@pa%T;*JGsMSO!hcGubh+ANIF*mbU52HS{CrH9}~ z=rr+#_5<}%@B>T!^9p(Z96&8{ef3zDI;GC^v{uK?0yBQ+1+U zY&K4Hiy&j_K2YcvWid}>sl8+h6bkJlCy!Y*lRTf=`9+3pUBUM zP>RP^*cYnluxT%-C-CuYnuB1~8)%+s8$-82wtl|*#-(RJ;TfGXhDuGRgM2v?Cf09~*UW+;?sv>c$qd1vkL; z&XT+tG3|s7fMWPWoiwhcSe5S*`H09Si0&+#l~J1KXvwaM5o$O!ji!8?zKs+xy-C2k zVcTuHffsC<5#$jUPDBqwSE>E7>YG$Qp!HM|wKvU12YD)gOp+gvWCKl5pYlQf6#!zI zy(X*N`~Lwn-;&CkzOx#Gq(4(xUYROm`pD+usd``A&~V$dGi0g9b~|Et zzIX7e^s6Rg|D>PGkBgsQ;43AFWRhoWz-=og&cNlh%p3St77PNkA`2&tLBsYrSu|-5 znwEF7)}%dXPdbCnWNom<*<&U>>7O#`r_B>*urB#i)(u`_CAQP$#(S)i=^mKoP9f7r ziB8H?cOBD|#bKTmiHvn&+SBJ*tWx~!jzlkMZk~N3^au?JPPqudO*Z5ywg zO;e=;EM~$u9_O%CtjH>6bE=AAT9#>T(0H2Yz0!Erl*X^^YBuA1Hi@-L#(4}y@i0ru zaw+{h$;R@M^mbgY8^@zGPgSDRrIKi=U@VWY2~3NLYs5M&HL+daxMiB{8aGcSsc}_N zXzZLW*4+HM^{w#Q@-H|v*87X&7wpx${k>w6_RI31|Mc;bA7;l#{c@D`mq6*)H`bdT zs}79OqRSw{_%}q0`}Uo$)#vH0>^9^yV)+QPbcgmHx4=j~;tOt(C8Tp2R{Rk#g&T2q zUPa60^6C5yB?O3@_{z6HMyz6=Z7jHEXZ$a0Ul17nE)LwMOl_kZKH(JQ)VTY_IEO7& zav008>Ix%z_f1HSK_+~`7467QQFpwH9%RKZ z$;yXdR&(mbh1F^rt7Xc1f$u|Cy-Me*rd#MJ3F}Pwe17X|S9`-81l3>C#CJd>EJ~(; zF9<2ltWS_h&IRo1>=3NW;qY&_;*~go zr(3UecILf3t*7Nr`+a++30yTF^^)-lv1Na%M=RI@N9rE*Up}(>E(FG#s&TFr?V{_{ z_vOq~W?XXbYe$8{tOD$}Fun{y7tVK^Cb``d=^?x5V$TbKF|R zrMSZuh-v$dglyqRXMri+)W5?7TLgy1mctLn?8ZE<15J%h%sy<-hJTZr@Kq z?X5*Sib4V$FHQ7JrRAU@)6dhaK&UjnPShv`;!-(E%a=58uyxTLBc;Xh#VCsqsJ2gIBiizsSBo8Ut=E%XOK{VFKSk}*yXK7ZN zhz^1jBGV(|sN&E#!y+>cE42fgmOV_zqdjf-$OyG-czPXl;mR`!uO*?F@pT|B;;sM{ z5#Qjh(-B+1p+JvC<3%W5`rwd`x_StRgk9@)xErKU$)5}L)deU4MZkv=tV2))SO|FXO}E>K5;|4+ zU0~w@Dfk%?!l2q9auWoAo=o+z@t|y!nkGTUmIXjzkQezpRfX}BDNra)0H+}AD#~)| z*EBu)ILT(I1!K2a=c%`;^Fty}h*0E*D(J=5Orx$osALH6Z)=BeRXP|xHV(S3gKYhJ z5kdzkDKjuY4RAw7^7%c0YlCm{E&dkZ7x5MVxFJF@zr6~6*Gd`0TLPMD5U~~}>?x`o zx=kvuD7nP3iQ;%t$Qe>sD%M8;j zL&dgD9hOj4X|t^@>S+0L9VFzPhTxG)U+9Nhwil7agBele&}EvSD0?;y{D^c=t~CC1 zJba#|>I0JeoFp5ll6n*odd~n6TMgE6<<^@WsL~bXJac2!2T5%cNC%yYw=Cz|2mmLZ~y=R literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/imdb.cpython-37.pyc b/sgl/dataset/__pycache__/imdb.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efdde3371008875747c8cd60b495769ab0e914b0 GIT binary patch literal 4255 zcmZu!&2t;K6$h}(<&q*Piuz8{l+Lti*)~$Y#!YA3Oqxh>-Em~sc4LQaW9l?W&=jur^q1FLHLB$)a>mm~{Q>vDL0iewVdM&#@!hYjJJgcz@>| zd<*wR!_Wt@@GC#{ReW7_hL6x(%A0$!^mjUGC-GBsR)_ncr~HoedO_%SX|36hbfRC9 zq4HUqOV$>WOChbk)wU#uRyH&$NQdO=+igd_^cm}MxgaYirFPMG{gSNyO~{&DJh9tl ze?ive(g{1}?F#6ITsCx7uE^CB)~?AlxsLCm+>o32)<0u5YcD~^m!aE0KGcOl{3P(T zD+A?s(^%<3$9x>mOJmP#S;~c|T41OI0&V)J3RN^RWym|onCPI-3y!bE*0y%zUN7hd zov>?RkbZO&RH&W&8kOV|D`goD<4^oC&zL7tE44FrEKUm<&#W_;K6NtoxOhyXZpKba znKeBE?r1QLtm*P4ZEHf zL_z9#%Aw}UP{mz8N&MW9-msHCN_4?Y)vF#IpXc@c$XA`z_mW{4q`K^-suM(BAP=$J zlTK);l9&2PN@UB{cH)OURi_;a`6N5)49K}kp4SaKN#c3`WM6%I`%ygbx0B>x`|kCd zZ-hsO+ets%p31#V>TL~=)C$)521;TAiQA$k8lu9-RT6J&rcg`hMPK%PFFhKfB4^f) z^3#!ux^y!1qc0ibU6kZ|sK^+{T=IQOt)wu^G5?gqJk}2c^i?zNHT}>Z_)(fP1N=ugi9j}cu>w6DrA;G(;+XTC_H9|(3jIizKj?%be^)6u z;_~j!^W9}b)IhbeV8j%LGD`K z+;_HyDU4q5$i~`=lcBaB$3X<|z$_18H`PhIlr6}|E-O|eMjFgL@COhlF^NQ!`FPo= zbCKhM8synVjaNV*#(2t35kHTG6sL%s$2KWqR6*1f`jptlg-<41pG;Z>MQ)(Jk4o2G zXhcG52euq^Qq{uX1r12gOB$$EF#QN+Dz?qXn?~&UA)QliEdlHKEP-%;GJc zJHjO&D0wn|8ijF3UZw4fYx8j}kA>OLe1}c&ETE9w zPZl6q1=>^(uMrQV`!TdJI!tP~waOx{z+{Zag8s@Y}p5a75zfvw!RzDMy zma&g&6NJ?%0_BYVMI6|7^GuG2{w_FhQ;Dq~ngo9##+w-Xs(dXBlC&ES zhHrUjoA`G4_(~q%(`>SJEsVRJFnJ5ajQfH_6z+FvI^DrWB6osNJ||0_VuX4V44Ma* z)7j6>yxIQXIMU+j!&XC=_ipXvSJqwCpt*JujD1}gsvt^reSV?c_3r)l_RU?Rur2^v zbsuRvj(ngC)rtC;F$Ty)3x|$7o{}HNhkj%-skQ(xfn_kJi#KlNx25eL25D>4WJO(_ z*$t_>o5GARm$~}&+EDpVf_Rh|vwGbq%~_Vd)~{U^)BS|`ZZ{5ds&35v)DQZP@`=V! zx-`E|b$Py5qH}J79EsP9N0HPOGaATzwUth{I~t5az$ZDEM=z1^-I2Jonu`;zh=+x0OI&&LUq`Y68ofnlmK0nXoo|OK!UpCLO*zyZy-Y2D~%N2A-(luOb@!6||`B^12|3=(4EuhAB3G)pD#dFVDGL zE_j32EL-3&pT}P{RGruPig*nxH85H+tr}KwMGbSXB4L;*xSKls9Klhjk#oOGLMl{1 zXnbF&JA__PaBtDga4OE&sddJXU0?L$=9>2tEST{F;m!N)XbmeCu;Rz5{V9*!6z>#+ z3I%Q*yJrCp757vPJL__eg9n{tU%P~Z4OKEY*pLo9UB-j)mD`~YBtFH3IZiaNxCtzT zw^nqe<>bM%Zm{63_q;sHZh5!w?C#un_nt0BaYO;5i#_-f46K*;;`GJ6pFHMN3 z*C7`$jRBAZhIHu#Rxgr>2DHzR`lX=P^~+fQuUNvQ5C)TMhz;?QaJeH~F<#B{(WSQm zP*N3B%<&*s;ITna$@B1L@CLCQ0pEGL;&}(7PMD9VUl4zvDnha9ZK??HE3&e-fNt{? z;=&W~ib;>u7L5|zQa_>UAywBtCL7WFsKUb`QhmINFdI--j@2xxsu6}(0Z#YGMua{jL713VoRRW{x;fRP Y3GTK0)Ze0e%NUa4*F`i|5O-An53AWOzW@LL literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/imdb.cpython-39.pyc b/sgl/dataset/__pycache__/imdb.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc99db55349188ce83705f19c08a3427c9e8ec5b GIT binary patch literal 4310 zcmZu!&2t;K6$h|Oa=9WY>RWQs4&9lwW!p$jr*)^P$CGKKxXvW9Cw9|@ZDvaoOL^_( zE`hEDmS+YwB*azSNcn|LbKXmJL zpW*l2{ddBDt}^y7YMlSMXuO4z#i%5cJZ1yl5xk(Cj%)i$of5txE)Oc5iXC%eZ&2-2 zE$znsV6HP~=~7%9)H`)cm*d7@zBA9+4JIqnd%~m_){mS{Q}R2kU44cf*>0N~_vZWC zzr?q6cQlMc5Oco@b*R!Csyn=o=0efjO=Y;<)!i)A=&TPP#(@gEGU!Kf*rT=1|41+T zWsgvWtivVih&|}{ciwTNfL>0ggtm)zS+0J@`rOJ`?NmZ9tjg+_Li%#i(-qLm1lnYN-ry&>R629|tS+3yQ{EWG*vj7d(pxHp)YdfYiNFPO^@nodJ zo=%l1b?Ku#2xrPcpwl2|yKSfF5E!ar1a11L8dbzPWhgqxuIQl8bB=Gs=9cl&en09( z-MHsqaQ34Up+bGxyQn0eIGW{Lia+xQJZFK>j&^f)B#uit&z%!^fiCClLFI@>y__9a zb8*6d@1W(Q%%RjIgQm`|aU)U{##Cvl+jS+dgK;J7W#%%~3y6Z{Oeq@;L!~M->;*xT zL^=pmnVQQ(mG;6c3k#bD!>+!cnUbAqRsuT9AlM6&P<3?}WWzYp#t*dWMoAFK11$Hb z8(XRx=rGenwq4_9Vcb_u+M!fTa+B^L{F<9;5cJ}1mIcAr?3>rN?x%xrE6et`?%a6q z_4x2$E8B~=rgCqQdYi*TwTyMXg_7C0;;v|jMN#A9I*GSAQ>ZP>d?@$AKpzfKkzbp# z3iU`OJvy1$k7zQ;+bG!&P?157xa1EVwX9*9BmOCeiJT1vLx?ppm%G8%=+H~M&uHrS z^<*@-7phj;Z-sF<2os&PBK#*fjYziosRBJ5>6R5kam;x}`@%1sI82QHVK*LyJ4(SX zmv**q>;!i|y1f%@@4R>O{hN1pK4=T|3fR?msA`>6mqD4;7nCGV&b0a-#xk-g<2CN_ z@hkRl{y%Sz1e*OwE^<02 znGbCaHH=>h$kxV5v!QVxq)`Gd!K?t0IMvH|lt9SWUR16@j5Jty;*TIsX7h^h`FP2y zbdlpFRmjV&DzAb-ym8Hr5lat*6vv372QDdN)jk>W7%4Xj9X8 zt`!Mx%CPC6s}(uS1uaO|7qn0-V0r^(D!R+Z8&>qQgF2_;1qH}1RICD|w+(f^Yd83s z7f#Vaf3c(WhiV|x$4Q)a^=Jygy?TcYg@J5=t4R0$vvq4JY2`4C&E3;ghP})7xOOjDj7ohDTqD z&%BgZudxdpl_>{iSdKDr=<+eG&GYFx(5iC6p;v`Q_rJs#oA~y6s!IO}b|*QG_+$=J z)u2V4I3U@Np@o%w(vWlDWZKK?_-m{oq0@}KW;H|)w5vbm`&Y5TJXUDpyD(klDpqOc z3&eT;)EI@J$CjvTE3t95!iI^;1SV^eX%j-;*&-dj^ z02}O2B45Pc@*1c<*#cI#9w|pSsOQX<9I$_&DZEs{0&R4 zx=c5pibKQMWSz6|tJ=xe$-DG|#C1tLA3=x7`XAR5zxZBZSI{ew-duL%aQN&SkTCoA z6$BameS~5otg&nOflbqNQjs0*qp? zgtCN^0Rpl;ma~r=6lBK;l@tC~vG3j~g1F698yvX3#O9SE;r|5fa{=yOi=#~U(!ub0 zKnoHqDRXh^qKq#WaX!r~o7dvB*NwC5AZBx3l8nOs4lP9Yvz5%7AgIsDWH(2adlL+X zM};gcG!$hV^$S#cDE~ph_@NPx@3j|=zk6%DxXA9PMH+FlXdIf-P(_J0&9iInPH^|v zx8K{bYMT;(SMR=Y(?)NeeM(v5}rUN4Ocs%Z`X zIE?o07Za_yO!e%NHU8ONnJKsljATJS9VOD#>}Vv5)t0-x-e@q20j%ViJZ6D}?~PQ^ zL$gi!pqNnAsRr1R(ZH-`x~p_BOtT>EpJQD+zh*z|0{nz_LVehk68q>d$$&Ykw8KF_ zP^5fVtj_n0P<_LvmzniJd%t+IoABo_MD`RFb6oE82L8mU?=135Xj6~J8-ggJebMCp zspk~6%NI|5x9s@bJuSO4ZjTFI{%_5xIj+E;U7^WewN#Tg`Ld{;*06^sJl+J4U-V&R z_tbT;JK9acI#VDx4T1}3kHU}~{%w>@p#qxY4~6=GP!0-iHM&2J#R)rhP8f3W^M2fG zXZ-{nX8b^dXZ>!niWN&(@ekVllqa4p$+AU|0(XKI=1V!xK3T}U}vwidJ zyQY$)35AiV^x5xB>KO#)1&`W)L=NJIAZICj_U00B?+7kjfEoU|?!N zus`a?#fW-~_#aU9eX7W=_JN?t%Ekc_E)t6eZ@{ZZAF3^4AxNfvN>xbJ8Y)wtzD0Ur z9OFUMq5)fdn<$&=2u%`H$}XYuxK}1m^(qa&SrXiLsd$d}JNve3FPKW2Z4#_Ct_8OQ zCrx=6^&ZBddKV3~O+D*2=iw-(D0*}H;K_;?Tp3XxC3}TJi+B*(@y&bi2>UK8GUCe1 y19V|Vt~!10zi77;hUddqzfw<7y|%;s zyfa$qt#I~?$v}p$nG99?!t1R{{yFPLZ*UXa>v9`B)7khCNOP~0>XXc5l`?p3kB{?I zt4yW`!(8=g`|=4ENVNw~O@5u}%&4dNWZx#o4(|j1M(pg`xI8!*_J>*ihU4mfxDbR^ z(xHdxgfh(Vnq7!XQ47iMv6{~WXU419k$(YB8q~nku=W9?IslAocG;-C8UG`363ixM zOEL&0_RgEvQtfuzHmD|Jr8Q~7#%Ve%h9*rl$zV6ey6mf}Qgh9vkR?`sXYk${W-N>Gj=dI#7io{VWbj7KaSwDFnf%DPfJ& zsLBx94Xt0P{J{D}Hd5BtWoaP26XMRhWN^e=BETIQ>0HJ%{WtsO`@4tbNbOeDtKH9^ zefm*;cDh>)@?H8vIM432(RRjXx(TJU$uJ!N#y3Tahi>@%{Lb~C?JTv}C)3hHD{1s8 zX0-}n*d=Ut1dGnZfV^1jX)=A+vtC(^t$$Pw3ur-SC#f9jF4uR#)eo=0w}IIXvB{_B zmbnL7B>zkfrVw|BiI6j|K~wFzV?6a@4BiO~I*n2EW|ggXNlTE2b{`}F$&wv%3a z2SkKhkXha>sYaS$4=|VFy?;Zz>tn8eWMf%-ED0ap~3k(x0&zpCSg20?F^QsVBwMH=$wDyAb*b zc1MUzei`GNpy=Seas6HgM+s&Vv*j_trr|W2#?qfQO#7&VL^B_uq=EA(^(<%{Z%Z1{ z={>@juUByJ7$>dB_-ipu%xWFhQ5_>M_;sTebs`&I$J3@+gJIWYav}cq>9lpWYVKTa z)Xlm@`Q!yFKK7Zpdz}28>kl}aBJ-Yan9aI$b%84 zfi!1h)fJYXT7G85$$odsisIC+9lK15AfD`}&M!e8TTh)SyMjEhk0ewcqy=gKl2vt> zjg^fKvnn%2>kmo1z?A*6w41l38)ZfhPpzlR6YKTM-1;ulVY}`h+R&(?Dz)|VvKVNR z()#6C70$V>ha!XhlhN2l&wu+%t+6v0>tSK6uTF=i+i@Xm`7;|IsLV_>q>CwSk)XCg zFr`1XUM7$9k8r&HG2PKra-h-^H5?oo+i)*#TbO72Dz7XbSfLLrSGTi6ybRFCIHvj@ z0OK+5@Pzv!;j5w}Ho5Ql{0ID7JP}*q2)@eKvFGh%wusNZ%PcsAivI_gE+l^ld)IKo zV*$rmrJ#{wF+52_NFS3{8a~F7Ms#{7=bx`vaPZg#8xrJ{3Y~{gW0&C0`0dq&Q$PWA zdT3Dmj=~H63HwuT8e?Y!=Z}xYj7uLO+W;j-fkdlOdw&)8*-T6WM+*@q9)-(2wxA=h zu{UybQ00;yz52%sRtJ|28O{(8b$EqEORU&}J|d@&ZIgV-i`X<}e8IouxEq0a;320r z_H+ytI*R6{^)nPLM7TB7UPN8cExPlmAq{fW21EOsN2%J z_{;__w)6^}LnX!rYBV-yHo!>(Wt-$`t~SH(qM|I$l`d_Vjp4!4M({=#D0lWs^L#YU z)kqaa$+;gVUu5}2IlP<9E72BR^PIqax+|j7A)Rf^8v&}^5#C)1kCzqP@}c#>y6vf} zw|jvnIJ2TqW#p~!(D!^4TJhNV=LX;64?$0O3w|7mSe!q&ncHvm72Uv9RvqEA8%n2q z(`{16LuD&XZIY&=Qch4zQ`<_@SCcHCb6nZCF~R_;%ADy(bPyFZO;MnKLtqWSwik^> zU*$PksHnD zWwk>+nDq&|T8vaX>>uY!|BMd*34O7S24;s6*UmmX$2Dv7Y<2hl8R&5|C0)C=pkM1&b9S~9M|N#}84mEZcS3C=L*W4JtW69ul;03plM2q+kjJND!jM;>2B2Cmrw3 zd&@`^j(vz+5y^Iz#NDE75|aesn=BAr>1B25Rq!_Ljl?Cs3X&R5)QMg*RB zp8ifgZV>V>oNPWWOnw8S6aWMfG$lQ1zpUrjuRHR3K3(vFUTA5xUJbr1jYjoe-SV8Y zF>3aj7I)LuXs5SB$rB^lq(qvK1fp8$K{)kx15w+K%`Z-$HeZpWI(OBTF;TL2t@ddT+^NhcuhnW_eEVa zo{?T4nxX~YQ0#~{d~4qlm-He~ZqG!Y#wynGaXLOR(YeL@fWKr%$0p1NgQTCt=?e;? z`~HSjh@&kX7*;V&G_A-ro3V-sdWTeW&M47NMb6x7;Bc=3ocomvFsM9$VMS)O%9+zY zA}4}ThtUuO6l2}1`dYK&PSbehxR6SrHB88Pk|mmRg*ur1Lgjr~mU3Z1Uc~yeG#(dV zV&n6&NE2=Lc{Uz#t>Pr(uo3U4aak^@U&m=8HmLXcS~llHnMt(Ka-(RuNg^&m5%`rO zR^wWh8qtn#+)}0k<7V+l8dv4H2JNj9w|-0qOWac7IfWe;I&%J;ynOHYG#|<1vV3~{ z`IF!OBAr|wm&5cJPY6zQyuP@jVxsC`lqOzG8vvmPtU-PIx|IEJb!&aLqm>=|c)_a& zakT4W7EVh4b%6a(d)zd5PGxV6-g~vjhAgDg&-3RGAC22+03cv%$<`tokB#Od4k>}Q%A6= zt9?y`bIsHvSRFu|(pd=qcv%bPwe`I6&LS9f7!8LoGWDmyG!*W%rkiK2A>_L6DPTLo z|Bg)~-LCv9s6vQxx2jdFibU{TIIZhlxSTx^UbDY`F>PG6^{v@{Rj(SD9lj*l2QJaK zC;z5o3c2*^4Si5GuPNk>JEDf)FUeP6y=n8$%`94cACiYa72u#QxppYY+V7L{`GtRDNNcex4fFW?NXU`lrU%GAnas+%(UI3eRQSypWjj&2(koPwY2(VvM0Z*?ff)M4gaaAtu>$QvOr5BxsUhTeY~#e@OnrzJ`v_&! zr5H+nA(P>$HZ?o3%ltGxk!fk@&@eSHRBpvPcMs-gkeMUKY&zJ&~v8#fuRaP#J-0Bgb;=FI>ln? zK!B2&E1)cHe{B$Y7={|k=b3*^KPG?bO+#20K){9PY)*v(VO|4D2qhMpmdg1HyGQ11 z>RFl(q3K|V-XTjmBm-C*SUOO}iXJ@eR^F^8+&P3{<=269LU2#SK-lB2Z8nQAK)>*iZhXquHO1XDxrEECRs7ohqeG7+A?_v1zby& zz13-`_u*gT*=SWexDVDeC&q)FhSJp0HEcZxKMArtyOJt5eq4aZ2@`<#SjWEA z&GjdvB9$YVX(<+d7kv?@V`=eDw5VHmQOu_Z?xCyzcL(iEZP7enwGUu!%YL}3`i3UP z0oJx>wYU3+N3ds!Ze~5qhr-MrfG_RQeR`Mv1UxRH4Hg15WLNh$E5KGC((!B!Rsrm^ zF!N3R>C~}9!s5oci8vqSVhm-Mn+E4k$8oyg*fzn0;IxqV^hCXl8)4y7=>O_71iJuC zbJ>XWWtu|EwJw!dAr;nk+m+$3wq1b0i5QI>9HA4wVY!`Mh+wHNa&>;G9^$qh*cF-; zd%;0o9%0vJT!gk}L+VcY=c!b`MghOXKlY&IIl?q^bOL^3+q*?>x&1l=x*~-++w87* pKsg%gBrR>zXqDPHi#?U?Q^H>@?DYsec6p$XHXzzVXr$nE{s#{jDyRSe literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/linkx_dataset.cpython-37.pyc b/sgl/dataset/__pycache__/linkx_dataset.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..953f46c0e3e45647fd65a198354ef6e73dbd9188 GIT binary patch literal 4654 zcmbVPTaz0{6`t<7Xf(R4_TpQdgmtJC5tJn-*yK{iU=orzxHcvZ*r<|B&8kPM*&WTu zJtKQpHOd3)5Q?gV%PY8Q_mL|80DptWe&xx3!7uQg9%=1e2oFd#-94v!PM_{R-|6rC zq|xvU{9d{DZus|8hVf5&S^O+CevBvk9EC6hi;W)BH6NKhi6E)p-L|rr<8d!2u zETMKqODvPJB=876*AC(y6V zZK0Yr-o%sr4n@HZj6-9_3MSZ2MV`%>;QQ>?^lt7r@^WG2Tv+?mcNu*)_2p)66~;b4 zunR7nGthC!nUOn%v0d3`h2gt}SyZu(D9Nr|nVz5T!Fm#m z!u2H0*R#oZoXR{D7urlKd~8GO^vvyX9OdhgI59}lrnHxp9j6aMsk~nX@g)4Il&Ms< z43e!-+0YnDJx;m4A0<)l`;wHZwXsb5VU~rZH~8ZqA7;w-1>8kBenwNOYFYb{P-}iN z8Tq*kqQr+zzaIx#HkaNH;z&FreZhbBjD3GAOhOst;asaXR}+a{ScQ+Fe;0G*VU|;~ zT~SsR#sg*1&L~T!X^y36i(FliL974G9xXg!4uDYgOM89)uf~&3cbJaCZkFBa-hS`q z#dvSGn{CBiS~zghjiY2|S8oH}I^#VF{2R-7va=`*?lStX;W^FgtjQewEjB&%?2&DE~SgZ=~Aw2dSpb%bI+n9Q{G?+RF15@uK?TWhHxV~^l{aSx}?vncL)Cc62FTZx_ z{e@?ec<#}%4pU`=iY66HRFJDr&kQzSy>#hfIpalrDp@BP?;Wj}V-&4;6;JjYiY?&k z)0G*^jRRUER(i<)$nROV4fzVX+om$pY^<#9G)mwXGI-#NNFr|W&L!ni$Rc+DGS)F# zeF1rRaoP{!>>3)2Zz*v4dIn3+(5ufq zLJNDfa*6syRTtJ|rRj5<$Q;!?KGg*>X$^&8a#sFb(}m}Fti`A2|A&!BJ4rr7n0*^h zb{YlkAYvBBWBwVQ^oSpFcmUs+I-M;Y`z`rhjFaC(p_=|vsAEIrB?yo0BJIB3N=eGU zk4NvJ!KW9FnfFx9)15prk_@Ed)<*sb4Z$cl!#w^7M~&K?*VeqY=e6^QiwYpXg{#L^ zA8`Ny4v*yl2s0~Z1p+w&H7^jpw@u-lLl9?hC%}V({nvoUH;t{8V|AdeZ%nTb^L(7W z)$MLY`Eat?f!lQ7{>Akh-HkN4_ zLC_C3({u*_=sp9v;~mf7@3~`N90%jMV|~Yv`O>kj{|~0$>Xv@n{RX&}S1oy1+tQw) z?6HiJT+&&}*GO(_I4E->W97~C3_x)!7Q^3pDt@! zP18dVZE+7!aM|?qF<(3S_{hNsLp40w!F~gab8-OO{y1RmYJ#6L4)~#QV1gI$f^S=R zY>sFN7jh0NIb!8LJ7hCkm@@~GD!77$Z#-cilDx8I%-n+Ln(+nr!2tS%bJmzub7#AD zB!}z^PJKrbd{Iia;KHK%6`y&#ALRAIJ7@^!knOW~j8ET~)e5iFT-506?Iy9Nuh=Ji zR-ZKr_h2b+9V{1(Lw2x&`BupyX59PR9q<#(?k+S#>l)VAESlS(fPlx6sLxvYX<{zm zD@#QSe5F}bMdK@Xww#~A9-S4Y7m zOTTDq-BFkanwIfh;4Tg}!#I-!1H~f6d|aE895@UEY#^5H5;01^{8g1oUn^l-ipzmw z;Ypafi`(HtD9q??Fj$kfKx~@K;%nRj)p-ri^Y}MGemqD$w3@igT5N@odxHARuZcdL z7X64bxOg;hHvwG1KyGw&*W~6iz>x27`Ekz5D@~XH9*7ZuWa`Kckj|{a5)MFTiHa`W zbKp_{w>@)amBh_$Q6-9>JSnxorEh6sU)~lBz;*@hf`f{v1L{=(1nvu%n;PaOVD14b zE?{bGS7-H{AuM=@>}36P&fQ^7W9u5;y;%bz8--UicK9Kq8*0{sqzOnig;CVL0OSw( z%+@jjXbu1R8Cam@j+R%n98$7O%PaT(amOg^gC&8$fE&*_YE>4kr)$StPEi?{0A+JS zvouLLl*SxEtRBS36#Q}2-$D9kYc8*B&6br#JIRvdx^uSdyvlCSHYbsgu|ksjaz zOfG1w%d4dJhgAFsMY~nr96SMgql8jpRGy>$6)H4tDVdbS1Lb)X$__{4e6K`djXYK- z6mZ{4k|xs-Nz6wytOAcFxV4uwe&n|y?0LiB{ydrJdD=dIFp1+{wQTEbqRdpNSy^Nl z?jpgEpOG1xR0t|wMuB-@1cu6x-b{RX5-P2enM{h2bI3RV1+oH^FPfg|0Uk}fQEu>( zhRGT`3jonoR|or>zHt1qJCPx^o$~amWSl-ix}YW9e!6JaD4Qd$>if#`{ZT3=ptioM z`~JO25SKkVaZxVdh-iharpsK1D)|<*5-R9cR*iXf)JK|%^iW^nRDdLYhXxU?)j4Df zjg0aDmmx?j)mQ44d)8eq^U%}s4vouDAbMsUN+^{@rRJi@^tIo~9*z2dMy={#-AR(p zV3PFnC{2P`@q}Xb(Yw4&Bi*bYjrTfHs$EH~AYbcHMC@$BTXn8hhT~KJ*?}ZZQu3vi x7k!l0P4WmalsIcibSRU4y7024*}QwT+}vw)0vUO;%j!637oo2X7GGyA=ihg*%cB4Q literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/linkx_dataset.cpython-39.pyc b/sgl/dataset/__pycache__/linkx_dataset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..773335b6b4055cb4ad97a0ed2e2ca96d1064cd88 GIT binary patch literal 4672 zcmai1TXP&o6`t<7?Ck7CtBWmPgPZ~?0<!~O6on!12)IgmH2;8q!E?XzJ37AM|kiXW`+8DA}i|gdtdLbeV4X!0cMA9J9NQ9;Y{Paz_y!W z4y`L%VjgYvbHg&a9+X~G^&4+*{PcsVX380cN%G2-FL}G8+_c}1dQlKdh7UiYo9I=h zx==%Xx&RX;`wgms9TT+$xNH zeqa||ILnZ6$eEElg|S`PXNBRrg;`kKK(C786*lTxfqB1P7zd5QJY*l4cxs}|p|pg7 zeN{Fd+R}wi%8_9xUU^wngTYoPv;4|s<@QpUgmL_WYHp`nNw$>^ccLVF;f3+}={c+< z!5~~o(tITw4Tq`BLvd-9X@QSyNF6WV9L7<;5{VOwBpu4?yt3o;ZYY)a%ODppAHQec-wcyb26;G@>P^K&Vi#KBWys&fxAHK{ zsk>THRu;y6WzoqfOQva#t!Ry0-H<|y|4kn~ctV|r46;joegC2HWxKPL4#G~B-R|6c z^Tt}dx7*1!;|^^ccC{ii zIKTD$`ddf#NyD-Z9knHnm;1LaudlC_ve)!2WbI_QH?u!J-QOZweO#Ns(#H!EmKz7O zBkb#t|AF7OZW{7CShH#>GtGv|+D@YcwjhJMzKA4z7H_XBm)sSZz(Xc4ks|bp{4Ofx zy%tAV-b)9=tC#WoPZVR^nYjRxMFL98Zod|%y&%r6qI2{nxlEtruUcBZs!we|4|Z$i zCiR1=9_f;mTF_TZRhqLOB1N+^q>stj^#7IPT8}NW&OFxQ3|kmJtz#@K0ZYTs|k?_r%>L#3MjQ;6eGS4 z==uykzBHrYQz1`Jb4E>4k&Ifa`A2jFBex6%`2$=yT65Z3)7GB0&I2wgFncasy{`Iz z!}ww3*dUC5V&$xW3x|v51swObDZI0A-wZ|tNRXTV9U$>_Yis3N?V{@&<7-=aKFnV2 zbT*@WYjmp(GfYw{%;w@Uz+BN;1qih@izrQ8yTKITf>e_}cCt4Cadq+~0o3}%_7i$wjEQ}rHO zmIr6q;vUT3W#iK`*7Zcdj7cp)P7Oty)SsYvPFIddk1N&2CHPt6fFBwMCO7~u__l>& zbNECUjdM`R;Su-QA)DC3oH!V%f)`kL#}oD;jaN2}iCYjwGd=?s=tG`x&KQ$w?rhhd zjAQIGPIFI2@J2bZ38NL&`+VZ%^};)72%|LU|q^UAxdRLjfG^PRB+fr7VgqboCJ(Sm% z>43P)a8EhF^4>7SQiL|LG4PCN94cp&M7Kw1gCQuXQ0xs=fdqdGyo2Y^y$`w!3Q|h3 zv#kIOh+;TCgcq~}RVG-rvIIV)DsR7|X_$OfuL*afe06DxFho`5g<>ttefFy_! zAZBVW5BN^3!V(VPXo-rh-Lv3QfW19&CY8j^ZBZqnpZu$o_63tSG`%k)2Lt3?fw*u` z5p_6$D&PYnh66CQ12EwLJOIfBgpKX$q@FW44DXPgY@dEJr&y7Q9(VM3MUP{Y^wQ&%+kY?YgE;|*fyn1P z=~a$;Pxp;4IYp&!f)q42G;fn0&RnC8KvoZ8U>g^yAv^ASow&u&qBKkeP^llu2 z2CqcPaFnm;nH3!cR*)bd5GI#2+T}G8`y;BZqgrj1afK(~a+F$Xlu9DZlDJ@r+?&)p zPt^rf$_@v^e6Pe|jcit@6fodQk|yI2Y0UeytOA=Rz^%?{1j%n=u z{ZSlut7Tv36J@SK-O5pp;U1+JvPWucQWa74ZB+O!l)zHCO;0A?JPDT8$y_GI%4{8# z!5o4GTSEPe>6soU04;IK4xUAC0WdlP2vHEMn=T(;oDG8~a->zKyuB(Umsd#PTT)Ea z)oP9MIpVIquRPx$q+$fB>#Mr&-yQ{VIiph-LM!DnC3@4 zB&tXdb%58^bF_pQx6UD3=w#$KU5WsYn=Q_~c@)z<fiADXo(tyC7( znwcW!*S05pT175LEoztDPLg(il=Si_O@dhQggp0?XZZoGbhBPG+-pavHYc@!tgTIc zv3&~`t24IJDWCe!38Zq8k}oxXG)KwZD39PtiNBUqi1O)Uo7X3x(+&5v@_?_>Eo2L* STvo?zySUdnvcNiPIsXQ!nbHFQ literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/nell.cpython-37.pyc b/sgl/dataset/__pycache__/nell.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f38e5a36b2e449c1126745bf930033c60ed26d6 GIT binary patch literal 4164 zcmZu!TW=f36`t7}mrF{b?vj%@vg2G>G?dc@MG*)_o4PSl7)o<|eS8_r51wcMp< zM-j>F21Vo|eQ_^9-+};C-};y4wNH8PlV95J3`r?=LSkm;dgjczeCN#D3kw0m@3oV6 z;`bYj{f9d9pM%D4Q1XAG5=`)fwYm9PZOeS^zSDO3bk1v6%$VP=wu9MNtz9!?^>!U! zD_Q6-winHuoh97#iCtL2-o*Tc#hG#n)^pE! z!9uUFQCA8!@eBKce@eY7N`L}ofg5MPX^;IGD{qXebEe)H|A4>ob2Aa{q)|WKNi(&R zpAH6@RI#`*cIi61&AJXgj*`>(_flq3yI2rQ>4R(ShhaDEsxXuub=L+m>%@5;i;%Ps z4kC4wYc~{7llH=VkaU$^4~c3R9>!@bBNfkvgAj`EreRkMLH3I%F|BH-;#^U)>1Zd9 zlLPG#r`nNOroao4gf&#du#-f29)|yBPhQ?V%KGtco}cV~c<=t}$!NHnA11rR1&n3a z2x2gji(q*fC9k7m+{b^zs&kKz7l^~w{B9ji^Y~dPbZw)xSI`h#Ks#hNWHG`&^h<6m zsVag;e!=g;hMksa{fPPLg@YwiOBOxMorRdgP zl69gazk|knkD=7G3{_p7ACpgqXNS=_@bsr>U8A#%w^(a@y!sQTJtd_Imj8~n5G}!{ z5`wQ-f;@=z9P*zF3zD(-v2#nhG znf)7Rl3PU4fvy2kq@!Fr^m#_X0GD|*xrGfm?PA>FivWy(*Lh$q@@r@X{5l^m8r5HJ zxYRcpbPdJm`%PTskd=IR3z~pz1^=_P&7hl=(kl3bupd!?QEFv;(|s6#fM-eYH8+QrnYjD9;tO z3rHHHjiS+e9_0m;>o&l?TrOA2)q=RC+pmdgxh{foL%mom7c0eTu~w`X8{ozlwa0$B zsdga28=`)0{q26aHEO7vlb4FkVv9WL0ZZR>n7VZw{N0kbI4id~8?ULCi*4iKuTa#d zRj>ujeYO3NeF9BBG$Ze`57+yPXFPlAwoa0zhluOQH?%vBC7j#M zkD>u+cOnd>r~Sc5+x=*$gZo%IBWod`1)VGl37)$Mt?5uVBGHRFaoQQpM^_GH)<>{I zHkloS6q0gn%j^uaGvEQik)#l#oeWnqFs^N+k(k1JO&EyH59|kYL0pZ&2$dRTO%(3H8XE5bE_ZT^;GtG|MqVo*IV8 zg0-hc=4<0(@)O$KrrkBtUw9UG508}gA)qjY830gGhoo{{H7iV$A8Jb;X#OKe^XIYd zm!O%yjfxO$jjvc6lm?maO{;G0phYEcR)BQAb=~0I1`?$J>m6$gC}{BpU%{?<+BA!M z9l~q+QEGS@peN{~^QcUp)1sz*Z?Erji0-tk^;OG8m&N0B`^(qwDO521vt5+4? zHbZ+l>ilrDn9YjH=6r=+y6Fba_m3_4R~3BCQ(x6Yl~V4)fh*ewHry=OmxR?b$Ff?Z zXPirtz>I5k{|g60yopkg^~X2Tl+8;pp+UJlrcUi?+~E@Rm6v0n5Fg|^J=)dK$A-fswPb} zuMxa^r@6e(7prx*?)ppqZ8Ou{hNr~pER)AW`7v!Gv)7d@Z&6kt-=ZN(1hqHlc8-Bq zpHZitce;a7D<4EMM^Z8U%8zLAuc#tj={19RkO`g~0HaRJ;0oEkUU_z^n;9kR6+B1G zy4L3~7n24}k!bzL6L_&mm#8hz=P=TuWGA1RlF5xImAn$ov}N*Ye{U-OI}`{^>O$`v PxUwl!*YW15^M?I@T$4R0 literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/nell.cpython-39.pyc b/sgl/dataset/__pycache__/nell.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..308608767b8f039d41547e8d27ede63b185aadba GIT binary patch literal 4177 zcmZu!TW=f36`t7}mlshKb$1*`cAO?G8cIe37YGcaXq^}V0;O%zv>_a9H=LD3YPn0z zjx3Sc4T{J``jR#U`qUr*l{f#TdF@l4+b6%Y-x-pUFCj5!?lWi4g@p>k?~S8( zqu(|d`!{u-{yFIU5+(a5h+u-ptiw%hbu3fc1E=Hi>73UonK6G*?o?)D)lSun)jBoQ zR=hB%cj{)&ju!`wPQ$dFcxlk=G)>!$mj^4I71Q?O)xlb4ZLr>1=j=WcCE-6}!jD$Y z?9PVZAF@{Y2XMxATU=M8QzgT$3REc3t;;ABfjU(|C}lXphhm zJL4S}ZC{j4+Y=R0eZ(N5npi;Xi@I1uUH*zWtW&|KENN#qisL7h1hl==-r2eJ#AY~ttkIe$RCGD-ynN&`pEZp$9~Gj?u`%X7wV$uce+zciEKRuT@Ptt3@j z+3|3gN)?IgW0$^VyH(SbkHh#ldPmAsY8MM4DShnLeh~DMo(cl#QFmo1({7YykqAf) z!7x+@nRWvKJ!vnLqa3z4|dr8m}r#SXu7@Jl(P*J9+*>bd# zMe&|?h*RyzG*#e*j(p5@IS9IOm}Nomm_2!I`yd@e+gWzB{o(!h-;776+u44+OmVOWqtTAsQoSO&-u6PJLlZ`hjrxOUh9d2 zVSR_(=G=m0>|J>iYsT*9Mf>xj<;vI5lO$YS4=#$^Dsp+9YV zlTvP-)BOBNC_ zhuS#_<3pnw?PUkYsuyc#kUosGCkIMKQOloNoqa4ZLon2-&1~O*lH8yZ9q1YGMEc3J zL)CN21+2^i37JhsyBK$P9Z*r>HD0mme8aS_@NwO!{bIv2x+bHpq8MG@!{_X?f-Dr$ z75p#OCWB_0g_ZLO0Y0FgB!ARYHlQ9BAgm1_3fAltu!!8=WQ7AuaD;tk$&Yep;+?a^ z${mcB5D1)eAVO~6VMn)Bnf5;@k^YNP;TPqiQdCt%RiR(jUl0^Z>c9~9JJ^mY^%sTr zwN=#e+N5!gSdE(&aOM*Ds3PE$`b`?!V7ZGH_=dgJb1$#w3zu_u*i$qCG&sX(Vg@x6S z6gc*C-|X;d2Y6Y_{c~RUYJIYyHYZn5UdpTIkTFIZd82<7HUke!MYCAWiCg;m z6;Uo$MWt9%ujEU4GhfbE^3{9|-1wsUwO_2OElBXXsGV7V|6{Q+Zm1iRSM&9JgFNXE zEV=D4b@Ll*bd9rOle6(v^;*7ZeEKI8;%Ut%u!3Uqi_N4mt#{ZCMg=khhn?A+CCyvx z8!NlL!?^i>0&2W+VfTms;r9c@C-Ni_C0Qh2L+?qoeHm2{9$leE2->s|PG^hWpi@3f z(vvt6`%!1(q8T7x3WAF+aPJzn$m?{0b@Ey3b$N@LUN{^^iI@g`SBA-cq;13yT{;M} zP$? z{tu&t!Xq%ksMkxAWH%hvr`Dk-brUS%2w3Fh~4617zNhPLmtuDRyfysX4ZL=mi?Wxw4DaMr2&-zhUWgUMW z$TaV1TMdR&;ZbmqTSQ2V@&=Jti4de{6aA4yA^Phj`gG(@)0D>ydFnJkPOLpOGPR8l zlL77i6YZ{&{(_UJw|}6t4*>-UtN<7b8j#9$*{m>4zOOC0r}+yg=g&jm&v9n+`?&vas`4dId{q@?O1^UkUT*K%h<}w`N$_oaYizYg@?GlP zCqh2e`!6VBPnd{wFuwkqI6|Vi)k}uQY73ccPi&bybqj$=MKZbGsu|b$F)x_9#|LyW zJqX}x#=Ye4sk6MBs`mzPqCu3vP^K11*=x(2W^S9b&v>gp&zp7mGOsSp((32KT;us{ zyo+D1D-IQZ(cdw%txY&dq)t+Kcq$*z#$SNwQku0XN04`Eh*CoB4SU@~fE5Ws*RpPJ zIBI9ZP-aLmrnT(Q;ZL}_x`}6rS=asy24j+< zDG05;IDwaobm_FE`5Z>tl=@^-V=|c%C6rGI=ORtj{++1;?vlNlM26lw@M2Sl*6{YJ H@rM0>8=5l; literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/ogbn.cpython-37.pyc b/sgl/dataset/__pycache__/ogbn.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6be8936642e7b9be9103ecc066ac80f215fe174 GIT binary patch literal 2850 zcmZWr%Wvbx8J`(G^|b6=+v_Avy#w2Q*J#LIraNSl6UP$n2+=3?|c6EVW-n1@O%6Dr}DoILjHq; z?azhI2hhqCgdl>ZB%=1oW=`bN^_&;^b}onlc(OE{HKK-HbJFIl6}4>NP200h)Ukap z?ap?homnsHQSyWcUj#3R2xMpJM7x51Mf%|{@J2G~Q*-e2YC6h=e5&%|sgmO3SjVN* zu;M*Yad8Tr7OrpSndubgDOWNUe3GQnbac+=nMjlzYbnNf@AilK;jbKlkXE;diQZ3S zLWL{57mdh~?i$^aMjrG-(Re|i?~A5ry&#}*TXf(Vimup!XX7X0lBfw9>>BsW=}Bh1 zSY0OH{mKSI6XugiGEU-jynQh~u{j9h3lJ3*^gK{K4TfCOAE~CARb&g~>N6F>ttKo~F6LuhE)F+#I?m#mJkD}` zJjua24;NkzXovllY5qM<=kgP!a%H@7UPz?^dk5!9mT1nEhr|1D~qwe^Qp|FinZirktW(SxmIzK@kCsjkbf7aw%6cVmKwW#*SMujC&oo|##MQ) z0Rx@|H#ayRr*TEfydOa z;18e!X@S5gBm%%+vKp!pID(DV;rpoXD0HEPq<#;*Ev{e_7OZ?h+3Z zy6CQ>W4m;N2bE|+vv@)Xn&0Aq1ZGF1W)(YgR_--fdD=ed7Lc%v z)K2Zfe5dwMC*htm^-OrRpFz#3ecgq%oh7YySSOdeSK!=rJ+}bU7Yk^yW>>#2ygbX`K-7y zeq8FSLK>D8CYXRH^`9TB0d8PAH=fSb_|&**o=r_qKA%gukfs4gc?P6WDj#n zZbDufH_v2goK$8CLl4FUbkXz10b(((fY+MfvoALlvQSB;jVmvqM(x-%R&U#Lwh)FB z8SA-%8~rpsk!fkd?YjX5Am9i-+ChTiub^;Leqo$(o|=}Oy^zWDRGWq!Sbgc#Ff}ps zZl*z1+2qt807m%^2txN*NZ+A8ZBn1PtO-@HyXilG9-}?@hc{JmbAD{I0ghw+xT%WY z!oac=Qy(BpVHn<|BZU3RxrV^59N`=hkgk>95l|ciyjx>IV9db)4oCC4) z8DtXh12=8XpU>kIi!g9C5}>7jNGg4$9^!EfX7xuD_}!y+L72{lT*op^;n)rT^r@3)$YjQ~e1MtOo zUOqh;U>ZQk6$O0VSx=cgpz;Cc&)@`HYuR+2jQhV`0Jc{coMv~s1~Y@XPSVnTsMrj* h_+Bct?pbCH-(4H-eJoklAt7ypD?q|`LYwxS{{z6R=F9*9 literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/ogbn.cpython-39.pyc b/sgl/dataset/__pycache__/ogbn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80e9cf867290554eeb7690257757d72193b37c21 GIT binary patch literal 2904 zcmZWr%Wvbx8J`(G^|b7b?RAnS+5%~#L!@mQpqD-Z^sxvKCl@2TH6lDu9oi5bqD-|u_QVB92wwy*hzMk7oCDw5% zHOzQVR9u`xqlNRkab`Nj)08V23qDO!X*xRRi%cX+PP7ygJooN{81X5`P)Um~VxspW znNr~j??oeWq`Q%DMI#T|p=i7y(Dp@Bv|bQE-4-4AhN3I>;M@3-xFl)dtxp)2|W)l{>Jtk~MAm~f6^ z{)SPa-HKdzD_RlmS7aSj&JF$6fnEru0i`Jjh~bV0Ooec&2@92r#YC6OSv;5h zEZ6;M4qAG&^m34RG-#RTS8=+Ke^n}1#w!RDN32bk9h7glLo}hVWcM&=piTY^rF<7E0eZe5$6%3TQiDa-V37+4Ede@q z#*x(=U-m~Tz8uXYh^19Bv5MfgN*Wd#(0SYfbMmqZ#KFLL1`rY$Pt%`nQA z-uLzJ`+DH1_fVSmtst*Zq2R`!=4u{m^$VDM4FzJoheGYaqPL-JMR)1)#EO2as5=nf z0Dxg)0USXCLIj1U5EKAm#cIe!1q0Wp!A??sZS^SK*4@(7JiM z=bs{|Td*grTl%01t47tVyvnayl~cK*@vpFM>j$vTiA`uO@KC3TmQ<-7V})m>3YJ14@y@7LrTV1T;&k8TEu2EUKVV;B_>H!fK@lw=3L zA^)M}!DB-0`wi6AI<2w#H7HDB!2OJsLj9Rwo1ctV#E@`oa>0DXIV5p(H|&_6o!pM$ zx&brv%Ft`WF3$#U-!r!w;mKaMm>1W^k4t@BNW-$i1XHlF{_V%=F)m;_H=fSbda!2$U43Zx*>o9BWvmwpPW03GOs1s?cmD<{Kmv}!M|((T{Eifk z$}f#G$y3v^qn9$7oomyu9gCOF3{z7>@5dqJoo&SZ62vG`7D5kLNI#@LZBn1PtO>cY zyKO&&7Nb3=!`s}s-9NE$g8SjY40D@1zk?3)8MyJM$YJ>nWYcOR!fyV8u<4-DC*3l<)tlsJs9Wuvw4xqxy;~7 zVNI*r@b^KmG3i5;#M-rCkKWMRE6xjclo#$8?|u&I-RI0dz>#GSDgvrMfbt#`pM`9B zxTUbyITo@WGQNUq9kH4LMi05A9nOJ(`5a;iIE0%v=g$^#ib)zc9f{Ij6CkYgwL-73 z?yLsbI>l-qmvpuyJdtS%+g@?h2`)jRvbP1)iWM7tC}gzhVhNq_9XmTXP#@#UPoOeE zUJjA(jB8mBdB*rfGI^Rx^(S2QJE)>wK06y?D1h4)1>6U%t;_+C{t#nmcm~F`oVy9g z!#5}3UZBu9&HioP(JwZRx6ZNrVCO?%G& E0fuMi1poj5 literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/ogbn_mag.cpython-37.pyc b/sgl/dataset/__pycache__/ogbn_mag.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d73dfc34df927fffb5744fd5d94de3889029401 GIT binary patch literal 4571 zcma)A-ESMm5x>1Vl1GZ9Na|bulI++q?L;byq7Q`;w2d7%g)6&BP&-7@1I;~BCmrwT zy;by)yFNq?0u(6fhyDT7W8d=k>}#L=9~6CQf3tU_B)SDk;cjPlc4vPxvo|}VU)O3C z1HU`3KJkWYhVc*TOnx?qU*d^JcyWWXz!kd?5)E#+3Swqll zOF2=$@Av#}(3{|)mL7|+J;5tsob8uHL&Dal>^mlzG!x_099xOWtrhUk;3a7%#$Mr+ zC5Br}jBzQk&e&HLdJdj4o(eaB#Du=SPtEO*ha~*_B=_iEP`-j=i9-Yyzlk5k(R(Xf2#ycHLgk zjbqpSyYa8~_HH!vw&VEK_UDhEd>9-bZO4OPn*iYCY?D4N?7cNQ7Ih$=#}n%hfjRhF zHmmHUk;B}Y=+URQ`w zr-{VO48DS~Sf3%Y&Q6vm>Rg_5Nh$K>1*M>hF(%baPDjhwvq`0?QUm(xP72?qTi>Sb zg1Am#ZfG3mN)c{V>_=kQmEszhmy{q?XOs}+3bJXA+-4_NC&;I!x}f9~KRJgUjo$%5 zF4?0%!gjD}10$;9@qy>c*-AdD5CdIA>2c zrYZDg>W!B{;HxB~F?@$G-Km)(jsS+alYK`Ywk<(xv^m5EsEJKV`SF|uRd#uF7zR<7 z-^T0=6FX^6Fip8rNFQL@93Bnn7dQg%LP7{J*NoJZ2pN0!xNwGWmRh=3%zHLCC2p-5 zW9Q7s=m*pmxxI#?jzDjGZ}uUD_-SFvDa>$+vOKQHxdfr;jOn>0S>3B8rKFG)`;dT@ zHZQJ`Y@7D$o4oWFGj$SYT*sKkaK+1FuP9~Gm4xzv)C#*g7Or>oJKI0l#|L7 zK{-bNBsqd=Qq2)yEJskz5mf#Qf|^D^oLL0*96@c0U|u7bPijehieT;?8#I!}J2sem z0Qm!!Hhm*)`9}8Fx4*K}g>*4pN|%quaz3$6*+b*&zkOwBZ71>Igv z7InLxEa~=2vP|t8_UfO{je(I~%WzwWrs?(cM!NBa9jyYlO}O8%uPYk)X0oZz=}ojX z;J?M|xK4~R431xblJfpc6P7S!XH1|$&Gr6^I zy|KjW+)ua3PwD%~?c{p$9<;iX+|cd&$p*C=#Vw$?sZre3DDG+$cd+6vR@}x4YH0=9 z@aJ*~V|V(nBs9C7-ueG_dSkx&@wp+FA@v3%EhJWQ4VGO>mQTxRTdp9)tn$XG`PV1u zz2k;l8?PtraXstdw|c^Tn!s=^?;@)~ zsH47`?;@uOd|dV1J@nTfh&w zca#km^R|HLqhnQ&k?8FzI|{v6S%DWS6PIzN7Yn2exVVp$#gTZFJxtf>i?RluRK?>= zT0wNffd>nXMqbGEjVOu$Ea;MeRE+9NPFYAymCdoIDt-DiPpU@34*}FWLQeBRc15-I zI-NUJNSW*c6YpcVGcTe;H}eJ3q~WMD-;07v11nnFL(d=V%1ey-F5kN}P$OC7mrPQP zB%ODcF3LF{u91}UZk=Rge$#O3B;-PW`J$8y<8=~qp!hb=CsCigfpX(M4ukiZZ*a zY?)}Y!7Ao;cHP9RjH^dzR>f-MV|9$t&L&p1KwU_$bC5lqYXKu=q0{s_Ku4E5rg%c- z&ZWc(nZwvZ?qMYsu`+2il$?ptWGn_nFYy9soE1u_-KU?pJ5L`XyMF3E`|Q!fPe1-b*IvrO0YD`wkHt-r zO9gm|LKc&S=m;Pt#9K&3e?evr!JKNp6e8c3I;WV)CVeuvp+-KiGvXFHbj zwdS@LR4IzLsDM(oE~BV&QAJ!i`++BZOPy-m^GC;9@u-XMO%=0N{0c1b5%C+*;N=#D z3W7euza+GWRSPxd76t09mpBr-zRRNi(z^_lp@L4iYEF0<*}{bw#JY&i?hN`C&Na;_ aUF$gyGFyK{m#$W}3is&;#cP!{tp5Pec)oi8 literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/ogbn_mag.cpython-39.pyc b/sgl/dataset/__pycache__/ogbn_mag.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5555e1144b2520a1500a4212e9e8a76233f6b927 GIT binary patch literal 4631 zcma)A-ESMm5x>1Vl1GZ9DC*mvk{vr{oJeU=^q~lX0I}m3aAl(j+J;Dapt(otz}u+jezaVxTZe>8y7GAVH2!^Ws7vGCI1t(yq0P3m zmkb7BKkP;Q2_C{kZF(KN=_3@1arP|}EeV|;v+tQ?(pJ{7ec~vaJFB3d!dlW*)?VqD zDa$V_>!hNbQ}(rko`<)Jx5h1?aiMBa1=pk1LU|i;Zxn3AiQE_@P}$~DDFNK=&YZ6O zu@~(J_k~D=E~Wc`E-F~5J>L)GQ2M?oQFm!9l75h;L9Tm$+>^VhF8Le=)N6hryzdW# zI1oJ<_~|$brLOr>^upK=`3rFTwilUJ)t5mkso8P0n+DN9y97|XB1t4*(AuAvuKIpI z>ZPgg|788r+1^b?!FHNH-~QskqYt9P7u)GD+9m)vLE9vc51ns~4@DD*=kccIqhKC> zD|Ve7wF;P96WRIeWMb8s;;fEgpeiQ$mb;KSi5^k7(*La<>T;nY_W{Zv`1%})G98cn-_#a zH)$dfbEU6hEH!7zZnC45iAWbGy(Jjg^qgQ2$2uYLY)(ha+H*;)sbE9;3Xe+PWn156 zoszgtfHqL*=9yrGU6%)m81QY=WGU$u0Y+J%MpX zzOXq>qc7ELzJdn6Qc@nncL^09+ZlWdpx8UVm^1M~#}Tx~s726D6Wvsb<2nnh>ht6v zj*=d~h1oeScC;|THKksF^**M~;We;6$B}p&7Tk)xZe_NF8`*PDN~dshnPYn8qUVBA z;m*2s;+U=P5B=viq>Da}yIvU*aJa|-V1l$p5|S>J1@iYlq{ z035K=<>hta?b3b&n^*p3XP)v-niw+}PMK+y3j`hz%w6d!?X2-Iv>{@m`+W=yh5WiyI)C}fLwQ0`h4b%-FzsZ{? zOX>!nS2u^a+Ol`lRiJst><_D}&5Jr$ZFPh2Ww*#8hI;O_fO^wVZ>t;X<}>G&BUE9Z zTV$W?U3E)cQ|~~Y+v>Wh-&Gq_8x%Kz;)X$S$Dp`lP~66fJ6LfGE2yRwsN>(tWsKb( zK!=d*R(AXU+v%14{9kefJg-+?|;tTJ8H@GlS`_j z-lIU^PuS75JP0CG0Jhv zxz@46thoIv>y?}S?nAg%`hA7hII@Ku@g5wDAdf4^+jbmX?(NH61PSE>5#stb!>`SR z4`>3zwY`hT2d+`?@&@ zgNXadz)$5qKNP>gwAXL&ye4<5ZQs|nOtgLQv~z97!)iAQQ>mNr{>YCJHKw?q`ck{k zg2PnLo0}nwdGI11E*EtP)5nLpBooo!)ov08sdl0u);2EcT0a#CDR8O3&<;mnQVlTO zgg4q522z&~b8aQki-!R;G#&>rH#emy1F)b=16(m`E|HVdfLZBEJlK7#0C4sDpvAwynZYOt+{;M+i? zfj_gF4Ye%>n$3n-h$@RnijGS94-^(_+a7DOrn6*UVwYGO?e?3pz4AYMW#&3dti@`y z(%xWetPU>ptJ^Nd+tju}x395lHfH5iGeWa!POBKJW6XWyB38Art_9k;FuVZUH*pw& z#Y2(mGz^c>q15P^Egn&7bRpVVLxMX6QOtLC(l2pwGm#x{f><6$+mrz(R2buB`v)_;j%eH49txcoOnQRWO zosvpHS(VXJr$K2Vh5B&JHQyg4d>?6wuj{`5e7_eJlvN72{l}mAJCE-n`hM&``NREt z4?q3V<3YVtn@5N!v7tXpuEB)uB5t@ zVsww;(0Mkixy9>vj`xk3(^=G&B;BH9OS>lJq6|luaRoh#0`Z7C^|T+354X~B58t0I z=e3|vTYN(FRx*6LMXrP3kMUQAv9oR?C*2~)z4a6a!{mbbEwKDH0coqCldcygG>mNF pf{ap=bmw;neGliPHFBH0(EBLY_s4Vzr<4KJVM_BN*{!pd^B?PU!5{zt literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/planetoid.cpython-37.pyc b/sgl/dataset/__pycache__/planetoid.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7122815dae981e98b12df0b63b972dea232c75e GIT binary patch literal 4399 zcmbVPTW=f36`t7(xm;2-MajC^apc6gFleMUXwnM|qiCH50bDgrf;c6TY*(BmwbF8z zo*hOMyW5AzL5iY9+kZek7U(bNFX&^R2Yu^H{oW_PwBH$ulI^B1U2^Vo&YU@Ou0LE_ z@)^Dt|MqUA+Kl~!28+)@=LaacLM52s3F~sxT3y?;&d}{v_8p9@F04{KdJazsT`AF@;4b}wNz{(=zUZL7yPIyTMEoN~UOVUCb{Y<&?KD%{`FJ$Sq>9ANsef^Y-A+UMABD*{dRxj&Y8SgA zDJuqP5X5P$f1&5*){e|F1%e1KgrXJ%y(G-@Aovq|{Mzn8 zHjH-j{K4)A@7#MmIhpL{gJhRTfz0fV7EiP@I+0BfPbF`lV%)=T&1&!ppSCY(-C10< z$7vHj3&le*@Aon~f(yJ82@;8r@CIJVjcC<`|IjP=T?iH4q`s&5L|ba2oiIrznx9~B zVg@JL9gu`ieY){ZoQi1j)Oj@=vx1V}LN#Co`}BIr6`S#2vqMX;nT^&VIXScbXgzRn z8lTpHX6KhB{vB7_Sw7OvVHT%gundm_5kn&SnH+}TWI$|%K3O~8)E=pl8FRUY)8sX( zuA_SF-%jFO^|ImU&KAZwaiYSe-?(HJVJ7&w2<+TWvR;_vchFe?8gw1Ypw_j;S^0eW z{6soTH=o9}P6v{-F>94iH~!a=PjQVD;70^U2nC;W3|eFf(ya@Qfm`;Td=Y!5?m53C z@maovN;iW`=?2r}E41<|s`E?8*J#2p>4H=83g+n0bBY~4y>UtLCB92mUC>QebWsf5 zx6vWGCvSjSc9(tnYH7h*SnEvjnN?V4e4UlHvJ3l^8w&x8aSB`Upm85&+1zXT)`3!^ z{QJARgIFDm_jmAoyGQXbKFP-E;8GuU!S3B*n5#${WqHhZJFa{k1Q=duZy)+&w9g&M zI8~B}>^Spxw8t`;3l6FNycb~Z8MZ)E7Z&;^pmXDVxWgN~4b_{jT(an5Uyo#rK05A# zFF(eBq`v?c1V6*Cv@j-n2@RMzXW)teS1{+2g}lvFWmctL=@bsECP4ri9CD`D-p`I- zPe~J!cp7czqfqA2HuOTqlSicVkN%BsfR6b(HEpFM?S@i@C)&=(L+y^TV;If@v39FQkcNu#RG+={MlzfI*sWku;IJ{?ghcscMr{&B*J(|C~XR9<1JH zkSc!4KdF{>>6C6+DXXPd)&TDtPXy>q@>kaHKP&4ew|>_< zS-tS_YE97I5YCbH2QKfaCDknaQ?9O1O}_HfDzB*4Y+1F7EAW-};fk;~Ss$8GwhE_c z6-_`iPx*#9>r1O|-DX9rbbx>L9H#OF%yLmJmcJHexvb)xtA$%QgnhOKnQIsAqDglk zyqgRoe=))-Y_?vsK}Q>BZ4~X<<{6_=(Z+f1CM#D&rCbHhEp=^n9p$;AMv{#^SBtBM zH&C8O*|y<;%C&O6+z{1r6S%GxUQsnS+$y%7I)CfQ^Re@0u_f#?tK3pA%w8J*9KEx$c~%Rg9h`v1J9 zu}%J-{0`{@m+~#NlB|IweSo~y8Tq+OL^qkJ16@_PhVRfty8JkowI z%N7W%bfR0KI1GDH+B;dyw)-+0!YTJuko5!dmM|JH658%%iFW3ZR450aiO!L|1MT0- z(!qzopGiqp(l{M$BP7QD@o?0+KDRg3Z2vIo0iU<a2&;h1Ep(bV766~a6d|N%?Hn#O>g0}>8~mw$nR^O-BWGcIbIgvA9;q!)%FY6+fN zCtbs-K2H6ss(isyPgO;o(v-qN8shYAIMK$QqzLuVDysNj(Bb%>Hpzb<@x-8)#(#(B z?)trCOyLD4G!KlLTs% z7NXp3o>Ls)%zflOPX~Fr^=ux~2l|{;h;oT01Z>r`YiNBVFdkZuS1iw(Zav|*pGk+# zQ#VOThT$tRF)uJNdJt$o2!@#$BgPJNBM2Uh!-T>zjkou^R=V@=dCY zS(;K0nLg7?=Sfp9N)o_cT&=?;#r7uPH_6@-MmeFzm1+&Q;d)iC=DlgwI#;2;kvh)g z(L{bqXkD-hnU%4t)hypJ@t{NXJdCW2;@(jbNn(|5R8L#Sx-F}SME@8xCZo|yk@wH*I+cbSAULCf&t>IdCjkC9ya7F0>6Q!_uYIoNKf1kC>4{*eUP2=zF4K`Xx~nwFj(Q7Im?L))NhJ4Q)j<#VXqB7tCSZ z8Yo`Vm3QMH2~`>i?etO^d}Hl(w3qh#Q7;PO-jWC+g!)u$lya|sWrEAOW2!OKeISfPR@?pQ=T)wl-sDwIh%XAea1hbQ3a)nQWFdWIR|Z9mVsM) zy-0;w7)o6ljqeUaF}rcj$K50thPz3scC+zllu8wfTeIrLJv!~Wu6_{2aL2P;zACxQVwjJ$cVcgda(X1VrrV0cR-mkf?_6r`zTyQpY6Yo1@N#|!3bl}hr};!%YNDMWjwhN=F_@acRJ#Mx z@mZDbx)&uPoIG(-1qZC5WY3wCa`x%Xf-5%Xzh_65U~?O-Lt1iX{n@(b;7mTN0M*`a z3;f%zw$p5+ouf2Lz+M@g_#%Qn^iw$uz(=2$3ahgAZflPWipF1axq;K}&e zC{w+3IC_I_{4W$Xd+CxoR{I2l}8LkBGH41-LC(wNFOywi5&Z5(cRp}o7X9%Fm%NJfd0L}1%lNYWn5 zc%e9C_=}#8xyRT7J6-zeTY%0?7~&4E@)itlwsy&)i+w$^v*kxeUhw5j42(R%2f@$q zD=du3W5NpN&KdY3z!%K9e^K zAJHh%#TtsysK0`z1Gdm83f-Su+oV}7(yTc}9>kzCXw$tG^opPIPs)W|IE7o3igMu< z6~I+zf+sDr1=WV{&t&j^ZOXJs?KY96>@CA zI_knXw!Y-Dt5#Gauby&sjcSUYr&e)IHRr3Um0v^rw2#(=z0LZto}!sMc{6VSqH)T% zaGGGBTYYPv<;}tY{`GU1(iNE1yqvH8SD4kJjB~E%Ztf8F`3AJGmACT767M#{2=}`< zh0Qnf7U*c86#xub5(Z=!rBuaK@|&-MKJ(Nic-qwLzu zV#P+WS!{`Nu?<|;b1yHO8}8&gkDR}A^?dBSmG21q%qn)&GxKNj9bn#3&*dBWX1}u zrEEM|^+3K(mcym|5n4&ULy~2}@&cqKx@j68380VogBkeer)U$o8f~52ZNF&TnJxvR zQJ9ED;^M0D2>GLvVoKZaYe*_!^dnt<`;`iTo*TsK!xDrVQ7iZG=)_JOn-3orN;Zbm@Mkg5I%qRf>)B7XXp_1XotbaFZ<5 z)n1w|5m?DYHv@4L^unY!UCy@pG94m>_mrRZeTuGdFK{Q??xnGI7Fkyy2cU`0k-bA* zeJ4!@zXJYLO7f{DL}?poH1>~&qxQ{(XDX+6kHQ}CdCM5Q;NsH+vr856Q$9o0vs6*s z)Xim{hMYW=0^+bFa0Ld`?nxLh_ITYGfu_s4OpYii<$6ZlRM3Jj)~1Fs(jYAiz# zm5hh*h9cAzs(m#DV{Pz9_J|fiwA4Z|2>p{V8XPKJF$1%$6bE<1IMaOaxP$eUA^sJ> zn^ywk4q{T%+OVj#5S_N*Y(2iut5)5%c=KDw+QOWJg4$~Fn{dSBjvZ@*Z=h~qrjFeN zhyC~=!d*=&&$ECh0vrW@J};1KrL;ZP~eEl11t+{q!>Va5{9mTBSk zoMZzyT;O<5*`M*mQ+Vmn6C=0qgmC&clArQ{r2O*XI;!a3u&L;uHuzYAc8hd4yY+4y zBKGe_$!M&0k(x&W0TPM#E`mMMh2$2Z+OGktYa|_$h=-ZHPcv%=sd{?|3mb+BUKb0C zqFnnt5Xx=ZRk_Ul%gfZ?TdK(8!fHPLcIX2|hIg#V-y3{~uY=Vk%d=)XSK{Jh503N4 zk@Ns&ZjZN)=bJp(_jT3xhp8ANll67o_wS8^m{K3ScIdI7J;)9YMoky_5hZ()(zJy! z7u0LfzLoQ5NiU3Jz?$^XRFplNOxHZ7Rxrv4z~haI#j6e#uk2O4ADH#_b?AMlPEvV1 zk-xz{?Sj#Gy%>K~PP09ebvaZolC{z(>K(_SjA(N`>qVpK-Xf>brA2F8m8675x9~dJ zBbTv9DV}+<84sbaKe>Phl!R_Dy}De(%-&cN8SkElGpU!>oFwo!!xl z+|$Zx#Y_&YL%2|c3nvcRBUStb{0s0GbaCT?@0@&r-`9F<>}(ZPjJDt1-}m*`-*dmW zvQlOE{`S!i!^0M1|D?g3su|FE5oDMmti3MUKEGACeq_1PBBzzm(DIex(j{TLsX%; zN5>MYpRpbnj&L7WI(F!Yk|;lBr@Z5$?}>`(mqb<69YaIAjsSKW4yL;M8d%dU|1#x#l6SY)q zqVHW)IT!q>ESn0vJLSLTiYqH;r`FWYEn#n9{mkM_IXOFYPkGM#Qf{Lz=WOcb_8EUj zqY6qDr6w4ta_+ZnSq5zFbt4sKVJLNJINBeCVsdrP&YdI}ggZ&9cCyiMm`W9j>yzs8 zGJEa1uKqlTN8#I2rc%4u6-rq$Kz%<-BIWy%E~3|lGVO+07S5i<9|r0m)2=V@Zo1@W z;EB>V{3YlmpM1X`CZP;eINwzDl?MRVdKDj6P}FTZ+R4JWryU|*J2Fib zC?fj?|Ih5{*Y^(6LAaM?kM`bs=ick_$#^g8$9qH*WNB}>e6(!-XuHD` z*}!$EWOY=Gd-$zebzb6=t6!#mcfo}&r``B06c5Eb?W^brF7SjTawKfR>w5(^LRb;i z$6n6wLj3StwfmZnwWY?|3F3IH`3VLmW^kh2KIy`wN<7$&5)qD{!B+vARg~-|sy@rv zht~?O*p&a49a@4-ZL|)F(3$lI>yZO&d{P0J-Cq^>YrEP`v!Qkl(n9FrwlUJ#_hU#hcb{u7@n+}F|wlU6#6D2lz^8&6B3Ia3} zlik~K+706D4mt}625#FjICN!!G@DJ&Q4?0(coxWlTzU5J0DFTV zQ=v4P^_1_mUHLi)Fig?jK9tR}{q>(=zo6CwE|U=>$%LRnFwGyMdI!iMF%RK)b{A81}i`ivy*?M7vQIB|U9HEnN`>{Y={;Rp2Y>oVEsIU6Sds zIVoX>Kh_oL10QvtXI<4_QbtlZGiz3pH+h3s@teGG5%-p7mhwg7Swk`M^+zzZ&mdnn zNxmow_|)1WIct%eP04$~kDWob9<4oKkTZVDAC?Qda0<6573IP!DmY&?9u23r$nROd z|EQ>)+#J7{*8pj~sOMEca`XDMbOs}7pndd!Jz&Wj4pZf$s{9LQ#R_MWHRa_iMv^PS zNHj^Yz}hD(AFyAr#QFfI1p66#mwmugby~}-c)kW`uM6kM`kc#qYDG2j>M2*3s3xC# zY897Mb9z~|@=Nf!_Tj3qw^$FlQ#5lYZ{`i0Xq@s*gX>ePXWeFbvv2@^ZGNWog)^7) za(?;3Gnb1p@Lb8=+#%@GbqHT8Z{-cT1Hs*57*!V|U}4jZyahU1z_ppTrdwx>MtKYP z+$~nDic+zrwx?H7uH_XY-dBL*h5X7HKYVf7duiF*vEjLj^$+o(gujd>2X1KEze!5MrF;vmBwHa#387>G!WZ2% zjgNTIhyTG04D$=LX@@ph$NTMD#(e2gFdT-7m<2el8VwO{Iw_{K4LgRI1NuGGCA1B%3Wx`wN4t|y zrkO5%oT;FDq+OL_qkJ1j`Jz0Qi^<+AL2k$CB?fK+U zGhOYb>0;f6<G(2 z*|Fan#j&obu^*0MCPgSPI5{79y)aNC8G@)}G=Sj~p{`Kvs}nHR2Cw8ViI#w9sfD5+ z`p03^KTx`224-6+4)()1)4cz@E%g?@o!&Zo9~FZaTH`Hi-E!cIEZEa+SXPhU_SpgDK`Z7#QfcInk97$?_n%<|vS zF{y+l?<$yyL%94km`?I-i_xN;n>ctl%vi$NG%bvtlY9ag?vW)Q&(}(u=P)dR(c%Pl zUv&NsT!1agXI%aj3eTsm@XolXY!MU}FpL=DyK-iu^*v9XnAn9H|{-5s`pj zMvkz9Km&V|>+Q0n%uCluQcTVeW-_ChD)dgO>={4pjB!zHBy?UIpwGodTzc}jVa zrg#f^)r~n(KEUrIX?ix`)6M7eo*uwwq&t*yG$2K5rd>x{HB#Y0N<3?F<05bUTn;r) z{3PWehJWo9#PueO_kCUU{Xr^5h`xPY_x(qsAg1sPnH;4(+Jm6MD5{fL3`dcve3L3; zfF`uXhUJs^-vQGr^K`2l#xVdaF9|879Ka+7CSP2^C?oi|T&3>TU9ap_yf@8S`wG-G zRL7}28q0?S)&K?_RB&O+l){TZIyR)cEmu9W8uu{-^ z3pw8|S-M?{L`_C&44c04iylMXz7`3j2 literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/planetoid_sampling.cpython-39.pyc b/sgl/dataset/__pycache__/planetoid_sampling.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..686f68b0df91a437730405bd5f7e86da13c36236 GIT binary patch literal 4447 zcmbtYNpB;^6|U-qY?2a(qBPn@9&0>ahDa=p;RT5!7zr}A6TnacJMlu;Y4os~R7=fn zPIYlaG2J;tV>p)}&LxMOs2l%DsQ>a7YE z7pn~4^;h2sUv4t?PZ~V>IT-v9B~z#b6Fg>ZZd$8to7U;O?Gm4^mD`?~tMsew+H7v2 zy?{9@UhLP~bu(wjSNe^1!}Ohaso!ij`^)WR&faCh6{QDEl)}cT-Chy=1J)`(#K~;G z#kKpc3*m-U(EYg{o!g#lcbR%PcVL$Hs2bRYV%4^f5Y zE^SLRzhGT19N|8wwC&IlB~gCBPI=o!-xC$nFNvzCJz(v!SP+Y7Jy91|&{jl4ETOG_ z!5r4Ef#PM|{CONCp-Lm+e-QKsag_A5(@ACUjkU9@y|mknI#CdJ=F|{H)Td&jWbdQO zx#0J?Y$$N=l>e41uB@D$S`$0BguRCKGmA6jQr%QeYz9@au zKSzDeE8p*hNhkvq&NfwjC4(sOBQesRe?N##uj1nhin=XFJ6RZawL`>fN2aL)Wpuz- zTvvR*69-x5`(LqdUfn%N`{8bu-P`@(-S=LOPe!|0FWx1ZAWgf2^M}i354STok#$^$ zO16NCaSy*$Yk`;e_}bId@65T-;k0}9S!n5L>-e(}Obog%{k29v<|8m-9Oiy6H1{ znvb-lM%oGDc%=CW1}A24qTL?p!MIAd--!|tj-EQHf&-ROvggf7Is0s@;EGN7@7SRw z*u+NbkN}-of3)s7IFpYnK(+Jh0)H)6+i5n?&S4rQV7Ls9eGx%ax~c33;H*zhZ(>FQn_WvY|*2XE1h|Bb@NuV2D3I)KnjMPz3$PCG%Iy@k#k z-e7H6<^)}tADK-L`!4WKPpm(VX$=b*$s)5nKEClkPkxGzq$@urBy=f(*CHN5;VeOF zc+Nxc&)%2B>#;lIpnL%%`64P^_b(IyY?QZXrG;vCANd`cFl;;LoxFlM8qLUe`1qwu z$iF~*!M?{XL6>(?%q6~u4qXC-vpejwTZIK1W34g8CsuBq@ikW1%FgXmZtM!|%qc9) zy{fSt_Wt;d1EmJpo4dQcNF5CC?%-v2kD`8bk`9yJh3f3$5qA4Qrb21->Kne>a^-Ct zZkVFIyHGWwsO~^UiIPNM%bC8dJ(TfOrpR1Odp_o#;185~ZsuNZ=#VbXKRWxILGNH-UKk7!{0zUs!k7#$G-cwPfk6Tc!kkMcc8{sjq)ffS$sL$o zg5bH);qkTi)8m&D(wI0Jhuhg8kXg75{gKh=A(@7U-{R}d329Zdl?=2SNEw`HI~(@3 zJ4laVq08MkP%2Ec8)Z?_)fQCK6;aU3v@KEv-jd#Ft3T2unI0QR2}^umSEP^gsQV=Q zs{T1;B#u+VW;J+&*TKUoA3t*$2hU+1lQ9z08j6vzufg0No5~o)0-swOBx_BQwF!AJ z__8y|*S#hrji2&Q%Y|JygnYrwKvEaX*Sar1>q z=?uowVEXtYmb~ULRX+JEXT>6C<1OXoi$;no!btN+>=Q&N#o})k6G$_)x7ls>5#}8H zPHoPT#w#rQ`ZnX{_X%oMoz(Iw9<&a9SrE>V^(B{WwW#WO^^~hCRFfY*wTdgMFRNj&ZxjyjugqXdkH9SD<$UQf$iV8ya16)S3U zat&oGuNZ;9iZh*u{^&wJZ8vskfOtQG5`Txbc3yd=r#xs^|07d@Wzkm-Cf;1JrK{@AqDD%}BrtqJr1|7_Z+@FHT;{ zujMt0c<|b9&})Bbo&4$lc<<&b>`N>A`=c*z>|AwUzDerDrTjixNtQ*DxZF*`Hmn}v6Ig7|J=z_I zGR<`9ex`!Xk#<#zjq)b|k$VJJRtTSyEQGxo=|&(9gHD)qPUfhZU77ab?>ov*yFPi~ zOjkQ;I$yV8W3=5#W9>{Mxj^uJmf>bpGtw}uq47hMA)I`p{`#n8;&hM zKK8rAIMy{a^1~4fst5%JCl>>+8wP49LlBh=`!JCr)D^0IbppoP;FatUErMvNg`yYw z$6?evP`Y9UW?Lx^?uK!udGAT9>&+v9R{(EP5C)&s?O)L$!@WA`)=SNE5aZZD3Dw198YNflb#)08H8tX7WB}B*p$S zUn@6gOXVUdDbI5fZ!U08K7HfKq^1iLCak1)^plmS>G`Ugn0MNRMVY zo%D!G;s;VR&o?o;@9V1X_fs)MaP8{_-@i8uVv4h%i~w+_ip8rA6|d}7yw}Zo>nc<$RL7}2 z8cClJyI?d@EMxM@X|`j+K8NaQv{)KMoufFE5p7<`I?>={XBtcC(zG?!P*NbJo5DXU1fku+$L!$ZgG)O>*PyP1*Vu#BP3N O51MESlEVewu>T9NY?NpK literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/reddit.cpython-37.pyc b/sgl/dataset/__pycache__/reddit.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0294b2fbd6cee77ad5baf0be4473c67d16381303 GIT binary patch literal 3228 zcmZuzTW=&s6|SnjPS3?-d%TO;fV@NsjFOBaB#;Qp5{1AjD4K-bh%6&esduV8Zo8+u zx4I^_Cp`#aOUVmD{Dbi$f2m&K#qW9A7jjPZjGgt;tvXe8>VD35F7LP75rOZG=f9JG zht_|vv;JIYya%QH8xTPRP05JbnvIyPovAzWXg%hS0?P}hjZw7XHAhY0v9vX9kJ^^! zq@C%;Xv4PMv^(7#ZQ8b%ZcVpG+tc2tN68~1d=Xp_5y;JRXS5^e6VeZ_;1*=qr=}@S zwTj1@&yxbW-Xj$khsYo1LVghIxRkojOuINvxstKq2T3YTQ{*REnnPFTW87-};WP0o zhd`u#B0R7A6*-{772ZW-&<;i8f+p$0B`}a_J;F~`Zz0vK~~J?-4$bka7E5bJNvNVHk1-~o{)Or5QW2i$2fUe822boG7y@IPq;`FV4e-$F(JYoWd-Wqf!TJ6x(BTn(StN8^*Emv z4_||RiBIy$yuW%f?t_bK4Gtcp`8ZC?htOEBs*h)OU(;A0RMza(ad?vHyu8;go@T+$ zLOS31zvsOq4#e4Sa1q{9&{`A#b|%mS$dQX8xnv+}XJ|Yt(z*A=V(`VH?oZQAV zc*1#9<7dI=h?=#dX4fD$@j zZ40(=7Or3mPdAp)^?V2CFZ{|E?k)l9S?ceCUBa4R5iS~4aMTpRE?Gn(Tr_oS*{(uZ z6Y37g*{R$rs+yHo1r<2*e}s#c-T>)$MRd;o`TIrttgAPdTUD!Sql^8FWIuI@-ad}h ze<@jXD4F+kuj*J0?uaJVpMkCfjQzFql>CWg?5XAafjlNpYb~~huUh}8Q{>>`AR2Lw ztydg|2@YheXG$6`ji1TXL>a_4uFnctDs)PqD_93sPfra!GwkHq6|(zx3Q(2)y6$=Ik#cqsQkn@;~eyEy*10gtKektu!nI0N9$aT4^8w@o=rZLS(z(^5I1gK z$jlmvaUjmZ^4YX7!G|B)XrX>$=g3p=(Hkpp#o|c)0@&(Z9BWE3k^Dp^lS6G9wu28- zzr?|C^>9F_eR-+AXW7|o%25WUaV-f$Cx)p5L*;GEhHKM*6V{eL0)mcyi{7Cw>(SQ$ zxGs%Y1R(BEM!WEXXS#O&DA}-gT>mB#g*p-qOb0BAy#ojQUo=#h z!o30?^$u{=&wyP0556mW$(7Zeui{MP)A{{RQVD@^FUg9T-h;@Sh&>xu_rM2esj~b1 zrh0(8{0hmnSG68R{Tg`2*A^2d$ntD1Rc-<$Yv61TbN%78Naa*!kXO{OKu^&3>!^e( zt<~MQNdSU%{14aR|F&AYw-bo>V4t!F1l8N49Z+r$lpIo@g=~Iz?U`<+9pLjY6;cf( zD8GIa(;AEloSTUAX)b0E3At%;{(Kgv^@zDO_wFYDha#%WHXI81C2+uwxGp59x~4K_9mb@^WAu^jo+I zYQ*?OGCoeFdLKKja-0-rgR+QKDUDy(>gPDw>Z@nX8Ds%@5_pF30Ij|b2wRXU+3w&O zfWoFNwW;sCIsn;}!YF4p*JHpL%yg2LHYKd%494)Am)9+u9E1;Q={`hvWp#%cSZiG% J6Km0~^DWcrMo|C& literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/reddit.cpython-39.pyc b/sgl/dataset/__pycache__/reddit.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6579a83de2175701f6f73891299eb95373cfe502 GIT binary patch literal 3253 zcmZuzTW=gm6|SnjPS3^T_!3M4I06MmXvbLuFG~~x1frEz1KPl@b_6PQrpn{Cd%8Q- zHL*S0gJ5lG_XQ#T!T2$Msa}zg{LEW~6waxhi5-V-)v2me_jA5;iQDanz<2+}Z{^>h z^)KvP|6J%ihEo0&h#-QdWJqnzhRoK^#2tFH9`lER<%N^RFk14O!zS=p+M2Y7ZOe1g z&SY)4X4`JsovaVnZQDz4O*V!blg;5KB~OX)MQ~0;AlJ{F;g+D!NH4sETaZDInx;I_ zDjsP*OA6?EPgPtTAb*ex`6SkHDRqyTc5#$)C1b(&lT@0f$d9u$hpx^?xYhL!pNd~O z03z)(;fdX^$vzdX@Xi}UN4oW;&gjsCb|@O>1lqocMDv`$<6EK)btpPw4eG|%#3jQB zF1~L3Po)rvHomg;Yu4X0VZOhgjFLDVt9 zV+*HZ!nqCO7mO0^R^-q-qZQ$P&Wid<b#^+^`f(&l+dSA}RGLtISaybxjP->F#M4Z6JFXPm<8eGd#W3%TP zx0LC=aq+gsRe7%AHn{04Y8sr6(zq-+e??yZaPJ_W$i1?BvG?rhCm*DzCwt{M-NSnT zH1=vO^ovu~hHI}uDZ4-j4dLHrE$UO3vNyA}`lkKWD@PP(Ek6b|YRPKq+t3hHurZrc ziy23_=V3*k075|G-r)6GeW;a1KPzUhyGzCx!Hb;tcK2b^4JakHleIMteV6Q}_I?(+tl;zD`I=i#H`PGhB(t79q z2oij+z5gRW6jPUj!AqfN^WW&R}-`Kd%CfR-p+S${@ky8;qDNio~8Z{*eI+C=Ha|i1&2)$?2vgR z!g*7-7VRp8HKFc+tewiOqN-VWRZxL@|9d!Z={3+mS43y*AD_+Jv#wrW+^Skt8=daw zB>SmL^v3DGDVcXDncdZ!RmbXZOEj_moO}WLns@%v$)dV`Kpp^B02o|y=1`J#KP2UU z9}sH)FMwM8wgw83HtZBRI6nwjoMTIU2bd<ja_5caq}Eo74Ag z;sTD=xf&gq=#xAf|3PMDt`uV2xOpKnYbnNo@C(bQlfnca|IP*v^?{uuPZHf*TcRr# zhYBr1{Q}3DQj8@(mdW@)n}+S+!_;FO43`fFh}xHz>RXnbO(z^>U>a{FVd&T}wQs1r z3EFUF|38Mc<$WON_gim#woP5ON$&!7T^g~7hQRAkM!Oc#SJ7Z;?ISdp)rS^@#e(@8 zXkgqvV(P$xdg-79xF-M`)84|z=)FS_+|UDH3-<`d7NPKf=PY^tHJ-cV1=o09whb#o zIOnBqd_}WRN1}mQf`ziT?|?Im1`4x03mx?{Y<&de@(p~K_<~ETH($k>$S1RVzfUE^ z$lW9>rg|5Ga3XeXh}{JTpry+0^_uFJxXZ7QymhPAov0_kGrqQXFhQ1QGpTYDC|QGM zW032QCq*hJGJ_POh6K8Ro?nM0RA;T~##Js5tWtq+l?mKbY4>I>@fh|gH-VsPH)#jd zx=FX8?X!^0Zm-JNrVXj z6@XKHs(ysaF;=LbAo(Mbbs(m_Oi@NMO(A8viY^!qZJe}O6)v^;NJu*oqmhdx48tuu z3Rdg^-XV9fTIl1(L0Z8rq@!%HZO}5*A4zRG|?+-v~r7*~u_0A%|X-h zZq}N#2kl8`(3z|a);PP*q$mAXO!{f_#2Kth{+M-x*Eosobh+{Ft7N(drnytd^iHCa zGSywjw5Kn#Sfz=KcgI<3nzA^|vjTs0VK!Bg%VH9j(`>9`S+5z=j;qI8GUX1K)LJtT za^nlO%cU=aSB-&_x-yiFSL}ojJn)fhTHco}*?z@fV@Iyx8_IQg0pG?K%w>ZJx?eQ@ zlf$td?iq1+=Z)y~O<3&ij)&tUd;KjO_nN#o^p_gYm;CGCEq6gGzF?$9LX zuisrW9;B%<(eIM%Aib+pp^S%_sZs%L8^-ZC#}&kiOf;KQRSeUzOzTUFr-|Mxjc0Gj z`1N&}%{V`p#9AfeJjO=xFiXmEDg8r|jphGHx8horI3A^WsuG&~~rbH@Xejkwk?j{nZyyx-p|CTYJc_xq3U zKlms+KI)gFtiQxxUt3W0rpM|MoPii)Qnl1RB*$W}r!ok-i zSe=T$iG3Hl@R|47y!Y9>>lvpgr^bCzjB_|tC5N#bE937LYLdWnF}blTjOhK)gcOkW zKUd$zN$UGV-UoRT-O9$L9u|}7?W_1-l080~b(i+iwzRkgM(%z=1TmfXE#%=a^#dCD2xMu4>XX1a=*&y%J;=U^cd4+;XIGaBzuoMu zo6giG^C+)bwQqqTwU2H>S$2ng`oUZvqlhi0`9f6Uly9-QgN)*waAXk|xkH>FpX|4u zBAZ*YW*_cpJuQFS?_2XGh~#|KOU5fem;I?8Em8f3_q(3Di9@aNCVY-GZ9~hOs&TFr zDb{uCdvgv{R-=RxK+VOFo$~{NYn9Y*V0CHh6L)!wZy;c1>#HDHO*(T#2eX@att0*o z=IrkH6?O`zkRz@zI7GZ~BxG|(I&&ajy5K$SF1*UaJP#7@3exM`uY8&Z>R&++&VzX< zgLwn{20E;Q6OP%DZd5@PRuNq3R*gzjo(%sI&YK!(_(fATPQ+h-J#QVa>(-)OHLDh- z@24!k=`!7U8L5AAHg9t_+th1S+g{4LL}I5oD5n_vWBVEVJLh+2MR_^}a(&$}0l$PwmL zNbPMRC+;QF)Ofm3C}`1xA|E|T^AdTUKQV5Rr=@Wa?TR89|EFLHzQ#d$Hm*FFOil2s z-`L8aeq`sQM`PXX)JSkCH!8xa2`TI5DDkF2G}gyB*3XjXX;zwu3_vN7>5*|%acG=j zk(q{-+KEle9;V~bo;G}BgxWPcJrAgG6^#_2C4q|Zbzmvtt^gGgkmWiZahZ3xzzB`K zO~%XM8T9}~;LPbW!7V2MTa0AEy48Tk}1Tw&YV20!Ykg} z$djo)HXb$`q30vW*!BP@4DuqMrK&J~f|ewgCO{w%c9l=dKPE!OuiLDP65YOfNaQh* z?|@(rYzUo9qi#Q_X9!l!L%8Z43?CZ@L+8=8e!dI&Irg-*mP$OZgwo^l8{jthCcn(D z0e2B^0eKrD6tk^WK0DXTAg>b26jkrqwFEpx{X@4&9Tuh8I5trnPYQW}>KU6>9Pb|_ zSv_Lg4HE*saE3mveM&*9enR9EB2Fc2p3U(t(YiGF)A8_SmZ}G|+dZ1Kfi9^>nV?5i!*+vp zKDqS&18Dx0LyIn~#vth(=y6us9@7?@`pT(WYN47cOw}1%9u3#vkkSs9(K)RrRXs)$e^w+-lVs zp6dtiC0{ie`zL+8{J8jd7jGd!B$GU4eQsMZcKR-_WM1F5vY;QJ6=^uG^=r1zN$cZA zzhQYdZH`<0*0|kokJtKZoZV&8ll~JX{iHE>`s>+#oB?Tk`SF$?#SWOn zo~SS6#+PiDOJ4?0YJDehWhiS;*qrx0@O9a+yf2%w^@QP^ZMlXvlw+S-@QW&Jw# zf8O{Hjz)U0XT+VIXQJCPVZOUN8jRxf^qW}jlq@;$mph<_c(SvDx7YS^L_;cZKXow_?6--z3HRERaP+xtem_kj`j9vJb#httbTgKuZ?IN8o}yqVhpf6iIb?(InP; zh4Ji78Na$Vvl(TF<4CJ`ltoY!4br$MmeSASbR=JrZbj8rqG*_8iHdcyR8n6m7|F-j z1g1sA)gqk~npm9OEs}KCxO70{syx@&IpM~f|GM?9@XYd0IW$)Li=wCO*;~E6e4O-( z;-Gi`?)yJTk01AnVcJ{bu%|6Bx|3sd0mgv8v0?*+@zDO>5KZpe-*l}yRd;2*0jF`x zN1!FJSioEaBl(cexP>o(@gywyZ2;F^v2n7=^tIVFOgfLZ_zuXBmF&|iGp^Z!|C#Lz z$rcV;m*90G{w5AwXyVVj&u87wXC2Qtc`-5Wem=@zSQQ^da-@vEo2zjQ3r1wljxeHo z!-V9Z^!=2pZ($|%7Ln^9rhYRW6?%}5C$}!+`#*Tuw6i=lDW$0uJi0g2d=RI_EikJ& zR#Qh124vsByWFzN zrRaS#uF+>|P z8hEWC{tM>$?eRM(!)3p67LcM?aVKI8VBlM zVuHAG$% zT4kebQXv1BWj}J6ZXf@hvssI?=@q?Jw(O;>O9Xft@d*;{to6rMhRdP$2D<^NME-Eu z+~F)+e}@(Sy1}^peFADPvA&}yfWX8_tPfRE^t~b65^k$wHB8{Pi5w=yDGlnc+pY-d z6hW0n!JGMG!;g(Pde}K^1TB*xRZLei_hpJ&4dNmd{a@CU}tv!pN% zqJW~Ph3^Si0=99G#*Hfv#}gC$`nNW7sGr#}$>Wjkv@49b6g$Q8;!WEZYBj>=G{;eC>O>>TF zquN~->J~xw3ywX&A$A>v$zVs33tL30Uw}LPA0FdZrQ7uCy);44*dAq*L%nSi?zTmB%;pd6ZXZq)}#_fU`0!d9U+300RK&;l=x eVL3X632#=WxHBM#AjL^>Ab9!zVOmC=Uz^m%J*u0 zo>MP0h_gYQmvjHd%nu8%sk<_0hbrr-n`q>}55gku9U=%`=@_UCr{UX{MSUiWU;4d?80R(}vN zm!4+kn6pL9+4CLO@8T5I|U!F!vOZeEDK)#81Wa98KLnC&+0+OYv&PHz6aXNrns3oQ z-=^vfRO85uX-z~*!bvmpCJj-vc*XBVSe_HYa+>yi3<#D9X+H$2blCR23Hpu zazoL7eUTnI6nKKj|3G7Tk(U7x!BFn9nFNh?*^Rgs2uVD^2JUek15gPmNFgX8EG4JW zN4I zVrq!Jmp7ul65J>5l9`0zMhCuMIA?N-Bz_p0NK55ZTR*t%wa)k~bnJ1_>HyLugwsNu zESCrGOT;t;z;NvK#PMgM%eaIEyG%1=Uri#)=$Z6wVFudo;ued;%y_`oj_5AV2U*_nsotEw8dp3BsPnOh>=Nfu7n0=}UFosMQtiQy;Thyh5 z5Mo&fGr`4!9m7pt5hFy1@Hk07nm6jK(m$mKjeUqJMRZ&i0D-46?WajPOh@>mG)}kv E0|< literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/utils.cpython-39.pyc b/sgl/dataset/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e08430be5346f7084e9f0f62dc1da07abc199d9 GIT binary patch literal 2933 zcmZuzOK%*<5$?y#?Ck6=MNzaY8i^Ap2BVD=Q%-_B6~{mv83qF8pa6~+F$SaUUUJCU znN@dFd|=iH!sO^Ehg^~qdDKtIJ*Ph8;*%*LSp+TG#;#|&s#>@cwf3dl>+f4PAwU4(Re(?JrR>wzMcBk4>g)6LfTL-SJ?Ax`+UQ&kygGx}xn`pE!!LBydn9khKkE$qJOU$DcNS)q}eMo*h3I=8x-^j(%%)$67ezS(w zm_T*$(HLDK$o3C3$SCKK!HGS&7qzT~IbaOw;udB(6;J;(k51Ow=qzdFS?8RcbIv|u z^-m(^(&IhO<}qg{-*DYFKE;U3lqC8Q`xrNz$DhVEU`g~Iy9cfUP9ko$w(NcO$@-i-aVA@S&g}}c0;hYtqJWUVh5pS?sYm>vF}eNVkHg4kr7dgt zMm@$Is}lE@jh*=f7qiX|h3`(vO4H4(PcNF+uawb19+qrEvgM63+=EIaG8FUH^&8P` zUSR5<=)mHfZl!Rmr}bQ%+;x0UD6t=zA2>o80+Nb=GG1p+o}rKMz?Z0sk1L=G2mz{q zxqzyU3g}v*N`5K>szARMRe!yTs+NLQvh9U=G!)OtEkk_qNmaPAeg|kvbv(v?e0;;iQTAE)4~!T;gSb z6gwB?q!dl{rH$PokEmDoJ;nj`xkq3LZnVZ-CgdHg*_1c z5GMm(_tNDq-H_N8yzB}u!U;V}Z#x3%wnypB^(jVz4wjTwG1bT33yEm21oFgfGLb;- zVB2>K=S)sP$oE4N`K}ym>w6cynKM2M9lKl1binWe;c21v7E627x>%VlJG8J%?Zr9@ zOXL)fbY+JtULba4V{7th;Y@ktJFZc>w%QPAd^R!7G_Mb)NRYxcKAr$qj#BlwX$&EU z`~)6_Hx4sosgQ|6svLFpyCxh*x?VbPpo4)`4<^eo2vgL`U)HJJ#v4#rBC46FvS~XLj zjCUJY)lkg!F_=5rOG--;Pmy8lyLE1h>8yg?J5P!#FV)c~LGCZg+NTb=jyS&S3>ehB zra`~HI48HWs&oMo*ms1(mQR`QK$Wi}{FqO`8@}q}VA+L?od(Y!%7+wD%=bvbWLA#% zD{X#6oxUBHv*WjIaBQC}sQG{)gCtk3 paT!I^`_o1pRQg?dC)sbKN)bL61z_N*OuK24_R|6WD2>ybUjnV}lxqM0 literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/webkb.cpython-37.pyc b/sgl/dataset/__pycache__/webkb.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7aa220dfea44481c64bbb4ab41f657730b431b3f GIT binary patch literal 4685 zcmbVQTW=f36`q;hC6~*KE>UFX(pGkhG%RA$=`BU-8foJ+2x_R81aV3vSWui1MQOQ9 z&kiGrSoR@uk-h|NpMwC>W8eCF=Cyrse?j4w_B*qrY$+}hl)}!Qo$I;Hch1o-n$3ve z_tO1es(-FB_7D14{(KN`;K>s-l1ZMh9(Qjs_If@q`+{D@^;LUSyhRd@YrUEq^O9&> z@6{dcCyjBl*K~A{w8ksFmGNqCm9w{*tjOvECabD>;`P=ff0wnxhd7aKx48`mSwB(v zK!Fb4*8RyIi28OW)t$cS=gPE&ZB7o7SgXE_cSnh`b(tNeNrq384QQX`kL00WzKupX zS$aaQead#Z3}od&t>-CUR%Q5ro$y`&x+Wt>S7co_9xzDQlr6kNxguBbu6@dU){DUL zx()8BonQUJRtK3*RgydsovsbD-QCe()K4_Wk%NFv$B*8`jE5)x7)`-X*%_Pjf=hl- z)mIId;)H*~3*QLoolxH^=<})1c*ZZ-i8u`kA%iPednPzDm4fY8Pk6!Nu<#0BFigTq0ILZfEnvc@i3k&^jrTwwmN;9*S zPbZU18zrwjB%hgWzB@@qW@{ubj_N7~+bcFmvO}e9^nO2?s<*Vxv<-AW9V#1OC#Bsy z8^-Y{9ho@Rq!ha`(b+)dxhicLPx@vrw?QmnY+H$QI?>k3cRZ4IGj>{!iKA+;%U zk6!BTWn#?X4*c^9+{N`*sK0Q1Fllvr z_Chu_uf{2^NW821W~x=3^mkO!F-K;$;Rc6vkzyr>DyQBH=T~s-GkEfEq8Y-CKH8k~ zKe2uGIqrMG(ef|+zl!@lP7|}0g)MG&ejMNId_SJ~opdsz%NmnraXUrlrsbnH*Dqi{ zeT|yyXdXqkl2L93*?4k$3*R}(USYGBOV#N9p}Wnc2Xt;F*`S}~w?Qly=_VFDo%820 z;y7QQbHh%FdqLNvmHWL+omCrVwCi%I5=LH_m z8AY54GRgL=(sA}}`}(_>r(Z^6Tk+$Y1v&M%(Rslp=Z4W_d>VEFGBJC;Tz60x+) zId>!mJ&wD-gOGgAkKPAwa2hH8D4t_?*hhEf!thg3h%-L-Kmn2VePENA`_8fGJAf?c z`<#6d%mbhr+#TpvErP;70nXjuxW|sR-M&Jy#lE+`V&A~+=RI}5hY>sKYgZHDOM>w-(%?yeP%WeqQ7vx$Ju<1v)P7u zwperSy-A^+X5q46>;J9<=XP-Awj~!|KN9SsS-mO{2Au{ z9vX*&*XJ&4gN~7ANL)VYJ;y-#XPhjv!1Cb2E)<2_E)>JuE)Bxcf-mPU;LQfRNVSlK zmAJ`2zT|@rE`LhwEb#rt!pjzL@-$?B0^}d@ntmOtAMsa!RuqvIC+VMn{;IRCfAv|` zpY(_7g6MSboaj#mKob2FYJ~FOAWJ2HncKui$1!F-y0%}3$u}^6p^&E^y@d%*>sK$^ z?g@)|7uJ4Vi|LX^tn#w#Xnca`zZ9W=I(X zKeGJTio>1u73*gcl{yeb(uS1ZsoRQlVuQZc{bM9?dwuK@m`o4$Y;-qEhwrI0NA_6( ziqqRq&}X@^5h43%IL)THChx@NsmiT~$ge$83`Iigp}MfXoQ@~9`p&OiBCB6@v(?ea zwAahL<_E5a5MSRRGF+ryxee3l*kxn3=5jI;Qj_Gi;?lC*MwEyl4VDUm#@5r>q4fq? zV(V_;P>qIr$a>v}QxP9pp?57;mrRE|W^t|PcI1>mGQJL^ZvpWmAb*Q*i1Pa?e-RXt zQGr;1e!*8U6H#E9nl3gPkYeaZ+17c4j{rTcRuTGNLpbUXr3~Ow%)PmPWG+SL9mHLV z`T+p60zmdo{h~q;bjD8ujMb)rGeh1NdGVLq)R~3%)ew-+%cY!Us;RBBWGN z{ghM)XfCN5)*`6L$^pRd6yXM(M6NHq&{ubTwF`ZXbQ6+>2SWebM4#|9G<8{XwOK@A zVAn(L5N&I!h#U;qsH{AY4jkB!Di251T;)Ne!!u2gXe+3PWHz2%`%R*d3v7+j$<%CN z4QeWvF>E2eBdJKQwHsyP@D4451H#^>?x?(UczAYMP}pr88UA zs=O1=Y-i@3@gz}Wq%%sEuDpV|?Plr31YNfNn12+6O99}e=Mv=FdAn!E6FtNdy+_T6 z7S|q^UDajP6)uae%ig+rxoo?KE%NJVsAyW{l&8Fizb)sz5ptCcz9~Xp5uuoEF0lPl z<%gt~3R+#G#uH5#t8%T zYt-yga|MlUE~>%->U>nL*Jywmz}^JV!k1!#v*r$fo4%Qzi$*=N5z zfEra(kg#jZG4yn%W|ZWv_;v}WWA3p;B}MYs@K&j{+oaH(YMv0~EETs6(EzpT{TB!G Be&hfE literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/webkb.cpython-39.pyc b/sgl/dataset/__pycache__/webkb.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4382fd4fdf081823a2e687784e45bedb3c9bc42 GIT binary patch literal 4720 zcma)A-H+SG5$7&RiK0GEcfOqS)%2XSY1G=%IVjR1b&Vi#k_HK`0bJCLnv~#ac{-gb zQeLinK7r~(&PDnXH0^5;K=)SnFX-R0uPqArxqT@7;{IkS-KUL|!VZ`FIkPkSn;C`m zx^Llm_0D_IKbw~I5BfOyIQV!AC5u6XC0J~AnR)Y}-E~;e=5|Y_t=ui+&0}v^=~hgi z9s9#-w`%H6TpQNAbyIia#&D^-G;DU8%zEDvB~iX_iE>mwu)E8G-LzWXL+og6wwU&M zX&6UYFGAgYUxuR{RH~b)h;D=`%p%p|x<0xc2Qmspusw()T@~qG5~ui7X^+-9`H)QX zWSbz7k*CYW@)y=N6Rs%TuXOFm5oO`sw+>j>MZF??Q!k0CsNJ_9WnDDz_QaBC;$8W| za;&Zou2-~sE86<-ZC&oAGKu2&5pQ?2mu_zldV?^Q3|sbG+I{{|ofxuFvb!KTn^}j} zl;upYyG%Z-nBWKObCx@b3;TfDE~Cw%Hf1X(w+{Hs&AD(dV(uYlmMZ1eZux-aR^a7! z?r;mW3W}dwvubW1vQKS%)lljv4Pk-t(q_w+9`@I+itdG(uIvr6UYcZsWb(rNoY#|Z z7_BF%TF=I#Q7Tm=Eq?Z)X|B(?wRFdj$mN|{RS$}s6i z+QmwdG~=`v1cPLtf?#oaJAwMw9(Oi{JvYA1{0ZS9bF+L39hz?WIxGcSyKz_3-& zi^D7nf(O>4S35iDFzRI4ozBhoKe`g{-|J-kxI?_bAUk`}*6pqKXkV^C65J~*TLZC} zXMQ#2@ikUs4dyWZc+xO@w-?g)$VL_qN9D*oHt-=}4t;KH)fUcuPg%k}U^if?7hyD} zy+rM~X@g?Ro0AvPvAP;0xG=$X6soa|f;ilY;QMGq=02g!Z*K?S*l69hJot zS;z3F^Zq<~4eu*crmPw9&*{n(d>bxWwzku3q@CS#kicAMPce`!F29e-+9Px!5#$f> zdCa$b4WE~&sRhzvl5~zqB%Q7-CS}D0Gu24jI4#n5Fgl|zTeeCeb>XRftB_J6K zU6*%IjFFLGxc?l*HfrIAoRqRTDYcE3TaNr82Fo9T=tgjS%OI!x9`zi{cBB|d$|s&P zvPg?ho-a5%o=rj)nmNTE84eZ0-(OJ?4A=(4AxH5C{+xBg`s~J(D>mafKV(xIHK5b3 z1DxVh$Cwsg2S(*dW zV$LbW60%iMjAH)|GgLr@T{x}hpx6WcCqXW3#t*HTE#Pv5BV1C=lo!4}=Y&~Mw*%*e zeyG~nEz#_$E4*n5KDT_|nvW{0^32Q280Fo9;W+pm*vv|rSFLQaVKMXj1a!`WGw0!V zPvyZsu{u}f-AYc>7u7HMtO}lMaKRN-&nrMT7slp%%BPj7e{T2xp4=Uy!9OVW=hn13 zt%2``TFR>j47}GxZQ4-HyqeeYI#w;^joisSQUB7LE*Z|3MFS`In`v{>R4cO!`BL5_ zhknaS-f%3n`W4&%3$v!n%$mHYp2?TZ6D_T zi35ZIEe-+xgD#i63IdmD69mvK`5uuc@b?zAlLLDCuIuPEs(bdV>P{%kJ~#jWRhaw* zvG5j=C-C?obc0ScKv*BZ+4L!tVD~iJ*L-iQbx}L% zC`t^XB922gV9rbPWTf3t%5WcX-A;&AT$Re+j`lxFlm4%wBtt-20;JR1iP5H+(muib zpg&H>ndIOXiziX0Z3F<)CjQAKY8#1%cEor%(&Z0+X<}RXV>3FsH&CsWBEWgWv=IEu z52@C7+oJZ4p5O#ZD2hG0LRXH3vQSK z02e4(0utLZCod6}9I~0K?A_8K>$9I=Zf?)}fUsx!km$_&+;e@-omX=Th|2ITWu!z! z%>()twF+j!zlhRpVAu?v2YmRZ%{$drHEoqsZM9?#QhJc`Kg$1{C7!B^ib>dfLi%kR z!I!@&DNQmrS^Gsgg6QD3_9z`OwTekcsA@n-UdMP{LXsrX;pEcCafHxdeUOaCY8`Ws zXqfWcX&OuzrcVH?{cV}TzF?lXA}ZC+R+`3=k`Y}_(qt0JRF`E` z+=>gEsrq0zilZT-n@AMKyo9l>dSS+dKDv6$KeEJy0I<^|2{LKy-B!bqOfW^JL`F2Z za-55m7pa)H$i&XtYx8U_cMD5o1T2(EHDS9xdlA2vjQ#q2m94Th?y(a0_+)Jk(Pz>@ zM86ba%L);r6mod2DrKJ(K?Q;KgJ777F~ZqE*Mi{AIE;%Plg()lmXABC_KOoI6p(Kc z84|e&qU-Ziv4=kZkj@gv1zD%sD@4j5HI$4p11+^cRP;cZ>hzoyJm0=*;gnLryA5+L0;M>o<7q5ti%L>FQ3JWZRg=iK?M1r-b@wj?C?P+&A z)nzhcwRzZywD1rvZ%7Eq^T=N+i6@@&7xDt%skY}v9%!qoPMx}+%Xf}{&}@bb-@DI$ z9RF*Dv42x%`8gmyLMe_=2_|^Tx>d`&mg(Cgr|a^H=XGm_=XZUyRvL`LZfM5rv_5Kd z8-{k$=BU+e8QM)(MyuV`(OP$HwBB9k>>(4LsJ&#O7Pn69?lr+5v$p>lr?cHQ*KYry zmt;gdl)do*i280W;*Wc(SH!Aq>*n|{jbz*t(SDM~x-Rk;S(>A(a@|yUG|5CFw*qN-O!$Bh~`U%m$$?UT3@V+ zHMGGO%wgRSx4)*nr^#W`FSPfQGoRk{jt zKjDWpxuLkQPWb1%bV?Rk%2sGkthrlS!oI;uerj>d$C#&TB^&xD9CT3HsKb)Y>!p3l zKcij)rHRrK3^#OkUvJ=~>E`2cnyBqWY-hRJE+*q~E+HxDW3auVUD?Y9u?~OLODFM< zq|BvmMM-f%rrO1#SegaekD???R21nTDrn;<5t0P0H^wsW$3+oWT8hR!bx>$G5>TV| zs>%9BR7ow025}b4o{AR(VWcFC0J|OGjNkQALj{qF3p(Pc?P;fo(|zq^y-}TrzoE9Jbw7-UV41AQw-7_x*+Uqr_y6* zd@R@S9xBBuD#jb!w>-Xy(%>GSu2fIxEG5w=RnupokcN$xZ-Nkfz-Qdp8+7^7FZsu? zG^3koK*x55MV?(+PnL2GrMQJ^z)JR;^%+-e&i}xM7TLmIt>?~TChuSvo?++3SUbZ! z$sje^dlA7A+LpFDU)MhQ3JI&tC2eMw$S;F>6W&jgLiO{}_`x>%1zp@@)Ar(Fv=1$= zRM@$n=KWq;JOHs=RUL1ceRZ%rsHoVB^4mxQJ)J>l6L5Ai+61D>Fvck&bK(1-q~xG?r&c6HLMr7JH1F) zTWki^lC#fj#Y_8ybEYgfxOD>8UpVvh!GTia;=`StL81;OdmX56XN*T=$8ln8_g=ra zx3p!5eP`4wR4kX~Z2;+l)wOL&GS%)_CYh3CcWtLqs6COXcF2D#BZ|m$g4tv^Q*&HRTPP zB6wxZTc>Qka>`~t?y)NWrq)UzEPDQY-Ehs`r7yxQIQT4>g}BQ#wNVBqbjP}=&l>7_ z8I)mJ2TFTo1B!D*pqQDp!p9Z;Oyxl7PV2?^m9&U$e}5W;m5hjA68Pjbb&I&{f=16B#2UrM=c<*1H=E5k|;0qG-_p!tx`{kG1t;ulcnb*PSepp6HxX)m%Ll&&&P~)7Uquh-au$FW z=F%sb(bqr;Pz<=-hyH-7{Lq4f-CDp5uvXXxW-JpoEP|7ZFf(*Zn}9_*Lr=J`teI0f z^BRD_0f)xC3uA8SnlX1VR=Y6fl>jNM^A=+Oqq9A0luDX4SKskW9l!=6%)poM*mj0;VVl z^|89XZzF~RS_r8PzQ}Ka()B7wLtr64k>4d=K)IUX=vJZf@i@r_i+*5|HyV>4V(RPv z;DdYFA{l89X{X3X(>p&;V`N|3Nj9FSZR~|S)#Pv6h(8El*`0PvKA>w*w36SW>LaR1 z$MPqrw6mAzsr(uB*QsC2@@yK*T-Shw#>wn{Pw8vBx%y-@PUBIWA^)vBqJ;(RMiuNS zIgyX4dP3WT=a1F?@{ziE>5Wp*%-}qAtt~RwZsTm8$!M2p>{U-#y_yLB6dNu)A~Zg+ z0WUZ!b_o9n?F~-9b)=*r_blI<-dN(>m9)C;O7aIu;n^s&`d`wnQyxz#5c2piiblDZ zAl-{}BZ{6+dTBLc(qQew$&mf4<4OU9E%F;w{g$dtRJyrHviotGM$!2bVItc!i2{xJ z_s|4UkdyjOtI=>9u3z)tss4y;RH}VMQ@@W&*Ycu6DTv%5)?H$G#`06*kp6X$ zz(f?(UsSlcj(omj66-w}mHF#oAW>ic-vOljlA@VjTaJO#nWzMtQH(Sxk-=GW!}DiV c6g^7&_bbs+BEt%@K1zS^e}QVN0cykkFH=YUs{jB1 literal 0 HcmV?d00001 diff --git a/sgl/dataset/__pycache__/wikics.cpython-39.pyc b/sgl/dataset/__pycache__/wikics.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63f108958f102640cc3338d5897f7532aac5f701 GIT binary patch literal 4071 zcma)9&2JmW6`$EJE|(vYW%)z43n)#Sgl$dPD4G_AVFaiXBSoO5bz0Yii^YnwqLy0j z(zC-@B9=Ww2GSnfy62*RbTrVvXn>r0^0_x(+TWWcEz0(xyO`NGUvFmK$M3z3TCITL z`QY(q@pmhX{hJ!gkAuc1D8&IP!30lPuln+yW!m=0>AAe>^LjPY=l6Vkt+YN0dVv|U z)5fUTYZ}@~TcdWbZD==L8LjqKN9TIyMr*w_&h9bciP}>pYH|C>?wuF>KI`}|u{+!D zaP9W@q9h~Yo{YwOAR61bi0?)!Dq_{Kb!)tzhBA&sxSOQ0ZixIzmgZ=xT(?vnPBM|m zxUXW-C!WiXJ8u-*AY#_zg5i#9-?Ck9&Q$N&u_Js@f69({&jlTbhM_&t6s@NWw{MFT ze0{Mh&f#1CmN~2!;0EWl_aNC%`i1tMnC~mAyQTencQ@%LQTmnS*t72~9fxR=zKv4s zpekJj`5*DGxm;IVSV#ODUOFWUEoCcwkE~<2w1j<$mHfown2#|})k-$>k2vUhX`>EG zcHAiK6aI)=O_UZ&TQHo`*?zf>ou(J>kJChLCSo(o)n+jnk8=r$Ngsoq746C>8^k*J zeUwh(Po>PIZih*+BwOubQ7p{@?T2BKB`OScJuHZEm)hkYPWKiXseOuvZ^$?d79*t0*3C-M{z6_4M#ys~DtPbU@hK zR;A1C_)wn1eW(<;>T z>X-a3EY9fXRh_o(4vRc{iw9ZEd6eP`ssSt6AJ=AFv19&cHnhkl{%$>X?lbv)48v3G zycla|m?s&eC!;4J9HV1tt9wKHWmh7hr`j^PADr{waBJu^?o7yVqDU<*C;l7^rb20FuEv zNigzUBrz}q8cjiV5C@?qr33v#c zpKc|r4K{;%$?<2l;-!7WIa3y#-8upc%pH4rXHThdacgU9kf^=MP8TZN8sj3_VVoE{ zzTPjcFYFrP-x@`Qisi!84M52sVB?N0yZC5#ER#%0vbv5_Db$|GR6FFwl?jDtzaay- z*8-A1#LB{iAMWsg`*^3T3rAT@IwJ(~tp1q2E=e={q4$JrTm1?tOvNd zgZIehEL*+FiW@f=H}6B#%Bj$DgJG_%yyF_M(=jclw1+$Jtipztp)B_=}c*yk>e0@6s2+2D8|# zJ_~S`^J=}UAJMrRqA_c#3uV0w$_6mpE1OWIBbv|rS|=+ijpzLEE@!h9&ZgJYm2$38Ep}+krNtBBC|J`bmfut7dA&X z81Vw#Ds1XlZ*d^QHykK384k@E5Uu39%Xnt2tYgWKFiG zajffcHi5$aR{lnpdjI_T$_N(PgBQW(UOHJavP?f4M;XHdJ1=pb#n zJ84GS$hov5FhSQo|Fua7g?3{x zz#aj8kT{6=K-)5ZqV0a3>W1lk5+{Q_g$vOXvxpD0C3iKC--H`~fzE_k1*xC$HU5#+ z=51h(V_igmzKjsGRI? z0nUaEg0EdmSfYXRSF>toW;L?QkXF-3>yB!E!!uvC0Bn@h`dHoAwGoQ}I)wLC@c9^& zZd5rRg42E>?@&*jGC$MS?Ly__agq(@?YcveLfne@hHxaO;>)>#)3|>N*O54kq@cr5u4!js@h*%Qn$|B zQ4VQl{WKA-&GX=1^YuKFA1{*PH~nMv&9wSgU^w@az z{K^=p)EC|!z%PWP_@>V-#?aH9s07R?a+>7I;IBF1(#0s8R;jXoyAs|ViXjErA0<-w M|Df7xf?Bu#2j!OqApigX literal 0 HcmV?d00001 diff --git a/sgl/dataset/planetoid.py b/sgl/dataset/planetoid.py index 613cbab..3fa1bd7 100644 --- a/sgl/dataset/planetoid.py +++ b/sgl/dataset/planetoid.py @@ -104,6 +104,15 @@ def __generate_split(self, split): train_idx = range(self.num_classes * 20) val_idx = range(self.num_classes * 20, self.num_classes * 20 + 500) test_idx = range(self.num_node - 1000, self.num_node) + elif split =='fastgcn': + train_idx = range(self.num_node - 1500) + val_idx = range(self.num_node - 1500, self.num_node - 1000) + test_idx = range(self.num_node - 1000, self.num_node) + elif split.startswith("clustergcn"): + cluster_number = int(split.split("_")[1]) + train_idx = range(cluster_number) + val_idx = range(cluster_number) + test_idx = range(cluster_number) elif split == "random": raise NotImplementedError else: diff --git a/sgl/models/__pycache__/__init__.cpython-37.pyc b/sgl/models/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e80b5a4b0f665408ef212af6d8c270a6243ab39c GIT binary patch literal 127 zcmZ?b<>g`k0&@WEU(a+6KNzEzNkB`sH%PfhH*DI*J#bE;!EX_%^1DWs{h#3GK^cxHS literal 0 HcmV?d00001 diff --git a/sgl/models/__pycache__/__init__.cpython-39.pyc b/sgl/models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a57627d14c1000d33570937964a2ca3cc658e9e0 GIT binary patch literal 131 zcmYe~<>g`k0G|E|a`|gXo@HcMV^d4VQY^(`7)rF3ax5+sl9B9*q|WSg&+P2Z z&u(=uxm@&ILI$D(C{Pd|3gfdoh{K@K_Oe~a&91h$R#`_NNzsF`My^@Ju}N) zNrofj;@Pb?)z#fquU@@+?|ZN2^-85^;PcP_{cZmrUonh-6xDO$Z6P~?PVSr z4Y2R!@SXMYo`dh4SMZAX&VOiRjD`cYN^#|NuyE(*Yj58IL%kN|#}B_b3`=kqBs4-Z zG9FrE>kyKQP})N^#iOaoz5q8@g^>a(C2l#|=V%Cx{Df+i!-h*X^TyxEJp1>5B>;^4*SV?rgfl z9eD?W3_851x;w)loZc+!>*Q|9-N5ACK4%^;bF7vh1iq&S>UD!KUOLjcr+jJCBlP^H z8Z9Uei!%&6DZj-jyIAak>$ZB$AaLEkFn;`%E1Scy7NHgARVfHunc0DRWrvwt8A7*SwA>Cx0I+U(#lwSa3? zlE!AIqulTTbmmlg?aM=pxUSo3iOqJ?+lupC2CSPphbB$V^Cp(cPnMk5iBO9WW6>;I zjtI02dT@STsMFUzgII4vOy-s)Vtg*NJ?o+QJ7$>K%0?D9=||ScimY+&&=}_r4b-^^ zj$Q=0GRFlhh{>n0U$PLZy1;2%B>4i!!R@yuN?4l+w$_IlH#>f92>CQ&O={hsHW-FA zY_gr9g6M1AK`q?$Yup_#CPL^AI<=imY>1U7COU?q+SqILqAGK&3dyq|aq*pIZ_j^2 zsiBI^wKyLP+hM=CALl$jY_>Mz>|n1CWqpP(JO`32hWa93&gq?SbS+9`B;n)BcrM`b z8L|)-7NiOdN3jkr99fF#TgR^jS&xh{?G@A-5;SRo^X;L!-&9_zCS6b1zt?R&r6m>F ziEi@7FOfG(csuyQJTG(4rQC4Ka?`N{Z-aQo2%89-&fIoXjoR_pykLy2#|Bo}9$fm+ z_|R-y9<<{T?B|yKE#sj%%Al4RXQK?C!F+O&jaoL!Vn&(uwHDWlA99vOIgJufMLL-M z`%5EiYZ=AhD1Tt>X71>X09jb!ux`cH>eu27_etGSFXBCQnM5`W_r{NlS7{Af!~V{- zU&XbE06U{TFM>o!t5Br6(v!s4Ag1~Mw2@|XBEwZGsn)NUZ5iy zH2XdrQPHKB*z2)gcHP~*W>4P{o2*ti%vVTONjP9!NvwR!@AX`Fs_=DFo#jKq8f|d_ zC7`Ku%yQLPaH@_iez1k6^uq`F33?zCm+zesr%zoeG;BYk@ybK9kwsmQ3hFu31*4#z z*Vv@tcxSzeSAAqQ3f_XZ2uM)$PI*iCE_tWDWqg+fB2=JX=i&vC&uh2de(Q}};wm%k zy++W_3~o9u~-uc0yz6#xe4HCUj=018k^zy$FeQg?(F#K)vS0RsaZ*y}_GWksD< zfO>`&xA=v!8#MbnJy2Js2pu6>dqbXQFo6Q@Den>br0ZP!w?zmAU69 z%aiK949RH)>;zudS`@_*-#c**Q{hZ~?qpN)o8<;d;DB)ZZxIy27+LC<(Gr&lVJ7nd z3uJ7nGFn%-@ouAN`(qtEb7WaoXZuJ^!e<&?M)s9x`w638!^i{-CyqYb$MK`n7IN-- z=BD~4u1^?^z_u|n+OtRJI=5zXsz@~yQES_;LCxX3xSdqgS_c}%jb;eQ-5q$rQL+CT z`uiYd<(LJZ^tE<8c;3Lfk7@9UvTjGV82(FaA=)}L9|LRcemC97aQ{4npM-pOL*Ou_7WYuq`ic3uzGayrU+`rRiv+`%EM{U zu-su^^ky=RB|VMMh=8T$r(i0+0!|zRz0IPq#zh`zfsvk)DY}#NX(%XW}?Sp)mP9-6D6QQB8qXWMeDqIo?!Cg$-+oM z17|_2N=C(Lao33cR71c=fr4T?sN(VI^L6c?C?EiNRR%i9-h5$Co6zB*5B zG`h}!g;O`ogQxcV@d%4D&Z96c2>u7=PMcT3OqTkuNQW>DY&`3N0nCKp7~c$E>?jB1 zl@a8HkQm`J<9#Q}3g*g2<_HK3s4E|353NyQ@B{Rf+qTr7V~gg1*5Sg?2hm&72d?uo z*G1t38ye$HdMYBga&TQ3tf99&dixjllJrmt3+&^enOv7;{8CiLTkNF}748^eG0OZA z@(zD2vBbybAfxB_v8g_o?Mp^4jVkkfcFd^a70w&Wltp%2)^{VM5Io~|v4|I7c=YIp zCgc2s^HL&%A8Hw3tiSxJq)<6Q3U6vDpuc~g<$a##(PxlDREV555q=R)wNYJb2I0*c z>vsqdd=)ROGv?eFs!2n6#*hV#axmOeE#Do$#s#VUp&q7Yl$cG>bd_j&{4|mlR11GD zwufd1CPAR|)bM;Ats|6)Vf^Y*Tkr`0x^e18Zx4wOHJw$Xv<4iO%;@0R8e+G>m&4kI z?+iT6~wvb&W=Cu8Wle6Bi$koV=zH6H}SX@#G9;fZ1wBxA}i8v6ecUN-&pY5 zZNC+E@A*i~z#nVNMOij-gq-ev!;x>GH+730QMMB~i(Mo1zK0SJDH)~%sOHGuqPYYc zS54}R2q>3OFPn?_%UDZ@1IoC*XkL=1V6x>7pW+?;RMk81qT6!FDz0fhPnZ=|jZ15gDu6nq7>H&DbmIA*n6&hXE% zi)AN=@PYSRDCvhDg`a?(PQr(|jA2P~g(h-_I#-xYa)mhw1oGHe<@huwuhW;Gk}+Jz zF!N!60REFRhOpZ1BoJVI5(qG*RG!NfR&+G*8yG=H12q1MxLPri(!QL|0GU8$iI^1{ zDv2-Si;_>2WYSEM&y`Hyxu-LIazkLi1OYIxR?NPsGa!>Skwp^qqc{CCl&3Ou3^gS~ zC#e014AYq5$xrZEZ;4MvFAP{XC$3&_kFE_^PtpcNkydbC48Nsnp`jB?a=)Oi3(B;E zSLME~uH(Kfi6@w@(0XVEcae(9uw}4>B(AxY+0Lsrp31Tf*V(OHM0c4>M*I_)Jh-+k z2T$8bLu9^hi1skZbRG*@9c2+_<{ueix{F@n92^N9DHgD_|2Gkw>id5M=$0Qj9XPW6 zm!mvdE{}>)VO)-i}ng$VdJta^AB?rBQiQfzzmt z7b0Q&a&Z(<1!zV}8KNi&5snudO&A2p=cUO-C18}?I7=)yW)o<8P# z9bb}_Y_HdX&Ea0pgPpFyCfw_KzPEDFnC}z$y2-uEwPQN3VZ0OiuifwVdhER0L8QyR ze=#{{sP*AMc&7>f0F2)pw2&ff*Pc3%SUGs{ZKPxa^f5q|y@q(Y>mh;Y`87#g*WjnX zpHIl%9S&BG_n6!e;xQK`{{$>?!Dw4E*(Ytiwa1L@ZT|t%QI%8<>|C2a*;(|Rf6IM9 zuM8ah^3-kOosj*C1FS$Y-}OZ zgMDf|A(!>##8-BMiMSQpt5+j;P)fnbYiosmC01^;x)ZME2#Z7N^l*-u+N*!BnySJx8UIE zo+u*-JDgCpR<6-h#%5mwAgoIWEe1nV6hLZeKfJqMXL^#E;Krd>a*{qc^&Pr0v)! zmuT=`QPRvWv%W>MCTOo{>4^8JGfqG_LTDO&A?rAddWjeFpiaX!u9U3lCm}Y%l|Mz9 zBzhMSbH4zj#`YFU`l0^t^RpqfbBuqE(U=xhlOKL_D3*Z0{8P~yQCUt>wK(EKeIBie zn-^zkO*6Ll(C^6@O?0BlM?_ifvPQ)7(_pRt#Pdh6mgHqM)~cN#g%oe;#ZB?n5_(JV z*5CacZ~cGZtqL(DG1>$>hRa)LcTO`{gJrbsoW$D#`wIRt43Df1OMzTwy zNCJ`&1OnM?A|XOOndpoF{S`F)1SMDp`B`Dx^L7r|8{~a#p7tJG`gGO%T!7rmlVt^v z`_f4~91rHGcGQF$D8UOLuz38Z>S6{-b@frLKYn1`f>IPyWHwleo z_ca!H524a_u(}s^dpa>1=WfYQ1my5upH>sy2_rLmJvEH6a}`UD-xOSf1oc@86O#OC OLfKxry5uf>=YIh_oM0vZ literal 0 HcmV?d00001 diff --git a/sgl/models/__pycache__/base_model.cpython-39.pyc b/sgl/models/__pycache__/base_model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1f2b17da6ad54d5ba0620c131c6db354c42c61a GIT binary patch literal 9501 zcmeHNTaO&ab?)l!nd!Oig%oe5<)LK59!FZr@+CsTD3VE$axAU|kr3jEl-BHY&+H86 zw$(knY)lB zvd+}4t4>v&I_Ep5*(jF_2A)$tx#9olcMan|SXg{)6y87z{tCo1Jk>Mm$*=00byFpI zt8bcyGWvEsBQ<6(+t1Z=lDB$JKVQ!)uZWuBw5H;zTSl$)1pFAAH5ISD?T5Y|ZVo;FS|e-(eu&CyFly+)cl-C< zaEO1px$RD?LVpmT%-0tmYg!1dAmbZ#Cm ze;cC*I)~liH3&|F7^_wGP<}PW2LV% zH&HlTbDxKV)U~>{r|uehLyj%8^PVTi&6WucyKC-K6mo;ow4U$fc(^rt^8j}T`_Xlq z!$GqV*4JnI<=*oK8h=6^+n$kG9>uUUPI2490oxYNL&AdKG0|lFCNAt=;ee z&7`A1rR;hM^>R=dFamhwQjcu=X z2Gwxeukv=hJl%lKpj{np<3e0|qA)N7)xuq?6?7RTT_JfMBrd$)=#Bliv>xhMU5|6Y zuod(VbV1DR>veGxl(maP`#K&{fLD$LhQ>F+VnN#@674$HrsT zGCjz~BWTjDbJD3MruAx3LQ!uJ7T3 z%S;Am`Wth{zLi`F?1>o;Yi4Y2yb{~IF$bpp4&Kw>CXuU?+?p%Yjm=?y^!gugYzHKZ zV{(voYs(ul-Lj1rBCHH^A3zuY1?dIfd0tX$)S1u|1#o65;l zPvchCp9^(k5VT-G7L~VtnPUpuzsDST*WUtp4JkMSVkjGzeH|K<*PTnSt%BNp?${;N z^0Z8YM&E~JD!4Qo;~w)R*WDR6dP$2YY5gL`)*B?06`3?HPt|_Y@AX`F;RaCYB+OAQ z=aB+xc}JBh&WclUEHR5sR4qO1%qQ3aNs)lJE?9s9<;|P#zVp^iv6*&j-0L;vCI~Br z!8{ogDSr={W!wN{K=Z)_Nd{_wL;^F236Z=bG%vvhn z>i+oCI9MFvWDa=b+=X-xl;-9fJL3&Be6TR%mk+Ob=?s^&HRYE{HvSc4IHiE3Dsu!$ z0(@YLA(s~54`tX+N`h@o;M$BxWW7QF6oh)*3@Kmw8pcbs9wkp#Ne&&?i>MnSi6$T- zsVHKR-3y18@F^&ScTxD~EHnNqAy5YqL@Giw@A0Y4VfX{w8o%!kfDgpF>zB}Sm@z7~ zdq@G#fsEOy4P21r1#7RgSEd)iEObSL&}*;nv)B?4W?S{8l4` zYa|w$I>)^)hAGb8fdy{|hnb|(-Qyghn$+%*gI_k&sV2o625}CTM%)?PhU`c71SZ-u z_Dz^s=1s{H`zrnGDC1doO#M*=O?%$h%V8nIE1G=8Zx0YWYvp>6E( zUBU+dNe=@2bF*chH(KU3k%H`midc|^LI;WDV+;_ z%7^MOce>Jq&!`v`S_H58O-*&6sUz!^>y041adq<+)wZwerC>PLP2U~hyn`i0z~h`D z8X9mxG!-J6$ze33zlU~cNE-WkG0`m^uOAeofSS2({C#wJjP&0~#_sPNzG9yOVdAx` zy)oQVJ!fCPg+buaT5Ug!9RL&fvV(SlZ)k@^W!Ee4Po)i3$Mb{Oh8xsRbON~_`|x9>uS_62SF0qd{w9arB$0*vh`IUpW@I1h=NB0TP)}bh%!|+$Y-zq>C zz&+|NsxxeXfjo`BRkeHegzc{>{U!#jl{F(Dn!r=PPD0BcI{<}ncESgmhAvq>)&Frw zO#dMp<_dR?zurIt7bm(2Tvtq0u%>C)r6&D(h?86=+XtJ%RSTr*5eBjun|vvA&TNkDCO+YP>M7pP5_-0EhjO* zmps`CQi`GkIWgYBsE42^K{1DRY=)$sJj1iu6hnnk_z>=KB|NX7OYIH_D(;u%sXJzH z4q6z$x*9^y(><4VrQe9yZUwJNJA?r3mbmXwJ)!x~3~s}rw^=h-hYe6&`>xV0^vcL6 zs+;L%BWT{e>~zMzh04mIpF`sZch3Gxs31h~h|!&o4Y3v(p!D1$Bcu*5c=@w%09lXY zDn=P-@W21S+{^7bqUCcazXZ#WpOm7)r0f;;l~+6u2M@uz#kx{bxAQNUH;=f^ZzAUd zYp=Lhf;BL@mB|WpbP1NJ60V9}Sc7GearRci^{7G)KHR{(tFTl?w-~L$+QDW;EBNCo zPjQuJVP}$geyMi8v-2B?2bDK9!PwIuztutLK_eKkl?3f59aRJ=z5yrB8}{8<0|E&z z&Aa(t+m~>Ui~3qM*dC61UUdsW3_ZNt@qF*n)B5aH852>DwC$zpF@skz-wESa?{#`T z4&G_QE#TPSnnrD_eJFgkGzN&G1Yu*)M4+WrJrdr%^z`L-5!4DW#sIO(Dx8mwhcJxi zS0(&cg0{BFE~M> z?gWR^+Wiw!K-fc{clXs}^bY-jsiS?I?|y_?-l$s~brnRIb@o)-IEhKM=m5{y`ezQ1 zvZcR@Qf$Ii#yvHj(Dv6(>#wu2&T#}lOhtP2| zj8_JTxI-1CtBVUW!3yFGq@?X@(Rg(Xtl#J8nE>u&m?%DmwPIToTBbAb{QGgivUNkoRGE-~&eBKHqS0S%8~ol{9< zZu*P33CmfSbp$R9~29odXJX`p|UHD9O|xAp_MbF+3>6 zR}h=CCJ=d1xi)Gp!eM~ffP#b>;hmkpTil1=2#|eoCjs${-qDp{OfQRkPGWjF-t6FB&7HW3NN1^EIY>ZhK)pO|5FzkWJ3+l$e%dKm$| zt+ zrxBa{9r@t#L)=yoB>8ROn?Pr5X(BB>=@I_#(AinKDaXi5lfsxj%yT-XfKc=^5gL&g zoHtN>HUV8mXkzRIo6%>kl@e_B3*=A6W+EEmYKYAI2|G)m;UtXpU+8fNV@aec!C2K3 z);+^n$&oJMEKWDWS^xO|IP3okXO)Q;iOW))7@qDv$1hU0y%;N=tNM`TL_tuJyV5Jk7y71f5`205}H^gjd{aUj_b ze1ZuPJ(7rw`eX{^5USwo`fDKg29uHh3;^hXZ)$g@VGR9<9gbn@t7M9xMLuERW5yHe zpJb?W0-1D&6|{P~OLB*#N75%5kPJyiBphnN;Owx-fLQWTPcs*ggd}5<$0Q$v;5&F{ z*N^SH9sgc!Juygk*~p-wzEARiWS8X6NC;tcM8c4r{&SK`B*ZKFCnP^5`3sVtfnZa% z8*uhzBbuxg+x{YQPm$z%$QO%G3N}9oJEpC&&m5bdhc7(aec`hu@(aZ+mwCuMLEx6J zoiz8!*o9(9skn+1d;_#aB*R=W(Dy{~VJqT#s`sifGVY z{wLBhUidEX0$XY;BEAUL@&Dg!3B4es##uh598aXzf2W5Y=2WzY!nd4rX#0DQ`q*8k?FbysJihsOU%&VC@r4U@ z2kr0k2TA;ft&1|Df9J*~+D$#XS<11%RC|7weYF*q!`7fMq zMO9URx36ldj(7Dd$8)*?&TQICyE-24(*dJwQlz7Nx?2y?9;d@`mV`Jw%;M*Xem^wn zFwJ5e7Nc;#D8^>H-H!EP`fO{Y`|bFNX@79*=IyPUAKt$CGu~mRx4b32D|lSGaTi_T z6k_Ha%eg#(MrRmD?o6Whib~_2OvZ`Ubh-_qD9zI%ifkjzqrD_9COR>;Ihqu!nU=&^DMq@tZyP(= zC@y}J5f)$2ypcUWY@2@89*$I!nYM|^22nZN8b8-nY;h5fxrELU z7ln`ChVaCOY>Mgi^QPQd3f&XLk46&uc<$pd56~%5ICJQEAm^@veow?(&JoWC-n=ps z4Opkt@xWKk)0+MSy#Cp5tarE5$98C6-+M);wW`(|#>GC=aG(p_1PQ}*npjW8MQnpf zp6-qGFe>#THG-2k3iV8X2g9=4pb<%Y9giVhI^rrMu8ZmSUz7H{4SM8n)JXH>1uXR~ z3|MBC@!2GJiZ2>nt6{)-x8fxMH8G@IUFuU8e_<&t&}2E(tJEA ztQ)I=t-=GxBVAY@eg^6}$!rxG_%hb&d)U)3PQj*4NLQEB&DZ2xFt3bBY9iv64w-Z?Yb?& zqMg4z*mW}2QS>er8*(kD`mIh_a<#-xtEPWI2h}5T>Ir;1vdt(un8exQ;A#}9Q4brF ztFz^GNTVOpvI~pQ=_OePmX({}esq7$-a>|^cw>DAt|u=x0zU{C26~{c8ZlBc@1R>@ zD6nybpY8%bJ;uul6c<<*i1>4Vg~36=NQ37&{QD2PGFyi zooWt{BZxajzzOH@ouWGPl-P8R#V5}1uO9j2gv=wCJO>|rLDpST-E!ruLIAmdJ^f>x z2nEv(SbJ9Hxvf89JwF6sS-0Oa){n>IBv&OWR&<>A6I)e!G=_(^stXs+2_UVP@2AT8 zIm4B$=9~wN{s0G95iOO@R$$GO{6EEGC}17(Vnpf&$Lq8SwXNZZG(>HwG4|lUf-n;l zjx$+^W2wi7zlI-61(5oqyk2VO*6$FcN058q&b;f+%$-$^DTE2E+!I>u4SY}mQh2IL zzI2b=gOIi7D#0s|D6tz><-a}LDI?{~88`V1e@6SjM*nyB7iV`bR}7H2C=I`qaGOf! z)xXhz{z0K*B%-{pFGE}qT$K)6HCs zU;ms)jcnA9%xJF|#)nJWi9_oFdJHj-wH2xUF0smfZyhgXpoE~`!DHS;=OBsH#hc<1 z`b}|FzRl*_#C+Ny(=D*4CE3mNPq3opb1>^4VXP^J%9sslL};`=Z|N%RB`tjV5Rakw z3&h2mCd!$}T)Kz3i7%_jM0LZlDV6&oKyAZy7HS)=uDU+QXq98M{`aDF5*?2A!etFo-hx=DmBDtJkE1U!XW3Yi2(DwcLyB$M5W%XFysXwjd_w7h za?Y%7(SeIFID87(N?VZ~^hfAyeZ8Y?qLi78qcYrEUWxrHR=orMrA5yM|L6Eprm<>S zn)qURd2Nez{GTpID5>A4g$f-dX2wa*Uwul^-ZetLN*}q5@bmCa;R7#0KSrioVCFkz z$ubX$8hx2job#JnH1~( zDb;h(c3O2!?xAm>vwo>2&H5qLus$_;u645{*KMMaPa$@w#?d6~IyTB!+yyPdQ7#w~M^yC^q$rWba0iL->9_1j`iVN*5vJdnIlzQsv zD}@}O+{1StZ$eHF$SDr4Ft&NbMVc(v9SLv;Gy0oRc`1R$lp6i%UskU6D1J&hjl%_O zgku!EYY*P$!Z}P-KY{;#idRUG7d{;2{j%bIjdvL>+u=@GyN4()b$S?bwH59>Tovo5 zmTLI1&cm@zRNABeHaP9G)nvItS(<(O7M)lB0&bUBEv*b%7 literal 0 HcmV?d00001 diff --git a/sgl/models/__pycache__/sample_models.cpython-39.pyc b/sgl/models/__pycache__/sample_models.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d488bdc98f3cd06856fa47e2ff30a4f198b950e4 GIT binary patch literal 5267 zcma)A%WoUU8Q+;*E>|S2xUwqSZP0FzC(~GRQow1{G>^)Wn${5lxDS+Qx>$2o)Y9aV znq4{yOI|7$Ii&>(v}ft)f2YSe73iVxt>?6j`}@A(Q&xZuF~fP!e6Qd4&Bm85RUD+P z2YbnjYmV~|YAhxXjrWlAzoJl%5*?>u|3yP^TXx-sD{Pnl=Rhb| zc@I|_xTEqPI*pPls{sGLT2U4Jmp^wrrxDP(=JKA7`-gN$uQMppUN+pTgh+SOZoiX+ zxIoy6KTGs$VV-u=POQVC7akTxKi}TkiuG~&V6&%NTk*Yo>-AeV-`c$S##=Yv2CPI6Q0 z4T@PyOPX0Kdb)XNs=J+DT>L!A@}4&Sr%Bp6EHn{e)^yTJa|kKwPsL>}W3!}~cjR@? zczMz}&;gd{6)I@)BM3Q)nw>b$qv(NJ1vbw6%w+1Zu`m^2z?uK{pB&{y}RGSc1Ztv_c5JVFB`8L z7l&ZQkuG!%U8ttR#CR$$ViOFq^q{A^5%Z=l^)q%xp&n>hvtx@ZR3oadBj?0VM{I!V ziWolk6vgM|&?KFkCXxVY15Ez&|4bq^GJOzY-|>sO&$Bk!)&F|l7S zX#yT2yWJ^SeN^cH3^5s{H0uuvU)@U98U;(IDFz zD6GigwWnyDOkNTyv71O`(q2IY1UPM>#8sz&VYS^8)RdgGJZ`x}RNdPj%J7w;7v6X$ z{8H3${cm778B7K&g2A|Xq33G+NHBTlDZq4{^mP=yjK%pi6pr+Td?ID=m2~BQr5F51 zmgK+XlMB_AKs>y)EN^_n=2jDht+KS#BfXxXKS zaW#`p2UF%sLi6&FZY4;h{0z+pw>rQaiV_YqP zp@DIMoR7)O{z+>c=kM<}$>64mE>3NdSAKw^a86)=1YreC@C>qxSm{myh6IHtFiMP- zN1hVboD*@|`DA14lZx^TIh}Fz#zLQKSB**pq7!J;ui-SXm`-Q@8JT6K@)2|T4nWPg zt!8fgxZh7QWs$R_9nx>u@nx}&a##b?=72@KnI^g|1B_+Rb$L`TO%~yiw!---qUeoqRE;-UKBm`R8 zfBXaqJ~=3K3^$Us^g6BqFn3kj1o=!Rg=7T0bZK|DKE6RHZ zMK?aK`<92|*m!`4oF@0R5vl$TrY>kLGmW(SZR82lJaG*ys{jJm#jr|Do8STPlfwQ~hNs2o9{Zu%Wz4&5JDKdl+~FIVY_T#OjKso+k2Anu;bTQfD@S_=BxND)(W4 z=!0V#_Qcls0m!i?>Cvss>hjU%nRC0F8Ul_% zKe7f__be*!(k^>ckYm5d{CtcKxKKNpIWiFAq1(HXZCSo`_-W}xp0g4u36!{p@ zvVnqgm0E5KUIza%atg9C*@lHGIkVAC*r*tjKXMqkc*cmVZ&Fk6No0eBcSpB?c$oP| zj7Hc8BJ_qfyd>G|Xy6Y_F)R)f4pZ~+UYsW?>}A|VT$~GMU%!QGv)KZm{RvDlr5q*Qp?aO@$EAVq@L4{Jl+sKcHf6Ky0KRqVHG8IR!b6=gW;JzWc0rMGUVlFn3X8 zOOXkw0A*(ILsS57U~U_bw7|aGV4pko3b*aGOQ+-j5UqfD6BmJ)4WM=3)9ZFdo8oBN zh1eJX%-peHz{4ubXuEPOCi7QD^m-_LApxW0$^7BpW)|~Ne4m&dhZD33`-oxZ{GAw7 zqFM>es_rg*QrPo>5*NIA- z^kl(lAI}=o9jvWfyv6?GuK?y&6QmJ+hG+dl@Vs8ptO2YLKSc+}czP2xQ#lx9d|uqq z@6r&tF}w6V)TW3YOo=ijW`&+9jHxc%`MmZJE;o^L(kiF?toW7)@T)xuT=tD?3pzAK z``|pn*Rj;XJSn3MioO=>32ul^j;ARKs?~y;re3+x#e4hyU3hI)G$GIM z1y-o{?;)m)n}yX83;Y~GSV=EMw&aW`uBO<5udodIitvOYqG!i|{8JaWLk~v;@pu?Y zCEs+$@@PZ3j23Hw+>D9|Un!>JgPebi0%j`iJNd>44>vOK z6LG(!UnoE=VjVo+PQ>p8Xf!6evh62OTjh7-? zGX8_OGe~kxaR-9QAXS7&Vb5Vy7dX5wG03z;(oF5=hNLjRHq`pBVQq~hGg z={p2{I*9&+!JMV+o!0;!f9j$%moG}{GnG4E^7+{%=i$jyx0Cu8bRE9kMHj=f3NnkL zzbxEqZHb3GJb9lj1eSBy`GgMmEfo3$>)iRB!?7*{ptbe6 h+_*+g-UM%3g?@(=hq0Pb8n2g{sJUzPwe|I_{{eC^k(&Sj literal 0 HcmV?d00001 diff --git a/sgl/models/__pycache__/simple_models.cpython-37.pyc b/sgl/models/__pycache__/simple_models.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84e059c47a671ffc994848d7b221d96502b7c424 GIT binary patch literal 9192 zcmds7&2t>bb)TM(-Nj-7f=dYe5JiCusSuV0P>wB?6oryS8jKwfW<*7F;E&00X8@usi~@k^zS?OBm0_V{RbcNvHWt0g}N?{q4az?3wQVAzO znaC)Upj5+2P$n}<4U}3q1P7#=^n8D_)Ai$CTiJ1}rXB<`Xh&%zSDHzxu8(TpiUJu2k6V%d zNz~lfOqC~tcq3A!pxce&P)!tlC@<~E#-^HFZgqn6!zfNVQq9OHiBi9t50or;%1xry zx|+f){G|7ILk8W=L{+;IGfUE_o2bfq6r_IGY%Ad}2sPpRjaHB(zW;aDuU}i*?6jk$ zB-vhicyHyk)~7p5$wq6b-3g;svXnI2T};(HEp|VZ+>Q#jU>b!brbWrFqO6KpyMO5f zdlqxvH3XX@k)`b84stS&V&@Vz$HF$v?21%u+56laZcOD-u_3PHkg|kAx!eqOW+!;k zjHKU-o9i9f_7`kawz%^33Aj>0Bd>W0UdsmkE^CxHwv zV6cR9vV<#UMBTn(Puu;ACmEZyhkiqsVkii7sp?axdXOz&)qH=u7qrY9En#}YyOs}3;)g!IF4>toD zh2Lz*PLeD`$$Ai?Mm`XQ?aM-za}=^daSCMFtIxEoOi7L`7g%s?^^KDNG=ACzlxaPa zWT(3Qd>tIeHY2+fN;%1b!Xl?wP$3swEvRyagtIKDqWQiKm6uDofqWEl06T!3uRq7W z{GWjl4>rT3^Q=)1GW?MjNf-e>uY*1WpLd>T)?W;kRxyBvrMrh<>5bE232@{cf}_$g zIC56yB`!_K7x_~*yZI{WW+N}Nrg?FNwL{n=UnXIm#Z?y9So{u)W2~Wuzkz%NM&<}3 zzxO;WDgZ`OPUI^rewPJBXNV>8CLfC_f=8$t{phGi@Tz z7&m)jt)$JHvNmsCSPqi(#KH=^VD|(t{w8G0F%t*oj_+tg$4$NcQd-`x;KZ?KapKUN z!P^OK!q}_wEBJb$YVhDS)D0fo@JCJnpnSx1xHhiY(4;WpUnsS5J8_h#o8vw_4h##< zG1bra;;7c&AtNYPFf{CvD0!~vUm3ZCp>F4gHh;-T-A5f2unLC8432Cg^VpGN5^hVX zsVZk@YOd2-mf2M$H$0dk^;6aAM)&ngql!ulahQzrGKs9c=VJ|_nAb;~)hx7@>% z)4Q0dpJyiL!tD7(F5*+UghDw44@>>Jo0 zqg0B^R~3pwB-2mjH#qt!%0&r$pQC5T^Hjx=q&1L>M7}U3Ty?Ho`Kpp*$Q_uH5!_jLy7?23lt_ z-j0^fKV@H;1YA5miP0iS3DCRG90f~b zx9qe%;)=h={yDLVs5_FBBRfRaF(usoe1}GPcE#R!s=m_Mfc4pId=SC<=X@1Gpc_tzUa;BezrwbM(x zJ?#*U-q$}vtL(I%7X24bc6oAdxd!Gt{FK%WoJ#EPf!MXRZ4b9n$ag@`XYv;6s$^y| zYOh0fMDAE{uKJfo%;C4&Z!*x&$AHP~N>UogMbe|=DIwhC`yEac+Bh-P#;_b#j`;Fb zHG99;N}DUer;&UR$wt&oWhegJbpES%mm4%NndsWFPVOcDiQ=df44hDWY{@GrJg7bK z7x1tk4f_1kj(!=qiE|L=;e?SDdQ>|%(<;BaU}^8{!tUW#*bknT@Gu|H^)J4*2d<$H zN*G}hZzeP6p@5@CU>{wQvngWYsvyyxnD?R4qz#HBIctTz&E%*o&9P;_9_#W77E$7)1bl%R)I)aUFPHS(oV3?JeG zah{1%?ICMicn_5n2FSh*i%sl`okgDdh<#k&wMj>`z~1r(I({Hb%MvJMEv>UzJj7X5 zE?Iyx&fN%_F^*_LprkLNrf~>1dM&!{d?)Yn4ILEinW$1ff~V>$_i-~wawPg=erT^Z zj-wD!?9Y)CYN&;vNe!#--x#^>u}%sxgjugYL}llJaZn2FybXk;{bOGQUhaV_Y8KDl%K2@u!{9 z)cq!MW=BoWRZ5by26B`e4xt8+LI}F0!RQd3CaH!X7J`4H_l182h-OydhqQg>~pA{J;%@Cnijd*MNngFhcFUMiu&(TA6KxViH{B!NO z90l=M(wui;IBgaTIU zVH>u%hstc&VZf%UX72Ga<7l+BnZI{u)fj1CoySEz;uxGlj7CL$@`%sC zk;}ADbu>uZplDx>YY@Xpc+b{>Ehy&03ob;5KEx;sS+qDn-Pn_p%~;3#8N*gJSZc-7 zv1pG8;a=p2>`B#etpFbn*D=Ml)R6%1W47eZn1Qpv{n1d*KjL+Es^U%wf3vt$n1!j; zzj%BPhI&l}V4{ALg+$GSC>Dw|3#;tIl%Qe$BJ~P}!WM1_2V#=0nPv=h*ZO!u-lY~F zeW%giS=~T)NdNt=vBJx^X$jIz+_`Pb)EHH6(&Pn97^H!!EDo>(n*5PW@_J0-vsV^DXdmpN^9>EzGQiTXkXz= z!t0~}q77_e(`NhY+PclDj-d`bXm~}7Cv^D|MqHRM_J(E_=VI&!XCwFI$OUr-@aV-s z!Fb`C7cG9_?4ZrM0#`s0>jOZWKz^S2z64aP-CTR`* zHxLGb5<0L`v1iNf&UaEQI)vW-CCq*8w7>BK$Stl(c~i(tHfXc zN;vVjUUhbz?Rm{Gxc>SJxLN28U&Y_sS;h!-EDn!xOOH{4!w*gQ0=+*y+PnE;tAbDi zf|H-|3KyrcK8l5^q!JDof{_L{vMBC^sy35mj62-80nE@Bvl9{A)kwl;lL2uQ=XhuWz%^F%oTjYr*HJhtRh;cvR&b ia|eA}phZGAnNC04T&Mn0{YrhVUanW`H|wv~U;ZE6008^| literal 0 HcmV?d00001 diff --git a/sgl/models/__pycache__/simple_models.cpython-39.pyc b/sgl/models/__pycache__/simple_models.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4dcfdbc540795a34fb31224c76718da3b68ffe98 GIT binary patch literal 6970 zcmds6O>f-B8Rn2&?rOD?<&FFq+f`gU%_g;#G#@R{z;$9rX;N8nZ8re~qztvBq*Yej z^^hCK!m5YZKu$q>Dv+XwRe>J#2YTqe|DdNq4+VM%{1r4 zR_H|bmh-?_bwP2%0w{%yQUJvZi=Y%UiU&$5oC0MkqZC0Yhtr@;XOt2sm2d`>nT#?8 z%4~Q9lp`6X9E#gkZSE2FX06oO4>u>fNq&eyWUY$OTD8Bl)y7Pet@yK6vy(Pjah;;bTLFV9TwEf#g)X&H(X}4fJ$ug* zR@d%YzLUDW!k&eZmlnJB199E@_}s4BwYI!o3D2g|a&Nlpbj6;%C-#J}?mpkOdKJDO zTDPrzu~M5+g{0Gtq_nYwD*1jRZlu1ip7jrH&-XW5Z9nd8D?5(W%*{Xs+ff?H;NS~K$RMU9QkR25_pZ$!!q+U+O~)l{E{Dx@u0-&E5}%~p`!iQ=Rs)troy zDD~Uwi5onP*c8NZw5)? z`}@|vFD-7iwxh))xx0A#-Q`Qo`=2i+8_mV-Rv0yt#iX&_##YVU<@S9^6?hc?96C$P z3ePT!iYSZn;J>>an1xxm}TpExSkcP&FP^BB}10 zE2()|M5kOTL>>7&_^c60zY{msTXNgqwS_n%{=CYp2}1&}MdJ*Gb`xUD0(;!DmDJD;tJK({oALD!}sXuf#k| zQ9l*Cg|6Kzb_@DjXS+6hz+HL7C7~q;`jQk{tt5#SlFQ2(bTH2}sDG*jqNX}=a<>yj zJCSm?qac>kn6d8}QOs(^H(`uEJ2f9{Y{+OMNTbm2KT@KxS_#5dJEyK5ZN#wQ?I>(u zrk;2}n5qbg-wR}T9E&AYbe3>o3{`vnu`TS!o>Q@RP8}9+_SSxpUKPVt;8kUx=GB3f z`KsdkcRNATJkhqM)h2&{1=LJtw)Lpl^!=e79m7+pbxhGwLkVS75|!yvK|66BwZL>9^^x}E+na%m!nf+Ol_X1WxDJe|o-f43Xn5ZZzMN6`3dKbDvR58yo0*YZ zOP^?9SvE|vZG_dz3ID>4k~F+a-N;m z?7)TjA>YR}3O)cJ4L+|ugeI*~gAw_{=t2V=pxJXjISy2S1Eke)jkMf76MoT?l9i}n(v=5d)zb+}Q70_N zmVzWbwDl3M=Msb?{w8mrJCL*>f$<%ky0~efhoq%fLIknr5kZ)u;OUf(f|v1Xt!&)- zRg8^WU-ZZB0O&fB4V<4eZIC4x8}9q2a$9kfs23)^c+kIV&Or|z6hG9FJh#gd$6sQ8qRR+Kvh>(R*xk7}e&+|?L_$&HiZ8Nzi5-E} z1lbY|B9iHS`6g-1;fKUE@Gx57Ra^;U10wBA$+3kbwT^3~Q9f)5*UNp3Qu(9MWdX~9 z3d#ag4!GL83fjBo-V38RZPbJ24P?iScw?m|lqVyE*7%SbXcfG)RgC)BAd7U2|6IX5 z^U*)$XpE(PL{6DKoI1FT$)-tr(6+#j?xIZ6$*Y!NW`RhZOl~kFR-L;LAy=BoTn7SN3ig^U9EP6Z zi4V@?*0XT{F5hZ&C7Grxn^HNMUa2K6I!y(e~Urtqd*5%M~o#HoQJ zhQ+WnQlFRA{EbdCZ7c`(BY88D^{Ac7R{V`-X#MkNOLeL!Zv+P3Q&h=6(2c6S>xAM{ zfMJT7=z;hGH51gYBcLwci-1o=LPW$VV;roJ(lXyVY3XX~?m62AjQ_(!#jOTI9v!d6X0Y_7X{t2^ZnG=D>5sc^bn@6GcYJO}` zE0`>4=lqT3n+w~WBwg4HK8qHDbOC`9cj>*>LX&b!p1Y&NK)CX|;y44R$j z9Vt=fKj)8=w2A75jj4poH+1Qwj~h-;cxSzyn7oVztMfN?YQogDSWV)x7AI=@?I5jh zuAta{B~-=hh!iksb#4N)FneCOc`I7}l{zzFT3bfVjxGZBOD6Cc-`#A4K!1H|QD!oe zefB%}2vy8^+iUT}N?{W-lk@0!Oi{ugk1FQHoH#D#5i(IVU_5X0@M7oeIB^f$pRk@c z+B4%NrI|2DrpP^dHnKIb%fRp89ODAd1-1n0etWCX)h7Y4Ue5DCDQY{d66%l?D3=02 zU~@NuMvT}@r_r^9nnmQRcbY81`AjbJi7qgBW}>{E2(_lK+_grKco zKYoK}I$&G+$fhcGE{^@~C)OH75MKKGKf%D9fb{+H3Sf}wk9`WD_#l%>Xv6D%0y-oz zp!XmfBQ?lHkM7R$-4LfG`dz2%=ra-RF?Sw4M$1bORyjr>H#lPcCU3K2UX+)US4nuD z-4eUej7ldw4?)k$$tujGtgniVcyMN{Xfhl|F#kE;T$)^h(62Yh+#E&B9xr*4v^uVl z)<H9jwKL@)e38QWwm4ihIQ$8O zV)BglFv8jG1OXqyVEGHqKRbE;82mjmc{U*U3A}*9@&oSlCKr8fGMgMAU5&lxCQsDp zdxtE?WWm1!U8ijS4MKMoCWvc*p*DEsjXu%GkiX!jpF$BqQOON30%$<&co}rW{sEmy zqs9)TT`5jh@!Sg?#Bm$=`~EKOv!qOOh<;&L3Dbi_01h(2>QXOFB1b*Z7#QH_9_!4MW}RA!7)r8mt<@D&PGqcRd% Rr+T(}u6n#$td^@U{0FbYIClU5 literal 0 HcmV?d00001 diff --git a/sgl/models/backup.py b/sgl/models/backup.py new file mode 100644 index 0000000..dd5ae08 --- /dev/null +++ b/sgl/models/backup.py @@ -0,0 +1,278 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sgl.data.base_dataset import HeteroNodeDataset +from sgl.tasks.utils import sparse_mx_to_torch_sparse_tensor + + +class BaseSGAPModel(nn.Module): + def __init__(self, prop_steps, feat_dim, output_dim): + super(BaseSGAPModel, self).__init__() + self._prop_steps = prop_steps + self._feat_dim = feat_dim + self._output_dim = output_dim + + self._pre_graph_op, self._pre_msg_op = None, None + self._post_graph_op, self._post_msg_op = None, None + self._base_model = None + + self._processed_feat_list = None + self._processed_feature = None + self._pre_msg_learnable = False + + def preprocess(self, adj, feature): + if self._pre_graph_op is not None: + self._processed_feat_list = self._pre_graph_op.propagate( + adj, feature) + if self._pre_msg_op.aggr_type in [ + "proj_concat", "learnable_weighted", "iterate_learnable_weighted"]: + self._pre_msg_learnable = True + else: + self._pre_msg_learnable = False + self._processed_feature = self._pre_msg_op.aggregate( + self._processed_feat_list) + else: + self._pre_msg_learnable = False + self._processed_feature = feature + + def postprocess(self, adj, output): + if self._post_graph_op is not None: + if self._post_msg_op.aggr_type in [ + "proj_concat", "learnable_weighted", "iterate_learnable_weighted"]: + raise ValueError( + "Learnable weighted message operator is not supported in the post-processing phase!") + output = F.softmax(output, dim=1) + output = output.detach().numpy() + output = self._post_graph_op.propagate(adj, output) + output = self._post_msg_op.aggregate(output) + + return output + + # a wrapper of the forward function + def model_forward(self, idx, device): + return self.forward(idx, device) #直接走下面的代码了 + + def forward(self, idx, device): + processed_feature = None + if self._pre_msg_learnable is False: + processed_feature = self._processed_feature[idx].to(device) + else: + transferred_feat_list = [feat[idx].to( + device) for feat in self._processed_feat_list] + processed_feature = self._pre_msg_op.aggregate( + transferred_feat_list) + + output = self._base_model(processed_feature) #model training + return output + +class BaseSAMPLEModel(nn.Module): + def __init__(self, dataset, prop_steps, feat_dim, output_dim): + super(BaseSAMPLEModel, self).__init__() + # self._prop_steps = prop_steps + # self._feat_dim = feat_dim + # self._output_dim = output_dim + + self._pre_graph_op, self._sampling_op, self._post_graph_op = None, None, None + self._base_model = None + + self._processed_feat_list = None + self._processed_feature = None + self._pre_msg_learnable = False + self._norm_adj = None + print("BaseSAMPLEModel: consider different model for general purpose") + def preprocess(self, adj, feature): # + if self._pre_graph_op is not None: + self._norm_adj = self._pre_graph_op._construct_adj(adj) + self._processed_feature = feature + else: + print("do not normalize the adj") + self._pre_msg_learnable = False + self._processed_feature = feature + def postprocess(self, adj, output): + if self._post_graph_op is not None: + print("Not Implemented") + return output + # a wrapper of the forward function + def model_forward(self, idx, device): + return self.forward(idx, device) #直接走下面的代码了 + + def forward(self, idx, device): + # processed_feature = None + # if self._pre_msg_learnable is False: + # processed_feature = self._processed_feature[idx].to(device) + # else: + # transferred_feat_list = [feat[idx].to( + # device) for feat in self._processed_feat_list] + # processed_feature = self._pre_msg_op.aggregate( + # transferred_feat_list) + if self.training: + sampled_feats, sampled_adjs, var_loss = self._sampling_op.sampling( + idx) + transferred_sampled_feats = sampled_feats.to(device) + transferred_sampled_adjs = [adj.to(device) for adj in sampled_adjs] + + output = self._base_model(transferred_sampled_feats, transferred_sampled_adjs) + return output, var_loss + else: + transferred_sampled_feats = self._processed_feature.to(device) + transferred_sampled_adjs = [] + for adj in [self._norm_adj, self._norm_adj[idx, :]]: + transferred_sampled_adjs.append(sparse_mx_to_torch_sparse_tensor(adj).to(device)) + output = self._base_model(transferred_sampled_feats, transferred_sampled_adjs) + return output + +class BaseHeteroSGAPModel(nn.Module): + def __init__(self, prop_steps, feat_dim, output_dim): + super(BaseHeteroSGAPModel, self).__init__() + self._prop_steps = prop_steps + self._feat_dim = feat_dim + self._output_dim = output_dim + + self._pre_graph_op, self._pre_msg_op = None, None + self._aggregator = None + self._base_model = None + + self._propagated_feat_list_list = None + self._processed_feature_list = None + self._pre_msg_learnable = False + + # Either subgraph_list or (random_subgraph_num, subgraph_edge_type_num) should be provided. + def preprocess(self, dataset, predict_class, + random_subgraph_num=-1, subgraph_edge_type_num=-1, + subgraph_list=None): + if subgraph_list is None and (random_subgraph_num == -1 or subgraph_edge_type_num == -1): + raise ValueError( + "Either subgraph_list or (random_subgraph_num, subgraph_edge_type_num) should be provided!") + if subgraph_list is not None and (random_subgraph_num != -1 or subgraph_edge_type_num != -1): + raise ValueError( + "subgraph_list is provided, random_subgraph_num and subgraph_edge_type_num will be ignored!") + + if not isinstance(dataset, HeteroNodeDataset): + raise TypeError( + "Dataset must be an instance of HeteroNodeDataset!") + elif predict_class not in dataset.node_types: + raise ValueError("Please input valid node class for prediction!") + predict_idx = dataset.data.node_id_dict[predict_class] + + if subgraph_list is None: + subgraph_dict = dataset.nars_preprocess(dataset.edge_types, predict_class, + random_subgraph_num, + subgraph_edge_type_num) + subgraph_list = [(key, subgraph_dict[key]) + for key in subgraph_dict] + + self._propagated_feat_list_list = [[] + for _ in range(self._prop_steps + 1)] + + for key, value in subgraph_list: + edge_type_list = [] + for edge_type in key: + edge_type_list.append(edge_type.split("__")[0]) + edge_type_list.append(edge_type.split("__")[2]) + if predict_class in edge_type_list: + adj, feature, node_id = value + propagated_feature = self._pre_graph_op.propagate(adj, feature) + + start_pos = list(node_id).index(predict_idx[0]) + for i, feature in enumerate(propagated_feature): + self._propagated_feat_list_list[i].append( + feature[start_pos:start_pos + dataset.data.num_node[predict_class]]) + + # a wrapper of the forward function + def model_forward(self, idx, device): + return self.forward(idx, device) + + def forward(self, idx, device): + feat_input = [] + for x_list in self._propagated_feat_list_list: + feat_input.append([]) + for x in x_list: + feat_input[-1].append(x[idx].to(device)) + + aggregated_feat_list = self._aggregator(feat_input) + combined_feat = self._pre_msg_op.aggregate(aggregated_feat_list) + output = self._base_model(combined_feat) + + return output + + +class FastBaseHeteroSGAPModel(nn.Module): + def __init__(self, prop_steps, feat_dim, output_dim): + super(FastBaseHeteroSGAPModel, self).__init__() + self._prop_steps = prop_steps + self._feat_dim = feat_dim + self._output_dim = output_dim + + self._pre_graph_op = None + self._aggregator = None + self._base_model = None + + self._propagated_feat_list_list = None + self._processed_feature_list = None + self._pre_msg_learnable = False + + # Either subgraph_list or (random_subgraph_num, subgraph_edge_type_num) should be provided. + def preprocess(self, dataset, predict_class, + random_subgraph_num=-1, subgraph_edge_type_num=-1, + subgraph_list=None): + if subgraph_list is None and (random_subgraph_num == -1 or subgraph_edge_type_num == -1): + raise ValueError( + "Either subgraph_list or (random_subgraph_num, subgraph_edge_type_num) should be provided!") + if subgraph_list is not None and (random_subgraph_num != -1 or subgraph_edge_type_num != -1): + raise ValueError( + "subgraph_list is provided, random_subgraph_num and subgraph_edge_type_num will be ignored!") + + if not isinstance(dataset, HeteroNodeDataset): + raise TypeError( + "Dataset must be an instance of HeteroNodeDataset!") + elif predict_class not in dataset.node_types: + raise ValueError("Please input valid node class for prediction!") + predict_idx = dataset.data.node_id_dict[predict_class] + + if subgraph_list is None: + subgraph_dict = dataset.nars_preprocess(dataset.edge_types, predict_class, + random_subgraph_num, + subgraph_edge_type_num) + subgraph_list = [(key, subgraph_dict[key]) + for key in subgraph_dict] + + self._propagated_feat_list_list = [[] + for _ in range(self._prop_steps + 1)] + + for key, value in subgraph_list: + edge_type_list = [] + for edge_type in key: + edge_type_list.append(edge_type.split("__")[0]) + edge_type_list.append(edge_type.split("__")[2]) + if predict_class in edge_type_list: + adj, feature, node_id = value + propagated_feature = self._pre_graph_op.propagate(adj, feature) + + start_pos = list(node_id).index(predict_idx[0]) + for i, feature in enumerate(propagated_feature): + self._propagated_feat_list_list[i].append( + feature[start_pos:start_pos + dataset.data.num_node[predict_class]]) + + # 2-d list to 4-d tensor (num_node, feat_dim, num_subgraphs, prop_steps) + self._propagated_feat_list_list = [torch.stack( + x, dim=2) for x in self._propagated_feat_list_list] + self._propagated_feat_list_list = torch.stack( + self._propagated_feat_list_list, dim=3) + + # 4-d tensor to 3-d tensor (num_node, feat_dim, num_subgraphs * prop_steps) + shape = self._propagated_feat_list_list.size() + self._propagated_feat_list_list = self._propagated_feat_list_list.view( + shape[0], shape[1], shape[2] * shape[3]) + + # a wrapper of the forward function + def model_forward(self, idx, device): + return self.forward(idx, device) + + def forward(self, idx, device): + feat_input = self._propagated_feat_list_list[idx].to(device) + + aggregated_feat_from_diff_hops = self._aggregator(feat_input) + output = self._base_model(aggregated_feat_from_diff_hops) + + return output diff --git a/sgl/models/base_model.py b/sgl/models/base_model.py index 6c93988..5325f87 100644 --- a/sgl/models/base_model.py +++ b/sgl/models/base_model.py @@ -3,6 +3,7 @@ import torch.nn.functional as F from sgl.data.base_dataset import HeteroNodeDataset +from sgl.tasks.utils import sparse_mx_to_torch_sparse_tensor class BaseSGAPModel(nn.Module): @@ -62,10 +63,104 @@ def forward(self, idx, device): processed_feature = self._pre_msg_op.aggregate( transferred_feat_list) - output = self._base_model(processed_feature) + output = self._base_model(processed_feature) # model training return output +class BaseSAMPLEModel(nn.Module): + def __init__(self, evaluate_mode="full"): + super(BaseSAMPLEModel, self).__init__() + self._pre_graph_op, self._sampling_op, self._post_graph_op = None, None, None + self._base_model = None + self._evaluate_mode = evaluate_mode + + self._processed_feat_list = None + self._processed_feature = None + self._pre_msg_learnable = False + self._norm_adj = None + + @property + def pre_sampling(self): + return self._sampling_op.pre_sampling + + @property + def sampler_name(self): + return self._sampling_op.sampler_name + + @property + def evaluate_mode(self): + return self._evaluate_mode + + def sampling(self, batch_inds): + return self._sampling_op.sampling(batch_inds) + + def preprocess(self, adj, x, use_subgraphs=False): + if self._pre_graph_op is not None: + if use_subgraphs is False: + # We don't transform _norm_adj into the form of sparse tensor, as sparse tensors don't have strides + self._norm_adj = self._pre_graph_op._construct_adj(adj) + else: + self._norm_adj = {sg_id: self._pre_graph_op._construct_adj(sampled_adj) for sg_id, sampled_adj in adj.items()} + self._norm_adj = {sg_id: sparse_mx_to_torch_sparse_tensor(sampled_adj) for sg_id, sampled_adj in self._norm_adj.items()} + else: + self._pre_msg_learnable = False + self._processed_feature = x + + def postprocess(self, adj, output): + if self._post_graph_op is not None: + raise NotImplementedError + return output + + # a wrapper of the forward function + def model_forward(self, batch_idx, device, **kwargs): + return self.forward(batch_idx, device, **kwargs) + + def forward(self, batch_idx, device, **kwargs): + sampler_name = self._sampling_op.sampler_name + if self.training: + if sampler_name in ["FastGCNSampler", "NeighborSampler"]: + sampled_adjs = kwargs["sampled_adjs"] + n_ids = kwargs["source_n_ids"] # source node inds of the last layer + sampled_x = self._processed_feature[n_ids].to(device) + sampled_adjs = [sampled_adj.to(device) for sampled_adj in sampled_adjs] + effective_batch = batch_idx + output = self._base_model(sampled_x, sampled_adjs) + elif sampler_name == "ClusterGCNSampler": + batch_idx = batch_idx.item() + sampled_x = self._processed_feature[batch_idx].to(device) + sampled_adj = self._norm_adj[batch_idx].to(device) + effective_batch = self._sampling_op.sg_train_nodes[batch_idx] + output = self._base_model(sampled_x, sampled_adj)[effective_batch] + elif sampler_name == "FullSampler": + full_x = self._processed_feature.to(device) + full_adj = sparse_mx_to_torch_sparse_tensor(self._norm_adj).to(device) + output = self._base_model(full_x, full_adj)[batch_idx] + return output + else: + raise ValueError(f"{sampler_name} hasn't been implemented yet!") + else: + if sampler_name in ["FastGCNSampler", "NeighborSampler"]: + full_x = self._processed_feature.to(device) + num_layers = self._sampling_op.num_layers + sampled_adjs = [sparse_mx_to_torch_sparse_tensor(self._norm_adj).to(device)] * (num_layers - 1) + sampled_adjs.append(sparse_mx_to_torch_sparse_tensor(self._norm_adj[batch_idx, :]).to(device)) + effective_batch = batch_idx + output = self._base_model(full_x, sampled_adjs) + elif sampler_name == "ClusterGCNSampler": + batch_idx = batch_idx.item() + sampled_x = self._processed_feature[batch_idx].to(device) + sampled_adj = self._norm_adj[batch_idx].to(device) + effective_batch = self._sampling_op.sg_test_nodes[batch_idx] + output = self._base_model(sampled_x, sampled_adj)[effective_batch] + elif sampler_name == "FullSampler": + full_x = self._processed_feature.to(device) + full_adj = sparse_mx_to_torch_sparse_tensor(self._norm_adj).to(device) + output = self._base_model(full_x, full_adj)[batch_idx] + return output + else: + raise ValueError(f"{sampler_name} hasn't been implemented yet!") + return output, effective_batch + class BaseHeteroSGAPModel(nn.Module): def __init__(self, prop_steps, feat_dim, output_dim): super(BaseHeteroSGAPModel, self).__init__() diff --git a/sgl/models/homo/__init__.py b/sgl/models/homo/__init__.py index cf4643a..2774fcc 100644 --- a/sgl/models/homo/__init__.py +++ b/sgl/models/homo/__init__.py @@ -6,6 +6,10 @@ from .ssgc import SSGC from .nafs import NAFS from .sgc_dist import SGCDist +from .fastgcn import FastGCN +from .clustergcn import ClusterGCN +from .graphsage import GraphSAGE +from .vanillagcn import VanillaGCN __all__ = [ "SGC", @@ -15,5 +19,9 @@ "GAMLP", "GAMLPRecursive", "NAFS", - "SGCDist" + "SGCDist", + "FastGCN", + "ClusterGCN", + "GraphSAGE", + "VanillaGCN" ] diff --git a/sgl/models/homo/__pycache__/__init__.cpython-37.pyc b/sgl/models/homo/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d1d9941205125e4ae42fec31c73ec244948a1b1 GIT binary patch literal 715 zcmZvZ&2HN;49D%*j$g&crm=lbq!ESW+lb?NW1lU zeC;l~@3hlMX1!H{|N4=XMN+zoqR^u8>(5OwX<61Y8!w|r@SZ;NM4~KZqor-IZRQrX zv;&T?gRb_#6Lzt!eei`n?C1aj;Wmc43teF!dpd$hxPyH?fPrv;Cwd4&;Sfi93}fLg zPV^L}!aY3IGnf%as;>s>L=DwQjnza=)v21Pv$6+g>O8gPZ+_cU_S|l~Wbx_tGeKYA zOR+xGrg|vI<0MxcaMEN+i<{mgD_Q^BtHpInwom4V%Iq1hbF)vDyenMd!R(9rXaz}~ z@9xt@@{vPd^G$`AQ>gZ+h`P<a2Y>~eg@%>g8u`%D{G`YRN#~-mN qG1gmEpyAjy&a!G#?X&E1_tCT^B_r{CNIZwcn)>KPdQjjC^q6U-!jJBN~aqN|2Ix zgd-b3z#L#`BZ!zo9NQRT<_Hs;LdrbG%;u0Y$2hSC6wC=u?F?qjDL%4um@{Yi*q*=% zbB<5!DV$1SL^Js+;jAetv6%k!p2b2Aky_qdzx6O>c-L-@z3V=T0guRC`jysX_pt!P)a!lOI1{X>+hzTV$eOZ8@v_mMB$k4a*Zkc1Aa z4@eS{l!T7A&q*d;){CbQncCu>H^Yos_kaEJ*ASTZd;TK?&TP2U5W3E6xgdm&vQ!qr zZGG#w8YDic8{Y!g9pYQxh1qP`z>sYo8leQ4fz(hht~M?E8e*+$M6JJxeokgdew~+h zJKL7deJCq+eUA5^N@uXNyQW1)Yr9hGZrdHSezw0Kk{5pEi0w|W!{1kL{LR? zno^3kr@Y)x{l@lHkcVm5*nx`jPTHYlLPRK{Ya${UUioP(Ix~{QcTkBW)NLQCm9bJ! z4v%L%uawjf^udGQ)&Vzi_U7p0^xaV@q=L}HDX$dImb{o~UM)UU8P-_~_K9Mifs!Pp zf~1~^ZpbMWfe5d=seeNPk_OO0$MrT>yGZjwl=u%@+F`*mD*>z)X2BZx+}Q-^g_cdq;|aUFccon1ko zljzvR2C*-TIY1oB3$9k&N|poYIot;Fb00EPC%)UxTEbv6%cWhE!acYTv0{~VkulyWQrT3hfq;JK(z%5 zP<0t=GK_sAf5ziQnai;;=i}LAx~DEbjm=z*kqKomSB~$&&a+}vUFvNp(t%}gK7>A^ z1G+;$Ki-UXw1%Ac8aHa(sJoz&5cw#0E*Wz@hS*jL?Ez!wE3Vp1a|1BdHknDKfW7Bf z164vbzBRpz3Ms~AkjR)e{&CWEeSk1(ScN$>BtOMS3ugpv{e&=2~kPPsJ}=bgw9Q4OatT3O?^IN-+f44+(k&~+^mdulZ#l#zTGwAkPD=obbsTn zyuU-aEfs%(iWxg7S!+h~&3NaV`J5Ef=@{_*d^wZ97{G4=hbKVr9?iT#CqOU-70Vf; z*i&BR=YAjiDyYIdWH18}i0}$TD8sFvM|qh97UIt!L{u=Q@dTVP=Qj&(s7cgZ&BfCn~29Cp#xzll2k7T@Ht0AyhE zG;yJ68>uz*a|!3`&DwIVJ%ST%XwCIz#p^~BoAYe#j`$Z`l2^1}I&y)LYUzSPuZ)XD zVXbsDP6eIMRV?wr7SEL>&lJm);8@B_t{oU<06aKrM+Wr#CfR-V>rKqzgc!( zS28o^B3sN(U#ZLQnOUiYM$fpGgVzZ>=Rk;t7Tq8C% zCppN;hyDZmQlNi{rvN<^{R=*IW@S}HDa>d{4rhknH{2J)VN77WoBSvmkB~ocvK=rd z-+<_spg7@Fl7c3bqU}hha1*z)UFj8m;&-+ugCa~qN{$Hkd2mg5An4jnA|B4kIQknR zkukN8UZh&gzkT)d@%OJPE+j1N9jCQSvpg-2R9c_?RKwcnhg6%R?@s1vQA?p<#f4dh zLAJwxF!X0&2$E1v5{F0c$%68Lhu6KteNQ}+c#ytl!-s%>ZBBi7ejk4I^|J|x9hQ|U zQkh=~mR0i9x>?<1y@%-dAA-0b`gc%<7}}DxvvN4)&JBRl`nN83@90~fd)&W<`Af2f zneniW8}Men3EGfHzeZ}+l9hiOwsaMN^Z-!u&RIAgBeY=Q8qV%$IoNsL(bjJR-n(%j zR&4shTDqpcIp+Z0A4r>+V*Nyy=LaJ7GvM~OOYBfq)e2#=BWRdX8`Wfn zrF>;wR6V28ykt3lV+TxQ|K{Zqk6^>vLD>`xi)xh6)|H~Pj;U;)y@q$u5hS5oJuTz{ z>Pro*8XH6TT$GIGg^i)$x-n?0PavXtf@&XB(#uN@rQ{dF23%a^nE;4-jCI#$Y{S-L zbpI4o1frjUBJ?qh>4**-$9+Zv>d@G6=svyL`#&9HU*TI&cvRubI92!&0iXf1GA#sS zHfDH>Mxs4n?7T_kW~WOEFmB4pgp}Kou#3i2O)zR#9@2n`y1s~LN6QJo@&sCHx@oED z4ljk>WA|~d0uobGDLws{C?qfG<7USNkla!C^gp3=jvtu~_`0L1%ed?A9db50*xc?A N-;hQi2DJOR{{Zu-XvzQp literal 0 HcmV?d00001 diff --git a/sgl/models/homo/__pycache__/fastgcn.cpython-39.pyc b/sgl/models/homo/__pycache__/fastgcn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d684c2daa6bf442a3b0ffbdbc7f43cecb229530c GIT binary patch literal 1465 zcmZ8hOK%%D5GJ`#$&&0o(l(d2hoFFhSZaIf!Ehf*>x<(S&~1T&*eoe~4erAvH;9eR z$u4p-(EbDan7_nRfgXyUdkc!DGgK5CzRZ@c|fL(z> zZZ9kpm-!5oBpDSXVGRmr!z>X7v0=h1U&Nl(w| zZ|o#oke4{5psekbIX3&;4OMREg0hC8Y>wrHw$cLPHF4MNWDy&1xc^Rg~2=d&Hv6uC0HV63legdPf_ z$bgNTQp_}RQN5J8&BFSq6j@vsDmNxeYJbVNMChh%Iy>ERk?C5gbqt|gy3PjPR+U!c zT-}(rTJl3=A8%TXAh`)In%Y?16_)4X*afHtR_A5SOYzK&xIu1~^&viq@X4s|Dvm{U zS{fH9SvzJMx51wP=lw|3Go!y7f9{ z?BBDk{|})+lB^fnh|fb1GXez+CRenl*A#e152j3i0$!J{kpRa%aAdz&W;Jd?I>N!P^*tAG*dC^Z?W}#;MhLB{_FPjyQG-?GfiEU9MIqzHN=sVVWyc z3aQi@8Bg8&7}8gPTtI3W9Cb=3o1Ps9@ZWF;csyJ2cy=>3jn%pj@T>tY<+jmg_LeV* vQMQ4Hz6szOhRr^@6>4>`;(+mN@fmNS<157CT)MMraTkBQ25}fs{{{a6$SQN% literal 0 HcmV?d00001 diff --git a/sgl/models/homo/__pycache__/gamlp.cpython-37.pyc b/sgl/models/homo/__pycache__/gamlp.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a96215cc0eb721e6246f2c193ca558cda60215eb GIT binary patch literal 922 zcmYjP&2H2%5VrHT>C%NFm4Xm2un2Zv00FHKd)aL(^?>Ay<*waqz)4&?ktoV-EA_+! zaDyYSloJxCz5*v^oL!0|Pdk;x_4kCkyMj{ zGD;B#Iw(RGdK~Jgh*|7$q?2O629(T*jAe37WTNOgWT_mSk#Twpk;s^Kqa)$e+3fIi zQAwqN*k3fd&gbGv*;8eg%G9TA#v9@VkLb@Uh#$35#9*rXo7?x>C1a2>%ql?1L!jxz`e zDOK6?Wm9mC4RIiHhx4Tt z&T;;e{CPfIRfU>5cR4+q&0p%P&r_G_slUykubO6}&}L$;EIzLkrb9#MesG^2&@cCW zxhLDy#xbT@ib8ST4LRy=G~!*(FB_qIkB`-jHcFOCYbeSdMD=D&?E}OQ(O~5UFde*U zl2X4u9`CS6pmWmeoZMB&;x6{&KcCAB0|NS;eMt|{T|rw!U0FBz*8$%ZfA-E<^zU)o aAbOSDV<7ybo9wCmsi literal 0 HcmV?d00001 diff --git a/sgl/models/homo/__pycache__/gamlp.cpython-39.pyc b/sgl/models/homo/__pycache__/gamlp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e95d2ee289975e3efd6f74a6ab72ce23346a0cc GIT binary patch literal 932 zcmYjP&2H2%5VrHT>9$*eC*csaA&td0y|&*qZblAkL1Nf!C4mSQ2D z;#{&3Ggh)3MPZe(F=E>u(=JB@ zDFbo}LNo1EPP)iRb(Rf{tqgOuG_Fg}B(Iq$bT_P;+BCQRst`h!Jzh4NDIAQurrj1} zOU0dI?3(<2F}2yQ__3yk^O@WUL!9$S2wz4vXn}JD373PZyutzhX!>Uz;^Ju zNlJZuINr)01ItOza&oVa#j5P-M$Q%5!0~-$FX%p|D;Nr|E9)lz_4hOJ|K2Bc|0?gs aE`8l!^bL;Q3$#Q17ybi;RPAm6 literal 0 HcmV?d00001 diff --git a/sgl/models/homo/__pycache__/gamlp_recursive.cpython-37.pyc b/sgl/models/homo/__pycache__/gamlp_recursive.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a98f1bef88e1829537eb15c1ea5dd82a6f2c18a8 GIT binary patch literal 962 zcmY*XJ#Q2-5ViN~HkVvPkq-(Q%9G%(p+tm4h!f~8iAaI8jpeMJy_LMXKHG~3+Afh& z@gJb1=hwJGqRL;OV(iT&;FV{eXYBEtH$EDTh6u*9)m#2EKrsE5592J`qQr4v*M7bu}5s`|M z!YM==sT@3x+hoX7D7&pD`0Dld6rBhAHy9ZJHN1V4V!u8Ht#^U?q^CZ45KXyV_VGV4T~-nZ^gDZ%_guJv sPEpe+J-Kg+Z`Q&+o`u(H3;Qqp3!p~=TL1t6 literal 0 HcmV?d00001 diff --git a/sgl/models/homo/__pycache__/gamlp_recursive.cpython-39.pyc b/sgl/models/homo/__pycache__/gamlp_recursive.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c627c1e2ade42ddd776f61ef356bc9ecb0dfdbf GIT binary patch literal 972 zcmY*YOK;RL5VrH$bla^sw6q5<99M$f3r7SLAy%N>wxS25Tr4+svyq&{YdeUd+*+w8 z{sSC)=hygz#HoLQ6Jw{3Wg?H};mmyVO+rVbAp&`}de46a2>o%+x_7xnn1$XBMO4LE?CnS-)gT*SG(#+A$t7Y5#}{FivcU-&r#Daujj3!-=$srDP!QzzYZwwMg~7gY%5JDTG~*2$>sE5592A=pQr4v*M7d?z5s`|M z!ZAe}sT@+f>0^~HT*Tpq$SJT-Do=f4BhAH`9ZJon?Z+ri*Vsba_IChiQwBm%=A#tO$!r5DdtXeAij4ukD^x$$X z1n{&6<$Ja3tbL~(l>;~${sNR@?;irwy8wLB1D`xlrra&Nzn0TwC82$Pv*-A+OE<75 rY8s^{_bc&h^6x$*tNydS+n8^^hSMWYx15ViLs*`y8PP-wYuCK4nEu82Z~h=lqj<$$yo%Z=U5O5UtvdnJnY)=E8b zFSdjU^Dgfo0!;6GTu& zDwbEAr$~SM~!mjA-|C3RGNj<&eSWwA0Be8_@d;inc;eQq5<2Tmz7rXqb!T1lj2BP z%L{p;7g&5_NFL2T1t&>P1<50kUXybwVv$@8^7xu0Bu{`=&&RWa)7x;K-0a`J4EAe( zk#!Kk{}fgX%YFdoh@&mJjJ`!JZKF2+4obMVrGkp+OVkpUxK!Xi*##Zi{$&DK&-LMK z(xz=3_L(4f`EdaS!EW}UH);%p_zs;X%qLc=b3fFk*33GotVUTAb;L zzAz*B8fCRXW4aSPpgZ*QHq2qNX>c5wK3;?;#wj+w;+15~4;gmSD2%rmTQ^*FnNSrZ z=@>1fQXtbF)<8F=W*6fp=uk_nZsFCFjE4O&Nc{--PCCAmd(8}<@_YZ~YjG~iwI(E)9{P`05^5^g#Z8m literal 0 HcmV?d00001 diff --git a/sgl/models/homo/__pycache__/gbp.cpython-39.pyc b/sgl/models/homo/__pycache__/gbp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1765c3769437808404d56fbe9d0ad3295aaeda94 GIT binary patch literal 962 zcmYjPO>Yx15ViLs*`y8PPzZ3~Oe9DSTq_C{A`y}}JlnHx-k1*teFFCK{Db@!5%LR@O-mqo3(LL-Cy1bm zlr*Ol<48qioW~)KRZ^yT8sbEC%5L7JWJ*LTI+sLrBwfdOCb}~+%5I?&8Bssj=T^?9 z?@kUYAr$~ShqZFW0l$#uM4GwO&QvSFA0BY6_`Kk&so{EgtN|O&ic%~2NfwKxlj2ZX z%NO!k&$0NXVFPIPIXFpjDo7rQ^opEO5sT!qm&aEmA$bDSIzFE6pWKA=tMh4 z=2;5`k8uhshGjp3bHvextfTKyLz}3He}Ez`Zm6Il`WiKaB`y`1CqvMq>8=yNI<5<6 zlO}EAkY|D*fg{s}iqIc5hC+0U&Zk`IB_H*CYHKZx=^&UfR;&ugn86nPfN3My0tI4~ zHptL#1+26SM365aryx2jgW{P_tW;;duT7V&f}bO2&MjVJEf1c$=~FnyYpvR0T;| zMsukY$n+Ghw{D2hGjymVDsJKJlZ=M^2!wtFY{xCz@x5XOPx&Vsdp5knR!`eX{1~?@ t7!!AuvE%;}2~!Gp?>Ec#+1?E-{CY}F5GA=^(n@mDpfMU0=&^u+SaN$doF+}8i|rQBO+Y~~OG;jY`(cu6#Kz{N z6gm0OTaL~#e<%l@0u*S`zu;46R#04^6gV2rketUiPkJ^Qr3A*O=2-m|5b`H3wj%`R zCz$p*7)d16q@W8*(Px?!!6NAVK!-)Nh&n&iagi(%N@hexGJZ>BtfFw zCNie}-igXrr>`nA7e%d=flx5}`2;i1gjMsW$G;xEIIg795W0UPYAv!{lrtmh(_iWg zZ`lq+iNk&eMv(=TWWi+ejx1@1puY&-k&rAxtnd5X74xUF7k_uB$mb_7Z=)ag_kP&- zK~^_eZxbf?oxlWAh zdF0}joJBWD8yVTsvmVSo&_{k}OZEtf1P^1lDns1yfwr;iT{AeHx&cInEo%ceYU4F+ zL)pKMZg$!Pb`gG05_nX4;=^U5k*Sfo5V{dg@dCKJrF$6#Eklf!djn{-6n9Psp_h6Wxy%cn0^quF=$l7)V?Qq? z0LU+t!Pfa$s*5~RKt`YQ?!}yM8uS&U&9`7szIGQ3p$}+Ecj%BZ8q>!tWsL69s|WvQ ze;gVV{7ZvpO&<)<3rv-wP@MaeBlbq4Kji$p5qgv9>H&EhB$?8Biv~3_rsgr;*C9jl zu|}tX9_^ZOGVc2U2x!s~2_5Qq{{|x)Nlk2y{}CJJ?1tlq|4HA`YEa>R7jqT#TDZ#C g$(Md&kGqO@TD~LA_YO9%JH%g%MGA(%8GM3&0c0Fqq5uE@ literal 0 HcmV?d00001 diff --git a/sgl/models/homo/__pycache__/graphsage.cpython-39.pyc b/sgl/models/homo/__pycache__/graphsage.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03eeb7376bb9c3fcbb277afff25926676362a245 GIT binary patch literal 1480 zcmZWpOKTiQ5bo}I?Ce7;ArdDcBm@E`3)ZfZvryvLW-qc5*xN#A#?vinM9)LrJ;oBF z(<=BRi2s3f^k3>zAcx>{ZYIR3>di_E>7io>xLDaBUr=wMw%*EvGuIPkyR1 z%&@)?B`)(2lq7R1$efApYjQ+Ag6=#J9g$p-dH9-yWF7&Ip6mK_h~iIwn*9F#`{=vf zoo{#FF>tsbtD7uYhP+0*0}_Ck-#}SnX-gLDlnE->HMFITZUPb9(#uGMBD#Y8Q?h`a z4cl16qI1n!!Xq2E_!_c zXd8;|b#yaqJMb3an}T3bhb}y76n1SxzD!k@1v-w>YLCFYSmqHr{nTm7Yt@O@1Nu5YBk@emps;3OO<6Y%-f3?5Ois z6LYL4c&joPt0u=dH4f8EKhRb;xbdIF|Jekm%tFZs;kD1wp?VIT526bDwnKo z18=Qzt0VMo@LEgmQS++So()5uRi&}I$t+LBsS8kbtxofj=i-$caD!aV%VRu(6mf&H zDL8^^v@|YIvUJQ=uFqcp$NgB;C~={Y>Im?q7QEK=0KSkV7kS}&0Jv@}+WK?Is2^b3 zA%d%+9B!7XbzRPV*_yvlE!C zF!$8a9^Wp+KLIiRBD#*@E=%YKudxi>MTP%bqZr3mriJ9(^*A=PQD_f1KWkF8jQILB zUIJrXwkMNHty<)HJjRqB43YvlX7J-k=y221q>j# zKfsZ{_I zk!|!f>Sz~r@i&;;#a*cMCF%%ET)HK&58N>9nIKsAqJ+y}KVAcKY9@O=wHqys89>ly zj8#qL7&F+=kC`@-l{hJDwLt>@PKedE#0b)X$Oll4)j_tzCswMZA8FHSW}Vd5XG_T) z6IJa;?Z)X%z|ek=u|jcc8T(FtPiCvOmNRSDv&H=Eg}VAQv!$AasupwYtl>*L)mLT@ zx@NE}Izk^scWFjHk5RDcZh%RO!!*2>jQJ5mb{mCwpRx6Zt3DG-!I(XdLMjELKY*ga zO{lq#_#ryfX9&x}FG(^Q-Y-DD$G~#hvz*@QVDOZG`d`jgwFUzH$zHf8m~KIQ?pkA~ o|8x%8hr74QT3X9Xx%&*E{VmfTAy1zj_M^U{NEQ{^r{RhJ0JRm|h5!Hn literal 0 HcmV?d00001 diff --git a/sgl/models/homo/__pycache__/nafs.cpython-39.pyc b/sgl/models/homo/__pycache__/nafs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6161ab0a693b0179dd0270ed3176c76f93ddd060 GIT binary patch literal 867 zcmYjPO^?$s5ViBsre%vjEQrINxK)Do#t8w7R$C$Mwn#l7X)x~Ou@U2v^^;xnGT6^Iz<@U6fs4&n%cki; z(j|-)Rbd%xaG)D8-AI-rV^-@10l0fUSDG9%2nRBsKtWb|`63q?sa9^Nn_4qtr8X{E zNp6`aN;j;xR&RZLOWlC6OmSlvyCr|d({)|SsWF@Ba(3}TU45CFTupsfgSC2E@s*kA zt7ZUG6Iccvp~t~}n$WK!6m7C6FpiN-#Y@SU8!|+QvI(=9Meom55zn_D6j{)PP!#KH9p}|w`=|RmDrH1j{$xgN>Sgt{OZtKQO o{;TYj_jm7#r8I`;a{m$h_ye3CG2(Q|tB&YvieOQnBkG^qlw0}c_+IDu+|!(l$3p_~Z1Y~8Muhyt%{~$UZ(!L|NP-BeNKJD} z5l1Sj<2(*HR!N=aX~2o<)&0Cr$()E(^lpggNxF;kO!SvzlKlcEGNFEa!mV7+-+nl2 zgj7&CUNmK8T~&OPrIFTF%^Irr7F;V{RD3-*Twh*je09OCJA)&XN4I1}1r^avmdCdwA$bBqVjnN(?+Tm{1D`1@B(N_aIpSzbcG1_UrES#4-(XS~ zw}I$O)Do7sbVuOXbA9aZu$Bpe)aNA-K>T!5cU#t#hB|foIt^81%Ml|F+Rbb^Up!Y=pJ%pIv%qSx*TBA*>MJvVOEXv&6QK|2U7FGB z5m+|e%b%n;%9_`bF+XIeZL1KE7~5>Q>Nacx&(P<9<@R{R4j z=a;-fV#P18;$DX`y^(M3!#O_p#C$v+60ql+5Aq};4IB#3R`$7&vk)-hZg@!>{AFN z$*CZDbWK)NP!U~ad3;S0k|#hW_VIH5zQ8H>hH(lDLF`L#jyT$qUGzO_X&bfi5181+ zZ2@q8iBGIlD?ikx(abujt8({Yd5pyeDOkEewo=)%>t^$UITeC)t6=fhi0%WIzo@=U7FFa zBam&nr$0$?lr^slaI>=nf;X`288|@%Rivgl zr5Hyls^dHkajcR$&C?Jks#o{(J|#0EQqj91q9^G#&NI+{E#8Zjl%eVvGs!Bh05`ZXH4iX; zhz@1y!?N&Cl8lD^1*rHK$WA)4lf4`UPx+_+)ofL3Akdxc^m~Hk2JV2n#@NX}xx*EN hZ}-Kz`@MFz5!&yWb`kRQ*-1C*D~e=MqyrkB_%BpT+^qlr literal 0 HcmV?d00001 diff --git a/sgl/models/homo/__pycache__/sgc_dist.cpython-39.pyc b/sgl/models/homo/__pycache__/sgc_dist.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02a52ef25f1bd3dfc3e3bae79d9eb8e7970b08be GIT binary patch literal 868 zcmYjPOK%e~5VrS`Y(j$y6vTlOw@5U(BgC^Jm$XIXwinB+o$X5AtYbS8MSDx7p7;-N z>|gQ;i4%W;6Eoh1W>%h^hduMnXY!SW2{2G4fq9Av2LO-%>X)Quq-M< zkD~iDqu&m1`J3$K7^gVXir11cKV%4UqtHHJY`x*C%YG;)4iOnY7OnXogI5mFx|inaMu_+`A>GZfN=IM fS$B_YcNd}k9?pOmar*34H|T4MV3DE&8ou}+qQ~7S literal 0 HcmV?d00001 diff --git a/sgl/models/homo/__pycache__/sign.cpython-37.pyc b/sgl/models/homo/__pycache__/sign.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b23428af018733ee1e04604e93c50526546557b2 GIT binary patch literal 941 zcmZ8f&2H2%5VrHPo85MyhyERqI7for7eM?Ji3Mr9t+cmXtPs1~ASdzKjzm#zt<)3m z07qWQCnQdN1y0O3U5aqz$u~33_?vHXGMS7CtY@2d;%P|8FLbsm0`MBPc@KsYP9>>m zMk(Sz2344a9)~ij;w<(!l7niP4XY$eD47!;^T7?_13|CDY{Z8vk|uXhjil5~PI4nw z^Ec z#1l^7zIeHYqOg9xf|jY)6!0CIbmU^wDxvidJeM$5)}>`k<9ytNDJ|FRw(0=3z1vK%gT+L)~fa%UzVJUx{udw#Uu{L z0KpMsh0Kj%><9Us&Muou%#7L0R`bPkdG&c_)^g@=Gni{;rd-!kb)|QJ%?P$hz!170 zJPP*cmkCNe-8PoS*kqkof-yH{$hwt?cNyEXx$HB(6tv&-D1?+C$srW=CZ+lS@evx- z4ev7n0_y!`I(7%Za@w<;KIo(Ilso<>XUa+ef&OGK=pm*X7)ov%ZKi+q_cQU|J|wID gC*6)mO9u*&j`k_8c|7+euCkn}RnM`E}l)9eN(B|sRhGxsBq4bMP-CC0J}s8T;O+~2 zp`F6PDs~CV5l1_68U2Vl+C^Rb6G(A!7pi@aI>Hi{3iOkGxCQ>e4STPPc1f3XX%}B6 z5YGgGp6PN8ML~YOhS8}p6zp46(~(bYtEDkR@O;KtS(lD6gM0E5rj2B4B#1TIAVj|v zVwGKE1OY?jGmy@z!1%x?R;p7!)~3pot$6;rq* zdqe$*u|jcc8M`5Wrt`C=l5=a%=c~ox3w7~zZr5rax>?LMb6c+KnZ7VvfMx{8W}pb& zjUGii^xFjKp4|^NO>xMYSCTP5W{A2~Xm2xi-g4DvLMa%($5BY7fRcSE+N(XZ4p1R$ z*v&?0(BLnpW4{L&XFbN*13Lz9`Gbv|Eh`P}`Yx15ViLs*@gt6sx2J&11mvt%kp`ToUeho7luU_8Memk~o}|0D$V7iZ^6VZak(~O`1-EiBee+@7 z3aOxQJZr1ky1M)*D zk`z>sBDy0>DyWEVvm(AD2`LgF68mJan7%DhA_o4auy8_<>xiRUvWvb&Te^+5@pq8w z;%%7pHQExExK!Xic??2BrRVy1a);ea5V)mxS;0JrpXWfF8bbl!W6C3+TE~s^v0RD# z2zDDSjp;$pXN=XW+A(I3mLD-~B&*N>)@p+^{ew_zY=sph4V6z|DQkkCiBGIlOFz)2 z)yz7ntn}C7V5o4v|)-v{k{LUxWZ6hbvt|yD>?76!6Jh7FUgsB#L zE%d5#tZ&Q^Zp~m>OoX0959u-eGD2I&djWZhVphD7jQIgWb{mCw$k=+rRo4llKxW6I zluCj0Qy3ajPR$A8Gfe1DAC`r0m1H#RFTuW3U^(tsjt@QR literal 0 HcmV?d00001 diff --git a/sgl/models/homo/__pycache__/ssgc.cpython-39.pyc b/sgl/models/homo/__pycache__/ssgc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c84742b0fde6d88cd26ecd6b86197e71ff4ed53 GIT binary patch literal 892 zcmYjP%We}f6t(A(OhbZDRazGOKqN?(Y!GjeP0}K=n~i10&U7SC#<3lVqS>WVSNsDk z`%7LSvFb0d;$9D-@yJ*A;koymWAo8yKwv%Fe3UODLVja#xDp87z_uU22qLH=HO(nS zKT=U0=W+03mDFjT20u}~x}WzcnGunS-Zc?DNq2FciT;v|vpeWS#?+7AaVwXzx1Sb` zkO~sV^QNq&pv`zb+(SFW?9-<%KS=;o*w^1@SjoV_oQ=6<1{!p1Ivwj++VWEcI2TG~c!{1Yg< zxDDODM=fEAO9jTs84wPcp6lbu9ab|zU{deA><0XL4Fsq$6v8_+ZN)&=M+8|8-AmnOWVg^A&<_k#6>Y!`l6D!rq5434C zvrcO3vz6qIiK_O4X6y7e__x#_F;*yUEn~Oj&v<&-)N*R=X1bispR22{Q(LNO=xVXn zK&>n$`pOJpY6jb)A@n4ANRR2a5lTDR6BwsRX3cBKm>)1iw^is58QW~R>MEfX(Cly& zQYoN*3Qaq6g5DV#RHqNy!mA}24e>dM_ZTovI*gNh6%3y8PY!Cfsx^%7PIj`LVz~k7 rxoeD_{8u?BAO5{B*4^XVUqTpvfHNRQoIZQm5q*VM9rPK}@Wp=tOLpP4 literal 0 HcmV?d00001 diff --git a/sgl/models/homo/__pycache__/vanillagcn.cpython-37.pyc b/sgl/models/homo/__pycache__/vanillagcn.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54c65769b7c85978af7c3152ee787bd0c818ee56 GIT binary patch literal 1116 zcmYjQOK;RL5VrH&S49DVxDmMJ(A__PC{SPpDhtv|9CEST*t?tJJZd`&t@c(VPW=yX zQ&WV{UrA*|>9k7yBwG!z}FLKt*Mo#gt5m2u1Xfh)DY9e%2H5oTR;@E@vAq7&tBC+Du8wmd82q<@M@}gvwGbC=1&i1IcB&tWQb<&gOenqf@Gec znJ?l`WI;Oy{Ve!I0+NMT-wodKs!)nUrn93#0`|ftg-Li)@nS8LwbZ7ltE65eUESoM zuGY0`Er47lGil$~`d#vVRphG#Lz5I$vVd;(j<^NPYF<=DwH$ZJe{}T7xBZ_WE^&U| zw0XS62;On|u*^$vmRPzW=iV{2AiN6@+WNv5!3Cf!u;GRtN9P#Fwg+*4<8Ay+xCumf z5q=(QqA$2N6a*TLW-dTdAli;xbJcPySqaiEp$^!$eb}Ko4P96*B)2Z8R)ug;l`C$H ziz~%XK?xTNT{m@W(~%2J+eob=_}cXtgC;Fw8Ykd(SfeFd;+$CB=xwW99J>)SIAf3l z@etn~Rc*A(`IA zjFFJ*BA3umpRtZVV_(Vd$z)ZRa$?NMWIjFIQ>Py$W~nA9f-<;QPuAPBu*|D*bEeYJoPmWPxsLE!KD$(Qt?tU<^~M=X%)sJ#!gzU{_ez# zi8=@6upB72Wt>Z;0820#x|Hg>7~eyOl4C7{V)y)jcE~L4yCHNm?j(lEbh6(+_$%-zPl8ivSitg`k0)s-gR1p0bM8E(ekl_Ht#VkM~g&~+hlhJP_LlH&@WEU(a$ePElMoOFDllLkI&4@EQycTE2zB1VUwGmQks)$2QuU{5HkP( DgR&fO literal 0 HcmV?d00001 diff --git a/sgl/operators/__pycache__/__init__.cpython-39.pyc b/sgl/operators/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20ac75bb2cee4b8d4a3be7ba31a48253bcb8a61d GIT binary patch literal 134 zcmYe~<>g`k0)s-gR1p0bL?8o3AjbiSi&=m~3PUi1CZpdu6eec~_h(w;MGZ%9H{n%g;ZIdjf;zS&=wmm37W+k21s6_1dA zaIpAn=-h`Q3#eie(}eicw9IFwwNl%6XsL63*T`zA=X*xR5;v{;bwSPzNxw_A`+{dO7hR^mh@=Q)5emtJ z7=%1XLzRvaeKi`zVemiN2sc`M3_ACr$giQ|=0b288nZ*7y?Zcc~@CV^-4cizokXOQ=TavKN14{EpVc?z;0tQ)Y^{+ggVtUnZ!Pgs)d5woYRU%mz}o*^HDKP- zj6NiPj#m&KtF;duQ6hD~5KdGv4y@sYzk>!d>o$3;7xm@6JZ$U57(=%`cn zgjsAPrw>}!&eLHw+W+t`d=c)89|jMjvIRx1LN$Q6ylx#(MUE&&=TEk0ZB)MxwT-8pfy>9nW-^p2%G2T=rk#O`OD5MBP{Cvf2chD<%_NV)MBanW zVnz3&xB!bp4JzFj${_$R%p$ItEOY~NzI427l=je6-lVTd#QHi7;-fg2vu#xxzcV}6g-Q&u1N*aLPc%vPrqey7)Jpbt7+iZ z%?S1=&bUC=t2?g1Ch{s&gf>}|HYk==>M@tLp?`_BX^T!Sp7Tw=>tfz#VakI*H-aF| z<57bBW)SR+!lcv)q=eFA-(nBB<&z52S74;EBhdjK^=!{FvIvINXIKSB@)xKorfj?? z+rXBwby)W;ZkI$pqrjDh#NiPzY2w%36Px%RB+8PmKjTt{1OB||S2Tb=UKIHa3=~m7 z0W)^QjtM>)KnE7kDr4BQp*Ra-bs?f7U`25gmYM5d_Dd@K@6yp>EF#tlnR9SvPsf(1s7R!Uh8laCk^;2;Rgk4&q z<|G1mj;>dx_Hd=4O*UCMIjPcrp1mQnScXsAo<=5GyZ*b4n7mm%1*w0niC@~*C*G{Q6-40EtL5)kn3GXTm~bxZ7zQWx&r+w zeuQcc0dWTfNNs1GV**71y4nZSsV1-M8Pk_z1 zP-F1}%5YC{7ZuBM)I~f=7RWZhVNQo}?V#uFHCLj&4@Y-=qAZl6!rA}3<8j6Few$K(KB%r7gm_rbUmC$O~oj4Brm6@@G zE$4zG(QE&Md0d$YC^lQPo0oq2C|=FNM*_uJ*Gs|^C}=KinZQb5RG zI62*1nB0S|N)QQ2Xi7qApDbkd>11x`(aPqBzSY&TAPlUIrG8cq>vrsy8l1n9` z1e7AmxXy;Dxf~A@9(}|b<3^_&gULPU>MIb!UJ5Qm6Lv^KCvg%Np5S2O!_!M@=!~g7 zlDyn~TE_VDyx|!8eduZp#Dq>r1n~d~k=iEjXb;?8iAq_NkIwzXSiVoRz0t?fsTZ=|EwTdJRK6{V0|7gBA#{1vkBGC zo;zjgOYJ3eO5f1@d(blN@6|wi8??-7uT01u-3t=uBAGCFx);eSM?Tkez!|(XVSRew z=mxA>nbHU3#hn9JH;-120e`QU{9U*3{fU#*E|Mub@bWdBADyc-$MnFPI7e#}@0e1u zzmo(>{g6zY2V{z9ZS=;sp6!Y@T%X4xk4J4lrX0L(XG5jiFTv1O@t_>F0i*2-!Lt>c zLVOEa1dBUaSqs{l2IEiBX@WBA`qXtFZ&ISsOi+9_J&?1JEwRZ~H z)=U2IJh7VPgEo!#c#scYzyB9njQh?TCJ?i-4P9-3=tEv!v<|2yM-;R32iteIE6f6% zJ%DDsYgT?u3$7Oh%eps>izj{vnZ})TpmbbhvjfvlphA5zPRH$Y@#lMAhgMvv_Y|N(17{01>*~B4dnRznb`C?KGuKEGZ96m5k*;% z3{xC8qiBD~)5;=OVdpAh*KmX!b4p+VOUyIs6_#Q^12=H2E{5q+vkU_o^)rYCb2i?R z9bnJIIjo0{a4T}3Qs7QY?C@c*#S+%u5|@Mll*@{#KM_jtzIa;pmdF5OJ@4vkm}sJb z2&U|a9TWVN04X>?vy9=$h3+BITIzy;cxnhdDZB8Zy*R{$_xm!^00rX%1#_u^f-#3q zMQC27b3g^uSC67N@44G)2=0L5Nd_b(@O1&@^$jfjeGpey`)3uwz@wBbO(qQ0|cJo~|rU%*Tv zwVI~=@N}z>4se?NeUeBx@)NWNmbtrm4lJ45keGG&;pi@YIJ!eUNR?$liMh7aVEiYR hZmLo{iIyE%90zwQLb{7(VE>LVT20%VY|UwX{vUv-h9>|3 literal 0 HcmV?d00001 diff --git a/sgl/operators/__pycache__/utils.cpython-37.pyc b/sgl/operators/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a960178d85a17d592a9a8c5904b363ce80228748 GIT binary patch literal 3492 zcmb_e-E!N;6~^L+0Kp&05-r(n>QZT&iftn~?p;T9GPPqj(?s@Eaxd6+FfbRAz<>a? zyR;(>cw_a}Gi{%MyzIO50Wy2rE5AZ++V3nV#!=i(XJTOR?ZMfD#qQZ}&pE$owL%Nm z?b8Rzmn)X_FM7GWYzQCWkzW9~#aU*JSl?!q?DSou^G3D4zt9DvdOuvqjeZ0DcGeuV z`Yoe#vc{<0Z$oyu_ZO?b!fV_|yUGJzM|+Kjyn%L&H+c)~I%fw~ul)*txAuEXdxI#= z6@=PhlFLF2Xo7UL@zn9$!z2F*P}Ynod(NgdB&USL*-dNeD0k*@yL9IE)K%`Vru<>R z9h%NdXX-69^`)j(((U1{sZfQ-e(4SyQIlJyZ7F}&F4=Hp&St9^sgCyA+?x8OHC(&Y zOar{FPwT+WG%PEN27UD>d)l0~N_W;M{W2))Wmq=KX4#sOAc~ICEQuoR$#Iq{?F>|=n^_U_2;+-boauVb#|0do2yN$M?LqwI9qk{% z2!bHz_tziL+r5L|e>TbPV^fkxpFP$=ERJakQU@ZDmE@f#qLBF8NsPJGVxy+PVx^K? zyF6C0_EXGxfbVJ$#!6@(-_r-ZwsvAZgnbgLiAZ$a%)VL*Peo)TA8Vl5>o67~`hXK# zbaNq_Y2(Mae4Yq#6O$L;Bltc!(1Am>4~RtZyb=aQ_76*JqxY-r-IHRJ>`HmMd+_kn z_p`GvcI9!lTZ|JCt3t@#iApoMGd|Os*q{g=7BY!Axg)A}iq4q)1;An z;4E4@WK9;>H&~n1a0e{3H*E1U*e^Mb!)ehlhT~sCczYaI;Wtovz7i@1$+Y7eW!S@B*%l14leBbyM-#5SJ`(VQ3CFg~U)e8^Mb3~0~ zpCBN`rl2?=$Wfv}&?IOPP>>MhMX^e7jbM$y5P9_~k;OXobqKB#Y!GY`+#nzVX(vna zuURDCfv&eAen9=V2!2S=CD$Q{5HqFoZsZi9do(Sjw_ZrFmjxuhC-lH6;k_IC)SAsa%NO<{D-Hm#|p zoS9o%6i(;N1X^cSEA5hT`-U~QXFdXba+Xy4MLVrMM1c5AJ92_>I4q;O%j0>6c0<({rAqes0SR{x_D*_SaStL(|in8QL>2*Qp z`-*b)H7k)GpD5kAVpeI}2xw7Wh*6xS=Lx^M&DW`06_Gp}jS?l&fwBD*6PH9vi*?v4 z9_**RX0XG=uc2S=BJJRE7ypVEMf=DuY|(*)^mt**Tgu_=g1un*9fjSzU^s->(VN!i zj`)?TamVy+pbz>B#$D(iLl5}{E0NuAT0;*>)}IAcE-$V77zMk8oyYO7BeT16yJHPQ zZu9y}Xeb#^-O|M#hWqCq9-gGKdo;-hDlPIZWS+?MI8S)@s1V&Om1=tt^LyPm=UpTV zWd7~v33;}+Rb|G#!;_@DFf-p&rchRPmF^^dmUMY~bd-oBS6!UlV|6084290$qaj|m zmM6s||5ZOf-!5{bnRFCY z^GUPt={YBDNzT|!eBXfL0P#%HDi=&u-gXJ86cDjVxd(Nqk&DF^BJ zY^SQ{bggQ|`!q?j)>>t63W&-rRI!;{YC^H0q?yfYi!#eIi9FaF6?~EX3W;O@ literal 0 HcmV?d00001 diff --git a/sgl/operators/__pycache__/utils.cpython-39.pyc b/sgl/operators/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..135738ac678833b14b9b8543f7d33c64fad49136 GIT binary patch literal 3399 zcmb_e&2!vH6<6zfG#_^CiJia_#tSSSVC_xdM5VF|S$6r@H~~8x1}K#>EqgSpk;bi_ zWbM*iVy>wI{(_lf{ulm$t-f-~nOmx;<@Z|SiQ{A`4l9*@dV2l3^}64$-|MK^3=CW! z9DfpjxndaqqK}K0iI0!)6kh^UgbWo;eCY%ypH!8Z}2AG>zwTy-PQ~I-Pr3g<@CcO zlMvj4I4g3|CkM~37K9p}yLgI!0;Dlx(wwuY3CSuUadzF9TGF05+$^oRIklxdbfq`+ zxkWBLv!>2MQ(J1>l5PjLD+3uEd8Iw9hYfD%cT;+^R>_7db2eMWNHx5#%#Ep68pE}7 z&D6)&`m_dYPlK|us8g>$o72X$S=zID>6Lz2D}%CLHp=E43tD=!tJTO@1b0s`%;|J} zXSZL7{%)E)8bxw6Nq36;1+}Q)?)|-k-+cP;{=@y}mU6f}8OO@bcrsGXVHzD3-3?{s z#aMaKqhcJ%f%3!fFiGPuR8BEY6RE7eOjRSzBOYRW5s4F3%lJ6QqQ^p+*;qLce|k%K zhcJR5C;+|nPw4CJ{vYm5(vK0C`2M{I%8$emxlkxy#6=}Jr?JQj{B01yH#gF18Z44B z&Xmm~87VJ;&wccBA&jL^9{N*0$Ny5-Ytsb-ThBL zdp|w-Vz)R-ck^*9BAJU~cOsLt*cqRwO#~>!3X38RIW0$6MT+)XaSLFufVIqqX)~Ld z%>0k*E+r>m_E)~?o7Y&2xww56m>Z_}F?zkmc5GLJ2Gnf-62dpfb{V|JcfaC0w8Z=W zlkal9^v~Fm@y77#6nxh^qp{$9bJ|+)z6Rzmcwe^WzGnW$8_eH)#r*z+M@#k#8?zS< zfeYwv3R)LYCuk5f2`D=VS~RgraD`xvV4XnIc=a-kMVs2L5^NA`5?mv=PC)cgRvKro zI3&IUU3W#?p!V+)bO^Qx-X-`R!S@M%KtMzEe4>4rWv3U$d5>CuNU%*n46NWskkmW0 z0O8lrjHpul4DbfcoL4k!F>^^Xe@V1!rg$IyUL%?h1ud>7+CLzanrcX8CZ*Fk+cc(z zv}Sf`P+FZcop7z0Tbd=~<~3t(&O9VUi(91yJ0FRc-KH^WBq0SaZ0qNi4lEsQ=`1ZF zOF^baUH0GBBewCVK9u_B~Q7lE$*S4R*@q##Nur^!8 z(?YUX)7WO>4)m`@lA^eb;jXM76(e3>_3)Z{}CEy)-_?D2NoIj|5OdrMKs1+ZYDi zEasiVTy)Z;klTy=-|a*h@1Stt zB(VJ?rq%6k=_>j1_Q4?TEX;JD$`s1VuF?&n$8m=zhljCS^ z{?&`GaB080rm`O=F^W%dJc;Ae7z%Ue17*`xE^{KdO^eabN8>vjzxp>G<8(XEP;`<} zSa}}vFyegYcAEF2w73J|yo>k+%%3wbJsEZDDj@n7GRpcmGmw$6(5`9+sNla6A{XK( z_}2SK1%HgvNu^Z|N?(7V+{ftK|6Kf%jBWzD4yrSpMkOw!s#o4rAA#0aM}UTw7xi5= zmrlQJzU{Y7?N9jEatj1(jVdei%>bL>1``j_`oyvXXX?M za9Gna=#MQCASL3|Ijk>q73sY_@gY2^%GgJcUJ#A`H~(%e`6qr2<8DLz7BBHT0t%d< z)bJ`~;vRkO0dV%BK&jd!TTo0{k47Uw3(|GO1G4))0a3WuTEuM;AaOt~mugX4KFQx; zTmr+lgr&tA2#n~VwcA!+j^h%YLv2D`wUv`iM&lD*o_i}=-s$VJBc14KOHr3|)hPPO z_++O#YALsRiyd-O&(*E$b#g`k0;58=)P5lS7{q}ACLqHBh>KN#L<&O+V-7 zu%t6Zv81pDGib8CWCSYJWV|KdlUR_Gn4FoI=U$Xpkl|kd6ypjgD1!2eKt>la0|`G( z)+nJInBw?!km~sS0~MrtROQ$#4i*5jQreG{o>+6 z{a|+=-JHq_{o?c-{rrN|qQsK?qGElRvHJ1xnR%Hd@$q^EmA5!-a`RJ4b5iXLph(dO~DH?m24Ai76-4V2xBvmWPM>jps);|4XsC=T0*gFT6z2Nh@XE6{=2K&Aax4M%jTc zwxSw-N=7T$U#%UXdv4u+6&>v!iEoo9ioY)A=%`sXRbj930z;E&2=oKkk|_S2GTx5> z93sHGh(*rGa=tOU(A6fTgW0aub3DD8)SwAeM|TV{nM6B0&lbM}xFT+_n@bpdb{AC=^MxjkU3zWFzk*wol5PHU(Uy z{EdW;j^E=7i2~&>(7<@_E-yy%?Cf~PGxN>2+3$A|$d7NI#213lFZZ()AB^{K^czqd zaV$}e6M~&hq?h}N-`JiE@-PW8+Ctpt!8ziA@Mm7q;&>ND;WcEVoe103mP)Im&lQ+! zDLaK1WGVR3E`~^bbz_X{qs)Z@qmZtDDf|lkCFgt zwe8AwTFI1UX|X7F{TJkS^Ah^`j(Kwnx$pAmXYgZ?!8MwZOH$)2j8RQ$L_Ln+RlF1V zHqhfrDCN4cZA!BuGnCp@`mi3QD)Owag*>p`N|hDWMpW8j1!H$8Wir*8((mZcMtoG} zBG&pO-rd@MEhne3K9q3@Wl~cr9UnTlXj#Q_ITKyFM%6^Mp%r&@gA`uD)AfhVjFycu z*li=rnLoJ14NwNv*x+L_BkF05FTE>thKtpO*Pjvn$^;zG@K zVMBJNn|nuM2>VQ#DA;6}r?Ah{VLsMoxDOXY&9Z7T(p9Ro7_qXX^N5&fwJ={q9ox-x zR%nwJOjx2T+r7QEfhi#mTsO(Kvp1>4N2N-o+(@bes)^G8zjB@bvlru-?mxhzcM@g`lF)v7<snwmt8U|g-*|xTy}RJr za1OG1457&ykZE*7rqOYx15VgG@*+gk5qD3kf!UbuO$V!|zpn^cqOH~f-X)l($cIj5$kJ^sZL|YD| zQvXJK>|gQ;i38#Pb-uqw+bj_3&qpYgvJR`}pT7)m+uI=SIFSTK1 zCM?ml?cH75z*JBNN4LecdoZcRXQe8o+?SLiqs`MW0@Zap4=~w*)0>kRllfF|HE_wT zd*pHkG^?Pi8fS(=W*^S9a|riw4^~9_c>4T*YVTFWSOriLhfW+ND+Gw6M`=;>sbjY*-At|SV^vV6ZvB{;L6X=Ar0^yFL=-vDy+t}df>+BrD{?s5OD5qmt~;TehjbK;XY zz*=?N{3g}B%vF_6es5j9t>o8My13JQz4_7|Gz>9Nk}18QC0qiKl9q&drH9(*1aF4j zD6oMZOQDp@wspn^qk&<}w%Kh>#*}?@Z=YncpG3a(wa5-_Q>sETZG_YolnrKHunor2 zEYX^=pXBdmuV3V%r}a^9Z|BSFY;xSw-?LtUwG&e)-TUrvWQFWWsn}xEl@nFR&fL-= zDgrAI!pZ7wcilzLDUP}l2ZD?Jh)W-!GNc3pLpr7EaS0dRB{_xB+Wgd?QvJ#V9H#IC zOjsgh!`hUdQjTy9>rWB3HR3D-7mX5zP3~=wOIRaM$SE!9#nPpN(%(JV-tUVJQXomi zD4lfj1SxRb$;aAs4$#rj=|E1pS|&=1Zdw#<7ST(!nw!tB0lQUBMJ?ML=)p*vWRwa^ zwY1IKKpU6>Z@{CQYiwhGBE=`A3Z)zXs^OG-2i6l4fn)9`5a+zf;AG?OMDAroR0SBW+(@F( zacq?20y)-XEHC&tb9Rfdqj8eWJk`k6GmN*(a}K2<%aDgRy+xoNxJzA&6?bcBXh-5Y zw1esdA5QBo<~E#RuhuZit|R-pH|*;kBw8T(l;eJt;q6h@Npn}w%_F`?fdOO{T=)ZpO83I{R$G|0jz1hPxG&x13<1L04-sKQ~Nw8I;$CVOqLm0cCf-ZzQv zX3Pi7Qq*f3XF;xd=TP2fboL?cK9oJ5pY&dCvQkNan}$m0PI&q(B-6Q4u@ zzE`#N?^4a5x-QGf@15(nrTp4SZhN|j^S!z17CMF)D9Dsv&;l+2NI?t2yuw54bAsJr zuN~MxkEBq_Rog0K{bAoQW}9sBE@R5B-8;vr?8R;0`dXw%wk}mJnKnXd3(5v%^a^Z^ zu_TSPX6%am-R}1COmwwA?(Xk?eVvX^y81`j&GGHn6aJ9C0-zY_<9%m0L`578J>fPn#>QuVli3-6Mg!f<1L>rbhEWdaUU z_z5N~5Vk>ON>3?A$OhG?2;By8o`H*60fRdCw#X%HkSFAn7W87}(!uHPO+FsaLBg9gL;;tW>U)13=ZBao>gY#8}{}d$GhHEYzy)sg*K<?Py$u zW>DVX=hIk5+=esc)h1@yP2^waj(?rQSPP_|axAjz_hTc|PxTB^_rTH!3yiAoW`E?p NTV2W32BOmN{sY2`F=PM$ literal 0 HcmV?d00001 diff --git a/sgl/operators/graph_op/laplacian_graph_op.py b/sgl/operators/graph_op/laplacian_graph_op.py index 8973a03..0d2df3b 100644 --- a/sgl/operators/graph_op/laplacian_graph_op.py +++ b/sgl/operators/graph_op/laplacian_graph_op.py @@ -5,7 +5,7 @@ class LaplacianGraphOp(GraphOp): - def __init__(self, prop_steps, r=0.5): + def __init__(self, prop_steps=-1, r=0.5): super(LaplacianGraphOp, self).__init__(prop_steps) self.__r = r diff --git a/sgl/operators/message_op/__pycache__/__init__.cpython-37.pyc b/sgl/operators/message_op/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..709576a58e51c232da58acfae3d95cf8402c1b2d GIT binary patch literal 970 zcmaKqO^?$s5QdYcZTgXq?RL9eEpXt#p-2#k11E$KAP`bZk$OOyi)Cu7&B{))b_(5p z!kHU?saH49gxC+e~R4eRToDShy*Z-|y`L)-eM=*TW~t@lMBLkO*JiAcr}Ti+HZ zvIjluJK|LKp>KUxoXG(Uj1O4IA{Mg~)?=rv&(7F@oiDm@&W1B*bn(+Qg^t`!{3_4p zw489Q>4HxS6T1Jdu`EpCInu}*bEDXYG6pMg+ z|5~}^cna1Ck)-SDLaZ5LX^>W<*n~+<_k!3I+$nZ&$zdPZFz&$d&ntM2Y13sY<)Bpeg literal 0 HcmV?d00001 diff --git a/sgl/operators/message_op/__pycache__/__init__.cpython-39.pyc b/sgl/operators/message_op/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3682859131dd6c509afe0901988356fafc264c45 GIT binary patch literal 938 zcmaJ2+4c1Fjr`hX><4X{t7SJY?QgzV$XMnn zh8}nx@s?^~8`_SyRR_D!b-bf`=!5TgSM_lK1IK%6hyetS`zpi;BFFn`gfYa957aFj z!`Shmx{U{L;CP_!;2|7(Y{|0ln+Hc(l(OmQr)OnNy>0wT7Yk9(r7>bDPf8nl_^y_f zsO3zGst})){3!F~s+L(ZDVhnhQ+52`^l&aVJLPjJie@03=R5nmA1eJ>E^ODE0w-yX z<$hUVy2j?Q_fO7bm7>=5>UC~vQ7ryUV`mL2x~0lqAFDb9U4kBgPtYeA5YSy$0YPZt z{d5B3g)@%hrE^@D?l5?iU%JQDwSD~j%D-kV^@Lbi6Pr;I0T2FFxsud{G{OkQrdddu zAx3LzGfG3))O;sML&3fB+FpKh4w@K_jWc<#cya?qW{u5eTIp#3wp`s$U^At`C7jLC=(6F literal 0 HcmV?d00001 diff --git a/sgl/operators/message_op/__pycache__/concat_message_op.cpython-37.pyc b/sgl/operators/message_op/__pycache__/concat_message_op.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f24e9e184082db1e25cb6cf36ba9f8a6f6c61b42 GIT binary patch literal 782 zcmZuvO>fgc5S?8+CM^j_AP`6#Ip71a#2FPzC3=By=*^dr>)m|7+D>*i5~*@)QU6AJ z7{Vx-;KkMW!L-um5ecnB1)mhbhk0QkY1UBt9_OT)dONr4gua?a9- zO6F3)DX2)rSD<1oFQYtC;tbN>PpXD-D#F1>?OZX{pB5z_?IPMP+;duJ$c2JjUc*Ev zq2yJcU|P^TzJ?g`9z%xvCrw=zeh;_<;mkZayJs|c@L~g(@) zq->GWIP~0Vp{?aFLmzQo&pl!|#9}(N=+_IKMj>|EOsLWq)`v*fie({|rf?4NJN$Z+ z&6-MQ&MmXE(+@Ar`fKKx;BI`Qi{$#XR+l1QRVe+g3!~=2h$|jbErI_r+!mz?J6RW%Zk@!cQ7gmi0pfC1nC_<4 zY#%cGaO-8MjX~TuPle@%mN~R9#5T1S^~E4=Ptw@lV~9uO_vnw`qw~UPGTSoMwr&sl Rx2)G4wIllHx43c;{Ra3Nyi@=H literal 0 HcmV?d00001 diff --git a/sgl/operators/message_op/__pycache__/concat_message_op.cpython-39.pyc b/sgl/operators/message_op/__pycache__/concat_message_op.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5791dc3b12b668021335e8c0214417a6912ab421 GIT binary patch literal 796 zcmZuwO>Yx15FL9rAuS1&DnKA{+KptJ_?NUXPzVi$yn?aN zLd(k@K{dcddIc#I9fpkuC-uBk;U4r3gt2(fe0lXRXbKR-1}@~a+=z|5A-0-wOf(Rf z9K~+gpk-<7_|?KV$6v-C;%q((h;bj)Wa2Qa7bZ($>Wv*!Wl%1}#LV?A$wMry@*eRP ze!t77b!Bqzm-*T0$JciKE%y_f*MzJ>?R;LfUSPe*+s=4@GF+@(m)bvI@~i^Hq4@sn zAHBmnZl&asZScNBhMqL6YH%q>=oP&*6rjMT}S~Q7NafJ0&6I*Vx5YP>Q4| zldeaf$Gx#p0qx8OH=w+o9`Q{N%ep$B&CNr~d_&@YEVo5jm))GJ$~10bRqK`IeIIeT zQntNmbi2n4KiK+N8fy`EnsYnH44OA$qZ*TXqMtS=S?UfL;0ZZC{OkDeTzNxwJ0{xH V?LaolT2XIXVn_4`Y;k2j`2%G8ywv~z literal 0 HcmV?d00001 diff --git a/sgl/operators/message_op/__pycache__/iterate_learnable_weighted_message_op.cpython-37.pyc b/sgl/operators/message_op/__pycache__/iterate_learnable_weighted_message_op.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79623d9bdcb6afac7aff6d20b022fae695a5c824 GIT binary patch literal 1799 zcma)7OK&4Z5boEE$BrEq_7P=;!~`LBy%M&x;22r)5;#TCe8d?c2&0Tf47(kYdYK>j)tsiaEQkV(o*}`6cY%Z_gBgwL z0`H@?e+{Of?Aci6jjC*!#$4C_$xeJ!_xR=TUnG8opiq#d)E7yGZ+#I|{I)%<4gG)3cm!!lc zxgy{VtT`>IAV)Rz&@e$~_&Z!u*!2X0-iI@GMAXRCN;bp7JEH0Y1SQ zN^b+=Fe6S!Ab=X|+55>HD^HhT`SS%@7ZfG?6Qz?AdH!T4KZ(;sgv;ZwUu3-`k4;kK zp*b7Nt?+B;!)&5W*puOd8{G#T<@~(H(4Bw2v@?!XoJqJn%qLk-Djg0A6`GL@Zxq7? zxYeEUaH!-kHigwk zzvS86!fR>$JA#&`GFQZ*^exje9ROAnD`Tn5MRe;wx3E|ZwG9R`He4i`i$7pVW34&= z5&gN}9Tl1EYJJ?@f42KUdUo2?!?as~_bwp1o6R$r7vt{#zbaozb;f6^4)E6C)Ayl6 z_$};{02A!P=;Hu4t9gLWAH6bcon`F%&SlqJQafUQ3ZMQ1I)SE;e20+IkgmT&f}=wW z$w^OtF>c?j>5M!@-~K#hhFy9ky7XbJfh5~RB^{6}?0!?PWKJHeM5Q++`FbTqO3z75 zi3twtB@sl>9|)}2Wn+04>;j7KiC_b=i7wZKCxB#E1TZ``piweGHh_MB;qerRwbCv9 z4Fnh#`lbI*{SENq_@@JO{O|zf%rqTAU7^{lGGDn!z;jY~)hcbT|AN+TsV3}OpU<%oimN4BPuVOAtUwJOOGNcN8`-iCu>V6yns(rl7ic9O{F_V!+3cCv9Q zGnqrRsS>_jQ*SucdX-gaqK#dfGniM4){@TL30R!i+I&Y_e5me%>b$$>j06j2$hTot zRYZgtZouyjo2wd#?B)LJm4FVmFs72q~jcN}s^y}QvuhWaQ1nw#Kmr-4?Li2JApUi#w$dxbcPzAuc ICTY^Y0S9UNt^fc4 literal 0 HcmV?d00001 diff --git a/sgl/operators/message_op/__pycache__/iterate_learnable_weighted_message_op.cpython-39.pyc b/sgl/operators/message_op/__pycache__/iterate_learnable_weighted_message_op.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34b253f649536e64e4d35d4abe41e7114860274d GIT binary patch literal 1817 zcma)7-EJc_6t+E{$t0VU(uJ*xU^SHxOZZ8o6_+c71azrX3Q`fEk&rcwon%_iOlmvb zZ8fuApS*yq@v&wQVgbUG~r?cvdu_^pA^ zAI`bmC^(;A6Urh#X@^zl4m^?#6%rPftVeD~_ z(=!y)XNaQMhYftYxow0>jM#zI^JHI&??gHp8^Le3aVJ~006NZx3Lq`P7-VpX=H#4| zxFi<@P6a;mN{^Fctll-0dvp9FET8G#syNCU8))X6 zu=GoR9qKVBE{7w48}8aS(i$(5OTg^O09II{w4EO(GUd^7IvNz&ewrsH zE%L~mOvGmN4cI7~Y7^~?=)raFgPwBdyu@HmzgVs_NmP;v*geXp*}hOZ8Wt)tV-Z~^ zMhkE^dnTii5~IWvYE!x8x=rir=|m{ys%S&T(mXYc*%ou5%;B=FZxcBcpDR@;`w@HL z2V(5j$?6qeD(l}Bw4M~XA`WF>sTOhwSWTjg^+e9&)&J~56E)H{9E!v+o@URX{(wa$ zS~K<>{k7L07n$g5ebnE3yz`+vdD_<_*)QPqt{}Rf)l6o^r2qe4mEA=3CMT*1@b18) z?}I^j4F@E|1d}UD0uta5w-;WBPv3iC>3Ylh58Sbr&5_%Q{Zn}KpI|te0SOO)+Cbai zf#~P}14%vOMxE!)$!F+?UuV?NbH7CA0lXU!M>kOE4ao&|znNcB7Y}QEU-~nWuT)W_ z^j*Y^nDC%k5>B}HGXamDx0XA>C}8-Ba5^L#===`%IY{k-I0T%Ol#_LkARu@+199DC zmccp#1dHr4_-FPyoZ#q}qjz#@+D@SE!r85uuUrK{P9A*Kinf@Qwd#6?~tDMeK&ddc^oZ3b`qAfm9_aL=4>u!;7p$zdRc=`h{ z2vgjG-x@Gw^(r7O;Li$PAr#*S2CV^iR`I=y9XU{*Qw)eR9sN4u*c^xCghBi)vI3w(RGDtIr? zd&6lyaF<#lEy;7+y75)yBEw~5rEU8BM&bq7%;*?C2}F) z0*o|j3r1OSr*diZ4Wxqys)AH>!;dzEVrH+Kl27)gc1bv+#6m6%+ZUB`5f{<68c7&1YZV*cpEPTKL=6W#D9Z`rM0|Pn539+3mQr#78L#Xg$U{P5LmHj- zbRK?_dH!ap^nBM1r?zO(%|jcH`|yNs5-jWLcD}Gf+J%VzfAd{Qdc>kEipoY&u&T|< zaXvy^t_s)RL^A$_!N)s^O6wfrp>--exAQzf6q{*B?ui)bcyg7h?kPz)d6aN+TLeq` Xc^)#>9030&^tvDGE&l+K&`A9OYc-|r literal 0 HcmV?d00001 diff --git a/sgl/operators/message_op/__pycache__/last_message_op.cpython-39.pyc b/sgl/operators/message_op/__pycache__/last_message_op.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..363a86952ac2b176aff3faacc9b05ad9aa9e1419 GIT binary patch literal 720 zcmZWnv5wO~5S_I*C+A3@;shN9jSFIl8W0lFU2}I7*SQ((c{d@_I(Bv)iJ){B>3##s z`=xEE?guDnU}haSBSxC_%eUO_&(r!L4v7+)Cg)xvyiE3OPJJtY1m1qMo?p-gUJF0|0eJ}Xn5 zYPp9}-hza3#He9CVZ1&)zsI^l^Ij!2I+Drf&cM literal 0 HcmV?d00001 diff --git a/sgl/operators/message_op/__pycache__/learnable_weighted_messahe_op.cpython-37.pyc b/sgl/operators/message_op/__pycache__/learnable_weighted_messahe_op.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..829e8bcc8380f1cede06d0277d3fdd3088bc4a90 GIT binary patch literal 2908 zcmcImTW=gS6t+FHGqcxjlhQT_&{ka1jc6iO6%qoWs#gjnATkqw)Sk& zRuc&|L45-CC+s3#;E^AI_#GB+Af!q>^S}$7A^s)6WLtGBnkOjuTEJ;*UofGe#(hE#IcdVG)U$&kSuvh1l_(*!5jAw<9ku z`b9$*qEcM;%W=i8P;!%Shr0)ayP~*n`BhHukf!$pe3A7g)htPxOcy^3Wf+T0NbTJ2 zCL)xuRQyCJ6?Vj@eULbnCL-Y7IM@~4&Q>OP5OS_hWxMG!+!orNe%6FgZvY{PPdV|~ zBhsdva(3YP7Pq*4Kz#cVv58-R*b4gP+Ymq!ZbssZiL+TGEzA~v9<%~97n-_83CT!C zbF$BdET=geVU}PXSh>Yn&z0vhn_K(zAzw}i`+d$=C(M2+@D<63(++BvrFuB$U_e}PJApkQrTGE?k%rMyj8Lc-4i$gOryVsAmi!mN&wTiYrJl@u1pU0W2VUP@MUOnFC zcJCN-&AKkAL0?FTS4(?A&`r8o5a?2XnF+9Ly7YM%4a9XRQ>iZm&%dTYaH!&A)fbD; z>M=QV+KNP?3s8*KmM-6j(lFZ)iAtr!4Qm@?)3tlyPFKhvNo5>Hfkdp)Hn4Uj!bH>4 zwh~br_LYS))0RlMu0Io)9m23!9wW6Dj{!7D`)f~c5CGJyZV6aj?cbM+ z5b;T9>Kb%}E>MSglu-vpht_Gu&>ll-0ahk+T02B9uoA5qxdrGSy#D<9tWK-j!gS*2 zK*Pdcg{J<5^Uuj1Ak#LWD4^1%jBc~c8rr}qKy#qEU$Z4L0u<-gH*60uDA_RbfT$9w zBXVZ_l37Qj{SIt-fYV1T?OAg#!$NM2=&%S9E=rvR%p0-o68Lk+{*ph8HP4{mLJn~( zkzx6eybCC8gP%%{{WIR$Y()RdBVhc@=LlqG{&tMNqVacN>{5vE|MX~Q)rat2?qMcP z9UfEly`dq;fbu_QwWB%A881OC3dxF*D)$_a`W@bxBblB_8T&cDxs~-Xlp^lIDYxLT z8(1~)C+onKC$K{VlZZPYR34*y5Or17w8oP2^}0L_Vw&F2u1eck9Nv>l$Z%D+6Q^A+ z&*02~8;4iZ4)AU|J1}eY{slB0 z`5Gz~&2`17WE1B8$&jag{9&2#{TjSfLW-m%3{zL950#n9x6q{NmeHU$0WseBv&?2MzMNx>5={qTIDsEC z1H(a_@t)C#iMJ3 jtQ)C$`Aih^kPqko{qe>>z|*PlLG}P1E3^(-u3P^Cb><-B literal 0 HcmV?d00001 diff --git a/sgl/operators/message_op/__pycache__/learnable_weighted_messahe_op.cpython-39.pyc b/sgl/operators/message_op/__pycache__/learnable_weighted_messahe_op.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c79d7e97657ac6692e2c64094769bef57ff45bf7 GIT binary patch literal 2956 zcmcImTW=gS6t+FHGqcxjlaeM#pe+}rTWBIxl@J1IDosSdyK%59x}D8T@F3(|pUQUA7q~67J^NW>LcIe6L43-I z&mNOD<&?8Wo^NrB+mDEEKPEQu3lLvHzkU;fNWzUsd_8tJji!Y=Hu*dl6rj1#)D=od zMlza{eKue@&Djul3+TYgEzWkRJe%3v+NTd`ZUMKz3G@!k**&JNz`xLA^83uq?S0Cf zo+a;qyldob?k$nMLgo#M+{p{vIZcLa&&f(#%%xHWA8Wqs^dq(3o(5ODIk-MiMvtzB#K&#ecPpG^!);rO& zfLZ{}`k`iSb7zS#SjC?4I$A~Uo+5{~ID!2?xu)l$Ndl%RhqO4V8}}oj>!W`VMp3#O zB%-73PM8Vprm`DI(bjCM_h9ozvJ*yK-k6-3Mk|dsx=EOI)1;Bz?~5~yb?6)Mw#phC zqOm-(y8QlRbZ8%W=mGy@AIptYHkP+~%PSHqOqQX01}A_#xIS^*pB(E(vK?;-sTyt2 z$u>n}hX3l#^bD3(blKO+ZwO)ICC`&lA+DZ z$J^ZM9b>Lp*9EoR7gAzrX)g%6NjD1uT?%k#0!*7OeHBLA;;NLX)aQejmue8qWPGIh zYVm14#)nQ@kw|m_lF{1K)~-aT zXnNXKB5K3BvQTE)5((G!7Xq_G*-^UGmuWvxndmFOn#KL#0uM6dQG0C>W{7rA0UKdu zQK(c9JSG2pzPg#lVpXZTtGBP;JRjY^x2ihPY6_c($4ITlBOndZ{_67!1RyoDn*wH6 z`uF8LM0^sOx&?&)cFLi(Bh(x#(b`yg4*vCn2XDVT zS1XfDw{UAGp92Ft2$!L$zn~sC*#o570#pT*x|q=|mRSQEI0a}9H1}J!M23Lu-1?5~ z0VXBuMjjAW;^>H+nZIb}5qWyWJ zPh-t9XgHTc97|+So{@I}wQcZI$+3UITbm8(e|ZGVpZXkv%+%kG@mDnd9vQn7BK$u+ z+FA7>+@X88lO_*OsQS^2VUGdzf5~b`_b_L?1i2_A=ZsW&hk(={aPJ)1>8X^lpW{7T zSsOts;vl>TEI9lIW(^$5I&kF)>=4l;;t&XxM`#~JU6nPhk>ufgT`qu_rq{Hq(smYy z_v8{XT-EKwX_w2>NI6h}SgLk_wd(8u)#}MNjOMy_r07E_YG>2%_550z-N0)p774se z3@k}JxDsClO-H_oibYet7?nH+w0|<>X&)alQ@&q=n@dQMw1i>m>g3`wq+CXmraOxU zy)n3nkAPBlpd-|#CF((;VZoNH5?z22TmbG{2cA`-i!fS*Lbhqc&?QbGvqI_I&p4S+bG41f}9X zK*?XamWsbX#mt?JBeBx%?A*@1dGF2p!QdEZy;**A7XsiHC;O%8@{V?VMTZ3|JY*cD z7lq8EfOD|Q>N~L7$y=2vE3P2v{HAIcC89pMbTOKV`?SdU(|)PlCBC4GgG^Y+#;%ebS;NR*V;&eW(5bI+!lS#m8vv7&3b##7Am6Zuqt=!zQAjI687!iNMpR@F) zELtWz}jFXs$r)m7Nx3WrxV+H;uCV zbKS*Op+=GzxIrRAk3xMkcBVo0w?L*VDOsOj&FQ{zrdTk#yPMxhY&@ zFczg{Jaz&x9}1L#9dXDzk2Y*vda`|2+X9=q}Y literal 0 HcmV?d00001 diff --git a/sgl/operators/message_op/__pycache__/max_message_op.cpython-39.pyc b/sgl/operators/message_op/__pycache__/max_message_op.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fed534a68dd2b3a628af68dd85c9986f808def79 GIT binary patch literal 813 zcmZWmJ#Q015S{(7ae|Eyl8Z=`2yq2Ii8?|^NTf+jk?xw&=63Da-1+R?vt%Q82};F( zfHHsSS}Ohm6*GGdvBXNdvvWK5=Djy&FgONU?^d7Oxd8aZ$>GwJe4riP&|$%f0y2)W z5T(qdfHSbl>RYhd$s3g^D=s1F{HAIcC!#sJa3Pqf`@GEg)8SJ4B)p=FgG^Y++n$K z=z&!(`F+3m&h;N@hu3t}Fc+j0=90GhCQnGOmEYxD$4IW8uyQr_G-sgEC_5|3(hd>K zucNyBOVh= zl>DV_srU<2%=heJHY){b4W*(dM!X>KJ!4*^LY>aPOfcPDLy-7dSrAtFt zrq>rAUKZ;wX_yshO~|IHeMrk*Kh_H>O@+txWU=-`x_rRISpkSsar@{Wm6HRnTyg^i zXKsE^qkD@6NG5#dA-~EeuoMYTWs3f!)9X(eJl?yQyP`lm zUY!X4fN3y;?rm6XXHO)?y4ztr56R`pAD1ULCOER#Gs&*!2>7?4*FCMb^mF%u#_Bis C_p<{4 literal 0 HcmV?d00001 diff --git a/sgl/operators/message_op/__pycache__/mean_message_op.cpython-39.pyc b/sgl/operators/message_op/__pycache__/mean_message_op.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30e29816b1348f88bd0ac9eee6c3eae2ba66a3ab GIT binary patch literal 771 zcmZWnJ&zMH5FOi_XmeyKKp+ql6f|3qm8b!ogml+%M{(UYBj4HyLHm)|8|i}5C8zrv zP~KnKmWsbX#f&${1u>GnGxqrTz30_Qas(7_mnZgE0Q}_QFd{0BY4}$(88D)RO5{So zDHv(g4vey5t8!_g3rGjQs0&ijjn1t1dS*W_3SK!3wJ-c@N-X5UKrU}!DvU647w5_- z6YU_6ZXkkuz^GkvW_5iJc?ZEroHI%O`j<4d(25Og<#)Le8~H;DBE11r1_(-?cTq)H zY19q;YGIw@b6t!$ujegdcZ7O2bJ(sIHdS5ZZ8@dRR=c)Swl=pU4Y4S-_lURf`+fGM zsch!`GP^ka^rl>2Wqwv>4I%5caXzbhAFx?as%yNjCyTWk(&Ymt&I&+065pQwqjYk| ztB~BV56=0(M;hT>S|F+Lm502R&tV%fUU>#DShhGPG0Ly1Ziu~@2uV#UT}&9=Xliwf z<=nR}p}LSnJ9LAF*o8s^ol(xLoP7-wB%CV+K$5 zP8PN-5f83@z$Ezo@H#9t7!-?f6joTx6Y_cT*XPNl_LfX`OfmEvzy(dcoOeB`5A>_| Ha>nWpMpd(e literal 0 HcmV?d00001 diff --git a/sgl/operators/message_op/__pycache__/min_message_op.cpython-37.pyc b/sgl/operators/message_op/__pycache__/min_message_op.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff9d5b1161f234f2e501cf65834131b285a13c5f GIT binary patch literal 799 zcmZWmO>fgc5S{%VQnY7j?|#s+Nvl1 z104BFzH;I(aAL+zNF`R)rCyo(q6qoa~mS%Lm%w4IKuIaFB46 zP82ec0?xoFqwm0ID{obzjJSlT_nWF=5{c&E!iJE}?B_+spLR=aFX0tkEF{7}A|GHX zj4<+UKrk(6q8~s*(qqVGc#)TTv>gIRCdo0pq@#vfY~WUYmm9H>KZv5C9F+k=k&{Lh zc^T=Z7gh`FJ^$7W5c4vx5t~CyXETrWdSN5g=wRKHD(lqOjk2X-L5P`4LqPlqe@^1- zs<3eg%lPu_{HN_dwEP+jr=C(x{c)KF{@S+PjdxguS{N$l|3StUAM}z zFHIjig&Iks=LV7V0}9Q-)TTAMJk)+jd3TKXWq?^#T;-*GN?B+L{zrdX5%t;NGA(Rt zFcy_rIld1Om#fruFRdH?1!Ir47H8Hu#C_*Tc<$zzKsz`#ur+TWhPwR{Y5$BNj%Wpb=ZRkw4^CY~@d)YA8o#fRN<0 zRb^g9y6yS((mKb#wFAVw${WP?5YzeGVY6AY7&B$~=%T#H>iYM|^-k zC-GHX+SvP5d~tgAuGoBw{k({4LQb37`MB&{zRQ6?|R(jF1a zu7bMkYum?8rACtJxIrY{fI@pPvuT4x?i)9vygOd-%K)>wyv!^6oH8Fs{wHBq5%t;R zDlKg=8Oz$N3%(B#SL?LsUV7W{$m$fGbi_WeRUNJzs=r-mylDUc literal 0 HcmV?d00001 diff --git a/sgl/operators/message_op/__pycache__/over_smooth_distance_op.cpython-37.pyc b/sgl/operators/message_op/__pycache__/over_smooth_distance_op.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6c9b25f7fe586515e504d5fdfde368f3b67a006 GIT binary patch literal 1360 zcmah|&59&N5YC^f>gvA{bX;b^qYl);ZrH;LqO$1jEQrp+&S9WHlG@7b?i%W^GqZMk zmaNk(?8$fNJ}&qMqOT#Q2M?a+6})I2!o(-CLiC`yN8S zIp=19j~;-ik3n(7agO5tiwi;!Mg@&M>>?!hiy#g#dW<;b-YdjC5nR!D!0{Q1{A)Om zP9toGUkjy@rFc@OE_Ac#-cav@AyAAtipguVz#Mb(DvW6l8hfvihhpDB+V`G37xJvA zD!qD`DV>z5I2YM+r3D9cyhn6jhweljfzn80gRaPy?4SS+HlrK?A~}tZDu86F!(=>D z?8P21@~p3#T1e?|Y{*zvW|}cOV##tTnZB$AROcy?FI<(3rFo(hV{g#!kLIhY5Oby0 z^RvfaeVAWv=4zSGD?m$hCDpv>uY^_g{9pOIq98u2FQpHO2d)VBCi{GG`2s|Ra4$>L5(77U7qwKQEjE-BPIq)mxVOWtw|Q#tL-fs;t>5kX|JV&6 zMh!LoA=-Ly3Y%a@+Cdi`7|+>;Y<)A>2HfA_9mZ&Ve2(>|hQ)c7Vele@@FL2#$R14aRdj31%a zU;lWH%oun82Vn3WFmTx+hs)lM%@|%jJe96ZuH8*tJ8Sv<4?q3A>px(@gHoOq zOW5dEG_^iFT)MJeS;@lEgmd@Db4yiI*dVEEQF1$MO0{l;cquGCu>(~tbdhW#HwZnx-G|Qo(?YQs0-U(g2kwV$fC1v1R>QwKhjkHu3wGH}1 z)>~Dzvcs||SXps67F3O{8*OoB@r4{hu3ow6hlWn8;#pRTF?i}26v7@3@gW|8+`$ua zNRCJdo(G-hn6$jovy`YmJ!nH!cU5=)e$`!7$!s=3Fy7zX7aw>C z{c4TF0Y3Bzi2Mi?M;sR@9%ft;f-ov+>|yI6gX* z`{;Cz&G?~^GRefFCbh1^LHC9{frUUZ<|rmF(F$|S$%``KL3v1qf^hC60 zWnHWFyX|lGU;#` z&n0`h2c3Jymt7;;*5a6uvAoI^V`jpVENhv%Xau~{lcHF?g4Z&Z7KxOMy+W@)U#{y? zET!Bm&mMgFUUBhkDYIf(16rc$RxZooYgpYZ|Hbbf2L45J(fU9kHRM8i+tLOoSzi#6qh(;eLs?(J~zX?zcT{YCG)J^vqj z0r;q<+P{gm9^An?*pYtZyd&+|wI*9%kG28#cX)>}+MKA-d8p9&SOW_WzS-+OR1=La zaX;4deCnc^Kx9)D>B*&y;L|?hBZwl>V`ty)XWG}Hj&3161Hn#%1+fGixIRR^zxhs2 zfgiYm`CG>d*p{`Q>M3MAJZ)`xT-hJGva9U-AAbDl7ya|ytKVTjLMhM73^v@D&x{WV zOV`G$>b5jA;oP2hVyNs&6C_O|DsIMIB{!WA&xOJF%}CZORVL32O%r8kAu6b(bgArPP*9vJuOUl4B)T+TuyVg)uHYOM@8E;)T(u}LF zWL3@KT2ObY>6F2_!RM{*`$0`DYsk{Ne4JOpTJkn1ggqSMn@--u5xGu|Nr)%ljhvjo z{QtGLN1X?@RjQ;Eu2ooB^RBS#h_OwV6vK`y?e<+8&urf~sVEA@4x$ZVx3xjK2DByY z`XI!S=REVUy=DP!xCAY#YO(67RONM*6oyolnfz_HZ`qS?Koc(hYQn|iL<;D%*4BbU ZJRo=w literal 0 HcmV?d00001 diff --git a/sgl/operators/message_op/__pycache__/projected_concat_message_op.cpython-37.pyc b/sgl/operators/message_op/__pycache__/projected_concat_message_op.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41ea1b2b22a3fb4ed28fd28845b28002575173ff GIT binary patch literal 1341 zcmaJ>OK%%D5GJ_~tz@}K94BysUIO&8DJlUu|f%wr~HMUIPIKFTxoS&SO?PK$x6H8SX?P{rcACHXX*;{!xLq#n5&A;6p=^4ax!Z;s*!6!2%9G6AuM~gyARrP$e zD!E1^tfw3?;5^sDTF(C>|Gb|r>QYUuT~0r|JN`jmT~6&>Pish2xZ2oh*#pMwW{Q<| z4d0f_ceY7$Ww6yaiFOQx&;uH?1ZYD4c>1+sCwCpnu`9R`JKS$z+XWCnid+CX7xYi| zn5-FaOwz}s?dXQCgUi=0kf`rG0XqiRJzIw^you1e4bdCvhThV3tF;2F&rNWt>O zN74p>WpHXJ#QTgSO?=Gz&W0R&j&+$mAfCSUp;3C}BWO>)@bLot^54?f?4x7TG-q(ArsT%B`Duh=Drt2(L?^)#Z6nsRtn0Z-5Ybj~&n> z8q!00!2Ykb+b`1?2eA^R>IRXQ-8%(-%=vO9bZ_aDYV=vVYL^W>Qs?*{aH2QCC8HR#kqqlg$(KOzu4UcX0&6H-QY5 zwi8>F4LpKg>Zaa5MDv%~oSX}*U@#5l(8b>+``;eE>`Q)y$XY}a$8yZZbR7H(#&>2E literal 0 HcmV?d00001 diff --git a/sgl/operators/message_op/__pycache__/projected_concat_message_op.cpython-39.pyc b/sgl/operators/message_op/__pycache__/projected_concat_message_op.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a804b334198e63d09f891fededa8e9d09f13843f GIT binary patch literal 1351 zcmaJ>OK%%D5aw{XtCcJl%Ck<399s0Ewvm9K$D#<3B0&M$4Vv2m0YNV*XLeHW8}>1OiPxU;7kWvDD>>FBCBWglmh{xVIzH09GR@MCXS!bwvH_5nM8-0?Br;KhO_Uu-cumH= z_joTk9fKd7)N-ZOabX=w_fJ;Z702R2nODl>s&S^SP~SgM){43MwaLSW`zzWB_8Vjr z$v~2fz9ln|An9eAMKY4?l4R^HVI+$&cy$7YE^O_(6Ao( zJH&zWTnlSCe@FiPWxA+KHMMp*ef{$IiN5%AYUg@dW1_;<#!kyEXuNKwq0+A5yK?!h zZPHwr(B?3a_A?{|?m|LS7(oi@hrho1RJoIz9_67gH!C0*_H*1;B0;RkIpT8;f73@~ zO_8G#9+9@A8(2p!l0n<~A$k51pjc>u<5wdP@ zK>0^icCek4PUE4-Mh?NXv)Yt9H}fzA>d5Eex)hsTo)wk4hmsv3AuytM;UO>>ePDD% z|F5~bZ^ua(#!8f`8%ADs?-ltW=gXDQTT6$a`8q^B+|5X?wB~#-r~7C%A+c`6`^ebv zpGjfZ?=<6~d5p5Z*Y##nRg>AO%3V=cLVH?OesC+BFN0@#u{H%ui`Mu|?V7 vGuTSq-1`T?{Bb%bXTmBRO%rnH;_s!p*TWyT1^*O~wE@W#$8$)BFpT~K=5S_* literal 0 HcmV?d00001 diff --git a/sgl/operators/message_op/__pycache__/simple_weighted_message_op.cpython-37.pyc b/sgl/operators/message_op/__pycache__/simple_weighted_message_op.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61baf71f4c30980f1c22871734f69c0f61170d08 GIT binary patch literal 1932 zcmaJ?&2Jk;6rY*>@Y-275l!19DzFhBsS(%|aRh`4q!dAVNJJ$=HJW(F_9pw`%#2GD zxtBVq5)!xm0m+engFj;rTsY;-jSIZDYbOB|#+unT^Jd=X?>*mZHX8)S2M3$#kARTh zaWfw-5TC*{xJV*tLb~{62}x;}Qly=PrB2r=X*Y3Guj^5ApGZ&o$3*(dnL1rx(rwbJ zoq-gx)uM&JqcW3gpz7Z!W5hr`9KmKe&lHz&%70MtU}%-(LdqUKnh$42m^Xn?q)R2~ zvKORJC6(+r=sMDo?lI}QFNjOJ9w_ySTibCuO4Os7%yWf3^bUsy6X_aE^BXLdSUMq7 zcETof!cH0Z27B&I9Le@;`i^y_Go{bz#DP=K`V)6brMvIwd%*FKBmM6g@M8dpd7xXY@I*2iiXt3}O!A%< zeTe_jR}Zo$B8g=*&q&nE)7>}|HqNui9*)#{v;%9Dj*X3WRkTq`*@$u-ZCsYL(LSRP z*`tr=0!KoNR9U5?Y@F^YZK6KBu|pMArd?7e_ z@knWn&?y4W<1DtE7Y&Zaa*RvS_*Nui^@Y~CF5cl+V36}m@vPoCiOP!CuYORa>8-0T zRH@?VMA3*%oEa;!o+`XDGDRKpRhi|McRYw}Z&)mBV&JO6)C~NB(s7MS`rxY-+7^z= zWU=%gdtIRi1}hnSNrW++|3d!!vNOz6)iLIvvwi>T_mab>9WzKeIm8STX*y|D_BR;Zi6*NeC;!BTiyWCNwIA?=IN2=b_JAopf@!mISU z9P@j9E4L4@U8X9722kQeV{}@szJ=H8eI;xKF3OAMyf1J()~}+}a(NB!1~YW4JFqir zunbL8aChHFo*lc5+LR`!qA#K-bpoP5`ko>RvYiye#!fAcGrPK z?xhZ@goMO_ACMgTH~1NQ;KC_qZd}65dvQ$xfn^#4cL0`HIwli# z%Eole&KURzH221yWQSDWx4!fy^cfv{a2nWf>`$ol_dWd)a02AW@H+;)1=tb1TG(g$ z*Re07Rf4CsHYR%uwmxCz&@RdvzBq7r22y#etSaX%~U8G`Bj}O&OQVfVY?nL^>P@*BFRg@~##= z$iwmHkMl#3CNiF7Chiv5ZjuX|6nSiqhH5R|0T^c^W8+;Fua{ER<3h*l^ODv#&M8Fp z`2Cr{q0k~zR_Qn&WxGn7xCd|SK*g15FExsOU#q^bg{!LgY7U@WRs&%2=1Aavvay}Q&>YbOW+}*wUL6xRAuD(#E zilY;*o|q&zR^(mff-*9$hWV<@^2<8`M7BF{i<=m@szEgazo315qtYJuYK6AWQ@M0Y z|FPE>x^G@UM}se^FoyFNm|-GKJFD8C7sK|y zN|C>0Z48ez4yp!h^DY2FZ_toMlu>r!vye5WSc68iP8+b7X&v~8FToG*9miM1eq)xl zE_yp#I0&%9K89_60RU|U?GLTD3mENVgE8w_e@vn2*$SD^Pso$_DYW|@I}Id$=?cDL z(px3VQK^9xvI8N(P|oXdQxp_{8u4yDoXDC+ES7E$rl-+={Kdmx8cv0hI>oBCmpS>pG3X zz^z+jEq@kb#Km&q5yKf4^SMR8X-%p^NQW%ivZzfb8rF?oR~#pZt94QJHyuO!n0t>8vX!Q)5Kog+VIMIH w2GBhQ1MdWhMG|!X0zssj^5TsLFC(3bK7-LUALgJrD7xUp{Y4Qf`+Ba@Sb-z94 zn;46#yjgF|geu<+{*TtFBN?-Co4hogaadNmwVY28w{33wn@*)WXYAR@y24nCIGFs1 zIq>-Ib1?1Bl2}Zm?u5jI3o>^8&)E4*?hU!>m|@p5fa@gnYA$g`kg7z<-se6F*V-N=hSbz)%ATHJb5-AKRj5!Rsj8TlaOi@g^%u&p_EKw|s z3@J=0%qc7>tm!OKtSM~444UjO8G*_*S#I&UB^H;sJNpGE<`(3n7Tpqb&M7S}Ni9MW z;`d9DkTY{O15d;BFZ%LB? literal 0 HcmV?d00001 diff --git a/sgl/sampler/__pycache__/__init__.cpython-39.pyc b/sgl/sampler/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4eb3ef69bb9a235c42a966616826651ab46d4dc0 GIT binary patch literal 247 zcmYe~<>g`kf`f(bsZ)UTV-N=!FabFZKwPWP`1s7c%#!$cy@JYH95%W6DWy57b|9w}gA5m80suW9J^%m! literal 0 HcmV?d00001 diff --git a/sgl/sampler/__pycache__/base_sampler.cpython-37.pyc b/sgl/sampler/__pycache__/base_sampler.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81938ce009efca0b3cb7272540c12c790336b204 GIT binary patch literal 764 zcmZ`$%}(4f5Vn($4a?F4;;gU@}$))$eCcgA#fHjaiGQ_$gd$=Fjz4OedaaOJhrE?-6 zqq5{b;pPai=4&QU7zjEbeCbcw+I!?2ZBZ+!4br+&KjOgpnOxd|X<0>;mC6oQS_yZu zsAe`2*sgS)$G#1WDyEvs+3?rzOwY`JrgyP@IZ5VqsS;zZlkw>Cq`3W)m|2l*c#|u1 z-cl+(T-}0HEUHB#1j#TFfboz%@9uC*<*-Z3F>Pm*PqC95AaH_s>bi(M-NTV4K6lUp zy_pUH)TwcU5M?cI3W^6pT;F6x`$zAA$VtfjN5FICJHf*;-D*Ob5ltY^vCsb7+SRqg@d)teZ#_o1YKDKr$XtgKWGyi}a zm;NPR`3c;5>KkW+wiOf2v&XUD%$q6p_QnMJbu`s)1koQ3wk!FfVEhYL;*lh0^#T3DJ^Lo{1!ijK#5#d?xMVH^v5)EVWH05#P;m6DQfK| zE2E`9SvCuoNQ5iXmTBlBtE*GPXRR-*g=jNfqVWwU-3btiIPNg9|D zjtq19gC2N}CI)b|);Cc+NwtyrsKZ05_xy?}JNki<&B#m;q{6+(+wK!jrq)fT)dyo* zvnj!iPl~QQgHUCkRN|7My-f4pa)t|e8gsXJFjBg)ly$2vE8b6}ytpi?{*T!Okr!y5 p^NwY1ftSxcW4GLgL?ojmG#uxt@axun|C}GW7F&V z)qAp0PVgJn7bN}!H~vIlIr+w!3sleSI#HA!Rdr22s;cX&sxODbK7#S&uTSKkApPxr z)`NlY1(1FYOdx@C6n8U@ar;kV(#|wy?d%nP9AI~bjm#K@r6fMh8eR;awr& z$>@^CLxGRb$iD&Y=x~HBD=SsN92_M@olBLvjrE`_L4O2-M9zVjyhIZ$upn1qOa&F} z3V|b@@L~2|B8Flg3um&~Qpi43%h8Q{_hW&;y zYN(;k%&%nO8LweW;Me#eGKl*EfeCu3Fx+-;NX{}1COti`rBY704LQ%stl``abe$+I zxqGN3Wo@aR7uKs)b*x9eWm@KwxK}I5+a()LWYWx))Rrb={>KifMx_;ycE^F7_ga&H1ppQh!wK34%a)xOw0AK`m6!WY|b zglN3>IVJ88j1b6Ocm||%U{f@OSk2HDS`dR5^lab7ZU%X>4fqRkNxworK3Xt?mJmzW zp@!W@*W?F?;=(sRM40U(<1Ofm-R~imOJ>M523=_88G0Wr0uvY_=!8Oq*~O>*oz?iJ ztmln>di(E|PoM4_Crx^?Q&vLiojlXc?f#9!=H@>b9ZJ6Km-B+>$+=Y8dMYWW(r)sq zl-(L{W0sXuJ7~{von1&BZ}arT9px(!n~FEG)Rj_0&_4=Q1e3+5@gT`_4*1uVYT{u! zSG<}xfEPFxuQX<*kZ1A0?R5Z&)z9iyf3?b=f2`gGwH@1V2Ew&>?oV$4LtaD}-o^}X z;1GwTkNa-{bbT#dS5gmdI%)y^9RvrSafOY3Ark@s&;~C_g8{fz`d?$=fB{Hn%z@dK z176au92Op&^_D)|3V=eOo?+M6eE=kbQw)G-#&5x;oh1`=wv5-VhlVcu&0rY`&xD3- zp@+!bUr7#St_lUD3zRBKk2f7IjAiaRbz>+ttYZ%sz<#R+PHTLbu8ePwgW&ElCaGT8?3Qr zrOdP?_2W2LquQI_0(^^U{rC+`)%)PuE|9(l40#&>W*^_d4{)D^_%1-YjW0et>MD8N z#Kvz)SZF-ZY|5mN&}{mg7nPXjPL4Q#F;DVtr)>c00a$7c+u^E(q!RL$zw&K_)m?YU zC0g6bTo|?M0U?fpaHCxt-LbS5$CedIb7H-wQYlPOU(je;5F(vs_4&9f8tZqn`T(@p z&DXcac~&O*RJnemT(gPagZ?!ZNt29^6KI-UNR@K=+nRpm>FedO-#z*>SM=KD0+dAi JMI^!z`5!0J$3Fl7 literal 0 HcmV?d00001 diff --git a/sgl/sampler/__pycache__/sampler.cpython-37.pyc b/sgl/sampler/__pycache__/sampler.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8090469477b4823e493d9704953c97ee7b983d7 GIT binary patch literal 12125 zcmbtaTZ|;vS*}}GS6`;*;_>e6W@9S$;_Nt^S;r1ZST??`vyOv3G14H03!nt637oTNIZ}SkU&Bp@d85PjTaCP^$IV3L%bmQ1-|b;)zv+- z<0X!3rcR%8>YP*O@?XCH|LcW?g|dO`uXkUKF1%nE|G~oaa*%lrcl`HALL)RgMqU4! zbqi&yV|4AhZR$FwTd2E|w>!nIS1(E4>6E*bdPVYuPPMyGUy!`psdX3Yi;^#Pmb%OJ zWyzbJmHG^$@sbgi!tz}sEJq9bcKu0U_`CwSIrCrwUbdf8bfK95Xf z)XmVSTkjfMW@v`iU9WD3cIezS>dw1{W7G>+L?NwQ8+JO`)S6Z{00-h-!7YO&Mq-YQ zeQU=Wn_K4C`aKgni<-a=Eu>CpV=!lQt?Xyl%RPUq(M&G-$#y&TW4&}gZet7fd=(`_ z)r;{g^0)AMzoi<3?Tx&%D&f7?HZi+GJRC$y%BdFw?Orsf|5O-CK7Ws>RJM1wpgZh~psmbK|26m$v)e=u#ZtzI6SiSH9TU+r1RG zI+wChr@tG6J;e=J<|epE3}@AJ&C%-Ij%`fmly4ok(8PTecT8$;8`v6iQ>`GsZmAl6 z((=tlg41pH!dU8c@(>?+)0Lb-DV75{_R494b(dShEx?wPgP_|FhaKjtL2!H6=;$Z1 z1obe7UdZ;K8Fe~AFnjU3sklbP_sdo<@ztEWXI4F_X`+x{qzqp16(kQLv#=1lAh$wT z3_bk1n(OLCAwLh>Qb`v#qjqciW?x;`B$DqVU&8y7l3+KdqUp%n{eGw^5cJ`9qhuT9 zt$wG|zZ3f}k@BuT|I#;R)b}#kv>Wx3M$+#0e2{M=Yz&eJom9W;zqa?%tgSkWDUwnZ zMV^%OhF4i!K(cPjk*h^kaAT*tO?uE<&Lw^pwPZ zQA;K#gLBun-HEk(erz5XhsKM>TWb@0Y>%y5ZhST|!{XRJZS0#rGH;)K!w7BN>XxhC zPQcxvea1-aeM{}=e9Odh2cvq2)~W$sFWml~e!_guOK_?jD=ChRox(n4f(OYUZ^oE^ zinit%_AB*XlaZ{jd}vLau@jbWS?b@DGWu00f8Vnvg>fON9xRLthvsK=pL1xpfOh|l z`WjpAEQYl+#v$KRlJN@R;vr?wl9b$J`CuhnIy9wjg*^@)I^_W~#%>!k8ur>-c=RD_ucdZ#FkF8)bugVctvw&d zks{%KL#e(>%k8+`i<3qV3$QVAYPX_9k=oMo>y6GZ(#`Bn)KibJwxD;Wl4%}JL1;I% zdV|!Chuzfa_pn8Uc)J17{0q`Yl$G+CFF`}~;z2)-(t@bDbyqV0X+Nzs$aYZ}Xu+Ho zc=DSFM`4 zYL4z_qim>L&J-~k+2|_ncn3+#0LdFW#;P#^I~C#e0@grXXs!XGgx0BlNO-e*sT|2AVBY%ML%uFP+PoLE z!^?iF)4$p1pbaWEhLFRmtOo) z({DC9@iWLwJ7V~1+#jlD6mSpYDM$$eZVVKJ<}R#ou$LCVBdsWXC;%&IIgL%4OFRC@}sIbi6AOn`cc}(nyv-`o=1Z(^Rk{s-^V~z`y{W=$# zr;Rgg=T4|A3}8ujTztnmV@y1h?XibkXQTW3exA`j4OwxX^d^;s?6P9dIK{#sE$oApmRvPYK{IZr{I@+ z_w&Nr1Za@>AnTm<5(xV!nX*9`lU501aWV=jK+F1kRPx14uUr-#z#iGb_`MAL0Agq~ ztSLx_KSurQ3+gM_>eR)Y28l|G+1%6yYSoiWzQpHtx7|x!&{x>+rfze)j|HS|LkLbw zYqv2NM7r_L9!ia){~uOTtqvRO7Cfw(V2u-D9UW|cW>^w@FU zJb{&(ln=a!Zb`5t;CKTPgaq3{5ya?FVkgD{$Dwg4Q=ADc6qDX|mh9_02u$QshS z_%6&J$a7dz@&xtS6Obj_O-j3W6MGCu1UA^6lujfBiU zc*D58^@f4fltAWORf%v)e%(WJZq4PL3fQPHcDAf7OV1uD*0W`W4#3#bJ#)OYMLkv6 zJR)`r*hbn4@hK#2irv}-h!0Bz()+`i;MLOCQZy4%vrQ{37hd|*ISQmNm`qFGzu>5u z^nLQa6P_d2ISHJ1{?v0+H}HDQatZkA_7FDZRGQ8vt6oQwznskyVE$=#stO0DuI`yu zn*Cn00WF7HJ&y+Ji%3!z%m&UAMpoBYP4Et!M8%$3@gOZX`~9HXNK|`weOc>(UuQdN zIrSYjwpITQAYZ?eI$jumH{#GRG2w6ck0D{ zCk!ZDyy?=_i+nNFB}`YD$s4TTxv5u>{1kWmC=w#FCF`ttRuCBQ7|>S@DG}N^>rwMj zYjo~G2fAKTq$YKm2{~VzQJ0yczETtm@=8U+Sf7Z@Yu{ z6|F#sFWQ6v;>;Id4O&1P%)?ei-jjS$>Prcc10j16g0-|?n3Tt5yu%xpGaJ-_javC)!imh$%z;2GD%1?KoY4yDLbxRTOW7vEY|VqCvV0`$ya} z?PM>An|<{_gHnL6B7~ zcltgJTtOQY0XQ;CGk|1)5gPiLjk&SDs6NkE((F=SU?PhZ z)NzSLact!HLp{Yh%0xx0N?l|^;V1Zay)v@})mK^jH6~{fvh5 zU=`BJBc%N}tML2-t-^1D2XhQ5$n>a@_Nk3REwcQr_Ab0ybW9MGicvv>G9o+#m<1Iz z5_Df9c9)2^&x8X-ynT}b+7=F6^eW=^2wvgkgZ##G2eRB5yXXzyrGCe_ zsNVtr%X;XcgHlpHsDv)$xV37025|V$gx?cA1jH5_SMOu1W6qD<_L-yJPj2wxP+&`n zwKKE$Nbqxv0Bf)B-$a#QhcKts6B%om8S?XX=-gPEss3L(G-s(?)`RsW;zTD69Gt!8 zp0Ad`r0kT8oI7A4H$g(dtJ!*W=`uNPLs1^kx6TDXYcvS8Q> zFy~Ube6=$KEKxHUJNl{+V^+_I2}xH1bQk@tsDa2G$QCXTf&|SRS})|*HqArMo3Jwe z3KDojjR39~q-J8l&SfCckwBu82NDYsNOTd7FoI$T-`#p)^wXR#PIAwHYA^fGA1B=l zn-L(7y0xoAJL@y0MctAu0GxC!TS`Xp7kZ&sVaM9*1zz@VXt&SXQ4^}iKcTVeW1%+? z3Gz~J`jXVu-F0k;4UPf&w3fU5f)K7iV1GJcGz$%wyUk9fLjrqJ}BIxC{8}~k+hzc>tCOR z!u&kSb8t0gF<~-9f{G8pWr^Tts#D>HDFBsow(Ce z7>O=&@t5py6}QCFsVR4+Q8BQ?0W~6YApN7*FE^yi#U6bnC;sbEl1v$K%7NrQ326Jx z{;-$Ks9f;QCbmK7hd{i(kaYsKM^o*ou)B?2p)SA#iHplKP?_w=u`RfRQJ&@&kYWr6 z3qN-PiRJs6$5sf5@}<{N;L=0$mZ!dh9Kw|d>^+izRQ65N=q(}VFvqA02?oDT_FE7F z0x3a`&|EWiilNIG3Zv^G?Qw4upEBy4EW7A86(j>& z+~5}&+VdOF_V-#^0MWMN(dWttw;(tyq7i2ya+LiLZBvLAZQe$8g4zG+&UH_?=jnupoV-gMz=mmsjpqvvyhb#*Fp$k2Yy&FH$1{ufugSCyb69 z8qb>)5Hr#!(a=*9^PWkw>m_tOZ}lbCzI`Yq1IRPXYk zXqJ=SEm0x)oDrtr=K4d3J+!pKQr}0%w1%Q^K&HhOSY~^XjOqt$Q)ZnY^jQ6nZxh++ z#lxGToKqKHXF?;ZeKsyd0GZI92|iE}6{%=~r49)*b@64-Fp32#Y6L|Djtv!9^d85S z724_@=4EN}NR~E1-J9H29~r|T7@z(vS&t(Qa?X6pd>k-+6(P;h#ZSUP`4Lg4xwx<5 zj{gnG?!ORKA=>y)&~5Zbf~Y0VIE9u{<_A-J0#OJj3-J?M9m3IaXn++Ex=%_wWjw2l z?VTz=*+LKaFSj7TVZTv~Tkl!AAI8w2bwQ7TTIBNDT3(Bs)Gne09P`>;&s^SiIj=t|$7)pH!~NhXwj_-v-x9_(#s&lEUU zJXG2{^w9jHjZPb%PPY`rySl`lr^G7^Y>aPv5kPHjAnXB}Q5635HGJT^q1{4i<0ih& z%Y14|k)i6#^N)u~yAwd!$#_TYBRonwBOsdq(xw)+hBN^5SS>nv2kZ`#Z%5O>vDsxE!sczo%N{{!#8 Bejoq< literal 0 HcmV?d00001 diff --git a/sgl/sampler/__pycache__/sampler.cpython-39.pyc b/sgl/sampler/__pycache__/sampler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e04ee1e8d5637a5cfad8fb4d23c9a90c58bd5334 GIT binary patch literal 9891 zcmb_iOOxEzb;bqIX!K(q9FF*sqCoN^YKa^Z~uS1Qc^Y!qI_r6x#1BNVaG)OXP|Q8r_v zXEiOM>+D{s>2ThP%RRSQ;k+GJd$neb^QE}nTWT(G-ieocE6o+oi+HuUiuua%iQZas zO&HgV&O^&d?YeYakJmM=cft&Woz5dtI}es)l>$ZzMn=a zMO)lSTlBd3pOqKXd1NA^DMF)Z-Z!>HC_?i+w`qk|XuoGP?XVO&xR>r5w$XGzdO53I zf9=}WbAnpt+^jSy?L!ThdIm{qq+)FBn_Y7(wnWDmn|~{QD6o#G^{voEYKImkwMV_G zm;W}SPJ8=SBDa$y^!PEIe%tFs={Cw+NgOA4m3K`BgKeT?V><2)TA^OlPlL3R^u3Pq zf-o4Q5eCVm=e@OiZPr&-K;6~NMky<);UJP+&RpN`^gF5VXLY65+wTXxD60=-xPI-M&&Rv>E~s{V zA)mFlZwz)Nt<@^6mxIJ8oe_>0t1|S%0WKAyWgBbyH)0R#JTVT8 zE5^+;d)C+*n>&VD{RX&yY@Igt#k=CpYG~;`JBI9HcSLBNHBximl;73)HtHtW{IGOj zt{b#NccgBieBVvM5Hu6_Du?a??V`h_N?s~+shYZzN?1D(+@liGTvg(SZY{0jdB39E zo{h9D|9xzCmnO>xx{o_q8QYW91LHvKmC$lxTp}5IXredz`)e4phW;lH3}N)ce9ram zWBcZZLVkzl?Vj3m_d_yJ=J(Jv+;f^t9DeUuhO88HL*ZY zX0@Y~-!rRzJBWwT*QHD(Q^*8nCDvt{&a&@gseE6OXh_m2Sy`byXh&HIEyE~t0vYsC zwEF!~R#8z(BT!!g*?yWiGPyfQWoD{DR-Mt8Ih}q;^<~_#hbptfAPuq;xeBG8T8(B- zelLOem%f^bG^^aeGBMeWid-SeP7r0w0qLYD^tCR?N_>|a)Ztm`P;Y@4s22UeS}TdC z1&4`Na-BMJ@)(V9eW)C4ev7tFT|{CCSJcHRvo4m!IkRrM<}z~Y!aY+Il@?G>= z!npdleAXbv^UxFq`a>POW9_+nmAz_ec2~xgvD-F7CoDq|)W+sybqwW#o=_&bh#xOHQ-ozV#a_awcq$fZx6l3WA)V@x$a5=2}6b30YA@0@SM-axDed!N| z>BVW0tkta`h4sb8s*9eA?hGR^JIN&1d;J8GOO4YpZ?X|voa07tn|iTKYZ2bUT=CY% zoa%r~hh8U*Buv{*6z_W4kf3}hU==e-wZ=8)dy zI$OMuN5OQ;$O9}P>v;=`DD*Ihv<}o3lnv&(yWPPfP_^sbji|F8zTr(@mL`3p?*Xl( zoz^f8PIud$HuN*0~{DH*ok%BVhP7r5W$3%A8oqiB|J=o>vXRE4ar2pYCJ+du* zsNRcSJ5Fw4Xt4{?J=T_U3!{lxpclFWTRGi^d`XT`gfCw4w0iK~e(}8b=1r^zt-_n$ z?|Co^x3ET}n!|Q9+k9FNG75M=G9ZW&#ODhlsVA`th~(|tiDa?KAx|@5bb9$~H<8_(ShPDf>()uvQwhZbNeq22np`3L14* zq*-+)N|{ZB$}2R5O|JmF1wLVYn%QhtE$CA+V{s<@tV)YVLKxjUqDGlsfVP)0nIbxk zrK+f!w&@7lv_(~%F{@CNRk03riaYA-=95sPtLCEGEG7c{1xh4T2$Wd${az9dW6IZk z|IRRobqk|9`3fEu8G5afEnf{(dhPP&b&YZtYyr*x`9VCvBWqKB2ZOIRN|JmvF6fmY zc9ct$EK|a7JvDaRXy1QrZ|PcKqOuqey}f@^_ImU4kDxZpMbpfyAx@T#D)o zxUmWlQHM+bR+ytBn{phMAh%?SCALs#8*nL%uCY$|LF`!!*RfS+7!B-OKoDOhEHQyC zMn7w8pEV9l;D-_*yaNkuUoys}J?qZX+9w^GSK!KMEJ7|$SR#wvaDv~*x=!`v8=r2i z11W>u&ugl+xOiKgdoYM>yP)LBz9`n+q|{)*V4j0iTVOWfhi#QwwEwN4^Z~=MC79qA*`D$=Wn-vHj$kr!PyU1t8RiFFSzIKGuRM?->TzdHf(ZUyVgtr!&ew z!OqE7ku3O~WFAQ3NUJfQw88l|>wAI#o?+ohMAcgU(q0bk=wHBC9<-u zq&Il^C+PVLT*m0RLsDO8-Hl6e(h6erGBU-e+=Rir@k04*?6fT_G5=>J?I4`=lYXS# zw5b!XC5qHiqpYRm8^||nbBowqTG*Lp?#i*9Y1XKXsf)tApqGA5x3f^uLj4&oEmYMy zjLsr*MMim_CI|4CpswdI%W7PAs=1 zY|q0g6Q0b6%*l`P(!{$p#i&yz*VyrpJPz}H{Zv`}wPZ~tku)uw{t;;m z4%{Is)kU13-`e%G%UlqA{5x*4RgiCE&sipKw2{N_KVAWJV(-{TDJ%DrJ}^=uzlV~% zNmUNmKTLX=)7nnp>B{G*&Ld`lFMqI5u>LMiT^Qz z+i>L#%pm`Wl5bN|qht!TBt1#~F(osFGecQ~PrA62iv%dEacIK5-vj=RyNEB569>o(>gPPKd>GSv(c6qt z#I$bT(~;Sk&C7e`uY(6%&0j2M{%86jo4m%6A^MFkdN;IA*a0!{Y>(+rVO*Q2&q_>8 z<`z#w42K5;sgp0lzDWU#?qI2;q(o7YQZi#1R^G%~e~(KMoHfc6W}&~)xg%S7JU%6y zM_PQgHDJCj^eN*>GqolZ{e(Ua#c4v24u>7A(M!{ z)(@L7X!JtBm)imGbHSr_#MAIwS6d9NkGfz;y$ot!*%=(|^wm}*eU2{pAV5P$xr@gf=3s&!|Ab2s)G_S3s6yM)-*G%VMIN+m z#joID;$uZI0rIilLxAuBHj=y@0^<*$Uik2c4vk#2*tBtZixZZxea>jcsQIeRUYHdh z1K?Sw1x6w!A`0kz6j?pSo$em|w0X=)tlKcua(fUX_%l4OZ8H>XoQ$_d9gAYyE@ z-6T~6xQyVjhRgp#L;p9#PiG(Wlz4@_UDI6p+*^fJB z55-OI1N7*^C%!~tlu)p#6g&SIP% z`GL5PQ%sn99FSsUU2{0sl^5$y6CK$1OGdhsEJ9%ttiCXVG8&)8R)zNEHMS8i=g`{Y7sFw66 zoK1opbcp;s#U}t=afm#pB#&$w83$&6ib%3>p4t-+1UV^ZK`ivWy&$RxW4@1`^3Rb# zV#!6Cvp&pDB&a{9dgk?4jJ&1p3YHPg zpoj(HA9HGeXN4|cbtO9TU7GQb_1#C^e^UMbWf7)=zd$PK)T!3`A_}~1Phi`~vjo@C z#_8zSS)*-gJ&U~}i=e{;MC(**kvXDM#U4d>rGAg;KK>Vtp;!2Sp@3r(Oa`ykeVY6O zO1OCRp@@rD^tTZ3!!{f1nXB5`XmXAXX|{|y{n>=pdP`KM4=8l8hGGFH5ZOi>BTF@> zE?qUer2(j;k51%!AS$awP^@%xgR@C~lqBaPvni4xDKO1yIMabZI^juXh!o4sF}$A? zZwGc^S@{d9;3j_EJrw^9m-3Jpc7assUKd{$PvGxG@eGa&*Ts`!^xWr=r6U_g{>Eio z>aURyoqsw4oc3+{hQyJD!CguVk%G}>pDaP-rCg&kWSKrIUK!-*z5Q4a(oVw zRlk|^+c)^j8u6nGrQW z!{6azqFe#je90~2T5%tf)Ydqy*Q5lF#jf^eq(N@bIbGA=3dfA cJTH^J;qV2)XN1H#ERGa_x^Aw%z52%g0JCVz4*&oF literal 0 HcmV?d00001 diff --git a/sgl/sampler/__pycache__/sampler_fastgcn.cpython-37.pyc b/sgl/sampler/__pycache__/sampler_fastgcn.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66487d2573561c1c6dc97bdc455da7f6a9c83206 GIT binary patch literal 1786 zcmZuxOK%)S5boDJc6ZhU+XN@MSP(^8DO!^Yg2)0y5Ec^h#YjXWP-{HhyPjoU>7F&V z)qAp0PVg64Uy%3{{E)tK@{KbWsGiw%q9{G;s+xXORo7QlUk-LJC0-A{*#!roW`u>Ug5_9c4tT)7QMLN$}}GoQ5<1(fCLlX6%w9| zE@?az_y~>s8_^m!k{Od$F2hR>u_cAwpm`)zR)E@(BP=gJ+Br{@&p%Y@*Zy2M7 z8tTmaN*12+8ny&_jUOO`xGxZxpoa>>ZTE)cEYo1p)AL#?<&@iy^Q_Do&h0?giPDn0 zhgwqBmg;$7y;@brdgNQCWj={}wUWGDvf)G~&0I-sX(DFU|LNyMO?4cmRo?0VgFVjE zJkgr-KheLB_fD!p?&aradq)S~?&)d1r@L6LIiElXrfE5@&s6}|=@Wsv> zp&GA!Pl-DOBZTQLJcFroAX7Ajn9a~OS`dR5^lab7a0dCZ19>3GCH)Hh_-MflT0%@= zhZ=SdU6UUmnhW3f5NWoLjJKdK9)1roT{1(iG3Y`w&(M2l5tzUbK_?U<%`QIe@2BD2m3E6) zrR>&t8@sHW+Ch7U>+C}6c!#Ga?kHb@*;KrdrLL43g8oscB48Gu#)Bl!IpjfIsU{w# zbH%H91Gs@>@k(P>3V9X}++GKiSpBSS^;fI>`A6zqP}`ABXCPdA=l=9I5adOK;T_EI zCJu2(`ndlVK-bsObw%~yrkgFGe}Le?Gp@kVFJwXh0NUUsX)pk{DuZh*954Xsj5#pd zvcOB)mBqq?v)^=aJ!6^p7Gvl}5($10zI$OqT*F!^>{bsO?gl9rS zw$TIR?k^}L0~Drf9P6S}9yqPGfP6Ov0^^_0M?T}V%U}66!s?DY zI)i83qqvRtUe!iMPvPrs}Den z-Fkg%oM&Z{PnGL8$~Bw#J?LX&ku=HpIDzKbg;Xh*zpd$4p1xil``x2Ib49ORES^(K)nMM9UcK!odbe5?sv~!%Sv`M~mZg(CySYZ^eyav(x0f9k+%HodwQ2@#64I=M3ldc&W35Ico9Q;c{nL z*KTQ{8`d9bVLe*dH#_H0ZiG#gn_ON&xfL#;ybwKu_Vb~BTWc?V4l1>cww^n8qC}^0s&Bp%$mn)39K{iJ^CUO>!yP|M@h`eteo>W`Q!Qm8=qG*_$qa3A zKkL$R<;N^C@-i}!*3m<)V|=1*>Y*MQk6av)8Cs9DjuqOWgKzs2&C)s!$gbth+i&0c zenCvvC?=u_>1pDTFCxjbOrL7|#=w~BoBGuF8#;f~_-1GzwL%k3PE_+d_FWjpOv8^lph@v_X+Z95{%a6_TKre)^7A+}V$N9`8otsNRG z8t%)!KUFQ1Pu$E2U7Cw~nPL5q?zF}wt0=i#va|X@BWxV%+@lfFTurQ~4O-a(*875T zdlphx{B~*$77xxGsy_9DrKxpr_E0<2_iVH*PwlW7ww@U1jsE_*Q_U`(8M)%%=oXY$quaRfp2Mk0OC z4tmS~Jbev6Y z9Geyl_>@WK)do%7mBJqendt94-@cezpj+mPZ%Y{ol9lfZk&3+0m;FR$L4s|YAS5?? zQ6|n|KY8QFK|GGQn|KBl;yl&btZ>8y%GZ1!N8Ihy_~Z4E0zHNFhX_ zuLLo-`CGK9!yi&|wn~<%zw1MoaWCC9l;k=RO<%;{s?kDv(P$a2L1hRaY8?Dg{i5#b zOZqa}jtONkS*8A3n2#vGjz_|(K>bbiOy4);N@g6GQyomK_q07?d`rB59P0M0oqs^t zhNylD7S{&Opa#~aoEz#-%#V#d2h8paAVP=QUTx|eG!J3UNZfyJu4z1`6&l-h@d&*Z zFs?PNt!bnRo*24@{)>m&$L5~9SKn)7#^B7fK6QIW=!7+hQFCfQ#F~dt%ou5j^{I(+ zjmr1ef1*ukopL2(3bkF8Oz!_bNY8>^YI8rcZqqDW4(p#mxrYYCtP$+PLd4)lWis)sxZvaRfdjG2wbINg+nmSdDp~4Z_hm?gZP^i^WVy)h6bOH`kQ_ z21)4kvq-=|Jc#05PpM-$*X3vZyCk1p6m+-BH6Nt$0~irZTZ|>+;=Ss;+koCC-O%Qq*M8Xek0hz~K54ZYQ0#&=-!-zVo;p<+tL7Md4#Dj&;`rUCH zh|5ZS^aa#~IxX5(u^=&ZB%y=26QqmIE*1S=KM7)Q2)p*`>?Gxk^dF706IYQy*1wk({EZ-nii@gyQ||jmzHqcX1wc3h#U0^`M6D z;fzQ_hP`NZ;dCCPhOt08BDEF7mrEk4C$R^Jq* zsCgD&f;2ajtr_-*ZA%b-$!$6bS)zf^r|so76VW!gi})4|{UaK>ES{r00U_}MlH5@{ zWM#>El(kacP*+WpDP>A1iFFkgQh1KpReXacX;g~RXKG=svY@*SZ8{o7NytRBTdWZC z#!QBCi)a0B?&l3UL6W{`=eWjn6unn56{AFLp`kYn z%W!ndu%I1RjfTFgH}n0&)bZ*@fZ446K-aFouIAkxMFvevY|{6{1TL~fZ+0gx*I zkX8jCS706V1F{D;!#db9l+4iXDVS$*8CHf{?Rl_Sfk-ELE~~*}xa2x1P)AU4Y8~p& z5&(3vCRv$%av53-*%}+2Iw5jc@q03__sN#{klMV-=*?&8`7)3uR|jEVekW(ae06*+KXX?PfTi3Ly zy=UHk5uO7(hEwAPdurF7r@a&sRD5$t|lTdzdV+wAYay0RgJ z#UNuP*$|L&gf)bf4>D!WnK_OdK6w;CbjQL6(#EA=o~y9p1z|EB3&4@m7Q((OJG;ro zm}jmwOLhvN)Q4j^0s=VI97^wH2G4~9U)aL-QCh1RnP0^lxWdZ$BP&C4dmE1=7VAM) z?zZ-H!khd0-x~LU!&pl~Q4kljjoes&O~Lb{-1bG|8X4?vIvlk&@6xy9LPu^v9SQX5%%cHMI&H1gdud*Xvp1pb5|3ox)3h z7M?QP40DPRn*_dV!8^Z08Xu79hif~}^BKHaz{nZTL%s;Y;$u$2lDPw@d&IpF^=9p+AV zE5$*Jm#F0}YB9q!6R%Lt4*LOONp%*2PQAL6yg7HW@T(PfcdR7p%iI{f-l>(`UoTHA z>=ZCF^hd8BlLPTnEcPor%DgNtLjYi6j1_=4a&ui{RiC_YO!`le+qNriVHWWNO5UY} zAc(?j9m>%<;)j$}8mELUj93VC&;d${#sRL`?vi_+(N7l@)ra)aTX-ZTC14A6#3mtJ z9GirnOR-7vX9=jOAZ(|mg0CG{1zkJ!a1p|ACO>;U9!t1^GyFFB4)bAAPvJ&lbjAdi zz0D{<2nZYtVIU7u$%Qf`UBXZFp>Lv-l>!0!Dd83P&Pa7Krap`(QsfRS4n9%n*O>ge z5`e4V$Y4mJByjuNML-GMeQywTGw&gSMdS(sHG}dw|E(Ctv|jZ#q6`74?Hv^#of*5L zSFsf=aI@G{!S*jFA=|q4i6LsYSG_w*8$19p@Q+XFPr+1+sLyRCCU?u#;0W}f8h#O|iqQDR_)xY48l+kJwlA*uJ#S_fPQG1T%k3!3N`@#ofw+KUyAW;9WcqAEo&4P`(bjsGNqpdIkZNB13 za^J{M(dM@eaXtfT-`905SwhaF95~Y?&MY|;-!^k9kVM;_MAebphzKjBh!!4^6L4qN zC%?R`FYkX&YA$h7ovHNm~e61HH2olA97-WGzzpZ=~FH;h^ z83~_b6Fvw~Sn(-WM{4lbcqI97n$^-9CSV5sPUGGRY0$;dr{rGZU`eh#IcJX$QhbbS z9N_v0dq0L!;TJZ9$6d5obZte86WX{uXEI~a;>>2-WmU&`ch;$5Az~j|onht@G7K}_ zJ;eLxASO#}+Yo<-4tcGZ`nVYXm>Thv@bWv?k89!e6IXE>)I35Psoirl=|v=cut!KRR7f+5H`if=q2tV$=!lVS(1%6Zeeuf7Rf{b8v_<8Z2a#i|38VZ<{;hh z5fQ;$xmXb4tB{wtPc5Vj5qKg8BwsE&3?qt($*ulqp2RZG#+3ZUf5TH~37GJ2e1#J} za+);^XI}dVoY&VL8@@_)&KsPXoLZdPpXxi89~%SbhW4p`*ZT!V3~C4NC)y{#!OW{} zR_8i|trfRY_#2)had92NfvO(8z}43{WR12ak3-EG#;K9N&_BWZLzsBHmc+=G;&HC4 z9j&`Sbbte{Y1u-y_*kbmNstp`urxRux=-NGB5zPTzD@2&R3b;K9jQaFqOR<}jCQ6g ztcP`l_fcaNvzRlK;J zQJ>oEP;tIAIqi4bSLa6)*UvuV%d5UEAH+V|6 z|K3TCdaADcowED&^a_)Cnojwz8GY5M(A0D)thUZ0!>VBsfeZ>&z_prF3rcyJN_px1 zQ`eE0;+jx2+U3J_MP zbQl>WdS5uC;IL4qRbAwNqi3;!-+O%upeZ8e))xpHX6lnBw^`T{P-OW`~A72{9WGNX-AT~*VOx>H%SFb(nAOw08k?? Rj>tG30PL2r^wH7}{|_PSraAxs literal 0 HcmV?d00001 diff --git a/sgl/sampler/__pycache__/utils.cpython-37.pyc b/sgl/sampler/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ca545afddaf08be55eec4e527fb3a695f1cc4fd GIT binary patch literal 1093 zcma(Q%Wl&^a36MJ=T%x&Efk3Z2T&v+mx>bt1V~5-^*|3umXNh%t;DVEb$2(kk=Cd5 zC*YLa;8Xa8y>dX|132^ovreJm##l3+ot>TcJnME_2q3!qDv5lAzPZC6puju@RJVX| z#Bqiq%n7IGC?Y)I4cI;$~7`nJK7>hg^D0MP8eEIr*Ha#1tQ8rL4FSA4rCOXa3gK}y&!Ao39!6!o< zOQuudilAI32yekB+rTX`knf*9I{5lz2z9Gr185Vl&UpuP44R__WOoBq{s>i#v!5Cp ze}-S6cR#8o+!E7V;5l8ODloy3H}|=#OsE?soMD5Hk$hts^QQ3@7^9QDlRX`na2_tu z3|Fm{R@F9bPIth|F_E8jYu+|3C)q=WZlmi!chS0J7je(sfRyTB;Z>bt|G!?*e_!d2 z2~B54s;=o8a_q?q7b(AtbUo^~2<&%F12XHG9(30~kgkBsUjWkeS$2T*do~;;I#y#= zCf1j%7$vp=5-Cotuf+En-%%Y^qWffLz@Hf^Y|Wq=$R(@8A~j9UK7nknEF|<4BLRh`bG&eNtCk ob{+3oAF?Q?b=m52)jDgg-3H0pfAA=W)yR^^u7Aqyct8UGH>u(ucmMzZ literal 0 HcmV?d00001 diff --git a/sgl/sampler/base_sampler.py b/sgl/sampler/base_sampler.py new file mode 100644 index 0000000..24ec2ac --- /dev/null +++ b/sgl/sampler/base_sampler.py @@ -0,0 +1,13 @@ +class BaseSampler: + def __init__(self, adj, **kwargs): + self.adj = adj + self.sampler_name = "None" + self.pre_sampling = False + + self._preproc(**kwargs) + + def _preproc(self, **kwargs): + pass + + def sampling(self, batch_inds): + raise NotImplementedError diff --git a/sgl/sampler/sampler.py b/sgl/sampler/sampler.py new file mode 100644 index 0000000..c1b5b0a --- /dev/null +++ b/sgl/sampler/sampler.py @@ -0,0 +1,317 @@ +import torch +import numpy as np +import scipy.sparse as sp +from scipy.sparse.linalg import norm as sparse_norm + +from sgl.sampler.base_sampler import BaseSampler +from sgl.sampler.utils import adj_train_analysis +from sgl.tasks.utils import sparse_mx_to_torch_sparse_tensor + +# import metis +import random +from sklearn.model_selection import train_test_split + +class FullSampler(BaseSampler): + def __init__(self, adj, **kwargs): + """ + In fact, this sampler simply returns the full graph. + """ + super(FullSampler, self).__init__(adj, **kwargs) + self.sampler_name = "FullSampler" + self.pre_sampling = False + + def sampling(self, batch_inds): + return {} + +class NeighborSampler(BaseSampler): + def __init__(self, adj, **kwargs): + """ + Neighborhood sampling method follows GraphSAGE. + Implementation is adapted from PyG. + """ + super(NeighborSampler, self).__init__(adj, **kwargs) + self.sampler_name = "NeighborSampler" + self.pre_sampling = False + + def _preproc(self, **kwargs): + allowed_kwargs = {"pre_probs", "prob_type", "layer_sizes", "num_layers", "replace", "device"} + for kwarg in kwargs.keys(): + assert kwarg in allowed_kwargs, "Invalid keyword argument: " + kwarg + + if "layer_sizes" in kwargs.keys(): + if isinstance(kwargs["layer_sizes"], int): + self.layer_sizes = [kwargs["layer_sizes"]] * kwargs.get("num_layers", 2) # default 2-hop + else: + self.layer_sizes = kwargs["layer_sizes"] + else: + raise ValueError("Please provide layer sizes in the form of either a list or an integer!") + self.num_layers = len(self.layer_sizes) + + if "pre_probs" in kwargs.keys(): + self.probs = kwargs["pre_probs"] + else: + prob_type = kwargs.get("prob_type", "normalize") + if prob_type == "normalize": + col_norm = sparse_norm(self.adj, axis=0) + self.probs = col_norm / np.sum(col_norm) + elif prob_type == "uniform": + self.probs = np.ones(self.adj.shape[1]) + + self.replace = kwargs.get("replace", False) + self.device = kwargs.get("device", torch.device("cpu")) + self.adj_t = self.adj.transpose() + + def sampling(self, batch_inds): + """ + Intput: + batch_inds: array of batch node inds + Method: + Neighbor sampling + Outputs: + n_id: global node index of each node in batch + adjs: list of sampled adj in the form of 2D tensor [2, M] where M = number of edges + """ + all_adjs = [[]] * self.num_layers + cur_tgt_nodes = batch_inds.numpy() + for layer_index in range(self.num_layers-1, -1, -1): + cur_src_nodes, adj_sampled = self._one_layer_sampling(cur_tgt_nodes, self.layer_sizes[layer_index]) + all_adjs[layer_index] = adj_sampled + cur_tgt_nodes = cur_src_nodes + + all_adjs = [sparse_mx_to_torch_sparse_tensor(adj) for adj in all_adjs] + return {"source_n_ids": cur_tgt_nodes, "sampled_adjs": all_adjs} + + + def _one_layer_sampling(self, v_indices, layer_size): + """ + Inputs: + v_indices: array of target node inds of the current layer + layer_size: size of sampled neighbors as the source nodes + """ + ret_nodes, ret_edges = [], [] + for v_ind in v_indices: # global id + st_indptr, ed_indptr = self.adj_t.indptr[v_ind], self.adj_t.indptr[v_ind+1] + neis = self.adj_t.indices[st_indptr: ed_indptr] # neighbor range + p1 = self.probs[neis] + p1 = p1 / np.sum(p1) + sample_size = min(ed_indptr-st_indptr, layer_size) + e_ids = np.random.choice(np.arange(st_indptr, ed_indptr), sample_size, self.replace, p1) + src_nodes = self.adj_t.indices[e_ids] + ret_edges.append(e_ids) + ret_nodes.append(src_nodes) + + return self._adj_extract(v_indices, ret_nodes, ret_edges) + + def _adj_extract(self, tgt_nodes, src_nodes, e_ids): + row, col, data = [], [], [] + unique_src_nodes = np.unique(np.concatenate(src_nodes)) + # global id to local id + nid_mapper_tgt = {tgt_nodes[i]: i for i in range(len(tgt_nodes))} + nid_mapper_src = {unique_src_nodes[i]: i for i in range(len(unique_src_nodes))} + num_tgt_nodes = len(tgt_nodes) + for i in range(num_tgt_nodes): + tgt_node = tgt_nodes[i] + num_edges = len(e_ids[i]) + col.extend([nid_mapper_tgt[tgt_node]] * num_edges) + for j in range(num_edges): + old_ptr = e_ids[i][j] + src_node = self.adj_t.indices[old_ptr] + row.append(nid_mapper_src[src_node]) + data.append(self.adj_t[tgt_node, src_node]) + + row, col, data = np.array(row), np.array(col), np.array(data) + adj_sampled = sp.coo_matrix((data, (col, row)), shape=(len(tgt_nodes), len(unique_src_nodes))) + + return unique_src_nodes, adj_sampled + + +class FastGCNSampler(BaseSampler): + def __init__(self, adj, **kwargs): + super(FastGCNSampler, self).__init__(adj, **kwargs) + self.sampler_name = "FastGCNSampler" + self.pre_sampling = False + + def _preproc(self, **kwargs): + allowed_kwargs = {"pre_probs", "layer_sizes", "prob_type", "min_neighs", "sim_threshold", "step", "low_quality_score"} + for kwarg in kwargs.keys(): + assert kwarg in allowed_kwargs, "Invalid keyword argument: " + kwarg + + self.layer_sizes = kwargs.get("layer_sizes", [1]) + if "pre_probs" in kwargs.keys(): + self.probs = kwargs["pre_probs"] + else: + prob_type = kwargs.get("prob_type", "normalize") + if prob_type == "normalize": + col_norm = sparse_norm(self.adj, axis=0) + self.probs = col_norm / np.sum(col_norm) + elif prob_type == "uniform": + self.probs = np.ones(self.adj.shape[1]) + elif prob_type == "locality": + """ + This sampling strategy refers to GNNSampler [https://github.com/ICT-GIMLab/GNNSampler] + """ + min_neighs = kwargs.get("min_neighs", 2) + sim_threshold = kwargs.get("sim_threshold", 0.1) + step = kwargs.get("step", 1) + low_quality_score = kwargs.get("low_quality_score", 0.1) + locality_score = adj_train_analysis(self.adj, min_neighs, sim_threshold, step, low_quality_score) + self.probs = locality_score / np.sum(locality_score) + else: + raise ValueError("Only support two types of probability calculation: normalize_col and uniform.") + self.num_layers = len(self.layer_sizes) + + def sampling(self, batch_inds): + """ + Input: + batch_inds: array of batch node inds + Method: + Sample fixed size of nodes independently at each layer. + Outputs: + cur_out_nodes: array of source node inds at the first layer + all_support: list of sampled adjs (torch sparse tensor) at each layer + """ + all_support = [[]] * self.num_layers + + cur_out_nodes = batch_inds + for layer_index in range(self.num_layers-1, -1, -1): + cur_in_nodes, cur_support = self._one_layer_sampling( + cur_out_nodes, self.layer_sizes[layer_index]) + all_support[layer_index] = cur_support + cur_out_nodes = cur_in_nodes + + all_support = [sparse_mx_to_torch_sparse_tensor(adj) for adj in all_support] + return {"source_n_ids": cur_out_nodes, "sampled_adjs": all_support} + + def _one_layer_sampling(self, v_indices, output_size): + # NOTE: FastGCN described in paper samples neighboors without reference + # to the v_indices. But in its tensorflow implementation, it has used + # the v_indice to filter out the disconnected nodes. So the same thing + # has been done here. + """ + Inputs: + v_indices: array of target node inds of the current layer + output_size: size of the source nodes to be sampled + Outputs: + u_samples: array of source node inds of the current layer + support: normalized sparse adjacency matrix of the current layer + """ + # NOTE: Should we transpose adj since v_indices are the target nodes in the process of message propagation? + support = self.adj[v_indices, :] + neis = np.nonzero(np.sum(support, axis=0))[1] + p1 = self.probs[neis] + p1 = p1 / np.sum(p1) + # NOTE: Should sampled contain repeated nodes? + sampled = np.random.choice(np.arange(np.size(neis)), + output_size, True, p1) + + u_sampled = neis[sampled] + support = support[:, u_sampled] + sampled_p1 = p1[sampled] + + support = support.dot(sp.diags(1.0 / (sampled_p1 * output_size))) + return u_sampled, support + + +class ClusterGCNSampler(BaseSampler): + """ + Clustering the graph, feature set and target. + """ + def __init__(self, adj, features, target, **kwargs): + """ + Inputs: + adj: Adjacency matrix (Networkx Graph). + features: Feature matrix (ndarray). + target: Target vector (ndarray). + """ + self.features = features + self.target = target + super(ClusterGCNSampler, self).__init__(adj, **kwargs) + self.sampler_name = "ClusterGCNSampler" + self.pre_sampling = True + self._sampling_done = False + + def _preproc(self, **kwargs): + allowed_kwargs = {"clustering_method", "cluster_number", "test_ratio"} + for kwarg in kwargs.keys(): + assert kwarg in allowed_kwargs, "Invalid keyword argument: " + kwarg + self.clustering_method = kwargs.get("clustering_method", "random") + self.cluster_number = kwargs.get("cluster_number", 32) + self.test_ratio = kwargs.get("test_ratio", 0.3) + self._set_sizes() + + def _set_sizes(self): + """ + Setting the feature and class count. + """ + self.feature_count = self.features.shape[1] + self.class_count = np.max(self.target)+1 + + def sampling(self, batch_inds): + """ + Decomposing the graph, partitioning the features and target, creating Torch arrays. + """ + if self._sampling_done is False: + if self.clustering_method == "metis": + print("\nMetis graph clustering started.\n") + # self._metis_clustering() + else: + print("\nRandom graph clustering started.\n") + self._random_clustering() + self._general_data_partitioning() + self._transfer_edges_and_nodes() + self._sampling_done = True + return {"adj": self.sg_edges, "x": self.sg_features} + else: + return {} + + def _random_clustering(self): + """ + Random clustering the nodes. + """ + self.clusters = range(self.cluster_number) + self.cluster_membership = {node: random.choice(self.clusters) for node in self.adj.nodes()} + + # def _metis_clustering(self): + # """ + # Clustering the graph with Metis. For details see: + # """ + # (st, parts) = metis.part_graph(self.adj, self.cluster_number) # 每个聚类属于哪个part + # self.clusters = list(set(parts)) # 一共有几个part + # self.cluster_membership = {node: membership for node, membership in enumerate(parts)} # part加入key值,key为节点序号 + + def _general_data_partitioning(self): + """ + Creating data partitions and train-test splits. + """ + self.sg_nodes = {} + self.sg_edges = {} + self.sg_train_nodes = {} + self.sg_test_nodes = {} + self.sg_features = {} + self.sg_targets = {} + for cluster in self.clusters: + # split train/test within each cluster + subgraph = self.adj.subgraph([node for node in sorted(self.adj.nodes()) if self.cluster_membership[node] == cluster]) + self.sg_nodes[cluster] = [node for node in sorted(subgraph.nodes())] + # map the global node inds to the local node inds + mapper = {node: i for i, node in enumerate(sorted(self.sg_nodes[cluster]))} + self.sg_edges[cluster] = [[mapper[edge[0]], mapper[edge[1]]] for edge in subgraph.edges()] + [[mapper[edge[1]], mapper[edge[0]]] for edge in subgraph.edges()] + self.sg_train_nodes[cluster], self.sg_test_nodes[cluster] = train_test_split(list(mapper.values()), test_size = self.test_ratio) + self.sg_test_nodes[cluster] = sorted(self.sg_test_nodes[cluster]) + self.sg_train_nodes[cluster] = sorted(self.sg_train_nodes[cluster]) + self.sg_features[cluster] = self.features[self.sg_nodes[cluster],:] + self.sg_targets[cluster] = self.target[self.sg_nodes[cluster],:] + + def _transfer_edges_and_nodes(self): + """ + Transfering the data to PyTorch format (except for sg_edges which are coo_matrices currently). + """ + for cluster in self.clusters: + num_nodes = len(self.sg_nodes[cluster]) + self.sg_nodes[cluster] = torch.LongTensor(self.sg_nodes[cluster]) + row, col = np.array(self.sg_edges[cluster]).transpose() + self.sg_edges[cluster] = sp.coo_matrix((np.ones(row.shape[0]), (row, col)), shape=(num_nodes, num_nodes)) + self.sg_train_nodes[cluster] = torch.LongTensor(self.sg_train_nodes[cluster]) + self.sg_test_nodes[cluster] = torch.LongTensor(self.sg_test_nodes[cluster]) + self.sg_features[cluster] = torch.FloatTensor(self.sg_features[cluster]) + self.sg_targets[cluster] = torch.LongTensor(self.sg_targets[cluster]) \ No newline at end of file diff --git a/sgl/sampler/utils.py b/sgl/sampler/utils.py new file mode 100644 index 0000000..6304d07 --- /dev/null +++ b/sgl/sampler/utils.py @@ -0,0 +1,35 @@ +import numpy as np + +def dot_product_ratio(ori_neighbors, good_neighbors): + s = np.sum(np.dot(ori_neighbors, good_neighbors)) + max_s = np.sum(np.power(ori_neighbors, 2)) + return s / max_s + +def adj_train_analysis(adj, minimum_neighbors, similarity_threshold, step=1, low_quality_score=0.2): + nodes_num = adj.get_shape()[0] + sample_mark = [] + + for i in range(nodes_num): + adj_coo = adj.getrow(i).tocoo() + neighbors = adj_coo.col.reshape(-1) + if len(neighbors) < minimum_neighbors: + sample_mark.append(low_quality_score) + continue + else: + avg = int(neighbors.mean()) + neighbors_length = len(neighbors) + if neighbors_length % 2 == 0: + good_neighbors = np.arange((avg-neighbors_length//2*step+step), (avg+neighbors_length//2*step+1*step), step, int) + else: + good_neighbors = np.arange((avg-neighbors_length//2*step+step), (avg+neighbors_length//2*step+2*step), step, int) + + similarity = dot_product_ratio(neighbors, good_neighbors) + if similarity > similarity_threshold: + sample_mark.append(1) + else: + sample_mark.append(low_quality_score) + + sample_mark_np = np.asarray(sample_mark) + + return sample_mark_np + \ No newline at end of file diff --git a/sgl/tasks/__init__.py b/sgl/tasks/__init__.py index 195868c..82db4f2 100644 --- a/sgl/tasks/__init__.py +++ b/sgl/tasks/__init__.py @@ -1,5 +1,6 @@ from .node_classification import HeteroNodeClassification from .node_classification import NodeClassification +from .node_classification_sampling import NodeClassification_Sampling from .node_clustering import NodeClustering from .node_clustering import NodeClusteringNAFS from .link_prediction import LinkPredictionGAE @@ -17,5 +18,6 @@ "LinkPredictionNAFS", "NodeClassification_With_CorrectAndSmooth", "NodeClassificationWithLabelUse", - "NodeClassificationDist" + "NodeClassificationDist", + "NodeClassification_Sampling" ] diff --git a/sgl/tasks/__pycache__/__init__.cpython-37.pyc b/sgl/tasks/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00f5e29344dfb784fb9fadfe9c085172bcaf9426 GIT binary patch literal 854 zcmZuvOODe(5Vf6;6X)Y31EY~ZELgx~G57)q5e8;hL=lpOkb2WHt|n=v+s?Qh!CZh7 za0#yDEi10TimFb+Xo$Pzr{}8Ys;<|5GZ+L8uJfBK@x9?Vzv$(#DDxU0^8*1KaFwH7 z=DNf^WZDvZm!t)zWR&w!DShXTIfa)zMwnwcJM@u)y+;>ghh~Ti#UzJ!C`6 z0~P8K8zJvO9|kal5JoVD37o)#Wx&SpFm|S=zg;}q)UD3mi&7N%A_sA)5@X~SnI@&o zGt5nDqmPcGJWljR$!v)q4R`6S!Lqx^zRV%Am|exF8_#97{!j>zDYd(vy`zrzgL~=^8=2QF+6Z2ka3R13R%s4AogA^Bp3*tAydq-&>wy(4}I zH~vztocIfzcE;_ptQtpodY*Zi@jNFT3?c&S?B-JbXb|!nH~+5yHgDl|-vI;>R1-rP zrN{%|0Si2D=!R*srsqxFGHurOyahaDq33PgFkn26m0`Q9*?a8xIJlRd)pY!%$`!vTtd&_g z%SB>}qFh0R=Z9L9&eP9Qe{>QQPYxCDlq;u^g{vbL@5;6i+6W;+2ce6A_iTFzO@PI8 zQXS`D9nT(lw=06>*t426cYC{h5jNUO?hShAUX8%iHQ)Sm>ESrlB-;(o(_C=p$M;8M zFX}peMY}Ybz_(8NG1zIHP=FPFz$)ZC)td8P#Ex-bgugH2Rbgc8+)bQZ&0p!;Z?RkI ixJ=#J#hj~Lm7Je%ZmZ5~Q;4mWZ}2~JCjfnV@C0`w1?hVL literal 0 HcmV?d00001 diff --git a/sgl/tasks/__pycache__/base_task.cpython-37.pyc b/sgl/tasks/__pycache__/base_task.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..981375b8c9d9c9f4eb85620046281912342e8e7b GIT binary patch literal 752 zcmbtS%}&BV5Z-M|DJU`GOF*LZ1&kpWB_^g9JoU1v%&Nrx!0y(7S9uvi^Qv}IE4lJ~QgSBL#4K_Gv9jJkewgz?Z(7MpTKFz`#L|WcO`gzQ;n1n3pC8~B1 zfMF&N-Vq8%%lJVN;R?Px{!^AnU+bAa8s478i@C2KW8Yvy?cZZ-L2hrlC_Eu1>BI;@ z#hQTe*W-%4S4H=aVahRZh?5{QH_0@XiA;@zE2T2k<%*WJ8BNjX$SUV@oEurg(+<|( z=pR^>iyii`G=zMK;yg0a+oEZ2UakuNlgAZCMU(Vki{xlSvWf-WI;}d55-CK{6e7tW zkI6q2VwOj78Bw$kRVSd4RfAytSxb)=sH(p~Q<}#;o1WEefXkrhuCl^AFL^I0;|rDj E0G(Ei=>Px# literal 0 HcmV?d00001 diff --git a/sgl/tasks/__pycache__/base_task.cpython-39.pyc b/sgl/tasks/__pycache__/base_task.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08bcf452f23acf9279a6f290b7074f66fd968176 GIT binary patch literal 772 zcmbtS%}&BV5Z-M|DF`v*OF*LZ1&kpYB_^g9JoU0!nN^AXf!)@CS9uv;saG%Fy?Jt` zSiyuC6Q`NKoz8sU{>XN_$xynpOZA5Sm6C0-F&Utm4w@kY27HP7VucK>5$oAwT>l{! zW&>1nghsKD0}CzSU@ci_gAERP2WsG=uR$F=^e!}TRA*4juvy@+(&+fLydon;{}EN>7ww2m?V=- z2x4r=F#dc{G3Zy-->GCsMN=7(b@eI??akw!O<%bh^myrawH4mU8uyg? H-%;5o`mK)C literal 0 HcmV?d00001 diff --git a/sgl/tasks/__pycache__/clustering_metrics.cpython-37.pyc b/sgl/tasks/__pycache__/clustering_metrics.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9e0e38aca650430b50f91eb8a0ecda9f5384d1a GIT binary patch literal 3268 zcmbVOOK%*<5$>M%zW9=?NQe*#@x^Pzq{%o54~$@2vK$05FEk)wAY(9`?p+Nzk7duS zNiJq{p#aGT9r6#B$K3LJ`kG7l7ktWBGs`6<(M^EL) zJ;we;i_6c%TB|5iHgLlFeGeKIb)* zC+)Q3cRBl(2}iiEnQ+exzbE(}>w6dAz;^pwyWfqoLnRBErm|FVL{sbgz1?qo^Q+-7 zBE-vYA>cM3lKEUP2*!o+n)#+Mg@w@)ws0`o5=@2nnw@dq6)n*Q%o80j@9MQE85gBg zaW)L*Oe3OAAB_guW3*xgUBxRFu*yIWemKu}`-X0pYAk~!+?R=NkCYV6yl-l&kjX$f zIO+FjBaX^u?;lLH0R4G$<5ix@jiNZ(*!%X`XUXZwMlnn_%CI;rHhwAD`sh@9K@ewg z83bLD%0kDu!zaJK0s9g&;)I3o4w{Qb7c2e=Xvs?c);KWE_!*z_(yVAS=1{AKnt+;s zTDZ1#XsT_pT4^6Rv>q6xD~xx>)IJ%M9-V%|e(+x!Q|Bl4h7EXa2(R=S9#3F^se7^w zOi%>o?||w3#NY5Y#=wBFtk+)U)>tiYZJqJTuAFb;hRc%AU2) z*{pNUrgqhybt`zI}l4IRHd9sjcn`D~yDzoYkMS*sY5ogHg65m=zoU6Gn zRlTZPErIL5!Lc`6BKfQCxiMWXdk4$SN_e96&YZ3Qvhp5e6_C~UAZyo7wd+&2uAOQi zx38UQARkrB)k?KmtyQ6pvO_~z75}?&WbHN1S9>6w)L8#q8}X(#$=XniNiqcABRVRA`eAu17Bh)QH@QSVIoTYi8ic8}d4Y@aJ~+`gd@ zshC=-=MZ*Fd9d{5UWzRWHwH58sUj_rrjXO z$Xt#~Kro+2N+g#ym#)^5Dcli?Fi{F?$yrOrRBG#`!@b-m&5qome$7scu z=t?$2%0TR32PUEi$J)BK27G3pvzPAFoLW;G5z3u<(-xu>uZ)4&Bp9c$?)@xRX_&+l z80bV30Q^tUal5qFbRjpo}I06=q_-{U^xMZ#QN(vozKouxc++pJPjr>#8m4 zeox&O)ct|FKT=n}uEi7659E(wG7d|))8i(2ewPcGJXLx6jFRA=ai$=1FusESjxpJ~ zQrAL*|ARAKT7gaJpfF*MHH5JZ+%bm z7up=eLv2QRf62T0mkG6+O%1m4P$r!B_+8V{#p>}e3?oeKQ z0$SS42ZM=6+}20g@ZzMVj|y)TPgb9z5Gr0tu@DAkv5!>@1ChdXZ-6?f1hHa(x1-Hb zQfdb_8!M^)LI_rP5*NN(AVve+cbf-dl}94=+pl6FWERMDUws2)-ztkt`p%1hM)fs7 zAKJ7_V$PpL@mCY~=b!PUHOH!g_S&dPnanMT$Bvj?5ty{Qm0=srmBsx?@8!Q|wRy#v_rTw60 zhLJ=pPcmTV;6wibRrHd-G>2Y#(z&;u`ra-@iK&VJ*~QMxdo%Oi+xH!Jy8(gc!{g`T zr(Ht+g}}|n0^n2l)JqsR;WQ&*Gt-a(%rYytLz^Pb%ADK{-Fj_jUhaoJC4$Ub+_|7N zmS#cT4m*^*Al&8NYr?&A7Irz^Cw>1CIFP+QHQpDK;z$aG%Uo!gCb+c!yubJR(E4mN zO7P Q?(zzPylPCz)y*=rJ7+~PLOHg~uSvm<~p_g|B98hX6N1Hk#b4FYt`W|B>m z7BVeH(SmV;w67jqGx(mur+P3{v?39y7>vMkpL|RA`pg7cPDGT&2O=}USPI^(`(nL3V-x7&-_|NXnt1=fmN*(X)zxe#)?Cf+$jj|mbt0T4ZztRoHXU31B zv`BRnbxN{;n(dWOFlviYUj{Jc*wNJ**i9Kw2yb6k+1&plFi(2$Qv@GwK;t4HyoaU8D<`w z18Q*YBRJ=OL*LLhY{N73i*gB_`Q@Qs@_) zW=<8%I~91_tgW57t1aLH5(C^>%JK+hnRP3xY9lQU-#vmnK{gLL*0RB^n(tcGtvb~j z@cXCk&eu@ts&m0+>$-cm-ay>vt#{U}2bkUpW&772ZEsQ8iy30GW0NXV2-$(6d42$77NL7ncYsB0T zr9oVY(m*J4t*DiQs5!Gdh|t1)y4dXENljc#7QZ2I7u-c^wpT%*9_S3C?o0`2+~KJ_UKn#`eSKo%i~ z7UU5Muzdpt{M@-9FTI&Hvu6%uDsSe`T9B=1#fBDys|!T>Ze!#7MJe+*OQ%p{@`;|r z8PtMd*%-3+r^c1;#QYE|8;@jM@Wu9TkWkPV+bZ(Z_;9NB75Q`6lAqx43mkrl!>@2a z&Bo}Xl8l38BQ$)WgxyYW&GXES4;cwxN!k8tXN3O@f7-n4K7^Bt~=J3fnOU1x^meS*hg*fHR&-VhvH(_YM8hSi^K= ztfb7K*n?e6`??uW5CNz1BfR5`O9<)jD$z!w&tPAnk<`7z(bn9Ca_OS8VM!ekZvUNfr?M_U4Y2ES7w=Z?p@-cP zlqv=gDmZ9bxw-^R@9*na>IW!}hNkJkMB#{y+F&ra#4H6luTw;rn94#FT7i-*Xnc$_ zl0p>5KFB7J3=d>^QiSe!BFC99?wQE4@+7p!X>k;~CHU+}Tz&}Dj8zVYQy=*a#>MFJ zwB`Uk-Z-6ZJcl+&y%KyW4)AhLWD2H1RhjM(x+D$6)DWV^SmR6^7pygrLjDmC*zswq zz?_hy0p0hS(6IF}&%@wV%DE^aksrtxa5%KJDn#gh8R0C^Xy3JF2gO_jgF*&%%SdRz z0{DEoUFVkR`~SzjS2y?iE$s1%I2VxXgDA>NKFM(1iK1f&`DRCA`zJA`!mAw8xi7osMA(z1`F|06GDx~34-}XQDJw^kk-s!KIAdfZvO_m-Q2WrCjOk$-D zy}aqv7X?WjWg?bEotwhm0vhCTkq*m@H&(DFN|=pndeFimW#cxp{43C{+1@6p@gFyT T{{&09!r~9&(hk0$)sFRFM;txk literal 0 HcmV?d00001 diff --git a/sgl/tasks/__pycache__/correct_and_smooth.cpython-37.pyc b/sgl/tasks/__pycache__/correct_and_smooth.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8909a19c3b07fd722af67aadf0f8a5287f62a4bb GIT binary patch literal 4686 zcmb7I&5zs073c6<5~bB@KfL~m<-|>+z}bxx#|7N9sXyZMV7NtXAT7~?pf!|NT2hjS zRI;m5r!|mV4EWSb5kPzFzciRI^>?Ve; zF)gY1RYfu5TH^Vh!mYTTH2g-g=C379znQH2>q*ORB^&-mvgvQCHFms}Z2Q|x`%2>u zcVB4Sjl5&SKgZc)ty6go*|bL;Ce1rMOwd>FhgtY=6!J)5%iIHVV?CB#AQ7 z?*_w>pw~C=jf9B0*_|POoQy`oIO~=S9W0BAdIs@o#C%Z6Ex>AC2dw2bpqDoQ>scdrayPH!)x4H_$27B+ zw*Z^@I^a4dj;(x+YX=qaA5Oel*@k-KCa4D2x%raby;V}T^CoBu>ba7-^9J>Nwvk;> zJT49{9kVy)F6Y|(9muW)SGWdsT8}!q)MFvNZ=?RR{Vd>7H#~aqq+`fxVJK;T8c8SJ zpY(chBr9Bu#sQR{b=uNOC*w$nIus|}Ab@dYK_EW~-u8fkfWwMmrIH=V8!K#!!>3W4 z%8ysrlE7$Up_8c9aVSFAb0kuEa|IToan?_u4Y87}(RkF|Pgl}PJWA7`H(cp8Jo={J zjpW`6-$SlJc%+ffR@gNjr5U_Ozk47*U13|Q+(aN&(s;Mgr}ZN(8ZP&UptbD9=;3pM_RRp0u<$gDdpf(cW)+!>nsjBPJDeo&#XRzF zufQ&h1;?djpCkk7Pr_cF4-ECE}EoyyEJ~guui{4 zVlNIg4v%mH{rX}Y#^e3aU#BIdv6L&ka;O%_N;t_zX*Yy3vWm`G+!u(0VAm`T?u_?Z zQy+aw@GFaV3g2BACZbv;2)Yzkg5Y1;PdE1VM@h7orqA~tfA#RA_~>vieHQP*Umc`- zi~Zq{m%{FjkHi+fxq+761klWeZnG*iJa+M5)F2GKi(Y4&XW9jQdgEAH&Tlkl8c$ewVC~j9p2A8?5Q8K``Rn^x0^GkT<3z?;MRdF{+F#0Z!~Mr1gxP| z(GZ^ptz-MI`m8dm7I797Yp|){7NA;rY+*%h<_-UpZG(2o=iFii`hu5J^4eLhBVOmx ztO4<(m|gIsmEZ|Y*z=3I!<}>53F9vC3i=9qXn|MJ*U;C{d+5FS#p3I7cL(c_ZD_5Q zn+K-&KD#u(%#oeoP4ArOGs2Ad70OqjPkr7#(PoY8>fqhHf$X#KQlHs*ZD!`)%zFO) z*UlG6}E%rkGb#?A8Z!Kjrlx$=TF5f@~MQ6O3 zTW9!Ku-dIh)9pWsXe>tE2#MbCS(}fBQM*5EPmSH1J;~BCo71!h)B2YbV{fu>(|>)X;kp*ToZJ+#T5W)(r$5{(CT=pEP)C~lLk_E$quD87X4u+^=u^G zp^ZBPC><7;2)svtQebh7z(oR=2~-HYPv8Rrl+ue|5THs*{E|QiK-xW21Yw2}V?Qel*imzQyxmz@G${>TaadF z2Jr0;QV-IbbV(<4aI&H>(vDwK=utq-rKMtzWSOjLKkL6tB%p7pc zIVg~vo{medqxcXcu4DpduMF9wSS=`kC_7fu;$0eECC~D; zjvfM4*@tv-l~9vqN8&e-O&P6{SJunirW+v%3{F1@S;1{uNx6Wk;IaNV{70J-&h01?RUOsIB(5>~<)6zY5P51Qw1Am)eXaE2J literal 0 HcmV?d00001 diff --git a/sgl/tasks/__pycache__/correct_and_smooth.cpython-39.pyc b/sgl/tasks/__pycache__/correct_and_smooth.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7af1cd483970601f08fdb57ac39d33f27dcbe42f GIT binary patch literal 4732 zcmb7I&5zs073c6<)W>SIA6~~_a^j{@;OxeU;{ryS#IfU`fV(Mb14W4z1g)XGl1WJ( zQi)fkPHP}113C3l1kfJyFU_fFbMY;Qo?5i-?+q#KIwwoSym|9x=JD~(oA;rxUiUQo z-hFa6+S=B%f70Oe&&1#}wDhk4T;nX({Gw+*Q)4|be1j29kIlsLErlDgojAUua2C6X zp=(S_D!!*EW?W5beof(4Tu&N)BU$s;lBVBG*8TOQ<+qXze&q9{SK4nEgmN5y*ptR-W!EH5*WFkhiUX6Odod~>4x3z zM184Q@q*27Bm5H!Fl+C0+><16Sta}g?6gT=hj3;3h$&FKUOX60@2SGLp(!(T) zGSTk_!;zrZH*b%Gh`QOWA-|uDM%h7^1U&t7%R#yiB+^KgzRnrfBV%qJYuw-_9*bK& z_gk&UxXqns72k?%?()hr?TGmf_jvW0=DWPc>*y=rYNqCUP)9?)vp?d|?Kn)+ey`sR zv;Jrp{Gp#61V5MS!J|TspFVx@F0lP_`1en{8>N8M5$>QFXz5P?nU=9!JJRR+F-|oH z9vHbcFa{QuO_=xxCtj^=L%nemRDd`|D)DyiFf6SM{OOiA5&g?cvI z$lg#q&JE5Vu~+6UaAT!Tuj{f;j6SjgHpQUB3F7VxMW9^QS>F{D=*OxmAD z(n${{y2s$qwbU6}H9U z!zfPW`zvfoU^ubRNmOqfiV$`kiBw)+fyHQ?^%H1AtR!nR9(51Wm9!F%(lqD|S9%SP zzV3G;xwpdikc$vC(#S_E>>7{K3|^$)eJnp*VOy%)MIctvc(?F2*tLp;Cmv1E#g-N2 z^MXEqDu1y`C93~Un&l(2ts?x{caXPZ6v~tRYagoB?n7{!$*-c>B>lVI7#4> zdF0<*fn6BK0r~g^OE-v~M%@Xbz$z!Hz68OhvmwniihI&o++wS^WNT2mXp-jb()jJd zI{g-jy*Si3Jj4z3>x*$1j}JnBotBiwQm&|#L$yFw!bvttyCIyBRdmkczCauVy=HN6 zr@Ysi`sfpaUs=3U`0m0me_*l_1YL?PLGXk2?`wMpqa@l((R3R}XskrA;HkK+{rX$MFU0IdO{B%pJwFK6}4rQ~MWn9IPsY`fH&Tunl8%|hwV9pOu)co8?r7h!KVP3Ye3s9{*vaq5$s||mjZG(2g=geXS`hu5J z^4eLhBVK3GtO4<(7+&zCmEZ|2*z`tMD3HtJ2^G4Zd(PDsD5PTb|KrDTSZjO@VW z8_1{VbT@PB6dwyVyR|>v{+)=%V$_Y0`VAko`Dhrm`@{Cs*uCD9EG_dpO}jg-e?bxV zX8SKU-~W)$UeFzwwv>xdj)5kyXEt4jjThzSx_PUIoT7dE?fX~TOGeYG!s%MvY)`kpDsl*wxwh{jV@eMqetCNC!6*ykHt}vp zaX3yRzI?~j`y89wN~~%9YWpFL3Vxj{f^N3B1Rzb?EzT0^jhD(2sC_hPAfcD+skFwT zKg^__jl@OTxI=&vWO1IrTLdT}7FP(IBXEI0g}~bc-XTCqzW4mmY~eqq&O}pkSMEG)8b7UUMA26 zkWLwz3)3bumUa>1g++^&M%-}=^Rf#I@vF<2&Ii7?NI7tAOtMK?>-C4ooJwo>BeC)c zyc#{lSD?jc=^FrK7cJI8px=gN=xqC?S;FZ@uOHk`Ux_N!Swj>rsI$Bvz>S*3gx7-L z$s~-+8GVSF|F%FC@x?Ghy>@sHsLD>Hi_3+YEjtvSKt5&0N@iIvbDwU6Br!PkSvP4N zWd^FwqgpSe+|JcaR%7OW8&`DZ8SV`=+u0GH)5_lfAdw&+usa+Ia?XCW7^)(wsKc5C zWs>#dbeD3-B3<8D?Az^4hFvQ0aj(4-T>O@}Zy*`lRms@yLzIvKp@RO?k!vS2igj`B z%1EXk5aaaP7QI4Hi`l2gXs7-SNvanc#h>D`^#x$Ws5 zdw1MTAnU@A#6~~?PC|;=;)#DCRgj{fiU**0q2K}1P(`ZXMbEtP0>AH^o}Qi6NM&Yq zF8!VF+`sdk@B1CSQYd5<{I2iccK_>XMfp$qnEni8UPAEx8WB^N+Er|ssmQFV^T=goFMI@fFW+s@sl-`W*iNfWa8?8Ayh z5AzLrtpRtVO0Rw33LX=aWd=)qn6Xo?$xN31P&rcV zH1Zb9h`hzJEcc;eXIP#UkY-ttm5}B>QVhk;V*|=j<;`|~?`^r3?Lq(bmv7y_jy|_` z16)HxT>2+s#J zwhN3(>p_ModzpJG|2syl*8RuOE{!jX5_3a2(%;p>d|2==PA&!cpm3~(MZrVk&j;Cr zrxcb2Pbnyary_VT_fGKyPc>W;Jk?+cJj;To3ZA7CJS$;M@T>$i@T~fmgCa(vvlL5z zY+!_sgHi#1Dxq8fWqFQr6_nLE$~Ej+<5+o@*7G{T=a>S?u(rm_w{Vmo&fT_O-^S7L zc3t)u$Hpt~K+d%;kMi%k?VVlUVQ#B=c>BJFqJ|Nr8O|_HYm}8|&1;X{DDCZzwzs=( zlwo`@bh=H?Z#)sDywT9*yo8F8q^MFt~&a}}LXUFZkyy?605O0P%U8Qr#r&4(In zLch8EA)E7l+uwBrlgCMN7BAUg=(l?iARcWj>ed|&TD#t2iRlhJ&)M!TuE^Z?+ATL) zThyqyWfm?kT{@1x%Uy4G&}GrWEt*{;oK$CI(3KVy-Vw^5@;DltlxS36)VY+juDAjh zcR*qimoJu8x&|eKy--%gcAjX(a@+&AHS*nujELIfgO<}AwW5WBE2Ogd-VWWUt;Skp zcy4!_FVh@TLNIs*d0r)g8E=^x??8IGO{j#+y)5T{LUWYp{*A9~p-M8z2H2(e*iOaQ zj;qrQaqhdY1PYQRRv&hO8cXEaE>1sV?i_tjd5dyhN#mvDJLb*{Ld+gzX3dt8EDgi$ z+BqV`j4*66j?;n)^c?5Ul>gbhu{-FwH#~3u#+}#S{CxND;D)!;y#Ybl^KNu;Jx>i4 z+|QfCL%xn>B#r6aMxP5uyAC;*lO7e3-Q_E@`bfnsN;{lDk zIaS9k()jeFl)|eBo{0$86I~T53^@Kpd?R@t3IBy8dK0-wu2L*OuYuyxaaXkQ`3E}{ z3mUV2wvL}eQxEC5AonYl%P*bi`NBhbp2~32HJwwkxy78XAwzTSsJ=FVNr0Sjtgb4P z)Ugt(zBx$;>aiMXfhJeQ;v1;J&moF5Zy4$Bp&MEKQ4aX@DMu?BNZmTA7lRG4Dw8hno*7)<}6PR5+5AxtG<#L-&iCbRFH zpYIeHwP<#V$d_2^V?EU4HVtONuu6q6YLTWx>z;b>wqFK!hvs38!~Dgv2>>qnV)^)M&Qwheqs?lZ`pML36W9)W7CA6f4j zVKId97FPOO{squxZP#P2S-V!!E}4mv^CEJWV$Ss=bJiYK&6-?>@s?Rtj#ZXLK8G}i zG>tst!+txhh>l?bwHXz3`6wJ!##Jp=3UPeCzE1?4au zlmc_l;FF*fnn5MB_9tJ(4ox11(U;6HuBCBnzScdBZ%pf9)Jwutx7c#huOa#+1KwgQ zK}twLEl5-PLnG#14NO5_3#=I(#<#UiCuMzWT!RNP#Ibihuif9Nvq9gjxBK<6zWK~{ zqcBg{R3mWHp64Y13!s8c)0!SdM^ zK}nWwMw;#iEv8+a=OS8^5=Ky@!c3lts9i|v5nf>zla__3M!9J*I*Rg1q4=5>iPj5A z8GW9gXgj>Sh!i!eS?aQCsWq((L!gE){AonVgT4Vn;XJ-->SeV^CHT|g6!o;M zYN!V47^qXm{ev$VgH}^vkf;kyi`d>wE+N`R5;$=QeIFSpO=Y4ATtVq*OocE1jG*bh zxQIID5nW5TiY8DgxOCON#gq`p$Q@Z%`sphCWZlnA^0*ZAN#R&I);h(|@Jj);BCZsq z)gwb*I5C~B5xPJm9TyPL0#m3f1M14yx(!ciIBFP)_IX66u_8~7K~}7g_oyIDVg~~H zev>o)bxcsNxQO#X)=DEI1->8z zdfRRKBls>+dNA^bBj2-^;Mc?8oiEJ`w>f6!yVQQ0q7FrS6wMoW{2QQsg5cE)fyJNCc z4C3Wj9w46Krs%#Gz@3SFiIrnOFU2aXDgiyvm)Nq%XV?m>2~cm9t<8aY>rr*y2;6%4 z_MP!t^?72Qa8<8<7w#ylE&DVV3dPnNFU~~saeQ1Ck4fH2Lhp`E|z#4%=rV|{}^6C6rootLg=3qvOKj2#e zdGw1+7$;;bIwgSymgDw~J&k|YuTWoo^oa5*X$xmb;Dpc%lNB@Tdb;r-X<@&n=R_|j$?~I1Sfe%Ej{_5^v;JNjy zw3t_)K-xUGS|9NGRayIL{ju@d?qKNE$@mdb2oi_{b`$WK^khzh9+FRp#zOooE~m?LF@AQM zfj}l{P@Dz=B_}{7EaC&0k1XU4p-)1oMyG))$w5o`kRpLVoQ6zfEm#aEkRPE%qQ@7m zn7HoI$<$Pct~{D;dRadJt;mk)3Qw9k4N9G0nKsj~r0*yXsF0LppDPW>l7N!{}VU(UG%ZoeHV)R z$21T)r34q8a@s{_Cs!x2Cdl@u@G8OsK*jHY_>d0ilK3mC6L)w@W6nC9uJvc@@pm5D zV7@Sl#1Jk&Y6#N`$r#84rz=pb`#&P{ zAe6vVSX!{jn1<8{QwLB$9sM=sYrCNtyYzUI0oY(36Q>17KgINkN#D#sY9UP_%^)== zah)J7oKOSKuN7!;b-xu@fMLwg2r_|=@9RMZ_lX&%_OE|cL5p~mg2S4E6FzML00!J= zNzI^rEzpl~1E56%w*$Vm$@a@`Epk2^nz+$KG|tkIyZOP*r`QQ!Q@j^Wy|!`U54JtH zdI@>?q(4~qT&REMjn}r|i`4J*1_OVWeB;mP82+4)l&M=V*S~c6lKq|=L?bO7zmKtd z{8uRYE~2RB>^USO-FY-pXld?%qmh+bhaqq0PxfsW#Me|<_#Y9+15)>1q@#Ul*QVo9 zM^;kll)*_h+&>_A*AYn9Frnk5>i>qIyoh9ODG6u1Ssc1bv5-x#GyxXs|bk{ z6Ckz&zrqJ321(UAI&yf0{TMQg?^Fn&<>1r65U|uIAuN%J3ERc=A;lf|4Mza5V2eQB z!Zg~k0&9|mWI?VDuK2kC5(nAe2s1?Y75FItMVPh@KZSVqzd=9^pfVjc$)|Y9%*LfC zdz;2Zb&9Z)U|JsQlM=%z33CU(6j*5WjW91t^K+#I(8{oS^2^Gkf?8_)Zb{T4Kq$On z23hn|oK(rL5~KJL&WwDCq?bl5Ecwd-*8tjyuQJ9_IruteL+eu$GujHP%nI`dU!Xa% z)GA(QF`{eX65dJ6L3LdTmJlov;|pV{GMwmVl`wx#-TxA1UBMUtg{gr1_MeLTufTAo z{$ZP;6z~_$@G8KwHM~8?D<7CR#{z=fzY?!qZchUk_OH^Sqxaw=^$d2jzbjTYk5Y}7 z{SE(&G`~>}4zc5SW0pNm&eG^uW}WIiDe9F(y(h%FXz@yDeexQ`ZlTSWP#Zf*Eltir zgJ_fUxAPBlx%Hfrsg=-AjB=seW#f`T^F(!kcyeJTR}Pw!?a>*=>!w5?F%ACth+1Uz3p7 zRMiN>m46d`BpRJ=Q61q9f=B8dx*7lJxdcZeWV7`blzBZ>-$BC|Ohx_#+*3(k*Wf3`& zJu>4on~cGW)OMtHF#mRs{{T(%Kc?uE4*ni!_Ypi=Vg*`u873h?M5KiY1S+Z<(8Wbn z)2J+qP{S+KGR(*#C{^%O)oXy4iYPk|-Cae^Dh$g!XlMh_0bqm1XBLx2fsZf{{vJi| zBZ|`SUhrz^2=|7Mi9W*u=o;dmQbd;tr`hliC`zUYS$=~E!YC#;iA2F8XV(I>ae0kk z009M@7fw{;f)utZV2EL|Br1{r=|m zsNW(e7+{~)??>6$w-m*^*J=+BH|75=f{Q8s+Y$e=Xaw(vo8tX&ll*&!vYei=qSCBH z_Draogjds)BS&0ea2(a%8UiwMDA0x*TCzGBN*wN>Ob zU+e0&t`bi7jc&%yh_s5dsVS<`&DvIi$#wGrQ?QE(rqnIl6_GalYPV+BL^|WwyG!;` zciCR3EdQE8-?qL?nc2kXwmzdi{oWJ6R?rVLQc^o<=mp55Fo^4OG$D7}z4%nO-D^8H+_1GLsHRQG;*$?6 z7Cp>27_|D_i)-EX9cR!FoSwHU`)z7*BN(_m@SN@)C+y>&qqhVVdcB~};}w^6s4Up; zcD<0dTTZXfyDeIZqhBV;1U3?$LfdO=p58H-$~64xJDK;C9hDg@^KQ=0cqTJh_Fd&b zwX;ZDEGN)G!2UVZ76dyk{f z?f;S^`1coE#bhiLts__nIzsTTs-lESs7A_xHqj2jU=h-e9w`&!P(dmanvv4cJJ|#E zdKP&LA%~DhC?FION(kk!!qfv*Rl@3|h8m#hpuAJs(GRJ`dV)v24(fEu34STkP`Vth zgsYJ@Sv$o153oKrF#=~Dbn8qXD*R8v4OtrLk-?1j^`6FPbemBI@I*u-I~if5&Z&q- zcsio7onchkh;mFBsvoHQZyB}P2p>VaG`>7a%uVG$zpjmo<5GBbaxN-Hr9*987Bn>e zVw9iKRK`_7Q;DjesRhbePw2_pjd5P}DTyEW=sGsg3jU%mwX{7iWXL z;m(fl#W}|N1IKrRu=zxs35ElY^9m})mgAt4&~f4u&I63;IBBCR&aT(NGvSw1R^VuQ;#faV_Ga) zIK7DEh^RaJ$>pMr%z(FhVVn`-DArqUXfK?->wn_Frp2PsjYS=I`^@uW-S>L2;kN^9 zKE$C5Ip^_*Y|dNlaL*A$0VmE`ykz}B*zSUXc)Yo&TW`>B?FEZD#_tD#v(sB#k$Jb< zEiYbQ)Tp>^7A`McK2EsDyw^LO|v%`qbwgD)Y?mkD7eTV^IZkWIGnrhv*v?I_E zX??W$fb88$1#^pJfBI3H!pjJO34qHA7sp){Y8`(b`MWBA5fT5%DR>R3Sgu*ZPFH{l z=u3qxzkrBu6FAm+N$z06sXqaVwD&U)={@D@x$f!ol147( zeh~?pdq)km2^0mSj6-!znPd)?u^O6_Y@{BlV=dCIDRO`oKZ!d03_z>}gIM?Wz1Zpv zyHE%juFA0Qa_;VPYB?@k#ZeBW^l0jGNvWotBXv@6*RfghQz)jj7Tcnzn_z&^g$L)i z@FJa?XB~|Yr-^xXuumb)RH|!2?HfZy>*y0GA04Pn9kY`?)S$}AGA*>23N=k#hkf!r zgXtgXr09tnLX8p%9<5kmGRqgjVyDEYMYB^zy23K=>tj7>(_ki)uFM!pFXHUj3ag;3 z9jI5757almIL;+>q%PK3b_gu5hfB!mPyvE+Hyz_xC8GQ~S~|#M z-OEwtrh)bN$EZBcMAdONszl~ZgMS!R#%5F-Tem)hD)=Z=!Rh>xIvUsJ>wKW`<+K*Y zyd)HOi!D!kHAJtZ)>~{P%8316jj|MfWF*vUkty)&ku}3ZCAW5PQkomvqgB`@0~~)Z z2->~f2J82{M!VM->08h2#A+Z^Q3%rv3-<@!T`e(NaI(m9x;iSoLdwDA#@m;le435# zk+2%o#cd<51S&v7o2D+SP=1ik8I9$2b|L{274Qpsx96D4EBVpX?ED5CMs{SiQq9P^>}mggZmw zrwIsoW6yp*YW2)Gkfyktf40Pg@qO z8W+-HbQBk-g(8<0iPlTgGRhvmFP_8ld>tX^04Qo+v(#nPQdhOATGdvO!ynR?S_Hm& zKTr7kWqnyKB7I&hQ#t;iKqP;c<}yQ_i?no-}csv~HLC zRBn)V$qQ6zb}YnE5od!OJI$PA_L5-c9nTGiuw>$Fe;5viVPG%8%7^+pUs@DKbVAH~ z)PA49fWR#R^Ewva0q$230+IyaCj6Od`M#lQ_f3O%qN>(OQB@add~;{BL_UEYj7azd z@{WUr8v2wjI?k=3>nA1Z)wpp31Lylri<8Oc^+LE9_8nnh`u*K@7zmv?ws7d-ZV31J zS1~e9>a)<5I2`~n?zogF<}DBUx^rY|UnUY#s)Vzmj6^`@j74hp{qpBDsNfW=jP<;r zXrXZKBL=1~poL4_dQ7!-2DW6zuEH|kSDBeO_sEBuvkfmJy{B;QnMfB{F>&-|NC%LP zKHw{?D$+StV|C%|TVl&|&c2nnKCb|7zjXE5==H`tvOzxApm75xDr?}%A)?9>PCju=7NA9nKe)bxQ!*?wp`nnz=v0KGu$r`Q+|>BD!W#9} zLysu0lc;c(gbNWuVX{Kn6EM&(9q54MigY%FLoKrnk?=0&5bIs{PHeXNJ-B0HGiY}Q zzBhXIbuZlOvqlf=dTlru^g}q?8ZYnl`+?WEKx=y83B>N53ynT+T#&UdG#(kP@AU^k zgESyf7Y?;SpHpQ{+Iw>|&iMVe;6BO!lam}NXR7Hp%`((28l77I+B0Hw5 zd)m}-_tX)Rw3((QtwrI7F~np|xnleZZL^uI(5sYD9WF$i5avS85i9%wLO{BkV#vE4 zT7Xu@jov^^4u(yVf7*{k1yeu6^odD%W+1i@XAtKQo0FtYlof`k0rS_2G#IhAM`4N|Rtt{-X9O*e9`vu0M>b-#%=_%{gr20*;(+;nL3 z{dxDJ5XRgCMKdpj3z-ad@o3+6N#s&4;2#mi1MGA=*3rIHU9-ukV{2OJm`+Iw+TSAt zPXI^@1@cJC!`>(hXsOlzHB7QL4B*J-ASnM}ox}{mo*fti2k=Wl#{eshvuMMLtVteBLo8H? zz)E20&2f&fp#swb{tl*X!1N%BTQ=G#F&%owxAAb9Ps&iD!jpCDkI){~DntLl^Y5WP zsW7nrxDZ)r^(*6|C@apDm4K^4mnkkQlN#!%$wMV!%kVDY(K5=Tr}Csurji)NA7i%k z@JJ7f+PD%f!)FFRoXAx%h8kuO!M#QZLberwQqsbA<>D1$c(E4Kz;*gBq=lhuw)982L% zyhXmW5-Z%);Mw|XshiQG?}2&(``6nOrA1Ntf8l0$Lh8*Z2Zh*QJS5BhCMRiJEVn`R zo)q;esHgEwF)A&Yzt$t;PwWxOwon_pM=ecGK^$n4(_}{C0rw0ZUX>2{1Ih1qRv3|Z zz>V;DEBss4ZT{0>xqJ8V@%ND8D+GQHVD|@q{?9-Bx9DG<$4?=2*6vCpQV5mHBvjzI zV&lsV{yZF5Pwk8wE!S@ied#5F2`3(E$x$Z#Td4pM$|wJ2v_B`9saYL?%*}wLG=waE zUvfyl)6lksJi{<@!(MP}=y@X#?)9Gs)e+N(w1wDu&Ud>vnES#K8vJiyl+hO!ToQ}L z%GPs!ACm}PKti^{e+4Zx&m}G+`2!06E#O4wab|~{;{4YMA%qqgWc+;sC4ks;2LrDs zt!cbQ4PYdHKzM_GaY3)Zj20F$C>VD6A5yg@fnOk+Y|!5cp=b*a)lA@V5`2-3C$n536M&R23%{3_^zk_sKZuekgJ5!_EN>luMRIG9-3N3Q#i$RuB0;$G+ zNK}R7yhuuM(2XG{LY$kDjI))%t+p*fg{J~v^XqtbYz%i-& z`@r2p2z~}YfmB_Fwns9zC}eCN@j9e0>4Q{8UZYh!Gu5FXmVv3FMpb`5sTI{a zYSf`K7J)|_@C3jM(EO3vq*>x`V<0>tFb0UTFiY?-=?J@pf0OVt)1EFK{(A(3eNVID zzeHd*%`a1y_*YmzL=lLexu}MIay_-_)7UyKsG4Ua%aQxVE{K=)E%8b%U$LtaGYs2) zuqAvlbh_e6aA&Wt)J=oupKPG~~D1K+e zzb+cVL*bTqDBL2W-Jv8msH8G0kv(UqXHtHp0Y_%F^p(bU!KWx&mL)l+^E6XOa`R|YHqO)>4V;M@=QdG5uqx;NfGJG%l}3D58>%d8UA>{JRHyky*KC+Fult#9 zwvm;2!?(J*Mo#8U-|prc`EH?6=oTBrZmCh~mK)`6rBP9p_Y{_4*+Ye8z1)G`m_ylO zIh1p9%`q=bjrjx1G4g^6tC-&<#B?f1+(%*`!&=7wrmN!3MxnXeSz) zrym)|*deAdtz~?vv{a@u<1p7Sy$my1=1@6M8(EfR)}hkCb1@tD+?R@>G;BOqUQ{>x z%zM{&gP`4NH{Gz^?>&AQ6Kpo~G#M_sXOMKH;BSCX3DrnB(2lfYMO7l?9X(PyddEP{ zj5J^-%tqRgb*w~sWH9wBv^%-T1lmy+n2$2RLSzApj7F5A98-33{4b39mcp`>SV66h zXEVMc*>g#4HL~e>liK;D_S~u3^ImU84L#==d?@0avZ$i z&~e0y^Yp$t4uj4>Cxp5umZx^`-Fu!Nh^48GUA!ruYc!EPaJk#{LXQXHnW>%ogRtGj zs`(i+d4qm)J2+#Vem@ADR_~0V%)8%idg9{Ld8N8vp4v1~a89};a`Ff`ZT3K1nK~eD z_vIlISsWDEce-$DO?q0P7lg8FZMs-$PWqbM?QMC&pf@9|-mvTRu*>54sRLc#cj#!2 z2V0K!z-tb1crQ)uFV_KUfMaF^p5GGLu>zR!yDVZGc!~BRle~{ivbovuE5Jb5k4n51 zCzP=)7c!}U5z0AElMb@u{7L!t)wS(@*INsM-L((jyYq^__h2p9^4Fk$JHZ+~W861A z^s9qCK97|Ykb-LfMJs6;)ux1BCjOyhA*ZS5PNXGmfB9)`Umc%?rb^Oq(Vfz}NC8J@B`#T{61+PQ3^=^mFY8%_dA`D9)=F4obemX9nzfYfvPH@ zG+!eauMSKo(Cok@0kg+i+Ny;XQ(@55ML3#27)<|Kr?GNO4rm{bjbHv~QqCu`VTeZt zlKD`5Pz?*{Up!E6DG$}pv{5GM17p}JF>}Y_-^6ANWwTQO=87!9q>{WPq*x~PgiP#0 z=4d&Bo4lYLtIR^)!aaw34tE=O>;%i>p2xj#v>ZQOlwCysgA8`ej*K0HA4c}bi1H(I z_wZxv;?c^n63-B~KUDdj$2~A7b_;L9Vx?4`F6G6wWjGdkCzZ&g``3niGbZ$DZgYR( z1MUrYzv;oP^|orP-}7qiUTt4reZ3{rV2gB!u8sZTZ8C&6YQMX2{WUgvO!~QBm4{Cr zIcmU|(f%CtFZS7KRhg<_4$_K>>;vnP))Zlcy>ThnU-(6Am8F}lZNOIq+a9~JzkIhJ zx_*s1*F4`H1Rk3VvTs6SP0q7lyjHtMCDQa8wfZGa$B1772!lrR3zW47>C^LbRH3`9 zBh&}N9PoB86k6ElD>QI{;01yuf@cV_1eXae5-bzs2(A!3M_>`u2q=c*&k{Tj5Sdn= zcim7FoIyVblNV4g$WtxySd~MIbb|0i0oE-WIw5UJYhjN8ID(wf=z$|qh_fKFk}l_MXexj;CBvX8(2}RY8Ca1$d%NB zRziLJqaKE8X$ADd2q|Mxoo@OPzmOUt+Ef!s8lt}mx)>kfBq&Mg2y4goqGw?nU}Y-CYmuqda0W zy(LHQuG4`-e=XiVT7e9o~?Y5TFG3P9pn$Rk(sv5ooGWIZ5c#znU;p5rf+Tt zy*U`xtFh)8N^dAKak)`wwR`Z@af$z&1l%Dw?OpggsQU~l zr~#lG(%CI(hFX9Iz}ab4tqNziB-8%oX^!s1YT@b;1ojY{Aqcb`2b(bTDPM4$-J$Cz zEm{v`;9ZU|>vk_hFu8XJS?Pa;g=2$>h@o#CoLa$~B5&;@+Dc!VqWj81R~}Q%~}h zG7Ke;7tqNhy{)bk4C}hAZ=x!_=nT^p+yzWy6NtNIUAF7No3dT!>u86Y{z47a5N88CMNfH(A=wFdNXLlX{Y9J|oG?{7 z@*LO-7)qFr^rHg!fdSWINbXXUe5EA$N?9@(FcV@Vj5rPP5#lc2GGaAV>7)%fXdBGZ zWLe2w&L_1CQ66>U+6zhT;;CB6PohFpj7m}Y*!YzA(K6DCWGmJt{|&Zk|C8J8aNFax zV0dqfyMt}VZ^JY6dF?7LAYm;ss+kze58+Z=t>$PCJKlB;gN)3v4zs<-W zHPWF{W;;mRuGOXpuVKD3j@~UZr|Eo$=Xt);UX~Cfw@kj6U!{tlbRb(3ihRa?b<6>~@Eq2a0!P#^8jT)n;^= zFf_hMa-W~kcS6u+O`6bxSr0yD3bVu{tm5*F^+=xa!i)}2&^H+)h?iyzI7#26?H04d z!*L|MN(LF?#aWUmOo$nrG?3!C@9K;YLA*SpQ>yUs7l3*tW+Fs{MDeuRP#`hUVZKVm zHG=B|-zRvL;5C4lpK@HpLOR>2vnYybJsmr-z`bUlv&lJ5)|S|N>BpqeC*Bq!1vdan zW)ZF`g8&97IUMre#p2-Z40Zp?S*PoNaV;|+|Q;KrBtMUFFqw9G3t7ks(6N8!*fY3!oDy5YF!g`*fl;m9!N{L(sF_;$V z$M`~-P(FMc1F1v|HHB2-dsqjQ3bvkznSC@V+hf8ALTMr8B7^^)X6oe0SB-3t%xA+Q zYDx*otb(}3B$r~ovLu=R3@OHoNgt6xBA7&6t4Q;bWEMxIkf>$chv=Z$q3>w2rzDx> zQ3WKEXyrr>QOVP?h&qxiqLmX_L@i@^@Mnok>Is?HmCVr!{NR{mu0%yiGFPIKB$+Gl zg+ww}qKYJ$D+$Rgq$IN(Ss~l(nFXyJONePI(aJ1? zjD%L^PS8pN)btbUTb3<3rgJqDM|N{5Qa259v6ZxSl{7JemPoo;bCwe6 zIpF7S9vRRyM;!fTWl1&rADb+Zj`%~Mx|SARGId+#%-RvtMOcr-!U3< zC_6lla$c3^Ir~7XyZ5mYZLQ9v`3Cp9xVu;U*uU2Ic_2`7-t?p3x*y%F+tTry&7ts{ zThfVw*o%UIODh(ByCwIr|drUOusaZju8Q5*t`d9AJslh;owW+K`EFIrS~a6(zBX zS_98#eo3)I;dg^L5RrUrX6ODO zZg;V2amY-;px@kx4q0c|k0P(tJ7g#iZnc|%JT-G(rS@lMHcb^>kS@u*ItpH!-;w8L z4oKTWb=YJM2S)XsEu31jo=zM@vFcizEtZwz@s%}A#=?0P-yvV3ypKtBvU zI-=vjju+esnnN7oGc)_Eb-*g%sF_g^wq$Os09N`g%h(1{p}nxP_hDz7n;X9Z41^u2 zrdzR387pcb(~5gc<~^@ThuQP)Y2UqYaiiZ2E=JMj#Sh-T_I$W?=VG)TUWESLj4skM zrhO*|e`T;GPGBWPr09nLO)u*6?^njBqRB{Z zc2SwCSIAqkC2n;+ofGaDWcP$=`&{?@BV>V>G9~O zs#*Vx`D2=w?J%ar%Ak#4&N?uxK&t~I1mxgBhiUATxi!?pUs9unvel^q3rH@~9Mb%b z{gFO$M)~*{N@RMi&T)k&fU4(?g&yvx(0e~Mf~u9|oE-PSJ(ZQqWG$j*DhnnOQc2!v zq*NyLgiP#!y}Ojad7jkv76;;&-mS$2@f zer|i|dO!BV8g;G(p+ATMJ{{z~35hi|&vxlT?FN-d&o9^NX9OK0aTXv=8ZAyy)*Vcq zUMx_B;q#7Ucce8C?OrVPxGx@~fr|vs5G)ha2yz7H2$l%W5abD-Ab5(vA$XjCVmt99 z!P5X~xB8;%$Fk@R`ca&{fO=7#X<5LkJX)j|#rq1dbh+4zX#=D~4@X-n^D0M=DJ^9t zC5xpyeJ?UUc^@*Dy$2PIc{JuB`yrmg;>fNV*SKwNUmNmh zvk~9fEgkWcv9>CW=3rQ#Pj%2FU!^okmXzc3pc}V+!Bc%BwG?E{Q$@T=ISNA&J@-X( zL)y2Z*l*s1o=g=ks3>Fpi|9e5AC~a--O4_u#?PlK3SF zc%R^)ml1EH?qj5g76L87&t~NZCS0Lym?)8s*m@QEp|gbtCR`!m_SsqPZ{NBR3Lr@A zA$CKM=z1QuV;EAt=y{t%Kg?S69!SGm0%6$IUW{OK>l(7kKT8Kk2?1xkC0@nKlsc=` z%2MY3n*qWUZ{GuvHy7kPRgg;FFQUMMG}wGlKBF_&a4xB~CaMoybh_z^t^;QA4n*fF zK5sdSe-P_P`G;a1#6Q4A6!SnHW}Fvz@h;oPLdtUS9g1W7k8YeGXT5g+0j||%m6nlR zl%Ft?1(@litTeoq6gj)8M~g{`?xGgMN$;@FnQCX^b=7V}uc&rIypDFb@6TDxVw@H* z8e_s}OvPisSUP4J2Asla!ii(b!RNtiz-8h>V(b>dBuqFSQ?Z>=R{V!kQ7iiu05I6E*_{=OeQHNrKFrx_RNon zNi88QD_-NQ368DW{`J*%yb*|6G`z7c{K1A7w&6GWqIMq9FYkA~Ndp4l3)5}^Ux$_6 z8U!Bd>$Pa3KMXmXP7Nw}tIf$kzc#r4nq>&Uhnkq$rPLtdxkLWw)ZAVR-v4x`^d`uvY5|fdQvq!8)(T}H&=k>WpYXK6^x`iFT(S2I|$6 zlC02MEFDyvi=;LaLMJX!@dpGi5WGn662Z#=d2GhnkPDO94jepLn$*)7lM5nf_646F z+H`F~nR(^Ll&vQ+7b8XQ05p3MeoBXj%I*e{JDV;3xAYD^55MoX&mDHs{wI;iCx2*A zlcJLGHHXZEZ~g;odW3J1i^=#VayCLUJu&w1VKd`|_>>05iQsAmhXC!Rk`>OMzfTG0?xq6!@i#keDK|Phzz5 zNXHZdEse@C@y@glaYd^`pWIYW#Xu{gDi|nn&Z!*YnFnPNuOwN-Ij6FScc${-&@!3S z6Ed+Q_U|VOlNuCB6O@Yx zRsr?7)TrI2iaP{b1lt5;JH$^3en#+ff@^@ramC3V(tPm?YF{IGm*72uLqqK6(E3ND z=mtR3iHNChOW&KuK{{r#;{QR($Y7OII!52D^rK8y_8G5ks(4Ak+G@QZKElET(E#xq z0*~Mm0y3_`Cs08+y`3}?r)Sw`%%-6?MJu%B(IUXfh^G0ka-G6&N-$UtR7_Jpma4{w z=!mafM1fX%JvDiwkQP;tkbViSrZvO39Y!nkc~N~^Z5+eLz7_ScZ{-F)IeU~9^e+Nz z+@wVfN=XzM$yAdl*eG*TR-NZ=YqoH$9*rb-&C4N4?sDUy=7BDt&K ztSxW2CmBe4aDkqS0;^kkE&8|gFxOt1e<7#--f&lzg*fO6e4Ouj^XB({?@{ly+n$Es zYeyd@zq_Pq|Dnp_XQJ>fp7I-rSc_Sv`SqLmOx1N?S8pRXeUnk2o>{r=+p5emCwKkE zyzJ#ozop7X*3LVAC-3@QRX4L5Zm${#@_8kr3yW3m23^NrYuO}x$6ifz2D*n#@4Z0*p#bDbjt6X{29 zTJ3#2We=ial@@4~4vAGUlWpUyW{e|BD+_m%tZzuOOtL-R#&F(&_=m=6>UVE%9Srki zt1OSU9(?ri>)GSut#Uuxf|U--tzsA_L6n7MneL@gDAHjO#AzuuN3c#1z#1Y5x>%+p zfN0ESlgsB--dv1{81>^v?=#&AZ$hiI0Xw1VQd%2XQYbC^jMK_W%UUd$RQ0Y;S7n%S}PxP67 z3KS^FlGU}qtPDtN0bE;J<#?2EPFJNYAUA^V-!TXX;^fXZUN{1lO)`1?h5h8;i(7UU z%USg-xC}a(Yy`Cv1i>>PD$Pj6(uL4iN4E+7H@{!gXO6HyC1IbSqMZ2OCT<{S9CC7Zv#_h9FP&w_jR?|;rn%qjcJ(xeW)0mDmg#1q;yDNBBn8jLU=aB^m8 z?vKW$w1@{w={*|qEDp+al1P16N+&5ErhHgPJs$zIqwrB7?Rj8m1dAXbtr1TPA?>{( z&qDzZ9}P<}->>i0B-h7gep3X{Bq&A5h2N+vNfG;9QkA;qlcXa8L5Gvh+&n61DYN)4 ztxhKM8>&RMy1*Hrp5!=*#v*wY+bZYOn9El2xO#`EOLy4~)?%xA3vDJdQD6N5Y|mGV zEG$PqqAl_Y<@>(?lJNk34LHt=tG-XS+VgH8H2EZOIxAI8Mw1X#ZC!B z=FF{2!kaZuwYoO7MGHOcQ(CoyryHA3jj2qfpuM93L zxmuUk!9~-Ca0XWfYlHQ{)v7Vq#>PtCy#1PLeJOT`4l6JAp4!vqu@KiW_QnajuYI{U zZ5{tn+^m|awgXRUf1&?bn|2hrHsjV4w{q`mCv4iSIs>c{x1X|U2ikSf25+l+RR@|Z zRnBZfG2_zgW%YJuuh9GGzn^PI@^fTV@TShTG>a(8CN{8lduyw2N@q8W4gnL{RN**J z#KAD;TQI*gc~~66r}T`=)}tidKM+BjMB!s_ninSz( zA*WEi)3^C6w3Jt8Jqr6yc`)AF%M#h7S#XNOgu^>VX%XZ}KID(3Q-=8{OZZvXC)Oi$ zsHjb>xes3Xa%GC>`zLK3@fi6cd&>T*ADK9%z5@kY1(J>3Y#1G;#Xjn0G>$_$%;p=? z+)F}%HY1IXpHcF6-kl6_C>rLYyT8RADm?%v-+B=z3(BrG?_iHGEAOJPn2p@ymsp_x zs?-b4Uq=zy6guUyvhmyxgMB_6kCdsogE9PN8fb|jj}G{+sO(V&SsNNFA%T3K56kj{ zLh#|}@y94~G7L-?a-;$RX5{y%L4Op>6?nYjq$CdLxB+P=T^gEY2x(U+=JIqGZX8pF z`02buMgN+|8tQ2mqpS`X^0;}pQ9rLJ9C0MdRU?PN5J z4ocZbaPbMJw6`%y_z;E&W9bqQAB8+dswt9De|erdfkHX)n*tIi%25E5jff*{;>amw zV_0)IqWKs(ERrh8pdzY3{O0rbiD%CCJW%L*29u|Afbk1-rMq`1{ zl9A;pnZiA9lRhewdJCnW<0XctBFgm!-E}x4 z-n~T)bmX4(Eafrd(#&FzUrj(JQDBCLM3J@g;Lp(%VjbfKrh9a9utOkSEMEs%Y1fdDt-Z9s8g&hDI~zvfx23}GLa*Dj8o}g3Q04ESZ5K~TiG!x% zCqZHbK%%&M11*zRUj&SE2RnlYo$U-Be?jH{g@?+i2`maCxKwlIK)M5psk-{=S+sbn z#Sl%(#4&;ZgK_GTO|(+3xM=lMtA)sAeb3c%4h{Qlb@g1CoEVS`5*sAOBp#A@M1m46 zJ|XcH2|BDIDFht#*+)xAk8XB;M@@CFb%nq17Ji0lM|jkCG2H_48GRYfyGj|%#*6H~ zNY+3I*D<@_;Olttt0bkrhypQxKQkXFYX9 zQMac;AfqT!(8wi>E5Y7e(zGgp`3>})CJepe?4&UhS6*P4_@xr A?EnA( literal 0 HcmV?d00001 diff --git a/sgl/tasks/__pycache__/node_classification_dist.cpython-39.pyc b/sgl/tasks/__pycache__/node_classification_dist.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2085d834e888e58f1f172901bd8bb9e64b60518 GIT binary patch literal 4800 zcma)ATW{RP73T21+?yoJvSi1Oaw&oh8Yxa&IB;DzX_Vw8Q36{H3L+^%aVTk5`}L)WdWnYY4L-VWQkZD$MlVz{X5PPUX^2ruNz;c|X4yqK?qEBU4HQob6l zGUG#oyWD$da4&70Sm7GlKChu&OBd0;%-LO|6MTc+8rvPF?00#bcTDAN#WF5ZsRF8e zG>CaBFzo*VojY-H*m0E~C&@^}$+7Yw7!_&Cl_N#m?W={n+wVrZu}lu4s>e_7#n~v9 zsk(4(Y~GpB^5=gh8t>pKUPX}_WIC9Qar2=OTHN9`-Zpo*i?_o)e1fkUJGAy}--(5Z z^B>ho<9m3D1r#MKjmRiXl+b+Fc+9pRo66?+ROMvxZkly0Wfy6-C)yAfH0QA)mhk%5 zjm?8Wo^BS!(dOL`KYAlOey~~WXPdCrVX@gC@H9%YxG1`N-6WRXK|kW%LT(IUr6__` zWE7Fgf`DQ$myIu;-}c5_CSf!_`%(RD|Aaqjltz!8;OImXO`MmNrG35E4ik&$(tE86 zz|>G)6hR`2s8NtEgN>J;Sy>AtdXNP0qfMSI`~n(S7RFmMCNN!(>6!&IT6^YfE20Id zj;;5gtZttx_CT@YXSY+2qWx5MWtvCP3T6~83WEjs8?QdItuw)G0$mjK<2;Qb6+}@! z;G>M{jVL-A#aT5XmZ3_JR&kjMI<<2JSwJ*U&`$NbKJ#tg(tZB}M14#ocpJqT5F1|^ zdjKQ19@YRy08ztH`x)S<{VQ~N!0YJuzJeJ;9}sP-m2KeTy{F)@1$2D#I+Q2#rBK&~ zC)CoTeLxx#m_0G4=BdFd{F32~$Sy6E4xzZhf~%ZjG)#rKhPm1nYCZbyb)$$7PhNiQ zxe~B(nyIVL9VfS+pSL$x=Tx)cK`DRCR4uA}B8r|MQtyl)Rrn|j)-ql4h8Mqkr%(Li zJUq!22sqlW|AvY*rpDmEM)cD+-xD!vdxjS7b-BZRryuDL)f|Vh$ zs3rVKt@KOx&=!9!{fS%F^bG%mVRvm_n=VveP&%syW}>$wZVR@@wd-%)*mx6v*H!)F zckk}JeUd4ZiA;IuQZ@~UmtlAn z3`I(drbQuMq7EzOJ+VqPdw)18lta{5sNnuUWIQUm<5Zc0LV0QbURMnI%FKrV?l8XJ zSMDrk)S`KOP|i?v`%<}k1ChrP{yrQOa<*P4s2J~uSS+l|2s?=i84DTKsz%!9VVkx} zL$gWBlaZv4Q{K!xI-Y5>gg$8}lZ7>1BU?SgARwajL7I$Y`U;lSE^08Jt>E#^77>~0 zvumuuR?G(aY-Xdq@;&&TX-gcaoBxQp$U8I(yn`rEnxJ2^#|W{PjEN)7o<&g+eRk<^ z>k$$P=>VIqMwHj{r?rxu5{~StU)5wVt)CiIYvRfVMw+LjwS}k6?MK$cEA8IG#Fs&D zvA0yhPxcv*Zcc1@0lJra7j?N()z`pElbZB;mwKzcwch2jHrtI0O~I`Hitc@Zdqj?9 zz=KEbq%N;Q^4ba8GQQZIG~|nAUDtQuNzKpA-y4&brrA1gJoHO{%Q#_^cG>Dd18+WJ zlNNT?MjyPbTqs-E#bW7A*EK&bPG8b*Z~8L5kN$b|cVfs+Yg^g0k3J9-wRbwMc$HKIRWhV;?-d85y}c||b(#f-I7|gRW7zFSd72NzvGR&I zA7-g|DhnXyBXsBxPRzLvUio5qf-w9$V-0~BsU&;E{%9WA_)>ER#M|g2|Jcn2$ziv@ zk4@Xjh{yCn8*eImFO4N+tS*1>gqBrsD?>y|2Kn&zFR_MZXP24Z`gq3IpU2LUQm>6$ zSSikm+o;TSh>xM9^NKS2LcED8k}M3VC2i}OFGl-fFdAx;3lB1aTnee=KqLp^6RHQ4 zU{(f)Cghjzi9u1k-8*`qH!F0yrRwxOv*~K-oUXgo5ef?H z%F)NK8~|!RRc<;QBnO46r2s%GC@ZdwQ!#)cVy=7w-@2C10z5t&U;SRtJ>?>l5RkpeF2&a<5KfsidgjzRHN+}W z<)4f{n|)+%rwoo`^z5law1|AxoVr96(wllnA&??WeOZG<&;u<&2K0wMN&^BU0yf&t zOx6VfSx6{6*(!;AgI;?Lu8H@;yajUd&qf!Y8byl;6lpc2kquj~`DxYX{izvPqcX3!?3LS(^IZ!smOz{SZH+2>kA!ZVq#2)%8Jp05u6pvN? zTpSgfBtZF;@}T1h+|pBpFsp`muIb;# zv?DzFwoIl+N@Fg;dsirlS%03<=lL2gp)X=~rzWo8B`8>lt5oQ#+0Us)JfTlm&|0cF zyCNiMmd&pTnps{WIn69|pCNSxdAZM8p1PW?9Whu`m}gH%4ByIJ`h1w`i81@3?+qK1Rz>$SDd9pYPqCm zmyX0PPco2`gB%Lc%$u1vZ@%CAy*GEN)slhl ztABhRRj(Mvk5rj|77DkK(hNixg2hIw$XSc&I`5k;lM#={R^M*fx@^Wy-)*_NY{g#R zZ~3}x$EAL`Rqj_>m43BV?Ju>K`pd25eyvqw#zR9m!hLQCH}X%+){0<{ji&b!GaI{2 zrmW9I*hgNv7iQt3K`0`Lns+BmqbFf{*mRT^w%a2awvUyUMp=+Xkx+If!)~H#{ch3? z_QI@v5EL9QdKSi`FpJdcg|<1fO|#BFItBXD2OyC_J8$vGoLZ+i5y1uTSl<{OCQM;H z_gi-42wOPMjT6>#g)6-0M#~evC?WU1F)X81!oJFCYj+@`yK$JN-A=b1X5B#&JP!NA zxSQ<1`~WlUw#zgtKFVz*6DiF>G9zQTal)tk)L=%Aa$x30&+J(!**OPwGB@W_@6^c6 z+!E{s+5&!-B>Ic zhkH?+s_ToT`Z(8Ea#}-a7|O68WsywPPZswY46|;(JC5XKi-?AU_Cb2trs6@G2A$-x z!9?_|+m6(`i`S+#d3&*p*6Nvtr`U)3Xt6wL`bF{g7mJ_U2rf1{irOPs&dtSrvM9|0 z*d&zQN~5@=-1$}8#T`{~M{etzf(E3W;|?v`I$KcK)5f;O2`5d7Wv z&#!h42K{I!O`q;Oe)#C4`1oii-H&&mX@}`f0u3o-&rL6#ncnvBSk|zl3Q~Fv#NbQZ zVI@lVoeO`Rjc+W_?CrS$v>6S9k8(#*9qE$OIrl03tu7d}<5hluZC6;5i+Qb6uA6 zzwmF-BiShUpl%NU0_qF%1wR|h%u2!n@URU-*xy^i`N5=_AX)y7Nhho5K2^{q;3qKh z1p(Lcj1R!+t$;h{*Ju3I*Z6Cwo4q%Le`v~oW|eFmb+>orgm!YV|3FpEe_k7|{N?8N7LWcJ1_Goun^2CNjc~VJkXYb5pmgjBAUmfM$nM55qoyRL9 z;1$xi`S8$u6in5P<33&R?h813wyekmnU{! zp%Zv|?~*bQC9!%ovZY5Rtt*MO+}7IhMh=hpz4e2o zb;N-|tnaGoD2;-2w6`zA;X(QYF2@W-PqCx%`j;{q%0WAV=S=n+Vvt0QZqgW=+aGrn zOCK~%RVqBa#z8kd7vql%(U2vGwjnv7R~Y(*?_##$y8aP^E)(=vd;Gfov=CYsmM@%k5q+tKd7 z@c<4pJ&43R8!XY}EmrCHzj#92r$891z)WsgkZj#A2i9vCP@(T>^iX*v&dh3Hrz@Or- zAl7l3a>?0HPXg4|C9Z~KXZ;G`L&sM0lG(8Rd8dlrUL=DAu3yg7;tqS^LT}A@D1j%b<{XppC|A+yT-xcHC98L#VlTjT!pQ% z8mqH9FR=|cN^tlRXz{OsV`1dF&Mov?hty&pHV-q}+TmekbkGzAYajM+6o5d+l<5%I z=yAagJ^8WbnHl^^&+a+e7r3Vggas#mmPlbt_)&~jANajDfsYdAA(vA3E9EAns4ro& zM&}OCt#1kA?1`BW!Wz?A%PB9oIY+?k$v@}zgg?EVm8TU1;C4sP_w+VlCHjybuzO26 zKjjlATc+R}_1aZxU743pIQ+@Ht$v{mWwdz++`SIhA@~DjwudA6OPoS;wJ>svR7@MJ zs_OVEfF^z&N+Ev(slvW*Q%S4y9V&fBgvP)-X&~DN%6XP%VfzpYc&@tzDA2*nPl#wk z_%%vRzlc`u(*W*ZlnqB&s%(wDs$5uMAiC{LIYpzYbh-(wyQs-8X|&%Gc{SdV_rNQlj4x&*!O$ZzCG${pSkE58%l^}RJ z3gcM|Pw;BID-p)u2Mi8|$B$6dk)iTn)OhA($MQogPn(ZkvZ~G^&~^k63C?$gu&zmx z6wK+z5k+Dtp#@n-yBEZZk%pY zXk7sHDgxbY9q4ZF;b9d}QPLY$EuXjOPEud0_81z|G>B<9qqXT3oo}Kj74hgN#i4yl R&M~C`p7v3pANbHP{|oy$qJ97X literal 0 HcmV?d00001 diff --git a/sgl/tasks/__pycache__/node_classification_sampling.cpython-39.pyc b/sgl/tasks/__pycache__/node_classification_sampling.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5788532f037ab97c1f14a09bc6cdfb5492f6f89 GIT binary patch literal 4302 zcmai1TW=f372er9mrIJeSeC6^nAl0fG?oH1XaP5E6I-@{z)%9%Peg;onzNKhEtmA{ z(y`cOP(TLql){e%iU1VQhXyG6gXXnQ{s%=1)cwxvQlcG{+Qpo?oio>SzH=t5)jSQ~ z_UK;p=gXS*PpT|_CMusHr5Om}L z9)(%wFeqBw=t&rl!z_~P3w`s(4$Zpw(Miyk-UW#?+I?G(jG1|gQ{g(-yXJRVmvMue z&nj&zvbn|WXW9vCJKW*!Gp+6N3ips#zSB&t?O|_KxxGK+(d{@)(_XjN3A5fX2_A%l zQQS)oUVMm|_B&OYwfv}TAsI+%4w7jZ%e51Irk`p|%TW%DTM5gYh7CDcoYdi>y>U0;Ow&?knaUcwi^T-_lkRy>{wo3 z&KZY~qBxb;mU9howz1G@4c;ge;ULN)k;*$#6^pUM#IiwdeNrhVVVZr z#^AgsCZlijVnw*DWPtOP-aOezR| z)c*bJ-NWG^+D+5PyAST&zY!ll-AxbTU1-`-x|={l3f^<`OH1?H86ArTmgFO)Z-Hof zUALJ>iRmw03GV${>@)1hX zou4Q1q@>@PzI^(*z$??!d6R*v>HmRYGFkvTVNA;upiBWLeIOjD)dxs`I;Z+sub#P_ z0bUV6H-0ckmSp-B@ohPVf}@#7@Km;#kTY{LH+~{lea=K9^VLi>&^lo#cfKbu zTPfQCLj6^=&fC|@_Vw4=H&8eFm$-Xmh<{~%wu!pczkEV_T9n^Pw*JouTQtJ#3h33- zIk0Xo)VIt1SiCZ|ksLM08z*c}d(OVrr*7E~NbOg-Jw~hed-f)cie&X^hucV3k={aj z`^5geKJ}*6tchB|L%X&5v^T|hVu8b-){^_#J7+uq@&3x319>r)Q;zlSD|~CH(a98} z(R>&U_uO5zUzks{APP6ylJ7vX8C7&*`IUnX;RT}5=wX<<8H7@C)%hz*TY1c;XQ2v9d zbj*R7T-u*(ekG!j7y< z!*RD8N3z1ja1;PLGP!y?9;fgZ_ipd20U~ighDP6{c88$KmG-uz91W(tSO*U z@I)P0z;v<-%qyaev!)el0^*)EG!zd2@~ou>{}at?>16ZE!Xc^Ppm`UrG(C*?JCp5) z!z_%O#N3SHaFj;;T*$kH5u0kBN$sQNBdQSy+-$bq7X&rpDu^^mT5M3!8=W1cSfvId zt2Xkyhd{v^Pc#{UBctlnTZ0_-=VjK1LhZi^nipzpn#c|RhCq$BO zkmLm}hJ^OVG2S zxI%}pP6R=;CLAK0M7D@jh`2;3*p_y8C*2_krHm8sICPEXYVjq)XRNN+&C-M2a9Svnp&AfGy4=$sXj`EnaWPEftWDm*-daz^wmNsK5 zn%4R{XJbv=RBc99Izr__;0n>|o*J_X*H770Prl{a6!CW7+ta=lQzOSy;7AvLRI=6F z01wF%uouv-s%gw*CKM_D$QzIc~pv_+4IW>yDTt_^viNEF6RDZme`Li0Fuhmtv zKi*Svbx;e=KtI)|cGjTC9`)5Lv@3#sRX)+-s^)#`3w`PMD;p%IalJ1#bb=`I}I1@jGxS=n$V#=~qOQVceq9=R{~AyvT;4b13a6X%=>l zpw;JAAbv?)^uiIJ5>Y1g8IGl=j8&mb=#8z0A1_JiQ@IE>34J;6)$w!o9&PLjbD9p6V$MWNCKP~b_G z9g7=SsIon^cv&l>-A)8x3KsU)B0dTj)iaDDy_7;eG9mM2eO+gs;oehytu^rp@jV1V zbV=8Qz*{r5Z)cIP1mf4h8ZB{I4Ud-ce!o&LceI0xAl6HOrOr9wk&< qu<}9*DHSv-kR*ImD_Yowy09Yd{j@l?-;%>U+lvp-@O7UxjsF4DhQr4I literal 0 HcmV?d00001 diff --git a/sgl/tasks/__pycache__/node_classification_with_label_use.cpython-37.pyc b/sgl/tasks/__pycache__/node_classification_with_label_use.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5693723aca018a818898200adf2e774747f9901 GIT binary patch literal 5266 zcma)AOOG4J5uPVGd~><)l`T=WESY-j%8%HI?AY={@lD+e&eL04B*S69EPtE%~a zrBc-JyZZ2M@TccB?cYR9e-@D2D9IlIn8x%_bF*4^bwwMI>6$w6jL?c~*H*Y0=Ayit zSGXQJk!fhU78Ts$gi~_M=wXGGsOnY~&kk!*-K|Fp?n2aX8_}Y>7%jO=(XzW7t+*?? z_L0VNEdQ0p@mL z@!Rbo_uGfkNdn;|LBOOfxZjOsJ?h3?Z^swyJul-p!RLNB^hF>S&$P{$MV~P*^ml?V zY10g*zrqAb?gI%l+9THp%&~QXy<>*@OlC2g*`31gwT`aVRB-Je#~fDpN;}rwJS(!& zSDNdvGOM75oTa%%thOr8Z}(a7LFgw*x6^I=qT7!@>59F3G?QN>!IP^Pc6&FQ#JBH0 z;5yKDJ)6}U*O*|(zGa+)OgF_ zn%Erc5skbu;jX4dwP6`|ZNgoDin}3}#5twox!&`~`crM^!CzqZzQw--R%6D#!@m<3 zSngQ=5^Ehdu`12G$TVzie!FEzBjmFDMbO>d6CMlN{^8vREmIaVyGXi6fy^g+!%im* zWP$Pi02Yu4wZ}(qeXW)tWkKLeA)Zqm zKKG+Q1U!*f=XUH5L^r}d@n@J#FzC1Ul4n?B*iRC#6F-AW7JS}q2lA!4^Hg_#{&FHX z4v&))xHeZM4szlU50Ny6Pm}d`9VPoHO=2e;udZj+38| z1?3ODE<2Dn=8l@lhRXZOJUpoCH-Au)ob+-;kOI0)K97(i>T9yJ++m{wg{Nf35qetJ3W>^@bo zoxDbSnw!{uZepJ%(#xuuOQxiO8xGR}8eE)pi!pvmd;;BWt3;$#GuUj zJ&d|JIOI!MC&ioO7Jz0}jhtSjf-m>|4}1qaLti>mHjE|RF^;bMum{|n`pPy%I{Hyf zdje~!P0;uA#{-O5R z`1Efhgx_9?|5cR72(^%^p6H{TDEDgPdYVJXGuAa4qu1EiMtKD2{JJ)BsIF_rc*R_K zk&=Vn#sr6+QCXbpJ*RX%59~b4Ju*g>+1xK+?$sI91XTldF)bj0sO_4d>j(c->tCn! zfBN^)!mc)IK;H|LNIcg0U&M=)a)8d`a~@fv#e>V~B4p2xmeTV0Qc6h%Rxw&m8{?PK zm!(MG&2UR5%rP%05qUUubGu87WQ28F| zdIKfXN0QKo^DT|vOl?+Pg$=O^@Kw}R)N80~sMk@~QE#AL7{8gVD_vQG{BaI;zMjGo z`2BQ!WTYD-lh%%8>&GWrHbU0^*yNX{5*V{VR)l(SrcYP;GVW4})DYIToFaGr+WN*) zHfQ2iuS}C&Txh!WM;mtoUkrKBypR0M-$m?hvVI&iyK(c#+$lU-F%IWo+$MV<*Ai_rUob+9WB6VWUq!9t41uu zlo_M0WDwbAoC-3fU`i!ND?iIpCzaAS?;-(A_5ybC=-h+8ptKRf%^>s#Nx)7!72k)( zrW)s{a=p1j6uG2Z&DKvi`96N0z#4!wM4#hzr}0&yEK z=l&soiRc_cE9*znq8h2{o{wDh0*RBW<0~XnqKK;^NRS|D9d}+2fP~# z{sM_#Byk(FY46D#l1oHweuH>c*pGMl6_RNYU26AZJaLdNBj11Y^ zNL$7JY-25o*6t+SXW2PrK5jKj0I-WpWkMxj9-SCcMtMpEjK7X~xlU(@90navH_xpM zW2JDSLHDFZ>fFZJ#9^8`sl-B>A{t0t$l|Gu%figIWqxAQ>K*OTs=gOO`AtlQ!hP~3 zfQEQq*XxF(Te=0fgeYIt*AVH8`UdC*uoY_!aeoQ?st#F5mccJ7s%%t274=K{8l=mr zy`eXZGIZ1-LG8}uL#BYf4%&*7(N-WeKX;fUceChC)?lNk=>0Ds@Q&J8SC3=(VW#5& zNuE?2n}U3`-NRE7b^ZiTWM&|a=<#ia7=CaI+yWxdP53}1PmgJmLtMvy0?@OQnj`(; zZBZPT@bES}s`ta&vl(`sbCz@l_Ay8sWs&h9@s19cHIH}N5nn*{Cy$lT|6t+e-HeW#|+x2c&n zEHgavEb8jE@{>E!ABw?HB)qFsWRxs-x-ru2jN&nM>JxZ6A>*Hc_Xm`u2>@%tmxJYy zjnr}8Iiq1TkS}hia&&E8o_J>7trESlV?1l|#x8mu9Lg{xyzF@ohkiI|F=G5N^a01~ z^A3Dhe{gsYs7h?5gUb&eC=U5Mn42r)Y)*_$kjh0d=VM`rpoX)ckq{jM<}Pi|77j7MGcdkxC$@IeJX#c nJg4%PDkjj%PFJ5{S(^QJw%hMeI87+#(gu{(2bX{=e98VFU+3^a literal 0 HcmV?d00001 diff --git a/sgl/tasks/__pycache__/node_classification_with_label_use.cpython-39.pyc b/sgl/tasks/__pycache__/node_classification_with_label_use.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d00dcc637b3d8cf293ba3d8d04a3114b94a509ec GIT binary patch literal 5292 zcma)AS&tjX5uTGA-dyf3_mC}7vMiaByer>{BFOg&2?Rxn6&o1D7?0S)U2J^svH8RCTM0XNR?@?$)C@cP?tUjcDGTj~3j8XwhAanr>6qZfh*Z z@?U8zA5`~EcZumAX>I2T=B3?j>(Y9c`4Q^k2fpy{415-FP|kaP65RKb?RH)|ey20y zerH!YNg%u=2$-}5_j|FdN4>b`J@7?m%gZ=U@VOt3d=bd`6KzM#qR*HY`VWFI>Cg;L z{t8nfxd$ZBXrEjoFelal_Kz9rGnvI~W_JsJ(z?1@SHZP|9CKLVD{WtQ^Q_29Uumwx z%B+GKa+c;6vEr&cdw0Ns_d`EPdfi^f7rjCJNl$Fuq1pT{37%ZRuy;4JNqqPAeXaxT z)H7N^!gDAlN`e#7geLS<+czf00gfpJ?wP69H~SVaJ2e1vBA*%)=Rix%)MEN0w0nir z1}vs|z*3q6ET;}&B`pJ1(+XfMtpe858sJ=70&FnSH=h=mHq!Zbj5N)Q1vTCxxF$Bp zdPE~Pr`)Bqs5UL*E>F2Dr?{(PL7Y)Kp6frqub*l=3;rClw=MoXuo^Q)8vkCLW4V3( z4_ND@g;iFL|hN2f?pZGJ(CKwJnTgfx5F&rd`*NvY+B?~_9bprX~%y}vwFnc)>9Eius z8C;#I5(hbT4D!{P!pb(hMEG2IY9HjanSIphc|CO!Z#|=QHqHtAoxP4@ugA$x$b#~V zUXMMLFV7q`lMR&@mU(zm)o=EoBsuNnh#(QF+w3h-qSJmh_v6h#T680%6OSSsx~;SWl9j^^b1fc{$vDxSni=>u@#O$vSOoCVLcyye|-kGVAv+ z>e_IZFJhe(Z;}rHG_z{t^dc4AcDj`p(&7Q%b1)Q3}8FVe7 zs2SYDl@c@ueP2hw)%wPS0w?Z=axTk(=ZWw6V_gv4KL8`Mxu@L;+ZFU;jXR zY<&9nF@kZw#Q!1669isJR1frVPL%t#Nj=RW02(VA>F+nTwQ(NdI=`Zg9ja>@u!5LZ zy%50BcYy*q>cxEv<7DF(!7(OH#a~LT14GxQEhVvxFkVWF%gE`zMrQMkIbGKZGq$bA zI{$%Eh5qUkhn{g+oasNObUhF3EXzGI#+9SFpTpd%M^p<`4b=IxfOMj^X@ah+)vwU% zKmF%;Zc`gKpzQ)B6&P6{UP#N6iz%fWSiyKPZA@NF7Y}d> zkM$iP&P*<)4QOw~RV7C{NUj07%Lm%zC7eS8^KDPAP^u)Z_OGRlw7#OT0xLc;QY)>a z2llOmjQEd6*SMLT?x|hFNWYry9PwXf`4LW#e=OEO8CWCE9|b8;e;wuee(v|icxk*W zUI9hck?)^Vumz|j_iae5j925k;#DO>I$k?$MV;x%PxZX6WDfPb0V>}oU2mdf`bZM` zaIU4vTdB>;ORybQ0ltK~ih3D!4fP7@I_g!_bCb8Sb*0T^$nWQ1+bbz7fZs`1#zwk2 zHh1nI;riKumg&!0ADg^+*p3k!WIL$mkF@DZTgF{TksiXv7E|QVUt8Z;%F;~S<>q11 zi)&1GeQ$Lm@WqG+t$WDP{7uB~78}Gts~5NS%(d6MQcv(iu$kB$1{<8b)RU+9foHN_ zv{(6%GLsvvFK=9bjg6lul?~;gm0zX?C>ig~!98TJh9g#uScoM@jJlFRL_6YCkRb(! zRI=Cnb(S)zbiTEL#4p(j*!jIP_XmQ~MF_Wo&>toNJML6`4;ou)oW07m)&ru*`P^u= zf6mFv@pA-L0Hh%X{47z+L|LSeHwo1r@@vF@j=*{181YcrKIi@}zeIEnA(agxX;F<7 zb<0PN+9GjsXncueN)%~T_F-p$lM2A5%~}Ki=;76fLt@j>sRCSbmI~W3j4(>s8x9gNU10l+3LG*oJV8e#9XbX&d|6N!9f)+gnJFgHRxv-@Sc{^xI}P($ zR!$j@Tg?JLc9E$Js07TT6GN&fkBNZsH!v^P>5h=QpyTo8xs_3@6izkhp0r4v+c=pp z9HvYvp^&DChEf-@cx>acaAeyuKecJ~j`nC(KMI!o6-?#^O7a%~4ROA%*9}LvbPKSK z*k09_5#NjYD(D8VrnQV{zW{z!ucKIcS@DT38z_*ys5c>9M!SQY0KJ|TO;Es9j~^VX z3QgHp|F4y!@p1Ks;GxV6L=io{$q=`}DYHcs`+a#%uz){|FEspt@;JKxYRFCfLwO-FSBhd?!HF|zzu5x+W-In literal 0 HcmV?d00001 diff --git a/sgl/tasks/__pycache__/node_clustering.cpython-37.pyc b/sgl/tasks/__pycache__/node_clustering.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..626db58444aae2370673be3f023e9156ea315d42 GIT binary patch literal 8114 zcmb7J-H#;KRj==^uKt*w?~h%tYaM6T9m1?-8=NeT9q&5pVC+f61_xSpir&68Gd)vX z-E(iXJ=3n{!DA&nBw_^;2uBg!3KB$-c;E?y_yc(0i3gO+14umWGe}6h!0+6u?w(J3 z(lhtet^0k?J@=gRJ9V#8@ihFtbnqa&|3yvvPkNdBEM&fpB>otLX-to_M)uSjI^K=Q z=$H*t&&sThZD_jIaT*0xVn%MK*eI&J6?vUfqone7RPIz7l}@!$?aVdiI<-cvGvAn3 zbxyRV28%54C#f3D%`;)pcpT&YB%Oo!@#V zY{mjP_g+(k-)V~dx-DH1CV>b;CTp#zFOrbAMdatVXnGd6S==XT*ERgc>ohVT2~r znSnoZ+x~&Jtuu?+pA;KbXfuZuKGBZ#hJ(DziYi|~-eV<|cUhTLKG7OQR%LT|dOy%C ztx*D7YO=c3W#RqlR-b$cQ*J$30&SJ^WzdtBWFm>b21&F;PqkxXWQ=i&X_0CBMR6nb z=*e#-Sm?3-v7Vw$V5QogwPz#eqy}gqaZ_Vd9BZkWTFm&7i83#x`chf|EvF7>C3Qio z$=qIzp#?^mM}C1>s{NuWSwf5Dv;?}6YNOS$_8yJ3hO~~f!8B~o-WuGQ@!C&t#y;zY z!jHR&FZ#V+mnZ(CX42aA)e+;Y>xMKUF3TT=?Va5uU}39y#C2>M=Vr6ys2A3+NgI7b z&PynkZVdD2#+$Iakra=|tS1 zi9E<#6XFtbAvex7?Dmp&r#%Sy1*3<(ZfjRuuqG1ZU^~8`3kyGJw?esmZXR_vY$fi} z!GhwQP`Yuy1KDPww4!djBcD0fYjSfGmxDGtlqKrZikd=%LS8%9Ax53Le&T|~#fszL ze0UhP`bj94&h>A$S{IbYowi&*-Nb&e)ObtA4xbMxMFmHbZAU|8}E3KX5hazCBKfGG{I*o7bKyYF<3Y=5vs2?$8+d@ z*&H~v&Ci$v)mNLNe)$|ORJD?U7fwucmTQwX)7j_%NCzmc{SEk+5Z|EubS+WJaR?Q- zXO2h-+IvuJpoOt9tu+#t=}?RM8Vux*EoT12q`s<85!4%#3MB!mI z4eb=lec$5WPs>B+;QQajK1a)>wzIjXJ>NRj^LJ-GG3qw;#VX5+$1bZ(S753Y5a(T1 zCDQ_a&ZV|`)>4O_KXDXyEtnf;+k|w_Zw*$y&BGq=wn8CrvtS(J&~L~7z}$RkTk2v5 zr>lO0$~Q=f-t|Aa`}~V+_=GOYU_sq}b?2!9U50}aM1J-H*_^)ACh{{asX%3ilQ%I~ z`BtXGn@PJH`ww7M#BRuL4z7Ktn=~Vz8v9|??1_-gI(hd(k@#xTLHW!6BfJG_*1>%K zHk-D-!pV;Fr$D4ddc7$VI9YyNYqlr#LuvQm{v^_Xy1z~xSE$lfN}q2~0jUc9 z43Tvr*N8kz1pXnk@fJtw1lq^`6`jqN{in@8EG%^Hi!nD>FO&wl2OxZ$XoP> zylYf-AIZ)CmP`*dlu%+;RW2`4Bls8S+1x-)H3m7*sv&<&8w?0@Rn_Kp8_bHAm&t*sF}HV{J4y#!bY{!!4h@5A9?g zGH)FlLnn1c3+m1lhVHNkm&;9y>ModD(roqmaDTj+?d30GiS@-yN-Q#EvZ734fRp{O z$r-E#IO(HhcouHj*+GcyK%iHZ3_BMN!&VGkneCZsi*`$-ZSw^e# zIU@1~X5T{+bR8ffFsExqS+D6!5C~7N8iQ+REZr#)FOydo!&xJ*FbKf1enk0l5FGTI zQC7nX7+M)}>Edp{p-hg>E{jB!(_&;L_ZC`V6a~Rdq+v8)qDCap${{6hU698@lDu@i zY|)?llB#Q=c={s_(@#7GnV<}&H%!79KoEcg8a4EI1MwgT6`vtz10K5?@ zO4JnxcCyr^Cj$@kleq^$0jNj|T!n#(iI<{BQrb88yD;v^$A-dCl@!-IAj58xSr2Lo zKf$1@t4;U`2r0<$kwQnoD&E%=A~Lt?7AG-~W~X^5omMw)K|?#D-RVW);N^G1WVg!{ zK6*&_D1ia;@9%cIBJ^()qi%m0&*tH6zsvpGS?}Ba&HB2uqwa^mM2i@k&l7n>g#1)_ z{fx^^r!)mbe3_b@MF3fo{JvgX2`Gf6Uec2XQ5$x}BETk17%`=(&N=_tG`roNxTK2Q zdj17k;nNpxVx}uESxDyYE1M&4UO0p@Xjno(hVFBkkOjQ;IipNW<5>ur?Ug3ExQ51l z>cX)n9z?y6fq-8jeprfYPIe0Xc@!yYA)So%6TPHtuToAV3PU78o`Ys7vkGmRffh&K z08mk2<&0*fW&e_ZrH}D11`|rZ{s}E22P`Y6^@*AcESef@^7AO&r#|HE;MPrn(Fb_H zWFB%p{*3y3#ynv7)#mx|@_ER=`SFQ)E@Ix~_^vh?zjQKLJ}Ihyabnh|FW}bhK0i)Lg-B3x{rr$j_cq70#&+WYMmiBo0%hrN|=Qwp*0&%LTe75gw_fpCr4{V zg_UfD)(jxB$3R?0>H?vaQWNhZU@jBy=tpRdZvu5;wqfC5^SfGt7^6a8Ht^tNL?Bq; zKSE!KI$}P9!h7Qk%>r$~&r~Jb4lH|Vt8VUK^Xhj#Q#X}mosoe@n&$3DcVA$?9K2dO z39oz+!dfvgSfeBJA9lLkWS1b{6>-*Al^j4dOrY0Etd(KEb9gH&hSpT5tBVcuSBS(Q zvKs6Mr1kSS7ymldkYwZliN8q&-z4&zMBXCu5s{P#Y0`#?0XU(Uvwq_Ri09{MCMdj- zAs$4D;Ly52Jl+)FWY}j$U4?5F&N;i8N}gaFrQt9vsr4d~lI|a(bVmbLK|BP~YwVfG zSxi@W1&+%QF@nQSQ|k@wvHt#JAQVU`+$ASL5E!l9)CO890O=Ht^)P0$HV2O=zOYweCfuh_5pwqOTGV};Rz`iQaLy1# zADg2&<^VO74u3m!G3r~0givkyRBZ)iHOO^&NgK_hm!4L}#;~eNtGH>%ryl01jusf! zpjPpBxG#vYAT(~X0`@+aEbT3)bNO2h>zF_M9(F@}UQs*R8ZIz*SU&s;?UC85i0flT z&kYxG7i;Olx|S{?xr*WqYgu53(Y&k;%a8Sgx3KGZtO5632WJmn%IBYl4^Q*YD^AQ) zZL*4+x`rTq&imBC9acB?;8xB%xSoi)NT5e@MAwO4!dY5YBb#w_omR(+>(uXg)o)Jq+f;lZj6v%%Vz{f>X<^KJ=#7)aNKc{k2GY|=KCAu6 z94@6xqnp^pGYaP}CtpC#v-FlOLt<9amC+X!zFi%z4cF7vbS+)aV?D6cHHhbv|HF@O zQ;sB;{|3m;Km6mr{^H-$fBPE3vLgaUKMV&U z*8Lh<%^;t`deVI*YIYv6=BxjM%Kt*r1}|P7A{^zChN^Z^n}j+79N=c zS4Mn?bvp>lG?TD?i{C(vbh2++e3zQha-_W-b(;y_rvjzb=$h#o*jm&>NL|#_gF+v8xcGyh%9N!?^NsbSwTIWi! zHYhTzzS}h_lLi^I(1S)P|42kjYPhs78m?gy$zEH1F3qIVYj_|*)?7N-v*ArwAf(=w z`X2v1T4{sGdqh4vjP(a-^3O=>E1^bUAHFn&t0+;Rs;X~5xmR_=pgIp}MK2o^!Kk8S z0c{ra=MdANh{g>Vg9V+^251e*HI`6AwS$|VD_8Y7{ywevdqe^dSpdN6wh*Zxx1aw$ zku%#lZOarw;&N5dT?#;}NQkkW^g8H7)(wl`O!KhdG{6s1tEcC6PDKQkK-W zs!jEMEc=kQn3crw=5{}BQLq(pEhCPlH+$!QM1AMPeiSx&yg4z_bZ_M%zRYZ@FEg7I zcnB!Vt*rgDrk}JUkr{r14Oy^7?A$WT?4()9upleTk>fhu?an#J z?yBRNwC9-1b}xg37Og;9E~{PaT_l+0jW>jNhX6rU>3|ihFfix*|w?5d=(DT47Hl0dL1Ue!`pWSQf`6e%KL0)U6xnr9MT>usI<$9taPL`@~q4j5PKhMmewdh zG!agI=cRNvVU5Nebx8Hg1r^AE$ z*TqhFoitWlC(%yi0Ttcs9`Gg1ri36agJ?$Ctmv-c=pF^#ct~mWXRceG?il^Glj^uR zQQ8)%_T)!t3U49^2LuNV$LJFs_LF}B`A>BIQ;7IW$K>ack|v~0wnLDfYDT|$iW#bZ z5f{oxLMYJj%>ThXZ|BFFx+`>)DG#0`Fje+1HR1g*sst8Mm?j388 zA9#bpporPn!{s!`oHcEvvmDZS#4cjYgyj(z5POJ=hzrBzj8E#Wqy3?SnR}^q&-%8` z%c(c8V6JVIyO0(Jb_!#E&*DdEap3G9!HBPr5zjau_q%zb-}}>k81pXm#2kxCzAh^r z^EVZLlDf;vbkRVT3#qN*O6pMjJx58_!sX*^o470%xB9DJj5xMRd_z%G4y{{2qU3~SW57KF{N2-vif_eLO+nqt~7 zUaQ?j&R1Ofi{spsZGDx~tfzzn*D+4#6hTbXC6C(@g42LA<`!plTnA#&2F@T(~MIRx=l5Dga8 z)mL={qoP-kw&>3&!@8FB<%dNRsqtUse=XNTEd^J1P&3IPR-@GB*xm7!b;4?pi%pKS zvKqgK1kRy0)Rl8ZDI>|8Ap4N4hchyUIi`_e)Yt+Yvr+^!ew2hOb_cEQzb$Us+mk3$hZpXNV zv#fOPMFaq(bD+3MXY70#t>`hTatqw|5Co|=R0UUT`JsjD)-sEFMPG)BczW6BpFef~ zoly89d5kd}Ir12NAA;+JlrH-IeyXMf(P70jOrMJrd8{CKenLR#%SEI(!Fo;eBbi>gfv)hB**rs*e0FDH0}=?#<62ylq8 zqcQsj`WTlG8p&|U1d9+ZnP3sZB?^rIueeAnM1pi4=?Ys^c*R4y%9d2Rz?RtxFiVlG zvNgmdgdp=*e*ga3&Y%4${eJw~&I{`I&y$jWoFxACH3b@4w2zY? zr74i!i|>Hoh-xE^jw;hX(D`}-kVnWVQ3oFQ%2FnPu6!V*%wh0zKu((PDx8#0yc9i> z!aak34HrJrv7vBQDaB#;>9V)!s`o1jW8sol2c0k$P?VoxD21T>HRRV7YBIO#77sB$ zX-3Wa(rI<#7VNeo+EF(Q`Y+uIlHCqd80t1*s07zi?akdzM+CL2B(AI15I65%t#x?q zYS#N|?NWV1+F|EI;3h(tyh`LA5%Oi_g;Q=h?bR`8;-s}_P(s#Ze0onqh^!KD2~&MU zPwt0p+%^`0I%(o+Q+DeNK*&bg?R3Q>s>t={&(I7n&RxXRaz0`r86;3Qk-Ri_2zAY2 z3O-$bk7+{p;+4mYGPy8M!_;i8#AJRAjs4Wzu_t~-J(od*Ya|a#NzKW2f&UDO)ZHPS zO!QpYCv_hyBSnfZK*;Q?r7l|7)C|@*`Z)j@1!hj!SDN$_0h$TW_)VJ79H5~r?H|!Z z#(-w^q+L-AA*2(RIQf7S)Fn<<6K4+mGT!{?B)3 zc+C^!WFw!1;YW{PD2ua-`{y&9FV2zZY}+T{o<5#C1o*?Bl=~A3n-q55#>&O%ddA1? zFK4Fk%x4N<+~I^2=sF_*M=98UA#oO4G9CB{z8_Gafs{RP26$B4Gl2oWF>uEi5Pl-? zr#qsy8OpPnIdrImS%?YQ*@y|*Ifx0_<%Z4}+2s|ovK6v3fZqN9_{>OMpt(Y7BL96L zG!uFBBjm@g0i%K2Ah&-l!3(8AYBrGLWIP~UAVorI@DsqzpohUxhI)a_;ESq~Z3nl0 zVXJQP*Rb*_K>;1QDd~G;b8dWe<9RmVuT#|r<3dw z>N|#7{4MJAMYWAM?GuzW3#&4icLrKz#V~~mVBMz)yiMdo5Lx!``DDt+4_y2gspdV9 zF&5$kMEEZg`4u9!i5wC6CXwF-X_y#*|0-o44{rV%%>=pZ8MyHPaXbK)ilSs=EB|NZ zGN3c1jsh^%GrnwQYbPK^nMWEKKk|w52ueNssM*l~O7Qpr1vK_dq%5W@U;>{6Dq+2) zeOv$f5nu?E3|^3vkg6JcZfXOPkNeIV@r&wAxj*ns*0P=7Kj?$Luvc=VLBk7%brkRnqE)mM$T8R0g!nEh2`nnft>JjB<4t}n@6^G;Q^#!ovXU5lNIa?h13)2p zlp93PV+Z41Ra-=@!L{UEa-Nk~`A9=OTBsdfN{x))@B+UTlS${4TT zh~al;dxSBcLvQRJMtTaRPa|AJsIkho%)xTHJiLT8TvqUPCHWj`o}pa20v%aRSBIZh zuyt*)KG;at()Dy>{FK8hP{jui@#7szk*4x5f$aR<-~Y>x|1JI3SMWW?WbG(3C1gYT zH^_3}<%SJz)c7kn!_RLIYOQA2>V?@O3}B>sgrPT~?1)c{h;o;Cgi%hIY#P=?pd&2E zcEj}I8`+4yOc`6sXki-pUM%)|LC_EI`1xhDnn7^n7fJW!uo>ND%~$^ix&KDc`Y+76 zK2F+|&6mRtCNEw^BJ0g#^n}6ubwrsTR>b2C>qK~SX(mDa3crXN>13a=_&znG=}3D! z>@*XeP=T^zbk6xZL|m$On%!;?XRsen=-nvdg32xW?nTc>j70r`G(|7szfUzL0Sf)5 z%!$r+0;rI^6C7JYX@)_Jj|t?S1A}yxWV82z1LE2zqRi?YO8o{A5-;aOe31IN%nkh= zq-D7s;~l`?4w^}i<8y;MN%2Ff*10mJ4SM!fpXwT=NrMa_=zgOx{vJeAYPhs68m?gy z$+E2;mu3=m8y<+CHJ45nH@xE+2&uQFzQ^yKSq;(LQo$CHK6xpa*AQP z={1U8po;oB%sE_TgX&y_RlR6b;XId7QbC)Feg*aPKykq+>Q(fp!l9;Tiz42(mQhnd z=wEuQUey=)HqCg4$S#P?0jG6Zcz9T%qBfCJ%Q<;_DWJn?O`Kj+X0(0!UO~W+wp!Fc zutxL3>g1*3bnISMFcQz1N;gS|8wK@&YEyj_%f6m1WhHUEx!sFf^rnhuD!snPx^*Lsf-V=OEaw}_pT+>V1p~xGvbpD}v? literal 0 HcmV?d00001 diff --git a/sgl/tasks/__pycache__/utils.cpython-37.pyc b/sgl/tasks/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d697cf7953104627438a3b4b92c940dbcaaa11dc GIT binary patch literal 10871 zcmbVSTZ|;td9GV`SKoUsJ9}k!*$|)s?_wK4E`)%=U~?HoVg;lHbZYz5%yv)T=2Xq@ z?4~+G9y=iqScC^+<-zJn91A;lhd7ZU+@?O6(sJ5y|=T!ciR;#88%12yRMO8xFP-Rs?yrQbA zhIm!gRRi&wT2ZTr*VUR@M|?t^P#cIh)Jb&;@kw=Bok4s`omJ-$pH_FN^N7!=3+f`` zv+8bj3Gq4gOX?oPcd1`i_aZ*8?o;<8zMz`w0mK*8gX$r~cPsm4tNDeGv1)6xX{Y%Y zUJ8R@)O6CS9(CHmUZ)*(MmkJS1~)<-Y=`Z!4m;hbI~tk_^jYon_o6t|-Qjk75XQRO zi8|a*{yaP`o~Q9dzlI>Orr6y>TiK_rshuD-aS~_hFg0}#t%;ksQ_sY4F0l{nv(_|! zXknJl=EpWhXckg;5bURg_|`aV2T|&TyJ=pH_OA8A)Q$EAsoM{Ssn?CeLDNg~qrG^% z7pM7ta4qadsfPuqw17onA@L6^y^1kEx#!CD(IC7MMY~sCe(uFD^l$B7iMIP!;vm|I zuI$C#espZI;u#4u)zpNWCb;qMYG4&BY_#xt^8> z!Ei6=x24?c?5QBFA1P`F{eJ3S3pzXDP(^7D4Ti(CszMy&pgZhFakrBeuZ6?T^+BL_ znzr7+DrBonu!d~vM`0{)SV1nTA+YR{U6F^OzKD_zH-bNs0v{oThmc4dW$)Pf{{8cb z)w7{YBm2sehL3>+|_Qnwoxi}if z-9h(usP!eb%7x<*ci~90<;kxHst2-d_rpLBNxC|S!!zg{xXJsN{slOEl@tzJNU1wXjtM=3KaBtA= z^n)l0BM|&BygrJT*INpdY8xK7fR5QJ*X@R_zleN?tK+YTXMRl2AVZEl1W6+Xzb4+H zGs#IBD@8gVJ4r6V!JW1aoJj$>JaSk)a(+@s{4K88-fX(57v2c^X>~ACVZXgK(l-OG zgmIe1G$*BL0WI2Gk{^dS9Hh)L^D4a2?SyHG^%+N(rAd}4fHm73o;=!S&GH8LeES|Y z@g6KHsvt0jTybh(Vf{Gr^G#?V!{d>oejKMY+XQ7#a>{uf%bpbW-%4yy?LL-i^=#$t zIQna7=OjJ|p@=d!@e^;$Ns8aL(W5l+hYi$A=^@8JNm816%7ZxcAP%vO=vl@FmZ#iA z|1?C#v~Xxm{X^?9Ym!fLaWU}^Y@~C$XI{0?GI!RRR4}V5X5%c)=4YQXo7$&lgOa3n zhuIW9V>X4PI-gB-_XErZC74Y;sZ2|UR#Imuho2FA(=l{U;)m=A2ID^XPS~JX*23M@ z*%`G-4*ECHQh%92fdK`sCVgtkUtMOvbJSG?&4zvik)e?OC_wkrP>9J09qGr=q$NZK zmTA>9EzlSww7`I+U!HiH7Z)#^; zOKIeiQ^BWs%BGA{-6J`0Dzz@S+?(bkE>P<(xVHe!qJ!=*EmAl^@&?MB3>BYIC1hu4 zqBtCz11XvVDN4UNE6WjR@?cu%kD{o(HB2j}Nyf@0DbC&|+{_m*ZKwYJK=xi{R)_W z*AZLG-ju_@*v{pLQ@fYi`!rD$%)*6C`=t2zb~p_8$NGui<+z-aXWhP#IRkBy!R5#M zqfXF|p13@3+bpFy7_Hl(k=&=yEA@sW4XYhSPU;Qdei+FvrjUEPJJt`PQ8O<@ljd*H z15s&x_MT2SgwYzMb~p9-&X#Woa4|fhtR|s;Uqp^+W4*cT$_wCk^yzKvLI| zdeYb;l{lMl-%3LmgyNRLA{@4%e&XiJ5t$%bm1mZgx4Nm~DMKb>MU?L<=p+)`d*X!BKqL%;P1Reb4IA;6*xOMn$ZM zh|ym|kA;3-NEUsK6&(hY7-@co24&m8KTFRx1{tt_KUb zm*0Ze=ebYd`}KVIx!`r<;y{(S#wsqQPBczy-6$Fl0{G7R?apXKFAkCM2J$ze7>-)% z>d{S{%&4DIqz)R-uub(VQ#UY$<4Lgz0zrpIbcKD|rP!a1}B2JXOZ9`pJJxING zoKx5tv;rT%19qVh#iYgVKnYUrMw@Zz8#Z>cOuu(x8A=KKbs4uUhe1ygnO~T47Se zdQMC?5*Prmii!%x_O~D=VGI<(;jpDnVx`D~i<2_q3gQxC%&vDbaWThJFc3=g3Y85F z0&Z1#Rrn?hWSLQ8|GV+&q;X&)tR2`p*1NV&<(+2~=|xrtt8L{DWvj_bvby7-$M3&u z?>+&}M;^OfIAGbINLd4^pP&Y9QPYr`50AArWj`^krz>U*u9hX(5$@aC{zsVKn#{5y zvvl+acbL)V&C<~iVZ@&wo28?lJXXfJIQnWZoU^J+eM*V+BivDLV`cC3u6arLFFY_S?wKYFYk=j`g| zkF_&pZ%WxZYQHmEH7M$AvJP5jpWiz=&as;Bn4ZUW0Q1ETxYKhcGFihrYDpuhChIV9*R~1^n+FgSg;-|X-1B6Xc;rb` zG%2H~%lKo==Vd$s(ZQOtck&`_jYZ(3c97cFmgS+}jRxW1TB!epJ!Lri_T2*@fC4A9 zL$w_qv1n$kjk$fd9-~)UQr)0!ErSt)X~5=!ZR#$bfx@@J+Y+5c9x1L%^Hys1N%LdQd5- z4+?bHL$ugl!%+GJ$16nF_qMiB0W_r#s4vZ57u6!BSXv;KD8mL?2e(oWI8v0BI+32O z$i8ZhtkTfvDPSg?I2r39lZblqudDeSpBqKLm&5j5sR<6;IixoIH%oRhXR> z{Nd>1f5!3MJj@UNOMn9!c=8f5wT>Rg>sb1~ZL?^*fnI#p{f7InU2^zl-`POfhO0T_ z4$T|>2rSH1{aeV;)WHK7n%*cF*P41r!-YD8=z!Ap*ur=d$iVqo=EGOmx#fx!| z03gCJNb=~Dj|%`#^8zCRIg``^Ld^VJ;WvFhX4EVSIqDYD`aK2$20efjh&eRqX(8%v z4@O<3=>%!64q#(NC@2F`(;_S>Lo8`=7;ev!@pXMMgl7H4D$Qj_lH9wECW0ql5V zTkb}d#mefuG3;w|XR-zaEQ8B0fe;?RK@j`Ve}UO`xbgha>zE&ahaLi#`}!1YCxgf4 zk0Crc*b1;&;BnCJ0EUBn3HdypLhL859GW{VBJIr%ZCXll*!VKr!F~Z)37gE6Rp66U zvICSOo96(7P&eQJ%L>pXEuLwG5+0rjAV7FB4j>8#cAc+30#IQ*nf$aS9!;6PORoXD zNxT@O;Ljxfj)!qpU@=wEyAIpTq#EeIb_g&Lz6~wr+~(BQw9siWs6y{sOQ9T)J;hnMv4p}qv24g5WxG*ZhXBhw+xZ~9=rcJgU1-$ z!+`jR{yqabF#69J5RPh<77|HQO6cEV@B;=vWUvGv9;YJdd!AHuFWOoe8&j0r4>19H z@@c`7wMkb7S=4`vmUEw$mo+&&V#$+%=2T^vI1lRafp-wB_6cnrI;5#y9yC?}g#jOd zwdo0nU6v9LfS4(%;6&liLW5Vu9|iXB3%qLpLmrR*O9p?%V6o{W|7Oj2gnx~6t4a|( zXHJuyGia+OB70G7^7DFg2*+d&v8>!1RAyBiHm%O{ZKBH4sT2W7MRNzxx{Q0u1~A7l_BOe=cKxpX?tu zD2@PY`FP_jy@Gem(s0eh9U;r-H?L;K9;gM|W{@w#`=3HIgqoHpOt}1z9fhrh!Knz0 zAb6Ik=NK%I&$R^+>c3&tOAPMB)i0o;HK!`#oeQc$PLt0J*FS|+nv)ASVfa>w3dG#V zeKz0Ubp!=HMsy+qp2v1XU&6pM5dcsWuqXkH29yd^OWwChi4+5?nV4wM>TvWwqSqWD zCx}0%&1W#`OdKd1w1m4ZiG9$+;u5fgyDPZ;Jr`<*6v@1yNKzJ^Ma zu>hg)3sj0PDy0TDAP0)9CwWk0HK_|_0a%Mm^ITzflQ&?XPuL)!r3Pwe=@q<}obdmD zuP^V-C>m7C_nKUB_anU5$GZ#-ucCIAhGRs!h3i5fE(aiPPP+5NG7VRb$VMSwA<)mK zVIg0No96lvQvFR7=qDKnseT`+R(&ScC`(D7jk#o-m&}@oKxLbid7D|$5+Q4X=ouLc zNeh6e|Bj_UV(|A2K4h?@$j*HhK~suGr00jwO8_~jFM@CxC4;7EGm~~{GA(lYpU{2| z^&hiaNyz|H9JrirW}o7sKJf_L@KVgA#Cpxqdd0)*d-oI|C8SYr($Ye%wILH(O4;b` zCy5v10}cWS`e*D9oA=J67k=UBxC%gjsL;i1AXDLngNr`(=H1uf&S&o<;AU21Qqp7o zJ~O|?ijar(eq>dj|6W`e${mwqE*_e>Ws94}wwT>8#^?4=%O3#8jPN-IzF~+14X+fR znY4;5r4Wnop%jwlV+33f4aT!#CQb1&-K?+&mHk}7zei%ZB*ap(&)b(oyw#Wjq3{~8 z;+zMeAnoXXCW(kkkA#nBZu`843{gxTp81mF&<4hl;ka;7Z~!<0xIM6tp^6o-OxQw1 zbzOy9Bb39oKD7Cz2atWFKxkRp@~kE9fpB%tmQc4@ODbM`$`Ji3W%@4{9;h~cBdxsAHRtu#k$A;K*Yz6!a4 z#xTd;N0b6^EU0i>hd7bChDm#)9qnp<60;Sz*2L{!Dr_8SrMuhL>9E^tT^- zdNjnhZ}G+8V$|u5Z(TIF`9%ZsT#QGkkdLlr6$Tx-jQdlKnWgz4;+JX}3Ixa|h`;*1Q{1yeXscred+e1167%RI;*F%<%lxwGK^e@vTwg~9p9|e&SpmKXzKx0W zE#wScED7CgX1Ez>6Yb|0V(l6FFHp{px%I|;C0KAyfC>FmzT}c%M~3f7iDn%Y*^0~A zZ2pbq1i)JLXh%M%x}4p$Xfjv=TsY&cbNI~evV3MY>vH5XLQOW6FNE|(rtV|# zMF#v@N4|tG&L3}k<YXlAyC;m+UlwQI6gwz y>xZ=Mt>_U1;#0}q;Fj6(3x3hh;a~F4`&IvxU-wUfJcP@ literal 0 HcmV?d00001 diff --git a/sgl/tasks/__pycache__/utils.cpython-39.pyc b/sgl/tasks/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8392fe8ea3e28d6a12a0abc85e80ff2dc389df3 GIT binary patch literal 10956 zcmbVSNvtGCTF$Mqvi5pQzh2$lZp^sLKn8zKG1KSLL=1 zR4HTmW66m4_dhbM)yfurcYo!%$oisX{W~j1e-0{7;fcS8L|DQeT5XftZClr!wySe* zRA~D)+nix>6tn|fc88@=xm_Mr+7(^*hSgE6T|>Db{MW5^T@*!tydg@WjJzo-qKbS& z)I=Tmrf7&J@-49;Hj!_OEwPRKm^dbOknf1&;so;J;-ol*{De3y&LBT29uQ}dpAzT9 zdE}?XgW>}6GvWutL&zTxKPWzh{H%CbTtt3Ow8SIG&x=osN0C1$?3b<9XFkBLt*b3t z`7gW_h2yy8sG6K~JK;gM6L%*vQpdxakqmdE&QwO-e%zmoO#{Ykc83RX63PB}w=;?o z+3&_(o@f5^P;l`)jVJz5B&jvW2_M?RK55PE6s4JyI&+7mxqE2M+|-?WCQk~fea}8^ z&HY0QD|N1ZU}J_>QMse=wkjrfrco!1l@skNUrY|J4>IeS6L(M?1f7y_vQk;_Jlw0-o^4f>FT=5zjFcBg*qAe(szNf6-5aVh z3daZGu%qE#_dtZIaRk%}heH)y54(HOSj4J;4&$+^i3sQ%^~e1<>33D>dNl6d7=?1L zWy>AxLLZk&wop|;93}b<8>q!~B$i#atNLLo&jZrsLF7M@1RqTj7g0zZVei@Uq0|~! z&sbmC1_{`yvkfviZKbYou2}o^Y{MTsj{eFSN|0W48p?3g*_D%nX$(Y{Cex%p>feo| zyue(K)#(}6v{D4ScXY-3WH-Gi3Pl30}rFr<%08H0R=vK1D`{| zV0M6QoWz269niEVoYYG})v0sR>R}6sHE>fa^>!WM3hzw^_$yqMBofJYeeiw!-m#Sz z-tJ>F@;G3sfW1tBnfOLHjZ~#O8F#V1F`iaI1-&Rt4rCO|2K%~waa&c!2cu4R7{+lF zgYd`EjY+b;=W-}S$MC{A49s@AZ8vSn!|ifs`LBe+0#(nTLhe1ZZ06wI)H`%$1zpB& zQT7ujEu=u-N$Z|7E28G3h7&+7NQ-IE!@l3KueMy}MK{Bts*NTh8g_aUc`KAcGf}Ig z3L2`4=+WV}f+PZxP%>NfMRc>@jZ~TK8FyE7m#k6*n|3%o`L)BA^+O2w_5&Q`L)cVY z<$i(es$F&JU}yOx>Rn9&Y{8Gj&ALYUGr(GY5W=4M!g&>&#ReTv?`yLX=-3wSo+DpG zFDEU59Lj*XX$ePG=(&9>Ex&F5hK*5`S!uk9URm3rYhY<*?go!f?LhLaG}C_^=-qG?P1 z2)ed4;enOfjZ8~)$XB44cJnBJgry4C!-P^sFFWdpem9ZSbF#_g3a2yViz(!f5%%Lu zsB}!CrRWv~Zbwhm$+vn5U1BCAO%D&HQQ1D@Y(nRpqSkQ;U5CkMF=B;ySGf2ZT65w} zy}6xnGbNQv4hH}Flv^1GyGLr^U}|P?zc(-FyhzP#xLSS|6GJr@5*hZ#szlKR*&GUk zC#pxIY{=-)RY^28=qVZWl=Qd-lWRmsa;GW|Cvn{AjaAij$(X#X!P(n1pYz3Q2P>b$ z3@_q|i4GoQ;8byRRM~-}Te~Eo{I|mA-vu;3DsWR#K!_nPX+D1&JY6_)d}Z)xMZjDf zy5QAnc2L%!kd`SZHP=@%uCHONpVlB6%leq>X`Kg&gDsG+k=N#daKZmB_`d|i*s})s zx-l=CG2s6S_`jO*Klp22)8!y7CH1tdWkCtDAVAHb2?beDOiP*z>)=52)*R>=74#qZ z{teSL^ufrV0#vgaY^kZZu1Zw(IPh*{aQY>-KhNYNxsVK~Sux{6`4a%-h+OA8jc!S7 zwV~^@c}B)-*O!>?z`Sc`q(X+GPLe;# zn|Ze97*=Z+ZSC)98>IE?pT;y5TT>^3tMbtE>vJ zgm)t&`cGht^2QShcK}wC@<#A&jB3!fm3y~8m5-uR%hzP1{5y1SglgpP=|*E%!m+aZ z%HuoRfuY0E)U5PX43=7YSMwq4AOVg^be7Wu?VNEgLQFU9$7omd0qttHf$jKF-tYjm zUGXgxa(Qzf(HtZc#zkSj=|BdR;3@=g7(fG*8MHBy0J5SqfFzn%w9KeLm=t?!aFRFVPk5q z4@Ty=1I?9yy(wvp$k2+Jl%y)XexgqmS`kPNlMu+59;gC%cX~%^995o$97wK3<9H%f zGe61f@H@(ZFS8UhG)~*+G7Sko^l7t>muN*kD<`8D>Ej($`|<@Yix%=yqi7K$2ZZf0 zq0~_R9tk9t5g|P$lq1UNf_|V~+U|JNZwK%fv`hRUsgkXmK7^Jkzly;zJ&8=`xK&7z zO-c|uK)DQQB5$H~g`G$V3u>SZn!AiKZA1r7v6X0d5n9C>_&qop#Qh9z^PoIVt9gB zhT#)dUE1#p`HTtxp;e(Kq)(@qTbx$Up5B~NPQ z6v&=bD4HjJ7Uz-DdBPrf0&svmgkP+$YSAL>UAxD9Lju0(oWS{@7dQegwmUDP3=0`& zR62w&x#z@JpSOO^#+kwz2fx@!h2aH;A1EhhF6xkyy6h!Y%m59Inb9WP(-?uig*CXB z!TqFukCwPc?t#?bq`Nb3CL3u$xeEdKUb2~Nfq(Hg1l;U|Ez{_JE@EH9ZxH1neHa+j>4PZS>cP~>(;EHSJO;Rq|JLa z()vAn&w9(2|EAw~QjlKszF_ce+y}OqZls%g4o1B8mc9SO;C|F`&c%C#{kDcRQFWrdFX+zO2A zaOCqDY`I=X{(^>Wq4gV!oowmew^wE`uy2^&YdcYTC)?P`w%$oi@5Ghg{6FvHKi_ip z-(J~?E3Fk6cjC(J47S{fD=)6}<9c0rWu>2ieObe{(faj#$DpHA={BgFW2Q&vS!wyY z8M(51tQUPwr}d1Q?AbAJ>Y4Odig*oqRDe4X63J_D{}~Kfa_np;#eeX?acj1dY?4}M zP2@Pq**0>}#%wD&J3p7!4{e<{4{ZmwEP4@I%b-|2pf!vkw3#b9If-T8eQ<=(qK-{;BC#7Cv1sz%#uh&>Z=t^`i+>Nc?@q?y(1`Ei zIJ(8Ck=}}GL_-9@7$fSK=n-tIa2LT+sK_zoIh|wWK@S{g$&0wr7*1?&2yu?!X@odw zG8w`egVqxQYJ!3r_7Cl|U&K&(jN^;(je}kf4G6^y5ieB!4Xr}7>7|McNa|@rDZ@L; zLm(CD6|VLj-RQ$?2O zm)91ph(*HGV6|OEfKwkXy~=hK{g$jxzO-~PMmMNc2j96{UTmCnFmVI|H7oW|b={{zc zwaXMZJAm!D^7pXTE)ALdS3+T-(qBV`h7BIXv#El?u-4o|8BWt7LNykTWhlo~IH5FyiSXF}6WC%u;#?fwGOm1+F$xxUOW+^94 zkeeQEE+w*S7sM@Z?&`~yby2dp=nRY6T*7SIAdfnhV%ZnqV>|+6FeD`Z6p$_tKL63z zSRlc}DB3*-mOKGF$pm%%Bk>GGDq$VKMuDF}Z=(Q9%7KsJ&V4+^BuHJ2ICow`*~<|J z+p2)WudpBN6;Io!23Cb9Qq7R4(YC0$MlS$~6~S01ES!0j(jERt5puf%B*H_{tNkYv z>wv}N&+FQcsnF&cG+`%c-vuRjF=?>pVV(_GNHvUZK%1CS6XUnw)6k=Vjf@!NwPh6E zP{Q`QFQfewXEaiFazFZzbou1(Ga;KEaaYEyY72TnFh`ve)ily%4A~4fN<@g@-|8nf zI{L~Ymfz;^KV-tN`7_o`4WC>Wjv%D%(jhM^5kLLUbjzmCL2u=%u7f5lBV`5Sq{s>+%B+9@=6RM?ZKjH_lAU3ri zXl8wyv_BmA_ZYQu^T4S8f-bL+E;Cglc=_ccq|Q75H7xHF%W0X8{=kK@AvLmYsFBoJ zOe>(PiVnfQg0U0}5DK55RsyY75N1ycpvH#2Fsr2vO;re~C1ofKQ!DPjh*r$6>8Ode zyu6_~;Q#;5031-uDH-(1cN)4b?f)YJ^>Fj`J2%mmm$%@yfL`IWF!)wL@Xe6y{phxO zL>4ZE-^bIfCRt6M+egUri@=(sC|Sz7b|V*Kl%1r>=0cJ!AbA%pK-p(w(I*G35wH9^ zzU5sc?Xo6g9W0dJA@r}9{56xmVX~&a&U_RRQ)(up;YTq_2j4)$47g>~3p%dDVkW(U zmXFl5ceWDPr}4Bfn4ZF%2%I=@G~LWP#a(~uF+>9?nw7NeTA=NUhhG4=ClEM78SSR5 zd+1wnNI^m=4H1M-@D&I{2ME%T7d~q4;wR3H&_BjQZ7$aW84A}MeDk6A?zi)|(Oow? zF(uu9@iw!%k8Zw0S>GB(B{V?*{vU(BGb2$_r%RlTF^bAV6&c7wfI;Fwl}P2eKUyrMZk(F zV;w854kbscYGMj*bSKh5106?@zk?uWEWpdrLe0Yry7_Ub4S?BWZ81bu-X9!vMU) z35TEpv5w4RI!Z9<#CJyK+cVww9U_#Phu>y3@zjCBO`UGr@}B`;2{%&)E_Bq-p<;9t zI2f04I7EwD2u3Rocz}@dKy#9AMc@|gw$@S>uU&vz4_7bNfEJURc!S!!P9GLXrNv?2 z{ruCDF+QbB&WGpYZhv~`yor#XH-Vk=$pj7h+pWC8#78c*8uAWSru;DGM{il&2SJ-K z`RwB+lv3{D`h(Y{xL3K0zVert@cIns8z-N8Li-`gdzl|dQ8p^yyj)IE!52rrOjr?K z2tKWeZ>|^L!-lFI;5dpYVGxC3tXE!T5D*zpeal5 z@(7vg!=rVsf-CrP2Cy;3KiiQ)_krNbyQc%FZ9Gnem!O5Tz90w#0D#7`n M8nl9Q!A?;BALorsasU7T literal 0 HcmV?d00001 diff --git a/sgl/tasks/node_classification_sampling.py b/sgl/tasks/node_classification_sampling.py new file mode 100644 index 0000000..dcc6f0a --- /dev/null +++ b/sgl/tasks/node_classification_sampling.py @@ -0,0 +1,127 @@ +import time +import torch +import torch.nn as nn +from torch.optim import Adam +from torch.utils.data import DataLoader + +from sgl.tasks.base_task import BaseTask +from sgl.tasks.utils import accuracy, set_seed, train, mini_batch_train, evaluate, mini_batch_evaluate + + +class NodeClassification_Sampling(BaseTask): + def __init__(self, dataset, model, lr, weight_decay, epochs, device, loss_fn=nn.CrossEntropyLoss(), seed=42, + train_batch_size=None, eval_batch_size=None): + super(NodeClassification_Sampling, self).__init__() + + self.__dataset = dataset + self.__labels = self.__dataset.y + + self.__model = model + self.__optimizer = Adam(model.parameters(), lr=lr, + weight_decay=weight_decay) + self.__epochs = epochs + self.__loss_fn = loss_fn + self.__device = device + self.__seed = seed + self.__train_batch_size= train_batch_size + self.__eval_batch_size = eval_batch_size + self.__mini_batch = True if train_batch_size is not None else False + self.__test_acc = self._execute() + + @property + def test_acc(self): + return self.__test_acc + + def _execute(self): + set_seed(self.__seed) + + pre_time_st = time.time() + if self.__model.pre_sampling: + # ClusterGCN samples only once and the sampling/preprocess procedure is done before training. + subgraphs = self.__model.sampling(None) + self.__model.preprocess(use_subgraphs=True, **subgraphs) + else: + self.__model.preprocess(adj=self.__dataset.adj, x=self.__dataset.x) + pre_time_ed = time.time() + print(f"Preprocessing done in {(pre_time_ed - pre_time_st):.4f}s") + + if self.__mini_batch: + self.__train_loader = DataLoader( + self.__dataset.train_idx, batch_size=self.__train_batch_size, shuffle=True, drop_last=False) + self.__val_loader = DataLoader( + self.__dataset.val_idx, batch_size=self.__eval_batch_size, shuffle=False, drop_last=False) + self.__test_loader = DataLoader( + self.__dataset.test_idx, batch_size=self.__eval_batch_size, shuffle=False, drop_last=False) + + if self.__model.sampler_name != "ClusterGCNSampler": # TODO: need further modification + self.__all_eval_loader = DataLoader( + range(self.__dataset.num_node), batch_size=self.__eval_batch_size, shuffle=False, drop_last=False) + else: + self.__all_eval_loader = DataLoader( + self.__dataset.test_idx, batch_size=self.__eval_batch_size, shuffle=False, drop_last=False) + + self.__model = self.__model.to(self.__device) + self.__labels = self.__labels.to(self.__device) + + t_total = time.time() + best_val = 0. + best_test = 0. + + for epoch in range(self.__epochs): + t = time.time() + if self.__mini_batch is False: + loss_train, acc_train = train(self.__model, self.__dataset.train_idx, self.__labels, self.__device, + self.__optimizer, self.__loss_fn) + acc_val, acc_test = evaluate(self.__model, self.__dataset.val_idx, self.__dataset.test_idx, + self.__labels, self.__device) + else: + loss_train, acc_train = mini_batch_train(self.__model, self.__train_loader, + self.__labels, self.__device, self.__optimizer, self.__loss_fn) + acc_val, acc_test = mini_batch_evaluate(self.__model, self.__val_loader, + self.__test_loader, self.__labels, + self.__device) + + print('Epoch: {:03d}'.format(epoch + 1), + 'loss_train: {:.4f}'.format(loss_train), + 'acc_train: {:.4f}'.format(acc_train), + 'acc_val: {:.4f}'.format(acc_val), + 'acc_test: {:.4f}'.format(acc_test), + 'time: {:.4f}s'.format(time.time() - t)) + if acc_val > best_val: + best_val = acc_val + best_test = acc_test + + acc_val, acc_test = self._postprocess(self.__model.evaluate_mode) # Test the best model, this part might have bugs + if acc_val > best_val: + best_val = acc_val + best_test = acc_test + + print("Optimization Finished!") + print("Total time elapsed: {:.4f}s".format(time.time() - t_total)) + print(f'Best val: {best_val:.4f}, best test: {best_test:.4f}') + return best_test + + def _postprocess(self, evaluate_mode): + self.__model.eval() + if self.__mini_batch is False: + outputs = self.__model.model_forward( + range(self.__dataset.num_node), self.__device).to("cpu") + else: + outputs = None + for batch in self.__all_eval_loader: + if evaluate_mode == "sampling": + sample_dict = self.__model.sampling(batch) + output, batch = self.__model.model_forward(batch, self.__device, **sample_dict) + else: + output, batch = self.__model.model_forward(batch, self.__device) + if outputs is None: + outputs = output + else: + outputs = torch.vstack((outputs, output)) + + final_output = self.__model.postprocess(self.__dataset.adj, outputs) + acc_val = accuracy( + final_output[self.__dataset.val_idx], self.__labels[self.__dataset.val_idx]) + acc_test = accuracy( + final_output[self.__dataset.test_idx], self.__labels[self.__dataset.test_idx]) + return acc_val, acc_test diff --git a/sgl/tasks/utils.py b/sgl/tasks/utils.py index 3a8a631..83c2a14 100644 --- a/sgl/tasks/utils.py +++ b/sgl/tasks/utils.py @@ -1,7 +1,5 @@ import random -import math import torch -import torch.nn.functional as F import numpy as np import scipy.sparse as sp from sklearn.cluster import KMeans @@ -45,20 +43,32 @@ def evaluate(model, val_idx, test_idx, labels, device): return acc_val, acc_test -def mini_batch_evaluate(model, val_idx, val_loader, test_idx, test_loader, labels, device): +def mini_batch_evaluate(model, val_loader, test_loader, labels, device): model.eval() + val_num = 0 correct_num_val, correct_num_test = 0, 0 for batch in val_loader: - val_output = model.model_forward(batch, device) + if model.evaluate_mode == "sampling": # clustergcn still uses mini-batches during evaluation + sample_dict = model.sampling(batch) + val_output, batch = model.model_forward(batch, device, **sample_dict) + else: # other models use a full batch for evaluation + val_output, batch = model.model_forward(batch, device) pred = val_output.max(1)[1].type_as(labels) correct_num_val += pred.eq(labels[batch]).double().sum() - acc_val = correct_num_val / len(val_idx) + val_num += len(batch) + acc_val = correct_num_val / val_num + test_num = 0 for batch in test_loader: - test_output = model.model_forward(batch, device) + if model.evaluate_mode == "sampling": + sample_dict = model.sampling(batch) + test_output, batch = model.model_forward(batch, device, **sample_dict) + else: + test_output, batch = model.model_forward(batch, device) pred = test_output.max(1)[1].type_as(labels) correct_num_test += pred.eq(labels[batch]).double().sum() - acc_test = correct_num_test / len(test_idx) + test_num += len(batch) + acc_test = correct_num_test / test_num return acc_val.item(), acc_test.item() @@ -76,24 +86,26 @@ def train(model, train_idx, labels, device, optimizer, loss_fn): return loss_train.item(), acc_train -def mini_batch_train(model, train_idx, train_loader, labels, device, optimizer, loss_fn): +def mini_batch_train(model, train_loader, labels, device, optimizer, loss_fn): model.train() correct_num = 0 loss_train_sum = 0. + train_num = 0 + for batch in train_loader: - train_output = model.model_forward(batch, device) + optimizer.zero_grad() + sample_dict = model.sampling(batch) + train_output, batch = model.model_forward(batch, device, **sample_dict) loss_train = loss_fn(train_output, labels[batch]) - + loss_train.backward() + optimizer.step() pred = train_output.max(1)[1].type_as(labels) correct_num += pred.eq(labels[batch]).double().sum() loss_train_sum += loss_train.item() - - optimizer.zero_grad() - loss_train.backward() - optimizer.step() + train_num += len(batch) loss_train = loss_train_sum / len(train_loader) - acc_train = correct_num / len(train_idx) + acc_train = correct_num / train_num return loss_train, acc_train.item() diff --git a/sgl/tricks/__pycache__/__init__.cpython-37.pyc b/sgl/tricks/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..320ffe7fa8cb05e311c4f5b58418da641b608403 GIT binary patch literal 208 zcmZ?b<>g`k0=+`F)NmmE7{q}ACLqHBh>OL5L<&O+V-7~MrtU$G3;+KYgMt*LpesOW3ez3cbZcb%|esOw^eo0Yga(1zPe0*kJW=VX!UP0w8 V4x8Nkl+v73JCF;CLALQQ0|3KGGfMyf literal 0 HcmV?d00001 diff --git a/sgl/tricks/__pycache__/__init__.cpython-39.pyc b/sgl/tricks/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..536c1fd317a2a22554508c64ed34f75b20eb2d53 GIT binary patch literal 214 zcmYe~<>g`k0=+`F)NmmE7{oyaOhAqU5EqL9i4=wu#vF!R#wbQch7_iB#wex~=3oX* zmY0k`NlnIE0?zqGMXAXpj(I7;x%v4e8ATxNewxg;gpwgL@rijU@x=(aTkP@ii8(p( z@hcgMSb=K6#4io~jQreG{o>+6{a|+=-JHq_{o?c-{gR^0ARsz$zvicVILA}`H;Wq3FrB+&*R!Zo<;kKROb3Q)84Y8ADg6n0dX}a=dymQQv@>yG24wt{!U_-D zDH}t+OqwO_SyMX3m41mUNeLR(Kr`YhIVQ(+1L}#kA3ni;ZXMG}dkSejhCJ&n z%D?nG#{@I4w*>XVjBU;fXY8V$H8NSA5?oO_?~-w&Y?Lnc^xDEj!n|P6>xpLRmEPnc z%<_c4XFnhtnBQ*ct&l5_!$;(Zc9D9mLvP$XYq4yK%33v_^j7K>Rpz%hU;YC{Lexl% zfCXC{>K0Vh0w~D@t`fnf49H?_?m!omSHI4Md&o=3f&PGlM4!vcFkHhGkc_thW7{IS zdmX<1gi1!Y-a6M=Zge)X0uox?1)*LP(6BWql*U^qPLz6y{yhztu|(Q8jlCopfN=#4AX`2M~N&H&(ggF?yKDh7FlVojirWti6Y+W#~~i0SNP`M}3IZDs2I~SLhAq(yP#3hT5_g?aOqu`d?8zp}7m1us(#UwgKLx zIgp!bAUKfwzJ~V(i0sbDU6Fi7ZV$)}lt!YL7E-#jHi!nwR}}Suq6-H$P}hN6XPa^r zECnt30W|U>1l*Fu6p)`F1PHefZXW<` zjIw&9x`C^WnMPjl9T^Gv6*`LlB20UBzsE_IaenF|Fy<0YLQf@f3{OWPcx=UAx>ofC zUaj|HmGPZC?{tTGjDIAe%&@~ literal 0 HcmV?d00001 diff --git a/sgl/tricks/__pycache__/correct_and_smooth.cpython-39.pyc b/sgl/tricks/__pycache__/correct_and_smooth.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e5804ac888d6852127d5a2645b47158c500eb89 GIT binary patch literal 2365 zcmai0OK%%D5GJ_~tz^k@oR^&hXwfEZl{8Y3OItuq)22v(qVUBz=|ZfRD|_YDu0m36 zEUbF50lSC%gC3-#e?|WXuRXN61!zybq%&Ok6}u<}e#^&jIGmY}RM=?L2$a@~r*XYX z$nV%WT^#5=T!C1K7T9mY(EN#*zLe6YGOK*IUp3D0g}DHCpA%SdU~D zRZiJ~Z3Y!RpdJEfvP-6P24^m4K{zWIpi?-2Zs7uYg$L*tK43+Vb#h48nt}1uXc$Wq zaGqp|=G@%n?^?roHX87#w2`ObPArvq{{pcikuV)@h33u$5~xVoBOd8_?*gg0M5*TX z1p?tn=PC-**sNSIYiF7j~s={uVC-EAhn*WT0z3CR8# zg%tsQXKVsFGp!c1@66}~SJnz#Nea-g3YrmD$N@Q^9jM3JeunmfV5=mn>G}bkHfE6S z6UepxqWoK5Iv|*R{UxZE=gjk7c@rP?tdi;SjNppGyG16IqEh%c(r*YKiSvrVXdtRZ zPz2K}AQgz(w)-vVV2=AmaFyJEJl-e!w1*U|?*)_U*&U0ju)A4#(!W}+!0c}})?UAb zBEf4ULO3eecA>6ARds-ZOyOz~Y{p=jv$hFCP~Q9|Q40A9c`z7q$O-yVuE1CctI#q* z2j(_J_~?D;{RNeb@12Y7JU?2UG>4>CkDyU5YSM0fm@ADpQWPuo8f|;tWX2K6z*M&5 zcnA$^NAf2m!MSr_VqAf$mH@MEs0)uxXq}Gl|NlvjS2nL1CmN27r*j!?8DGblf~`EEcZRWX(>&WW^r@-jS~h8P=@HOE~WDaFzEoezhFM@@^?C`rQDp0UR@nCHrH+ONV$W*&KxO8i>k|5Nwn` z#FmY6Wgfw07*Tfs2n%RHYv806S_f)hr5`Y#-h}UUsC8%2y-vp~|K+|j9LFGGeF9Z& z0KCg_U^dmjZeaFf4UZ16*q<}oGJMYL0GJJYMwS;2^0{z2L<8R~KWo6xgp_yU zHF*&g{P_Pk_e6<@S09K`UV)i@vIVw7UL&iRW` zn3gm43bqF`l`Q9*G8D3dhJFQL8uJT2iqn+yQzl+V5wuexy)ZV`$bgZzjIy^&3qGyg z@~NF|2BrQhHtkzk*6NM22>)4xsbN`WYA1c`*zOWM9If)gL7_DPU4|6<7l>M7kN^Mx literal 0 HcmV?d00001 diff --git a/sgl/tricks/__pycache__/utils.cpython-37.pyc b/sgl/tricks/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b2be5c68a7e95e552c51aaaabd4e25543b67ce6 GIT binary patch literal 2237 zcma)7O=}!S5bd7to!ymWd2P!O$Q}}sg+u};#1M=f?8Jd!oPz@n3t<>d_wJ4~U+Qkj zS~GjHen7}C*gfVi^))Bm47uf!yzZ5RB67&U)YMd0cUSeRSN*i#j|raLqvw1O5%P!m zxqLh*-{6(MLNLN;K}tGy=vsHjp0)X7-+qHKbO zGdk|wum&uIwa+5fIV0l@7PBtCo7`V+tqAL}{u%b3IBbJ$p3!%7e2dXnWVCgTsL0Ey zk%YW|a_-=hx%Q6`6T5kz;g!1(iY#eMTDo!;u5y-c>$L7YvZBf}x{1%){T8ueb^3A? zX!>0DGf_*K@JflgnQLFwBAaQqsHfU1Q#FI+X*Ke+*9gvZpi(jAN{76Wc~Mu|D=__c zBKip7uaEa;b;whl)*T|DEZl82LY$w(dGdUzh{a!~MKU#XR{}vdNoV%@SmqD=B=q2@`@a!Hhi7 zF}kT8(KVuHgh_HOHlXRD$>XVA`Fb|5u`EWgs+n{>uUq)f42 z4x{&h+&!`&%3XRbS^5kJuIR$Se*k1Db9TteS%xd1%G}n4Ju=8XGHY~Ny%k+{S4In93)>JU=)Dvak)0cHQHVI4o_1uTT|Vd9m$S6t+Nm1t zr$VH2?aJd)`%V3Z3ms01R4HC*KbLtm(LT%5snjm3mF_5!Rf zF7>8J^6Dg!M?xh9pD2A>Bq+xl#;t2!BA?Eb9^7zSTQ{(LQq`hNi~KETH*vgUhE<)& zd0BE(tk&0nB?l0M4(Jx%4H`S!Rt~861oq4vi+KzYP3>4He};ksB{)i|oC5;|0SSGO z5(j+{;*BPN(~fp5tcIjM>@j>D=w0YN=&te(`pb=$8ZxhJNxyygtgcSDP{VX6v%HxP zWs?fYhd3RPpAJ=x0ZZb`0i3_A#qMa+kUOhu9i&pt8!ql5PyN2U0 zORzLJkt)p&wJYj3+Rf?$C|2sLFSW-~m1^%5j#At)YXs{MboaSg>nFP(I!0%0PRdgq z<`v5`;2NBy#W9z5-c1-T0*1cc3 z^fO2>&}HSIoZZWAdR5KupNyyuvLY>;1R6cP8C3KLWBV+9Ji=ms;w9p1h|G2?D5gVP zL2vF6Kt~ika1I^u5xBNyFSrZ91K2}T4hSj&bAuK)%V9_SpeQ&C#v3NK{(aJ71Wb-w z+y;w~0sI}<24oSPex|yu^NxN?ei<&nbIk0uAsFinNE@MJ1g+Nwm=zq&USp0;`uuM^ zVjIDY`rx1i-*g_+ysPQDLxy$5lNop!3Vy+9Rcqf?hxRi#avcNkS|&w)$nB|X9~7Q) z?a8JrN3jlybe|Ve7@TyBDoToU&V@v4D4JQSdtkdt8d0MXq`6)ss;fR$do5X8Zy;Y2i|G3_~#+XF`j(8P}z8*S^J<;m{4l*9uV z#cc@f+d8*2SXcjTTXVZrwL3YkGP6foptV(?#nrz(Eqh0-Zwqr(3DYWUM_3qi?=C!qfTqg6&q literal 0 HcmV?d00001 diff --git a/sgl/tricks/__pycache__/utils.cpython-39.pyc b/sgl/tricks/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40ab04cbdeef6f00a1d33a645900613b74815fad GIT binary patch literal 2240 zcma)7&2Jnv6u0MlXEz^Bw{6OWl>mtqB9Q=<164s3iiE0Mf&^L#jizJo?j-XS+nZ*q z*;8}mfcOh`kNuZ?<$!WS;>M-?>`kH&35k*X=K1;Ae$T)6v(t@@nBdzxdB*z@A-|i> z&Cf&UD>V5l1S5Rqt2yAJc^ma<7Jn*E4oX@lzGfQ zr=#92YrsNSH(10v=VY|WV%Eifi~Gy%6=6NLagM#m4%=i~=kzrl-C^{Q47V>36*>6B zNJ9Sn^1{I{b8T-eO#G%jMU#6FiY#eMTDo!;u5y-c>$L6zvZBf}x{1-+{TA^ak~j2V z7-;%TZ)BpDGU1gHbu-t#szo-{Zc$IPSEgzT$(q4h- zZ;04N2!DUDKdno?FXhSp;djqJF6L+Za#HLok!Qzpf2Q(6?qQSd=k9_GvTfZ%BFq^Lec|J$2QTs}74zg@DJL&;HBFFjuB0%52$TyE z9V5C%n9SDV4m3S5!GCQRx|za{SOp_kqzA8?$@J0B41QOiS>MECq5BO=dzsFW!{|dm z_JAyia+h99mOcY?E4pyd4*)4;&MsLw%WwrancKRsM+VZzW{r-~7T%R__sY|Ur*A#| zYfos406WC3x1!7L%4h*>VH=_Vy#qlJ*@Y38Rfu!!X-78N<#VomIZYd`ovP7(DnvTh zuAG(HZ|av^=x|)5O7Tkjxy-Av_F0}zq;^@YbVrG_l1(l7Fw$ga~KQ;pvtdM!@H| z#e=X28Q|ZAtxpz_>bA~n`VIMcumIOFuh)iPt$YSe)=FN;KN)iX7UEd_1Wpx96( zG`Wq#9lN%f+@zCK<#lCV0J#f6Py@S8L?as0o)fvIBKlyfIKFLc zKxp4qxh286_V3z?+pDU*@vO?s9%+HgtpX{o`}VZf4}*Oo%qu0#Jz?*J^|6FiolHc^ pYydYBv5$y-KcdHF&1MCEV(QdhGd+YK(H$B&_x<~Rv>9~*`VXjaMqmH{ literal 0 HcmV?d00001 diff --git a/sgl_dair.egg-info/PKG-INFO b/sgl_dair.egg-info/PKG-INFO new file mode 100644 index 0000000..225f2a6 --- /dev/null +++ b/sgl_dair.egg-info/PKG-INFO @@ -0,0 +1,175 @@ +Metadata-Version: 2.1 +Name: sgl-dair +Version: 0.1.5 +Summary: Graph Neural Network (GNN) toolkit targeting scalable graph learning +Home-page: https://github.com/PKU-DAIR/SGL +Author: DAIR Lab @PKU +Classifier: Programming Language :: Python :: 3 +Classifier: License :: OSI Approved :: MIT License +Classifier: Operating System :: OS Independent +Requires-Python: >=3.6 +Description-Content-Type: text/markdown +License-File: LICENSE +Requires-Dist: torch>=1.8 +Requires-Dist: networkx +Requires-Dist: tqdm +Requires-Dist: numpy>=1.21 +Requires-Dist: scipy +Requires-Dist: gensim +Requires-Dist: scikit_learn +Requires-Dist: ogb +Requires-Dist: openbox +Requires-Dist: munkres + +## SGL: Scalable Graph Learning + +**SGL** is a Graph Neural Network (GNN) toolkit targeting scalable graph learning, which supports deep graph learning on +extremely large datasets. SGL allows users to easily implement scalable graph neural networks and evaluate its +performance on various downstream tasks like node classification, node clustering, and link prediction. Further, SGL +supports auto neural architecture search functionality based +on OpenBox. SGL is designed and +developed by the graph learning team from +the DAIR Lab at Peking University. + +## Why SGL? +The key difference between SGL and existing GNN toolkits, such as PyTorch Geometric (PyG) and Deep Graph Library (DGL), is that, SGL enjoys the characteristics of the follwing three perspectives. + ++ **High scalability**: Following the scalable design paradigm **SGAP** + in PaSca, SGL can scale to graph data with + billions of nodes and edges. ++ **Auto neural architecture search**: SGL can automatically choose decent and scalable graph neural architectures according to specific tasks and + pre-defined multiple objectives (e.g., inference time, memory cost, and predictive performance). ++ **Ease of use**: SGL has user-friendly interfaces for implementing existing scalable GNNs and executing various downstream tasks. + +## Installation + +Some datasets in SGL are constructed based +on PyG. Please follow the +link below to install PyG first before installing +SGL: https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html. + +### Install from pip + +To install SGL from PyPI: + +```bash +pip install sgl-dair +``` + +## Quick Start + +A quick start example is given by: + +```python +from sgl.dataset import Planetoid +from sgl.models.homo import SGC +from sgl.tasks import NodeClassification + +dataset = Planetoid("pubmed", "./", "official") +model = SGC(prop_steps=3, feat_dim=dataset.num_features, output_dim=dataset.num_classes) + +device = "cuda:0" +test_acc = NodeClassification(dataset, model, lr=0.1, weight_decay=5e-5, epochs=200, device=device).test_acc +``` + +An example of the auto neural network search functionality is as follows: + +```python +import torch +from openbox.optimizer.generic_smbo import SMBO + +from sgl.dataset.planetoid import Planetoid +from sgl.search.search_config import ConfigManager + +dataset = Planetoid("cora", "./", "official") +device = torch.device(f"cuda:{0}" if torch.cuda.is_available() else "cpu") + +## Define Initial Arch and Configuration +initial_arch = [2, 0, 1, 2, 3, 0, 0] +configer = ConfigManager(initial_arch) +configer._setParameters(dataset, device, 128, 200, 1e-2, 5e-4) + +## Define Search Parameters +dim = 7 +bo = SMBO(configer._configFunction, + configer._configSpace(), + num_objs=2, + num_constraints=0, + max_runs=3500, + surrogate_type='prf', + acq_type='ehvi', + acq_optimizer_type='local_random', + initial_runs=2 * (dim + 1), + init_strategy='sobol', + ref_point=[-1, 0.00001], + time_limit_per_trial=5000, + task_id='quick_start', + random_state=1) + +## Search +history = bo.run() +print(history) +``` + +## Related Publications + +**PaSca: a Graph Neural Architecture Search System under the Scalable Paradigm**[[PDF](https://dl.acm.org/doi/pdf/10.1145/3485447.3511986)]
+Wentao Zhang, Yu Shen, Zheyu Lin, Yang Li, Xiaosen Li, Wen Ouyang, Yangyu Tao, Zhi Yang, and Bin Cui.
+The world wide web conference.
+***WWW 2022, CCF-A, 🏆 Best Student Paper Award (among 1822 submmisions)*** + + +**Node Dependent Local Smoothing for Scalable Graph Learning** [[PDF](https://arxiv.org/pdf/2110.14377)]
+Wentao Zhang, Mingyu Yang, Zeang Sheng, Yang Li, Wen Ouyang, Yangyu Tao, Zhi Yang, Bin Cui.
+Thirty-fifth Conference on Neural Information Processing Systems.
+***NeurIPS 2021, CCF-A, Spotlight Presentation, Acceptance Rate: < 3%***. + +**NAFS: A Simple yet Tough-to-beat Baseline for Graph Representation Learning.** [[PDF](https://arxiv.org/abs/2206.08583)]
+Wentao Zhang, Zeang Sheng, Mingyu Yang, Yang Li, Yu Shen, Zhi Yang, Bin Cui.
+The 39th International Conference on Machine Learning.
+***ICML 2022, CCF-A***. + +**Deep and Flexible Graph Neural Architecture Search.** [[PDF](https://arxiv.org/abs/2206.08582)]
+Wentao Zhang, Zheyu Lin, Yu Shen, Yang Li, Zhi Yang, Bin Cui.
+The 39th International Conference on Machine Learning.
+***ICML 2022, CCF-A***. + +**Model Degradation Hinders Deep Graph Neural Networks.** [[PDF](https://arxiv.org/abs/2206.04361)]
+Wentao Zhang, Zeang Sheng, Yuezihan Jiang, Yikuan Xia, Jun Gao, Zhi Yang, Bin Cui.
+SIGKDD Conference on Knowledge Discovery and Data Mining.
+***KDD 2022, CCF-A***. + +**Graph Attention Multi-Layer Perceptron** [[PDF](https://arxiv.org/pdf/2108.10097)]
+Wentao Zhang, Ziqi Yin, Zeang Sheng, Wen Ouyang, Xiaosen Li, Yangyu Tao, Zhi Yang, Bin Cui.
+ACM SIGKDD Conference on Knowledge Discovery and Data Mining.
+***KDD 2022, CCF-A, Rank \#1 in [Open Graph Benchmark](https://ogb.stanford.edu/docs/leader_nodeprop/\#ogbn-mag)*** + +**[OpenBox](https://github.com/PKU-DAIR/open-box): A Generalized Black-box Optimization Service** [[PDF](https://arxiv.org/abs/2106.00421)]
+Yang Li, Yu Shen, Wentao Zhang, Yuanwei Chen, ..., Wentao Wu, Zhi Yang, Ce Zhang, Bin Cui.
+ACM SIGKDD Conference on Knowledge Discovery and Data Mining.
+***KDD 2021, CCF-A, top prize in [open-source innovation competition @ 2021 CCF ChinaSoft](https://mp.weixin.qq.com/s/8JX5ymkUt5MvDcHLOjB3Xw)*** + + + +## Citing SGL + +Please cite our [paper](https://dl.acm.org/doi/pdf/10.1145/3485447.3511986) if you find *SGL* useful in your work: +``` +@inproceedings{zhang2022pasca, + title={PaSca: A Graph Neural Architecture Search System under the Scalable Paradigm}, + author={Zhang, Wentao and Shen, Yu and Lin, Zheyu and Li, Yang and Li, Xiaosen and Ouyang, Wen and Tao, Yangyu and Yang, Zhi and Cui, Bin}, + booktitle={Proceedings of the ACM Web Conference 2022}, + pages={1817--1828}, + year={2022} +} +``` + +## Contact + +If you have any technical questions, please submit new issues. + +If you have any other questions, please contact: Wentao Zhang[wentao.zhang@pku.edu.cn] and Zeang Sheng[shengzeang18@pku.edu.cn]. + +## License + +The entire codebase is under [MIT license](LICENSE). diff --git a/sgl_dair.egg-info/SOURCES.txt b/sgl_dair.egg-info/SOURCES.txt new file mode 100644 index 0000000..b890aa4 --- /dev/null +++ b/sgl_dair.egg-info/SOURCES.txt @@ -0,0 +1,123 @@ +LICENSE +MANIFEST.in +README.md +pyproject.toml +requirements.txt +setup.py +sgl/__init__.py +sgl/data/__init__.py +sgl/data/base_data.py +sgl/data/base_dataset.py +sgl/data/transforms.py +sgl/data/utils.py +sgl/dataset/__init__.py +sgl/dataset/acm.py +sgl/dataset/actor.py +sgl/dataset/airports.py +sgl/dataset/amazon.py +sgl/dataset/amazon_product.py +sgl/dataset/aminer.py +sgl/dataset/choose_edge_type.py +sgl/dataset/coauthor.py +sgl/dataset/custom_dataset.py +sgl/dataset/dblp.py +sgl/dataset/dblp_original.py +sgl/dataset/facebook.py +sgl/dataset/flickr.py +sgl/dataset/github.py +sgl/dataset/imdb.py +sgl/dataset/karateclub.py +sgl/dataset/linkx_dataset.py +sgl/dataset/nell.py +sgl/dataset/ogbn.py +sgl/dataset/ogbn_mag.py +sgl/dataset/planetoid.py +sgl/dataset/planetoid_sampling.py +sgl/dataset/reddit.py +sgl/dataset/twitch.py +sgl/dataset/utils.py +sgl/dataset/webkb.py +sgl/dataset/wikics.py +sgl/etc/__init__.py +sgl/etc/auto_select_edge_type_for_nars.py +sgl/etc/hetero_search.py +sgl/etc/hetero_test.py +sgl/etc/stability_of_subgraph_weight.py +sgl/models/__init__.py +sgl/models/backup.py +sgl/models/base_model.py +sgl/models/base_model_dist.py +sgl/models/sample_models.py +sgl/models/simple_models.py +sgl/models/hetero/__init__.py +sgl/models/hetero/fast_nars_sgc.py +sgl/models/hetero/nars_sign.py +sgl/models/homo/__init__.py +sgl/models/homo/clustergcn.py +sgl/models/homo/fastgcn.py +sgl/models/homo/gamlp.py +sgl/models/homo/gamlp_dist.py +sgl/models/homo/gamlp_recursive.py +sgl/models/homo/gbp.py +sgl/models/homo/graphsage.py +sgl/models/homo/nafs.py +sgl/models/homo/pasca_v1.py +sgl/models/homo/pasca_v2.py +sgl/models/homo/pasca_v3.py +sgl/models/homo/sgc.py +sgl/models/homo/sgc_dist.py +sgl/models/homo/sign.py +sgl/models/homo/ssgc.py +sgl/operators/__init__.py +sgl/operators/base_op.py +sgl/operators/utils.py +sgl/operators/csrc/libcudamatmul.so +sgl/operators/csrc/libmatmul.so +sgl/operators/graph_op/__init__.py +sgl/operators/graph_op/laplacian_graph_op.py +sgl/operators/graph_op/ppr_graph_op.py +sgl/operators/message_op/__init__.py +sgl/operators/message_op/concat_message_op.py +sgl/operators/message_op/iterate_learnable_weighted_message_op.py +sgl/operators/message_op/last_message_op.py +sgl/operators/message_op/learnable_weighted_messahe_op.py +sgl/operators/message_op/max_message_op.py +sgl/operators/message_op/mean_message_op.py +sgl/operators/message_op/min_message_op.py +sgl/operators/message_op/over_smooth_distance_op.py +sgl/operators/message_op/projected_concat_message_op.py +sgl/operators/message_op/simple_weighted_message_op.py +sgl/operators/message_op/sum_message_op.py +sgl/sampler/__init__.py +sgl/sampler/base_sampler.py +sgl/sampler/sampler.py +sgl/search/__init__.py +sgl/search/auto_search.py +sgl/search/auto_search_dist.py +sgl/search/base_search.py +sgl/search/search_config.py +sgl/search/search_config_dist.py +sgl/search/search_models.py +sgl/search/search_models_dist.py +sgl/search/utils.py +sgl/tasks/__init__.py +sgl/tasks/base_task.py +sgl/tasks/clustering_metrics.py +sgl/tasks/correct_and_smooth.py +sgl/tasks/link_prediction.py +sgl/tasks/node_classification.py +sgl/tasks/node_classification_dist.py +sgl/tasks/node_classification_sampling.py +sgl/tasks/node_classification_with_label_use.py +sgl/tasks/node_clustering.py +sgl/tasks/utils.py +sgl/tricks/__init__.py +sgl/tricks/correct_and_smooth.py +sgl/tricks/utils.py +sgl/utils/__init__.py +sgl/utils/auto_choose_gpu.py +sgl_dair.egg-info/PKG-INFO +sgl_dair.egg-info/SOURCES.txt +sgl_dair.egg-info/dependency_links.txt +sgl_dair.egg-info/requires.txt +sgl_dair.egg-info/top_level.txt \ No newline at end of file diff --git a/sgl_dair.egg-info/dependency_links.txt b/sgl_dair.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/sgl_dair.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/sgl_dair.egg-info/requires.txt b/sgl_dair.egg-info/requires.txt new file mode 100644 index 0000000..5c6362f --- /dev/null +++ b/sgl_dair.egg-info/requires.txt @@ -0,0 +1,10 @@ +torch>=1.8 +networkx +tqdm +numpy>=1.21 +scipy +gensim +scikit_learn +ogb +openbox +munkres diff --git a/sgl_dair.egg-info/top_level.txt b/sgl_dair.egg-info/top_level.txt new file mode 100644 index 0000000..d1cebd6 --- /dev/null +++ b/sgl_dair.egg-info/top_level.txt @@ -0,0 +1 @@ +sgl From 57261eac63e7737bbe74d57ae3f9eb1a83c118ac Mon Sep 17 00:00:00 2001 From: infinity Date: Sun, 12 Nov 2023 05:05:13 +0000 Subject: [PATCH 02/28] add graphsage. reorganize training scripts in the examples fold. --- examples/clustergcn_nodeclass.py | 67 ++++------- examples/configs/clustergcn.yml | 21 ++++ examples/configs/fastgcn.yml | 25 ++++ examples/configs/graphsage.yml | 24 ++++ examples/configs/vanillagcn.yml | 19 +++ examples/fastgcn_nodeclass.py | 76 ------------ examples/sample_based_nodeclass.py | 38 ++++++ .../__pycache__/base_model.cpython-37.pyc | Bin 10475 -> 11100 bytes .../__pycache__/simple_models.cpython-37.pyc | Bin 9192 -> 11685 bytes sgl/models/base_model.py | 50 +++++--- sgl/models/base_model_dist.py | 3 - .../__pycache__/clustergcn.cpython-37.pyc | Bin 1017 -> 937 bytes .../homo/__pycache__/fastgcn.cpython-37.pyc | Bin 1362 -> 948 bytes .../homo/__pycache__/graphsage.cpython-37.pyc | Bin 1316 -> 1134 bytes .../__pycache__/vanillagcn.cpython-37.pyc | Bin 1116 -> 1072 bytes sgl/models/homo/clustergcn.py | 9 +- sgl/models/homo/fastgcn.py | 17 +-- sgl/models/homo/gamlp_dist.py | 2 +- sgl/models/homo/graphsage.py | 25 ++-- sgl/models/homo/vanillagcn.py | 7 +- sgl/models/simple_models.py | 94 +++++++++++++-- .../__pycache__/base_op.cpython-37.pyc | Bin 2650 -> 3060 bytes .../__pycache__/utils.cpython-37.pyc | Bin 3492 -> 3828 bytes sgl/operators/base_op.py | 6 + sgl/operators/graph_op/__init__.py | 2 + .../__pycache__/__init__.cpython-37.pyc | Bin 276 -> 336 bytes .../laplacian_graph_op.cpython-37.pyc | Bin 1090 -> 1160 bytes .../__pycache__/rw_graph_op.cpython-37.pyc | Bin 0 -> 1000 bytes sgl/operators/graph_op/laplacian_graph_op.py | 7 +- sgl/operators/graph_op/rw_graph_op.py | 18 +++ sgl/operators/message_op/__init__.py | 4 +- .../__pycache__/__init__.cpython-37.pyc | Bin 970 -> 1044 bytes .../pre_normalize_message_op.cpython-37.pyc | Bin 0 -> 845 bytes .../message_op/pre_normalize_message_op.py | 11 ++ sgl/operators/utils.py | 22 +++- .../__pycache__/base_sampler.cpython-37.pyc | Bin 764 -> 764 bytes .../__pycache__/sampler.cpython-37.pyc | Bin 12125 -> 12551 bytes sgl/sampler/sampler.py | 111 +++++++++++------- ...ode_classification_sampling.cpython-37.pyc | Bin 4180 -> 4306 bytes sgl/tasks/__pycache__/utils.cpython-37.pyc | Bin 10871 -> 10871 bytes sgl/tasks/node_classification_sampling.py | 12 +- 41 files changed, 432 insertions(+), 238 deletions(-) create mode 100644 examples/configs/clustergcn.yml create mode 100644 examples/configs/fastgcn.yml create mode 100644 examples/configs/graphsage.yml create mode 100644 examples/configs/vanillagcn.yml delete mode 100644 examples/fastgcn_nodeclass.py create mode 100644 examples/sample_based_nodeclass.py create mode 100644 sgl/operators/graph_op/__pycache__/rw_graph_op.cpython-37.pyc create mode 100644 sgl/operators/graph_op/rw_graph_op.py create mode 100644 sgl/operators/message_op/__pycache__/pre_normalize_message_op.cpython-37.pyc create mode 100644 sgl/operators/message_op/pre_normalize_message_op.py diff --git a/examples/clustergcn_nodeclass.py b/examples/clustergcn_nodeclass.py index e912174..2cb56f5 100644 --- a/examples/clustergcn_nodeclass.py +++ b/examples/clustergcn_nodeclass.py @@ -1,50 +1,33 @@ +import yaml import argparse import networkx as nx -import torch.nn.functional as F -from sgl.dataset import Planetoid +import sgl.dataset as Dataset from sgl.models.homo import ClusterGCN +from sgl.sampler import ClusterGCNSampler from sgl.tasks import NodeClassification_Sampling if __name__ == "__main__": - parser = argparse.ArgumentParser(description = "Run .") - parser.add_argument("--clustering_method", - nargs = "?", - default = "random", - choices = ["random", "metis"], - help = "Clustering method for graph decomposition. Default is the random procedure.") - - parser.add_argument("--epochs", - type = int, - default = 200, - help = "Number of training epochs. Default is 200.") - - parser.add_argument("--seed", - type = int, - default = 42, - help = "Random seed for train_test split. Default is 42.") - - parser.add_argument("--dropout", - type = float, - default = 0.5, - help = "Dropout parameter. Default is 0.5.") - - parser.add_argument("--learning_rate", - type = float, - default = 0.01, - help = "Learning rate. Default is 0.01.") - - parser.add_argument("--test_ratio", - type = float, - default = 0.9, - help = "Test data ratio. Default is 0.1.") - - parser.add_argument("--cluster_number", - type = int, - default = 10, - help = "Number of clusters extracted. Default is 10.") + parser = argparse.ArgumentParser(description = "ClusterGCNSampler-Models.") + parser.add_argument( + "--device", type=int, default=0, help="gpu device id or cpu (-1)" + ) + parser.add_argument( + "--config_path", type=str, default="./configs/fastgcn.yml", help="save path of the configuration file" + ) args = parser.parse_args() - device = 'cuda:0' - dataset = Planetoid("cora", "/home/ssq/test_data/", f"clustergcn_{args.cluster_number}") - model = ClusterGCN(nx.from_scipy_sparse_matrix(dataset.adj), dataset.x.numpy(), dataset.y.unsqueeze(1).numpy(), device, dataset.num_features, 128, dataset.num_classes, args.clustering_method, args.cluster_number, args.test_ratio) - test_acc = NodeClassification_Sampling(dataset, model, lr=0.1, weight_decay=5e-5, epochs=30, device=device, loss_fn=F.nll_loss, train_batch_size=1, eval_batch_size=1).test_acc + config = yaml.safe_load(open(args.config_path, "rb")) + device = f"cuda:{args.device}" if args.device >= 0 else "cpu" + dataset_kwargs = config["dataset"] + cluster_number = config["sampler"]["cluster_number"] + dataset_kwargs.update({"split": f"clustergcn_{cluster_number}"}) + classname = dataset_kwargs.pop("classname") + dataset = getattr(Dataset, classname)(**dataset_kwargs) + sampler_kwargs = config["sampler"] + sampler = ClusterGCNSampler(nx.from_scipy_sparse_matrix(dataset.adj), dataset.x.numpy(), dataset.y.unsqueeze(1).numpy(), **sampler_kwargs) + model_kwargs = config["model"] + model_kwargs.update({"device": device}) + model = ClusterGCN(sampler, nfeat=dataset.num_features, nclass=dataset.num_classes, **model_kwargs) + task_kwargs = config["task"] + task_kwargs.update({"device": device}) + test_acc = NodeClassification_Sampling(dataset, model, **task_kwargs).test_acc diff --git a/examples/configs/clustergcn.yml b/examples/configs/clustergcn.yml new file mode 100644 index 0000000..f861b10 --- /dev/null +++ b/examples/configs/clustergcn.yml @@ -0,0 +1,21 @@ +dataset: + classname: "Planetoid" + name: "cora" + root: "/home/ssq/test_data/" +sampler: + cluster_method: "random" + cluster_number: 10 + test_ratio: 0.3 +model: + hidden_dim: 128 + dropout: 0.5 + num_layers: 2 +task: + train_batch_size: 1 + eval_batch_size: 1 + epochs: 30 + lr: 0.01 + weight_decay: 0.00005 + loss_fn: "nll_loss" + seed: 42 + diff --git a/examples/configs/fastgcn.yml b/examples/configs/fastgcn.yml new file mode 100644 index 0000000..97babcc --- /dev/null +++ b/examples/configs/fastgcn.yml @@ -0,0 +1,25 @@ +dataset: + classname: "Planetoid" + name: "cora" + root: "/home/ssq/test_data/" + split: "official" +sampler: + name: "FastGCNSampler" + inductive: False + adj_process: "LaplacianGraphOp" + layer_sizes: "256-256" + prob_type: "normalize" + replace: True +model: + name: "FastGCN" + hidden_dim: 128 + dropout: 0.5 + num_layers: 2 +task: + train_batch_size: 256 + eval_batch_size: 256 + epochs: 30 + lr: 0.1 + weight_decay: 0.00005 + loss_fn: "nll_loss" + diff --git a/examples/configs/graphsage.yml b/examples/configs/graphsage.yml new file mode 100644 index 0000000..0236a38 --- /dev/null +++ b/examples/configs/graphsage.yml @@ -0,0 +1,24 @@ +dataset: + classname: "Planetoid" + name: "cora" + root: "/home/ssq/test_data/" + split: "official" +sampler: + name: "NeighborSampler" + inductive: False + layer_sizes: "5-5" + prob_type: "normalize" + replace: False +model: + name: "GraphSAGE" + hidden_dim: 128 + dropout: 0.5 + num_layers: 2 +task: + train_batch_size: 64 + eval_batch_size: 64 + epochs: 20 + lr: 0.1 + weight_decay: 0.00005 + loss_fn: "nll_loss" + diff --git a/examples/configs/vanillagcn.yml b/examples/configs/vanillagcn.yml new file mode 100644 index 0000000..ae0c76e --- /dev/null +++ b/examples/configs/vanillagcn.yml @@ -0,0 +1,19 @@ +dataset: + classname: "Planetoid" + name: "cora" + root: "/home/ssq/test_data/" + split: "official" +sampler: + name: "FullSampler" + inductive: False +model: + name: "VanillaGCN" + hidden_dim: 128 + dropout: 0.5 + num_layers: 2 +task: + epochs: 20 + lr: 0.1 + weight_decay: 0.00005 + loss_fn: "nll_loss" + diff --git a/examples/fastgcn_nodeclass.py b/examples/fastgcn_nodeclass.py deleted file mode 100644 index e550504..0000000 --- a/examples/fastgcn_nodeclass.py +++ /dev/null @@ -1,76 +0,0 @@ -import argparse -import torch.nn.functional as F -from sgl.dataset import Planetoid -from sgl.models.homo import FastGCN, GraphSAGE, VanillaGCN -from sgl.tasks import NodeClassification_Sampling - - -if __name__ == "__main__": - parser = argparse.ArgumentParser("FastGCN") - parser.add_argument( - "--hidden", type=int, default=128, help="dimension of hidden layer" - ) - parser.add_argument("--dropout", type=float, default=0.5, help="dropout") - parser.add_argument( - "--layer_sizes", type=str, default="128-128", help="sampling sizes per layer" - ) - args = parser.parse_args() - device = "cuda:0" - dataset = Planetoid("cora", "/home/ssq/test_data/", "official") - # model = FastGCN( - # dataset, - # hidden_dim=args.hidden, - # output_dim=dataset.num_classes, - # dropout=args.dropout, - # device=device, - # inductive=False, - # prob_type="uniform" - # ) - # test_acc = NodeClassification_Sampling( - # dataset, - # model, - # lr=0.1, - # weight_decay=5e-5, - # epochs=20, - # device=device, - # loss_fn=F.nll_loss, - # train_batch_size=256, - # eval_batch_size=256, - # ).test_acc - # print(f"final test acc: {test_acc}") - model = GraphSAGE( - dataset, - hidden_dim=args.hidden, - output_dim=dataset.num_classes, - dropout=args.dropout, - device=device, - ) - test_acc = NodeClassification_Sampling( - dataset, - model, - lr=0.1, - weight_decay=5e-5, - epochs=20, - device=device, - loss_fn=F.nll_loss, - train_batch_size=64, - eval_batch_size=64, - ).test_acc - print(f"final test acc: {test_acc}") - # model = VanillaGCN( - # dataset, - # hidden_dim=args.hidden, - # output_dim=dataset.num_classes, - # dropout=args.dropout, - # device=device, - # ) - # test_acc = NodeClassification_Sampling( - # dataset, - # model, - # lr=0.1, - # weight_decay=5e-5, - # epochs=20, - # device=device, - # loss_fn=F.nll_loss - # ).test_acc - # print(f"final test acc: {test_acc}") diff --git a/examples/sample_based_nodeclass.py b/examples/sample_based_nodeclass.py new file mode 100644 index 0000000..88bf386 --- /dev/null +++ b/examples/sample_based_nodeclass.py @@ -0,0 +1,38 @@ +import yaml +import argparse + +import sgl.dataset as Dataset +import sgl.models.homo as HomoModels +import sgl.sampler as Sampler +from sgl.tasks import NodeClassification_Sampling + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Sampler-Models") + parser.add_argument( + "--device", type=int, default=0, help="gpu device id or cpu (-1)" + ) + parser.add_argument( + "--config_path", type=str, default="./configs/fastgcn.yml", help="save path of the configuration file" + ) + args = parser.parse_args() + config = yaml.safe_load(open(args.config_path, "rb")) + device = f"cuda:{args.device}" if args.device >= 0 else "cpu" + dataset_kwargs = config["dataset"] + classname = dataset_kwargs.pop("classname") + dataset = getattr(Dataset, classname)(**dataset_kwargs) + sampler_kwargs = config["sampler"] + if "inductive" in sampler_kwargs.keys(): + inductive = sampler_kwargs.pop("inductive") + else: + inductive = False + sampler_name = sampler_kwargs.pop("name") + sampler = getattr(Sampler, sampler_name)(dataset.adj[dataset.train_idx, :][:, dataset.train_idx] if inductive else dataset.adj, **sampler_kwargs) + model_kwargs = config["model"] + model_name = model_kwargs.pop("name") + model_kwargs.update({"device": device}) + model = getattr(HomoModels, model_name)(dataset, sampler, **model_kwargs) + task_kwargs = config["task"] + task_kwargs.update({"device": device}) + test_acc = NodeClassification_Sampling(dataset, model, **task_kwargs).test_acc + print(f"final test acc: {test_acc}") \ No newline at end of file diff --git a/sgl/models/__pycache__/base_model.cpython-37.pyc b/sgl/models/__pycache__/base_model.cpython-37.pyc index 0d708a32b99663f0d9b2ff90b02f9a76b4fe79a7..d10861df33f4f1a640f5cd340b6210a54370cdeb 100644 GIT binary patch delta 2214 zcmZ8hTWl0n7@jk?-R`a1?e>197hshNqTy1I7Fv;@fLO%1El#&P)9v;$yPlaM;FO}R z5fe+fm%2i0 zNb9DA81NWkP|#%pLncXwO$i=V)pqYjA6zFOcRYgNM+_%3Rh-P*IipbYYm8#Sao1Tn zo+*yu0n~+t!wHm^SWJ$v2)r9Qs~y1Y0Q?&|uk|4H8d8n05YgajxJ%uK9QY;Fzy{%u zaBB8cRzU{+G+Zh!+Bo63C9Q~RYkhSDBG5MO4GxDYg4K7EUqT}ldE0Nre4(W$Tt%a=4_8~dBE|d!My}0v(tIg*t zIm`7!c_)t<$kX-fHDNv_CiKZa#cpqmjShP}>$EQJ3A zQ(0V4)kr1f5T2-b7K52c!|YmrQ$lIcg`J}a9+q5mZ9P$?PRE&)ZCy+&_B-wrSM?-Q z#$N43DQ!{rLovU}8B_%JIW~xl!bnF=}PpY5W$7sUQvWM5#P{KRV91fFB=nIUGKx4(}`f<$*dz#10(1*%Qnf(bK z>k~2_Ux&*TC)h){8$S^8%DmvHjCTt|Yoa5-d(oS(!C>M@hacWBL^uFlc=W#k0Qe@c zoB8lK(VxV((#^Ssoj09SOmr9TfxgOiwjXkpN19HOwlMJuQS@R7XZ#dWaHq13y$XL; zRyVv(2VyE$Jc_P(61i8PE!oQwkWKDmr{J^XrZ_!-B}>l{)em=*ec@+YAOj)^)O<;h(w%_A6Ye@6h;Rv;jBk(=Bh1u`~jX z8NSSJ>9uYo?h7QdtKmJCftw8@~DC`xLh-g`&59hnc~|?^bFi+ z+QUx3qo#xGG3;pW)=7>JL$Ue!x*;UL#N}W=NkJBgW7)R`6+Q%an(wh$x<=(05~r3Z zS*AgOLiN+K&LGx$0Izne^l2Lr&u=;U1)q1cVc@8skwOpL?`XCtW1Q6ZC;`1%e2gGZFitQ*Ko1WO5Ul58 ziner9zLr>S)3!iRB)CNI4h(nJz;b8hAg6iYX9?(L{4IiW1m_88ka#P>1%kH;Jc1bn ztmIgxXstGuNXc5=su=v%c_(^wH3<^C&S%5>tz{{VH%@Gk%W delta 1736 zcmYjRO-vhC5Zj-LR`_kdb#e}WN^+lvMynW1;rRb$O$=C z5N?QD;FLNi_;FNvJ=0-(eF3}D5>?_t+a4`;zj;q!l)MBUw;5$c$ z)P*a!?(kVpqY?*)Q(aGrsM14kD&|!bgU&4o^THM@ux*BEv8XEYQ)n$J`w5_-3Jr1l zbRw6W)6yB;;65~zz)xB)!etc0{0Lj~4=@ta2u1*g#{ZQFIaT8pzsh`N@?8|H1W!%Tfz*}McuNg6xfz!TR0=!VC2yyw9$U? zs5K+&*|S;goO^tk%k%?7))WHw{ucs`la68gMq;V)tuMXNEZrS+o2EIy@^ z3qE(%`)1K;RwoN<;lt|SgTOKP(={gdqn{5zPla!N?l4r9O*VNSdAkvc)^sklWbjHd z5SOZCHkUV&ldZAmv&u>#W~Yy1Tmr>#AqdRQ6wxjEnO~IITAi5} zL@IBxC3sNzF*D#oRj>FjgO95A98y8T#>yWuei`mm$GltCf}>#F7O&YdpQ1}%5YhED zJgYv>{(|$KVP=O#GOVq&yy)FgI$95avwbMv<3J8=7MRq;p@m$bl~zD zgli%!4l^}p6qh+zyayiEJjc4>kD8IjVcKDC8zGIRn7@Wlw9*F*81{zzba-5z$uBNy zt8vAABP)Cqu6t|Q3Vh-1V=mbBc3Xy27xcG!AnbFv>1LFUqbtCB;f~^iDc{LTr&FRo z#bS3VvSQ&8=sg~TJ>UCm3^x3qu`hS(YJXvDV&_TSCM*8|H4%MG}QW*JJ=;t9G`&4!NJ-wRPUh}G}sDr zx^XUAu(UBOZ5KPV;jed(Q#j2`>Et)yY^YOwD!_WE2Wx#8`lnQ~#sWNsQAnnlqIu8F zJc&#XG&Q{^T|y0>20Q^DJ7^9M(a~=api$s)__e7EX|HY$NrV5#GZmuE4XL-uFA_`= zyhA|09b6-rBA@~1(*$J%hrLMBlB(p?Rjw1H2xbWAj`DX2R^jJnA8a(&PAt)qD&iS} zEWt7XT{s>jxI(}Q41yd25;&JwPMNd3---jao(T|Z`H6MH^XkZUh?L5L;#|j2KK{?9XOUN79)e5tcqLRz<aCt{Z=| diff --git a/sgl/models/__pycache__/simple_models.cpython-37.pyc b/sgl/models/__pycache__/simple_models.cpython-37.pyc index 84e059c47a671ffc994848d7b221d96502b7c424..c73db554000041b2b71aa2a2f0d62a610a71ebc3 100644 GIT binary patch delta 3012 zcmbtWOK%%h6ux)fb`mGe!*SBmxD9<6RRxK+T1rbxBMLHA8&P2%*3@^B*f`_l&XlI^ zgh3>wr4VQ}66<6EKL9rT1K5-eJ8#MkRtSj=V#Rmv#3`v-STNR{x%b@nIp6nP|2p;6 zTe+9|`mzkqwTrKNbBAu{3jEL82Yz~Niiu${vdYAWr#I{sy|M$IycmUNbkj2^_<2?; ze8caKJNDGc^2wI}{ zG`-@KbQk8Rsn~R{c=Ay3Qs5QmJ=ZU~LD8L^mENoy)?0q@vU?>P)4=JPNfE$9UwU=C z!2(RkLLRXdZC%@7oP~PCLt{mg#fV3IMq7rBSk;U)SS!+D?a*mtA};tCgVzw+8hNu3 zYuHk3w&~O+R9lS3o`|jY()5DiTxMBvEL`Gp8Xx%>%;{2}q#emcG6qa1TMz1f5W0TN z>*#eqlozR~Mom4k)e5VZz5482C@pdyA-CDA`_*RpkunEY$&~%W)V?)VhE$zSPK=Gk zcTo(eneiOYa)(>=)AE|d^ICgHH};3C)w*8~tJNCqnjZSnZcpM3_QK%oY@Mt4Ax?rL ze!Jn*Q=dAtR+jIhXUUM4@S)uev5r0G25uP2w2909WQQ{v|3=f&u^4uh!6-FYwcWKgbAsOYirQdw>34HQCb2GFZSZFq1{+8o&Y_ZL$zH4?n`G!j9W$777rZpS>ZG>p zCqG}3K7E~VHmv$}5y)exR&4*$d=;*WzEqWR1g_NNSsBGgQU3#PFbKnVFVAWYLXp!P zBqyh}2ev{o(~_56DN5S(bcvTsj@*XP^E@OPUNu-1yZMyRpg?p?AfT3iv_5{Uji%*A- zHd{4!q!}DTDeYS7NoC+_a0xlMyQC9hYsB*I4)Zh@aYe_zf&mm#{txm&p2utPcAwg3Qq)J{EuQZ}zok5a&XAN| z(#4eBi_V2KQb3<3KSHlmdbD(-Mh)`44c`b&|9wR6qn#)!1M2NzwR+km-Xi-wu6Mzu zxc(OwjKJ&=f*X?Xe7y8-2%HEcl31-Ubmnbs*;J&~BNNhWMAnMZg@yq=$jH|~bI_X< zA_S(orG_2FPyl>b!y!y88;eB+JS85O1{WHp(^t^q&n%roKL|{^FrsuE(6y=8w#B}C z=wkP;_2mUtP*c((m?m8)G;#eHiM_Y!iZzhgA(LDA)brN>!b;00}X zvrM%?-tJX7>xz-C=q8Eyq)N$JW|ZGZ&g35*xeAv?^O8C^JePG0c@;8LPJW#@7?V&w zTq*ujC`h%tlCUWz|0WdOdxRWjGed`zaC z939)eYd;+-hx&FhwR`bbi2fv<_@6|Q*Ks!vow|Vq9^b~6AuMqCK0eiGF?zaGc@^d0 zM;M|?^1_}y@dM}(8o29H5*Q7T0MVp~+7CxQ;o3jDbWDNt9_|nQ6Za3EHQt)`!|jBk zDMO<{Hvs-2u4(E2CA!g}Uj*M&j(MuZ8ASkjll-5OQOJJ)=;dcLLn@vG_6a7};OSWJ z;r5&jhRzLsMDHb&gT{W!sD uA5FE1`wVbTzS2vm&yrV?R7 zh3HXenX8B9;9t;xz?+`DC<;;VqToLv_$IA_?ZA6)_j}*_X5Y7O?*CfOULPCF5IkQx zD?TmnXC3&oKljn{;0~#tJv;~B3%T*-*nP6K*ovOE!X4RcMTa9(0<6g)totJqdWDVZ z*Tr|a*^^BDxb!~foI)3-pX;~zx-M5My8&Js*n5A0(uuK+zuU z1CY)InN-{ZhxFhEd64f~|2Xt7r+xV@P9INF#f~Ob0u@sp!JX)8T)-Qv>6l@OxxRVc zx#F6~v`MR`myFL~X(Xae-<1s>Z8g1y6>WsB5RGIMy3hPr#H;$@WNGvp5=9mhfsDR3 zUnm+?bCMVq#v+FJj7dS4>EH9SyS$5E6EMXmiDD|zqg^I(Q^oqokUX#uC74N&^yn+r zW$l^aPRxVh1U1!6MfW;Ok{T_Kc<_=UYJ9o0M!OunF*@knN%i?t@-P0EKN<;Ais`+% z>h1I((-h4}NH^N>2Rv*FIY|4_&bA-B5{mIfIKTaVU7Yg2B!ont=- z&>3*o3xu{SmBLSC3G*iiaL7Ob#xXmPr&Bttm8=*WRZH*HCiI8wjMK)~uNZ=3NE`kKWYom(KGd{dZS*Ra)#X$jDpSVqFu7{cGAw+75kz+{Rg_#wQB$X diff --git a/sgl/models/base_model.py b/sgl/models/base_model.py index 5325f87..c0d4c20 100644 --- a/sgl/models/base_model.py +++ b/sgl/models/base_model.py @@ -70,7 +70,8 @@ class BaseSAMPLEModel(nn.Module): def __init__(self, evaluate_mode="full"): super(BaseSAMPLEModel, self).__init__() - self._pre_graph_op, self._sampling_op, self._post_graph_op = None, None, None + self._pre_graph_op, self._post_graph_op = None, None + self._sampling_op, self._post_sampling_graph_op = None, None self._base_model = None self._evaluate_mode = evaluate_mode @@ -91,20 +92,41 @@ def sampler_name(self): def evaluate_mode(self): return self._evaluate_mode - def sampling(self, batch_inds): - return self._sampling_op.sampling(batch_inds) + def sampling(self, batch_inds, to_sparse_tensor=True): + sample_results = self._sampling_op.sampling(batch_inds) + adjs = sample_results.get("sampled_adjs", None) + if adjs is not None: + if isinstance(adjs, list): + if self._post_sampling_graph_op is not None: + adjs = [self._post_sampling_graph_op._construct_adj(adj) for adj in adjs] + if to_sparse_tensor: + adjs = [sparse_mx_to_torch_sparse_tensor(adj) for adj in adjs] + elif isinstance(adjs, dict): + if self._post_sampling_graph_op is not None: + adjs = {sg_id: self._post_sampling_graph_op._construct_adj(adj) for sg_id, adj in adjs.items()} + if to_sparse_tensor: + adjs = {sg_id: sparse_mx_to_torch_sparse_tensor(adj) for sg_id, adj in adjs.items()} + else: + if self._post_sampling_graph_op is not None: + adjs = self._post_sampling_graph_op._construct_adj(adjs) + if to_sparse_tensor: + adjs = sparse_mx_to_torch_sparse_tensor(adjs) + sample_results.update({"sampled_adjs": adjs}) + return sample_results - def preprocess(self, adj, x, use_subgraphs=False): + def preprocess(self, adj, x): if self._pre_graph_op is not None: - if use_subgraphs is False: - # We don't transform _norm_adj into the form of sparse tensor, as sparse tensors don't have strides - self._norm_adj = self._pre_graph_op._construct_adj(adj) - else: - self._norm_adj = {sg_id: self._pre_graph_op._construct_adj(sampled_adj) for sg_id, sampled_adj in adj.items()} - self._norm_adj = {sg_id: sparse_mx_to_torch_sparse_tensor(sampled_adj) for sg_id, sampled_adj in self._norm_adj.items()} + # We don't transform _norm_adj into the form of sparse tensor, + # as sparse tensors don't have strides. + self._norm_adj = self._pre_graph_op._construct_adj(adj) else: - self._pre_msg_learnable = False - self._processed_feature = x + # For ClusterGCN, we have already processed subgraphs after sampling. + self._norm_adj = adj + self._pre_msg_learnable = False + if hasattr(self, "_pre_feature_op"): + self._processed_feature = self._pre_feature_op._transform_x(x) + else: + self._processed_feature = x def postprocess(self, adj, output): if self._post_graph_op is not None: @@ -120,7 +142,7 @@ def forward(self, batch_idx, device, **kwargs): if self.training: if sampler_name in ["FastGCNSampler", "NeighborSampler"]: sampled_adjs = kwargs["sampled_adjs"] - n_ids = kwargs["source_n_ids"] # source node inds of the last layer + n_ids = kwargs["n_ids"] sampled_x = self._processed_feature[n_ids].to(device) sampled_adjs = [sampled_adj.to(device) for sampled_adj in sampled_adjs] effective_batch = batch_idx @@ -145,7 +167,7 @@ def forward(self, batch_idx, device, **kwargs): sampled_adjs = [sparse_mx_to_torch_sparse_tensor(self._norm_adj).to(device)] * (num_layers - 1) sampled_adjs.append(sparse_mx_to_torch_sparse_tensor(self._norm_adj[batch_idx, :]).to(device)) effective_batch = batch_idx - output = self._base_model(full_x, sampled_adjs) + output = self._base_model(full_x, sampled_adjs, tgt_nids=batch_idx) elif sampler_name == "ClusterGCNSampler": batch_idx = batch_idx.item() sampled_x = self._processed_feature[batch_idx].to(device) diff --git a/sgl/models/base_model_dist.py b/sgl/models/base_model_dist.py index 261f773..b6d3fa8 100644 --- a/sgl/models/base_model_dist.py +++ b/sgl/models/base_model_dist.py @@ -1,9 +1,6 @@ -import torch import torch.nn as nn import torch.nn.functional as F -from sgl.data.base_dataset import HeteroNodeDataset - class BaseSGAPModelDist(nn.Module): def __init__(self, prop_steps, feat_dim, output_dim): diff --git a/sgl/models/homo/__pycache__/clustergcn.cpython-37.pyc b/sgl/models/homo/__pycache__/clustergcn.cpython-37.pyc index 281266c39858e72ac638d9c1d009016f6d462137..e0e867368bc9da3554ba815596aa1b3a57c42135 100644 GIT binary patch delta 590 zcmYLGJ#W=86!k}JhxP@b3W2IeS*tK~1_=Qz16zgEjl~MF`x@lLk?nvGkvf!xN^f|| zAAphn!Y}Z|hPtt`;Tt|`>t3IGj&ptQ%`fp)6h}p|h0^-)?4^FRDR%M0&Hn!K1yY$} zACY33UeH%u@e@>vE4aA14Cg-pAoTs|6-h^=bwAYkOd$lPNJf+SqvQe#ctjWE4N+K; zHO6R6v7{kWRApwzoqlZ$##YUOUH9F)^32>6##AT*!YZrR`k?;$7KW zAFlb8tNJrg=Qc^Oaa+Ewyq2wfuCyuhn0akZqnIro_dnbuJnV$!UcQD6`=^NSI2GYSKT|WIJ4-VmCETo+KthF&xx;>D= zmDt?_d>( T`Qz5A!RRMAMIMS=;IzqaoDG#> delta 657 zcmYLF&1w`u5bmCy>6!hRgcS*dfZ1D^fVp@~P&7h7WP>bTa%m>po!trZvvl`*YRF-| zF1XAK=+V1pU&33zfaDRprD`HtLseIO^)vO=_xMkozDd(Q0kO~Dsh`_)pT6B59i2%c z!;)9=RZPhRk&)z|iR3C;v#UhLQ<5e3*pyK}IxknoDt&Q2nda59RC-(3JI{4q%ci;? zL)GsWo&0(=8(gX$`utf!+9|y`;x?%#UV|yA}0ZRLU+I>39A@qFz@}y%3H^E$c$POrLPuao85JJC)1r1y*w~aO@|BizP>!$qw;iEUk VuH+03vntg~3QdVlis)YYK-ya_mWjRFRL-NeQ&E&#k$UQX zfE)ZJpDJ-c{R^C!NxCW{c``HhJic$THy)1&$j|CnEdxS+qqC6^jCXMCYfzF%s!2gp zO3@BLwfXF@j$_436Cg<%pv#xeQJNNo8;i5FL|DhSw0~VH<#sT-?&r zLCZi7+o%Pg*U|T_HvWNcMv_4H_{fK=MkBJZIv2VTPVoY;uAv9W%`V(fGmU&yu9R>- zEKhRj<1*92T0bbYxKIWRX{x$v+;rPVwyBjd5oG%z=Xses&J7Ok$GkR*uaP)j)jr`? z6t&LFHFyR|^GVqh92+%8S-*{5*9n}!Reqa)hK+OsYG5w3O0S@|a1eEV(3Mlhk6=cr zN-p!lOt9<-3*C(xOLd-S3UF=lzBuQ<$lsUqld4d2YtQD(#lb6m@nvq;dX8LZi+9z0 z1$$a&<*dFiJ5Uw(VG~e_z z?|O`P6)Z})%GlXvC;gtD{Fi8RSkWEjwp&tH|9+I-!a@A-O^@X*#c6P)1c*u4KU*XF A>;M1& literal 1362 zcmZ8hOK%%D5GMDn)gwmG7%7VO)Qf;hYJ1ATaN0Djfx>YM$fpGgVzZ>=Rk;t7Tq8C% zCppN;hyDZmQlNi{rvN<^{R=*IW@S}HDa>d{4rhknH{2J)VN77WoBSvmkB~ocvK=rd z-+<_spg7@Fl7c3bqU}hha1*z)UFj8m;&-+ugCa~qN{$Hkd2mg5An4jnA|B4kIQknR zkukN8UZh&gzkT)d@%OJPE+j1N9jCQSvpg-2R9c_?RKwcnhg6%R?@s1vQA?p<#f4dh zLAJwxF!X0&2$E1v5{F0c$%68Lhu6KteNQ}+c#ytl!-s%>ZBBi7ejk4I^|J|x9hQ|U zQkh=~mR0i9x>?<1y@%-dAA-0b`gc%<7}}DxvvN4)&JBRl`nN83@90~fd)&W<`Af2f zneniW8}Men3EGfHzeZ}+l9hiOwsaMN^Z-!u&RIAgBeY=Q8qV%$IoNsL(bjJR-n(%j zR&4shTDqpcIp+Z0A4r>+V*Nyy=LaJ7GvM~OOYBfq)e2#=BWRdX8`Wfn zrF>;wR6V28ykt3lV+TxQ|K{Zqk6^>vLD>`xi)xh6)|H~Pj;U;)y@q$u5hS5oJuTz{ z>Pro*8XH6TT$GIGg^i)$x-n?0PavXtf@&XB(#uN@rQ{dF23%a^nE;4-jCI#$Y{S-L zbpI4o1frjUBJ?qh>4**-$9+Zv>d@G6=svyL`#&9HU*TI&cvRubI92!&0iXf1GA#sS zHfDH>Mxs4n?7T_kW~WOEFmB4pgp}Kou#3i2O)zR#9@2n`y1s~LN6QJo@&sCHx@oED z4ljk>WA|~d0uobGDLws{C?qfG<7USNkla!C^gp3=jvtu~_`0L1%ed?A9db50*xc?A N-;hQi2DJOR{{Zu-XvzQp diff --git a/sgl/models/homo/__pycache__/graphsage.cpython-37.pyc b/sgl/models/homo/__pycache__/graphsage.cpython-37.pyc index 32fcc0876bfc21ed5e4fca0d600a6dc203107601..c7c984eb984dd150f173f53801d4fff9d79d1aab 100644 GIT binary patch literal 1134 zcmYjQOK%e~5VrT3C#|YdIB^5C2a-R4fL5v^p*D@k1?|P!#?GViK5Fk4iuTq@J@r4p zk-y{<5(m`3z=;`e(%_XR^N2ssd^0}k_d5jE$BQ$u=MwT8jm`2v`2br_KybpTBq@z4 zMct83>c+0sUFoHM>|5QFK^n%P)qUAY+i{zc5#a$3ZwU{Df8)j--kOrA{Q#&WqGo3- zR*U6%uBJ&^NTI;xO%F%M7&=O{m>!-^PL5A=E+n|xpT3M#QY_CVVC+tm80RWI61qFL1M~D?C{`@%qjP{f7_1TJK~YpgKR?6wy=2dAA5@YdjH^nX53j( z&BDe<@E135Mf7J7B`Il5Zk!8;Q|{aYJEeO=YQntI=PrESy#rQIqYT%vR@tuUMW=Q_ zc5ACt1^0Gd|eenDJ;OW7+Ynlk}>6@*=L0kY=P9Wyifu9 zcB~JvD~RQV=`oExT4f8gp0Og=<<|9h7|AictV-GXtTtV=Y+^zir_4>4okMSJK@!;a zw2*U1H7UVcn$X@Ol<7c8Tx1LiRr?6rL7;VI0T)-RnZT)F4LD;z$=|_nnWti?_2qCn zI(Z|nzYg_64v{;lF_sT)syH%(;#wU5ZVR^F1wrT@ZP6z*0O`|Y}F5GA=^(n@mDpfMU0=&^u+SaN$doF+}8i|rQBO+Y~~OG;jY`(cu6#Kz{N z6gm0OTaL~#e<%l@0u*S`zu;46R#04^6gV2rketUiPkJ^Qr3A*O=2-m|5b`H3wj%`R zCz$p*7)d16q@W8*(Px?!!6NAVK!-)Nh&n&iagi(%N@hexGJZ>BtfFw zCNie}-igXrr>`nA7e%d=flx5}`2;i1gjMsW$G;xEIIg795W0UPYAv!{lrtmh(_iWg zZ`lq+iNk&eMv(=TWWi+ejx1@1puY&-k&rAxtnd5X74xUF7k_uB$mb_7Z=)ag_kP&- zK~^_eZxbf?oxlWAh zdF0}joJBWD8yVTsvmVSo&_{k}OZEtf1P^1lDns1yfwr;iT{AeHx&cInEo%ceYU4F+ zL)pKMZg$!Pb`gG05_nX4;=^U5k*Sfo5V{dg@dCKJrF$6#Eklf!djn{-6n9Psp_h6Wxy%cn0^quF=$l7)V?Qq? z0LU+t!Pfa$s*5~RKt`YQ?!}yM8uS&U&9`7szIGQ3p$}+Ecj%BZ8q>!tWsL69s|WvQ ze;gVV{7ZvpO&<)<3rv-wP@MaeBlbq4Kji$p5qgv9>H&EhB$?8Biv~3_rsgr;*C9jl zu|}tX9_^ZOGVc2U2x!s~2_5Qq{{|x)Nlk2y{}CJJ?1tlq|4HA`YEa>R7jqT#TDZ#C g$(Md&kGqO@TD~LA_YO9%JH%g%MGA(%8GM3&0c0Fqq5uE@ diff --git a/sgl/models/homo/__pycache__/vanillagcn.cpython-37.pyc b/sgl/models/homo/__pycache__/vanillagcn.cpython-37.pyc index 54c65769b7c85978af7c3152ee787bd0c818ee56..694f6b09abdf562e40961a59cf444e6c9aa62dac 100644 GIT binary patch delta 374 zcmXw!u}Z`+9LD=^nzZ&>kki3K5d_bp4lXVZP70nb=OCQm5`v9))m|-0Z&2doO@hOb z7tlA*XAqnPx8e)jBeh#sKocS$;KDN8FZEfG zsq&JOY}v}nQ2YZh zvw+0Te?94sI!q&U4lVld9MIz)^j3Jz=9F9&Z+`tuC!v-*?qMG*_t3eHJ zY>r>>o7<3g)AQ6fe7N^pT-Jqhvkk2^a-gE`1FE z_MxxKf_>$qq338xQUg78jxOLpaTHEWP}3<*M2c+KMZXJstyeCRvo4taPZxI2C~3|` hQl33VrC=SGNd2G#XsazcBqqEhulBKToQ%CL^bhxbWsU#< diff --git a/sgl/models/homo/clustergcn.py b/sgl/models/homo/clustergcn.py index 5eae03e..dafbcf9 100644 --- a/sgl/models/homo/clustergcn.py +++ b/sgl/models/homo/clustergcn.py @@ -1,11 +1,10 @@ -from sgl.sampler import ClusterGCNSampler from sgl.models.simple_models import GCN from sgl.models.base_model import BaseSAMPLEModel from sgl.operators.graph_op import LaplacianGraphOp class ClusterGCN(BaseSAMPLEModel): - def __init__(self, adj, features, target, device, nfeat, nhid, nclass, clustering_method="random", cluster_number=32, test_ratio=0.3): + def __init__(self, sampler, nfeat, hidden_dim, nclass, dropout=0.5, num_layers=2, device="cpu"): super(ClusterGCN, self).__init__(evaluate_mode="sampling") - self._pre_graph_op = LaplacianGraphOp(r=0.5) - self._sampling_op = ClusterGCNSampler(adj, features, target, clustering_method=clustering_method, cluster_number=cluster_number, test_ratio=test_ratio) - self._base_model = GCN(nfeat, nhid, nclass).to(device) \ No newline at end of file + self._sampling_op = sampler + self._post_sampling_graph_op = LaplacianGraphOp(r=0.5) + self._base_model = GCN(nfeat=nfeat, nhid=hidden_dim, nclass=nclass, nlayers=num_layers, dropout=dropout).to(device) \ No newline at end of file diff --git a/sgl/models/homo/fastgcn.py b/sgl/models/homo/fastgcn.py index 220dc50..191204a 100644 --- a/sgl/models/homo/fastgcn.py +++ b/sgl/models/homo/fastgcn.py @@ -1,24 +1,13 @@ from sgl.models.base_model import BaseSAMPLEModel from sgl.operators.graph_op import LaplacianGraphOp -from sgl.sampler import FastGCNSampler from sgl.models.simple_models import GCN class FastGCN(BaseSAMPLEModel): - def __init__(self, dataset, hidden_dim, output_dim, dropout=0.5, layer_sizes="128-128", prob_type="normalize_col", inductive=True, device="cpu"): + def __init__(self, dataset, sampler, hidden_dim, dropout=0.5, num_layers=2, device="cpu"): super(FastGCN, self).__init__(evaluate_mode="full") - layer_sizes = layer_sizes.split("-") - layer_sizes = [int(layer_size) for layer_size in layer_sizes] self._pre_graph_op = LaplacianGraphOp(r=0.5) - # inductive-learning - self._sampling_op = FastGCNSampler( - self._pre_graph_op._construct_adj( - dataset.adj[dataset.train_idx, :][:, dataset.train_idx] - ) if inductive is True else self._pre_graph_op._construct_adj( - dataset.adj), - layer_sizes=layer_sizes, - prob_type=prob_type - ) + self._sampling_op = sampler self._base_model = GCN( - nfeat=dataset.num_features, nhid=hidden_dim, nclass=output_dim, nlayers=len(layer_sizes), dropout=dropout + nfeat=dataset.num_features, nhid=hidden_dim, nclass=dataset.num_classes, nlayers=num_layers, dropout=dropout ).to(device) diff --git a/sgl/models/homo/gamlp_dist.py b/sgl/models/homo/gamlp_dist.py index 1e7be08..877c2d1 100644 --- a/sgl/models/homo/gamlp_dist.py +++ b/sgl/models/homo/gamlp_dist.py @@ -1,4 +1,4 @@ -from sgl.models.base_model import BaseSGAPModelDist +from sgl.models.base_model_dist import BaseSGAPModelDist from sgl.models.simple_models import MultiLayerPerceptron from sgl.operators.graph_op import LaplacianGraphOp from sgl.operators.message_op import LearnableWeightedMessageOp diff --git a/sgl/models/homo/graphsage.py b/sgl/models/homo/graphsage.py index 832b48f..d1163ad 100644 --- a/sgl/models/homo/graphsage.py +++ b/sgl/models/homo/graphsage.py @@ -1,21 +1,16 @@ from sgl.sampler import NeighborSampler -from sgl.models.simple_models import GCN +from sgl.models.simple_models import SAGE from sgl.models.base_model import BaseSAMPLEModel -from sgl.operators.graph_op import LaplacianGraphOp - +from sgl.operators.graph_op import RwGraphOP +from sgl.operators.message_op import PreNormMessageOp class GraphSAGE(BaseSAMPLEModel): - def __init__(self, dataset, hidden_dim, output_dim, dropout=0.5, inductive=False, layer_sizes="20-10", device="cpu"): + def __init__(self, dataset, sampler, hidden_dim, dropout=0.5, num_layers=2, device="cpu"): super(GraphSAGE, self).__init__(evaluate_mode="full") - layer_sizes = layer_sizes.split("-") - layer_sizes = [int(layer_size) for layer_size in layer_sizes] - self._pre_graph_op = LaplacianGraphOp(r=0.5) - self._sampling_op = NeighborSampler( - self._pre_graph_op._construct_adj( - dataset.adj[dataset.train_idx, :][:, dataset.train_idx] if inductive else dataset.adj - ), - layer_sizes=layer_sizes, - ) - self._base_model = GCN( - nfeat=dataset.num_features, nhid=hidden_dim, nclass=output_dim, nlayers=len(layer_sizes), dropout=dropout + self._pre_graph_op = RwGraphOP() + self._pre_feature_op = PreNormMessageOp(p=1, dim=1) + self._sampling_op = sampler + self._post_sampling_graph_op = RwGraphOP() + self._base_model = SAGE( + nfeat=dataset.num_features, nhid=hidden_dim, nclass=dataset.num_classes, nlayers=num_layers, dropout=dropout ).to(device) diff --git a/sgl/models/homo/vanillagcn.py b/sgl/models/homo/vanillagcn.py index 61b6478..7bb38ce 100644 --- a/sgl/models/homo/vanillagcn.py +++ b/sgl/models/homo/vanillagcn.py @@ -1,4 +1,3 @@ -from sgl.sampler import FullSampler from sgl.models.base_model import BaseSAMPLEModel from sgl.operators.graph_op import LaplacianGraphOp from sgl.models.simple_models import GCN @@ -8,10 +7,10 @@ class VanillaGCN(BaseSAMPLEModel): """ It is a naive version of Graph Convolutional Network which works in full-batch training. """ - def __init__(self, dataset, hidden_dim, output_dim, dropout=0.5, nlayers=2, device="cpu"): + def __init__(self, dataset, sampler, hidden_dim, dropout=0.5, num_layers=2, device="cpu"): super(VanillaGCN, self).__init__(evaluate_mode="full") self._pre_graph_op = LaplacianGraphOp(r=0.5) - self._sampling_op = FullSampler(dataset.adj) + self._sampling_op = sampler self._base_model = GCN( - nfeat=dataset.num_features, nhid=hidden_dim, nclass=output_dim, nlayers=nlayers, dropout=dropout + nfeat=dataset.num_features, nhid=hidden_dim, nclass=dataset.num_classes, nlayers=num_layers, dropout=dropout ).to(device) diff --git a/sgl/models/simple_models.py b/sgl/models/simple_models.py index ec0d1bd..63f5ea8 100644 --- a/sgl/models/simple_models.py +++ b/sgl/models/simple_models.py @@ -185,13 +185,13 @@ def forward(self, feature): output = self.__fcs[-1](feature) return output -class GraphConvolution(nn.Module): +class GCNConv(nn.Module): """ Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 """ def __init__(self, in_features, out_features, bias=False): - super(GraphConvolution, self).__init__() + super(GCNConv, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features)) @@ -215,18 +215,96 @@ def forward(self, input, adj): else: return output +class SAGEConv(nn.Module): + """ + Simple GraphSAGE layer, use mean as aggregation way + """ + + def __init__(self, in_features, out_features, root_weight=True, bias=True): + super(SAGEConv, self).__init__() + if isinstance(in_features, int): + in_features = (in_features, in_features) + self.in_features = in_features + self.out_features = out_features + self.root_weight = root_weight + + self.lin_l = nn.Linear(in_features[0], out_features, bias=bias) + + if self.root_weight: + self.lin_r = nn.Linear(in_features[1], out_features, bias=False) + + self.reset_parameters() + + def reset_parameters(self): + self.lin_l.reset_parameters() + if hasattr(self, "lin_r"): + self.lin_r.reset_parameters() + + def forward(self, x, adj, tgt_nids=None): + output = torch.spmm(adj, x) + output = self.lin_l(output) + + if tgt_nids is None: + num_tgt = adj.shape[0] + x_r = x[:num_tgt] + else: + x_r = x[tgt_nids] + + if self.root_weight: + output += self.lin_r(x_r) + + return output + +class SAGE(nn.Module): + def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, normalize=True): + super(SAGE, self).__init__() + self.gcs = nn.ModuleList() + self.gcs.append(SAGEConv(nfeat, nhid)) + for _ in range(nlayers-2): + self.gcs.append(SAGEConv(nhid, nhid)) + self.gcs.append(SAGEConv(nhid, nclass)) + self.dropout = dropout + self.normalize = lambda x: F.normalize(x, p=1, dim=1) if normalize else None + + def reset_parameter(self): + for conv in self.gcs: + conv.reset_parameters() + + def forward(self, x, adjs, tgt_nids=None): + repr = x + if isinstance(adjs, list): + for i, adj in enumerate(adjs[:-1]): + repr = self.gcs[i](repr, adj) + if self.normalize is not None: + repr = self.normalize(repr) + repr = F.relu(repr) + repr = F.dropout(repr, self.dropout, training=self.training) + repr = self.gcs[-1](repr, adjs[-1], tgt_nids) + else: + for gc in self.gcs[:-1]: + repr = gc(repr, adjs) + if self.normalize is not None: + repr = self.normalize(repr) + repr = F.relu(repr) + repr = F.dropout(repr, self.dropout, training=self.training) + repr = self.gcs[-1](repr, adjs, tgt_nids) + return F.log_softmax(repr, dim=1) class GCN(nn.Module): def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5): - super().__init__() + super(GCN, self).__init__() self.gcs = nn.ModuleList() - self.gcs.append(GraphConvolution(nfeat, nhid)) + self.gcs.append(GCNConv(nfeat, nhid)) for _ in range(nlayers-2): - self.gcs.append(GraphConvolution(nhid, nhid)) - self.gcs.append(GraphConvolution(nhid, nclass)) + self.gcs.append(GCNConv(nhid, nhid)) + self.gcs.append(GCNConv(nhid, nclass)) self.dropout = dropout + + def reset_parameter(self): + for conv in self.gcs: + conv.reset_parameters() - def forward(self, x, adjs): + def forward(self, x, adjs, **kwargs): repr = x if isinstance(adjs, list): for i, adj in enumerate(adjs[:-1]): @@ -235,7 +313,7 @@ def forward(self, x, adjs): repr = F.dropout(repr, self.dropout, training=self.training) repr = self.gcs[-1](repr, adjs[-1]) else: - for i, gc in enumerate(self.gcs[:-1]): + for gc in self.gcs[:-1]: repr = gc(repr, adjs) repr = F.relu(repr) repr = F.dropout(repr, self.dropout, training=self.training) diff --git a/sgl/operators/__pycache__/base_op.cpython-37.pyc b/sgl/operators/__pycache__/base_op.cpython-37.pyc index 634c3dfc4eb54b7b00c9ac90ad6b28d54705a7e3..bf9180d67c78d568ef6f0b6ea2f570ef6e6dce8b 100644 GIT binary patch delta 523 zcmZ8d%SyvQ6rGz%n@LP+>yx7RLZVTuMFgLV;znKg!d(`Y!o(L=GBzJe=OWs;E()&oe)sywI%1`66sIBC?bA8OIcB^n{b1`&hH zPC8o+TJ7o${2TOW_M3)_f7!&Xo?-1c-aWy3^A-QC`GnXc65wzgS8nzN%IY<{1(T{^t-zEDtzno}Hx|OI`m{d0c`^q) zeK*djcQFJDDsQj1%+eh7CN8Q2yWgMDGX04%sS+X9e$)ZV;gjBs74y7^H`Tk&l&!g@YL^swptpkvoIy7MpK=N@-4N ek@n=btm2c;bMuP`fR%DE@-Tuh50d~hpDX}BxG5n3 diff --git a/sgl/operators/__pycache__/utils.cpython-37.pyc b/sgl/operators/__pycache__/utils.cpython-37.pyc index a960178d85a17d592a9a8c5904b363ce80228748..0bc3309f62f64d762cae870992598f75023d801a 100644 GIT binary patch delta 865 zcmYL{OK;Oa5XX1?h#$4nq?9UBsh~niYL)lnP|?r=;#JX>imC@AWTLM7Ky4DOEqd7~ z66zhH{Q?f^843C|5O74{$OZ1)Va9FBSo_D0n*0Wkq(~HSQ4XvhY<#16IUL=kos1*Jr%fO z0_#4AezC!Ml*Bt5(N3$?PC}2zn<9>rXeZuC)dj0s4E;e%Y;SI<6^nN(^e`}B6iT>? z;A$lpp&e^6i@&YYt0~Bs6LTJ5fr0^LYR57G>)ar;$3S~D6scj)lv*ybkXe* z$c`b6?vXxH_WLi!e|e55VN+ZtTp?T~&~e2z!bQS$!VN@V9@$_W3k)3*Z8u*>qO}+L zBHC@FgHqh0nW6d>m^=JxWnhXr_EqZC zp{YJOH=(Y6IjbTVli^}(HhmBY{O8cg5E`vOmpS>hpq{%GP6VVW>TcuZJOi{5zc z4yLH&33V0;%ZRYJ7vtATBM~QCjdm=as<+-0B>18+^~JNCHJXCe;I#VTy?_NZFgF-*_^WFEpuBSTA+o_(zHI}fbvfdTjsHn>BQs^0o0f9O zyquhgb5?4p&*pjX)eo~%FPubf@5jUEv1r+{hR-QL78VyfQixbRv>NJ_)vrdlz#xJ% zs1RJOGG+v_eR|Hx!1(?LOMP_Pz|}8z5jtwYdkjnJ#2dnf`W;N$F7B*@?4DNqC$P4H z$F+QU6PsKm=!9E@D}-snZNeQyVZ_OHf3I*3_v3>@Azp|=e>xb*Wn60+@-F4>5$>x` zK?D!fmtYp2;k7lhN0CTf4man=Cz<`^cqn2?f8*b!bL8d;Z9G=(X^iIY5Z&3mOYY9Q7)dE;B-KIMR0cnyr6^(=W3e2!7cgvcVfWI=StSrHEW-YQ z2>u9bYYRK!YX67sMj<%xW@lz+=Dm4y9^9Z41Rm1%G`vW@7QxG|Ge>=37PDt=kFXLe z&rr{rBMbHHD-^nmkNQzJw#;A(sT%_P*U#@dPd2K zhU})Y5Z#&G%k6z^5>bnYQNd~?j5dt|ASd`%a9GhvN6ul!Xp5)Xj#rn=A3Y@1?@l&HK!4*(hFRM z!iPYcBf^2);~iR@OR{(d5xwI{vIvH^g{d{KW9srw(sb(jn_6_pmse8n z_QW8C*g;}SH+(z6I`GnlBYNPye#Z?9mFGouWlS-ssyzS6@nITj<<%`u{=v|}_(4Kb U&?n!3CF+WU_W$|cA1bl_1^O8}k^lez diff --git a/sgl/operators/graph_op/__pycache__/rw_graph_op.cpython-37.pyc b/sgl/operators/graph_op/__pycache__/rw_graph_op.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7523ad8e14146874044460cf1377f694e36add8a GIT binary patch literal 1000 zcmY*Y&2H2%5VoD4%`WL8L8&+{7o;GuIf4*Es0dX^s717=T&y&4+n^*)JzllD+Fsa7 zeU0|WEBVT)uK))aC%fBrB#$Sa&z|{aCfgeu1A_8*v?snWLVnwimAGhp#076r}CB+|9j9l6%5kI7u(3`y}#j@iy6wsOjyfv^m*p(0rWc zKe(1$$unL{Rb>{l5*ybFyhJ5PLUWR^+`l46G-tVU*-xAmn0rW>6!lbv+n(G zxBpNwWYdb>o#lc}G&Iq4P8RH%&FKxb4DdePjU3~_q!CKl_$J`Is0+<`#Ec7~JTiku z$%ca#4IqO)=UJHoaQ=(@c@v+=O2h!C@&3-&*X8s)hT}4p_*$x^g80}n=dy{_8DH7s zW~zD^D!>J8wsiJnO=7%?zb3LG@yX)fq0(edHT}UBOl{5SwR1x*XuY{Sbr%d?X)mV> z`jcvZj{60hv+LfCrOS3_+Xp9Ngl$MOQD@Upm1Lv#656u0@7^VGuR6i|j>-pW^YBrH IETp0HA3CZAG5`Po literal 0 HcmV?d00001 diff --git a/sgl/operators/graph_op/laplacian_graph_op.py b/sgl/operators/graph_op/laplacian_graph_op.py index 0d2df3b..eb1ca7c 100644 --- a/sgl/operators/graph_op/laplacian_graph_op.py +++ b/sgl/operators/graph_op/laplacian_graph_op.py @@ -5,15 +5,16 @@ class LaplacianGraphOp(GraphOp): - def __init__(self, prop_steps=-1, r=0.5): + def __init__(self, prop_steps=-1, r=0.5, add_self_loops=True): super(LaplacianGraphOp, self).__init__(prop_steps) self.__r = r + self.__add_self_loops = add_self_loops def _construct_adj(self, adj): if isinstance(adj, sp.csr_matrix): adj = adj.tocoo() elif not isinstance(adj, sp.coo_matrix): raise TypeError("The adjacency matrix must be a scipy.sparse.coo_matrix/csr_matrix!") - - adj_normalized = adj_to_symmetric_norm(adj, self.__r) + + adj_normalized = adj_to_symmetric_norm(adj, self.__r, self.__add_self_loops) return adj_normalized.tocsr() diff --git a/sgl/operators/graph_op/rw_graph_op.py b/sgl/operators/graph_op/rw_graph_op.py new file mode 100644 index 0000000..e40f4ea --- /dev/null +++ b/sgl/operators/graph_op/rw_graph_op.py @@ -0,0 +1,18 @@ +import scipy.sparse as sp + +from sgl.operators.base_op import GraphOp +from sgl.operators.utils import adj_to_row_norm + + +class RwGraphOP(GraphOp): + def __init__(self, prop_steps=-1): + super(RwGraphOP, self).__init__(prop_steps) + + def _construct_adj(self, adj): + if isinstance(adj, sp.csr_matrix): + adj = adj.tocoo() + elif not isinstance(adj, sp.coo_matrix): + raise TypeError("The adjacency matrix must be a scipy.sparse.coo_matrix/csr_matrix!") + + adj_normalized = adj_to_row_norm(adj) + return adj_normalized.tocsr() diff --git a/sgl/operators/message_op/__init__.py b/sgl/operators/message_op/__init__.py index 95fec89..69ac949 100644 --- a/sgl/operators/message_op/__init__.py +++ b/sgl/operators/message_op/__init__.py @@ -9,6 +9,7 @@ from .simple_weighted_message_op import SimpleWeightedMessageOp from .sum_message_op import SumMessageOp from .over_smooth_distance_op import OverSmoothDistanceWeightedOp +from .pre_normalize_message_op import PreNormMessageOp __all__ = [ "ConcatMessageOp", @@ -21,5 +22,6 @@ "ProjectedConcatMessageOp", "SimpleWeightedMessageOp", "SumMessageOp", - "OverSmoothDistanceWeightedOp" + "OverSmoothDistanceWeightedOp", + "PreNormMessageOp" ] diff --git a/sgl/operators/message_op/__pycache__/__init__.cpython-37.pyc b/sgl/operators/message_op/__pycache__/__init__.cpython-37.pyc index 709576a58e51c232da58acfae3d95cf8402c1b2d..b0231416230aa9c1e04ed4d093793179995fe406 100644 GIT binary patch delta 185 zcmX@bK7~WwiI;=lk8kl_Ht#TzE7cW^VL@Z^Z)ibshvGECec z!j~eRB9YD$C6OW-%%CZ?@f-`IkS60TfqnocC`ygb1FA{P$*f9^&xNRr&o2P0yTu+KpO}*qAHR~Jh##o0h<|brvlpY@ QTGcY^`abSQO$Z!DS;zbkHJLGv%cvJXN_)`Q@1XF}k zgi}OPL{r4lxue8V#Df_$B{tq+VVwMh(VW>&Q+%^AlL@2HE%x~M#GIV?_>~Mrd_ZH0 S_$JR__F~ZjGAF-a)&c-Jq8qpX diff --git a/sgl/operators/message_op/__pycache__/pre_normalize_message_op.cpython-37.pyc b/sgl/operators/message_op/__pycache__/pre_normalize_message_op.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb57b9a48f50bbb2b434acd694a0c0b9ec031b26 GIT binary patch literal 845 zcmZWnJ#Q015S`h(bL?P8QGfy@TBJkbPNGBzA&`;=9HeyDtd84t;%M(9yJyL<+{IEV z{sWZ!rERHF(o!*V#&!fF&D_rJ-kbMk_WfwI4`jbyoa)yI@RO5`g|v7_?p{$)pwK{x z86vNR5oI6%p$xMK;RIBm!WF1c<4utDlsJPpx}*OfMnC*)^(XCIey&fOalyMA;~H%4 zIRykcGDpLKYG3oj=pAu~B48vk{%+horKE)x zmvAG#iKX~X0+y7?Kmmc_I1YW}W{tL%|M7h(r`5ET(hJ$dfe)QF6OWA#)U+fq11SrW zJ16B2_9^VFT2v?jE?t*uMTwQN~8X=Amlh;(kI3oUmCNpo#` zRGHm%{lG(9Jo#5*vf)(-y2!@gi2MV&dqkmyOJr%6f;>?8H6Vdv91A<3Q_GVkv3TqU zTOM|gHhm~ro3MROaBhcuy$d7W8Oqk?m75TGIcF5tQNYEs|4H7vNq{V;3vrJdTjiyu z#_mg5)@o*WJ(TilmYelS*9UvVHAjEjTxesY+%byLK44HL)(wn1(6z0P2Vu8E59BdT zA=yqL`I0+LBd|={KcvAF7fDqmlUY@?)4Ixy7ggo=cey>-!0#&GYe@gin*UpVc3ttI HI1+yVXBo+_ literal 0 HcmV?d00001 diff --git a/sgl/operators/message_op/pre_normalize_message_op.py b/sgl/operators/message_op/pre_normalize_message_op.py new file mode 100644 index 0000000..1b48b5d --- /dev/null +++ b/sgl/operators/message_op/pre_normalize_message_op.py @@ -0,0 +1,11 @@ +from sgl.operators.base_op import PreMessageOp + +import torch.nn.functional as F + +class PreNormMessageOp(PreMessageOp): + def __init__(self, p=1, dim=1): + super(PreNormMessageOp, self).__init__(dim) + self._p = p + + def _transform_x(self, x): + return F.normalize(x, p=self._p, dim=self._dim) \ No newline at end of file diff --git a/sgl/operators/utils.py b/sgl/operators/utils.py index 3b92f1c..a8d84e2 100644 --- a/sgl/operators/utils.py +++ b/sgl/operators/utils.py @@ -73,18 +73,28 @@ def cuda_csr_sparse_dense_matmul(adj, feature): return answer.reshape(feature.shape) -def adj_to_symmetric_norm(adj, r): - adj = adj + sp.eye(adj.shape[0]) - degrees = np.array(adj.sum(1)) - r_inv_sqrt_left = np.power(degrees, r - 1).flatten() +def adj_to_symmetric_norm(adj, r, add_self_loops=True): + if add_self_loops: + adj = adj + sp.eye(adj.shape[0]) + degrees_left = np.array(adj.sum(1)) + r_inv_sqrt_left = np.power(degrees_left, r - 1).flatten() r_inv_sqrt_left[np.isinf(r_inv_sqrt_left)] = 0. r_mat_inv_sqrt_left = sp.diags(r_inv_sqrt_left) - r_inv_sqrt_right = np.power(degrees, -r).flatten() + degrees_right = np.array(adj.sum(0)) + r_inv_sqrt_right = np.power(degrees_right, -r).flatten() r_inv_sqrt_right[np.isinf(r_inv_sqrt_right)] = 0. r_mat_inv_sqrt_right = sp.diags(r_inv_sqrt_right) + adj_normalized = r_mat_inv_sqrt_left.dot(adj).dot(r_mat_inv_sqrt_right) + return adj_normalized + +def adj_to_row_norm(adj): + degrees = np.array(adj.sum(1)) + r_inv_row = np.power(degrees, -1).flatten() + r_inv_row[np.isinf(r_inv_row)] = 0. + r_mat_inv_row = sp.diags(r_inv_row) - adj_normalized = adj.dot(r_mat_inv_sqrt_left).transpose().dot(r_mat_inv_sqrt_right) + adj_normalized = r_mat_inv_row.dot(adj) return adj_normalized diff --git a/sgl/sampler/__pycache__/base_sampler.cpython-37.pyc b/sgl/sampler/__pycache__/base_sampler.cpython-37.pyc index 81938ce009efca0b3cb7272540c12c790336b204..1eafea7d0f68870afb2929b2f1c36aa5dc74220a 100644 GIT binary patch delta 20 acmeyv`iGUg=ssEqp{^x?0}O9#kMSY6%l)FEMdbC(-}?o&dzFf zW^}s8-kq8)Ok`u3ih~1HAp}V6K0p=x0aU?L9(bg9fhwx{i7KAV6Hh$&1-|chx_f4K z73D&x;(?j!)2C0LTc7j&?&rn1xr%{b^ZD1p=0(H!H$G%P2bmXeMSp`NFaon@H1*wV zTJmgfJI#X0GOJhIE;U`r+r9F3rCE`@)2nXRnl;Iry}9kWWthhHd~+c!S!^z%b)mPk zeWG~+WyPQrxIuZ-X)a$if=W<*&j_kv<&NEa0ME5x4$pJ){2-p|!91Sl<@qF@7lK7R zFUs>uU|u&GOP^rw#(Kj{oc=&<<6eES6@}MZ+rwU{@VL|pw)|MNx_!UZZ}oPgZiLbW zT@r^;>_@|1H%=C*oy{nji@Mu>ys5%ybI=PCCyK*ive+BEGnGf8%q`THWC|ao)mAr(^4HScG>say;c)ajj1d(G~v!(COwL>9RSMI?r^Y!=P&^2~~@W^>Ayj!S6bx{NC#y1NFJ z##~n?kZ)LO9=D`&vlX{D@i{>xwU+Q+SGGLyioYRcI8sBwPY4NyYcK2f?Vv ze9ae(=r{5S>Oporm##rO?Dc$qI&i~OOL(t1zAW_uhvwW}v*t=k8xQ%9$lxymG2a(v zK_MstZUxdC?!}-SRB$f|{9G(cHJM*0+;0xlbq$ex9XSZEpB94KIf}BLHwS}2BM|uT zw!?T6&l`gtXesip5_#7zU43bazH21g?XVxW;_jgD0luxEHH+Z?C** z1h%eytEhStGq6t?`*u(|F!vmC%J+SK=@(%>_$d@Jmpp;kQ z1@<;6P25R&Qn^#ut4^vib~PV+0pE6VQutFNSow3C-ze=1!Knjsx<$6%UpnACPXr|G z)9Xo@9B8NpHzQ^TH~n~b7$((TYd2JW)E$RWQt6MjeJO~Nk_v~tRy!P@zSe&mygcx3 zg}d(zRN#SGjflJp-lyhM?QEyX1W1k7xzi-l4n}{gwr9r1#JppE$GYuYH#C1S6T91w z8&+biem<#8Com;;f!wE+i@k0Xw+Gw9=bmD(9HMB9&mBV@z6@W~UcJ~Gv|GLCIb^bi z^Eg3U7CV(39c7vo+gMNg`E7S;|%HM!dP&m!DYV`vT61yL5 zpmI0%;sKlJZC}07SWX=5VwB9k{OaXPFJHUyoBr3{y!86j>q-4m6orbcz$!$W zEr>|sLwRgnmPrl=QJ54&?`YIC2O{%H>MioUFz~fxObT+B*(ma$#BC3HdNeMxtnhSU zlX|uZTKhNxE6Af1FOQ>b^ehs?tmE%lHMzSMBntApT$C&!ebB5pD3Sk?xnwO^b#r{2 z{iPk}@_}2=8D7Q}eH}^1*n@IHmbYi`JWF+D%erHN{hbBd0y_f>1gn}jU?bMo!HU3k zto;I+j%iE^dp6nN!~zoq8+{)P5WRtE2KJv>RBc+H0SI%^a;+Z^NAZRHnM%)8q+b9X ztJW@YFNI!z0P3Ky&l|iVQs{688SUk=DzE+Oh{HvPtM~nGaKY>Jfa_j9LbxN;)@pBN zMS2Q(SDYzO5nV_raU<2uzTkHE+z-I87eK)^Hg1 zgGNaM_yO#4;sU_ZS#z#EQhwZtfz|+okg{7WbGW15MXH^>BO!?16b45yE2`1U1Q8{) zQE>p#1%lFat-5s_h?(})rgFj`Q&x9sDhTaL?KK%LxJa1IRsN%h)m}1AsuO3wI;jZYZR8;HOyl+fU~lZ#CMMwS98iz9 z$A6LI759m$nk3!AP}3>3)he)xc`6HRjX>+>;_EHH=kKN++kDL8CS*jKIohzWk|Mwt z47QVEdvk!#NQy0u#x-_RBe?u<2VArrr>rw6=bx+|L#1XJ7L10L&I%1Q3sS3DhCQd} z(kz3~rqgopO;KVEpH2#)P)0e2=TStygdy6vqSHu>Vg*=y7?`YpVe&cyR*Abc{!GR( zavkwFA+;$O_#?UrQ^GcS0?dps?cfm*bjS^3V;>5|KEQKe9+)2*$lG_JU>sPI2XDP= zY~bC8#?+#Wow#rpiXr(d)WcF-phB|e?z|g26IZxzOa*LG2K*{`Ycq!$)k$F=0PUEQ z>N3oBPk1}2OzImT9D(!q z-ZXA+yvf;D0B3$tWl~7*;Po@#TivRGTnZCs!`iU)%=Oo7fa^mooyXVDZwU8aKctHl zu!Pj!BU)8m>SFahXo*~`Q^p@nwK3UOn0%i>wZ!Z~uUUXbb9)49HRDWr>#5U0B?C=q zmF!(SXDZp}k8Sb{K4;XjyfN&HFQGlig4j=rddQ^O9`xHSuyb5VISS*TyRq?fpst{j zdRl4%^x(EaO^Q~95`riQjVD$#Oe*ccz~64gs=L!z5T4vD=L}f=I-A^JPqrGolh{xs z5+`WIt!BO74bTqE9O|%YH|MAfPaD30HcdBshB}D1LQ#pi)hxmL;ghc2Y|+#;4#W+> zr@qO&7_P4&w}VUk6$a>K!K~SJ>$E5uP&@Ek7bSxf^r-o$H9m8{rEQeeN%pz-$h`>sf(ii0YW>6C^u`%WfB^Y9F=h#n10qU=8Io^&4&Ly;t&|45enMZbjPey*FN zHnrH&C2;L>QeSID@zu-gDLcqjCHBqhCpV>O*VyZ71C7=)+m;VMn{E4tXsK;mQ*7Il zhv#|Rw59P}#3iQf_ra_F-?W7Z{l8_}I+q2g9T2AS3@{ToIMWpUFyvoeE3!7&E)Cq52ybOBwBcKWwBrGe+w2eT~#BhmF*qzm2Lzi?`9;fjg9*9%}7kf6=}TX{|xqhl{@z zdm*h*(ekJM(tBF6wBrV&RH2@+DQ1*@v7Xtfgbcb0P9*91kVTPMhW9dUBkycHP?(VR zr=)gmofl%^;MSD_>~Or z&{%Dj_ylwxMI|+hAL5GWkf$-~Ky7j0CrQoGqi7R@32yD&8m0S-bZ{A0^mj;rC+MYg zYrGFN7fLid!24Eg0%eHZP|-aw1K^EL673$Lmlkzws?D%EX}cDJ0<6wvT!h_8Z?CpG zO|d%7{R(Qswxro0mM7Z7@+`tvTtvSmj3c8@N_u>#+$?t|_I@p%+pohD?M|#^;~^;U z@HpF0@NH4MORdYt@O;DxlG|l7u=(jtQEU}RMLe}rTa-wfK_KkVfp-%{5~m1qsN56L zqtw|jYa!3{rP2IjE#%aFc>x5L`bR-?;tr&<1`gzYd)L$I`2Ef8^+OU!tmc%*|Ht!< z`6-`HO8r3}Uf6-Qn8kqhfFVWyi01RJFt^U+RV0YDfd?dZFo@M}^SKaoTZp%nHP8Mv z7FB0fHz}pxq+Vwi6!Xnex}wc;wrWAMl9u_yr;kzy)$V`6rH$kIlDU!_#>?<4(s|=r zE9UsoqtgDkZQO8w5*7bG+WiLR@g2^Eo@Z$H1-Idz6s@_9g8WC_o4<(GLZ;gXX44i! z8Z?N|ti!mC#L5tzmROm@U795wpKiK3F5N8acyzNOdVdvq`a-gJxi><1SWQ*`@fQW3 zX*t7QM6-y(bKXYSLiiO(hQCIf+Tyy=JUVw5(8kSehG+3W%JvP)C*O>%7zjuXlXx>pUshfJ&faw3N!K}>B*4NyL=H^`52&yrvYR8Vf-P#doHLLn$gS-R81bs)AV2rL7xdeoZ;XrRaa}Rvw>zca` zXgcMe-oyhxJusmRyp9|qB?!_!j9@YXrKZteKo05wJah0J5-qqp>AfK60gk{#=$wY% zE+{g#!t)UU9T@I7%oprR$k^l{$`&nUT@Ze8m!Y(*JfqSXMt06?E3gIDxFH&!u!Ly3 zzH3PVpzTKEM=A)zAiOV~V=sh>vRwp{N4$TvGJdr3y7-FkQ|AWML~4kMxdX^v{a<}Vx zqo^$eVN~y-Hyy4m$mb3qO1890kJyTjSZWO

@1QX~^eM5{%<*GC$*c}79BlyE zt~bz8*SKr>IyEJ4Noq^#NLu*V+7C`%h6KLVri_e})+y z+B7m}K|5_rIpp!+o_W`#0e1>h1+RrvK6KneeSn%N#MS$fPf97&W(0x2>a%$L(hSsJ zxCf|yBwLz+eC{MWGN0}8TxHg?&-aWhK(n$B@XgcPC zAil|v|HPVM0(uZzVlL_p)Rk2^Iufb=tJ9IjKkHKD)Su5)vypTgAHy)R6lbuFw8DWw zZY1q-<~o*M9xMtblez^7fShSD3<57)Q3p@-XXKlCZTbcO~WiB|xqZzw!7jm>grLjC)GyzI$MR6mZ-iE^i^m z3RjH1RU^P4AK^;e00W2dMmBDJWa)P3Lo2OwdC$n@rS-fNIVqh-yM??IxxDUTUWy!; zqszmDV2;jsm@~&4sX~)!ZNekrPI+xtYiK;=1>fG;aHky(W1P)Y9`GPriz6ppcc@Jd$D76rW6i#x86()>SCM6%1Xd6d-X;mtkKZ*u`j$2Z@ zQ13F~B<>fyq_O*-xP+yM29<jyT0 rICJDvXLc~(#>Bm-4_!Zpt<*Xy^=f^-t!6Gk)5b|RYoWFJT;^~(Z7KLI&Yg009DY4B)cqMswqmooQjFfLsG_Xa zm4;T-SR~UMM$weKR*yB}#f0Sby46S)lae>;sYbe(mb_WdH2R8tl8@E1ja)G&c~@b_(T`EDxfbwlgSaWO)P1eRdY*tSk@N z>}jQtyG1KHU0|Vk#&NwS2dBGD-oQKcbgRrgCz8MZ!Hx&=o)nZCb0xote_pv%S`zuo z%2}V6Yi`MRJRf!SnlCP}cXJ60a`be;2o0~@ayU?tB zqtlvedz|DEStoXCXAdW-yq;R6OQlM^?0FTZUN4nyH!Luoqekn9(2!X?o(>XcsW?w8 zzN3w?$45}6xp3M{P9-k8P7#*|DQI#sypIIb#d5!LMzZsy*kU};HBm>;2|V65247&d zYHPM`8@74L?5K0w0@hO&lq1T8p^g@4fjVP)BR;cZfwn=JXD_qak#mZzMXhE`e%V(8 zW^2QWug$A`I^wGg^*VZuFQ@}bpp=Z+FH%i+SQgy4Z`uj!%#Q`iv@y@nByN)uhDN7v zp)r$f6Yf*8E7i6Z)Q%qLwl$;jcYO z%jrHlGpsD^rGXN%L&NS{P*KXtlIiDq{dRVNNm)Os^wupvX|)UMLw&}%;)R(*p677C)^s1|yvakW=G9!!FS`{d)M~CD zYE{SQ8=)Oq&y?$JCu*kE9hYw+snO!ii(Z&qa#Zqd-X~5ZMt0pLtEC(BJn;$muIsg$ zo)a2nd%ENo%*Ye6;+4c`ag5r;E6sXIcCyMpWkspQ9gJ1kj*_-PaKQ}JK;N$fMn{|7LC$3B*OWj#0{4PGT+TLbsxZ2__-ktiv-sT@;?}?L z6A#BGcFV)ZcnrdV8mi-yk=Nh3wvVj+^_u5bnvK?hH$*QsCVrcWi{9YC9_XaJW}8=> zl3S|T9!lj#tM1q(*w4cmi?+aRH(FOh1J$o4j1kmRi5>i zQ0dVbR4g*v5E}%Ii%Dxe=8)S;$Cc*(7#{D3ATXYy!eAYJ?q?qrIQNuUNT57Fs1|=xD;`-*>iB@ zUMgU)t-heavg*}uzmeZ4&W{}x8#0gWAOjcD{6V}z6G+zbc^JbRuI}-1Nckj@`^1gR z=7*&Zn&no@aqTcgRx3GkaKnm^E+ix-E)yJ*a%k1dR~%mQYF8ccyUc-O$;jAJ5E_o$ zfP}{@@J}0Ktug0ejQR~c-VlglTFhdb;Gq`Fh*N!gGlZOEN)3-UF5c|h(q-3`DL{v< zUW@yhuit=6&BLZ}tk`vi8f~Qp@}(M>mD+U@bJu(=04!k-<~r65C9q)9q_4wl;Lsu! z7(D~q0#n1`Qv(W_ckVc}Qc~+cx08+<2WhAh@|e@H0Latn!_LHmc+dwFqTMl3LT!)? z`li%KHhc4&GCOrn@fGZ{zQE)gm?q!m0@KJ%^p4RneQP=iKQID)N}W=p>4UQQ0XC)D zIv_E7g9X_sIz+~4F|mD6gxS%DR;-KIwI4PMSq(8W;aC0D=BH?ZNZMN-{*6gJWJ5PzNTvlq}Fekx_rgny7EbF2(*!+#f^SkFUi2F~A7x0m|(}HHwqks*wHs zO32pTAIWywa>(w=nEj@2kjJlz*+8_Fn610lmSXk~{22V7PL5?K=wu)W>+Te4Gc)Sf z7^i>@C#Ue;F(q#xaI2k|znlZ@xiO7o^_E*SLY7fJ}Gx7AH%L#`K8|zZmaE! z&hQ8ECFAcCvm=LNGPU53h`R=N-z!mKIj@lT_;HB^|1{R{chMOf73W7~M$rYd9DtA? z1S#m!z7c*&DEN=OAz8*1U}+Y`bfwKp&Gs??Q&Lb&kqj3S84`<0DlIi%YjFac_X$fp znmqCiRY$&|%S0`XL`|F+onSrjh0(`1_-xwfnLg{q@Wd&+;9#k?_;7Tmk?^fvaxp(P z!^+Zy$+6GSw&VA~LE!X?YlShl>0Wh^)!v7?C^a}n+!I8OgM>Q8`cSi*K0i-o!>*O9 z9$=GF1Aa>UW8;YVB(YD4(M{XdCQxLukMMW|vPz~O%R#hP2bd-5n{ zqNcM%DDC^0s~4nWOl;Y_n9Nn`Z4dDc{fPHAKiVbZ-O7Vh_sHMx4mfl%A^qcE*C%5!p&@JPDJMs--MoWV+%i=f~I^s;C1m90gVBMXIW?sOr}IauzA zzin%7l1h|DubnKp^gi*@-G4qH&!fgy+za=Z$adax`hr`(?c6Z~8Q3Y^_^fzq`?2gH z8bDg^CF1(TN82}#uFQ8?Qz`{$aB}g%d-AGGlTddJlBJRO%aRXcOEXxtW~oe=P!+Eg zUf8)}ht-;E*GSDTiSeC}Ox$j+MTkS$lpJlX4lQaLLp&zl+nx~b?>ssw!;du5QZ}bM zb`|Cb%?2gsJnCU(elw!D>%@Ih3BItB6zwPe%OF_s^<{aqi8Rd_@$Rm>Sy%jX*VgWf zD2J9;E%{}R8?478qmmS;da=sOtxKVa4>G=E^LY{{99*Luj)V@mJ|^dJI(eZ^SF_N( zOm|;TI(6iGbcA3xxk0aimn*=hjV}>@U9LpdNv;Fovy0!q_m`}@2khxWe>5K%<6a?+ z`$0nMWYeui--Dt{%(GP0CGr3f!YF=$$ZjIk75@s6uY!chqxELlUs8XO#IJ%ZAAXuw zm~iiDHru-tw#PG!N^LHXx7%KrzD)=HK;f78sr?V5KlH`+|*2$&* zCmB93TNsv$=Z})^BoP$^H{D9Db!9yItbm=8pT;7TUbnhxyk2w5^(xm$ zbAmJnqc3Mwr`f=#i&|wIX$v5X>!>UadDZ%Obg$#%hc1>q+ydm4nsbBdbF1syey#4w zlqBpUxnK5XyzxjHqw$W3S0^`KAU~1c!=kTUgxnD<-!`n|$Q8o){zlVo*PR38*xn`( XT)1WO9(}t2)lgQmsccu-o;~+(fuNx5 diff --git a/sgl/sampler/sampler.py b/sgl/sampler/sampler.py index c1b5b0a..351f2c7 100644 --- a/sgl/sampler/sampler.py +++ b/sgl/sampler/sampler.py @@ -5,12 +5,14 @@ from sgl.sampler.base_sampler import BaseSampler from sgl.sampler.utils import adj_train_analysis -from sgl.tasks.utils import sparse_mx_to_torch_sparse_tensor +import sgl.operators.graph_op as GraphOps # import metis import random from sklearn.model_selection import train_test_split +LOCALITY_KWARGS = {"min_neighs", "sim_threshold", "step", "low_quality_score"} + class FullSampler(BaseSampler): def __init__(self, adj, **kwargs): """ @@ -34,15 +36,14 @@ def __init__(self, adj, **kwargs): self.pre_sampling = False def _preproc(self, **kwargs): - allowed_kwargs = {"pre_probs", "prob_type", "layer_sizes", "num_layers", "replace", "device"} + allowed_kwargs = {"pre_probs", "prob_type", "layer_sizes", "num_layers", "replace"} for kwarg in kwargs.keys(): - assert kwarg in allowed_kwargs, "Invalid keyword argument: " + kwarg + assert kwarg in allowed_kwargs or kwarg in LOCALITY_KWARGS, "Invalid keyword argument: " + kwarg if "layer_sizes" in kwargs.keys(): - if isinstance(kwargs["layer_sizes"], int): - self.layer_sizes = [kwargs["layer_sizes"]] * kwargs.get("num_layers", 2) # default 2-hop - else: - self.layer_sizes = kwargs["layer_sizes"] + layer_sizes = kwargs["layer_sizes"].split("-") + layer_sizes = [int(layer_size) for layer_size in layer_sizes] + self.layer_sizes = layer_sizes else: raise ValueError("Please provide layer sizes in the form of either a list or an integer!") self.num_layers = len(self.layer_sizes) @@ -56,9 +57,20 @@ def _preproc(self, **kwargs): self.probs = col_norm / np.sum(col_norm) elif prob_type == "uniform": self.probs = np.ones(self.adj.shape[1]) - - self.replace = kwargs.get("replace", False) - self.device = kwargs.get("device", torch.device("cpu")) + elif prob_type == "locality": + """ + This sampling strategy refers to GNNSampler [https://github.com/ICT-GIMLab/GNNSampler] + """ + min_neighs = kwargs.get("min_neighs", 2) + sim_threshold = kwargs.get("sim_threshold", 0.1) + step = kwargs.get("step", 1) + low_quality_score = kwargs.get("low_quality_score", 0.1) + locality_score = adj_train_analysis(self.adj, min_neighs, sim_threshold, step, low_quality_score) + self.probs = locality_score / np.sum(locality_score) + else: + raise ValueError(f"Don\'t support {prob_type} probability calculation. " + "Consider pre-calculating the probability and transfer it to pre_probs.") + self.replace = kwargs.get("replace", True) self.adj_t = self.adj.transpose() def sampling(self, batch_inds): @@ -71,15 +83,17 @@ def sampling(self, batch_inds): n_id: global node index of each node in batch adjs: list of sampled adj in the form of 2D tensor [2, M] where M = number of edges """ - all_adjs = [[]] * self.num_layers - cur_tgt_nodes = batch_inds.numpy() - for layer_index in range(self.num_layers-1, -1, -1): + all_adjs = [] + + cur_tgt_nodes = batch_inds.numpy() + for layer_index in range(self.num_layers): cur_src_nodes, adj_sampled = self._one_layer_sampling(cur_tgt_nodes, self.layer_sizes[layer_index]) - all_adjs[layer_index] = adj_sampled + all_adjs.append(adj_sampled) cur_tgt_nodes = cur_src_nodes - all_adjs = [sparse_mx_to_torch_sparse_tensor(adj) for adj in all_adjs] - return {"source_n_ids": cur_tgt_nodes, "sampled_adjs": all_adjs} + all_adjs = all_adjs[::-1] + + return {"n_ids": cur_tgt_nodes, "sampled_adjs": all_adjs} def _one_layer_sampling(self, v_indices, layer_size): @@ -94,31 +108,34 @@ def _one_layer_sampling(self, v_indices, layer_size): neis = self.adj_t.indices[st_indptr: ed_indptr] # neighbor range p1 = self.probs[neis] p1 = p1 / np.sum(p1) - sample_size = min(ed_indptr-st_indptr, layer_size) - e_ids = np.random.choice(np.arange(st_indptr, ed_indptr), sample_size, self.replace, p1) + if self.replace is False: + layer_size = min(ed_indptr-st_indptr, layer_size) + e_ids = np.random.choice(np.arange(st_indptr, ed_indptr), layer_size, self.replace, p1) src_nodes = self.adj_t.indices[e_ids] ret_edges.append(e_ids) ret_nodes.append(src_nodes) - + return self._adj_extract(v_indices, ret_nodes, ret_edges) def _adj_extract(self, tgt_nodes, src_nodes, e_ids): row, col, data = [], [], [] unique_src_nodes = np.unique(np.concatenate(src_nodes)) + unique_src_nodes = np.setdiff1d(unique_src_nodes, tgt_nodes) + # Similar to PyG, the target nodes are also the source nodes of the same layer. + # Guarantee the tgt_nodes inds are always at the beginning. + unique_src_nodes = np.concatenate((tgt_nodes, unique_src_nodes)) # global id to local id - nid_mapper_tgt = {tgt_nodes[i]: i for i in range(len(tgt_nodes))} nid_mapper_src = {unique_src_nodes[i]: i for i in range(len(unique_src_nodes))} num_tgt_nodes = len(tgt_nodes) for i in range(num_tgt_nodes): tgt_node = tgt_nodes[i] num_edges = len(e_ids[i]) - col.extend([nid_mapper_tgt[tgt_node]] * num_edges) + col.extend([i] * num_edges) for j in range(num_edges): old_ptr = e_ids[i][j] src_node = self.adj_t.indices[old_ptr] row.append(nid_mapper_src[src_node]) data.append(self.adj_t[tgt_node, src_node]) - row, col, data = np.array(row), np.array(col), np.array(data) adj_sampled = sp.coo_matrix((data, (col, row)), shape=(len(tgt_nodes), len(unique_src_nodes))) @@ -132,11 +149,18 @@ def __init__(self, adj, **kwargs): self.pre_sampling = False def _preproc(self, **kwargs): - allowed_kwargs = {"pre_probs", "layer_sizes", "prob_type", "min_neighs", "sim_threshold", "step", "low_quality_score"} + allowed_kwargs = {"pre_probs", "prob_type", "layer_sizes", "replace", "adj_process"} for kwarg in kwargs.keys(): - assert kwarg in allowed_kwargs, "Invalid keyword argument: " + kwarg + assert kwarg in allowed_kwargs or kwarg in LOCALITY_KWARGS, "Invalid keyword argument: " + kwarg + + if "layer_sizes" in kwargs.keys(): + layer_sizes = kwargs["layer_sizes"].split("-") + layer_sizes = [int(layer_size) for layer_size in layer_sizes] + self.layer_sizes = layer_sizes + else: + raise ValueError("Please provide layer sizes in the form of either a list or an integer!") + self.num_layers = len(self.layer_sizes) - self.layer_sizes = kwargs.get("layer_sizes", [1]) if "pre_probs" in kwargs.keys(): self.probs = kwargs["pre_probs"] else: @@ -157,8 +181,13 @@ def _preproc(self, **kwargs): locality_score = adj_train_analysis(self.adj, min_neighs, sim_threshold, step, low_quality_score) self.probs = locality_score / np.sum(locality_score) else: - raise ValueError("Only support two types of probability calculation: normalize_col and uniform.") - self.num_layers = len(self.layer_sizes) + raise ValueError(f"Don\'t support {prob_type} probability calculation. " + "Consider pre-calculating the probability and transfer it to pre_probs.") + self.replace = kwargs.get("replace", False) + + if "adj_process" in kwargs.keys(): + graph_op = getattr(GraphOps, kwargs["adj_process"]) + self.adj = graph_op(r=0.5)._construct_adj(self.adj) def sampling(self, batch_inds): """ @@ -168,19 +197,20 @@ def sampling(self, batch_inds): Sample fixed size of nodes independently at each layer. Outputs: cur_out_nodes: array of source node inds at the first layer - all_support: list of sampled adjs (torch sparse tensor) at each layer + all_adjs list of sampled adjs (torch sparse tensor) at each layer """ - all_support = [[]] * self.num_layers + all_adjs = [] cur_out_nodes = batch_inds - for layer_index in range(self.num_layers-1, -1, -1): - cur_in_nodes, cur_support = self._one_layer_sampling( + for layer_index in range(self.num_layers): + cur_in_nodes, cur_adj = self._one_layer_sampling( cur_out_nodes, self.layer_sizes[layer_index]) - all_support[layer_index] = cur_support + all_adjs.append(cur_adj) cur_out_nodes = cur_in_nodes - all_support = [sparse_mx_to_torch_sparse_tensor(adj) for adj in all_support] - return {"source_n_ids": cur_out_nodes, "sampled_adjs": all_support} + all_adjs = all_adjs[::-1] + + return {"n_ids": cur_out_nodes, "sampled_adjs": all_adjs} def _one_layer_sampling(self, v_indices, output_size): # NOTE: FastGCN described in paper samples neighboors without reference @@ -200,9 +230,10 @@ def _one_layer_sampling(self, v_indices, output_size): neis = np.nonzero(np.sum(support, axis=0))[1] p1 = self.probs[neis] p1 = p1 / np.sum(p1) - # NOTE: Should sampled contain repeated nodes? + if self.replace is False: + output_size = min(len(neis), output_size) sampled = np.random.choice(np.arange(np.size(neis)), - output_size, True, p1) + output_size, self.replace, p1) u_sampled = neis[sampled] support = support[:, u_sampled] @@ -231,10 +262,10 @@ def __init__(self, adj, features, target, **kwargs): self._sampling_done = False def _preproc(self, **kwargs): - allowed_kwargs = {"clustering_method", "cluster_number", "test_ratio"} + allowed_kwargs = {"cluster_method", "cluster_number", "test_ratio"} for kwarg in kwargs.keys(): assert kwarg in allowed_kwargs, "Invalid keyword argument: " + kwarg - self.clustering_method = kwargs.get("clustering_method", "random") + self.cluster_method = kwargs.get("cluster_method", "random") self.cluster_number = kwargs.get("cluster_number", 32) self.test_ratio = kwargs.get("test_ratio", 0.3) self._set_sizes() @@ -251,7 +282,7 @@ def sampling(self, batch_inds): Decomposing the graph, partitioning the features and target, creating Torch arrays. """ if self._sampling_done is False: - if self.clustering_method == "metis": + if self.cluster_method == "metis": print("\nMetis graph clustering started.\n") # self._metis_clustering() else: @@ -260,7 +291,7 @@ def sampling(self, batch_inds): self._general_data_partitioning() self._transfer_edges_and_nodes() self._sampling_done = True - return {"adj": self.sg_edges, "x": self.sg_features} + return {"sampled_adjs": self.sg_edges, "x": self.sg_features} else: return {} diff --git a/sgl/tasks/__pycache__/node_classification_sampling.cpython-37.pyc b/sgl/tasks/__pycache__/node_classification_sampling.cpython-37.pyc index 65bc66409d5fc091593469d445756a1203926580..68d145236ba442c4e7f854a9fd25d9d66ae42ca0 100644 GIT binary patch delta 1544 zcmZ8hOK%)S5bmD$dtYAf+78CXUYQNV4iXP3iY(zI4l>9>5MdaQHmmIk+4aoq($k3u z^vV)zkpr@1XHF1BvdN7D5?F}e0pc__j!4|#z>(@%8;8uOzpAdTuCA)C>fh#mop;w< z*Fx~wSKbbi19u94d9bv+g+ncpC4^C=`vyh?5jiv?%eQh~3GK-7ot#%gH}ZThD)@z{ z=oh1sUy90pIjZ;-jIJZ15q%dCJ#a>{KXLHQgSAU=Sb%&?kS`JEKDvd8LeyQ;mjh+2 zJ;J_93}WsgUn3T=f$H~>f_&p$)O53%v@OppWkwu^t+1aYf1C%4n;lz30iU2M3>ikU zpZb%kr^0FdNK8n(oMU>d#(x=qwkYCEeRZX=2%1ey8$jsFe{d1F%wR#nTJ26} zF8y9z-}zQTh=G?qJi-_;=}Pr@j! zpRAys-hIx_F zBFIZ1J(ZVF>aV33I(y8;O9P++UCDV+&622@~l*+KTE7RpO*C z>hrp}ZGdavr8o4iF1UvaTqj-)9br-c-UM0%Y5^?)wSks_IzTI_W6bPqaC@k7XQ=Jt zPd2t7!-tExbB!bXkun3>R#b(;x=}!@>F?(AwX!IfR|E*9OdZf}%;NI`dn#A*I-}-5K zR_IO$s0+{q)C4?}uG*JyBfV#TzV!bxqzkYvsx%1!5LC2*4XKEyk7u06xUveeW`$OP zWJ++i9m4UyxRsmNKtOZ<)&-Xz{D$DJ3lO_9mGx=ocBb8hF6(S(_Mo4z0qu8!BpC-K zTAi*5$X&hyT+0jRB4i$}> z8F1DhpzI!fNoa^Lso0mftyU*&CrKxOSEzMn8rpPOD2e8z;wciA6Y;{rGm5>kqZrq6 zd9z9{3!M)Epo>M6%W<4ZFl?AzglT^O*GQqGQC2xsSiW^9?yzn@Zikt5#?y1@&xP)* iXA$Py_=UI}^~s$uSjAu{DFVRVUhEgR1aKeQ?|8Csl*U+9I?sh*OC zooawiY5=;a3FxI3U?DXb>6(0#Nj-_lz!Xz5E=>q{WpP!i(TW&If6`?!k!uIv$;(u8 zAvE{IBjrJ@2(E)W9R(ngCu!8e27ireB37?2AA^B9f>U!6J zv;83U^slek9#{)(0jv$S2-X2x8ZS*1OpPkcA8C-MovK|0qOGR( zP)(hocAtK}D()Eb>o=YlPM9Bjs))TK;rvQ9>C9Td&u)H#Bet|aX!;4|X9_39fXtcMkZQ#hm!c)OSI zc_grYGc)1McCXQ9xA-g2_fKLGil41ndQ?2NemZg(XI2pA5OjoDgsRxLPtyhQyM5=x z|BH^-V0;N$90Cw1SA!aoL#qd0IuGggyTH{edw+@|=~`;^*Zy>NUxlX|_Z_&)E)@fzsYAKbnSF2~228L;Uf z;K^Gk+JcI|1mKY-ijl6bS&e(QM>xi|d+oB?6C!r&~0MX?8fn>hG~J6Vwb E1D&fs7ytkO diff --git a/sgl/tasks/__pycache__/utils.cpython-37.pyc b/sgl/tasks/__pycache__/utils.cpython-37.pyc index d697cf7953104627438a3b4b92c940dbcaaa11dc..f26496f07063693762bbbef11d5827416bf6bebc 100644 GIT binary patch delta 20 acmew!@;!vxiI Date: Mon, 13 Nov 2023 05:39:25 +0000 Subject: [PATCH 03/28] update vanillagcn to vanillgnn; reorganize the sampling abstract. --- examples/clustergcn_nodeclass.py | 2 +- examples/configs/clustergcn.yml | 1 + examples/configs/fastgcn.yml | 18 +- examples/configs/graphsage.yml | 17 +- .../{vanillagcn.yml => vanillagnn.yml} | 8 +- examples/sample_based_nodeclass.py | 20 +- .../__pycache__/base_model.cpython-37.pyc | Bin 11100 -> 9590 bytes .../__pycache__/simple_models.cpython-37.pyc | Bin 11685 -> 11524 bytes sgl/models/base_model.py | 120 +++------- sgl/models/homo/__init__.py | 4 +- .../homo/__pycache__/__init__.cpython-37.pyc | Bin 715 -> 715 bytes .../__pycache__/clustergcn.cpython-37.pyc | Bin 937 -> 1183 bytes .../homo/__pycache__/fastgcn.cpython-37.pyc | Bin 948 -> 869 bytes .../homo/__pycache__/graphsage.cpython-37.pyc | Bin 1134 -> 1045 bytes .../__pycache__/vanillagcn.cpython-37.pyc | Bin 1072 -> 1422 bytes .../__pycache__/vanillagnn.cpython-37.pyc | Bin 0 -> 1557 bytes sgl/models/homo/clustergcn.py | 18 +- sgl/models/homo/fastgcn.py | 12 +- sgl/models/homo/graphsage.py | 10 +- sgl/models/homo/vanillagcn.py | 16 -- sgl/models/homo/vanillagnn.py | 26 +++ sgl/models/simple_models.py | 48 ++-- sgl/operators/graph_op/__init__.py | 2 +- .../__pycache__/__init__.cpython-37.pyc | Bin 336 -> 330 bytes .../__pycache__/rw_graph_op.cpython-37.pyc | Bin 1000 -> 1000 bytes sgl/operators/graph_op/rw_graph_op.py | 4 +- .../__pycache__/base_sampler.cpython-37.pyc | Bin 764 -> 963 bytes .../__pycache__/sampler.cpython-37.pyc | Bin 12551 -> 14040 bytes sgl/sampler/base_sampler.py | 8 +- sgl/sampler/sampler.py | 215 +++++++++++------- ...ode_classification_sampling.cpython-37.pyc | Bin 4306 -> 4345 bytes sgl/tasks/__pycache__/utils.cpython-37.pyc | Bin 10871 -> 10787 bytes sgl/tasks/node_classification_sampling.py | 62 ++--- sgl/tasks/utils.py | 16 +- 34 files changed, 327 insertions(+), 300 deletions(-) rename examples/configs/{vanillagcn.yml => vanillagnn.yml} (71%) create mode 100644 sgl/models/homo/__pycache__/vanillagnn.cpython-37.pyc delete mode 100644 sgl/models/homo/vanillagcn.py create mode 100644 sgl/models/homo/vanillagnn.py diff --git a/examples/clustergcn_nodeclass.py b/examples/clustergcn_nodeclass.py index 2cb56f5..96ac77b 100644 --- a/examples/clustergcn_nodeclass.py +++ b/examples/clustergcn_nodeclass.py @@ -13,7 +13,7 @@ "--device", type=int, default=0, help="gpu device id or cpu (-1)" ) parser.add_argument( - "--config_path", type=str, default="./configs/fastgcn.yml", help="save path of the configuration file" + "--config_path", type=str, default="./configs/clustergcn.yml", help="save path of the configuration file" ) args = parser.parse_args() config = yaml.safe_load(open(args.config_path, "rb")) diff --git a/examples/configs/clustergcn.yml b/examples/configs/clustergcn.yml index f861b10..e105376 100644 --- a/examples/configs/clustergcn.yml +++ b/examples/configs/clustergcn.yml @@ -6,6 +6,7 @@ sampler: cluster_method: "random" cluster_number: 10 test_ratio: 0.3 + post_sampling_op: "LaplacianGraphOp" model: hidden_dim: 128 dropout: 0.5 diff --git a/examples/configs/fastgcn.yml b/examples/configs/fastgcn.yml index 97babcc..71401e6 100644 --- a/examples/configs/fastgcn.yml +++ b/examples/configs/fastgcn.yml @@ -4,12 +4,18 @@ dataset: root: "/home/ssq/test_data/" split: "official" sampler: - name: "FastGCNSampler" - inductive: False - adj_process: "LaplacianGraphOp" - layer_sizes: "256-256" - prob_type: "normalize" - replace: True + training: + name: "FastGCNSampler" + inductive: False + pre_sampling_op: "LaplacianGraphOp" + layer_sizes: "256,256" + prob_type: "normalize" + replace: True + eval: + name: "NeighborSampler" + layer_sizes: "-1,-1" + pre_sampling_op: "LaplacianGraphOp" + cached: True model: name: "FastGCN" hidden_dim: 128 diff --git a/examples/configs/graphsage.yml b/examples/configs/graphsage.yml index 0236a38..fa4b076 100644 --- a/examples/configs/graphsage.yml +++ b/examples/configs/graphsage.yml @@ -4,11 +4,18 @@ dataset: root: "/home/ssq/test_data/" split: "official" sampler: - name: "NeighborSampler" - inductive: False - layer_sizes: "5-5" - prob_type: "normalize" - replace: False + training: + name: "NeighborSampler" + inductive: False + layer_sizes: "5,5" + prob_type: "normalize" + replace: False + post_sampling_op: "RwGraphOp" + eval: + name: "NeighborSampler" + layer_sizes: "-1,-1" + post_sampling_op: "RwGraphOp" + cached: True model: name: "GraphSAGE" hidden_dim: 128 diff --git a/examples/configs/vanillagcn.yml b/examples/configs/vanillagnn.yml similarity index 71% rename from examples/configs/vanillagcn.yml rename to examples/configs/vanillagnn.yml index ae0c76e..4d91943 100644 --- a/examples/configs/vanillagcn.yml +++ b/examples/configs/vanillagnn.yml @@ -4,10 +4,12 @@ dataset: root: "/home/ssq/test_data/" split: "official" sampler: - name: "FullSampler" - inductive: False + training: + name: "FullSampler" + inductive: False model: - name: "VanillaGCN" + name: "VanillaGNN" + basemodel: "SAGE" hidden_dim: 128 dropout: 0.5 num_layers: 2 diff --git a/examples/sample_based_nodeclass.py b/examples/sample_based_nodeclass.py index 88bf386..1fc86ae 100644 --- a/examples/sample_based_nodeclass.py +++ b/examples/sample_based_nodeclass.py @@ -2,8 +2,8 @@ import argparse import sgl.dataset as Dataset -import sgl.models.homo as HomoModels import sgl.sampler as Sampler +import sgl.models.homo as HomoModels from sgl.tasks import NodeClassification_Sampling @@ -21,17 +21,23 @@ dataset_kwargs = config["dataset"] classname = dataset_kwargs.pop("classname") dataset = getattr(Dataset, classname)(**dataset_kwargs) - sampler_kwargs = config["sampler"] - if "inductive" in sampler_kwargs.keys(): - inductive = sampler_kwargs.pop("inductive") + training_sampler_kwargs = config["sampler"]["training"] + if "inductive" in training_sampler_kwargs.keys(): + inductive = training_sampler_kwargs.pop("inductive") else: inductive = False - sampler_name = sampler_kwargs.pop("name") - sampler = getattr(Sampler, sampler_name)(dataset.adj[dataset.train_idx, :][:, dataset.train_idx] if inductive else dataset.adj, **sampler_kwargs) + training_sampler_name = training_sampler_kwargs.pop("name") + training_sampler = getattr(Sampler, training_sampler_name)(dataset.adj[dataset.train_idx, :][:, dataset.train_idx] if inductive else dataset.adj, **training_sampler_kwargs) + if "eval" in config["sampler"].keys(): + eval_sampler_kwargs = config["sampler"]["eval"] + eval_sampler_name = eval_sampler_kwargs.pop("name") + eval_sampler = getattr(Sampler, eval_sampler_name)(dataset.adj, **eval_sampler_kwargs) + else: + eval_sampler = training_sampler model_kwargs = config["model"] model_name = model_kwargs.pop("name") model_kwargs.update({"device": device}) - model = getattr(HomoModels, model_name)(dataset, sampler, **model_kwargs) + model = getattr(HomoModels, model_name)(dataset, training_sampler, eval_sampler, **model_kwargs) task_kwargs = config["task"] task_kwargs.update({"device": device}) test_acc = NodeClassification_Sampling(dataset, model, **task_kwargs).test_acc diff --git a/sgl/models/__pycache__/base_model.cpython-37.pyc b/sgl/models/__pycache__/base_model.cpython-37.pyc index d10861df33f4f1a640f5cd340b6210a54370cdeb..a503947e23b0eec4cb0e09fc3ad0d9e324cabe08 100644 GIT binary patch delta 1916 zcmYjRU2GIp6uxJ6c6N4Vce>QVZm}(;K+6=UfCQuh+6pQ^Z77whD>zJdX4@^@AMZ?| zkl9oP#0R5^H)@QL>^}KmLPB_kkm$1s#s@R9of1=Y2Khe!laT@kljJq3E(aI4mzY8Ielnb~}obxXSX1qdG~X2}g5uq^gs0 z(nynz;dCI?t}~hC^=W3V5w+wF*em*SS=E6vQot-(9L(hksFNn~f6`3yIDUxlBTx`R z!*Va!i%YieTIH(aa`KC(NVC@Rxm~K1Dsz@+muri}RBPdy=)$vG+`!{vUfb0Q3Jg#n z1u1b;+a>?aUA$thrDO?_S8)_CKaTcSbVw> zJfECqz81*eF(>`KbOk;JGFFtZ4-D~5*JdM3MWuHgO0FmFb(NteM%Er1@`yY~tjJ zBNI68izCGn5(nTSo%|uF86;;C(UkMt`OD+7lx+`9vRwW7A5m(So)tQ5_2`<+2< zH;Rv$7#gh0?;tB4^>p?8hs>0$#`EI$!BNPI=Y|Gk>R&n!x|kg5bKfI9 zrInu{I7{$8fkkkRz$Ta_NDw?thC^cXv{PESOHd@3BbX;xCb%T-4`s!*p{{(H#0mk$ z#j6B0f)5C2s(3#EC-4Y-f+YkjEuV?>kXBV;bSeXk}HpFg~P7Xfm znZgo@b+*RK7I0f6HJ%UyA|=x6yljh%$gZ<;QshJ)vm^RM0dq?168)IdFEEppGt;b8 zR9BJ{>TsdZ7JSACnH$3}yLcS9B^Y2L$8#$mF& zpMXVxJd$aEPTKV=OU=bvea>w(J4LtI@PjVt7-(~{(wxVR%o20f^;Rm2%ay=$mm0#8 zIZ~4+&_~-zx89ITZbdAp{m#+d1<)AJ8fl*97PkxxZmAjP=(vrIkS(wU4kI~?ln;`8 zx+%S$Icn1>&K(DN8pAgLn6Wy1y5f7&r!HK2^x+FIVDZRyi7F>1{Ai|e2E)JGS=wz+ zy7kJEM`ul(0Qcuf;%;Z{khtj_;-6OEafad#LFNhdwo__fuEv1VkHT$Y(Axh6z>Dy9 zgNJ;Rw_;&z!w|SUn?}nFO~Gy?<{DR z0r?(~xwEgrCt zukq~`a7nr7ov1DyD7NQ7>1yw(BreL{F524Tac!=R(xv6hpoP~WZS$YV7p|+&C%^my;+~j$U4FKym9PF6b z%X0N}>K=Yxy_Pz74YwE-H84Xa3_O~XC`0iLE2Na7rq+Sc+^wEdot#S*X4b0&%a|L@ zk{#7fO0}FMpetWuzT#JcK*|g1zVw5A%h>R#VlwW;@-d66F#QT&Q@PA5{446anNK~G zz+Q^D45U6Pg?Mpn!I1BvB=8%SEQc{+>|w1q<``xRyRp#P!!}H{FFVqXhlwzb5|XTc z9TL_JL;fJ-8`dUk!4q(D;(U?eB`2^(x9F`Mtl5H1l$IT`4btRn8ljDKvOWcE326Hr zY3XW6p$Y<}dL&3JJOL5|`JvE7QlP zTA7_n=XjVA$wAgnF2^_0(Qb%3DcEzyEJCq$ye>g)PPQ+Q5pgg%UyH%^Z$y#6fp7d@ zvB=yd7N3b?0e$aocYsWLU}re~S2V(8XwMobn<@2|?08z9MD9EBI!c7Eem^pxPUbFg zrC!dRS3G~XEyu83!bd|e7r1rsl(PCYO6pM^JCK!H9UVJry`KOrYR@LUKYc6lnC4u~ zhL$WmUq44VPSQ?&dC6U@-14N~u`11`R~O|BaeHX|nz?i?8`a{^lI95lj}xH7wc+wI zuyS&%`af09d9$-#HK?t4u72XOM691vujdEaUm#RB0UqhnkND%Deig$f%?yv!!k~X% zKjUl{&*KRffTuyvbJHm3M{V;VK8B}Vz+)$gh>MfZ9Kx@{NLGvmYRc3-kUeF#+}lQ!w^4A@!q{k}i(8FjnxDu3%v zY1c0~zjo(?{&k*t2dq)?=BxLrqvQSRoiRt%M;}xF94PWXsQ%)3Tuy;Woh_C}zDQ@4 z0ni=r$rVJUWSFSui03E8=Xgb(8@!&oPBa~;1HgW&el++t|EpTrbEkL>>rz7YR7suw ztLo&?1N<>{Yv>ezU%fSS%pyK{P5oo&;XN}z-^B180K_tns`ODiCTG-ZeOJ^E_dd_t zaBbhNt0Y2IE7DFiR=r<1itje{>hN*6{Hx)2BhOvFPgbyvrmZdgu#HgD{uo!MMq9i} zJ$874_adZkwG?y!h)jCwz0qUbQxgZqU^5WIj=rAFxaN5 zxy|(@T9Q+;7eiU6WrIMIz_SFt1kf@4+A4b1O3k}jDn$4DN!pYId;$RhD&2C0068dc z5_pQhEyj*_TjwaPR|z~#ASBQNKw9T3O>gVA?@Z!uu9{;%M;AJA@YW%nz9@4fOhGSM zkx_?bD0c`31Fi{Jw4J9BqRkPEp;K&QBtpL@(v!v+B+9(4w5nwwnCN{juf3 GBmV*?+SjE3 diff --git a/sgl/models/__pycache__/simple_models.cpython-37.pyc b/sgl/models/__pycache__/simple_models.cpython-37.pyc index c73db554000041b2b71aa2a2f0d62a610a71ebc3..7d9a691429c7887be27ef6f343d6ade7984df1ad 100644 GIT binary patch delta 1622 zcmd5+TW=dh6rP!#b-ZruT9NEDsc?L$o9^Z!9)PMU3YP$tqF{o;MJht>HtRwxXH93f z5T)bVrD?e}0df97{J;wm59J5+sZTuc1NMQGCnUrJ7l}76=j@g!NFZJrYrZ)?bI#eB z?>oDny|#Tm_jER!VK{Do`C{vyb0_!36nu634E#E^l$GSiL92x4%TMdX1G79*Gm37TY6(552RTHlNVKp=)F_>!LqI2r%-~J;tyWoGB~Q$iM$1kmY+NC)j{rYZ zr{{l0Bi4-ZQJxVQ#GFW>wUs{*4HR^b7T|)tJm(3|*uh&PohJ5{6;ml2agpb29v{}uD9TdTV5vSM>=Dp) z8@!3hBK287x+dC@W<_3vZ|ci@05c~av z&=T9-(B=)4B9wbbH%Sr7HWG~cDV!tJyFKG_`Bl?o?)-g z_EMVi#lKgr9+VyYy=;=wPHO``nYLfB)Fz%flAZM+jI5S_Wuv9su%!m*RRO^Y0dAuknx|hNYcN04z?t4-?lCIYdVC>W>o`u!f+QcZSHvgBTBBoyTF%s@rD&>?$MO}t5%x7nlzaiXI!>-c zU-druKOh(o{DZLSL)6vt=o_wKR_B2cImSSTNtl44@3(HKOreqpv)P~&Q)3$uJ|>8_o=#aIKA>&rkQ>NqyV$y9wLO5JN2O4{TO7rkvY)H*Oyh=pA6CKei@C5$aIWMlD=6uvt-+x<`Nho z6MV&x&9Ipjo1UMMk#p$jTk;xSdIFIpK*~9lYioBMG%y}ln{CZYnF))&aJhI$ z*TG(g5+D>(#CMEE3n@Udb~rRsu*;r&EHnxxC2`S4GFfP^8PJ`g7xED4TXZljOI`?R zq#Zgt{Bee{77~aF>I6Enim|gy&LWhaYh;(FbMNLW+0koob>bpEFN^3DdBIdLR-6%K zb-ump%srUOwGfO~@wnk3dnv<1-crUw&B!w^oEt0^=T1+SiI|9kDO)q;VD`z#BH!{j2LD098Tagpm!Ag`1aw$MUzRIjE_cnL{UkvEjl zamv*reNXju_ygJrMTJA+Z{IJul*EAvZ{->&6*6miEpNc%@P(xI{)^k7e)9YHvbx)G zmM2xBv+WQ~`-#vr{Rt-_*a@($0)qC|{5t^H5LhE|v%;Ta7hitcShO`aW@uvrTA{O| z$tzeOE+$_ANg#+A4D4-ZUn{u87$hrrOO&of>uT#8v{N*mXvCobcjgLOij(sfbjw~~ zgdJFPLvmsKBiQ%MDsf61PZX)TO3o-l@LuFZJjZ>l4waENGp2 zXpj$RmK(73sg9nW`hA3kU@89$K?td#&h*j|m;@G%nF<_M0dSzobOgEoO$61uI6U9~ zFCKai^Ki7+r31I`A1eMHLV;EA@LR#(Y0IemF6b=9?xM2l-v#uC$Q09H$PdZ>5ec22 zBVdppQ}_E2v@p2^OXOU}(Kc09 i#!%hcs=v3HN9tm4n%`7Qy-zsA0L(x14w~Jy?SBKHE|OgU diff --git a/sgl/models/base_model.py b/sgl/models/base_model.py index c0d4c20..c6e2d93 100644 --- a/sgl/models/base_model.py +++ b/sgl/models/base_model.py @@ -69,60 +69,26 @@ def forward(self, idx, device): class BaseSAMPLEModel(nn.Module): def __init__(self, evaluate_mode="full"): super(BaseSAMPLEModel, self).__init__() - + self._evaluate_mode = evaluate_mode self._pre_graph_op, self._post_graph_op = None, None - self._sampling_op, self._post_sampling_graph_op = None, None + self._training_sampling_op, self._eval_sampling_op = None, None self._base_model = None - self._evaluate_mode = evaluate_mode - self._processed_feat_list = None - self._processed_feature = None - self._pre_msg_learnable = False - self._norm_adj = None - - @property - def pre_sampling(self): - return self._sampling_op.pre_sampling - - @property - def sampler_name(self): - return self._sampling_op.sampler_name - @property def evaluate_mode(self): return self._evaluate_mode - - def sampling(self, batch_inds, to_sparse_tensor=True): - sample_results = self._sampling_op.sampling(batch_inds) - adjs = sample_results.get("sampled_adjs", None) - if adjs is not None: - if isinstance(adjs, list): - if self._post_sampling_graph_op is not None: - adjs = [self._post_sampling_graph_op._construct_adj(adj) for adj in adjs] - if to_sparse_tensor: - adjs = [sparse_mx_to_torch_sparse_tensor(adj) for adj in adjs] - elif isinstance(adjs, dict): - if self._post_sampling_graph_op is not None: - adjs = {sg_id: self._post_sampling_graph_op._construct_adj(adj) for sg_id, adj in adjs.items()} - if to_sparse_tensor: - adjs = {sg_id: sparse_mx_to_torch_sparse_tensor(adj) for sg_id, adj in adjs.items()} - else: - if self._post_sampling_graph_op is not None: - adjs = self._post_sampling_graph_op._construct_adj(adjs) - if to_sparse_tensor: - adjs = sparse_mx_to_torch_sparse_tensor(adjs) - sample_results.update({"sampled_adjs": adjs}) - return sample_results + + def sampling(self, batch_inds): + if self.training: + return self._training_sampling_op.sampling(batch_inds) + else: + return self._eval_sampling_op.sampling(batch_inds) def preprocess(self, adj, x): if self._pre_graph_op is not None: - # We don't transform _norm_adj into the form of sparse tensor, - # as sparse tensors don't have strides. self._norm_adj = self._pre_graph_op._construct_adj(adj) else: - # For ClusterGCN, we have already processed subgraphs after sampling. - self._norm_adj = adj - self._pre_msg_learnable = False + self._norm_adj = adj if hasattr(self, "_pre_feature_op"): self._processed_feature = self._pre_feature_op._transform_x(x) else: @@ -137,50 +103,32 @@ def postprocess(self, adj, output): def model_forward(self, batch_idx, device, **kwargs): return self.forward(batch_idx, device, **kwargs) - def forward(self, batch_idx, device, **kwargs): - sampler_name = self._sampling_op.sampler_name - if self.training: - if sampler_name in ["FastGCNSampler", "NeighborSampler"]: - sampled_adjs = kwargs["sampled_adjs"] - n_ids = kwargs["n_ids"] - sampled_x = self._processed_feature[n_ids].to(device) - sampled_adjs = [sampled_adj.to(device) for sampled_adj in sampled_adjs] - effective_batch = batch_idx - output = self._base_model(sampled_x, sampled_adjs) - elif sampler_name == "ClusterGCNSampler": - batch_idx = batch_idx.item() - sampled_x = self._processed_feature[batch_idx].to(device) - sampled_adj = self._norm_adj[batch_idx].to(device) - effective_batch = self._sampling_op.sg_train_nodes[batch_idx] - output = self._base_model(sampled_x, sampled_adj)[effective_batch] - elif sampler_name == "FullSampler": - full_x = self._processed_feature.to(device) - full_adj = sparse_mx_to_torch_sparse_tensor(self._norm_adj).to(device) - output = self._base_model(full_x, full_adj)[batch_idx] - return output - else: - raise ValueError(f"{sampler_name} hasn't been implemented yet!") + def forward(self, batch_idx, device, **kwargs): + sampler_name = self._training_sampling_op.sampler_name if self.training else self._eval_sampling_op.sampler_name + if sampler_name in ["FastGCNSampler", "NeighborSampler"]: + sampled_adjs = kwargs["sampled_adjs"] + n_ids = kwargs["n_ids"] + sampled_x = self._processed_feature[n_ids].to(device) + sampled_adjs = [sampled_adj.to(device) for sampled_adj in sampled_adjs] + effective_batch = batch_idx + output = self._base_model(sampled_x, sampled_adjs) + elif sampler_name == "ClusterGCNSampler": + batch_idx = batch_idx.item() + sampled_x = kwargs["x"].to(device) + sampled_adj = kwargs["adj"].to(device) + effective_batch = kwargs["effective_batch"] + output = self._base_model(sampled_x, sampled_adj) + ret_full = kwargs.get("ret_full", False) + if ret_full is False: + output = output[effective_batch] + elif sampler_name == "FullSampler": + full_x = self._processed_feature.to(device) + full_adj = self._norm_adj.to(device) + output = self._base_model(full_x, full_adj)[batch_idx] + return output else: - if sampler_name in ["FastGCNSampler", "NeighborSampler"]: - full_x = self._processed_feature.to(device) - num_layers = self._sampling_op.num_layers - sampled_adjs = [sparse_mx_to_torch_sparse_tensor(self._norm_adj).to(device)] * (num_layers - 1) - sampled_adjs.append(sparse_mx_to_torch_sparse_tensor(self._norm_adj[batch_idx, :]).to(device)) - effective_batch = batch_idx - output = self._base_model(full_x, sampled_adjs, tgt_nids=batch_idx) - elif sampler_name == "ClusterGCNSampler": - batch_idx = batch_idx.item() - sampled_x = self._processed_feature[batch_idx].to(device) - sampled_adj = self._norm_adj[batch_idx].to(device) - effective_batch = self._sampling_op.sg_test_nodes[batch_idx] - output = self._base_model(sampled_x, sampled_adj)[effective_batch] - elif sampler_name == "FullSampler": - full_x = self._processed_feature.to(device) - full_adj = sparse_mx_to_torch_sparse_tensor(self._norm_adj).to(device) - output = self._base_model(full_x, full_adj)[batch_idx] - return output - else: - raise ValueError(f"{sampler_name} hasn't been implemented yet!") + raise ValueError(f"{sampler_name} hasn't been implemented yet!") + return output, effective_batch class BaseHeteroSGAPModel(nn.Module): diff --git a/sgl/models/homo/__init__.py b/sgl/models/homo/__init__.py index 2774fcc..541f34e 100644 --- a/sgl/models/homo/__init__.py +++ b/sgl/models/homo/__init__.py @@ -9,7 +9,7 @@ from .fastgcn import FastGCN from .clustergcn import ClusterGCN from .graphsage import GraphSAGE -from .vanillagcn import VanillaGCN +from .vanillagnn import VanillaGNN __all__ = [ "SGC", @@ -23,5 +23,5 @@ "FastGCN", "ClusterGCN", "GraphSAGE", - "VanillaGCN" + "VanillaGNN" ] diff --git a/sgl/models/homo/__pycache__/__init__.cpython-37.pyc b/sgl/models/homo/__pycache__/__init__.cpython-37.pyc index 6d1d9941205125e4ae42fec31c73ec244948a1b1..c6909b0e6f7581b4782e34d3235e439d17f90386 100644 GIT binary patch delta 32 mcmX@jdYYBniI!W;sQ~t-1cHEYZHRJ5A0n8Xm22K;@CSJ z`2mSEp~apJ2wQJhkOmVe@i?P2M}o8ba!Q~E1@9JFYAW< z`1t66WZbl@)3DVM>hBhlMHDpJvgTsx-@r`)-Nsv0;m*HW+Y8>ev_uL&BTyLbM5 z&~zTq-r)qodSesnBrD}CpF(+kah}LJsRYZQy2v+vek_v8c=^eUn;=hRqP6MdGPx8A z1g?s*sH-S6zOGB56t=*GjLq^{#h4keN+lrYQ?`i0iqdQ{jI*jzwos=k%sM-Q7Tfp| z{CZl*<9P7b&GI2cL&QxNQsW|Le3q#mDAW_H&FhTK(+s#cpQQroTVu?Yz}R>4do(&J zGBMKnY&6+FcrPz6MtUkoSh>__D@JJx$u!M}<)zvH$3ARrv7lQtfbZ(r61U-^@D$O( zg`C3Pg{`qUw(E4PUV$9BY5>f*2|qRT!s=S`8r9uWiBc7*(0U!t8e2r@)nFyw@jZ3 delta 589 zcmYLFyKdA#6rDQ{+Z%aTQ6dqF07^k&imgCO1msyP1==7%5xE&zJ7!}vURyKXCb$by zsx2aaK%nFwDBw4y;!~LGu*yhtb?%*`bI+N(!Ow7XIvj2gf4&}n$Pei4=<(sc?-Ws> zNeL+cVybD$Qr2Upc_~r>WI_~I;u}#SXV)wZRPd1`;qAqM{+lSrJ{C@l2BkvPT_X%Dc51g>eIyXl6 zHy24+EQ&@-7fYL!wJsKOS=9z@yB%3qwt1M(F<iJ>h~6WV UbkUo#QcIm50_vi_*!z#^KVPkajsO4v diff --git a/sgl/models/homo/__pycache__/fastgcn.cpython-37.pyc b/sgl/models/homo/__pycache__/fastgcn.cpython-37.pyc index ed054b5f969d3aff9dcb7192bb36cf78e545babd..910aa41eb97d334cce6c85552b35259fba85a9ae 100644 GIT binary patch delta 551 zcmY*VJx?4l5cSym#j#5u5)zR}`a=Yf9w7ljG~on7DA6_6f(>lsy*t@^RF){9prJTz zK}zmFq|g6wpMrmp89)Lo&6^p|$Mf@_^j}hJ7e$BR8Qvc1)k*Q`a1|596c|=RAAn+R zcrE$@*d9|tMK?@ET3m~Mq~asiO&)2!3qkDd^w53SE}cHw+CTWdyFXOgOzwaD%%}wq z`4ml{xebJdHQ?Dbf8bLF%yZAA@B%SvVh_I(A0ehb@v(xt_#yM@6#M0jV#v&U?HE>e zIw5xBk+wF%g&~u&8dQyxVNo`=#CUKnU0II|1;bHTkovMTZ?7zG1FUamA#&QBg;bRd zN@&9T`+03c{yK_T=b}cCEM3`zAZ`Ehl`Cos zg_NhJbna9eBc=UJWr8l)Ws1KLkR&q45i)qnd)+uJxpT8t)AzbHSD`p9pEFC1Vew7* h1hqUDEqyz#`1->1%K1bVvJKoZHhL3?k!W+s`F~wVibVhb literal 948 zcmYjPOK;RL5VqqayWM8RsshA`$8jaty>dVlis)YYK-ya_mWjRFRL-NeQ&E&#k$UQX zfE)ZJpDJ-c{R^C!NxCW{c``HhJic$THy)1&$j|CnEdxS+qqC6^jCXMCYfzF%s!2gp zO3@BLwfXF@j$_436Cg<%pv#xeQJNNo8;i5FL|DhSw0~VH<#sT-?&r zLCZi7+o%Pg*U|T_HvWNcMv_4H_{fK=MkBJZIv2VTPVoY;uAv9W%`V(fGmU&yu9R>- zEKhRj<1*92T0bbYxKIWRX{x$v+;rPVwyBjd5oG%z=Xses&J7Ok$GkR*uaP)j)jr`? z6t&LFHFyR|^GVqh92+%8S-*{5*9n}!Reqa)hK+OsYG5w3O0S@|a1eEV(3Mlhk6=cr zN-p!lOt9<-3*C(xOLd-S3UF=lzBuQ<$lsUqld4d2YtQD(#lb6m@nvq;dX8LZi+9z0 z1$$a&<*dFiJ5Uw(VG~e_z z?|O`P6)Z})%GlXvC;gtD{Fi8RSkWEjwp&tH|9+I-!a@A-O^@X*#c6P)1c*u4KU*XF A>;M1& diff --git a/sgl/models/homo/__pycache__/graphsage.cpython-37.pyc b/sgl/models/homo/__pycache__/graphsage.cpython-37.pyc index c7c984eb984dd150f173f53801d4fff9d79d1aab..7b70e8861e883dc28598b98c947669b94e9eb4a1 100644 GIT binary patch delta 518 zcmY*V&1xGl5Z2FXS9WYCrZkk$LTLkKa}5nW6hng{(1W2rQ0PT0kM6nlHq%8?qI!)rKbUS2FoU;l-RqT1T$SX)dCeotVPM_9uP% zu@{A#*T&ius-je?oRwZFD=^$twKWQh=U~t4_)vLUK*j8fa>cZ+(1G+&nXg6lFV((h zFGn&a&QvEc(}gdb@p0>asV6wjaelbOy`DP>he4}vbUhh79>MU)k!;Gw&?2W2lIY|`42=?qDx3o@{W zdhMYM-W5NH-yn7I_1Dw0?eFqB_Bw9)uQ%*We{o=d zfh}001>%sS4TQkpHE011Ovtg@!3H@NwI1VrjF_Z_i?qcKdu`YTJlupovo`wWv}n$O zDbpk>8Hw}QvleH-eg7E1YQETMG1oEu72=xI)CCFsnTPpeu* z-gj^3n+j$c7prQUQO(L)&X+UK9?-hd=FWAS>^rsDvMK34SKHpQ>7Jm+gV+XI$f=Dv zGfWF(CtUx kVfP0=eeH*{yMCwq|KxQodE+MDNbwdrrv2ZP!Tm=605gA;y8r+H diff --git a/sgl/models/homo/__pycache__/vanillagcn.cpython-37.pyc b/sgl/models/homo/__pycache__/vanillagcn.cpython-37.pyc index 694f6b09abdf562e40961a59cf444e6c9aa62dac..b87e19c1b28a3e074ffd1d66389adf78ca773200 100644 GIT binary patch delta 812 zcmY*XPiqrF6rb7uyZNJOQxQZcu~1ea7ZER#gB6NJ0wR>MmtoCJ+N7J^Ff)-77IJ8z zCxx(A5B1_%Q2a3S5&RUsmqt?G?0fUx%$xUnznT3Je5o~GHyRNE`Tg^~c;+md59yob z{(grO?quG?Pkc%a33s`7MYt#ErJV%a|3Knk$)7J`YIfB$l}fP51=9t5a&X4hOIqX# zVyym8oVBB#f+I*mIY}%|6Pt(EWJp&yVd8L~2UjF8bL$uvW!n3@(7pVF*)Y4v)6ag@l0#^Lrc0MWN9Ehwl0yJRa#&jPT&G*KXzpsvQ3i>id-$UoDDQf`PhUkFXRMW z>0-hbo62G!loFf`MXF~~#I}5dTQ=s?3;6_sTZ0XlXe!0}3_>Um453Y+G~c;HWs?=# zEs2jM@zrDqCibfrcJ~NxMdH6<8bCQql^TdFV{C&MQ`x~J1eOSdbykRh?a*epTIGWXPJ)`cfJSKXN&WopH6m9nXLhe+^hQt(+OdK8Odb9EpA!~P4B)5DMe delta 417 zcmYLFze@u#6wbROm+Q4xq(X5Je{?x+anQvUbQ6GLAcFACxs*b zfKIOd75)+V8=TzSOtj()@5}eX`@X#7sreSP52GkVHok7o#Jg9w*I>VnR7)*yixiUT(b&W3@TdL86N3#qw-&%LRqLk;WBI*s%BFmCDrF%4_dnR`Y|&^ ziDG0J&&yni(lC`zriG9uyc%;ZZZkff$Q2#oL8oJ{b8l8H3kW*a;vTQBCRqBmmTn>$ zSV(+blAR;hZlWw37P(SGQ50Dww_J~d1mvdU+YYD&XQAA;1en0j$!Ws$2N~@C6#vR! V=5!`FnB4JMA^PCN&Y=Z%;cwMNQ0o8y diff --git a/sgl/models/homo/__pycache__/vanillagnn.cpython-37.pyc b/sgl/models/homo/__pycache__/vanillagnn.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c8a725d6406da9cf2a5002d1b29fdb7bf55195b GIT binary patch literal 1557 zcmY*Z&2Aev5GMCWyIM((o7id5OZO18*#eQSJrqq7BM=HW>W_gQ_ChdAT3VaCtB{oZ zgT2{6PCoSo(qkW>FGHX}fF6pxLQkDp*|kk+W;w$lIrGi8`*Ji&2#nwU{8|1L5b`$~ zHv??)J&2is5=2l%Dw4l#UgnvM?=p&p-GU^_@UAoCm9+sI&%PO01rGByC`ocENbU)m`yzQq4yovg=(3*&A{PBi zl85g|K=K}BIB?0Yye^gE)7k9&O_IX+!KS52d0O*wEz`Bsrfll8IZV6k(!Hi$H)>_U z$yGX&_N39r>B*uj7Aa0mTGr{|N~vdWxCNipysXQ5KJE^8{&a8lPnRqC@cg3(Gv^n} zRng!02tFeaAH*n7mRQ=73-1W(5#AMarVWLE?CB@gzaTAPLA?W>06gE=5W8>%@hoUP z(7o0N9og9S+wj6`=~2=Kpa*T#0w!1C`(YdZj`1Tw0FQX)0?Zv6Fp_JoR@_1hpk*#N zcrksYF|#=cj1SKti0UkKVSOmMbwRx-h3nOY;>NhBR{RXGbCJ-^vRT<|%Y|mOlv?*- zw~HAASS({2d*?=MsU@2u6SiRe2H`p=5tEB=EVa)Rg#@trY;!$he~^DClSNa>i7{^{2h;s$>g;r4=4#R% zVsNdQtT)V@*Y$XLrg!1s0f^ZFMd(8s(*$Hld+>|t4n2Q#OXcwmx&4yw5475JqtWJ@-C1GUdNoE;{%;#p?(a}Zi^L7ZLD4umT_?e#ng?iaH>OGe+Uh1 z3TccKH;`w({sNbLjHjpiD=;?dDY_Hr`1PJo!8S+>Lbsrz^W8g@+(u19-Zi9wiU zM^s96nuUe2>^0#(#Eax9M2BVWM%p x&6#d=&#gbloFCvWAn36f!zW0Y4*V{m+x+if9Dlok<2#BaAaFyv4cr;}{{cw|wGRLQ literal 0 HcmV?d00001 diff --git a/sgl/models/homo/clustergcn.py b/sgl/models/homo/clustergcn.py index dafbcf9..f312681 100644 --- a/sgl/models/homo/clustergcn.py +++ b/sgl/models/homo/clustergcn.py @@ -1,10 +1,18 @@ from sgl.models.simple_models import GCN from sgl.models.base_model import BaseSAMPLEModel -from sgl.operators.graph_op import LaplacianGraphOp class ClusterGCN(BaseSAMPLEModel): - def __init__(self, sampler, nfeat, hidden_dim, nclass, dropout=0.5, num_layers=2, device="cpu"): + def __init__(self, training_eval_sampler, nfeat, hidden_dim, nclass, dropout=0.5, num_layers=2, device="cpu"): super(ClusterGCN, self).__init__(evaluate_mode="sampling") - self._sampling_op = sampler - self._post_sampling_graph_op = LaplacianGraphOp(r=0.5) - self._base_model = GCN(nfeat=nfeat, nhid=hidden_dim, nclass=nclass, nlayers=num_layers, dropout=dropout).to(device) \ No newline at end of file + self._training_sampling_op = training_eval_sampler + self._eval_sampling_op = training_eval_sampler + self._base_model = GCN(nfeat=nfeat, nhid=hidden_dim, nclass=nclass, nlayers=num_layers, dropout=dropout).to(device) + + def preprocess(self, adj, x): + pass + + def sampling(self, batch_inds): + if self.training: + return self._training_sampling_op.sampling(batch_inds, self.training) + else: + return self._eval_sampling_op.sampling(batch_inds, self.training) \ No newline at end of file diff --git a/sgl/models/homo/fastgcn.py b/sgl/models/homo/fastgcn.py index 191204a..503f3f1 100644 --- a/sgl/models/homo/fastgcn.py +++ b/sgl/models/homo/fastgcn.py @@ -1,13 +1,11 @@ -from sgl.models.base_model import BaseSAMPLEModel -from sgl.operators.graph_op import LaplacianGraphOp from sgl.models.simple_models import GCN - +from sgl.models.base_model import BaseSAMPLEModel class FastGCN(BaseSAMPLEModel): - def __init__(self, dataset, sampler, hidden_dim, dropout=0.5, num_layers=2, device="cpu"): - super(FastGCN, self).__init__(evaluate_mode="full") - self._pre_graph_op = LaplacianGraphOp(r=0.5) - self._sampling_op = sampler + def __init__(self, dataset, training_sampler, eval_sampler, hidden_dim, dropout=0.5, num_layers=2, device="cpu"): + super(FastGCN, self).__init__() + self._training_sampling_op = training_sampler + self._eval_sampling_op = eval_sampler self._base_model = GCN( nfeat=dataset.num_features, nhid=hidden_dim, nclass=dataset.num_classes, nlayers=num_layers, dropout=dropout ).to(device) diff --git a/sgl/models/homo/graphsage.py b/sgl/models/homo/graphsage.py index d1163ad..0ea8a69 100644 --- a/sgl/models/homo/graphsage.py +++ b/sgl/models/homo/graphsage.py @@ -1,16 +1,14 @@ from sgl.sampler import NeighborSampler from sgl.models.simple_models import SAGE from sgl.models.base_model import BaseSAMPLEModel -from sgl.operators.graph_op import RwGraphOP from sgl.operators.message_op import PreNormMessageOp class GraphSAGE(BaseSAMPLEModel): - def __init__(self, dataset, sampler, hidden_dim, dropout=0.5, num_layers=2, device="cpu"): - super(GraphSAGE, self).__init__(evaluate_mode="full") - self._pre_graph_op = RwGraphOP() + def __init__(self, dataset, training_sampler, eval_sampler, hidden_dim, dropout=0.5, num_layers=2, device="cpu"): + super(GraphSAGE, self).__init__() self._pre_feature_op = PreNormMessageOp(p=1, dim=1) - self._sampling_op = sampler - self._post_sampling_graph_op = RwGraphOP() + self._training_sampling_op = training_sampler + self._eval_sampling_op = eval_sampler self._base_model = SAGE( nfeat=dataset.num_features, nhid=hidden_dim, nclass=dataset.num_classes, nlayers=num_layers, dropout=dropout ).to(device) diff --git a/sgl/models/homo/vanillagcn.py b/sgl/models/homo/vanillagcn.py deleted file mode 100644 index 7bb38ce..0000000 --- a/sgl/models/homo/vanillagcn.py +++ /dev/null @@ -1,16 +0,0 @@ -from sgl.models.base_model import BaseSAMPLEModel -from sgl.operators.graph_op import LaplacianGraphOp -from sgl.models.simple_models import GCN - - -class VanillaGCN(BaseSAMPLEModel): - """ - It is a naive version of Graph Convolutional Network which works in full-batch training. - """ - def __init__(self, dataset, sampler, hidden_dim, dropout=0.5, num_layers=2, device="cpu"): - super(VanillaGCN, self).__init__(evaluate_mode="full") - self._pre_graph_op = LaplacianGraphOp(r=0.5) - self._sampling_op = sampler - self._base_model = GCN( - nfeat=dataset.num_features, nhid=hidden_dim, nclass=dataset.num_classes, nlayers=num_layers, dropout=dropout - ).to(device) diff --git a/sgl/models/homo/vanillagnn.py b/sgl/models/homo/vanillagnn.py new file mode 100644 index 0000000..c1f897c --- /dev/null +++ b/sgl/models/homo/vanillagnn.py @@ -0,0 +1,26 @@ +import sgl.models.simple_models as SimpleModels +from sgl.models.base_model import BaseSAMPLEModel +from sgl.operators.graph_op import LaplacianGraphOp, RwGraphOp +from sgl.tasks.utils import sparse_mx_to_torch_sparse_tensor + + +class VanillaGNN(BaseSAMPLEModel): + """ + It is a naive version of Graph Convolutional Network which works in full-batch training. + """ + def __init__(self, dataset, training_sampler, eval_sampler, hidden_dim, basemodel="GCN", dropout=0.5, num_layers=2, device="cpu"): + super(VanillaGNN, self).__init__(evaluate_mode="full") + if basemodel == "SAGE": + self._pre_graph_op = RwGraphOp() + elif basemodel == "GCN": + self._pre_graph_op = LaplacianGraphOp(r=0.5) + self._training_sampling_op = training_sampler + self._eval_sampling_op = eval_sampler + self._base_model = getattr(SimpleModels, basemodel)( + nfeat=dataset.num_features, nhid=hidden_dim, nclass=dataset.num_classes, nlayers=num_layers, dropout=dropout + ).to(device) + + def preprocess(self, adj, x): + self._norm_adj = self._pre_graph_op._construct_adj(adj) + self._norm_adj = sparse_mx_to_torch_sparse_tensor(self._norm_adj) + self._processed_feature = x diff --git a/sgl/models/simple_models.py b/sgl/models/simple_models.py index 63f5ea8..f20dced 100644 --- a/sgl/models/simple_models.py +++ b/sgl/models/simple_models.py @@ -220,74 +220,66 @@ class SAGEConv(nn.Module): Simple GraphSAGE layer, use mean as aggregation way """ - def __init__(self, in_features, out_features, root_weight=True, bias=True): + def __init__(self, in_features, out_features, normalize=True): super(SAGEConv, self).__init__() if isinstance(in_features, int): in_features = (in_features, in_features) self.in_features = in_features self.out_features = out_features - self.root_weight = root_weight + self.normalize = normalize - self.lin_l = nn.Linear(in_features[0], out_features, bias=bias) - - if self.root_weight: - self.lin_r = nn.Linear(in_features[1], out_features, bias=False) + self.lin_l = nn.Linear(in_features[0], out_features) + self.lin_r = nn.Linear(in_features[1], out_features) + if normalize: + self.norm = lambda x: F.normalize(x, p=1, dim=1) self.reset_parameters() def reset_parameters(self): self.lin_l.reset_parameters() - if hasattr(self, "lin_r"): - self.lin_r.reset_parameters() + self.lin_r.reset_parameters() - def forward(self, x, adj, tgt_nids=None): + def forward(self, x, adj): output = torch.spmm(adj, x) output = self.lin_l(output) - if tgt_nids is None: - num_tgt = adj.shape[0] - x_r = x[:num_tgt] - else: - x_r = x[tgt_nids] - - if self.root_weight: - output += self.lin_r(x_r) + num_tgt = adj.shape[0] + x_r = x[:num_tgt] + output += self.lin_r(x_r) + + if self.normalize: + output = self.norm(output) return output class SAGE(nn.Module): - def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, normalize=True): + def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5): super(SAGE, self).__init__() self.gcs = nn.ModuleList() self.gcs.append(SAGEConv(nfeat, nhid)) for _ in range(nlayers-2): self.gcs.append(SAGEConv(nhid, nhid)) - self.gcs.append(SAGEConv(nhid, nclass)) + self.gcs.append(SAGEConv(nhid, nclass, normalize=False)) self.dropout = dropout - self.normalize = lambda x: F.normalize(x, p=1, dim=1) if normalize else None def reset_parameter(self): for conv in self.gcs: conv.reset_parameters() - def forward(self, x, adjs, tgt_nids=None): + def forward(self, x, adjs): repr = x if isinstance(adjs, list): for i, adj in enumerate(adjs[:-1]): repr = self.gcs[i](repr, adj) - if self.normalize is not None: - repr = self.normalize(repr) repr = F.relu(repr) repr = F.dropout(repr, self.dropout, training=self.training) - repr = self.gcs[-1](repr, adjs[-1], tgt_nids) + repr = self.gcs[-1](repr, adjs[-1]) else: for gc in self.gcs[:-1]: repr = gc(repr, adjs) - if self.normalize is not None: - repr = self.normalize(repr) repr = F.relu(repr) repr = F.dropout(repr, self.dropout, training=self.training) - repr = self.gcs[-1](repr, adjs, tgt_nids) + repr = self.gcs[-1](repr, adjs) return F.log_softmax(repr, dim=1) class GCN(nn.Module): @@ -304,7 +296,7 @@ def reset_parameter(self): for conv in self.gcs: conv.reset_parameters() - def forward(self, x, adjs, **kwargs): + def forward(self, x, adjs): repr = x if isinstance(adjs, list): for i, adj in enumerate(adjs[:-1]): diff --git a/sgl/operators/graph_op/__init__.py b/sgl/operators/graph_op/__init__.py index cf2e5bb..05edabf 100644 --- a/sgl/operators/graph_op/__init__.py +++ b/sgl/operators/graph_op/__init__.py @@ -1,6 +1,6 @@ from .laplacian_graph_op import LaplacianGraphOp from .ppr_graph_op import PprGraphOp -from .rw_graph_op import RwGraphOP +from .rw_graph_op import RwGraphOp __all__ = [ "LaplacianGraphOp", diff --git a/sgl/operators/graph_op/__pycache__/__init__.cpython-37.pyc b/sgl/operators/graph_op/__pycache__/__init__.cpython-37.pyc index c1832ca54e71ce449633e12775158bdb5e24a640..68a3a28e85627f05e50f074030e4b4dcb96464b7 100644 GIT binary patch delta 39 ocmcb>bc%`FiIJi3(&9fNY7=hG7CU!;u+4&1^ diff --git a/sgl/operators/graph_op/rw_graph_op.py b/sgl/operators/graph_op/rw_graph_op.py index e40f4ea..ba56472 100644 --- a/sgl/operators/graph_op/rw_graph_op.py +++ b/sgl/operators/graph_op/rw_graph_op.py @@ -4,9 +4,9 @@ from sgl.operators.utils import adj_to_row_norm -class RwGraphOP(GraphOp): +class RwGraphOp(GraphOp): def __init__(self, prop_steps=-1): - super(RwGraphOP, self).__init__(prop_steps) + super(RwGraphOp, self).__init__(prop_steps) def _construct_adj(self, adj): if isinstance(adj, sp.csr_matrix): diff --git a/sgl/sampler/__pycache__/base_sampler.cpython-37.pyc b/sgl/sampler/__pycache__/base_sampler.cpython-37.pyc index 1eafea7d0f68870afb2929b2f1c36aa5dc74220a..1b41422a2d77522196e892fe99d3bc7218a20b84 100644 GIT binary patch literal 963 zcmZuw!EVz)5Z(1U={9jh3mo7N@S*V!AcTsbLLdt%J^8Y7z1uX!_PV>SpjA#ZXa1u< zz#s4vd*u_jabjj2M@iL*W_CQ;ef#E(eLR_r2-vUtFX}fX29~B7+%;SpoHpZFWzgq)0|3$pT4bup$|gOop(Aav+DW4&+#Jxxb>>P>uj8a^vIN zs##t&r7}gZ21Q8pPvG_yge6@P0TdttJj6f-5ZRJkhUa7%{G=3UBDK_osjk(Bc;Hwr zFI{9iY*FWxa)+YPwiUj%sOLgAE^3V06kH_mu`#+(*2c_*R+Z-lxwzriZ@HP}RWdvId{o|iPwc!**4HIx&{hYl%(S_2T!=-zXobKYt=FCoyI*kod`pCB zzp;R^c*7Hzdo@Bxe5W+^JmSD`D4Bf>2-}4*W_u8@WWBER_QlRMTa9jn@2`4?rx3R# zKR|F0mZmXv&r;of#8FkMZk2p*j5ccmZgiHn#RW7e?E@Z;_j*sZ$n5moZ#XSbYBXH2 z>BpHEM2eYj-Gwk`>ke8iY?B+1y;Zf<<{^toH~aN$4*;Oo<>js}_=c=^LDQ%49u5y* zA0gaGh)T=r67xuitLwb%{xJ~%dd1BX%<$9MV&Es^hk{L^dwDtKoHNc0R^T$fK6cFU N*WJBuaOYoe@*f=**RB8n delta 422 zcmZ8c%}T>S5Z=uu%^H(>@c~5ZAwu;g9;B6mdW(l5cnh)F{>1#Ts}xGWn|KmJp2Fww zE%r6Mi3ev=3dM!_X202Q=9~Spe=Tq1c@BZI_i!i9Yj5BfK&~KEufFeD_L%s*em{CW zgi=K-5|S8^#s(ln!ltA$J}4zI0WAx(85Fs=@lDO5IMJ?(vN9Di%%e91}?huw2b6TeG<6?^rwF|FS6#AURWF`Y2N7Y8XtT^+K?nxTjd(? zH?jkO8il_hbhST zFgR^?t%j{rp3y0Eiw%e4dZ*Ma8=9_lD~&4Wn4PI^tx@B+)tTC-~Lrqvfd#Ei7nx}I2JPj>NLdC`x8Yks%i2_%wpzS!_$ z>9<4A5B<(|)Q*trMtxsKf!E#g;vW8GbKOg`;vkHA^yZ9uD-NOYt?5D+#vP|&3R75jw1)MrW@!x@D`6*$pVlCE3+ODx-jPx~qWW29qgOKBy!a!&V6OCEJt7SL+yd1h~eltGr z#_R3K<$(v%#q|4~ZC3{IK!y>@0(TAFyDjPW*H`kFWexpaUB&3^XwVNNrxVBX+F?8P zyri18=Y@VZNUD>IRX4G!L1!&7v4u%t^R_Qrw09-XYj*r7^1Q#+K00?{z1IycMA6L) z*Is({iO%-cg{alJkhYqAuk^R&3Ub{WiRcKm zH7t$3Di;y28%q9TV!fo(sHRKy!{sLQNEy+BJsOMoRQdtSFE z1|5o5Jtikr!c&l^sOhP64Vpox<9Xwb>$+S(xupJispqI`8oQ%c9nNVYk$;FX=o3AU z;C^H#Y*7HY*!{H$bnKBuJB zo(P_}-Hrm6PayipXfdxMA4CsCY;qbAPN_AUr6esy-Q-=9^OSLtGSUU3Emf;VbVBV+ z(#g`HX^m-x=z~5}c!;oG!4o}$y{2#|oFy z@+!*ja(Q*v*wuEdk+oMCX&c3TI^sn!mACj4yfKxw_~XMZqOKF3N^4M?KN+`4Wx~wb z`9PO5sAn!~dnHlZ*LQ3&EoSa$Bl`;Gv15;H)GvolToJQ0uegfQly_+~Q(!y1Dw_qg z@+?}>mo?!bTOZl^V674<&S&o?uACvHY;dEk1Z3l`nPLQQYix`_KM5wMtzSwBeIS&R)7c%^-?aE_HfMzY{%+ zNcQsZ>g$~VGJuQ4yVVwf%U$Ba!pgbrP+e7!72MvM8?aiIT==qx%Zf@^41G)ymT!oe|7cc zEM9vdiULW!e@V)oOp2`__TyM4PD%t(QuCS^RxAh2*u$mFY&4nmjz}~%%S-ZM^qrL7 z^gDw>jccVdM7#Eb^r#Oj5hbURx2AvO@H$ww%l`6Lx3 zF?nHIM0o-Oit1B}zY}XurlWq)Y_GMOo{}9Bo4*r>WfaI_JV5yn)peS^jz`HQh0AFB>&HRlSV# z9Kr>?Y*B6co73lv8M9`hCdn&j?spdQyO!vX?xM?hq8}jux6$PM2t0Q~-!t|#+&Le# zv>gb5d*;5jV{SbYTPU-4b-tev$DLs!r#P~}J&ms!J2v>RvFD5m`w$mL_Ktb;{99;u zuQW1n)8Vdr2R8t?KX@Jb`WC+p4tPl$Q4V-1Vz~8$`90$rjR$4K>!}2H5a+TyZ}|nB zs{BTOIqVPOi+PeH(y_dC5gbnX+jNdO(+wda;P{Ftf8onPyxtQB8_3pBi4PFZ^4hNr zVssZBtR8x8anWsca85gU55X3556JCVmKs9d6pfS!@?y$a)>2^wTv==!=0b}8U4=Zx zgJ1pdzwrU%v}p`cCMyd%dQrW?M?qf2{>slHK&?Tyzn$3N1z=x%%1(GVMCzi*5r;aJ z1;5`9LIFXN4=+w5u~YdH8mbr78Cs@B9UlTF_1Bnc4x|^iVw@6zgPGn(SnXg-l|{0t z${H2Cp!y5Mp|ir`>(|k8L~^WFp4TlLff*ut9EN$LfN%Vn!y_VBCM(NTyw^18ZG7ln zB6??9FN1W5{z#G#`a8y&u?*@4)ftdr&c&epk-1@lO0Dp3K>t>(Z&(}l2523WZf-!T z0WDf}KOFYSX|o zH?sD18a3i4M%GA&bVBhtoM?=eN|0-+IaAKEwuVzycZNUA>5&i8A(wE%|Vb9(`h7Tr+ov!oZkZG zZ#Op~D*2J8azs)<#eOWgb@|(9LOw?U(T?iC1i%OA#EqX6aEwH+n-rSsJ*<(uPCZs| z{WX1@Q#?s23S!Y-TYE~>t9)V_Wn2Td)ESWHQNJ;pE?as8aV#3;FxX1ZP@^WHE{L$? z&X@MFb#< zPg9(PPoOwzK=B&JzE2?eS9mfh@~I<}oaB4QiIE3}_(?@wNpqPT*7)s$XyxhEsbPp)Gi#AU9^O^oVlrZu3%&!;u` zAhu9OC6IRBF>4IW^n z!PyC=fh9^xX}~Ih2Un#_G)Xu`Ar=NgMCJFz{J(|@Ej;)z!k3RCCLg2V90epg$vT3c zkJCMlH#1I0m))30C_Nf+x_`ivao`JwPs0ReI*$MTvS7Sip}|2iq+8+w#h;|$DGKQR zQh@I>6#G8|+v|%g(#y}$sGgzV^Avo6f-h1)7od`a_&&TuNkZH5Hz*)3%K$*hH_ub@ z1qw)PQsT}PihYR!y4qO!nZ(ZWH!1gJ3Vw?MZkq}46-tuCeF!7VS1Iq17((|60nmsV z`K94uV|e;Lw3J^5G{_tV|1J^=@-p0I2s#0G0Sq99n}^`8H34_$8SdIg;BF3T>z027 zRF+G%zY6H8ZajCby#?W##8eWtS(ql+8xW#{kPfqqAG-m8Qx?w0)|g`fF5$Y~AeBic zpumKz&Y+qaWiwA(LRZEu=aVHMkRiq65O3t3V4jUuHWa2+OEl7IdcPc(-> zD_e#Id}^_3h7jtJYn=sgs1b4zne=@`e3$V=e~$p9bqJFxWCwzSa4gJwz;FVZ7T`E- z16U3*=*>~b9NCN|p+n!Wg$*p(hzl{aGSHtJyT%UC;fAwULQUXF;7`VhtZ~8n`xdT2 zkT6;UW=+vyF~!6-=Q|^FuNqJ7)nIpUM#j8$8af;2R}kGkS6c}1`H$!7(-Ww+81_PFoO_B-GeG70{_Ci!IQtEXk$jyL ze5LI1DPf9UoEQ;cq-guFl$I1IE#)vjsZ1;~(6!P+H!J|nC^7m^HHzuNHcHvziAFii z^ZHL6J}7h!{|k>&<&@_1B><5cKhSrwTE{V#^x?yYx%E+{)c;F-2``9Tx_R>O9RK&= zB^)xcVL{@VfS0n3&5|1w&^?X96BN$@B|JP10ZL_~r2$F8TtC2*0ZH3=+F5j# z8JH2F=O2Q2^obe>G9W4RhRM1Iuu`~2SY&i%TXjJc8AOuBT5FKuP+6=USg~u!DJfHB zqbz2{9MToGWLE(S=abpXodJA0KHe!6kF9Qlb_+?nRUxv+o83mvxGT_3jR=0;g z%6Z>$ial_qi|z}D+0(hzAckFVbIWCWPJJvE(CwX00J^}d=~Pm#b}kgGW$vR02>UCwOrwWGq8J15ZvZwE&@KQJuf=9;i2{H)dqv$N60q6UPBqG?@eQs~ zCMZs2=(B(^R2UMHM?e8GLzPnmLqQl~ic`=y@0IuAvcj;jwwlzMiiW)|n?c5DC=zc3 zGAXlflO&7deUT(1#h(Tz=~nvCMbP>TfByM7m?%Dr^sxYjVaxMWx<$b@1w#bze!&q^ zX70MBHl22P0+1YoF&?GeroGqr_HoB{Hrz+_H#~dH&QAc)90H>uKBAZL&F)epyKTUl zz@7;>qI+y9gS@6M((Yg{NQX3hCO?qZf;i5YL&g<|Be2hs+w2X(czp60hOR<-Wtt{D z)|g6Pf~gf%8L}m2*WY4_Z&Z})6PINS{Lq4o(8}j%L744WnEaVzCZDfzF1BWb&r_#H zXzD_T(0K(hXr@4#XF=p}ZqT)GhRB^_AUYTk`P)jrZ6-o;{D9cISAdC;G{FtD0`gtf zLNc0wbRn1$>Vp1cYzl;+TKFD$pcRKfzNA+I@(Soh*|MfQ7h+i4CVO?3pRtz-4t(Bi zO7JtPah<&Ul*up}@4(f>#_4EB!-vW+7Q(@e8^>-j_ow$*DIbWH^6=sE8*CJ~Po3*f zIY~Knzwol?b+(r>oJQ={@Xpqc#%FY0nwyeYCl-CNvK>8g|__6Oz zPBU31Uq-g|dx2;L5u6%@!ZV50j)QJei(09>qq5UYQJP#BRnkF{Uxd}YoT0lgzl(o)*c^ZF40YV=HzGK*I#?6lkdQ5JlngS_+bHR|X zKryBJXTo_nhNYU#31mFbSFyy*J}%-VrEs70>dNICHS%1#d7TDbra^;8Q~~H5efZ!p z2V+ujJHi&|_KqG$5R=^lgSxAE6K%-9Oj(?_?VEV{ ztuZI7>cdAq$*2yj9|3XRwKvgFTf1v`D%3e{aAII!oJw0J76-GzG`^^Uw5;oO6R24V)4;@jj}G)Baz zuLF6gmF8vjXHm|hf;(YJ8O@PnrLD~!TWVA3%@qhaB^=4}@P(XLU(VPq@h&&#U9bQv zKc#l~1hFH4X!tx$@ZkfSMpqFN{aHlVBDesHpiWwpWegP;5rpIeK4$WdP%~3U{~^bd zVoGXCieNs*WJQ{3LgN{!KYt8SeJlr_Ao<+m?7-NG*skG;NDSq%JK$oIqgCiDHDA(L z5A#?WU^E_J-aIu;OT*0*|7CtqgIfYQ-r$a%;03A^lWnAl-j>Jpa}P}n#{CR`-i64| zem+vgMAA)s1l{mLoFTI=L}5@MZ6jjX&SkA-sIA2IVKAv`p0+FJTy$f>{jSg@uN67L z*lDG1jc4!+_s>=V9>hfd@{egOy8I3e?jRXW^i7`W>bC;W%dXZ>!2P?G8t3oQi{GN) zT?*bqkd*0OX6zKnF_Uzn!3`Ebk^+8&5W{`ueX3jvAhAOtfY*V|x3?t4c483)CIxtX z4uXiq=P@?I+R+J;e?|dM#*`mWoM*vBd>D3-w?-@IBBEIYld``V<5B&LaYlbye-r>} z9$o{(^FPZ`$#;Y9FtV*&#uFt7w*G~#c)+NC#9c;iX`o)>N>FGj0h`n0wo4`8dA4$5 zrds>@p0lrk=1o$ULit(VsGz(Gw;VVWv)vAQ{55cEpljmU2B}xk5@b)bq}s2kmXM2B zo;#iABF4Goa6FsmB9_;M^Nq3*ZXgC;=ztfhUrzLMUU)qf)iU{%St3-JuWu_^l$=PR zR6iGNHG_WaBI#!0nR|P^&C)7dZ+bnCu{b<+GC!Qo_5>WyG>ZfuKzNFr*aq~c$VqJ0 zyeH*XdtvK3|F}hg3Y%fQB;o|+(oiA#1j^{&qs(&N^On4V3}r9M=@`E^>Gb?KZIXyu zeiuOo)rr{PLJOA_b@7{Uc1cXQS*jy6K1X*@afc|7_3dzdtid~F2~@Yhf5T57LvtS| z1d~Q;>X2(gQuLrhXyRuQ#F3Sc1Hl{R=zBuE;}{ahYml+R~_bxcdI)|_dl z=**Xk_+DgXNBxW>-wgQ`3SLA2&%I{5zrCUq$2i)oK$g77Cnk$l@FO9=(~`tK*wcg=ssEqp{^x?0}O9#kMSY6%l)FEMdbC(-}?o&dzFf zW^}s8-kq8)Ok`u3ih~1HAp}V6K0p=x0aU?L9(bg9fhwx{i7KAV6Hh$&1-|chx_f4K z73D&x;(?j!)2C0LTc7j&?&rn1xr%{b^ZD1p=0(H!H$G%P2bmXeMSp`NFaon@H1*wV zTJmgfJI#X0GOJhIE;U`r+r9F3rCE`@)2nXRnl;Iry}9kWWthhHd~+c!S!^z%b)mPk zeWG~+WyPQrxIuZ-X)a$if=W<*&j_kv<&NEa0ME5x4$pJ){2-p|!91Sl<@qF@7lK7R zFUs>uU|u&GOP^rw#(Kj{oc=&<<6eES6@}MZ+rwU{@VL|pw)|MNx_!UZZ}oPgZiLbW zT@r^;>_@|1H%=C*oy{nji@Mu>ys5%ybI=PCCyK*ive+BEGnGf8%q`THWC|ao)mAr(^4HScG>say;c)ajj1d(G~v!(COwL>9RSMI?r^Y!=P&^2~~@W^>Ayj!S6bx{NC#y1NFJ z##~n?kZ)LO9=D`&vlX{D@i{>xwU+Q+SGGLyioYRcI8sBwPY4NyYcK2f?Vv ze9ae(=r{5S>Oporm##rO?Dc$qI&i~OOL(t1zAW_uhvwW}v*t=k8xQ%9$lxymG2a(v zK_MstZUxdC?!}-SRB$f|{9G(cHJM*0+;0xlbq$ex9XSZEpB94KIf}BLHwS}2BM|uT zw!?T6&l`gtXesip5_#7zU43bazH21g?XVxW;_jgD0luxEHH+Z?C** z1h%eytEhStGq6t?`*u(|F!vmC%J+SK=@(%>_$d@Jmpp;kQ z1@<;6P25R&Qn^#ut4^vib~PV+0pE6VQutFNSow3C-ze=1!Knjsx<$6%UpnACPXr|G z)9Xo@9B8NpHzQ^TH~n~b7$((TYd2JW)E$RWQt6MjeJO~Nk_v~tRy!P@zSe&mygcx3 zg}d(zRN#SGjflJp-lyhM?QEyX1W1k7xzi-l4n}{gwr9r1#JppE$GYuYH#C1S6T91w z8&+biem<#8Com;;f!wE+i@k0Xw+Gw9=bmD(9HMB9&mBV@z6@W~UcJ~Gv|GLCIb^bi z^Eg3U7CV(39c7vo+gMNg`E7S;|%HM!dP&m!DYV`vT61yL5 zpmI0%;sKlJZC}07SWX=5VwB9k{OaXPFJHUyoBr3{y!86j>q-4m6orbcz$!$W zEr>|sLwRgnmPrl=QJ54&?`YIC2O{%H>MioUFz~fxObT+B*(ma$#BC3HdNeMxtnhSU zlX|uZTKhNxE6Af1FOQ>b^ehs?tmE%lHMzSMBntApT$C&!ebB5pD3Sk?xnwO^b#r{2 z{iPk}@_}2=8D7Q}eH}^1*n@IHmbYi`JWF+D%erHN{hbBd0y_f>1gn}jU?bMo!HU3k zto;I+j%iE^dp6nN!~zoq8+{)P5WRtE2KJv>RBc+H0SI%^a;+Z^NAZRHnM%)8q+b9X ztJW@YFNI!z0P3Ky&l|iVQs{688SUk=DzE+Oh{HvPtM~nGaKY>Jfa_j9LbxN;)@pBN zMS2Q(SDYzO5nV_raU<2uzTkHE+z-I87eK)^Hg1 zgGNaM_yO#4;sU_ZS#z#EQhwZtfz|+okg{7WbGW15MXH^>BO!?16b45yE2`1U1Q8{) zQE>p#1%lFat-5s_h?(})rgFj`Q&x9sDhTaL?KK%LxJa1IRsN%h)m}1AsuO3wI;jZYZR8;HOyl+fU~lZ#CMMwS98iz9 z$A6LI759m$nk3!AP}3>3)he)xc`6HRjX>+>;_EHH=kKN++kDL8CS*jKIohzWk|Mwt z47QVEdvk!#NQy0u#x-_RBe?u<2VArrr>rw6=bx+|L#1XJ7L10L&I%1Q3sS3DhCQd} z(kz3~rqgopO;KVEpH2#)P)0e2=TStygdy6vqSHu>Vg*=y7?`YpVe&cyR*Abc{!GR( zavkwFA+;$O_#?UrQ^GcS0?dps?cfm*bjS^3V;>5|KEQKe9+)2*$lG_JU>sPI2XDP= zY~bC8#?+#Wow#rpiXr(d)WcF-phB|e?z|g26IZxzOa*LG2K*{`Ycq!$)k$F=0PUEQ z>N3oBPk1}2OzImT9D(!q z-ZXA+yvf;D0B3$tWl~7*;Po@#TivRGTnZCs!`iU)%=Oo7fa^mooyXVDZwU8aKctHl zu!Pj!BU)8m>SFahXo*~`Q^p@nwK3UOn0%i>wZ!Z~uUUXbb9)49HRDWr>#5U0B?C=q zmF!(SXDZp}k8Sb{K4;XjyfN&HFQGlig4j=rddQ^O9`xHSuyb5VISS*TyRq?fpst{j zdRl4%^x(EaO^Q~95`riQjVD$#Oe*ccz~64gs=L!z5T4vD=L}f=I-A^JPqrGolh{xs z5+`WIt!BO74bTqE9O|%YH|MAfPaD30HcdBshB}D1LQ#pi)hxmL;ghc2Y|+#;4#W+> zr@qO&7_P4&w}VUk6$a>K!K~SJ>$E5uP&@Ek7bSxf^r-o$H9m8{rEQeeN%pz-$h`>sf(ii0YW>6C^u`%WfB^Y9F=h#n10qU=8Io^&4&Ly;t&|45enMZbjPey*FN zHnrH&C2;L>QeSID@zu-gDLcqjCHBqhCpV>O*VyZ71C7=)+m;VMn{E4tXsK;mQ*7Il zhv#|Rw59P}#3iQf_ra_F-?W7Z{l8_}I+q2g9T2AS3@{ToIMWpUFyvoeE3!7&E)Cq52ybOBwBcKWwBrGe+w2eT~#BhmF*qzm2Lzi?`9;fjg9*9%}7kf6=}TX{|xqhl{@z zdm*h*(ekJM(tBF6wBrV&RH2@+DQ1*@v7Xtfgbcb0P9*91kVTPMhW9dUBkycHP?(VR zr=)gmofl%^;MSD_>~Or z&{%Dj_ylwxMI|+hAL5GWkf$-~Ky7j0CrQoGqi7R@32yD&8m0S-bZ{A0^mj;rC+MYg zYrGFN7fLid!24Eg0%eHZP|-aw1K^EL673$Lmlkzws?D%EX}cDJ0<6wvT!h_8Z?CpG zO|d%7{R(Qswxro0mM7Z7@+`tvTtvSmj3c8@N_u>#+$?t|_I@p%+pohD?M|#^;~^;U z@HpF0@NH4MORdYt@O;DxlG|l7u=(jtQEU}RMLe}rTa-wfK_KkVfp-%{5~m1qsN56L zqtw|jYa!3{rP2IjE#%aFc>x5L`bR-?;tr&<1`gzYd)L$I`2Ef8^+OU!tmc%*|Ht!< z`6-`HO8r3}Uf6-Qn8kqhfFVWyi01RJFt^U+RV0YDfd?dZFo@M}^SKaoTZp%nHP8Mv z7FB0fHz}pxq+Vwi6!Xnex}wc;wrWAMl9u_yr;kzy)$V`6rH$kIlDU!_#>?<4(s|=r zE9UsoqtgDkZQO8w5*7bG+WiLR@g2^Eo@Z$H1-Idz6s@_9g8WC_o4<(GLZ;gXX44i! z8Z?N|ti!mC#L5tzmROm@U795wpKiK3F5N8acyzNOdVdvq`a-gJxi><1SWQ*`@fQW3 zX*t7QM6-y(bKXYSLiiO(hQCIf+Tyy=JUVw5(8kSehG+3W%JvP)C*O>%7zjuXlXx>pUshfJ&faw3N!K}>B*4NyL=H^`52&yrvYR8Vf-P#doHLLn$gS-R81bs)AV2rL7xdeoZ;XrRaa}Rvw>zca` zXgcMe-oyhxJusmRyp9|qB?!_!j9@YXrKZteKo05wJah0J5-qqp>AfK60gk{#=$wY% zE+{g#!t)UU9T@I7%oprR$k^l{$`&nUT@Ze8m!Y(*JfqSXMt06?E3gIDxFH&!u!Ly3 zzH3PVpzTKEM=A)zAiOV~V=sh>vRwp{N4$TvGJdr3y7-FkQ|AWML~4kMxdX^v{a<}Vx zqo^$eVN~y-Hyy4m$mb3qO1890kJyTjSZWO

@1QX~^eM5{%<*GC$*c}79BlyE zt~bz8*SKr>IyEJ4Noq^#NLu*V+7C`%h6KLVri_e})+y z+B7m}K|5_rIpp!+o_W`#0e1>h1+RrvK6KneeSn%N#MS$fPf97&W(0x2>a%$L(hSsJ zxCf|yBwLz+eC{MWGN0}8TxHg?&-aWhK(n$B@XgcPC zAil|v|HPVM0(uZzVlL_p)Rk2^Iufb=tJ9IjKkHKD)Su5)vypTgAHy)R6lbuFw8DWw zZY1q-<~o*M9xMtblez^7fShSD3<57)Q3p@-XXKlCZTbcO~WiB|xqZzw!7jm>grLjC)GyzI$MR6mZ-iE^i^m z3RjH1RU^P4AK^;e00W2dMmBDJWa)P3Lo2OwdC$n@rS-fNIVqh-yM??IxxDUTUWy!; zqszmDV2;jsm@~&4sX~)!ZNekrPI+xtYiK;=1>fG;aHky(W1P)Y9`GPriz6ppcc@Jd$D76rW6i#x86()>SCM6%1Xd6d-X;mtkKZ*u`j$2Z@ zQ13F~B<>fyq_O*-xP+yM29<jyT0 rICJDvXLc~(#>Bm-4_!Zpt<*Xy^=f^-t!6Gk)5b|RYoWFJ3FBF zK;dmhNJL`d(E+6-Vna3}3B_9RRB+AHB{w*|&z|{NWO5d>m{tf^zEQdQSP|KXw1Ov* zEi)W&musSz`>0gUOzqh|966Y0j{MXycg$YETJh$&xyOQEBReW$t$C$^Xzs=Qa;Ne? z70RN*Z~?qBLRMXYwm9rRH!M!)Rk)R?AngG?QsqWV1A3Kw@-)$>5e8B%uJ4%-XmA+4 zBnpExPfWQ8s(?;_P7d^sSz;zubPB6UOct8viScA)f1n}6PMpqqbb1<;nK>c^K_!?968!a&EJ*|EB( zVybd%Cu;17GuSV0;x310Mr6|HG=>FhvlwhVC-GwPf-EX(Dzc_Pk;&&{RU)GpYeGI& zB|fy9q*`*WSJ_x;aBobH(!8}yZ>8s~mo>F@;QMLI`i-upYj%0T?ghnN0Jr;U6@jX1(3z>ATJs#i|>eg@5cIs~+vPoAoeC|8~OFRRQuA7@;;Z zX@#1!%yi0V`5C2Oxfk(3es|yKQxS-t1=2zCi9CW8Cr^fSgtsU(k8uiAK7p=qwr2+G z(yvAur0y9KqvnVWEv}4cLOa)X@TiV1z-t3m=MyI5otVKZKtk}?VC#Eki^a;96z8rY zhz{LY9pOaxuWgbjGt9e(g%IOfZ&<0_<|2Qd`9P| zqkD)IP2Ek{PM_=$hE5ans5dpYj`gYxOI+pJE*hDhh3x!TR7ElK%4(TrG+jTQr%jO)Q$xTlS=vfVqV znL#tE3mn~&XsIp)B{;{bfJ?9rH8>;I8K9==*IV7-(?-C9%fcZc*1sauQ-WO;OkDri z3>)p9zuVav>y2Ko+2La?>IU1l$NFIyHMaMFxo#NEpej+SiR6od6$QH_%IF6@4#c9- zMpU&XXf?VAQSTrM$J$mSLdEpw?5h8vNVy?X{Ki((4};f)MY3~3+W^}Fg%`j8YwCh@ zy2S8r9)A^@rM6N4yq7>H6yX*2EIb9tZ9D}*cV0>FYA4h8b05)-^xNF4Z;G!%D85rB zB{W-bRr69HY}I1$4n(jCHg@FnYuz56S%9wxsikxIlEONj^lbj?LKYSiugt|}E_3O! F@-J~rqAdUb delta 1832 zcmZ`4OK99ibY}Gb?8owNZ8wP>JE2>s+k;ClC73v|g9|2*(om&kS$njxWl6gp$*wUH zd)QKXaY<1wB?NYREbXDUo(etno}&(d+aD&~#0~ z%YhZzt}S>asDxFw8rIxeSa<8;ggX&Bt`knWlNfCws#5JML^aK`2a-GW%kXCzSMnX@ zX=M>s6ca#*oUP=)n+v(AI@Jou7QhBT{6qeZy3l8LQshg%!k0_QImw8S)s#?a4gr?u z5LM0zMp0vILB8hez-zuyS`soOBa}&rN!3(JaRZk0Os4vCX@D^TPwNS0c4CW|3g9Z# zGN^ejr+w6_jSY1$On|yHpq4Y%Prjj_0$K8>sdXr^zY;r{0a^2B2lx%YM%4J93}Pe0 zIq;wFk07ii0lQEtguF1`kxcDO0Z^adQ&|*xYk=>dFaFN7u?If~iyDVI?C{6r3P?Pc z=JQ=H0L%ki1-Ld)KPQ=ynaLtZr4Qe5U-S9?gzG>77b~-)^<-)6RJ*VNeLcvbW!g56-+oi}#N^oxm#$n!nTyXJR8b zjuVeSD`rb2Rd4NxH`2cgHNIcxWr*AlCNT7`benV=spf;l|!1CG}6 zSH^Glc=fCmwZr@;^PAP8;+?dE<96b;!Y=ii6?TK?ws>w{sud@N5g+e9X6@d7oZqwJ z<_b8(0U(4;qGO%Z@jPi@9nXM7@XQ5}{H1jZJNXrRtuI~xh%f*Y0CDmE0R~!#hIj;T z5O9RzL!GS&os`Hx6<-ziNgI)&PRR&oBw`e0-Q#GBBc7vBk}liZFVvGbY406E6T5LDCSfIx z%PHVz&u()x#Sw2q(-19ZALqYTKiv70s~>Xo-X(gUbK8Ip0OFf~p!4uoq3<%bq4PFy zQ$SVX5C{4{FW~HL@!+4-9^vi$MeW_sY#z@)S@5yJRmYa~yNrJZn$zqfAnXyKLKmA@ xjiQ18Sf$kM!7wsDMaATWeYJBO?Ij1@s2$`x_3x+ZV1;pn#n{-vHnH)d^dCQ1zGDCY diff --git a/sgl/tasks/__pycache__/utils.cpython-37.pyc b/sgl/tasks/__pycache__/utils.cpython-37.pyc index f26496f07063693762bbbef11d5827416bf6bebc..f2ad67af476e8a0ef41db86b2f8766c4d67f9e88 100644 GIT binary patch delta 1801 zcmai!OKcle6o&5^+vADfNt}nBq|Li=(kE$Rs`5zlXdb1Ns`Md3BFD*$-64(>-*Kv@ zjFko3sG?lhAVlq?6_7&P;L@OOc&x$(wGs-_Y(TIEq;7x(hom)QV>2|vvg#WI+c2;TWTJ^Mv3sxXRCs~psKxD|+WMNduh#C78S%+RSHky(6 zDCklSGDSLhRnR5f8WZ%QYrx|ieAJP($d(+)mH^3Gb#7&xx~My|Hr=920g{8Pm?K$x zj)X}@)NN@=mtF^M+fu=cL|@>pDw1(=X*ZV?{#Vk&B|T4(wEv5wU3YIux>Ji>l6$zK zS9fKLa->uD;v*&ulaTcc<;5T&!F&`TqQgO-;SWR7qSLbpH5sSv%;BtI8}`qSH1h!h z5?J}|;Fq6YLGXbjSa-eCZXP))vRa#;{V4ssw}-!;m3Ax=#z(WFL4(KCh^VjJG);3_NiH9(SueeZ$HHhe>Hoz)eW!F)bk{KN{M zvDWeqm|^4PmtdZKU495IvOmiI0-gO{F~BFBs5}Vou-VF0xWVcJYvRke@^$uo;C8^x z4^Z>)cLMdil&#0RXmFu#vKqr0vD7EB~T*J`}y>q30TxRFNYE$7Vs6YL7rK`?mHtJjna%-hwmRb@=bB&4kV5{5z6+0rPrsE4f?jN&VTw zTglK*aRd)##9?L;!G!2N^sXX4LFAd%a2Tf9&4vzG%Nq?TsNQZIvu8}%B3ft*7H?8v zkv(o|gNN+TruDKJzKCH_<`vq8%T*xWU`F%)*c>-Q5`7y3fEspLiBm<<3@?qS$;%Nf zMbnXl5=$GUG%_7c?{fNfL{T+P`jFx)k#I8 zSJ+Q2$NKJVx22E&H&*l3bOC*Y6Y6Ohe^B#8V`yPf43{>k&dwj)T=^UH&#-7H)!&V# p*~q8S8b>4%^N1S={7C3s#C^mA#2VruVjb~_1zLOf*EZG~{2Om+q51#- delta 1819 zcmaKsT})eL7{||ZXiwjM*8wdb1u7qdj>$G1BV_Y~frG(THJfu1u@-uOj<&#i3T}8B z$V_&lnQ@OFi6%=cY+l#$I@{>m|k;?~LzrO6){poBqyo{_pcX zU+!8zVlf>uVhey?=u2om7GJgOhNtZHGRstysX>(M5O-?Im%jpFjdZeM# zX{Jl3AX+jZ^FHAx>nsattI$exrkh4(-FhCx4qKWfJ!0KZvlK%px-Dahnz9fzRZ1DN zULtndbw#&l6kYgLN?FLbEmS}3|}i`15X zr7mW+=+2Z?XK#b3!a2Qs@)b!JQni$J>yC^qE9q|HsC-NcnnPR+ObW`z!$6*Ck;GIy zmK@^>E!p{#;*oc%vyu2rB&CKElTnpFEi|04{MOkf3<(&R#Ls4Yy|M?&pb9*|WYZQJ z?-n~o6bzWl@H~5H?h&s?aKTy$J-JV;GE|YkpC?cj{uJR5p%)XD8vNPT28U6$@2u(~ z@gSj}FhKASj$og?5njcF{Rs%=KDAE(yoDC$DD-2}*->$Vd`boKj*<9`mQr~*KB;Nb z_>Hp~BKWKG7o|`GuDM=THc~(4-nYMxHo6HNghs+)!Z2ZkK$YZ`iG&(SYJ7;4*9il7 z)_njL@H=+{OyUFgMNsjqr{DLnSkxQ>l9&_&FDT;6=W-7`f0*D~ykFi9Q|R!H!yL|d zcfm=# zPq1{O91^}D93^bkJq(~6Xlz-e^)7+DIUX5SuF18+7w M94yG{n3uW!1|?sqMgRZ+ diff --git a/sgl/tasks/node_classification_sampling.py b/sgl/tasks/node_classification_sampling.py index eae4781..838e9b7 100644 --- a/sgl/tasks/node_classification_sampling.py +++ b/sgl/tasks/node_classification_sampling.py @@ -36,12 +36,7 @@ def _execute(self): set_seed(self.__seed) pre_time_st = time.time() - if self.__model.pre_sampling: - # ClusterGCN samples only once and the sampling procedure is done before training. - subgraphs = self.__model.sampling(None, to_sparse_tensor=True) - self.__model.preprocess(adj=subgraphs["sampled_adjs"], x=subgraphs["x"]) - else: - self.__model.preprocess(adj=self.__dataset.adj, x=self.__dataset.x) + self.__model.preprocess(adj=self.__dataset.adj, x=self.__dataset.x) pre_time_ed = time.time() print(f"Preprocessing done in {(pre_time_ed - pre_time_st):.4f}s") @@ -52,8 +47,7 @@ def _execute(self): self.__dataset.val_idx, batch_size=self.__eval_batch_size, shuffle=False, drop_last=False) self.__test_loader = DataLoader( self.__dataset.test_idx, batch_size=self.__eval_batch_size, shuffle=False, drop_last=False) - - if self.__model.sampler_name != "ClusterGCNSampler": # TODO: need further modification + if self.__model.evaluate_mode == "full": self.__all_eval_loader = DataLoader( range(self.__dataset.num_node), batch_size=self.__eval_batch_size, shuffle=False, drop_last=False) else: @@ -91,7 +85,7 @@ def _execute(self): best_val = acc_val best_test = acc_test - acc_val, acc_test = self._postprocess(self.__model.evaluate_mode) # Test the best model, this part might have bugs + acc_val, acc_test = self._postprocess() if acc_val > best_val: best_val = acc_val best_test = acc_test @@ -101,27 +95,41 @@ def _execute(self): print(f'Best val: {best_val:.4f}, best test: {best_test:.4f}') return best_test - def _postprocess(self, evaluate_mode): + def _postprocess(self): self.__model.eval() - if self.__mini_batch is False: - outputs = self.__model.model_forward( - range(self.__dataset.num_node), self.__device).to("cpu") + if self.__model.evaluate_mode == "full": + if self.__mini_batch is False: + outputs = self.__model.model_forward( + range(self.__dataset.num_node), self.__device).to("cpu") + else: + outputs = [] + for batch in self.__all_eval_loader: + sample_dict = self.__model.sampling(batch) + output, batch = self.__model.model_forward(batch, self.__device, **sample_dict) + outputs.append(output) + outputs = torch.vstack(outputs) + + # NOTE: self.__model.postprocess now directly returns the original outputs + final_output = self.__model.postprocess(self.__dataset.adj, outputs) + acc_val = accuracy( + final_output[self.__dataset.val_idx], self.__labels[self.__dataset.val_idx]) + acc_test = accuracy( + final_output[self.__dataset.test_idx], self.__labels[self.__dataset.test_idx]) else: - outputs = None + # ClusterGCN for batch in self.__all_eval_loader: - if evaluate_mode == "sampling": + outputs, labels = [], [] + for batch in self.__all_eval_loader: sample_dict = self.__model.sampling(batch) + sample_dict.update({"ret_full": True}) output, batch = self.__model.model_forward(batch, self.__device, **sample_dict) - else: - output, batch = self.__model.model_forward(batch, self.__device) - if outputs is None: - outputs = output - else: - outputs = torch.vstack((outputs, output)) - - final_output = self.__model.postprocess(self.__dataset.adj, outputs) - acc_val = accuracy( - final_output[self.__dataset.val_idx], self.__labels[self.__dataset.val_idx]) - acc_test = accuracy( - final_output[self.__dataset.test_idx], self.__labels[self.__dataset.test_idx]) + output = self.__model.postprocess(sample_dict["adj"], output) + outputs.append(output[batch]) + labels.append(self.__labels[batch]) + outputs = torch.vstack(outputs) + labels = torch.cat(labels) + + acc_val = accuracy(outputs, labels) + acc_test = accuracy(outputs, labels) + return acc_val, acc_test diff --git a/sgl/tasks/utils.py b/sgl/tasks/utils.py index 83c2a14..9cc05ca 100644 --- a/sgl/tasks/utils.py +++ b/sgl/tasks/utils.py @@ -48,11 +48,8 @@ def mini_batch_evaluate(model, val_loader, test_loader, labels, device): val_num = 0 correct_num_val, correct_num_test = 0, 0 for batch in val_loader: - if model.evaluate_mode == "sampling": # clustergcn still uses mini-batches during evaluation - sample_dict = model.sampling(batch) - val_output, batch = model.model_forward(batch, device, **sample_dict) - else: # other models use a full batch for evaluation - val_output, batch = model.model_forward(batch, device) + sample_dict = model.sampling(batch) + val_output, batch = model.model_forward(batch, device, **sample_dict) pred = val_output.max(1)[1].type_as(labels) correct_num_val += pred.eq(labels[batch]).double().sum() val_num += len(batch) @@ -60,11 +57,8 @@ def mini_batch_evaluate(model, val_loader, test_loader, labels, device): test_num = 0 for batch in test_loader: - if model.evaluate_mode == "sampling": - sample_dict = model.sampling(batch) - test_output, batch = model.model_forward(batch, device, **sample_dict) - else: - test_output, batch = model.model_forward(batch, device) + sample_dict = model.sampling(batch) + test_output, batch = model.model_forward(batch, device, **sample_dict) pred = test_output.max(1)[1].type_as(labels) correct_num_test += pred.eq(labels[batch]).double().sum() test_num += len(batch) @@ -91,7 +85,7 @@ def mini_batch_train(model, train_loader, labels, device, optimizer, loss_fn): correct_num = 0 loss_train_sum = 0. train_num = 0 - + for batch in train_loader: optimizer.zero_grad() sample_dict = model.sampling(batch) From 41ae1c5b9bd5f3d467897366174c6c96f26e14ac Mon Sep 17 00:00:00 2001 From: infinity Date: Wed, 15 Nov 2023 01:54:42 +0000 Subject: [PATCH 04/28] add lazygcn model, but currently it only supports full-batch evaluation. add recyclingsampling task class. --- examples/configs/fastgcn.yml | 7 +- examples/configs/graphsage.yml | 15 +- examples/configs/lazygcn.yml | 27 ++++ examples/configs/vanillagnn.yml | 1 + examples/sample_based_nodeclass.py | 5 +- .../__pycache__/planetoid.cpython-37.pyc | Bin 4399 -> 4396 bytes sgl/dataset/planetoid.py | 2 +- .../__pycache__/base_model.cpython-37.pyc | Bin 9590 -> 9606 bytes .../__pycache__/simple_models.cpython-37.pyc | Bin 11524 -> 12911 bytes sgl/models/base_model.py | 12 +- sgl/models/homo/__init__.py | 4 +- .../homo/__pycache__/__init__.cpython-37.pyc | Bin 715 -> 763 bytes .../homo/__pycache__/fastgcn.cpython-37.pyc | Bin 869 -> 980 bytes .../homo/__pycache__/graphsage.cpython-37.pyc | Bin 1045 -> 1080 bytes .../homo/__pycache__/lazygcn.cpython-37.pyc | Bin 0 -> 2728 bytes .../__pycache__/vanillagnn.cpython-37.pyc | Bin 1557 -> 1266 bytes sgl/models/homo/fastgcn.py | 2 + sgl/models/homo/graphsage.py | 3 +- sgl/models/homo/lazygcn.py | 62 ++++++++ sgl/models/homo/vanillagnn.py | 6 - sgl/models/simple_models.py | 45 +++++- .../__pycache__/sampler.cpython-37.pyc | Bin 14040 -> 13930 bytes sgl/sampler/__pycache__/utils.cpython-37.pyc | Bin 1093 -> 3162 bytes sgl/sampler/sampler.py | 45 +++--- sgl/sampler/utils.py | 46 +++++- sgl/tasks/__init__.py | 5 +- sgl/tasks/__pycache__/__init__.cpython-37.pyc | Bin 854 -> 908 bytes ...ode_classification_sampling.cpython-37.pyc | Bin 4345 -> 8043 bytes sgl/tasks/__pycache__/utils.cpython-37.pyc | Bin 10787 -> 10787 bytes sgl/tasks/node_classification_sampling.py | 148 +++++++++++++++--- sgl/tasks/utils.py | 7 +- sgl/utils/__init__.py | 2 + sgl/utils/__pycache__/__init__.cpython-37.pyc | Bin 0 -> 287 bytes .../auto_choose_gpu.cpython-37.pyc | Bin 0 -> 1163 bytes .../basic_operations.cpython-37.pyc | Bin 0 -> 583 bytes sgl/utils/basic_operations.py | 11 ++ 36 files changed, 370 insertions(+), 85 deletions(-) create mode 100644 examples/configs/lazygcn.yml create mode 100644 sgl/models/homo/__pycache__/lazygcn.cpython-37.pyc create mode 100644 sgl/models/homo/lazygcn.py create mode 100644 sgl/utils/__pycache__/__init__.cpython-37.pyc create mode 100644 sgl/utils/__pycache__/auto_choose_gpu.cpython-37.pyc create mode 100644 sgl/utils/__pycache__/basic_operations.cpython-37.pyc create mode 100644 sgl/utils/basic_operations.py diff --git a/examples/configs/fastgcn.yml b/examples/configs/fastgcn.yml index 71401e6..2fdcbb7 100644 --- a/examples/configs/fastgcn.yml +++ b/examples/configs/fastgcn.yml @@ -12,18 +12,15 @@ sampler: prob_type: "normalize" replace: True eval: - name: "NeighborSampler" - layer_sizes: "-1,-1" - pre_sampling_op: "LaplacianGraphOp" - cached: True + name: "FullSampler" model: name: "FastGCN" hidden_dim: 128 dropout: 0.5 num_layers: 2 task: + name: "NodeClassification_Sampling" train_batch_size: 256 - eval_batch_size: 256 epochs: 30 lr: 0.1 weight_decay: 0.00005 diff --git a/examples/configs/graphsage.yml b/examples/configs/graphsage.yml index fa4b076..6a7f18b 100644 --- a/examples/configs/graphsage.yml +++ b/examples/configs/graphsage.yml @@ -1,29 +1,26 @@ dataset: classname: "Planetoid" - name: "cora" + name: "pubmed" root: "/home/ssq/test_data/" - split: "official" + split: "full" sampler: training: name: "NeighborSampler" inductive: False layer_sizes: "5,5" prob_type: "normalize" - replace: False + replace: True post_sampling_op: "RwGraphOp" eval: - name: "NeighborSampler" - layer_sizes: "-1,-1" - post_sampling_op: "RwGraphOp" - cached: True + name: "FullSampler" model: name: "GraphSAGE" hidden_dim: 128 dropout: 0.5 num_layers: 2 task: - train_batch_size: 64 - eval_batch_size: 64 + name: "NodeClassification_Sampling" + train_batch_size: 512 epochs: 20 lr: 0.1 weight_decay: 0.00005 diff --git a/examples/configs/lazygcn.yml b/examples/configs/lazygcn.yml new file mode 100644 index 0000000..d3be737 --- /dev/null +++ b/examples/configs/lazygcn.yml @@ -0,0 +1,27 @@ +dataset: + classname: "Planetoid" + name: "pubmed" + root: "/home/ssq/test_data/" + split: "full" +sampler: + training: + name: "NeighborSampler" + inductive: False + layer_sizes: "5,5" + prob_type: "normalize" + replace: True + post_sampling_op: "LaplacianGraphOp" +model: + name: "LazyGCN" + hidden_dim: 128 + dropout: 0.5 + num_layers: 2 + max_workers: 5 + train_batch_size: 2048 +task: + name: "NodeClassification_RecycleSampling" + epochs: 20 + lr: 0.1 + weight_decay: 0.00005 + loss_fn: "nll_loss" + diff --git a/examples/configs/vanillagnn.yml b/examples/configs/vanillagnn.yml index 4d91943..9af310b 100644 --- a/examples/configs/vanillagnn.yml +++ b/examples/configs/vanillagnn.yml @@ -14,6 +14,7 @@ model: dropout: 0.5 num_layers: 2 task: + name: "NodeClassification_Sampling" epochs: 20 lr: 0.1 weight_decay: 0.00005 diff --git a/examples/sample_based_nodeclass.py b/examples/sample_based_nodeclass.py index 1fc86ae..62da388 100644 --- a/examples/sample_based_nodeclass.py +++ b/examples/sample_based_nodeclass.py @@ -4,7 +4,7 @@ import sgl.dataset as Dataset import sgl.sampler as Sampler import sgl.models.homo as HomoModels -from sgl.tasks import NodeClassification_Sampling +import sgl.tasks as Tasks if __name__ == "__main__": @@ -40,5 +40,6 @@ model = getattr(HomoModels, model_name)(dataset, training_sampler, eval_sampler, **model_kwargs) task_kwargs = config["task"] task_kwargs.update({"device": device}) - test_acc = NodeClassification_Sampling(dataset, model, **task_kwargs).test_acc + task_name = task_kwargs.pop("name") + test_acc = getattr(Tasks, task_name)(dataset, model, **task_kwargs).test_acc print(f"final test acc: {test_acc}") \ No newline at end of file diff --git a/sgl/dataset/__pycache__/planetoid.cpython-37.pyc b/sgl/dataset/__pycache__/planetoid.cpython-37.pyc index c7122815dae981e98b12df0b63b972dea232c75e..ea0075e6df1061de1241253dff6c884a28519261 100644 GIT binary patch delta 33 ncmZ3lv_^^7iI delta 36 qcmZ3Zv|fqViI diff --git a/sgl/dataset/planetoid.py b/sgl/dataset/planetoid.py index 3fa1bd7..4d1c580 100644 --- a/sgl/dataset/planetoid.py +++ b/sgl/dataset/planetoid.py @@ -104,7 +104,7 @@ def __generate_split(self, split): train_idx = range(self.num_classes * 20) val_idx = range(self.num_classes * 20, self.num_classes * 20 + 500) test_idx = range(self.num_node - 1000, self.num_node) - elif split =='fastgcn': + elif split == "full": train_idx = range(self.num_node - 1500) val_idx = range(self.num_node - 1500, self.num_node - 1000) test_idx = range(self.num_node - 1000, self.num_node) diff --git a/sgl/models/__pycache__/base_model.cpython-37.pyc b/sgl/models/__pycache__/base_model.cpython-37.pyc index a503947e23b0eec4cb0e09fc3ad0d9e324cabe08..5f8f0077213fbce45fcc9af446151b82fd8da161 100644 GIT binary patch delta 874 zcmYk)-%C?b902fruiL$MyLZ2?(rjAJY1V~AqB04IC7I!rM#{8lt*hSLOoyJk{MB6s zf=_|O_aUetdnpJZg93Z%A&7$NA*jxGd#J%a`|f+rx!?0W_x$MJNcwyH zexHkfhT_BKn;rjG0ldy1VHhCd47=|FwBerTA)IhZq5#m1cchna3OnToAUnJA3P3L= zeG4#vuYH?P#zCbCQaGoa;chF3kIV?`V7szH(QUh?U|nG6ZP$!vOSFY2l$Q!TqF?-? z?3ZLlJ1JvOjn=Um3Z>=;DeALg6fpiC4txe9_aAO6Gm|| zm(3S&uW279f{*zA-UnOxs1P#yd-y8W%FzQ%qqvV~CoT||h$Jz<$#^YeZhqFB$}Y2u zl+=kqB85AlkG1D`SlNmU0ZDbsqO9Q4@Fz$+n_5ADjPtv0OMp>)9C-#~*jGORQgTB2!CqF@P1>07k=0H87lR&;ALZcOb0fO?Xi1s1yjutd>acZS7|X2S^O3|4Owh& zj`Q-V=FYWilokksm>?#JDZ(VK6Vrr5{Ocu8f9XZCMq-AzLCg|!M2Wb~DZ8!ws#$g~ z=FKHttG<&0@1!LjPBNNg5!3)H5;uuOqKSwROT;pfABnBhXga zV)dkHA|0MCR%;%mMCBgb!VfJA{~dqK=0m&uYf(!ws3Csuj>|QS@_LY!@Kx)Or|h++ bXR-d!L6NQ}v$6Y7hXN|7HCngUqb2_W44KCl delta 848 zcmYk)OH30{6b9g)fzAw_DJ>iMprJ3iP5-mq4zHfoupsyoH=*yx#!MS{MV4X*X?#W zca>XutL)6q>L<)Tha|4)hR>@vi>~P4ZM=hLZNHzzym;oC2x2Y^{QM3qy9GC@jU>hGp yWRG_ue!YlfkJ#Y1kuf9XvF4VK6*AV6#ceJ1Dh^4=vuJ5I6qC#f-w9ujFa8(TbGQot diff --git a/sgl/models/__pycache__/simple_models.cpython-37.pyc b/sgl/models/__pycache__/simple_models.cpython-37.pyc index 7d9a691429c7887be27ef6f343d6ade7984df1ad..7408637a28c5d303c03a7aa3473082d1d2b871d1 100644 GIT binary patch delta 1983 zcmbtVO>7%Q6rMM`YpDHPP|DPD9|btsg+2Th=SZIMIc#1uI1T;q=`+& zn~;=^+|&u-FH~U;NN`CcdO%1>RqjZf5gd@X7zuIZ02LBK;()|^yEsiMdSa#hcKl}E zn>X)$^LZ=xKt!HzX$doY?q#prm!I5=#CaU{GT%V!-k+JwG!t==*kU4KYYV}emTO1L zDbbCVZo3_ScL=!7tkZ{s@Nu%5!7w(rAb;QmkD&H z&i9O`ek4DRz~ENF(nD!H)VmHP}f>MF0;#Ibi`qn0^)0Pjozkh6br^O(3ObHT7+8IvJ{K$X};4FPdhgUjr-u8klv$=DCz8s1p5Jn2Zzw`&l8sly`BrM@s2e#4BZJjHan1mgyJZ4jF0$TL^Vyv9NASxlM+w zx{*l28TF}kW5g>eQNrTB-`ljJ^9Z_Vw1aT|Fg?-%1uD24L9jq}(6k_Xr-> zCiy-sZ?uP2en7BG&`(ez_>k@gT)@7itR}4ac;gzFGEMVC1j7Uc0*_mMTd6=}aMacg zYU^EU>xC0qWqMkr3Ek7GkO+w239SlyTD4wWv~hQJbueaoG{NAWOeJyfs7yH;Hr6-@ zJQOS4@7dl_igalkx9MbO8ZE(F`oc?;{xL!OH;dzfhH>bjyGMD^jjDW)-~!2;%N?k3 zyjZ1?Sk>4-GA(GXkzkl#phg0y?;9w8v{qjrIxH$B@)PBJ|_pR*wa~b26~&1;rFhR~OidU5*-WWaTS(5aiYW<(Pj1 zM&&Ao^GKG9ctY&!uv3=*O}eW{x~{Wv-T|n^5QnI*$Bia5h{!z|kRN%u9uwv`b^Ar? z;E^zs#p+a(W6y!3s$-xKq9I08 z>!URi&cdyNI~Rgkh~UCScdlF-#K)@Oq6Kl`%6d*h6vYeoyEEthXE^^qJL%Pn`bn?1 zSE9A?;#_9aYUn=sht!C@Rg|ml#`l;j4HqQS*PMk%f#zZ znBV7zPnZm6jcfEJp);-%Oc4Ea;D9j#+qfBi8shs_k8{Od&a@Sx$l5Y%U2Ux>;d^PR*pm7*)4u z)B1dTo;AGOxMLl__md~GHTemERPHHDb6Q%Wzae*ey1whA{>OJ`{I+kKZ#4#c=uUi- z3%3|G#wsI)Ret~ixD~%-Rk@%OOu1w6HW!TF*uzh2glMy{e*|ofMpy*;1!@6Q}(E diff --git a/sgl/models/base_model.py b/sgl/models/base_model.py index c6e2d93..2e3baa4 100644 --- a/sgl/models/base_model.py +++ b/sgl/models/base_model.py @@ -1,9 +1,8 @@ import torch import torch.nn as nn import torch.nn.functional as F - from sgl.data.base_dataset import HeteroNodeDataset -from sgl.tasks.utils import sparse_mx_to_torch_sparse_tensor +from sgl.utils import sparse_mx_to_torch_sparse_tensor class BaseSGAPModel(nn.Module): @@ -77,7 +76,7 @@ def __init__(self, evaluate_mode="full"): @property def evaluate_mode(self): return self._evaluate_mode - + def sampling(self, batch_inds): if self.training: return self._training_sampling_op.sampling(batch_inds) @@ -89,6 +88,7 @@ def preprocess(self, adj, x): self._norm_adj = self._pre_graph_op._construct_adj(adj) else: self._norm_adj = adj + self._norm_adj = sparse_mx_to_torch_sparse_tensor(self._norm_adj) if hasattr(self, "_pre_feature_op"): self._processed_feature = self._pre_feature_op._transform_x(x) else: @@ -98,7 +98,7 @@ def postprocess(self, adj, output): if self._post_graph_op is not None: raise NotImplementedError return output - + # a wrapper of the forward function def model_forward(self, batch_idx, device, **kwargs): return self.forward(batch_idx, device, **kwargs) @@ -107,8 +107,8 @@ def forward(self, batch_idx, device, **kwargs): sampler_name = self._training_sampling_op.sampler_name if self.training else self._eval_sampling_op.sampler_name if sampler_name in ["FastGCNSampler", "NeighborSampler"]: sampled_adjs = kwargs["sampled_adjs"] - n_ids = kwargs["n_ids"] - sampled_x = self._processed_feature[n_ids].to(device) + batch_in = kwargs["batch_in"] + sampled_x = self._processed_feature[batch_in].to(device) sampled_adjs = [sampled_adj.to(device) for sampled_adj in sampled_adjs] effective_batch = batch_idx output = self._base_model(sampled_x, sampled_adjs) diff --git a/sgl/models/homo/__init__.py b/sgl/models/homo/__init__.py index 541f34e..a3eb663 100644 --- a/sgl/models/homo/__init__.py +++ b/sgl/models/homo/__init__.py @@ -10,6 +10,7 @@ from .clustergcn import ClusterGCN from .graphsage import GraphSAGE from .vanillagnn import VanillaGNN +from .lazygcn import LazyGCN __all__ = [ "SGC", @@ -23,5 +24,6 @@ "FastGCN", "ClusterGCN", "GraphSAGE", - "VanillaGNN" + "VanillaGNN", + "LazyGCN" ] diff --git a/sgl/models/homo/__pycache__/__init__.cpython-37.pyc b/sgl/models/homo/__pycache__/__init__.cpython-37.pyc index c6909b0e6f7581b4782e34d3235e439d17f90386..636c71820682035072914acc873b070f753bd0f2 100644 GIT binary patch delta 187 zcmX@j`kPhViIRxs7#JRdI55BqWHRq5_LIf2oZ zqlgb^n4hNfnH=SEfU@byd0?4a?D6r5IXUt1D;bJ}fhvoHCr2l1PyZX3&(H_^Nes6{8KapQhyGeT-?0wL) diff --git a/sgl/models/homo/__pycache__/fastgcn.cpython-37.pyc b/sgl/models/homo/__pycache__/fastgcn.cpython-37.pyc index 910aa41eb97d334cce6c85552b35259fba85a9ae..f09f40a14964a09ab896322f53a939b365726dc5 100644 GIT binary patch delta 452 zcmYL@ze~eF6vyvfa!G8PQbYt7s}4ddI15TempVv0s2~A3LNC_JwGEdPmuz({Qi6k9 z`#(8lauNI+oP4h$-tj)XUwq&9(|LFObKfs9d|!Ll(d*pbY8_#rm;%j)5CBllH6L1m zwaP7R4@Dq=wVAS&cxFmO*3=3d7;H3k6vaY*D@y)%9@EqDRY;!Syl`$*_}SwX6)k> z?}5qpp?Od#L^6pZV+0NgPs+h)kVhB_DZ83?cD&xC;X|sMcou+f-kfVhq+psMl zSY|eL_%tv3iS4Wh~{yjI(T_jgc}vr72d*<3A*zCemnKEjCo0y?Wp82|tP diff --git a/sgl/models/homo/__pycache__/graphsage.cpython-37.pyc b/sgl/models/homo/__pycache__/graphsage.cpython-37.pyc index 7b70e8861e883dc28598b98c947669b94e9eb4a1..ab53df435869a84cd831e74f6cd2432dbeb23037 100644 GIT binary patch delta 333 zcmbQrv4exxiI#vhat|XPBgf>sOxZ#zKu!^e&}1x92C-EqM={HoOBAQ)=;h|8q~;Xs z6=&ua~bhp^`<+llL>rnX`aYh=Ek(7o-*?mgE-|>%m>Z3bKO@ PM8Iu^uv901XO;s198XsK delta 326 zcmYk0zfQw25XSu{PTZE16{ysKcF9s^rXt1;l~5PNVgDx_P1lQ-;W5Z%0Np z9s=O#mFX{XFV4J)JHm`tN?_G+WUDxW3ZCxd&}?0#J3IR!c_Z6`+NME+qSzu?rRXP6*G`=V>OoOc+o*#9Qer%$R=b*y z70FFj!t6-`7{eHbnTC4BIc@G%}>N^C&GvRXj{hb_Ct{k<3oB zT<(2v0KJWesg{p_`_aeyAAVE`nZwZLemcw3Y?79HDxDquVFqJw=~=3@WcIyt%Ufl)SHybsbZq~s$?sdxyp!Y`Y1Kp}ybtr-_gR9$h?;6j$BM4|09*p5q zP`-Kq=A}w?fLCoG!@vgRp-heS%cF^~VVUKr*0xpV>8Vt}2vyChxf$NIfu7H#QmCkn zIG>ag!?{8O*&d%M$;bGDyqekD+^7`h%Q4q!G0TyG;agmOmgcVwb$O94_~(^6#=`>l z)^=ko4KL>fH%CgQLfb}a-D`0nKbvII`dl4VKutPNVz*#oU>k)GXWQ-yd?FS$<^^2l z8M>sxC$e4K9JQV+sci>)PHpHgQ`ruG27by5=c0Y1M-5cX-A&Z!XRG zKgl=m>>gEx+|~MI_wnBTyZPy2SC8{ur-H`0YBvY}9%tofcB<~cCAv`ccYzSvr(GJl zZ-<7g;rfixEqVsw>7rt^e(wwohy`1$A3)VV15%SIxuP|7e#Z=W_>`T~D<9gR8LQbV zz0i3kJs7ssU07}VJgZ7=)I2jh6;n{JtW<%l3SnOT5XqY;IRIy?GO4u`>qi`V>PI-` z;Yzwtc%O#+9EW4@&ea;+LF0y~HYhk)D|jQ4q5oh(F|1C-KLax+OD5eR3+RN1e+Yqmws=`t`u9n`e;w*d&20Vp5VUbozd+hCrUMzqV$Zm&)B+M1VELOj59l3dbDcIXLD_!klEED2=cYuuHkv|2q>Lo}+^Fn*EKP z_>UVu8!{W}`LkkT)DJ)!x@`2GrF`hwNG@bH2V}Fomtc4}FEjUSH%|IZd@rA9lU2p+ z{;%=W<_c#&UR(bVfEhq!g{Px?d6lKPz7NcL*)Ue$gOzFg~gHc)9fmNtg?%H~SR z(%9DFyurS~69=&$S5^Mu>Y57GwcS)hKmvA4BOxsSPR?@!vFgGsTH{gTkSVXKnNA#x zB1rL>u{2L^w4GMZv<;ghox%{Vdw9{`Bf%Y&3)BAVD!vC3E3e!Fbc$G?b?JZ&=nmAg zJOAG|83vHKV^dJ{+&@`(Fb1jrGw{{hz|;5h&Q literal 0 HcmV?d00001 diff --git a/sgl/models/homo/__pycache__/vanillagnn.cpython-37.pyc b/sgl/models/homo/__pycache__/vanillagnn.cpython-37.pyc index 5c8a725d6406da9cf2a5002d1b29fdb7bf55195b..06f41a132cb9e3ab6fe84a42ec6f0c82660840ee 100644 GIT binary patch delta 360 zcmYL@Jxjzu5Qb;=D<51AIV?`FI#5ZYm|Cs~f{k3GpllN)8>?M~1ZxWq4(!5mzr>$m zVSiZI+WHIJiHcJ^yz?^qudFX zL*`bwNJlQ#F~S)bVdN5wY~&NyiNOoV)hB5e{bMn&DgBN34c#n+1~eMb%%J8bUcoII zWuj%G#R>#y89d3$E*>E&f6h)(`kuYx#NY92xgFl8{*@n3bJ41>v=JdDBmuedAL6j2 zvrSf2wY}-Oy0NyZ+I_kq}J|M^d@^tjYm*NQ^55{2==~-> SarfrIHghBinI<^GE&dJ1vODzv delta 670 zcmYjO&ui2`6n-x=`MKFus#Vd0J%}0)0ekQyJ*ZF$5)eV`T!xsLYU^aP%w(aIg@Rq^ zxdiX}CwTCm3H1-~>dB)sabIUJw}M7WF@6_JRS z!2ENdzyj+z>XSF+^iZjCwl;?xPk7Z#O{?cs%S&<8 z$9&r8+NIie{S7|TO(l(yf*;7Tooks9E$_M>aToAdKd|I4Japw`rse9=P0LMS0dRsT zo@}mH&DW}o>d;DbTyAMtA JDsEs({s3xNk=_6R diff --git a/sgl/models/homo/fastgcn.py b/sgl/models/homo/fastgcn.py index 503f3f1..7483316 100644 --- a/sgl/models/homo/fastgcn.py +++ b/sgl/models/homo/fastgcn.py @@ -1,9 +1,11 @@ from sgl.models.simple_models import GCN from sgl.models.base_model import BaseSAMPLEModel +from sgl.operators.graph_op import LaplacianGraphOp class FastGCN(BaseSAMPLEModel): def __init__(self, dataset, training_sampler, eval_sampler, hidden_dim, dropout=0.5, num_layers=2, device="cpu"): super(FastGCN, self).__init__() + self._pre_graph_op = LaplacianGraphOp(r=0.5) self._training_sampling_op = training_sampler self._eval_sampling_op = eval_sampler self._base_model = GCN( diff --git a/sgl/models/homo/graphsage.py b/sgl/models/homo/graphsage.py index 0ea8a69..a8ac314 100644 --- a/sgl/models/homo/graphsage.py +++ b/sgl/models/homo/graphsage.py @@ -1,11 +1,12 @@ -from sgl.sampler import NeighborSampler from sgl.models.simple_models import SAGE from sgl.models.base_model import BaseSAMPLEModel from sgl.operators.message_op import PreNormMessageOp +from sgl.operators.graph_op import RwGraphOp class GraphSAGE(BaseSAMPLEModel): def __init__(self, dataset, training_sampler, eval_sampler, hidden_dim, dropout=0.5, num_layers=2, device="cpu"): super(GraphSAGE, self).__init__() + self._pre_graph_op = RwGraphOp() self._pre_feature_op = PreNormMessageOp(p=1, dim=1) self._training_sampling_op = training_sampler self._eval_sampling_op = eval_sampler diff --git a/sgl/models/homo/lazygcn.py b/sgl/models/homo/lazygcn.py new file mode 100644 index 0000000..579450f --- /dev/null +++ b/sgl/models/homo/lazygcn.py @@ -0,0 +1,62 @@ +from sgl.sampler.utils import RandomBatch +from sgl.models.simple_models import RecycleGCN +from sgl.models.base_model import BaseSAMPLEModel +from sgl.operators.graph_op import LaplacianGraphOp +from sgl.utils import sparse_mx_to_torch_sparse_tensor + +import torch +import numpy as np +import concurrent.futures + +class LazyGCN(BaseSAMPLEModel): + def __init__(self, dataset, training_sampler, eval_sampler, hidden_dim, train_batch_size, dropout=0.5, num_layers=2, max_workers=5, max_threads=-1, rho=1.1, tau=2, num_iters=1, device="cpu"): + super(LazyGCN, self).__init__() + self._pre_graph_op = LaplacianGraphOp(r=0.5) + self._training_sampling_op = training_sampler + self._eval_sampling_op = eval_sampler + self._max_workers = max_workers + self._max_threads = max_threads if max_threads > -1 else torch.get_num_threads() // 2 + self._device = device + # hyperparameters for recycling + self._rho = rho + self._tau = tau + self._num_iters = num_iters + self._minibatch = RandomBatch(dataset.train_idx, train_batch_size) + # define the base model + self._base_model = RecycleGCN( + nfeat=dataset.num_features, nhid=hidden_dim, nclass=dataset.num_classes, nlayers=num_layers, dropout=dropout + ).to(device) + + def preprocess(self, adj, x): + self._norm_adj = self._pre_graph_op._construct_adj(adj) + self._norm_adj = sparse_mx_to_torch_sparse_tensor(self._norm_adj).to(self._device) + self._processed_feature = x.to(self._device) + + def generate_taus(self, T): + taus = [] + k = 0 + total_taus = 0 + while total_taus < T: + tau_i = int(self._tau * np.power(self._rho, k)) + tau_i = min(tau_i, T - total_taus) + taus.append(tau_i) + total_taus += tau_i + k += 1 + + return taus + + def flash_sampling(self, num_iter): + num_loops = (num_iter + self._max_threads - 1) // self._max_threads + num_iter_pt = self._max_threads + sampling_func = self._training_sampling_op.sampling + + for l in range(num_loops): + + if l == num_loops - 1: + num_iter_pt = num_iter - ((num_loops - 1) * self._max_threads) + + with concurrent.futures.ThreadPoolExecutor(max_workers=self._max_workers) as executor: + sampling_jobs = [executor.submit(sampling_func, self._minibatch) for _ in range(num_iter_pt)] + + for future in concurrent.futures.as_completed(sampling_jobs): + yield (future.result()) diff --git a/sgl/models/homo/vanillagnn.py b/sgl/models/homo/vanillagnn.py index c1f897c..b421fda 100644 --- a/sgl/models/homo/vanillagnn.py +++ b/sgl/models/homo/vanillagnn.py @@ -1,7 +1,6 @@ import sgl.models.simple_models as SimpleModels from sgl.models.base_model import BaseSAMPLEModel from sgl.operators.graph_op import LaplacianGraphOp, RwGraphOp -from sgl.tasks.utils import sparse_mx_to_torch_sparse_tensor class VanillaGNN(BaseSAMPLEModel): @@ -19,8 +18,3 @@ def __init__(self, dataset, training_sampler, eval_sampler, hidden_dim, basemode self._base_model = getattr(SimpleModels, basemodel)( nfeat=dataset.num_features, nhid=hidden_dim, nclass=dataset.num_classes, nlayers=num_layers, dropout=dropout ).to(device) - - def preprocess(self, adj, x): - self._norm_adj = self._pre_graph_op._construct_adj(adj) - self._norm_adj = sparse_mx_to_torch_sparse_tensor(self._norm_adj) - self._processed_feature = x diff --git a/sgl/models/simple_models.py b/sgl/models/simple_models.py index f20dced..e9e0663 100644 --- a/sgl/models/simple_models.py +++ b/sgl/models/simple_models.py @@ -215,6 +215,21 @@ def forward(self, input, adj): else: return output +class RecGCNConv(GCNConv): + def __init__(self, in_features, out_features, bias=False): + super(RecGCNConv, self).__init__(in_features, out_features, bias) + + def forward(self, input, adj, recycle_vec=None): + support = torch.mm(input, self.weight) + output = torch.spmm(adj, support) + if recycle_vec is not None: + output = output[recycle_vec, :] + + if self.bias is not None: + return output + self.bias + else: + return output + class SAGEConv(nn.Module): """ Simple GraphSAGE layer, use mean as aggregation way @@ -283,13 +298,13 @@ def forward(self, x, adjs): return F.log_softmax(repr, dim=1) class GCN(nn.Module): - def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5): + def __init__(self, nfeat, nhid, nclass, layer=GCNConv, nlayers=2, dropout=0.5): super(GCN, self).__init__() self.gcs = nn.ModuleList() - self.gcs.append(GCNConv(nfeat, nhid)) + self.gcs.append(layer(nfeat, nhid)) for _ in range(nlayers-2): - self.gcs.append(GCNConv(nhid, nhid)) - self.gcs.append(GCNConv(nhid, nclass)) + self.gcs.append(layer(nhid, nhid)) + self.gcs.append(layer(nhid, nclass)) self.dropout = dropout def reset_parameter(self): @@ -310,4 +325,26 @@ def forward(self, x, adjs): repr = F.relu(repr) repr = F.dropout(repr, self.dropout, training=self.training) repr = self.gcs[-1](repr, adjs) + return F.log_softmax(repr, dim=1) + + +class RecycleGCN(GCN): + def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5): + super(RecycleGCN, self).__init__(nfeat, nhid, nclass, RecGCNConv, nlayers, dropout) + + def forward(self, x, adjs, recycle_vector=None): + repr = x + if isinstance(adjs, list): + for i, adj in enumerate(adjs[:-1]): + repr = self.gcs[i](repr, adj) + repr = F.relu(repr) + repr = F.dropout(repr, self.dropout, training=self.training) + repr = self.gcs[-1](repr, adjs[-1], recycle_vector) + else: + for gc in self.gcs[:-1]: + repr = gc(repr, adjs) + repr = F.relu(repr) + repr = F.dropout(repr, self.dropout, training=self.training) + repr = self.gcs[-1](repr, adjs, recycle_vector) + return F.log_softmax(repr, dim=1) \ No newline at end of file diff --git a/sgl/sampler/__pycache__/sampler.cpython-37.pyc b/sgl/sampler/__pycache__/sampler.cpython-37.pyc index 16b72c549612bc7e34d6a2e9e9989d3cb8d2f01b..87dcdf3e161bc1359c50891c61bd023c309f0efd 100644 GIT binary patch delta 3894 zcmb_fYit|G5$4|UaimC!BB=*OS)wSv&X#3a57&uoE4D1!R$@7+ABtJG%!zlQX&)ci zJIazuSx6)oZ4=j4Hg$uvDJ=9?kplg}0{z!EMG^FYHb{%0$rWh(w=bkX{}laEbY_o| zC>1S=pcMFVXLfdXc6Pqm)dz(?>cP{2fX@M+-*jHqB4^fu9ptBL`|eKN5Yp{u9cn-g zt~%797F-e29UwQWA&^5{j)B~whCvQTPIe_r)uxn^BIT zRUnhe>WX1xHC@kSwhU!LThXW;0hh!%4JUv^{1PmBFB)Ux&UYreL51$nAT%gSl+PZ* zRttbEa(`$96;V{wr{+%a`J364>oj}2^?dDp*UyMsb-CFYB!ELg(8 zpykt*1&wBm{GztyEp4xTPr60~xt(?UW=Mq1`MTSCVJ*vdRnaTjB&8*4iMm##N%ofS zSfCI62-uEAS##4~H(JvK8*3U$Hbe`j@Mp?YgKsIT8OA=)7*p(i(*zgMf}DVv;H`g`47vK$BJS$hi)>oMTC-k|lBesOhRYZn{^<8ONvhFFC7@ zN1mz(8ma1xIo5^i!jfyrT@@aAtL}9|997qn_`=Z{2iW=SyEQfb!Ar9EMMNp7Y$P>W;M%Y zmT2~d<-V*H4G5U4SSgejEEiRZISnG$5{!{vb}-oKl|6RsI@qOPYGixBcw$}0ICTIx zd=haBej*7@;v^x`MtsnY3hdoroIF|kQ}CB$=t(f61YrmO68H@KSRTWqc~vt~puqF8 zNLh~BX>^Kx6gtt0Ry2)p65%vLA3N33L#{HVr8|r%oA$xs&1H%un1~)^Yb_VZN%p&z zM0AJGSHSi*i-dnM%u`OFkK=5Q0l?}9u#Nc(sf&}=sf4Y5fY(_xqMzd92z8-Z!;r=# z-5Br&GzF&)4_%3_cvAXvV_F_uGt_QGv7-10M2!aP;h&>T|;2ee~dpFjzn^!1)?qk@?ojs1N z69`F!lPuhsRIm*4l-w*KsF*}ZA)H0v!xPv#hrkzJ7m=@Y7JM)ZpAJpS@zGKHWFs7iHkc|#x4Ar%_n~HF{tQaX78=#dba#8RP;iQ{5uuV-CGDc zLVddJoOIzpJVgf-v@n#mAY2G3#W2rKu+JpBuKyo6io|;pz8f5MNkdY*6r%;ad}|*i z-y&p=eRX#{&9B}$u-rjYE12PBKVQsc)KXCcM)OVZM8A#p^9cXpQiiS;7TS{{-(fR* zQp;tv9sRqh5kFd-@78Lz@?RO!3F4sOUt(| zfaKwlpf+(y;*zNPx#YzwZrwo~{0R`4wt`u`VwiB>7g+b+iNS3}1DMTg#y;EwX4oGR z0J&lQ`PI1GM2FdXv5sZV=y5b!MEER16~OXk49(;(2}5?;k-<%Zj6CpXKu8Ck*A%k? zkBe-nQZ(r)&{<+ZndirnZsHF!yASrdb@6boxI=ef#6Ycif0e+S=M}lry#NZj$i5}N zc8J@fqn2b+C0~Rjpf6$fj!hxaM_J*(iz7TxkD&=?PUE|6lLrmam;uZ4EP3#^lQ<*X zU|TGg>c?i(i?!v%3wYZm(Eoq6jy|mr4oSK@0cv$X&QpQI-pLWx;5-b&A z!6V0(`9R3HdJ(E`?ln+AEWW&<-@5!tn4JRX@nw8sf0F7htExx1jt)J_b&Lza;M;FY}=za$>wfM)+ z($Z~9%$D?>$ND)ioMYb~x$xZf`fhnLP+_z1AJo=k49RJ~$Y!vG`ARVW6esy_Z?ie* z4={Z+{xr7aE*=zm2}kh~;Wx7R^1^^oR;Z!Dx$s}g;}B`=8qo7aMbA;p2l@m+n*DmT z&&^*a$6h`9csT7fa{55U%k@!WB_&)(5e`q;W{ue-{1jPUV delta 3988 zcmb_fYit}>6`nh@kMZn!y|&l(uI>62$H}IS)5MQ7i5oW$yP+g*)4YfkGM(&+ZC?)*OY(}mdK`g%=*-`Tq-tvi!-SVIlrb`w zEmL#{TE+Y9yNxlJb80#r^ny;YU~wK3Z_0}-E|SXATiT#6E%Ro0dDxty-o1k;SSqEHe6r)+}92>_@ps=}n5fnQE`1BpxF z4bb37anjs~EgXUN7I9sBe6Sz-NR=1y4?)joO{Yp2 z<`pMh@9(fhd`BAyp)%ep{-EvY+ps2_A*vU+1&^Jv9A_JFoae+uSY zmfr?GY*^xn5`#CpC$CBKA!l_?T9Uym%9^yKERMMW=nJl}qtYk0FG2fWsH6ZVTneP5 zHTi1;FR6vq}PMqsy8qG>;>~W8YN1$wM?C?5+SqkDYVB3?mw6>E)=eL zL2l&dEWQO<&AOep^qg_c;?;xUed0&;?J@oo2&BWlX;b1a^;2UH+moF4c@T2YizO|= z0Eh_+kU_OD7MjCAxKbZC1A0&6W~qp&8FNT8pn7L#bjO`QVU!^ zVsXG4JdO+yLX9{RhNHOH+}T>uLbFg5o`K#s#A@@%ELthePN(a52ef!6!XpS2*(b1t zCg->nj>C8tLUp6-5l(6`C1AIMnGzF+x=-!Wx_LLGkCZr)xO;3KI;Pe$2ofJ5gTSb- zm3v#x%kdsm(2GEk*e^cUwl%#K=s9@c_p1p!<^gDNj6uE~A&QU&5PxdgHCQpt4xrZ! zJct6-4b}{NUkoR23=Cl}x)tnv7q%Zm*p09UA&f9Aevw=iJk`~YhwZx;o&7kHwAd%G zbr4~TcsJEIOMZt5m)x?NXOM9W;bRC-BluYdTOUWD2~}*>A)x8`afD4wKd6Q)GR^|b)MtuZ zw~KBiF3uq@io9&zXU&9>f`qYoY2ab9^z-tEoip+>*%Zw=RXM~zPox_v0Sa6%oVWAW zEy&yynG~axx1Wbj-=3GSl|k8%;pZO%3B|NsPj;XZm|0`a0guOcn5nypH8Tr&+Ka&A z#e?qge;39diq4)Ne1tHLiDd7*Rq?X_C5+?XTmKryxTP$@6@-eqhvPXae%m{C;r|fD z7?y9rvmV7kb&uMnruaNYaar`e!CV1`y?Gi<76-K%&*25*kRn zL=xJHyaM$~4?@ulLBWgAT_Q}BY(|MviBimZqUxYLYHJdc=x!04_F|b_(Qz#JS0h8}VcpPCL0>%B4*jhxmhHxFg({#szJ4*Jt<)j0C z2|ye1;aa?toodFy5FSAPK+{Q3@nWn;j>lB#Ax}QZ6(%=2k5ayufdv~wK8t@s-@E<6L$L|bn!)T zEd8aOq$Gu+s>_XR9`0NGCG6g$7|Y1GsN3=B5wZv+Mhb|Hm&U4l`=iN4&cMKD#r+*0 zOrl08YPA4YE63(ka<>=2n*7ktP{rT0afom?V=ojAUJ$>2^jMry8y&^xp`RzjzQNv( zsyWsb(Da`Hfpg{JV2`|)=4bqixQt3QzD;c}Tv^BllhR-;%v?MedT~e98XHG$wh@Xu zVqw>Toi%ZV%D%_49Ov3R;#J@kVS*!~bFfbIJ~p+VN)fHpU!oVV%sH@$7g{h1_ys|# ztE;fK|jgzg87DjE6Zy&+g>FEtGm0;j;*@0C<`+r@IE9 zvs{N)Onz~9EBk^d?oJ=TaHwSpFc6)@SCC5^QTQshX?E(O)vo|=w*0f*-)GBMJg3vm zzGf<_Uj~jBbZwrw>}kjC{M>0P@7SDEKm-lWjcXj&>D7VL<{kK4u>9+djZ7_L+xmh5 zSH(qs6c}`OtZm7QOyz9Dt*CrOg7nAlf*4$EXnCt*a3r<4irfH!C2@Y_#PY*A+6(Kt zX=mVTjIt=kIL9FL?_KEp99@v(b~!5fPd6~k#tA&pCTOTifJI|!gBn&7S{UAK9Q)Jh zPW}e!If+09i8`5V;o6W>Ft}rRA^**7F+JK9JOomncwux)h;ASI#8*f6C%pz|E;r;F z&b%{Jbh9~!uBiNy_|0gaIf{BI6~?edABtF1X!Z>bTy;cG9bd4`V$M2-nS8J5%^)APCBEt9|!s-_A^F_5(bEa4lLgs=yKtJqBuYN}F7??^lDaZmSH zYG-X%&&gYr6U7-$o*VoW{(`O?sO1N6$R&ARwe7LjnjWDpswhd>r0qDpBpHOtnwMIKr%> zx|nr!d=iXXnD?HeE!JLTx2W;I`l6<><8 zm7s<`ekG_y9ekaEo7j0qj7Fgo#li*lVD17{RxVnUPqZ0@<9=x}S?FxK`=~V54W?zO zF3q_eNAvVh+J7?B2BS}3-`*|fdfVFl?T7c@yOSRsZrf?TZPWQ8*Jk^m%5uB4IC3{3 zMJ`OK4knc}X_b{^h$TxHzlERaV)TrI{O9j~qxP>~Pa^t?b_2bI?rnYrdun#dR#0% zanAN{@87PHx_#PSu@k=RG`p7Fx~qa$AZ1UOe^i~*Zr!1k+pG>Y*;hedWtUf8Wh$gE zSPiOV6)tB9{HSUyHrad`gZT4vCxHb zK~_{Qoa?kO0k~n^bMWzUt<hRk@=X7emhO;#PrYL`Tm zxpJYc^upcD%O`UGAkFEcos>q8+eN9gmBqn)+-*=K=c(CqJcBSblca5X4bUL{csky= zm`i9rg%m=I>#QY@2bXxsV&TEt5K3-$i&$4C%?gbcFdPc1~#fd7}W_L$}`n(Y`LACUDe)rY?jkR1xY!dw^YOO@o<`btS@oiYK~tN9WsTwPrywKYF;zRDt;?QL2&MWc=AYR^U7}RXW z{@Ua8@163H{gFLjf5x5AI3_TJ0&ls4spNl{`a;DFRrNQ~&%BU)u~rS6ol(p4Oy^ok z*O79LGMm%5Cw+Dm&6k#xs+>q^zDufnj|hpnsTCV#5rK+4?4?3kPG4?gGVkdaZAoOA- zIIBd>aaRl6s~oqjz`c83Xv9u)v4t+O2McW)8@~cRu8WI9dT3~S4HqNs+Hho^T8k`L z+fOd~4QWB@FckXZ%}WYjM+j+5FBkHUC9fWCRX+3m3uL(Dm6!#y5VasekI6xQ36`Q3 ze?pPk?|slzg<$7*PzaaqB2OoJ)G`Eo^J5|;z{R*{;WQUdcFSy{%>}&A>~dVX$eVg8 zs6Kfqwckbe$;b2{^+$au%N1KrhV*?@W4`di~8$ zu@#^rjSXdh`8kmbT~YB%Q;kHEp@iP&l_mfE9{ubs2vx=N?r*RS54{xHu74KWiKa|h z%R!m_2?_STA)Mbk7l+F`OkzW8P>~)>rcdLxi-4oWQDYW80{nBw-|@Z(y10Ba)04_j in)vf;am|Q?x#Sm3DgG7ZyCwd}E|El?*9IFK-OmBOn7Q%* delta 156 zcmca5ag@W$iIf9d2=Kq6BA?9LJf5y?j8~Fm_qUBODofY`|k28~ZHbDNeiRahQ)b0UqY# zHCXI+Z>JD-zmvV~;i}Gy3zjuac~jD?E$fP07WAI7LJr`&{DRXoRR0rw9hQpKVn^C& zIWYr{D|JsOyKCA)l+{Mqde4#9XteKYdMfac)-`OaVMXCr~gTQqMyNNnx+id8sA@(YY(BXcKNS1nb@~{-i(Hi7y*Kg z^6NkZ&b~Dj)HzUL9f`=9S~J0%LlxJFNSq^;)|tqhnc{UWa_blpNFf6bIdq{1JFq)v zq6>T1sOW!V`+b48gUfRBaf0#Yci&9#8}(ZtD#?;W6aDCIq1jjF_rR0DJfr&=It3bL+?B~ LwoI^XQ+bDfmHtBN diff --git a/sgl/tasks/__pycache__/node_classification_sampling.cpython-37.pyc b/sgl/tasks/__pycache__/node_classification_sampling.cpython-37.pyc index 3cda834a6476d5441d414b1d1d8e176462e79252..fe0401289b85a8e0751ef722c718e263c80bb524 100644 GIT binary patch literal 8043 zcmcgxOK&4rcCIH`ReXq|*m`x_mb=>zj%~FwGf9A+j>psXwBv!H3EV;q6yg-B>lP(a z#Ugd9*p}of7eCJk?)O0Iq zWfKy0>b@W6-us>N-MZgsI12q<+6_PKD#|Zqq!b3;@8U^*3c(bnMv7a!RaZq^i?pHc z>MEt_kufw~Q-;;3GOTKfsthgHmMKQ$3~O#phRvuxY`6^>u0+k@lDjl)xvk-{yF6^W z?cs{MGVHh=Rrx?+Rc1X=m=!iob@v+LHggbnWPFvW50!51Ym80V>8hfVazEY+5i-8V z{2|`X`+n+wH1b)<5wq_5Nx17L2i=OW{2-WcKR6av5~f}fhD?~K-;YIW*pK_(BR>uH zy&}a5AN$e7PeZYCE^n3?Q2(bdcsx-kHTPr^ZM%L<4VIM z^?(H3nlO?u>WS)X^O(h|FJaYri9rtQzkAqqm~tqdUo7bO(F)Ah{D`&lP1Chd(I~|MvKpw=tv!p5#Rc zMQdsm)u9Lfin@uX__x)zdUE@c)3QA~RDntj;YV+J-oukrAb=KhN>v~%e-&}OUc-z3 z#$0?8Az{sQ`7&aO2BBz|cRi~x@0yM|0hsJMMH4A?uAZu@l1c!}0bd1l4a@-m2(k*m z(&qVEYB3FgRC=?F!&1s!3N?@{%>XO5@Idn&2pE^ z^!8KfD~RiZ4zmt4{!gis0@Tp<>M4!zeEqd}{eNj;jarzlLvEbSV0rUge6yUh&g|U4 zV}7L`75!X4Ro_=W{qMX|mH;3R?5|Y5o8F){@z9($@oeCE0ndx4l|RsOE4Sm@=}U-H znL-m6mO^e5?M>6Ivg|V6G-Fx%)I?iOUW;$0FPHh|fb?b#NK3y~2J7KWZ)O!%TT{+d zR!7*vyMecjcN1?1?Gal_ zQjDZ5qw%}5GU#_%?i-7>=WXh;P1;QsTY;&g9n#56djG91o(HiKEM0+w^;r)owbv;1QQ7%l!3q>sWJ2NJTeHvlpZGB zugTCOjDDWs^WmmRJ5V;Kr6c|lhRv^$*dSq(SSLYP&Mgvzf}GrPQR$8N&`(9(8;_E- z+=|`0gkVvF+C5qY58G@JfjO$CUP>!0ELuI_aFpK9;(!TIwo{kfyG}{E(!l4tYaENNcGrlrmI9YonwE)6~{F zGOdYF8=BQe&uFifX~-o+ zRp3-pn$^=X{Yw?tWT1>Xt!Ks=Q0MTYPn5JVZDuO9ChHPvRt7DkXt|j#lhciOdxJ(w zcq_xF8jyaL*EyGmF!HMLKOFrz%F>wrIM?=+Jw?MfbL(haruLp1sXo0%T%?b2NZWYv zxM^-nna(Qc(ZKkOVkVik0c1u!3-h8;mhr>?`b0VW_wzHvTvrM0a|@$aGpn3O5vqbf zRC^lmwz|WKD?Jc;FrEmDgN+ej?n(^lt`z`d&>`aYQLey%hIG$79FEU?&N!`b0X*+h zNS?3<6#6~{7~YIFRs8WdjEU)uc(5-jkKu3!2RJ&5vsFMBZ5#eB33*VzN6l0wV|a(~ z&Oyeyc5wo$qe(iRq=_)45fC-mRp_yPkcvu?DeAp`3>Ye6{9|hMp=_|=A5R~x-$fps zs^lI7juwuP^y)MNp*oI1Q+4zrlM^-Dd!bCd2GoV-ci zgxJ~X+C;Wu5FL;$$0OvML=>)j-r>ZLiX7>K3JZ_~ik2SpZ&3pT37jS%8l|rrgn&SA zF-N$MZ&M+18YQtKuP-62vs6WGuG9_dJ((ULY5tqjYFghptzMF}T4zQ{tKrm@v|3&p zfL6Cyjn$v1Zk07y69n6m1lz{8TNXDj(9XwUa2!P8jAt*j5)hcskKQI89Z$j`{$Cij zGHss$6ucG5m^HX4d`?Ze9lnfl@BxW7gn08J`}Wk-<~*5qsLC5xs!|f>3r+Sj?H3G+ z^u*VwiwjBLGqo1H`D#NdXc^z6ZmyHKLE;7K@M>z9@wlsXSF*ngJ&0%KLd;MeiHqHY zXUa>*1E#k`7ydd;rAuPI*qrWr_!}hNB=Hu+>_|yQCeIb^g0Ju-GXf1yjC<2BZjIVknpN{+c5D&T^)jM;}j{RE?*zK+psg5$N( znh6$3-pVE*e}!Ok&P*=lcKo-n&$#4fE4#!%i2Z$tdp}8}65Qn&EKCA8)$*DcrX8NP z4B`U0UN`~F1Bp!LyO(x884lt{3Hkeb0Q65=@C4N<=^QiuEA7yDSlp~&q|$*oL_xY9 z-G009ht6NV&{6-Db1&*AX)qd&-}y&snYyVcC*Qrw=`L@{w(muwz>ku55SX>K_w%38 z@1MVmsS>+;>?Og7hbOC&y^%KhGi_|$`aI{~p>gWMAT)Jr1;vv!giREzsN9iM67em7 z+>CE_y9KA@Xn~W4b+1CG&_He~&~=E!eoa9eI1)XWXT@5lombY7ao<(%mpl zMqF5vxKH%V6E$3JCZVK=QNiHGR0ns3@Gtj;fq@Qbg9HcUuHvpS#jOJQn5YSL^x;?H zdS5!6Vu`*-(B|TOaCfABAj?rId2GG@5&x7@99)O@LLTD6xO=_0lKT+p{0Ee2k!rjE z#5+UOZO&L=u@LT(B!^OevS)C4sLnOgM?QC}vTmOVyHxCmbDvtCka$e(l~*3LzujuF z?~i7IW4Fq~0Bv*8M1b5W?{OHw<92It_!%@sA6AaeM~>a*B8M6k2DTjIpHpA99FE*q zjtt7ZS##X);#v(G4Wk1=UiBYE6ka65%AO2vavT31}%i#bqW^=vq`CRL^h?3b3xE7KoNf zG%>3_#2222FcN@5ZAXF|91Xec=TK9ey+6Ys`0tbW69`;e$F$3_Fmlg!tzx(RB^7v= z#0i9J(K+w&@X$3WbX+9Jb)-GGU;u^d{})L64?GER4u#YY6>mz_o@`$QycXwX4zZ-g zxh6kG1iGL{9Y7cI(i7jem<9t*T7-)*t(Mv|r^R(LzGorbhW5%@xrXaoka%cNC+&6h zon1g;J;v9#M}W-(Kw-C5Xi=gSQN?qpBC+h(^~rxiMdpfG zPi^c%eR5lWi7NkzN0KH*-Gts=B)0d~RkXIi3~fPfXgf#%cX)_~d5;9$SV}s<_bIfX zceHYx2EqFz=HCHkZm1k2krCP;C9`WTf!E!4YLcXY)JvQ%a5X(-!veHxl9g{u1|Yw; zxhq9V9B=m~aX@H@i!3dUg|mnkzJ9AHIZ69bvP~axip$0od{Wt#pH#LbLQ{~_!2`_B zb7UpCAcZBO16f?lITWgqAeLG%toeKwOCqd$1xJ2|Jcfh_6b#+T`{gGn) E4}nRVUH||9 delta 1446 zcmZ8h+iN3L7@sqlnM`ITlS|WF)^0A{b|oOGRzX>%yIso)y67$nD^VGznW;@WlVs1F ztaVA~Lwpj1)B|2#Bz^Ezkb?N&L;V-_X_*%hd=Y&Xd=UJ;Nou=I_!BL#}r-BoA^>M7Z;E=b4y7~u}e)ES%2|^}(BeTo5#BxHIE*=R( z!OHWISk?-m91Qv{r8@sk924=l9=lEWiB`$`?QYxJwpnX8R4sSkZ4HJ zQ%opLQ+!4KLd$?@fCNkYM6U4PS zt?6O>f$c?pjsL9uxKc)H4xxw;M<^g%Mo1v2{7bz?R{2kQzx3omrB@)26729zW0AjW zRI3Kbku=GYtf-M9(TGY6(FC0RbL9AI<1g~wOEBnG7;E(hO_jRL+8KD>J5eL`!mLZl zg^AJ|z>C@mlYY-}J!_{&KenkuH!%@)e-o{!jV-iZ<-eNi^=L_C^fwTq2Cst^Ducdb z!!3=qY*tUuRemk?;?+CIxf`WDd)xJb3m1~!g4{O$Ds^AMOH7m8Oh0pgcREyIC*W;h zpU~$pgNqTd!=z=kJUa+luIE|SSwXn+4J6@3r#b*ifG?a%1Y%|lNxdG$>p6N2IrjiU zo%LvIx6$jvF`)P$G#B`{<&`&~&r141itSPmLL?1AH9QI;sGy+11T`J%J&iaQJ=?pK~}pr{{=8A#6j}Zqzx*^~JnN+syNO=01~a zW)Xe~9!vRMeoQW%qhNTvn%p#8JxLi}P*cxlQC^wJi1;T&wXF$L=6QnWSPlBRT;X;3 z$|tNX-rOtiR2oO}haV0R^C^e&z_M&v8xf-e*XqKvU0;WF4H)vKvU(I?I8{$!0bHy1Ve;c%eML(QH!2-ZGK5qisZ&CE z={7K@?6C$zNwqQ`3>CFLIm>XVyy;J%r&ec9fWCaorp@@^k*NC3jv|#HL2(LGj4Y~C OfHD-M9F5Zmh5i5%aZhys delta 341 zcmX|5PfG$(9Nc-i`+mFnvIuistptk{yBk3WVO@g$kkloJz#u3CH9OcYNOhE;-mR17 zp=%T6Y-q0G>T%5rl^~yuJh%jp4slofw?w7ox46vm7_Y*g&M>*iYs8PX z_qPt}W!d8?`OSCx0w7hSEU+M(=K5fujgAJ*)!NK8JMb*|Wj-1Of1o-mZ>$%4&9}92 zUo;Q`e`J>LEOiz<0qT;Ad+e#607FAPg#~b--Udj=gZhljGhC^J?Jx+HbNpm0FVO;~ zl{H#tD5z%gouRB&Qd0~qi-3}s80X@ diff --git a/sgl/tasks/node_classification_sampling.py b/sgl/tasks/node_classification_sampling.py index 838e9b7..7b9a205 100644 --- a/sgl/tasks/node_classification_sampling.py +++ b/sgl/tasks/node_classification_sampling.py @@ -1,5 +1,6 @@ import time import torch +from tqdm import trange from torch.optim import Adam import torch.nn.functional as F from torch.utils.data import DataLoader @@ -25,7 +26,8 @@ def __init__(self, dataset, model, lr, weight_decay, epochs, device, loss_fn="nl self.__seed = seed self.__train_batch_size= train_batch_size self.__eval_batch_size = eval_batch_size - self.__mini_batch = True if train_batch_size is not None else False + self.__mini_batch_train = True if train_batch_size is not None else False + self.__mini_batch_eval = True if eval_batch_size is not None else False self.__test_acc = self._execute() @property @@ -40,9 +42,10 @@ def _execute(self): pre_time_ed = time.time() print(f"Preprocessing done in {(pre_time_ed - pre_time_st):.4f}s") - if self.__mini_batch: + if self.__mini_batch_train: self.__train_loader = DataLoader( self.__dataset.train_idx, batch_size=self.__train_batch_size, shuffle=True, drop_last=False) + if self.__mini_batch_eval: self.__val_loader = DataLoader( self.__dataset.val_idx, batch_size=self.__eval_batch_size, shuffle=False, drop_last=False) self.__test_loader = DataLoader( @@ -63,24 +66,27 @@ def _execute(self): for epoch in range(self.__epochs): t = time.time() - if self.__mini_batch is False: + if self.__mini_batch_train: + loss_train, acc_train = mini_batch_train(self.__model, self.__train_loader, self.__labels, self.__device, + self.__optimizer, self.__loss_fn) + else: loss_train, acc_train = train(self.__model, self.__dataset.train_idx, self.__labels, self.__device, - self.__optimizer, self.__loss_fn) - acc_val, acc_test = evaluate(self.__model, self.__dataset.val_idx, self.__dataset.test_idx, - self.__labels, self.__device) + self.__optimizer, self.__loss_fn) + + if self.__mini_batch_eval: + acc_val, acc_test = mini_batch_evaluate(self.__model, self.__val_loader, self.__test_loader, + self.__labels, self.__device) else: - loss_train, acc_train = mini_batch_train(self.__model, self.__train_loader, - self.__labels, self.__device, self.__optimizer, self.__loss_fn) - acc_val, acc_test = mini_batch_evaluate(self.__model, self.__val_loader, - self.__test_loader, self.__labels, - self.__device) + acc_val, acc_test = evaluate(self.__model, self.__dataset.val_idx, self.__dataset.test_idx, + self.__labels, self.__device) print('Epoch: {:03d}'.format(epoch + 1), - 'loss_train: {:.4f}'.format(loss_train), - 'acc_train: {:.4f}'.format(acc_train), - 'acc_val: {:.4f}'.format(acc_val), - 'acc_test: {:.4f}'.format(acc_test), - 'time: {:.4f}s'.format(time.time() - t)) + 'loss_train: {:.4f}'.format(loss_train), + 'acc_train: {:.4f}'.format(acc_train), + 'acc_val: {:.4f}'.format(acc_val), + 'acc_test: {:.4f}'.format(acc_test), + 'time: {:.4f}s'.format(time.time() - t)) + if acc_val > best_val: best_val = acc_val best_test = acc_test @@ -98,7 +104,7 @@ def _execute(self): def _postprocess(self): self.__model.eval() if self.__model.evaluate_mode == "full": - if self.__mini_batch is False: + if self.__mini_batch_eval is False: outputs = self.__model.model_forward( range(self.__dataset.num_node), self.__device).to("cpu") else: @@ -109,7 +115,7 @@ def _postprocess(self): outputs.append(output) outputs = torch.vstack(outputs) - # NOTE: self.__model.postprocess now directly returns the original outputs + # TODO: self.__model.postprocess now directly returns the original outputs final_output = self.__model.postprocess(self.__dataset.adj, outputs) acc_val = accuracy( final_output[self.__dataset.val_idx], self.__labels[self.__dataset.val_idx]) @@ -133,3 +139,109 @@ def _postprocess(self): acc_test = accuracy(outputs, labels) return acc_val, acc_test + +class NodeClassification_RecycleSampling(BaseTask): + def __init__(self, dataset, model, lr, weight_decay, epochs, device, loss_fn="nll_loss", seed=42): + super(NodeClassification_RecycleSampling, self).__init__() + + self.__dataset = dataset + self.__labels = self.__dataset.y + + self.__model = model + self.__optimizer = Adam(model.parameters(), lr=lr, + weight_decay=weight_decay) + self.__epochs = epochs + self.__loss_fn = getattr(F, loss_fn) if isinstance(loss_fn, str) else loss_fn + self.__device = device + self.__seed = seed + self.__test_acc = self._execute() + + @property + def test_acc(self): + return self.__test_acc + + def _execute(self): + set_seed(self.__seed) + + pre_time_st = time.time() + self.__model.preprocess(adj=self.__dataset.adj, x=self.__dataset.x) + pre_time_ed = time.time() + print(f"Preprocessing done in {(pre_time_ed - pre_time_st):.4f}s") + + iter_cnt = 0 + val_score = 0 + best_val_score = 0 + + total_iteration = self.__epochs * self.__model._num_iters + taus = self.__model.generate_taus(total_iteration) + tbar = trange(total_iteration, desc='Training Iterations') + + iter_id = 0 + generator = self.__model.flash_sampling(len(taus)) + + for sample_dict in generator: + + batch_out, batch_in, batch_adjs = sample_dict["batch_out"], sample_dict["batch_in"], sample_dict["sampled_adjs"] + batch_x = self.__model._processed_feature[batch_in].to(self.__device) + batch_y = self.__labels[batch_out].to(self.__device) + batch_adjs = [adj.to(self.__device) for adj in batch_adjs] + + for rec_itr in range(taus[iter_id]): + self.__optimizer.zero_grad() + + recycle_vector = None + new_batch_x = batch_x + new_batch_y = batch_y + new_batch_adjs = batch_adjs + + if rec_itr != 0: + recycle_vector = torch.cuda.FloatTensor(len(batch_out)).uniform_() > 0.2 + new_batch_y = batch_y[recycle_vector] + + self.__model._base_model.train() + pred = self.__model._base_model(new_batch_x, new_batch_adjs) + + if recycle_vector is not None: + pred = pred[recycle_vector] + + loss = self.__loss_fn(pred, new_batch_y) + iter_loss = loss.detach().item() + loss.backward() + self.__optimizer.step() + + iter_score = accuracy(pred, new_batch_y) + val_score = self._validation(iter_cnt, self.__dataset.val_idx, prev_score=val_score) + if val_score > best_val_score: + best_val_score = val_score + + tbar.set_description('training iteration #{}'.format(iter_cnt+1)) + tbar.set_postfix(loss=iter_loss, train_score=iter_score, val_score=val_score) + tbar.update(1) + + iter_cnt += 1 + + iter_id += 1 + + final_test_score = self._inference() + print('best val acc: {:.4f}'.format(best_val_score)) + return final_test_score + + def _validation(self, iter_cnt, val_idx, prev_score=None, val_freq=1): + if iter_cnt > 0 and iter_cnt % val_freq == 0: + val_y = self.__labels[val_idx].to(self.__device) + + self.__model._base_model.eval() + val_pred = self.__model._base_model(self.__model._processed_feature, self.__model._norm_adj)[val_idx] + val_score = accuracy(val_pred, val_y) + return val_score + else: + return prev_score + + def _inference(self): + test_y = self.__labels[self.__dataset.test_idx].to(self.__device, non_blocking=True) + + self.__model._base_model.eval() + test_pred = self.__model._base_model(self.__model._processed_feature, self.__model._norm_adj)[self.__dataset.test_idx] + test_score = accuracy(test_pred, test_y) + + return test_score \ No newline at end of file diff --git a/sgl/tasks/utils.py b/sgl/tasks/utils.py index 9cc05ca..efa8446 100644 --- a/sgl/tasks/utils.py +++ b/sgl/tasks/utils.py @@ -1,5 +1,5 @@ -import random import torch +import random import numpy as np import scipy.sparse as sp from sklearn.cluster import KMeans @@ -45,8 +45,9 @@ def evaluate(model, val_idx, test_idx, labels, device): def mini_batch_evaluate(model, val_loader, test_loader, labels, device): model.eval() - val_num = 0 correct_num_val, correct_num_test = 0, 0 + + val_num = 0 for batch in val_loader: sample_dict = model.sampling(batch) val_output, batch = model.model_forward(batch, device, **sample_dict) @@ -87,8 +88,8 @@ def mini_batch_train(model, train_loader, labels, device, optimizer, loss_fn): train_num = 0 for batch in train_loader: - optimizer.zero_grad() sample_dict = model.sampling(batch) + optimizer.zero_grad() train_output, batch = model.model_forward(batch, device, **sample_dict) loss_train = loss_fn(train_output, labels[batch]) loss_train.backward() diff --git a/sgl/utils/__init__.py b/sgl/utils/__init__.py index d493058..369b4d6 100644 --- a/sgl/utils/__init__.py +++ b/sgl/utils/__init__.py @@ -1,5 +1,7 @@ from .auto_choose_gpu import GpuWithMaxFreeMem +from .basic_operations import sparse_mx_to_torch_sparse_tensor __all__ = [ "GpuWithMaxFreeMem", + "sparse_mx_to_torch_sparse_tensor", ] diff --git a/sgl/utils/__pycache__/__init__.cpython-37.pyc b/sgl/utils/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44492ea7c210248797f18894b603a9ac524ea91f GIT binary patch literal 287 zcmXv|!A`?43~kbNAhZeg2ON;59w0w}5aPftjLVRada<%LY9l30i4&>&Bz}S;AJHo( zet{Ft1eX2$Y(G1BST4^pvd`6v`k?%kihtuo?kVt!pvf9$GRrE)66c27yvm8^&8+3s ztSNR`IsaxPE!ni%4DnTmZY#$}uhdpqYF^?XJu0vx1V`@c4*ps}_2|4$FZ!I|xh$%4 z8A(=m&e2Xg#N)KO*h|zkxIuXtwCnMxpH2XzF#zwGFVc|&KWouBs|4bk*lnJ!&0!SS h8WDpwC;)V?1HjF2m`-o4Yobwi>H9b*Sg=LT*)Jv0OFjSq literal 0 HcmV?d00001 diff --git a/sgl/utils/__pycache__/auto_choose_gpu.cpython-37.pyc b/sgl/utils/__pycache__/auto_choose_gpu.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6ac54bb92150947568708684cc5d35b04319970 GIT binary patch literal 1163 zcmZuwO=}xT5bf&e8Od77ag4LX%kDBPAutC?Ib?6f?B>HxAa)j1+?lO6TiVeCKJ zTnvGWKX94{Xe5&mvjheJOo8H?O!5yP9Vw2*z&ZCQF5M4Yda{8&Av?-N&kA~4@`q!0 z;K}v@3syhjYb=0DJc^9Pucwur>U2bI8kE+>c^9Yo7mWnTVzxN{!U7~O;gqA-S@QQx ziW6tSiyI|->70s^FF`-}Uc}_)Q&zH5SGpzMI>9(#HB@H3lFQ~1KW!}8J$3{Or;c0s zKdy8Tcxzt+q8W6?`E;15WEdm)i-9+KiGZLjb}yQjN|^)MzQIKQ=1J( zhgoLS5Nqn?v$NI~6*bC|e0!BpPzn#`{;smS`SiS+-d3EA!q{x%;v%>LJ$H2C2hECS z1}P0wnI%<&z64L?p*^fTRp13YHZx;YQi&{AX|Q8XAL+;Zh=xHEntTIR;W`)41`pO@4fHiU!A1!JuTk^t3jss~ z-L92KAW#R)LWxaZ0=0Y}@3~PbRZ%%fSOlOc&$no zsp)x;ZK-$#DrTKKiLqyPcE0g9GoFjtY{E!>-oH}6IAg#4aA#A5dj}^WxIf*UDld6UcS5PoR+&R8mQ_ zdY7Of0FBMpcL1syj6I*a9;0vEWMxg2)!VAsyZ9PV z=AT_)%~u!XrgV|6@%b|s780ck7y3rI2-cFg;zJki_>Jeg8cNQ?<W-o&0AW`lRv8$`spD Nz4EIDpF8Gb@eh+PqgVg{ literal 0 HcmV?d00001 diff --git a/sgl/utils/basic_operations.py b/sgl/utils/basic_operations.py new file mode 100644 index 0000000..5766ba3 --- /dev/null +++ b/sgl/utils/basic_operations.py @@ -0,0 +1,11 @@ +import torch +import numpy as np + +def sparse_mx_to_torch_sparse_tensor(sparse_mx): + """Convert a scipy sparse matrix to a torch sparse tensor.""" + sparse_mx = sparse_mx.tocoo().astype(np.float32) + indices = torch.from_numpy( + np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) + values = torch.from_numpy(sparse_mx.data) + shape = torch.Size(sparse_mx.shape) + return torch.sparse.FloatTensor(indices, values, shape) \ No newline at end of file From d603f885b5499da2fd8525b71dda0b4bc0921fd0 Mon Sep 17 00:00:00 2001 From: infinity Date: Fri, 17 Nov 2023 07:11:11 +0000 Subject: [PATCH 05/28] update lazygcn to lazygnn, and now lazygnn supports mini-batch evaluation. --- examples/configs/fastgcn.yml | 4 +- examples/configs/graphsage.yml | 2 +- examples/configs/{lazygcn.yml => lazygnn.yml} | 10 +- sgl/data/__init__.py | 3 + sgl/data/__pycache__/__init__.cpython-37.pyc | Bin 617 -> 699 bytes sgl/data/__pycache__/utils.cpython-37.pyc | Bin 644 -> 2786 bytes sgl/data/utils.py | 49 +++++++- .../__pycache__/base_model.cpython-37.pyc | Bin 9606 -> 9740 bytes .../__pycache__/simple_models.cpython-37.pyc | Bin 12911 -> 11525 bytes sgl/models/base_model.py | 4 + sgl/models/homo/__init__.py | 4 +- .../homo/__pycache__/__init__.cpython-37.pyc | Bin 763 -> 763 bytes .../homo/__pycache__/lazygcn.cpython-37.pyc | Bin 2728 -> 4364 bytes .../homo/__pycache__/lazygnn.cpython-37.pyc | Bin 0 -> 4484 bytes .../__pycache__/vanillagnn.cpython-37.pyc | Bin 1266 -> 1266 bytes sgl/models/homo/lazygcn.py | 62 ---------- sgl/models/homo/lazygnn.py | 106 ++++++++++++++++ sgl/models/simple_models.py | 37 ------ .../__pycache__/utils.cpython-37.pyc | Bin 3828 -> 3796 bytes .../laplacian_graph_op.cpython-37.pyc | Bin 1160 -> 1090 bytes sgl/operators/graph_op/laplacian_graph_op.py | 5 +- sgl/operators/utils.py | 5 +- .../__pycache__/sampler.cpython-37.pyc | Bin 13930 -> 14035 bytes sgl/sampler/__pycache__/utils.cpython-37.pyc | Bin 3162 -> 1869 bytes sgl/sampler/sampler.py | 18 +-- sgl/sampler/utils.py | 31 +---- ...ode_classification_sampling.cpython-37.pyc | Bin 8043 -> 9688 bytes sgl/tasks/node_classification_sampling.py | 113 ++++++++++++------ 28 files changed, 264 insertions(+), 189 deletions(-) rename examples/configs/{lazygcn.yml => lazygnn.yml} (84%) create mode 100644 sgl/models/homo/__pycache__/lazygnn.cpython-37.pyc delete mode 100644 sgl/models/homo/lazygcn.py create mode 100644 sgl/models/homo/lazygnn.py diff --git a/examples/configs/fastgcn.yml b/examples/configs/fastgcn.yml index 2fdcbb7..5169934 100644 --- a/examples/configs/fastgcn.yml +++ b/examples/configs/fastgcn.yml @@ -1,8 +1,8 @@ dataset: classname: "Planetoid" - name: "cora" + name: "pubmed" root: "/home/ssq/test_data/" - split: "official" + split: "full" sampler: training: name: "FastGCNSampler" diff --git a/examples/configs/graphsage.yml b/examples/configs/graphsage.yml index 6a7f18b..eb2ade1 100644 --- a/examples/configs/graphsage.yml +++ b/examples/configs/graphsage.yml @@ -20,7 +20,7 @@ model: num_layers: 2 task: name: "NodeClassification_Sampling" - train_batch_size: 512 + train_batch_size: 2048 epochs: 20 lr: 0.1 weight_decay: 0.00005 diff --git a/examples/configs/lazygcn.yml b/examples/configs/lazygnn.yml similarity index 84% rename from examples/configs/lazygcn.yml rename to examples/configs/lazygnn.yml index d3be737..7950302 100644 --- a/examples/configs/lazygcn.yml +++ b/examples/configs/lazygnn.yml @@ -12,16 +12,18 @@ sampler: replace: True post_sampling_op: "LaplacianGraphOp" model: - name: "LazyGCN" + name: "LazyGNN" + basemodel: "GCN" hidden_dim: 128 dropout: 0.5 num_layers: 2 - max_workers: 5 - train_batch_size: 2048 + max_workers: 8 + max_threads: 10 task: name: "NodeClassification_RecycleSampling" - epochs: 20 + num_iters: 200 lr: 0.1 weight_decay: 0.00005 loss_fn: "nll_loss" + train_batch_size: 2048 diff --git a/sgl/data/__init__.py b/sgl/data/__init__.py index fa5d597..74853e6 100644 --- a/sgl/data/__init__.py +++ b/sgl/data/__init__.py @@ -1,8 +1,11 @@ from .transforms import random_drop_edges, random_drop_nodes, biased_drop_edges, get_subgraph, mask_features from .transforms import sort_edges, add_edges, delete_repeated_edges, add_self_loops, remove_self_loops from .base_data import Node, Edge, Graph +from .utils import RandomLoader, SplitLoader __all__ = [ + "RandomLoader", + "SplitLoader", "random_drop_edges", "random_drop_nodes", "biased_drop_edges", diff --git a/sgl/data/__pycache__/__init__.cpython-37.pyc b/sgl/data/__pycache__/__init__.cpython-37.pyc index 6977bc1972dd6fdcc7135530ae600a2aab3ab635..f7dfb1087ff8bf4cdc0c955df1617b9bc6a0ec45 100644 GIT binary patch delta 216 zcmaFKvYS=iiIDBu z3@NNBY$@z194VYBTq)csJSn^>d@1}X0x5ziLMg)O{82(FBEbxrq7!EwRo7&?#S@g6 zmy)0Blb@K9T6Bv$xF9F91i~o-SyLncBqql*Ix`ASUd$M CUCnNy6Ay2TzJpO}*q lAHR~JNCc!_WHK+46}JkIDZ?nkD8eW(IeG{R=$t9nbD2t4$FSUio;AJ!7BC_Z{zdOG{0L_WPgr^-i6! zzet(35Dz!e?XM7uDW0)zHS(_TdN2sPHO`b!!4uYvRH$kg>)$gKsrnP9>Ib45LyJ`d zT0_SNLAL?DsanumUT-SC&XV>soWa&oP6L`ir}J;2+nW#tD_Ox0gt@kNsbsg=FK5QQ z5Qjm@kGSp79y^_%V97f3v+RhQ8>G>SVmKjKyK^rIU9)fdLtFHQsdi!3w}lIGJ93fU z!?R69x3SUB^r%;CSr<&S0n2?5R5ndFXo|WqMHWhX3j=XX3l;IrDQ*gF&~S99rBWwVnZD9ao{yk za<7F2+ZJU|h8se0C4Lj*T3PdM8j4}UU7-4dByhDNH|ds(w%~O7(AE5qM6RjTrk4Go z(t9p2`6CykdFJXv{YaALA82NoW{%(Wfv7}`bKFvq%iW>s8=V$foxzm#h~hC{#_x(mpfwKaeN?hbncEv0KpImp#65G-9=%AMU_R;`$e*uz{RK`}Yuj1L`w!JLDj zH!LAK=5ZJYelCSNYm_=__5+A%`k%55?jaB_Lf|RL(T$MOx@-T~8>)P83K%Hnw6lE= zqF@IT<{Sx*f>81Uu;@IB!399%68I$ppCGaJ%(Eo~FLp=Tcqd(u>7i>4cL&l#M%yH) z5;%hk=+%%i&3e{Kc?=u>`8#Aqg_vc&f9|Bki?|_^DOeOC`zgfU3RSvo7>pWsnQGj2 zupMHih8}sU2V+r+`+(a{8x?O-+h9@M=uy^7HR4n=z2T-dSFoWILzFk@h?|Sjt-PP= zW86@7ilbby?EpQTc)?N7 zYE-{{*}4Okyp30r$h?ceP##S}Vit9fHaA5gs${%RYxA`A$ovrV9)QHyj#$V*h5Dp! zUZo9QBk?+kd3?L2Nq$pJ1UfzlW?rS&eRPPqWb_2^?J$lE%~3jMEtov{{QeiftGbF& z`yvl~s(%gFS|4To;@@z81E(zFz5sX*guPk>harU;fC?z+1g0Y@mi1||jLnc+)^!x+5 zzgsC1CBHbiRb5x#P-5&d1nwQ)7V%OXFSj}`{sXH19iRXJ delta 127 zcmaDP+QJ&*#LLUY00er4ZmBv<3=EG!92fxd861GPSOZ9;Fr+Z%Fyt~uF)>2e%u&oK zOexGQ3{fm8EGeu&nl*(nm_d_mW7j=KCQY`<4P2t!+(4a0oFGDE@+vMhE;b;OhmnJs GgBJkbHxrNm diff --git a/sgl/data/utils.py b/sgl/data/utils.py index 14d4a12..27cb1cf 100644 --- a/sgl/data/utils.py +++ b/sgl/data/utils.py @@ -1,6 +1,6 @@ -import os.path as osp import torch - +import numpy as np +import os.path as osp def file_exist(filepaths): if isinstance(filepaths, list): @@ -22,3 +22,48 @@ def to_undirected(edge_index): new_edge_index = torch.stack((new_row, new_col), dim=0) return new_edge_index + +class Loader: + def __init__(self, seed_nodes, batch_size): + self.seed_nodes = seed_nodes + self.batch_size = batch_size + + def __iter__(self): + pass + + def __call__(self): + pass + +class RandomLoader(Loader): + def __init__(self, seed_nodes, batch_size): + super().__init__(seed_nodes, batch_size) + self.num_batches = (len(seed_nodes) + batch_size - 1) // batch_size + + def __iter__(self): + for _ in range(self.num_batches): + batch = np.random.choice( + self.seed_nodes, self.batch_size, replace=False) + yield batch + + def __call__(self): + batch = np.random.choice( + self.seed_nodes, self.batch_size, replace=False) + + return np.sort(batch) + +class SplitLoader(Loader): + def __init__(self, seed_nodes, batch_size): + super().__init__(seed_nodes, batch_size) + if not isinstance(seed_nodes, torch.LongTensor): + seed_nodes = torch.LongTensor(seed_nodes) + self.batches = torch.split(seed_nodes, self.batch_size) + + def __iter__(self, *args, **kwargs): + for batch in self.batches: + yield batch.numpy() + + def __len__(self): + return len(self.batches) + + def __call__(self, bid, *args, **kwargs): + return self.batches[bid] diff --git a/sgl/models/__pycache__/base_model.cpython-37.pyc b/sgl/models/__pycache__/base_model.cpython-37.pyc index 5f8f0077213fbce45fcc9af446151b82fd8da161..66018317710c0c8ebd1a63b892c6e70591dd40cf 100644 GIT binary patch delta 1127 zcma))NlX+`5Qe)3nCaQaA#oPg8J5_nD1tH|NB|)!OElmB<1&nPE6D0=KtPs*sELwb ztdkx*c;Tke2?q~ej0q>V#3eyJaw8^2PZ|@eny8p~u#@`o-e2!kSG}q$P4msJI;Ybn z(USREpyT(z#?ndt9sN;b5hKH zRZ5j&jy6e2$EbA-BKXYu5-RbIEeO^4$X1NwwgT|*zND0b9wyUG1j#fRs!QlvTW!nX z=7yFbRqHPv7DJ~sUDwq387(m~Dsa`-;@L}QS7bdgFx=lecve(loxL!vhQxLpu)l*^ zY;aUS3}+p8LBp`~CiEqjoSy+kaM9fajrhgg7T8C}g%QYM1kyx&FfpJ}u(&7tP(6F| z35rX2d3}JfopF%S#%N}=5MtK#5Ndo}a5X|+>|k_aG#H_0TnN^}GQJ8{`uUt-v?}t(`fxhm4@Loj z4)R_j(|NpPaCjuJZK2dVDD@98&1t81l1sQ4#vfy}LT7pjyv2K&HB`x`%(eA`x*_8@ zITa8_?moU^6tIg{-!PgxQ^w;e49DNzAEX=8W(E=FGjRqmx={{QLO0fj!huE>{&C%n zGofsAg1shO427$Ecxhq$d8Ko_n$uIklhUQMw#W!Rd8(K;YWnFQL8GM<4-X7uT9SrJlL3v7}x9KPIr1LBQ delta 1065 zcmbV~OH5Ny5QaNY`nsi*R%vOel(rCW8ta2skw#PEa(dib6sa%}bcksJCkpnuPX}3IlqNVL_!hYOWRnSPmc3hXShhc!=e;>+lF4+b%;l*4ytx4}P)tZ92t0?2HP+O|ncAn!N|`PU#pt zMT=v}7Nwx972TXPz}U{{!c|8N4REt_KRm}SXWdTD6O7}s?qAB34kEx=eug5cAVTgu zt~ssr>n)ZDhRfz*hD5l`U7W;1bFEL}&$1w_V9>RXgpIn^bNYJ0*+V%p2t|sG^A$b2 z$or0~?m0M(iw3Kk*BfM+Jc0&nbdSLs~%4!CuB?R zrj6ijkE@O4iBZNFW1LaWa5E+tlZ+|GG$C7%8Hm#r;#vHei!u^E;x&}Kub>1!c&8-g zPk*0g%;$iMIPGhIQheeI{@-c3XQ;y)vJ9)%fzhl2Lp5Daj>|=279*7xq!29xpSzgB LmC8_lvnBis@KWVi diff --git a/sgl/models/__pycache__/simple_models.cpython-37.pyc b/sgl/models/__pycache__/simple_models.cpython-37.pyc index 7408637a28c5d303c03a7aa3473082d1d2b871d1..48321029acb979e874842a63aa56721c40fe8ce4 100644 GIT binary patch delta 694 zcmZvZKWGzi6vyx7uIXJRSKH=nVl@~pF}aIDQYk2I+N}~r2f?Kt=B$;ZX?}Nw`iFQL zTmMu8;SD-ExH&ik5y8R1&B?*#5XC_R2L*BH)cRh68N`F%=XZR6-*>dARdCAFaR{$x3u*;Zdr)oo)eq$Cfz*4SZ#0z{gMK zVPbt>5t=JX9cddMube8XfhN=lyRI8$_p1f!Q)X) zC3hMme3^3~z+`>~wC>4#5oTA}r&Jd`Kh)f%m6|IY-xUGZc#KnwGUE;*+)=GB%zO2V ze$836hX1zw#wtM!w~8hjwxQf0Bfi%rVI7~RGX^)dP0QoHRM0gVp-Ceo@v;4Nu1(@U zZTD~QwXN|f-)^b>yKb^@i_s#4LsjRBE5vOU$`~Y5kiu_;^LD_3POu{Vx-}Mz0e11* zIw99>9EgL%vk99(zif&(O>i*cF5?~}VEiW$ai0zbNw8^y@qm$LEHNJ9GUZXAi01mK s>~5317Y3=sIThBd34BsI3Lbte9f2l(DNf*_=>+V-6Vp@5cy>JX2XxJwr2qf` delta 1775 zcmbtUO>7%g5PtJ^*MCWzX!Bz`Y0@NiT5o<@iquewR3s85Qi|Lvtw6Fs#`4>gxQG;()}=I&OY|6D#evv-4)= zee>STY^8q{;pbXfLJXfzQn&2MCw9Y89);cH*U-BAM>09hL{!8!nTXk1-e1+y?WpM% zJ*eri+ws3cz%6E-IY_|6zDg3q*enf#rMd8LvpEo4Xq&;bCIZ3`zD-E$+e~L^1AT(E z(5#)Co=8oUipx_uojBk_Ujtta-*ObeVGg*g%J;YnE_bzk4VYjU08*AF$xJR1=uD0G zUWk83ejI_pz4oIGE@U$E#d#-_N#%6)>0oaKW9IojU(+47r(s+jWy943yBm~Gx}GH1 z7N6rWhI*FQy_G3kK{0ShasDZ!d-$GL4n{lg5{H~BrgyY5z>scu*n8Nhc7vZPcDfb$A(!m8F ztjR_7sAu-z_p>tqf2wB(E`wA{0|{7C9}ZkWyxtpwU)0Y7^G$D{holXqtFnWA@RwR2 z{5Ve9no)i=`H&I0QEdQUq@y)S3#VJDGB6&RNK=SpNF)FDv)_B{-u#KYvAEM@lSCwRoMyn`GVj7`jN>Mi@UB8}5KCv6sUL7RVm@1SO>|-zIp6;J}am;8vDMZe%9-CvP*2eI-%-rl-~WZD!uYI2rkeBaJd6jF6GLM!(A?zfGzC%DonDM z$6vy^u;#hU# z?8-H7#`Y{4O9r$13-ai@FrI;1ukPo=Jz}wm3I2nzUd;xRAGWx$I6*z1k zjOLY8I_yoa`tWi}(krp%f7I9wtg)_(QQiSK!5|7|eTp^ir7^|0?54mUcwD`47Nx7d zPS^Q(oRP-rRFk96fTPw&pKhaxBfU}LHRrDSa`arHAC&=|nayBoX0O?2cA0)NWG2lK Nv#+hy974_6zX2+Uc?U4}Wn}c5T*>%aA&NaGv8pmXFRzFXD0hoJK0YxgCq8~9 PLy-`WUnDfSg{d3>Ju(s- delta 63 zcmey(`kR&8iIHs%r= diff --git a/sgl/models/homo/__pycache__/lazygcn.cpython-37.pyc b/sgl/models/homo/__pycache__/lazygcn.cpython-37.pyc index a22cbeb58f2f4c380bbdbd793e87edcf3fec818f..f528da64bd668f8119070189f1fb4d58d1b407aa 100644 GIT binary patch literal 4364 zcmcgv&2JmW6`$EJE|-){$?~_83P@v=Z6c?iNekB>No)hHtvHbp1w@SPdS@k(UhYz} z!&nlQKmi%Bd@zFE+McN1`acx?4|*-oK@UND2z=}{m-hE&MNw)Dv_OF_xo_UgyqTT% z=Dpv%d3RKm-#!Wqsb^eXC<}s#~eucRG&Y?bPji z9nbJiTI>5A-|%i)@6;jdra>n#WuwzHrOOCqqrMow`DZg_+Wsx#bOZ2A_@Cj zp~~@>@kV&iq>`-2<#J53PoDY#Q+xy_VI3}5#}a3svNbO1B6w2o*y5aMh~^XCam0+6 zeZo4fI4@da4lSOT7bnqD6AR)LN?)87XHeFkGKX~n=yE~V?nFD=P-OYVFY$>ltM7er z?S%#Ig-hh0c)tAYZ|M2QHEqX(Vce|57?m}=Hr`?fL@~unwrh>8ea4xx1sB$Hw1iG+ zfx4v)>XkLnTImb>kt5$J{VY)axW3PXbDp6mC>x-SvJToT1JD^YTk_tC(gAIiUg-+= zxjUXK8~Zf(ykJnpTVBTk@O*I%cIHyj;v_jw+tqHimPATB*+y4rFN@QtD0D4LqwPe3 z5i%d-L)AX1-C{UMq$H)ZABNqmtHMx{9rR2%kV&{sTN&m9eJWHkLVva%7EynY5`*Ts zF!?M>k2N*JezX-n&gCQem}@p&pIA>+m<{`(+K@>k3SCJ}YPC=#pLOFzJE7djK~*%A zWMX0Q-m} z3q`jt7wBsN&3|B~#!)p)WwKhE)@Wmi+N+k09jJ%Zy29}H?5j5}Z{+>ta#3tvUb%he z&2)R~aXL*x*=69ZZ7C6h|3w#IRWjuFj zvQ05<6FpQ?GVv(mWCf9-2vh59F)hO4=<{oc9M;NwT0-d3@*6Waki21DVH3Q7m0a^EgV2 zt6(NW{%2+6_h}`hiTnXb+mb&d{zt^y@@?Xc0w+{bsF!LWe@wD?50>6ODbJv$XTvzp z3MGfJ3M0{jvuC;N6DvKM_BA4}lH^=?2nPwmDprLKFg6}anPf`W)`nFe&)+vVa4*l( zn+Mj_b~jV<0!?+1rgCMJttZ+}lT6P~rzPO&(@msR9>lDVs6Fj?{36PWh~28(ng9CE$K(|B zxOmBU@8K;32)S&C*c(}R?X6`T9{drLpVAT2L-dXvlx}J5TNIrmM|pHWl=fL1KqKm7 z<67zLbIvf*d7F*w(h?R*YqPbAxqEbM*vpb)<^_?quU!J#Lf$}4E`n$)8)$csKZaj_ zgX(s_n`ti^43bPtjK=S4{-_hE93f!FcU;U8cB@D@3FLEQ5U< z`~rT~V4V-V+c`C{)Uh$QV#;~wp^{!>t#5%Tao7Um7D_f|umY^I0v*A#p0O9z%|U4` z--7pl6GP-}B3D4#j@kZB?O~*lk!*E>ih(y29e;=eWo?+IU(?6nces|z$B`5iDg|vI zdaLZ$c{h@fQxxJLHK5mvRyyBu|QVCl{0KqS54#>ymXDT)G@ve z$8gn5Zx(q5>$QaQoPU1(88={G0kj9OpDXJXIFEv@S4nz2Uk3Yt{s*(8M%fsj+-Dyn z^V$64Dw{(0<`v@MY_({cf_#PMwviJ^ znnT{DYAq@vDWh%e5Q^)1S0z#b(g1o<91r`$6y7gMErZCqStJosVC20j5mnTEK%G~K zTqAOw$ax~~g6L*c0D6&bB^aPRq(H+|wVRdq9z-C0RIWn0<0Du@s%oEd7vefV-(jXI zzD}+2cAO@iIwf}%8xS~8nhkFTPR(f0)~k@a!bjVDuL8+WXt4$wXOG*F^eIRch5&Ez z)7-ZLK5v5Fx90hIygR2~z6HmQLIO$?_GEWu0vyQ+e`Lv93IMkYz}vc{?4C1r;mMxy z&V><07f1Z?vA@p<6gMxeGRPx!0YIV;{^?X;xbxa-)Kxex77^NuIKR4ZscxPv!kk?L z3`PP6I8hSc>wk$Xgb0RlIEo%n}D3~WxZ zvw3)oVqG6XpdbK2a%lV30@6X>+F6*`?x^Vh*>bNOwHzJC$8C20{;{T)S?eY#FNpk` zt!`1)#`%hIRs5Hg#~lWY)2ww5!8ZL{Fo*77n= zPaGx7ro=bHaB~=?RgX#P<Ta=AbO|h~`%ob%*Ql;h zr;f->Z%opGfIdxqIqrcGa@=BRa^0#J3IR!c_Z6`+NME+qSzu?rRXP6*G`=V>OoOc+o*#9Qer%$R=b*y z70FFj!t6-`7{eHbnTC4BIc@G%}>N^C&GvRXj{hb_Ct{k<3oB zT<(2v0KJWesg{p_`_aeyAAVE`nZwZLemcw3Y?79HDxDquVFqJw=~=3@WcIyt%Ufl)SHybsbZq~s$?sdxyp!Y`Y1Kp}ybtr-_gR9$h?;6j$BM4|09*p5q zP`-Kq=A}w?fLCoG!@vgRp-heS%cF^~VVUKr*0xpV>8Vt}2vyChxf$NIfu7H#QmCkn zIG>ag!?{8O*&d%M$;bGDyqekD+^7`h%Q4q!G0TyG;agmOmgcVwb$O94_~(^6#=`>l z)^=ko4KL>fH%CgQLfb}a-D`0nKbvII`dl4VKutPNVz*#oU>k)GXWQ-yd?FS$<^^2l z8M>sxC$e4K9JQV+sci>)PHpHgQ`ruG27by5=c0Y1M-5cX-A&Z!XRG zKgl=m>>gEx+|~MI_wnBTyZPy2SC8{ur-H`0YBvY}9%tofcB<~cCAv`ccYzSvr(GJl zZ-<7g;rfixEqVsw>7rt^e(wwohy`1$A3)VV15%SIxuP|7e#Z=W_>`T~D<9gR8LQbV zz0i3kJs7ssU07}VJgZ7=)I2jh6;n{JtW<%l3SnOT5XqY;IRIy?GO4u`>qi`V>PI-` z;Yzwtc%O#+9EW4@&ea;+LF0y~HYhk)D|jQ4q5oh(F|1C-KLax+OD5eR3+RN1e+Yqmws=`t`u9n`e;w*d&20Vp5VUbozd+hCrUMzqV$Zm&)B+M1VELOj59l3dbDcIXLD_!klEED2=cYuuHkv|2q>Lo}+^Fn*EKP z_>UVu8!{W}`LkkT)DJ)!x@`2GrF`hwNG@bH2V}Fomtc4}FEjUSH%|IZd@rA9lU2p+ z{;%=W<_c#&UR(bVfEhq!g{Px?d6lKPz7NcL*)Ue$gOzFg~gHc)9fmNtg?%H~SR z(%9DFyurS~69=&$S5^Mu>Y57GwcS)hKmvA4BOxsSPR?@!vFgGsTH{gTkSVXKnNA#x zB1rL>u{2L^w4GMZv<;ghox%{Vdw9{`Bf%Y&3)BAVD!vC3E3e!Fbc$G?b?JZ&=nmAg zJOAG|83vHKV^dJ{+&@`(Fb1jrGw{{hz|;5h&Q diff --git a/sgl/models/homo/__pycache__/lazygnn.cpython-37.pyc b/sgl/models/homo/__pycache__/lazygnn.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4c39297adfa65860083dd5212a1d774e3fd45f8 GIT binary patch literal 4484 zcmcgv&2JmW6`$EJE|;Xf?KrlRstDkgMJnZxHvMqzeA()U!z7j)1#E)tdS^wEUM{KG zVJrztpnwcmIRrs(ZBA4#MgNKZ1w9w&pogG61U~kfOZ$7Xq9`>6TA)COoHuVicix-# ze)Hz_Mk8QIf12A$e!tAv->EW94wzSw^9)2V!Bf`f7UQgMwQWuG`Bj&{XBi44s z6|pFm(Bg?@v4WP0xF9Z~^u;sc63Xgh=CF1ENv`V3o#!- zWCaVIYy~YAT3%MJu!Zvp#)eK|fqI1vS}A-`zo-iLfg|r0)htleNpOraxJ)b5iaKb$ z2tXS}4RlV;tA)Znvp`ekdZ#W ziDThX*1=0sc+=Hxww**uJK0WGXfKP?D9?2zOQZcnf)R2s9E?=+f_C%KFp<*3H0_6B zH|we}lw_Q`5e{V%cE||A!BCQV^o3B#2!patm`D9#O3YxW7sKSEC_UFy5Bt$x_;4T} z&=*}}D0;q=s4yG#L$xE5NaVU4HLX=bk$lvR6YYd@X8@|Ak*;7)q?FY4%`U!}7(2yG zT84xi$t2e`su?L!vD84Rfv>v<<#Hy_HWZ?K-ivO8I}Hr$ub;LCKr`ST!*&vBmz$aU_S-(~Ju;4Gi7@B=uo z@rbqQamM5Ul&GvC8%zH)m?K-T2flpqi1*lp7yOtX@z7E>>eeBjI4Iq+E&EM}q0dv5 zp5LpgK-DI7STtW@hg?{n@K3LO%pSAH{IR`llUF?QRYT1k*@yg)zlkxJJHG`~dzam1 z@56FYU-*w}Sy;sikM8%oN?t})`%~iwURrA}Y5RUx zV2q#aCGiNRF26~m=Ex_=h$NX?_FDbjtrMjC_T5LXybufix}oed2#Wye)r7yb)ktCAoU41oB51^dj=f z)|<=n5^8!bj0ah+wrt2c-kyKrXODqBrIov{vPhe!hFq#p*Ko#PiBk6%T3 z6~0@R2TNZ)`IwA?5-E=G`3>Y;fRM?C@V&8xZ11&j>+lCmen7WY58gX*P`ZV6Y>{`4 z9p%xTQrMSpr;VwPO)7s6}T{cfhcXgEwVF;yDBr}=|+payUOGr#R(m9SfS!l^sxMkmP< z8L7k!dI`(rDYtWqC_QSqb;)H!p8mW%asBS4lO_{ zn~)Ja>luAf-Jpimx()092ByfL5qTA)>6qhhSMEm&fzDn#C>dBo-u6ehVYWwU`W1N$ ziig_+`7n}#Od_WPq|>ikGYQU(iV2h{5zG@vzkr;6AzI7*xLJD;TqHUEA^)DTCrd?e z4Cw!GZd@yBljURf9)hXeKW?!Zbnjj#o&wUAd>fJ)i-S9NJ3H!Wz)eatz)iknfO`SJ z%>mrj5jzjq++f@ZvS}Uyor5z_OcgKtf3ZX0nF!rha| z=V{p7T#7()&7~S^a ztRYpkPZ11pZJ_TcQzc)e)_6ZolXjJYz>*CJoTtr(Hxp-SG-&IU$X(*2ZN686`^`2xSleJik*jMMulFC!mZeEJcby9x;?joaS=ql|+iIHBPEwgSK% z0`T^(DZA%PTv)PataD{d-o@cReBvK70>#~HTMXk7x&R>22m5qNFg$o+E9xp-7prjX zRoq`)m{fO=twNn$0}N|xYX__EJ*8u{o3B#HyN?mgs@%X1|kT+`Z1Av zL<9)Tn2=0v6EUzk!_NBYHHv-x90PI!5QK=fZ>=C4^sR%HsqW4W{Xc8&xwD$1>-fCR zF5f%X^fYb#goNKA@^8BODYgAqx@yeg04OWMys5H)%qPk^yNrGqd z(vmgeNa2!Rk!9N)xTe)=;R?l#Qln>87WI=b)ImrOl2J;0Jq&k8QCjwx_+GYXi7g`U z6X_7?66t~9Iaa!z={q%Yx_ zu~t4a)^hWrDql&t_R3b^;@M_N7m>bjX4dRwB`xZ^NPPo5BO`>q`P#V`WBHV<(uk4j zrg<3#Ud1%|OPcHAHHG($0m-FV{I_-z{*`A -1 else torch.get_num_threads() // 2 - self._device = device - # hyperparameters for recycling - self._rho = rho - self._tau = tau - self._num_iters = num_iters - self._minibatch = RandomBatch(dataset.train_idx, train_batch_size) - # define the base model - self._base_model = RecycleGCN( - nfeat=dataset.num_features, nhid=hidden_dim, nclass=dataset.num_classes, nlayers=num_layers, dropout=dropout - ).to(device) - - def preprocess(self, adj, x): - self._norm_adj = self._pre_graph_op._construct_adj(adj) - self._norm_adj = sparse_mx_to_torch_sparse_tensor(self._norm_adj).to(self._device) - self._processed_feature = x.to(self._device) - - def generate_taus(self, T): - taus = [] - k = 0 - total_taus = 0 - while total_taus < T: - tau_i = int(self._tau * np.power(self._rho, k)) - tau_i = min(tau_i, T - total_taus) - taus.append(tau_i) - total_taus += tau_i - k += 1 - - return taus - - def flash_sampling(self, num_iter): - num_loops = (num_iter + self._max_threads - 1) // self._max_threads - num_iter_pt = self._max_threads - sampling_func = self._training_sampling_op.sampling - - for l in range(num_loops): - - if l == num_loops - 1: - num_iter_pt = num_iter - ((num_loops - 1) * self._max_threads) - - with concurrent.futures.ThreadPoolExecutor(max_workers=self._max_workers) as executor: - sampling_jobs = [executor.submit(sampling_func, self._minibatch) for _ in range(num_iter_pt)] - - for future in concurrent.futures.as_completed(sampling_jobs): - yield (future.result()) diff --git a/sgl/models/homo/lazygnn.py b/sgl/models/homo/lazygnn.py new file mode 100644 index 0000000..dc30cbe --- /dev/null +++ b/sgl/models/homo/lazygnn.py @@ -0,0 +1,106 @@ +import sgl.models.simple_models as SimpleModels +from sgl.models.base_model import BaseSAMPLEModel +from sgl.operators.graph_op import LaplacianGraphOp, RwGraphOp +from sgl.utils import sparse_mx_to_torch_sparse_tensor + +import torch +import itertools +import numpy as np +import concurrent.futures + +class LazyGNN(BaseSAMPLEModel): + def __init__(self, dataset, training_sampler, eval_sampler=None, hidden_dim=128, basemodel="GCN", dropout=0.5, num_layers=2, max_workers=5, max_threads=-1, rho=1.1, tau=2, device="cpu"): + super(LazyGNN, self).__init__() + if basemodel == "SAGE": + self._pre_graph_op = RwGraphOp() + elif basemodel == "GCN": + self._pre_graph_op = LaplacianGraphOp(r=0.5) + self._training_sampling_op = training_sampler + self._eval_sampling_op = eval_sampler + self._max_workers = max_workers + self._max_threads = max_threads if max_threads > -1 else torch.get_num_threads() // 2 + self._device = device + # hyperparameters for recycling + self._rho = rho + self._tau = tau + # define the base model + self._base_model = getattr(SimpleModels, basemodel)( + nfeat=dataset.num_features, nhid=hidden_dim, nclass=dataset.num_classes, nlayers=num_layers, dropout=dropout + ).to(device) + + def preprocess(self, adj, x, val_dataloader=None, test_dataloader=None): + if val_dataloader is None: + self._norm_adj = self._pre_graph_op._construct_adj(adj) + self._norm_adj = sparse_mx_to_torch_sparse_tensor(self._norm_adj).to(self._device) + else: + # If dataloader is provided, it means that we conduct minibatch evaluation. + # In such case, we could prepare evaluation minibatches in advance. + self._val_sample_dicts = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=int(torch.get_num_threads()*0.4)) as executor: + self._val_sampling_jobs = [executor.submit( + self._eval_sampling_op.sampling, val_dataloader(bid)) for bid in range(len(val_dataloader))] + self._test_sample_dicts = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=int(torch.get_num_threads()*0.4)) as executor: + self._test_sampling_jobs = [executor.submit( + self._eval_sampling_op.sampling, test_dataloader(bid)) for bid in range(len(test_dataloader))] + self._processed_feature = x.to(self._device) + + def generate_taus(self, T): + self._taus = [] + k = 0 + total_taus = 0 + while total_taus < T: + tau_i = int(self._tau * np.power(self._rho, k)) + tau_i = min(tau_i, T - total_taus) + self._taus.append(tau_i) + total_taus += tau_i + k += 1 + return self._taus + + def model_forward(self, batch_x=None, batch_adjs=None, use_full=False): + if use_full is False: + return self._base_model(batch_x, batch_adjs) + else: + return self._base_model(self._processed_feature, self._norm_adj) + + def flash_sampling(self, total_iter, dataloader): + min_iter, max_iter = 1, self._max_threads + count_iter, max_cycle = 0, max(self._taus) + pre_cycle = np.asarray(list(itertools.accumulate(self._taus))) + sampling_func = self._training_sampling_op.sampling + + while count_iter < total_iter: + # adaptively update the number of sampled subgraphs + curr_cycle = self._taus[pre_cycle.searchsorted(count_iter, 'right')] + curr_iter = min_iter + int(curr_cycle / max_cycle * (max_iter - min_iter)) + curr_iter = min(curr_iter, total_iter - count_iter) + count_iter += curr_iter + + with concurrent.futures.ThreadPoolExecutor(max_workers=self._max_workers) as executor: + sampling_jobs = [executor.submit(sampling_func, dataloader) for _ in range(curr_iter)] + + for future in concurrent.futures.as_completed(sampling_jobs): + yield (future.result()) + + def val_sampling(self): + if len(self._val_sample_dicts) == 0: + # When val_sampling is called at the first time, + # it would take a little more time to receive the subgraphs. + print('Waiting for validation minibatch...') + # Order won't be the same, but it doesn't matter + for future in concurrent.futures.as_completed(self._val_sampling_jobs): + self._val_sample_dicts.append(future.result()) + print('Validation minibatch is ready...') + + return self._val_sample_dicts + + def test_sampling(self): + if len(self._test_sample_dicts) == 0: + print('Waiting for test minibatch...') + for future in concurrent.futures.as_completed(self._test_sampling_jobs): + self._test_sample_dicts.append(future.result()) + print('Test minibatch is ready...') + + return self._test_sample_dicts + + diff --git a/sgl/models/simple_models.py b/sgl/models/simple_models.py index e9e0663..deea9ff 100644 --- a/sgl/models/simple_models.py +++ b/sgl/models/simple_models.py @@ -214,21 +214,6 @@ def forward(self, input, adj): return output + self.bias else: return output - -class RecGCNConv(GCNConv): - def __init__(self, in_features, out_features, bias=False): - super(RecGCNConv, self).__init__(in_features, out_features, bias) - - def forward(self, input, adj, recycle_vec=None): - support = torch.mm(input, self.weight) - output = torch.spmm(adj, support) - if recycle_vec is not None: - output = output[recycle_vec, :] - - if self.bias is not None: - return output + self.bias - else: - return output class SAGEConv(nn.Module): """ @@ -325,26 +310,4 @@ def forward(self, x, adjs): repr = F.relu(repr) repr = F.dropout(repr, self.dropout, training=self.training) repr = self.gcs[-1](repr, adjs) - return F.log_softmax(repr, dim=1) - - -class RecycleGCN(GCN): - def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5): - super(RecycleGCN, self).__init__(nfeat, nhid, nclass, RecGCNConv, nlayers, dropout) - - def forward(self, x, adjs, recycle_vector=None): - repr = x - if isinstance(adjs, list): - for i, adj in enumerate(adjs[:-1]): - repr = self.gcs[i](repr, adj) - repr = F.relu(repr) - repr = F.dropout(repr, self.dropout, training=self.training) - repr = self.gcs[-1](repr, adjs[-1], recycle_vector) - else: - for gc in self.gcs[:-1]: - repr = gc(repr, adjs) - repr = F.relu(repr) - repr = F.dropout(repr, self.dropout, training=self.training) - repr = self.gcs[-1](repr, adjs, recycle_vector) - return F.log_softmax(repr, dim=1) \ No newline at end of file diff --git a/sgl/operators/__pycache__/utils.cpython-37.pyc b/sgl/operators/__pycache__/utils.cpython-37.pyc index 0bc3309f62f64d762cae870992598f75023d801a..3ce4b41a875ded214c501d2d7c8408e912ddc351 100644 GIT binary patch delta 635 zcmYjOO=}cE5bc`Jp4pwrX0wSNTn`3`8%6M-k+_F7MhqInpahxAs@t7Sq8pN)iJV z9C=Zpc1K=ULuKjE>Y>#eS-m4GR=Ka8BGjw|WFIxx9PjSjRt{dVj~Iml^UGUo1*(a1 z2r8)x*;i7g-6wJ_OU6`pW9w-*-BH;=!}R>=iKO;lTW$2$v&6jf@0^&Tkx;|afEYrn z26f7))tvG9SGZ(;@M%cQUp@)*dy_#4aM^qgJ1{VR!i!j+uiS$x=2c}IgxQa#c!-lu zf?D3IS6AVF2gkRn>MAC+Kx7lqcUNbLXunz{x{fq-)9hjI@i5rkN_Tdpd@hIfqrpJ6 z@mb4NH|V`3qMPPR)PP&&dvpdAu5CIi*fyu)?uF4Dua`aBmZ_p2ez06ek5{CVVV@iKtn+hVHEQY2uLg}7p6FjcWnn8$Jx7pKxZPM zd;!>w%12Fu((?gmp+J`k(IL@+M9GX%qO3H(nc12D+^p_v<#(leHJf!9o^RuIu~J~{ z2PH=*g+T)?c>SjOs9NA@?(9RSz+Il%XPqL?au59i&+$C^MGg;Gt@!c7(AaDBNK1#N z8)(ljVqoYz7f~qH#FIVpw8R7N%x;#AXK9Vu)dy!rY8f5pb={Vr^!{Y^w8A zq~pCE=`qx4Mw+6jajZop=4Nv)HY0G4q!vo8xzswAVl=fQW9CG-0@B7j9i{OR`h5ww zs?-q!G{)4Sd40u^6$}%QiPz=a55%bNkG8ibK_|n-KnfA~Bhe3&ob)%x&-`FVhW{_r znzg);gc9L4;SM4(g0S0rnxuEf!OpG_FGZ4iytO5-V6A4zyVQ4&u#V_>&jr29mm(N; zCqiye<;*c^7^%D^={2owIZ!wdzrd+*po@bXD1MCMA_#T~rRkwkZAm&wo$latJTkbR zApQ@oZf?%~TSQD0xfh-%Pe|Th@;-R)EIbTDADMdlI?o@c;mf)^K|>Hk(6N^g;fT%{ zC3J!)N}Mwkv+3u?Zboc}DZBnlEt^e^YG}|1&u|0J0I0zY;zS_R38yjD+tBaDq@^k^ zhKXOvvf`@FE7gM3uW;c~ONFjxIQA|0$Vlz!0p4LFe2*JEW4gg7;OP&1ywy5&nJF0) zX-8;DDz1~r7{4v$O_k^Z1( wLK||P6{%8LUKE_qVKGR{kdZa}Y1c^Ez#yXfo2;23R_T6s{`_x>bU{~s0138AW&i*H delta 470 zcmY*UJxc>Y5Z&3mOYY9Q7)dE;B-KIMR0cnyr6^(=W3e2!7cgvcVfWI=StSrHEW-YQ z2>u9bYYRK!YX67sMj<%xW@lz+=Dm4y9^9Z41Rm1%G`vW@7QxG|Ge>=37PDt=kFXLe z&rr{rBMbHHD-^nmkNQzJw#;A(sT%_P*U#@dPd2K zhU})t!ulEvUOvu+^ueu4TQmDI33&g2ncL~G23;w)^1(foZAY6 zEd=KWD(Lh@h~Nh^{?PD4O^s217=94{=no^ma3p-~fwAODT-d)jRRiHWzlPv7%? z-S?dHyyyIyzCP`~+wFFk@Yz{Cn&^AaU67a;#{7fundE3cH_LQe$fPS%W}eAg^OCO2 zn{~VHILqd&nYxPAiPk6KJOA3D5^7u(|Eo9q&zi=ROnt>p5hY=F&YR~=WYaCP?6PUm zG$k)E?t;C6dsz>h3pBADd=OCC0(>86WH(+6rm7^panXz+0o`=Kq(%!NOXnXncfaNc0g!m#UkAynMnSTMC9QV zH}VlJndFJ2mQQ6fu)o1)FGpnyc?w&T*LX6KzbP!O-5%9!Xw$iLLXT+rq|p;er_d$) zaOB_GV$7{F=7bSXoDCOGd%k2f!?+2Y5yTNxGw+AAw~_V18E>z*pY(?bdI@&GH{KR@ z2>$kVvNq_b-otufwz^X~h{j{^ZuM8OE#$Tt0TZbt`$%TIhU5T2kf0smWq7aVI8!0! ztMk$L6268RL8gE$FzV};MsRom7JS_uRHn)fs&-zF5?@6SCa5A9A_x&sDSQO}^mR#L zM=T_GGT%$E$R+?p-v`$+NQVoiZb-b9pZns8{PX za5WG*;1k+~eD!4?=_hOT+K@?<@5y8^$}j9UM6Q4y#9WGy>+}%Hb|Kg6wJ1A;yh;zj zK3ls#JS$%^>8oG0%(KUsuvqm5NC)>}8}YTY=&b}Uf+zyqsM+LRhr$w5+92-pNyLN? zp{wrk));vbF){uQk~;`G3Dy(X2;%Ud{j4&e2KfdYLORsgLX6{rhdm@U5o{&cM$k;q z3-5%QhpZ&7K~UXge%6xW9)f!b1_(rWgruDW;-d;lb^G~hps35e3??Pw3n_v%uD)R#*iz)?`#rF@9j(S10l`E|y#VEoN_e6n@ z;XyG*hl;`Phil;u_I2^E@DcMNv1+tn14#sYe&4QEiH_ubI4apP+05|-&+;8ODt9<7 z2iFM1YE6)&Ba~!~Ynfz%KLl?y7DU7U2I0sT|EJ*CJW4>Uz!F3zDyN)|j-8s~`z<Y>=gW=2H7UCx1Ey3H#fsVo1aaI{IqVv@JZW`!IeE# zCej5XpWp_yAbi~3%0}Ra_CcqhpeXYUbjP+y^BB9oScpZVc=4O~r|e`uihETbKSOS! zt6m~$A418oJDW)!Ok|8KFSn$majwm9DzIeFW6K%2tdS7C&vPPTBAXqX*77`cjNeUp zDE-P@O3s0FR?AmX(!@OcJR)9vqPQ0z*1b-80yRhBaQEcNA&MqCm2b}xl4!d*O@Nc% zfAv%W?H<|@%B@(e%;iz46q{1(u;G7@lU)sqsq*_MS06&jWlW@UGaW`w<1Qmnl1ach z=3wu}2DTlhH%4V~^xg>3Yag~MOoy}L56syizP_E5jJ)htqEk!+(6KD+xa#5ld3IRk3wQvY)~fm|X4o^#_Tl#48QNaw z>~B%4F^Yv9sP~)KQNMsIymDSStW9aWazw$rwjH3-d60+TTP?h+N@~y86<33#-yh=D zVTU}J-!w6*M%sUsxoUYBbGfA);Zc-5Qm*6GD0`*6kVj#gt0hrCrCoIJ*yS>+S6Wma zpI5N5YVO$OFcNUyHD29oC41sy zcbNt!qL08?U%S&x4RiR&R~I&OJQE)_{9pP^4wgg z%Si>nbTSgyih>9eP=2uwUQ5Q{NGPO6y)jV_-zCG451sRmqOsRdg`b>8!K*G~SKxHj zD#`5!?j~NV$U@g&ymle4-Nfv_cun?l;`N-?*iW?el>&;i{aXMrQ__|^D5S{Ra=Hp+ zyOXtGYMZb{Z)pi;!xx&t&dj-D>{p1phRdAX17Ftctiz#;Mk>9Dpqao&um}Mj4)*8< z3e-(fw!nNaC|mkAyj1&a)7_+`;b9xwNNy+SAm}7;5iEf|&#Yq9z?b1z1I$?yPT~v}BBG^RGOR$+hhQ~?jgV9JW+d@S)(ez%%pGAU$ z`SEX&WbYkJ&R>gEDf?*TMw!@#|HO8J9i-62ZjzG(k^yp~y|g5<7_MSTwWN3uVL%@o z7pdI1ouWHniHD$NVH^8AKd^8>*(19{jp-diz~`4W#)4ua*5Rt)GIQo+M&!hLTou_) z=N69=$hL5jG6)4NEz;&tMht={J}$}sGt9(4`5*GTU^;QANdDDCeh<5Y{HVdB%jS;I z4q1X~(z2qqh7@#Q74A!J{m)wwraN#Mf0h@)&3vkRiMz%f7Nhy^lGhoVfL9vgh#xex zYX{In9E9H+>RClT+Q?Z{a*j^-V;o^ve zDW8?D)XOc<*z$eN?UA|%v*VVX5te*!C0d)<9_VZBE0+wEC!U3itt-`MF?MJEW^1#e zpQh}kdwCK9?e+3rj=;)}4e(*D2d3Na(H5hRI4O6f(P!ba_RStr7T9n}{{pT|44TvB zwUv{ncjs3f#e_Ndqoco5+OQcVYLvD;g9@a{*;{&1LlW(jgUQau#PhT~MsS?q1VX`U z4W;a~7|Pg|n1%N`qwEyi>@-%;M2o{Yh!P%5Qt~`$Wld#bj%2x)G|9c5LtQGrcFFr} znr6Dg2#Q&9lQ-rZN%tcZyxVf-(C&*LpgffR_E`$$z1dvaE~TV{`NfNf_#KeN%>i^RQjel$0^aM|AKOaN#Hiv2 zg(x6smVyo-j?$+^oj_5C(e7v^Nk*M4UF;%1Q0a_aq}-Wa}o-3F-;xzKRJr-qXbD;fr7)#1 zw=hI8r*NgPrLea!MzOdvq;RBgwlJh{HZw)BrZ5IGXmW3Cn83!T$yCGxG@*zWNNBQ6 c{>rT|`5~LgWNRLN9uA-o2O|d)2P>Bh0936RM*si- delta 1323 zcmZ`&TWb?R6rP#AWOHqhN`n-vH*mdF)YiLFTeS~j5Gx7+S<{`7CZtK7*@A?WJfzSk zLCAxz?N9Iz_z(O8`|AHtzq8wlpsO%B0cbLs;E{My-`8V~p?5rPWoXD{H0CO1-u>srVAYSYV&LYv{ z1FWZ@qzVdghM^QzQVLPk1^D}R?3m3BrPGW~<5Qe(Z>|SPU6eUi6rac0AP8G^V}c;9 z5Q@`J8F@-!q}3_$?&P76iTQRT!buJv8bk(>MGToaz?azw;WZKPGxlL-i$|hqHEn!x zU@lSg)!Gt0@s4r!x$qwC0b;D`3q+LZYqim;hsw`r5}v+GKrh6uPI!ymaIMn_)q%S9 zGmCi;J4v!fE)E#AUv;wJ`!O_WL z#szV4zrLZO-Z*&3=lUMKoHQeeeF2sVzQPN(=W6TK*iTJOlsmU=?e89IWt5uTHqE|tsqf6sCIW&i*H diff --git a/sgl/sampler/sampler.py b/sgl/sampler/sampler.py index 02e88aa..9990341 100644 --- a/sgl/sampler/sampler.py +++ b/sgl/sampler/sampler.py @@ -43,14 +43,14 @@ def _pre_process(self, **kwargs): if "pre_sampling_op" in kwargs.keys(): if kwargs["pre_sampling_op"] == "LaplacianGraphOp": - graph_op = getattr(GraphOps, "LaplacianGraphOp")(r=0.5, add_self_loops=False) + graph_op = getattr(GraphOps, "LaplacianGraphOp")(r=0.5) elif kwargs["pre_sampling_op"] == "RwGraphOp": graph_op = getattr(GraphOps, "RwGraphOp")() self.adj = graph_op._construct_adj(self.adj) if "post_sampling_op" in kwargs.keys(): if kwargs["post_sampling_op"] == "LaplacianGraphOp": - self._post_sampling_op = getattr(GraphOps, "LaplacianGraphOp")(r=0.5, add_self_loops=False) + self._post_sampling_op = getattr(GraphOps, "LaplacianGraphOp")(r=0.5) elif kwargs["post_sampling_op"] == "RwGraphOp": self._post_sampling_op = getattr(GraphOps, "RwGraphOp")() @@ -96,8 +96,9 @@ def sampling(self, batch_inds): Method: Neighbor sampling Outputs: - n_id: global node index of each node in batch - adjs: list of sampled adj in the form of sparse tensors + batch_in: global node index of each source node in the first aggregation layer + batch_out: global node index of each target node in the last aggregation layer + sampled adjs: list of sampled adjs in the form of sparse tensors """ if callable(batch_inds): batch_inds = batch_inds() @@ -167,14 +168,14 @@ def _pre_process(self, **kwargs): if "pre_sampling_op" in kwargs.keys(): if kwargs["pre_sampling_op"] == "LaplacianGraphOp": - graph_op = getattr(GraphOps, "LaplacianGraphOp")(r=0.5, add_self_loops=False) + graph_op = getattr(GraphOps, "LaplacianGraphOp")(r=0.5) elif kwargs["pre_sampling_op"] == "RwGraphOp": graph_op = getattr(GraphOps, "RwGraphOp")() self.adj = graph_op._construct_adj(self.adj) if "post_sampling_op" in kwargs.keys(): if kwargs["post_sampling_op"] == "LaplacianGraphOp": - self._post_sampling_op = getattr(GraphOps, "LaplacianGraphOp")(r=0.5, add_self_loops=False) + self._post_sampling_op = getattr(GraphOps, "LaplacianGraphOp")(r=0.5) elif kwargs["post_sampling_op"] == "RwGraphOp": self._post_sampling_op = getattr(GraphOps, "RwGraphOp")() @@ -217,8 +218,9 @@ def sampling(self, batch_inds): Method: Sample fixed size of nodes independently at each layer. Outputs: - cur_out_nodes: array of source node inds at the first layer - all_adjs list of sampled adjs (torch sparse tensor) at each layer + batch_in: global node index of each source node in the first aggregation layer + batch_out: global node index of each target node in the last aggregation layer + sampled adjs: list of sampled adjs in the form of sparse tensors """ all_adjs = [] diff --git a/sgl/sampler/utils.py b/sgl/sampler/utils.py index 80c65db..70cfe09 100644 --- a/sgl/sampler/utils.py +++ b/sgl/sampler/utils.py @@ -47,33 +47,4 @@ def __iter__(self): def __next__(self): with self.lock: - return self.gen.__next__() - - -class MiniBatch(object): - def __init__(self, seed_nodes, batch_size): - self.seed_nodes = seed_nodes - self.batch_size = batch_size - - def __iter__(self): - pass - - def __call__(self): - pass - -class RandomBatch(MiniBatch): - def __init__(self, seed_nodes, batch_size): - super().__init__(seed_nodes, batch_size) - self.num_batches = (len(seed_nodes) + batch_size - 1) // batch_size - - def __iter__(self): - for _ in range(self.num_batches): - batch = np.random.choice( - self.seed_nodes, self.batch_size, replace=False) - yield batch - - def __call__(self): - batch = np.random.choice( - self.seed_nodes, self.batch_size, replace=False) - - return np.sort(batch) \ No newline at end of file + return self.gen.__next__() \ No newline at end of file diff --git a/sgl/tasks/__pycache__/node_classification_sampling.cpython-37.pyc b/sgl/tasks/__pycache__/node_classification_sampling.cpython-37.pyc index fe0401289b85a8e0751ef722c718e263c80bb524..8d28e8b20cbefd6e02bcb885dfca5e194822ed5a 100644 GIT binary patch literal 9688 zcmcgy+ixS+d7l{$$>H!$Nz`T6Hodz^%*58(aoPZztuO1%I%yZ%z$;*jX`R6=&rqU7 z4yiN4wWQ88X|<@00(KmrD4Hv%zU85Bc`W*t0QnDsc_<3>$v|J*haeCAecuc(*2-EL zXiMRIbME)=^81c%mCL4r-(LBF`-4{%SQSoZ`_)0ss&PEmuMZklW3Xf` z4VJCt!HTsqXj;v|sXiC!`Z7BWx8z8Pin`mjG z^uE$ng)Xv>3s%O}MNZ@&E61vp6^1B0R;-*bMG@cp`--ku1{PY5FWnmo_pQF;`@L?j z;{?5-XWw@QSdq6K8(zO}_lLg!$(J$Cy-tzF#E-sNgbV_oX`lpZs2pn(?F1(lB5r3w zWjC{{BbE&{&|HuYwTW?}gqctm>W3(|3t<-242oesC>?9^ouC|4!u+x3X<;5E)i4KI z3k}eESO9H=Cg@UF5XxRbh5|JO^|TzCob%BvXC+i7O^l}4t3tt2^?NuttuN!!dv0%g zC$NRvaSpfcwAbRSKOVVKma&@Hux*@6VB7H*>~prpwgoN(&NNmJ<7?;F>^l$Lz8_yb zzt#Zf(3cv`!5lf#8MuKf{rF4gw;GOu-k^8n%BM}j9Su7>er)#qp63US*Ky;FAILbr z?FLSOhFP=ZI!1V(glAv9Tb>HoG%PB#{r=l`U+NzoT=%#8*Kv3E z{Oca>TvB(={C>*(o1;Uygeg@J{AWQFt)k^rlLG#7>IOpc*H9bk(Y159Wpm0@9hEx5 zkG>S%Lhy4SkQT(HDuk~564Lm*fRFs@j6kV?J*Rtl1u37nMbXY5dfH$%bO~)pfbq}) zx(H}+^;lJvfCX3t@s)_K-Rv#|5NI9(OPiHzfgv;qHFZr9nV;xF|5=7wqh>*N0|+Rz zD1w$mmSpgUBm~PySLDB?rB$SAD4Ukn)BMJ%{3WC_yUW7Z)8wxMGk{P--_2v1;bQw$ zZ|yUBSfw5&YoO~VQ?b0UklskwoDqdcN63Dx9wg&jI#zEh@BTi@r8OXscMBh@ax1t@ zeIn4BHW1bkof1$-;`n)p`nE#h07TuIgx=GRdF zIES+{Lw!%jY%Ygpq=!Y?hyBYBaHf;1CrUDtw2mfkOzU9WI^P>ZG-iEfxKFxGhFF5C zqZ_guX6gGgo$pB&7k!1)QN|Iwv3hXS{El=-a@cX9g}v>T7=mY=o8|#qXr8T z79muC&ZDC`L?m^U(Sp~n`Fb5wq@^^j7;tT$TiYtq1ck6+UK`n!IU;CPpv$g6hfSpj<@Z?{2C8zlF?67d=AVs)&psGS}>F?Oj|aI zY!D&iT&@#&j!1zBNkU1YFwS*{a^M7U$sP^;Aiaw15{uxth;6s&6l`3xc>?MvAJ_q% zaBR@&(VdA49HX_xS=ME-8pP)O;>Cs8g^Tm)MPoDBrZ%Un2SckCA0mj;egah-#aWEPOew<|`5EpqU z*A~4_5a*J@xYX@=5JO2y-ltx-xWfd0oMx?WqpXME-vj|`frZ$r({#j2V1tTkW*T6t z4Pvbfl~R5MB@L#q0>M6}cJ#uzmfymn+qxucQdU9k-D?*JYVQOoqy@y~)Yk2|5P|l5jW8cyp-N;lGO#0X=aUdLR7I^U zm2~4R&hH>=_M5?KI_Vj#t}(2JDVM@(_8%~;X3qg&bzPK1`LSx{MMYErunh*-1zfvE zeC3SYdB^P>cKYrVXU})#69Dti!rA|dIsJd&?4S{vyYvynELEfxG3FjsQw(N~@>S3W1~KE%jp!)}(kzo6+6PXKXWWbo-; z%?u^7LkkEQ1ph923X~5{0dJ(?GXw?jW|{{z5DEa``T_tq&jR2g8)Xn+bim9h?v^HH zGQv{$O8sRf70@b-g*@s|znMu5W&?~c#2YX_0B{d>!oqFkWA)wdMkc(9WSE#D&(v7Q zb!D=Q_zHTZ-kV`@A*af^gY<+EmT$uu8bs_sVa(lI%-$L`nnP=(QGD{ zsSl1l6E?zStY;;xlV21)JPHv<&6N+7$+L(rp|lw`C(osA)+Vhq_B?xNOCQ1*L)|MU zO0*j4VV2?Xzl3^}4Ob($8vw7b!f>lFVB%q_o;D=G;y1JAIB9#$1V4oBg;FD3Suuz}zG+5INoAdyGB?33QuhUR;>1|wXZ z@J&Baf2i&2_YxtiZ1%SNHqM;yxH&hL1#D^z4j~WnTE+JN3UOm z|5FF$=FR@FDd0>jn_sBiru`rF&HWHdGP=d4lTXJdk7|p6NWe zfTy6oC3}wdDX&Il3jAA;q&HeeRZi3SBpFlqA7EJdM?{81UL!IfGK2l{5hdRv@;Z?} z2Z{C0SU50|553Ng9D3w8E^oUYEebH!8T)dds#Usho9x&#qdn7iJvpY5dWwHtVRv08 z7)y74QqrSwZW9?%-2ze=6}M?$zZTMUJw`Pg7cACl_~m8 z)8#}9r;y#!Y0*}mTlPd;NH=3BBJx+%dFo=KOC9Hv%YI;`kwYsl-45&(8CMV? z^VNRjc7SlKqUXM6PpdB^4z0?(gge!7&#-VZT4IwKGFd=%<}^5C1+YR$Pfj%x@dT^v ze@4Xbf+T){7Sv-C;DvC!Nvaa@Ww@*_tIO;IXaK5kDVM{GfLax*33Y0qj1+6)1xWk@ zCA2MZPfeA43K!COQ|AD)Ooxmo`jKa!1OqF+XURjL3Uw8Nv56F0* zjKE6(2WnIc^OG_#cx{={{Jw}a??2pw;N|Xz8_mvPhonb=p9`7Fu;5%E8 zPsZ4H8LrPStNb1|^nXS4%I~8mi`{*=q|<9QH_iT4Hy*U>iNyGrO8y4K zGN@Tsy8Bj^VuwivF`^`c(=xeR68*kan+=ip7JRFED#t2QM#AjSi55>PN%lO8m2Q$E zez%JA$_I;yL-`Lh^g?xVYd^{6AEDNNA^0zVC}3lHd1{io53Ee_DRaYD0VZysMAwcs zKi>h(E>Q(z%-9re0@B%wOsRq?zk|FPQ?iYk7hqS|O2mxBn3~WR7*hjFz&Qt&$S`9v z`!Zu@!I)WQOk92e76&Ev#FJ2JW#;KAg?lZMwKx+2Ce0y4tyxZy8IvCQ4YFOqc-#(* zY2pDyi%OV936>*f5tPW|A}lrcrhn6TYG#h95e1KY~H$81>g@XrAcVM^>l zpMz0;gbn+@@3UC)(gI7qd65&Jsn*ZSk+*1>XK`dXQDJAZr2Gjr`zc7fl6diQ zT>a20@RvF8A5k0^=bF&B>hnIQJXkxOWmPz5&V|WDW^9<)x(j?*TBv_uEv8aVTx@YR zEB*j8#lPC50F{|<<6K5Odg*f=%A#*j_AwaV#}YNrHYI6!xlM%L0oYR{dlZ{{14&d! znxxz%!s0tM8(yIdHXF##M3T{`mk)LXuEG)cEpwecP_VYx+GHy-3lxtRa5_XB5czR8 z*}ud8pBNmwOKOgTUf_L9FM9ot^!?-Mzg(+jq9_$LrW9|0F~h$0W8%?4VdEwh}QYjOM22?sjix zcduv8nDdNQ;Vcq?91aF7At5mewRsHev|GZMYTB#TWfAe?To1Jf0r$>)qzzCxrG4nSyDU2mg@?}$|$df(A zS4~yW)YJTeL@4ofvnV)hh|#H?f!-fl=2F*eRvz!?a+Qq z>8~j*C%=*QsQmbb{BNq&UB?Ke+dFcpUNj@wM zTOYT8FTwoM<1dwONUN8Cd^MAu;ZCM=H?-DwJnxR!Y-cLBL~&sw{3@zkO#V=9&R#=O z*mNB!v7c9w`dso+^+HD&s^R!$1hLasfXcL;A+tjlK4yoVB40_qSbK5#O_aGUragPj z^@2}$o8N%B&E%)GZLNZ$)#Pmb%ndxJOb7oCT^@dwU&I+SN=RoF%W^zB2pre*EbCAq zoaZ?d!c*jBfO0afHx;B3o3@kR)hnN^y+_t5lbQ0MXv(g_RHi*3d(>3(=rc7|XC+pC zK+OWHu<8S1>hBSSm_-P_iR?u3Zu7)(p6?XYI(Rzx!?y~L3{Sw5|I*FVByEgIM64zw zuvWU&k$4`7Akt~^33%}yLL+&jc{=eMy4HeW_FVF-)>4u-E|yRkpF-X$@=Ajt#9w#U z<+;q1I~+Kh0iQ*&IfNq!M?u{;=u+~_)RIItT|l;HPo;9Ksodvj9*1C!$F> z5zk>Co&aN>(5j?S10VQnbQT_sCVxG$oIIL5+X+D8|KBKaem_Jt5~@l50xQ0+O;mBu z7|+6KtI7HHO7drIam3VEh6m0jR^KPRB0O@>XpvY$DlJ9Yn8t;$9BE8rg@;nCgQSp3 zD`7P%>`|CE_UXN^MaFfqPv5y0m!cBWX9!?XpfLjEsv^A|r7ls>RKy9=D=u zJQ+IJ(s7rXA?XI z^nZ-YUN;Dxfj_+Ri)4Mq8MQ@jM)lpN>K9JlkNIa%lbk8uAPCG-p1=cVz}>8fd0UXX ztu?3=3ou@QYH?}d{-xFD*W*ryi;BTP1HX*mBfNyLg>Yanze1oA@BuJ0!|Lz&7VaU) zDjRMeO?9o%-U+yr{CIZyTGp_h)Ek#s>#iN{aQDz??xALpx%;TD2w7raZE%}$6UU0P z!|a64O^vdWwPpvdh3E5*l+u+u6(%*n^+xYqqoX6`Qk62|_p&%=fSoOzlFB zhcgh#KdS{$Qt%{r7jhGDUlVf39S(wUy?c+pgB-)^_SaqR_MKa9XD&}lH(`Rmj9!#^CNhIF9Q(BfI6L~I&Dh| zg}+p#Wm*QhE>)!?v_%`VA+q)$V<%dpr0x&>)eky^A4qpDN}wN*@~ zXGgb0_&Hcm@P|1P8wot#U=1LmjBY91jA*2^p;OzEfF?Z>NszE8P!c3872a}BNmac0 zNT-V6XaWmk38iB#)FUlYAtw}~!fg`Cw+XwQ7c6oAppc8uI=EhWG=Cj-!QVo77a%KI z{Q>vUm&j_LEhy6dI8fQ7~Ob zvpe#wJCIn){XySa^9IfqB!x~XxBdljak0yJ=E{j{{c%U|@1h7^mW~1q(EQ2v#mx9b zXypzNf%O3h{tKQ!bl`LrEI$UZkW4I|8%c*<1~dn`mkmPD&ElsYWhw?0mv7Pt0=!3ko r#fB|!jSe<{=(J3~45@3!bFWYk3$SI<;Xc6?go{@O4=%{{5mEjJN^9Vf diff --git a/sgl/tasks/node_classification_sampling.py b/sgl/tasks/node_classification_sampling.py index 7b9a205..64fc623 100644 --- a/sgl/tasks/node_classification_sampling.py +++ b/sgl/tasks/node_classification_sampling.py @@ -1,10 +1,11 @@ import time import torch -from tqdm import trange +import numpy as np from torch.optim import Adam import torch.nn.functional as F from torch.utils.data import DataLoader +from sgl.data.utils import RandomLoader, SplitLoader from sgl.tasks.base_task import BaseTask from sgl.tasks.utils import accuracy, set_seed, train, mini_batch_train, evaluate, mini_batch_evaluate @@ -141,7 +142,8 @@ def _postprocess(self): return acc_val, acc_test class NodeClassification_RecycleSampling(BaseTask): - def __init__(self, dataset, model, lr, weight_decay, epochs, device, loss_fn="nll_loss", seed=42): + def __init__(self, dataset, model, lr, weight_decay, num_iters, device, loss_fn="nll_loss", seed=42, + train_batch_size=1024, eval_batch_size=None,): super(NodeClassification_RecycleSampling, self).__init__() self.__dataset = dataset @@ -150,10 +152,18 @@ def __init__(self, dataset, model, lr, weight_decay, epochs, device, loss_fn="nl self.__model = model self.__optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay) - self.__epochs = epochs + self.__num_iters = num_iters self.__loss_fn = getattr(F, loss_fn) if isinstance(loss_fn, str) else loss_fn self.__device = device self.__seed = seed + self.__train_loader = RandomLoader(dataset.train_idx, train_batch_size) + if eval_batch_size is not None: + self.__val_loader = SplitLoader(dataset.val_idx, eval_batch_size) + self.__test_loader = SplitLoader(dataset.test_idx, eval_batch_size) + self.__eval_minibatch = True + else: + self.__val_loader = self.__test_loader = None + self.__eval_minibatch = False self.__test_acc = self._execute() @property @@ -164,7 +174,7 @@ def _execute(self): set_seed(self.__seed) pre_time_st = time.time() - self.__model.preprocess(adj=self.__dataset.adj, x=self.__dataset.x) + self.__model.preprocess(adj=self.__dataset.adj, x=self.__dataset.x, val_dataloader=self.__val_loader, test_dataloader=self.__test_loader) pre_time_ed = time.time() print(f"Preprocessing done in {(pre_time_ed - pre_time_st):.4f}s") @@ -172,17 +182,17 @@ def _execute(self): val_score = 0 best_val_score = 0 - total_iteration = self.__epochs * self.__model._num_iters - taus = self.__model.generate_taus(total_iteration) - tbar = trange(total_iteration, desc='Training Iterations') + torch.cuda.synchronize() + train_time_st = time.time() + taus = self.__model.generate_taus(self.__num_iters) iter_id = 0 - generator = self.__model.flash_sampling(len(taus)) + generator = self.__model.flash_sampling(len(taus), self.__train_loader) for sample_dict in generator: batch_out, batch_in, batch_adjs = sample_dict["batch_out"], sample_dict["batch_in"], sample_dict["sampled_adjs"] - batch_x = self.__model._processed_feature[batch_in].to(self.__device) + batch_x = self.__model.processed_feature[batch_in].to(self.__device) batch_y = self.__labels[batch_out].to(self.__device) batch_adjs = [adj.to(self.__device) for adj in batch_adjs] @@ -198,50 +208,83 @@ def _execute(self): recycle_vector = torch.cuda.FloatTensor(len(batch_out)).uniform_() > 0.2 new_batch_y = batch_y[recycle_vector] - self.__model._base_model.train() - pred = self.__model._base_model(new_batch_x, new_batch_adjs) + self.__model.train() + pred = self.__model.model_forward(new_batch_x, new_batch_adjs) if recycle_vector is not None: pred = pred[recycle_vector] loss = self.__loss_fn(pred, new_batch_y) - iter_loss = loss.detach().item() loss.backward() self.__optimizer.step() - iter_score = accuracy(pred, new_batch_y) - val_score = self._validation(iter_cnt, self.__dataset.val_idx, prev_score=val_score) + val_score = self._validation(iter_cnt, prev_score=val_score) + test_score = self._inference() + if val_score > best_val_score: best_val_score = val_score + best_test_score = test_score - tbar.set_description('training iteration #{}'.format(iter_cnt+1)) - tbar.set_postfix(loss=iter_loss, train_score=iter_score, val_score=val_score) - tbar.update(1) - + print('Iteration: {:03d}'.format(iter_cnt + 1), + 'loss_train: {:.4f}'.format(loss), + 'acc_val: {:.4f}'.format(val_score), + 'acc_test: {:.4f}'.format(test_score)) + iter_cnt += 1 iter_id += 1 - final_test_score = self._inference() - print('best val acc: {:.4f}'.format(best_val_score)) - return final_test_score - - def _validation(self, iter_cnt, val_idx, prev_score=None, val_freq=1): - if iter_cnt > 0 and iter_cnt % val_freq == 0: - val_y = self.__labels[val_idx].to(self.__device) - - self.__model._base_model.eval() - val_pred = self.__model._base_model(self.__model._processed_feature, self.__model._norm_adj)[val_idx] - val_score = accuracy(val_pred, val_y) + torch.cuda.synchronize() + train_time_ed = time.time() + print(f"Trianing done in {(train_time_ed - train_time_st):.4f}s") + print(f'Best val acc: {best_val_score:.4f}') + print(f'Best test acc: {best_test_score:.4f}') + + return best_test_score + + def _validation(self, iter_cnt, prev_score=None, val_freq=1): + if (iter_cnt + 1) % val_freq == 0: + self.__model.eval() + if self.__eval_minibatch is False: + val_y = self.__labels[self.__dataset.val_idx].to(self.__device) + val_pred = self.__model.model_forward(use_full=True)[self.__dataset.val_idx] + val_score = accuracy(val_pred, val_y) + else: + val_scores = [] + val_sample_dicts = self.__model.val_sampling() + for val_sample_dict in val_sample_dicts: + val_batch_out, val_batch_in, val_batch_adjs = val_sample_dict["batch_out"], val_sample_dict["batch_in"], val_sample_dict["sampled_adjs"] + val_batch_x = self.__model.processed_feature[val_batch_in].to(self.__device) + val_batch_y = self.__labels[val_batch_out].to(self.__device) + val_batch_adjs = [val_adj.to(self.__device) for val_adj in val_batch_adjs] + + pred = self.__model.model_forward(val_batch_x, val_batch_adjs) + val_score = accuracy(pred, val_batch_y) + val_batch_size = len(val_batch_out) + val_scores.append(val_score * val_batch_size) + val_score = np.sum(val_scores) / len(self.__dataset.val_idx) return val_score else: return prev_score def _inference(self): - test_y = self.__labels[self.__dataset.test_idx].to(self.__device, non_blocking=True) - - self.__model._base_model.eval() - test_pred = self.__model._base_model(self.__model._processed_feature, self.__model._norm_adj)[self.__dataset.test_idx] - test_score = accuracy(test_pred, test_y) - + self.__model.eval() + if self.__eval_minibatch is False: + test_y = self.__labels[self.__dataset.test_idx].to(self.__device, non_blocking=True) + test_pred = self.__model.model_forward(use_full=True)[self.__dataset.test_idx] + test_score = accuracy(test_pred, test_y) + else: + test_scores = [] + test_sample_dicts = self.__model.test_sampling() + for test_sample_dict in test_sample_dicts: + test_batch_out, test_batch_in, test_batch_adjs = test_sample_dict["batch_out"], test_sample_dict["batch_in"], test_sample_dict["sampled_adjs"] + test_batch_x = self.__model.processed_feature[test_batch_in].to(self.__device) + test_batch_y = self.__labels[test_batch_out].to(self.__device) + test_batch_adjs = [test_adj.to(self.__device) for test_adj in test_batch_adjs] + + pred = self.__model.model_forward(test_batch_x, test_batch_adjs) + test_score = accuracy(pred, test_batch_y) + test_batch_size = len(test_batch_out) + test_scores.append(test_score * test_batch_size) + test_score = np.sum(test_scores) / len(self.__dataset.test_idx) return test_score \ No newline at end of file From e52b04eaacf149179ffe2c4b2d2c792618a2d253 Mon Sep 17 00:00:00 2001 From: infinity Date: Sat, 18 Nov 2023 10:58:59 +0000 Subject: [PATCH 06/28] reorganize the code structure; fix some bugs of ClusterGCN. --- examples/clustergcn_nodeclass.py | 5 +- examples/configs/clustergcn.yml | 1 - examples/configs/vanillagnn.yml | 2 +- sgl/data/__init__.py | 3 +- sgl/data/__pycache__/__init__.cpython-37.pyc | Bin 699 -> 724 bytes sgl/data/__pycache__/base_data.cpython-37.pyc | Bin 10901 -> 11995 bytes .../__pycache__/base_dataset.cpython-37.pyc | Bin 11913 -> 12046 bytes sgl/data/base_data.py | 30 ++- sgl/data/base_dataset.py | 4 + .../__pycache__/planetoid.cpython-37.pyc | Bin 4396 -> 4274 bytes sgl/dataset/planetoid.py | 5 - .../__pycache__/base_model.cpython-37.pyc | Bin 9740 -> 9277 bytes .../__pycache__/simple_models.cpython-37.pyc | Bin 11525 -> 11723 bytes sgl/models/base_model.py | 51 ++-- .../__pycache__/clustergcn.cpython-37.pyc | Bin 1183 -> 1063 bytes .../homo/__pycache__/lazygnn.cpython-37.pyc | Bin 4484 -> 4418 bytes sgl/models/homo/clustergcn.py | 7 +- sgl/models/homo/lazygnn.py | 56 ++-- sgl/models/simple_models.py | 40 +-- .../__pycache__/base_sampler.cpython-37.pyc | Bin 963 -> 3412 bytes .../__pycache__/sampler.cpython-37.pyc | Bin 14035 -> 9898 bytes sgl/sampler/base_sampler.py | 76 +++++- sgl/sampler/sampler.py | 243 +++++------------- ...ode_classification_sampling.cpython-37.pyc | Bin 9688 -> 9347 bytes sgl/tasks/__pycache__/utils.cpython-37.pyc | Bin 10787 -> 11243 bytes sgl/tasks/node_classification_sampling.py | 105 ++++---- sgl/tasks/utils.py | 71 +++-- 27 files changed, 347 insertions(+), 352 deletions(-) diff --git a/examples/clustergcn_nodeclass.py b/examples/clustergcn_nodeclass.py index 96ac77b..91ff3f5 100644 --- a/examples/clustergcn_nodeclass.py +++ b/examples/clustergcn_nodeclass.py @@ -1,6 +1,5 @@ import yaml import argparse -import networkx as nx import sgl.dataset as Dataset from sgl.models.homo import ClusterGCN from sgl.sampler import ClusterGCNSampler @@ -20,14 +19,14 @@ device = f"cuda:{args.device}" if args.device >= 0 else "cpu" dataset_kwargs = config["dataset"] cluster_number = config["sampler"]["cluster_number"] - dataset_kwargs.update({"split": f"clustergcn_{cluster_number}"}) classname = dataset_kwargs.pop("classname") dataset = getattr(Dataset, classname)(**dataset_kwargs) sampler_kwargs = config["sampler"] - sampler = ClusterGCNSampler(nx.from_scipy_sparse_matrix(dataset.adj), dataset.x.numpy(), dataset.y.unsqueeze(1).numpy(), **sampler_kwargs) + sampler = ClusterGCNSampler(dataset, **sampler_kwargs) model_kwargs = config["model"] model_kwargs.update({"device": device}) model = ClusterGCN(sampler, nfeat=dataset.num_features, nclass=dataset.num_classes, **model_kwargs) task_kwargs = config["task"] task_kwargs.update({"device": device}) + task_kwargs.update({"graph_number": cluster_number}) test_acc = NodeClassification_Sampling(dataset, model, **task_kwargs).test_acc diff --git a/examples/configs/clustergcn.yml b/examples/configs/clustergcn.yml index e105376..42c9297 100644 --- a/examples/configs/clustergcn.yml +++ b/examples/configs/clustergcn.yml @@ -5,7 +5,6 @@ dataset: sampler: cluster_method: "random" cluster_number: 10 - test_ratio: 0.3 post_sampling_op: "LaplacianGraphOp" model: hidden_dim: 128 diff --git a/examples/configs/vanillagnn.yml b/examples/configs/vanillagnn.yml index 9af310b..9eae60b 100644 --- a/examples/configs/vanillagnn.yml +++ b/examples/configs/vanillagnn.yml @@ -15,7 +15,7 @@ model: num_layers: 2 task: name: "NodeClassification_Sampling" - epochs: 20 + epochs: 50 lr: 0.1 weight_decay: 0.00005 loss_fn: "nll_loss" diff --git a/sgl/data/__init__.py b/sgl/data/__init__.py index 74853e6..0fc4ad7 100644 --- a/sgl/data/__init__.py +++ b/sgl/data/__init__.py @@ -1,6 +1,6 @@ from .transforms import random_drop_edges, random_drop_nodes, biased_drop_edges, get_subgraph, mask_features from .transforms import sort_edges, add_edges, delete_repeated_edges, add_self_loops, remove_self_loops -from .base_data import Node, Edge, Graph +from .base_data import Node, Edge, Graph, Block from .utils import RandomLoader, SplitLoader __all__ = [ @@ -18,5 +18,6 @@ "remove_self_loops", "Node", "Edge", + "Block", "Graph", ] \ No newline at end of file diff --git a/sgl/data/__pycache__/__init__.cpython-37.pyc b/sgl/data/__pycache__/__init__.cpython-37.pyc index f7dfb1087ff8bf4cdc0c955df1617b9bc6a0ec45..0733336f1fd5588b2a396e227840afc945312dbb 100644 GIT binary patch delta 239 zcmdnZdWBWpiI=oWWyK~81~gi{2vu1Ik5bVdgreh`xnNcd@rOuok$p-}`9 zxy4#ql9^KsRR|JbkB?8x$%&6&$xtK;(jq!JlF5ok4ak&Xlwp)$6kwjbkV%S#iIIsB E02)0xd;kCd delta 194 zcmcb@x|>zqiI|2dg?7!_oBptjL8j*a%>=jiUcMv zVRT^U12X+Ig(rVtj8MMCT3V8sQ(Od*21?vwkB?8x$%&6&$xtK$(jzjtfXRwS1;~_P VlwlNM6kwiwfJtQX1tt~{1^^SzES>-W diff --git a/sgl/data/__pycache__/base_data.cpython-37.pyc b/sgl/data/__pycache__/base_data.cpython-37.pyc index c9854cdafdb2928aa219b8927fbbe2e4ff7cf5d8..7603232cf284e7f5201419ec498ec015f9de809b 100644 GIT binary patch literal 11995 zcmcIqO^{pJRepbxo_<>W<96Ha@i=PR$yml|I}jWg9H-+Mdon{x3^T6rq+z0w^F;tNhh7}8Hs?8$6h6Rf(Sg>G0Q5403#Vp<;3!sV$l?@9vVFBNF z?$gtgEZK4uNcHZ!zvu7Xd(OG%zFRXh1p~jkwfFt{R}AA{S?K?4WY+Km3W;ZU=C-lU zx3z8TnCqsg>#TJfb@sNilUdK0#v6v`c$rTOFLPwAXHm*}Ih1lJMUD-)V%Zbt;{ zs-|-B5`uJ6L@~*!Ih&!P~BWX$eNADOfyxY){v^Aemva^_4*Pq#_Z1b*-cu9P;!az5tctSHcVM{u4es+%R{ql*M=mO!H^RiMpT=al@H34 z(y{S-#uM`otv%;mJr`K>R@lKzS*zC+XMx+n^FL}e{lPT-tZ;*Ct=ZYxy-CRn=~<(D zL6aVK>UzZL4Xp6Sc5oA!et%57%)g!pbrQn`FzO{Q%OIbCSnh*2%$dM`#dKt@XOI`QgnSlxQA)_?kQbGNyosh*_dc(dM|}Vm4K$9NLr|TMFr~)-Z9J6tKdw)j=Qe3c621RK0#keupZ73*wZQ-i9 zSUD1{O?<_ajG}zPR^$?ByW{y$K`J%c!FH0J#BKKlJi(Wc&~UhB*1Bje;0+^^{IjP1 z7p&9J_2hNDU7>A}g-n@J-SzeQ&!IfQoXiaIcO4zImbJPIf@r8Y(U-(D5gR!gk4;@g z)2k^QWx*i=;$oN%esv;zEp>vsnL?r^)}Q1d9Eo+}Oo(-YE#9zVTdcl_TE)2;&rUqG zt+wYM+)5#w9dIRWLntST?b1X%3Vxt*XFLK~S3~A6;fae)lt;?K3zFEd5bCIIXWd%~ zCTBXzDW>rfQdiLQb^_1px*{6vp6MV@mc9Bipx%87GPdM&kYB;j2{f;w>HQSsxxsX5 zg+fQItWKkgqVZKU@M1zEcUN_w)51gb^Q_}?SHHl7POwa4lnwk4Zjwkwm{7!*ng`Jl zq8?sC1Kr$U5eZbnsgl$Gm$S?%_Gscse$-%Vc!DmH#N;mgKbzbljPZQm0E?kL5jiy| zzgm;3h(Ep)L;G(ii20-Mo)}s&k}$2o#lG3>Wjy0yUfn@UX2`JKK@VGuGfkT-t2`g~ zz7Ml!)C=uG2%{+m5{6XZd?w}AplMO)|!- z()i}<>5?HARp(>-4N4W)pn91JXHiv{tT3UfffXg0a5DP_ zZ?0d#YO3`?7uK}AAj+sl8~mRqkc$N3#$yiDg>><-n$-q3PYDVzb&Va&gUH0`OKJfy zwS5x9T1UTtdrKQu8rL&4u=vAiwF;Jtvr8#9x~@4F=esjx=LHSyF9XX2Z9;N?3s2vy zYVdtwO{c7WWit3%X!_pM;A25g2c8yr0`=cO(@_e1S%i~J##HdvhZe+6BX~F-S)?bK z9X=h8l1)>gR9~OW(i)n6KLtN4EFBElP^!__*oS%>yB|@8E$;l&p?&OV_g!|r86v0> zTn@AGQGxDwPIkUkn_@dwUD;)+9hK||*;{(&2`{nDUIRU&i=B0UfJe?_hTUJ((+P@> z()j66-$c_N6E3&8VRf~j4u=HAv!Dp-mypqd3au?$)a2RFInE3T3yL>b4Ra}BxqV?x zt(fv!WFwkL7#d0GrqjF>2pWun*)>i;kyH=}JOXKXj`bC1_j#k`vin~LR zeYM7f-m^$8g=t{TzlPTD;t6=VhBZIix8)_LoQExEdlOIcqiomk1Z6X2%P;+(*m4Jl z=UF)z7jQ_P^X8Mo^MZFGIXo|VC-vbOeU`jak}vz^ty%AMQgXay?~Js}`EzJ_DJePL zS?^_OnaB8Z-WMdlfPBSUk^Bkd&wCdnzvwSw^u?s)cwh9sBrPZXlW2LRUlRYK8fD&4 zjok;eCiOjkeDgFftMM{2T#Yz2+c+&dy}}XC%x1_lvD3@-++O~OZUP(w?d)q`3Skie zs}HC$C4(wzy=qYh<>=R@yf;}AvKqwYTVty%IWCj9125<}?N}fq*a0doP5DUdRMQZ2FrNqWKge zecE^-Ogz6$z{IA%O<_v2q)(AhGxFn`k8Gt0v%_^h0TG-24o`or&8AqIe9C_skF05K zstI`5^bb$td5&{30S`Co1Uzi|rxc$4-g%~vH37}qWHfC0XFPqe<+Rv_9TN*3WLhL( z0e52yTCK0cEUg9zsQMg#;qw8aY1fxb2bJ0cF|jEcfbGWg5Bm$#&@VWiW4C`iKr}>u z0{#4k$Q!Uf2J zo~m?*A*yWjmRY-qOp2-s?2)MPWBUni^G+d{ozSMfMN%A18%3h%Vny~&<6r`TMdEL6 zH-f+q{vCY>^up4Vx>g*!!P~sCsCE2W=zo~pb`00S=@<68CWpRgu zXigg+k!cngO0gg8X49}N6zl#AkG@5;cwU{3VH)L04AWr!H-U1V;lsxa zi!qie*5Y{~hEH<%fZ=xKHp&Udj7e#Zg|L|8UuW$I$E^DwJSmQ=4xH~=b?G_YH~kD* zcsj5~BP!7~9dQ#!>=x_Rm<-(VH(C5uCcnl+W?_V|WNoU~(dEzZ1ecK@UXU61>w{bw zQ8+vq`8$lt$+#wlBtOb_4Nvg9NRpV`>AoMYFTrpZE*D_93d^Rhzp|<8uWai2E1UA) zvnl^Q>ka((tT*uAvjhKK{O9GU_$?n{$c}7)ffro{i4?(Wr1wv;GFsSMb%fvOJs+Iw zWpH`C^r#5ERSX%bTL_^T81&j(xM!di_hG1AL#^Cn{upxYW_t8vDiKRXY7oaw0AWv|H)Rq+Okk1SyC32bV&D=2B z=O^?n4f;H`_HM+VE);R6Nu4_>jJjMK6M^+w&2XiACK)#xV_oT9zx%+iJoFEPN@t@& z(@8ZS$h|0Y5{w*-3eh>hHTz0;`F-8#JA=kXMbIhQtY_c9UVo?aaUd74NJY)+Lx2V; z72>NW+LwR>FhUh90!U+JfOn<)W||Sw31d8jBfjDt;F!c)^+a^Kp3)l#`cfu!5;3kt z!_kIN;>UC_B8T1ey{|+0;sczH3L~f4**#o;`Dyrw7}q-+8@M+KqBG;01>y|=*GG06 zABK$5L)nd|lr5UlsPJB6dtYBw%iRjDdFaSTF8vqb(!~vv&dtd~jc(IH&N9yL_VXznboVKzkEt==!ktMTio<_@3+kKMm84QsPGt%yF=m8 z@X6;}U=*Klof>9v@UamMI)8lg5gyyfjQ13`Y;gE0+VHsu4B={WyG7?80a7TuYa)p( z`rVg)yo=(;Bgq}t7Vfyv&kgh1Jc@KOweF?c{e91`FcMKD>Y7dD#!csUe+O{K``yYz(}zSZ z+rjFdPA?@<3-R`!4ilI0cue?Og&NpP;hG=C6-F{t3jmHyZ2rJl4VHy&ev?fS+%I71yoQHyz@yI-QmD&uh+O6-MfTj%QjwB4U}-L##X+nn9z^DXb$WlPfR zzGSw3urNjRi`X3n5d}eo2`WPUpdcu?4+Bv|`~z=-i1=bU@)xv$qxKYuK=r@A`e;(r4dCXLSrmP3(loxkZWT}ruTCGm96y7|+c@}%Wi zSpw!sxtH%5jq1{Uu2}FVOrMEKRv=+AHLc}Xa)G~jj~W?so+XxR!z*9Oze_V!ULsA?vkRt-L4ZhFku z>9u6jNIgui`$pt9T&$fgNjvCIpBiuEZp0)O;SbinDil7ls?Ds%OrzkOLQDRAN+k!% z)7b_NbW=59NI%q zRdvu|IeMalzghk)o6j07x5zsAt4Jc-h*1{=ytlljk;@s3!4I(xgPk=&iG5R`lezFy zZLn36R|bP8-CbU;usq33j`@N6AW*ri3L*&Mp|?Y04_0GX!-eHnXXbTYKnpCi_H{Al z6&bZEGk|l8-0M~lqyl;kCPc0Emt6-%~BEQ=jkv#c;ySQfKOmxb<0%GNa> zivsP2*aU%1VLcGYJVRy~GRN+O5D~;MDlu9tv^Ynn=)1ZWDNg^?HHF*~-$2p&c-!ly zy3JjL~p){Cc24w!bg@+^HxyZ&$R4W)oyw4uBbv~#;WXgP;!Pu^dfL{FFSA% zGcsaTiteN3a;!A5$lqoDGi4Y39UhI|h1oFXu(g0*pi>Qdq9b79fIBb&y+j}1WTGjQ zXIdUjKlcj78+X|@(J#s~MJsuY)JP+}+SqTi(Lo&8fUeLVjT1H#Ey#fh=uK{SpUp%H z9hiXLp=*)7cBYVV2PU9*X{f2dL@7Ej0lh~zn#STH`+|9!BlrQGZQg9VfL(Sl4Co{J zu(@DZFmH3nu2N6S^p;gWScRi?D~KELVMEL1nPFa|D=lY6cY-Z5FwGgiTN2m@w5%~sonOg`;QpL@}w5w|&swx(>cC^7N_0O1Ac@a3$X}YjxP1WD=Q#&Dlwnu*E>#FX@=3Ies!(v?Dj~z z?*)DcVhUnE#9=N9)nIwf8;d3LttwTWGjhqCkyh2dm4!`tNxRimCaj7*DVi1)P_*x{ zEBF*&%c{7jrSe9ZW%w74z05Gp#7-;IaI9?UvlASeq)%c8JF!H0l2w7kmO`33Nmd0))jZA@U7E&mWc*6Ru`cFYR8fVk zjesHQ2wOSpT~K+Nvn8n!gjpKtduCi5@U_N7qZoy;LWV^>ifS!6XaEfI_M5{+SHxXa zn251``QUFYH@r#Sb z2ih)T7z?llsKqxb8kiL{6Upr2fc*~%!)Lqw6(mYe6dVqju(KwGs|#5Eo#}Z!pE5?V VCnmn4{E}Y^N>bJVahWMR`;@;E>WOw<9Qt5dsoT zjK&-j7R2X9H!7wYjbZD?g^|S7Xwan#78*4%F)`j#j5^7D^WSsN{pa3u&!vGYgJPp7 zCdcvTRL4=}+hfrMrYyjfv<{qy=Qbri;0cs2#LYv0Zou0UC;OrerZEeaOF79#Qb-FZ zLuw#`x1ILwKJIanzY!C;nP##&_-wOZ_eT`n_a_tA65Uwllu`|g&irbKFFrGhlPM0I zaJFE5ss+ZeB~^tGMpGy8hM103qAFB!zbU*chJ`AwX3%ae1PN=b%TR@D)&lV1S8EyO zr!}PPVx+=!T%54hI$1L=hnp4yw9g=GA3&6iD!g!%8T?GH6j1e%SVt zm*-vn3VexPE|`pmgZQgxM_UV}nslvK2G+{8T)_VRg_Y3KLj;9Dw(Kxr)00#up^MJ+o%pV$446_>2nMjjI2*(K24XmGo?_L|xGTjl614o zOBf;y6a0i?4g*vs(_n%SA_OqtZiTw&BllJ_wBzd%86Mz|l4jq()3i}uJAn$u5(%*r zc9ZHM^y0Xu)JBJuQ#S9>Lt!27d!F)k7%bHs|FslNISIQtaKb7z?xE2>!hS*#f$n{? zbTO||8kQ00idh9AR^14x^Spr_LM@jUaL`-fASLG*7FAc6u4%zYtaNM!?|HW*#A@G- Wue{R<)Xc1wFSg^Luh=|k6@CL;*6`l| delta 963 zcmYk5-%Aux6vsWTv%2c+uDhxGt3}IS)2~T3M^5t0AT4k9H^9Sa;StlcHo$ zK@dSvhY0!sJtaL9#t=#WKrcZ(_z*)6fe=JW_0Ur<-Ln!nFdxo2-+Rv7d+wP#6SvNZ z-J)2W$DeFful)SASO-v~H$x$qlY!wpDX5>!pRwSU$&M|hwQwGnOI)ax4#5R{Uz&nT z7#8P*Fi+3mil_-q++tW|&7#+8$EYacOKUN>an1S||JZh6+Gc?u&fE6mimkJxnN}Yl zwBT=>tGtz)SNN;ua})t1KXI?DT#wj)0bIu##{%5OCyo@{MYB@{#3kpV`0#&59xQn= z>~`WSA?9FJJp5>f@ z&)k}aFPQMAg%DT1;=1=De)QV0-#3O)pIv|JW2W{lF0sQLMnO8IYKEX)&ZcG7O%XxO zrc;_>Qe`by+%QjPvvMZM$|xy<$*R8)!Kvz1*wV}P%oRW{-l%IDJWjd!nsQnKc7oed zHS3}T`U=}e=qDrzxdxmfYr7^{1B5|B9U*?dzA|J8Y&tU|8>Uo7V0yA68FIJ9(Xiql8)l4aIg-bG3pPwmD03oo1w5_^kE4Vog`=6ViLr*MhPj3%jX9V>lXJ2cpTguiJ{=R*q9CC5Tiki2x$((4 ziN(dK#YM#+5n;dl63^U%oYdUZypq%u*P^2Qq9Pukj3(>kFMJx4b@LGc&gOMv?3uiY*FMsxhH(LN z3j0FFqFp5{KsIX%$3n&wPDzGl##+`Ah8k8N%><>Hp)^aLSPf$eS2k18i4wLZ#ypM^ z_7v`B#wNxZrW)oNmNe#I22GyHB76!$QC!J6rNt$wMd`_Tw;1D#%orvc@#!&gOitvJ zFytyONh~TUF3&8IGCr>~Hz~Df@*X}lc_pAc10x3`AEOkb2$&WF(p-!}KsE;>*XHkhyBPru)JGBk diff --git a/sgl/dataset/planetoid.py b/sgl/dataset/planetoid.py index 4d1c580..a1ebf96 100644 --- a/sgl/dataset/planetoid.py +++ b/sgl/dataset/planetoid.py @@ -108,11 +108,6 @@ def __generate_split(self, split): train_idx = range(self.num_node - 1500) val_idx = range(self.num_node - 1500, self.num_node - 1000) test_idx = range(self.num_node - 1000, self.num_node) - elif split.startswith("clustergcn"): - cluster_number = int(split.split("_")[1]) - train_idx = range(cluster_number) - val_idx = range(cluster_number) - test_idx = range(cluster_number) elif split == "random": raise NotImplementedError else: diff --git a/sgl/models/__pycache__/base_model.cpython-37.pyc b/sgl/models/__pycache__/base_model.cpython-37.pyc index 66018317710c0c8ebd1a63b892c6e70591dd40cf..59154ffad496d39a0d65e6008ce128e33c6b889e 100644 GIT binary patch delta 3396 zcmb_eU2GiH6`ni0J3Bl3V`n!Zb{yB6#QE6}7Doh=gpfc&f=dz;Vjy9GPRBbFd!6;} za%MKf=|*mqNR?WIgsucHtz=WF4}C)sf(Lk{Rr}ORRUd}(KzSho@r)2sd%ipCwQ=04 zmAb3@&7E`aJ@?#mzH{!pe(s}*+{tX#()jnzTbI4m18?RA#9MFfeWK=SZoF>Irz>fp zJ*Bw`*SM{@hNsVEZpA8DmLy%1C1zKWV@b-jSYmZ0eJn}48J1+al73fAX{GF4>^4~v zVd7-HS-tx1heaV0cRm%vB9Y_~Oa4>;axu-lFx8@F%b9TkkNWwYGD_&o%#a-{v^>i9 z-aR#G6Tg>YRDd|Bk^7<1YzM7&uzt&QtLX>bJ@?Ab67{jj z%$j>WG+ggmt$L4@+4-jH)$J=ydd(sCVR_T|(-D}4+!tvo%P*4U;(kn3`>x#u z8_)rU!*blbdXnk9t(c_rB)FH7e=!UD(5rIm z{o>u8KgS(Us|Z)C=C~&c z4PHn`05CGl+1@p$-eywS@GpUBl$GD4$_JGL^oT#%@6B4Vw&z==>x{f)T`WGz3V0mh zTs>ZNuH!5_FaKywE{6SfH~5}wPlNXxFA4kBO?#GA_&S58-xxXR_}( zubW#Hn8!an(N1ymaen!0D^pbIv9Y0tc+T_kaC+gf>>>D%PB4VYS;u#RfapoQW$b{QhR+!`Z(mo5QHn=_EMiz(ub2&tL6Hjm zYW%@urS#Se#nmW=DkMXxu=Pmw6<1IFSR1=Di@GH zb-A29pj-*%I0EuQxfJ!Mqb+G2!48Qn$JiZ$i>XSsuJNI{5~whts1_B6bq7!>%E}^g zi5tV~$t5r3ay=VBBG{p9fY}IUM+_pMF%D?Fqc6c8Az)ps!|%)b%|s^=*%x1}I$>m9h!+_U{9(GaO~5&CNF%G|`WmE#&VuZC^af@1-0)%nup|JOWTz zZ*r>ejS(yma*`egfb<0BZdADFFa~;CuM_|kA{3L;+63g>VO(H6jckai2CGC^&NgLLkET==UH z7wQUJ=GQpsoCaTFG5L?JQzM8CUBH;(v|3)n4HG_B zlprd%!}7U-eWD<54tQcpem=0rC}RE&**Ey$5G>=vKL?328mab6ZvNx4f7dDbtHE>o z!GV=bQkS0MOz%w5Ixze_ncsHj5!9hlQrSa4=6-bjJ%e5bfJ;2O{!&r;M>^$q+Xlph z+%i-e+$d9F#0T=y(9sLfPk_mixe@&$f~kzdqW{FLNgn@+{A{QoCYeMp$g|sj77-}` zkA!K$g}P8P#{)G~TNK7vKR;V!n{9IFS}0Ql&AU2$VQKJ z`HS5{HLxH|VYnUC>c0962o3&fYS%qV@L=>sh-_1Pts$=B9XuVH$M_487aRv!yTv{u Q@_gO!KbAk1AJ3ot7pS*~7ytkO delta 3750 zcmb_eYit}>6`ni0J3ITD^&@utNXXi5(rg}XLen-)L#iZ=(-Ox?aMd(!)5&CKcD?pK zoIA6QZN^!Fk%9t3O)sbt2+5WZ62ExlSAT%`frNy_55(igAjA)ZM1VgagrKMr=ex6y zG)WOi?CPF5_sp4d&pr1$=iKkV{KaYGR3?*D=(p=%uQ>+~-!+EVox4w-eD$niNA1|> z6g%dq3(1YhT#6)`9Vbb=B}tP+w-Y2uv?LjlB<&PQQZ0#LvsopdzK<1VzWcKrW6^tm zWjk3+BSAuJi~OmH)q}(muP<38%X4^^<~${ZR;k2IZ*|Gxy`8xum3nBNkMKT-BbL1o z#EXvQHMqmKlPDxZc$r+{4%sd^u8TMBHX;g3GL1%B)m2U0ioTl~CLtdI85Ox$F_i|J z0of_;#`YiE0rq1=T5IxM#0Ofv;Vm`1hnGw()m^W(<{t6)SmjXYX4zf|G}~D&7as94 zQ?1)h#VpqOO^e${#5>wg*miL>K9oC%l@EZ(jbTskCy!Uku2-m6mrfiPUx^<*gy?XD zgg*t+iJ)|u8zjqXGFG9bgW`kuWG)X?7|pA^y9yTI`-wa!&gri=aZG#+WFN>R2;K%6 z(=1dh*DW}eifML3kY{`lO43f0IGu)zGuBz!(6hRsM>ymK8l4{m__=Qpag;e`D|4#d zcUz1nCbSS_;zD9Pd0UApb7?|Uzt~9J*c~HneiQ_L28QV@Ta^ZZ$V5CjR^@3ioSf8U z3{>7HE+hxGNe4UOl;55aE6Gpgj?o9S0UTF%N8wgJL9%J_$K=eI4D?56Jxy8{#qUz9 zFT6lppu1LesZy?$_;H$2oFx#F*DbFwZ}{-SpvRggA`TG~A0sd}Tqg zmBj=f0k_F~w#9tKRBM_qpI(II>PE;%eZ^)QDxV@*Y*X`9sT(I`&*S_gW@*#oR?Q{!R?U?# zJYi~vX~I?>-X;JwN5jRjARBvBjiu#tZRC{YIrp>yvqgTZ*Z;UR+{)k zi_b`BnrT0?rTPq?_=#D?OZw5TDR$=TguTx)7%Qq2XJaO>2F8@-dS_41%pwgO9%N^n za%uj0owo&PnIMi0WVwOfPK*LYQ~p$KK#d zkJ9WAX$GkWY}1*`64y=okCTcAt5`x)*7i%-ojiYW`i%Hie5@Jto~|@pYDZo7`4p*! zkj^^AqEqn7%Z@314s`B#X0cJJOa-YadbBpUwlhbWQhUH>d)NZttR)yojbO}+wd7GxUa zJdwFXI~Xe*#UxCn8E}{IT7&E6%A78rbQl+8inYVmX$mz`b`f2L4KiK+dsrWfGDGr) ze5NUtbxLL``1J2#QPohzd)Xb$=jpqAh998^e+J}P5Lvo)^nZtYs~m-vqmFVE!^ia= z{v2M&qXZ!*yB!D3;3XO|`E_Q*)6sNP*Yp&PNQ561Z}t2pCELbj7}3PZ-W}N>W#XoR zn?;>-L)_?n{xDi&CmoQH!4AG~@7KM*VC;>1-|4@=GF|5gj3PZS zlI>~(dEXY#43y_Sq(#zzJa3@KG$nDq+ zl4wQL80k_qr>&T{gwJq;N(j*c@-SV0fM3>9X&3s=7vN~Qaf!qM$ zAT9`+075&h@CMjrkeeW3c7R&}Sp``G`7{yQ{=7vM3tMgK=)nZms}PAHKC?Xjs7yS~c8m8W zMoXwZLBcInCL3P4;)XXv?R?#CP_M_$D7>0q1}F2S3lGU5T!DU^uIlu`3BavYwbRs6 OsC7j8j`vOVo%=UsTi>Ms diff --git a/sgl/models/__pycache__/simple_models.cpython-37.pyc b/sgl/models/__pycache__/simple_models.cpython-37.pyc index 48321029acb979e874842a63aa56721c40fe8ce4..4d05b2de9353e321cf73c35c5774984a9bf8c5af 100644 GIT binary patch delta 1163 zcmds0%}x_h6uxI>+D@lX2tS2tiqld`rxe1AibmhC_*m>oc`+Y+sV%($V7}tBs@g1gadDQZ}k!hkWipDu(XXfOa#oGBcKz032J zC2<~0iL(hCG?6GwRib!~<|0n_8-fol*<+Xz_-qGuCC;0-Jc(=zMp+wfu=`hAnCOn| zD{1D-zC^3|R2W~(tupdka_FbEEwyCF3HdBy``laz&M-?pmC_8^W?K}Z3#@2I!s5hCUc$@`=&ByCZ9Ch}7QcfIj+J1CGb;~l-L38I z*U_&B_0snaTu)!Ob^GB#t*(0wy<(bSXvTm+ewo;`u3WbreoRDhQJu;-^JsPx^~P4! zZESjbwWbl}eH1>RKxDn%*x2nV+}`IrOCD)!gAY?YfD~DIge0crj-kOo!4?hS&|nlB z)Poka3~fo3#QEfAL{k<1SbXnKi~B>jXUqLXTn` z(t+kR`xit1WDT`KZ4eO=RY$Sz;PBTvtl@gc6q=(s`Y}}pH5%-O>dc0~&M<>PpPc-aJWb%BeJRGNZ$I;$dPNswpCn}%iT?1B8AOv?}D zd*Ge0OzDbWR9>@@->#jWH$#4O!cVoY;Z9@u+RoH1OtDkjoV4{2jYlX$C4muwWslLA zz|H;+ce!#OR^+Qn<^RK7`vhHRi^uB!mCrD==Xe5ssx%KY+0u)mI_=1G|&Lwg3PC diff --git a/sgl/models/base_model.py b/sgl/models/base_model.py index eb1c33f..9ad19cb 100644 --- a/sgl/models/base_model.py +++ b/sgl/models/base_model.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from sgl.data.base_data import Block from sgl.data.base_dataset import HeteroNodeDataset from sgl.utils import sparse_mx_to_torch_sparse_tensor @@ -77,6 +78,10 @@ def __init__(self, evaluate_mode="full"): def evaluate_mode(self): return self._evaluate_mode + @property + def processed_block(self): + return self._processed_block + @property def processed_feature(self): return self._processed_feature @@ -89,10 +94,12 @@ def sampling(self, batch_inds): def preprocess(self, adj, x): if self._pre_graph_op is not None: - self._norm_adj = self._pre_graph_op._construct_adj(adj) + norm_adj = self._pre_graph_op._construct_adj(adj) else: - self._norm_adj = adj - self._norm_adj = sparse_mx_to_torch_sparse_tensor(self._norm_adj) + norm_adj = adj + norm_adj = sparse_mx_to_torch_sparse_tensor(norm_adj) + self._processed_block = Block(norm_adj) + if hasattr(self, "_pre_feature_op"): self._processed_feature = self._pre_feature_op._transform_x(x) else: @@ -104,36 +111,14 @@ def postprocess(self, adj, output): return output # a wrapper of the forward function - def model_forward(self, batch_idx, device, **kwargs): - return self.forward(batch_idx, device, **kwargs) - - def forward(self, batch_idx, device, **kwargs): - sampler_name = self._training_sampling_op.sampler_name if self.training else self._eval_sampling_op.sampler_name - if sampler_name in ["FastGCNSampler", "NeighborSampler"]: - sampled_adjs = kwargs["sampled_adjs"] - batch_in = kwargs["batch_in"] - sampled_x = self._processed_feature[batch_in].to(device) - sampled_adjs = [sampled_adj.to(device) for sampled_adj in sampled_adjs] - effective_batch = batch_idx - output = self._base_model(sampled_x, sampled_adjs) - elif sampler_name == "ClusterGCNSampler": - batch_idx = batch_idx.item() - sampled_x = kwargs["x"].to(device) - sampled_adj = kwargs["adj"].to(device) - effective_batch = kwargs["effective_batch"] - output = self._base_model(sampled_x, sampled_adj) - ret_full = kwargs.get("ret_full", False) - if ret_full is False: - output = output[effective_batch] - elif sampler_name == "FullSampler": - full_x = self._processed_feature.to(device) - full_adj = self._norm_adj.to(device) - output = self._base_model(full_x, full_adj)[batch_idx] - return output - else: - raise ValueError(f"{sampler_name} hasn't been implemented yet!") - - return output, effective_batch + def model_forward(self, batch_in, block, device): + return self.forward(batch_in, block, device) + + def forward(self, batch_in, block, device): + x = self._processed_feature[batch_in].to(device) + block.to_device(device) + output = self._base_model(x, block) + return output class BaseHeteroSGAPModel(nn.Module): def __init__(self, prop_steps, feat_dim, output_dim): diff --git a/sgl/models/homo/__pycache__/clustergcn.cpython-37.pyc b/sgl/models/homo/__pycache__/clustergcn.cpython-37.pyc index ea288f991ce281dfd802490ba5f786dbd5759ab4..c1d662793efc8e618d7ce3f95dcb387d2020acbf 100644 GIT binary patch delta 164 zcmbQwxtxR7iI`1TR5YbQ`l12 zTNt8PQaDmLfi!D70}DeGTQGwr_r{&)80(oq#sC3`W&~nqATDMB5-AKoHH>~mQb4vQ zbCD>Jy2YHBl68x*0xSk0q(Q=51x2X^Mfu68#l_q}elZh}U|_5gbO(87Z)mcFMqg+p{)4^ozytpP5<*D$X|MbTXc7|Ncha>5%~5{-&iQWVd^fKi|K(VD z!m2P{7U^vd1QGJR*WzjkdQi*A}s|{C`*Kru;3_SM-J3bb2LF00zE7`#n5mJ zu~vdoXgcPWrsY_|LfI*ccg3lS_mESYB1~o4Dq)(ht;o(WlV*us(f`8p*fdurg1B|} z$q(W5HQ4ZpLm6=-w(l{Sr_5lb)uJP_5oR)Ll{y7hW|dXqC~TAsu^Q~CY?#$yM`L@~ zUU=(|Nr5;;aIrsQ^zYFk7;U2~LY`j+ca(DrXkNBnwOcv3afgGy3#J)+K%p$^zsy7W0(<$kJ@ane~-_f5Ep(nTT{lX z=inl?oU2LyPS{ELf$R_EL=+dE%R{@YQR5{090z%`6v*1*=@y9lJb5l zbv@SQm0U>b--pP)hHE+OM7iqqdVa)gh1=Ob`k})j=G%eKt+pTe+)I5o z_4>&>*@fbxYj2>%>j+l>?1H!jZ=y*>gu&7KiSN$$gJ6)tUFO1=`GGqh^ZOoWj zV)Y2h$)Wdv&!D^6OfyfP$7qSl$<4&~cxxevdFnGRV)hv_KSY>B_z2-3!c?|k9;{(u zIYt^owYr@YVtPCK&RjSn&MfoWC@2chMM~tVJRf+;!k{|ftdp5-U3wP_2_Nin91Z9& z`4M1=G77zZGbKx;3%$yuyAp4v5|guUt$oW9?rvz%bNVLPgyDLxF-L$~SOFS&vqWeX zw=~y*2Xldv2d7fGTi8%Q#@K`j@Kya0+A99X)}pYbZrOrKC%Rc{-V}LSJU*BGZcS#T z%Gh#=e+-A^697*j{*LNRCis`X`@*nAq{3xqEbx(ERP%zC#4 ziBDR75WxK3z9E$Pdnkk#<@k|6{|SOl%fm)ZHw{%E796{t*OGS7z`AeT@u0bo;$jGm zB-zvI+x`jA%lDv=Xt87)k0k* u>W96D&@^ujjZ9#Lic{@eWvAnEeGGi|1AiPFI;jDev^E$=k`DQSpY$)RXXDWT delta 2074 zcmb7FO>Epm6rQoyyI!x?yV)c(ZIcFQQ|hLKs)GK|kkU{ReyUB=7Nv4Q#@+G$)%I@3 zvox&&kw~QGTKPgbr6(>_Lh7wY&KwaE782YLM{bA{@6EQ^2yrNu=JT7G_ujnUneSiT zJZQ}qiw1%32W!RuX867}wjRNZ5k>z5IiB32 zf9=&lQv;O~^4mGkPdPV%=0@_Ra`I3X$W(XTR@4%%%QHAFx?S$O4eoWDZl}xD$G<2F zrK8EO>K$zlFqp|Fne$5;Pz2#_AnmyV3P$6nYCe3!LqB3Z~C&yM5%M6#AURH?Y~_}v_ja?@9}zHKV1r>lGx1?CW> zZ<&Gl{Z+Y-!AK3Ln=bFv{5VDfVlEXvt8pr5^6dczL$1GnDcw>{2gA0IHL=*SZ4qyVYLTaBiWMY2*j zxhP{mWqzM#HKi_8jj^jdFu!S0hP4$n;LxZxC>fd9OUaLgGvx!wr-t4tzl6svCllt; zQ+tubqDjrz_js)tcewBwm%+P&oAU_oAY4G$kFbz@VUAa@rW`|bpiteGfZ$$FzBQZE z(q@%kM@CtDA0s7GX4L~PZuW}=Y@JWe70*ti-}qtWVNu~(@kIbZ)<_##jM1A4Ul0mY zl1IfohbgXZ=ho=9{*XMRl-xSJLSPqKnMw(0xJ)%UYh7Ho^L!B$@$&#s0lXE7UH;T? z>pg_42n+y%j9&3Nf=t>VbLLKcgXtduLyR{Yp_ZyEBXor7%H5HR-rYNk)#^;2^#t?+ z@YepPhttpJp)8o>o;ADv)XVqL=qAD-XVS+HaP7H#EDl-}4EI3pnUm01Pn{fn=HxTp zqV&=>syqh*$!ckG1fMu=0oY2lYOB@if?Y_yD^0cWM(0(86@+UD_+Ic9LK^@c1h)od z64(47fJb5{3(yp=BNJ8<$Cso4BgR9Z=ZB3U-88cLux`pdyOJ!IPfQ$y-E=f=1T(Vj z;+Y#B^rfV^l5Cg9*Hz%?KSX(C>xotn%e#%stLsx#phdl~dxtM0U4aX+*5lld#7w;> pdfbou;YN!F#tu3#)o0k&nu<A^qjw8&_jWq{4e;_?+t&56w)F~&D%FW_RYL` z-+XU$e{-|p;JNtIZ<4QCj`JU?EFKriJ1FWC2;m5pIU&QFXME^}9-}rls|>4Q)#@y( z4Snt~XIKv#w#CafhRv{P^-8umY=tf6Jaj}=)Sfw_cE-ak!S_%kPTg)2O~Jy8)=yepz6e7vioE*f~(#I|UP&1Wq1MN4dTKw8t^j9o?wQiBtk>8lx z{QA-+wZmfL(3$c-Fp3@=uJ?@BEAr$)xK+v|Qsj7=52B(B+mGThi~DJuKa_EK^y_i} z$MO989zw-rSpV?Jyd!Lvh0?1ZcQ%X{A;zX@EhaL`<6*L>MOpHBl9}!3>Sw08x(MT) zBx7aVvM5b$kmy)zY5e(|%Cw?>kt;1neLb5hqK`{i^b@5zp7B(Y9U7NbQ9XGQ%Yicf zfINY7e|O|fxbk9m@2D6idrCdt+kg0IHyc0QQ-f@8zUaLJU}5G~BD>|-_)(PRsg5E# zUg~w)4OU(Jc#}1l&vsCntTnG)N3F#t*H=*0T^Q*zGJuC*M!JO;Vag~6`^>E+(kvjt% zmrt9gOyek>u|M+1-u}!nV9QN&-Qk9J-!zR&dslpa(YTYPO81LldG|+Tm8PgVlb@_@ z0Z=oK-)v3YJ0PGqQ+J^(zMeezAWIPF0MPk76-i*{46F}9ng{wQ2@b(uL2(!)DQYr^ z1Nu5Bpye2%lR+YX(5Xpqts}P}Ohw@wY56MYjgRAOl>9=6l{N9pQ$Z3=dKmDrExzsI;L#${;tY2y&<-y{I8B^ z{|%V@0vQeLw%rTs^HXr34AkNv(&I8QbxKHa2AE1ZcmSOjKTVaXjq;RY{hV4u--0F8 zW78NSQ{_oII8x6EXl7HT!$=>=L>(2G82shW^!(! z?-ltEb)ZINSx6mBJ|nmt#0M#j34qD|D2sJkC zxjMw)R0q1CPly5!RJYSyg`;gVF0f=a<|&o7KoTAd&r9Q~(a?Cv1Pa(aipxaOd6l;z zI`tKd66uA0zsMr{{f}wTE{QIQJ?d>Owr%G6RYY$@vq(>4|4xNOBB)K5pqV64Bxw<0Oarc1kTcp|j2yakF0V;~|zkup4E^ns3Y?CjOm2HS{n+DG> qZrzbiGi4bptW*1}PN=f$`!=VHe`mHycZo%8g2hD+l(a)6#QzJOI9XBv literal 963 zcmZuw!EVz)5Z(1U={9jh3mo7N@S*V!AcTsbLLdt%J^8Y7z1uX!_PV>SpjA#ZXa1u< zz#s4vd*u_jabjj2M@iL*W_CQ;ef#E(eLR_r2-vUtFX}fX29~B7+%;SpoHpZFWzgq)0|3$pT4bup$|gOop(Aav+DW4&+#Jxxb>>P>uj8a^vIN zs##t&r7}gZ21Q8pPvG_yge6@P0TdttJj6f-5ZRJkhUa7%{G=3UBDK_osjk(Bc;Hwr zFI{9iY*FWxa)+YPwiUj%sOLgAE^3V06kH_mu`#+(*2c_*R+Z-lxwzriZ@HP}RWdvId{o|iPwc!**4HIx&{hYl%(S_2T!=-zXobKYt=FCoyI*kod`pCB zzp;R^c*7Hzdo@Bxe5W+^JmSD`D4Bf>2-}4*W_u8@WWBER_QlRMTa9jn@2`4?rx3R# zKR|F0mZmXv&r;of#8FkMZk2p*j5ccmZgiHn#RW7e?E@Z;_j*sZ$n5moZ#XSbYBXH2 z>BpHEM2eYj-Gwk`>ke8iY?B+1y;Zf<<{^toH~aN$4*;Oo<>js}_=c=^LDQ%49u5y* zA0gaGh)T=r67xuitLwb%{xJ~%dd1BX%<$9MV&Es^hk{L^dwDtKoHNc0R^T$fK6cFU N*WJBuaOYoe@*f=**RB8n diff --git a/sgl/sampler/__pycache__/sampler.cpython-37.pyc b/sgl/sampler/__pycache__/sampler.cpython-37.pyc index 2c2c6a84ed36b0bf747b44c149b91d619c1b05ca..e523474ba0c3d43bfb0571d88e8599a4e6565fa7 100644 GIT binary patch literal 9898 zcmeHN%WoXXdGG4i^gKACxXiAo-6&&ugJ^6m1mun6;zN?(_f>b#km73P zZ1$EORCT?&p5Np5eKqeaEj0{0fBt8`7l|(z#@|z>d~77%LQol^&dr4H;m845QrkgK&f%w}WYv zH)uBgR*?3$^SYl-{CCEa{++Uz(p@XJRWu&t7JcU4xgW@(!U@)WzdsI?^8FtgKfiQk zd$JQXHjZTw*GiW-iu6!R(d9ah849C6h(%;DqQ}Z7qiucQ(Q0C*ry$! zhu*YF*Adh*A{<2-#K6K~h|R7k&!9%0Mbs5?h0-rp(Mw3>Ub!WdQ%ox-MlKt4-5fhj zJByue`2Nl$+#OT8<@*^2^*mo)*p(V%REsdM$6A^M)7caYa(wJOt#1uGz8BAqT)IE zBD9IsW|u3H-2{!kc0X28Cn-iMiN0zge;M89Hc|7&X4jR(j$Ed9&7*9cB9dKQOCQ(^ zRIo|~#lGW7k8XAC{Hza7(tHK>if56^n|?nS_x-7yY$-{*Dt`qLXp!lURx+mr zCIV0pF(*FgM~dZ~qMb3!7Jjo=o?6aA<9>OHv=#Ibv%?=DkzzRqLbeeb`vQ_qq~^#v zFy+Ne9NI@H8yfqNeEZhk*HQy$#gfpT4hN@vW$V`3f0w4 zAZ75NGa2Z@4iQL)YN`+3kJ9Z)c)WvhUo@FPs#UFjdpE^!>Ui^VysMqzc(N6Yt5Kpo z8aWF3+Z{F8m3cuNIsXb|7*Uc1F)^B6n!!Y8pszT^c27B30);G|4mysTIra4ol@~zN|vJt=t2V zj)Sc+M1HK|M5RH}k8&rSNJxJ77I&6A$?gthG>PWCfwa$SSNRM6B#Crpu2Y&^SulZ% zE+d&j2u>#|EooblpzhX~6Y@rRBB9DXOor5u>MiwmrJoK{pAG@cq&o#MM0>hU$-b`Z zHSs~gdf^F{VhMa1gDDclM$>~*X$jA?g)NrFi=1wlPm79D*gyGiJpEGg()x+Pn7JCD zYmyfkArSFhb6~C+`&MR#<{eXBN+B^aYh-~b?BqvaF*_9_d*qB*VeLbEznR&GXS3#!5XRl!2h^vDu#(wF zf@Y2M3z?k(qaCt{W~ge1j^SXVGhC%r>W%5 z<(H`&hc!WNsofntg0&MTz))C5?gD~^lbzh{Z%?pExl6p8gn6?+N%{d;8X>Q%C=KJm z;PtTEV%F(3paSj{9MYKYP48^6g9U4XKYEQM+AEl;*Ors}m}nrOvC_#Dn2*w!3NpPW zt8F^1XixqIO}t4Q&m0kfFLff+qn>29F52HfP!|vxwH4Dd+m>x!GTWjp&Wj6Rxwg4% zwh+%gUkdI5iL9R@OrY!OxVoNP!%Lp0h%~(h3KuCwj4f$g`3gl9t}9^zBLo5*^f4m@ zz^l6Fc`q~M=%ciHN=8>-LD6$TEVgYH09HM3Z$O&gxV~Aes?yBVH-%pmUlG`_H|1~P z)7;@R-Y6jxm3<1xw2jX-WD*)~&t5t{_y23e#ALtGwPlAU{yB=C2Ny{?=LBI94!wMG zs-K94M(E)<8d^X;KrJJ40r^;rd?4zMAs^=$^0~;!2Z(3^`P3GW&yPQA9M#V;>W#Da!Q1^Gpa2wE(H z8TlF&tRvDO1z4~Ulvk*%MW7~53fxko9AxucW@%fHB-%X$MY2p|78{6V8JK0IL@Xyc zm8CMUZ-rIJeyXs_Un2toKY>-W#s|KDqa(CF1Q~}QcLM$+Cj5OTbI``h9L5-A95|r^ zjL}P7I6~|LI7<%9eV~SscUVVD;0)jr#t#`>BO`o&WW!{zGZ*~=%M_@gR$wKE%e~Ax zY^6(w?a&26F;`*AF#@qNe8JeF7F>Udl|B=Af#ZU~z|eJ$F9?w=LWxc~>1?4$!;T8{ zoXSkQrQTdXJqvvin?48Xfsc70zyNmFXdmPQ3pD9J=U=x z?AvJoB((Fh*T=gmjpQ7C%zmBevnYprV0_K)#LJyQ6u|okN`^TK_}S5LT&ss;u!Vdi z3|qd5#N1t)T8FR?MvxB2+@>zOO>68n^-vGn1Wr;hf|Jy5;k|@+8&;3up9#rN>(1V< znCLXi0VIC4^Y%$PzqA>p_b2ksUgrk6OuKWfCVvOBlr-mC@GdQD?-G}hJOrBG%g+tu z1dz5LPapW2k$lZQlD$kGbwC1?9Z-2al>vly9PYv2H&55@1!FGs{3?}3FmkCOe5Pfx zU75yy2%0PWS$QoCQuw;lXZW=K3R7U8rs2F&_(MHGzuEbxPWltDS_fUh;RDBZ;l+X@ z2sya!*{20}KTg8g%-elULjDeWTA0h{g)juRBKd?m!!j7VeFte7P{GGLwiv+u4caI= z6$8e=?7}GvSVV9oi_ixYt|O>FK}2dS6i`C%B8A1k*uId0We#A*Cd)`cEs^u?&^m%| z7;n3HJM<2r#C+NamporE)=8ahp{H6V){G&%V96?SAlo$>sm`S)%z84ssfTcvI8f_2`1}Ha2b`|B9DspaH%=5ZsZPjI##pKHQ?LA1p)IWd^OC-^c~w~ zQMlfiU2J@dt;A=vxdqGc53x(~JBY~QjFU89@wIgczkJziejfhHB$A+DB10d_B#p_d zg%R^LBJxBHBI!rrFj787EsRV%PNSW?t%e03Xy;4|d13*bTcTPbex3D@ou{}%SODY$ z?KMkcgBr{|<}8fgt80f4)Wk_yPsr#e2wK+L4M_QhrCGu=UC4S5hUKbx8hca#iIg5a z^d>5!z|P;)cnpd<5RVOi;bIA85OYVyySVQ`jT6$9f<2p`U~wk%y=4*0E1oy1t}Dpp z;2@Bk;Z?+-v)e21%o>O2e*_Sgi65AE;Z^6wgU#lO)++{4T{Ab=za+`XhM7CR&mdd( z=Z$Odv!q0fZ~QHdLPNU7>}%IgDu?x_>0NfQifHgd%o}IQmpq z%0t|VNI0nNc)B>JGRZ2CGTH=Owp z`H{(IW#ax0!1h=PDb?d8&PQTzMSt^=xb1$=cw{_+E03>GywpWa?a({I+3P+>ppk1d zM(wbU7&+L7Vqf5n6aA)*vP`%Co_GY+2*vF_ghV-PYAWKg`th=ts85{#HDm9O%Rb-) zl<=sV$hYpe9|`FpWzmR-OOGMDAdT3?2lw62xK*25t)rFwj8cKat8p~t6oU9wG~jJ| z1>1y83c+NwQRLgF)X>zMg$*LLA-_QpO|ICn+$Oz-%Wrpw9Jf+Wf<@jgTKy+NL8_m6 zd>)!7u`s^8PjeA+53i3Oo51jxTI^87f;^@aImYB2M7fKbqcjS&9Vn^9Q@dLXN%DFG z^asQMN3_FQT>d_N_cld;K+$bPkB0-Np2u-#g48&hbA+2vn^%d3qhZA*auI)eIlM0@`T30)I@y?*ECB+D)`ExOOXNOc1C){B)5wG z`WpD~us!doEGJFgh7o#^~%*wJ51-aXEA1Ta;Wg?s|~q`hI>| zN%C6-wj*pwo&px=aA|z=fkvnU@-d{HOVM6Gnx-9OJ3w%J;+^~3F=JR<(M%@VWWtS3 zd0#VruxN!5je;cUa)tS2v-^Ukx%a7r>F`&P${X)ZlA*q5)UF4IZ-JRe+4M>jP<5nT zZDfGvjcV|{+=fM^^dQUBFA;R!VEjs7`E%m#IvOaEu+HmrrJ?bQh0BoK>QBa+afqw5 z8;%V5|3Xl!hzx+0Hf}G@;SOUJ?+f_NzIX~j^3bP1=Wg@F6-Hk3VGs84j}S6T*+kNJ z83d9f@rI-sJnL42jM^U}DB5JtvPG-yIo>(%jQ0{7!SLwN7Rz6w=qrfu52C_qmcL8c z7MsB9+5`sovrkFVCil+ae>&^h>mt{;V5kimqW}N^ literal 14035 zcmdU0OK=>=d7jtq&OWeMfFJ>o6g7H8=Gqh}`5`-uB2gkuMwSdIA}z&RR*eSRgT*d( zXO}$#64)%9ia|wk5;}6(iOX>wK$UYy&ORhL<&cxFN#!u7ROO^YQaSsQeE;7wvkyX) zb8{sP*Y$-r{mWN1?O*62y)0xd;|_m{L})_qY7O4t_X-W0^Ln?~D;b)u^~#M3mzdqTUbRu>yw#oW)fzR<=ei5M#l~Xt)rrOl zw3P2I^-eZUqO2fnQ52;ut8q${MPpFduEbuE$o;mp%NJU-Ux%hSwW(eTm0KPjuW!dhNjV z0MwbhB2EOFy*mm7^(2iU;t|aZbf!Fin z%H$%|O{}Kh-Hc6aU7X*(_b;`doxix%@A((Q@Ycl}*I#+EySIBW zY;`Xtt){yrP0^rDdg*h zT*OaY+VCO}O*;@Fm#bu%s%#zww)+g8Lf(*>F4Gt4d0G-$0xU_%b$flW)1`dHWinDV zJO#N*P0uB3(Db`q*By6U*X2pnOX{DOn#RauxAlt6B~6rM7f}XX;bkNbA~TT_d5~KU ztO^Fct2gqZN(_u#fl1KD(#*&6>%DqwB8r~lXu#o?=a_~ zkCYa(SL8YLK*T1i$ncX|vl=}vAgPS_UV6>%b)rd|W#yv?VD>Ks?(+bfC zT}0+YSTExaYuIbzQ{qvP+sf@5yPqHFhuV>LReS6FzBw{S#&$kj8JUru{PJ^Tiu}l2 z)(-S<>9*C2KFn7nP*Fzng9kZ`#4>q@G%g$1PHwFw<7v*CkpM=BjpB6mv)V zehx%2e^(pju3`@Rxls=9mjXK~iyF--s$eXo0~*U5L=LaWb{?%wXT(Y5OPoK;<>lxkwKlRx#ZhTgKFIA?MiuV8 zlJ#E03?CZhf24_Xe`?Z9`TLxB_=rU6DM50>S&xe(-UhP25mL6l;YND{KQ4E@Jzu(E zd+3L8DcI?`ToA?u=?}VI(;uFGDYy+?BAji1?@nI|2a<57=LgXxM-oHDwO2fPv+V`f zN&LP#XrUC>KR*w7smr+d`kkagR#5yi{X#P(ToOwp+iu_vi6dHA(6=i4dZdl?1O2{-qBrcLkfCU`-lBSI{vAo?j7I-IUZ=&i0DvfO(0W`KO;PC?*wX~76{`iM< z{mt{TNJV5@4XaoC!KWezDs9l0ku!Xo#<<~av`OPTm`HP{>qYH;u;yIp2O$P8@rM6o zRui;%_>*6GfpB1I1mPxLZbwekr%x2Utlg|H#}>9fjOSl@^~wvcyma&H?#pk!@cQ)| zap`O8FQxhF3t{L>>iEl2_GMgX`H>ezGPaZb596xa#E2rf(~MjwRu+Zvf;(%CKwp!s zBOk$ZmcgI&vnq5DTk5XxlMjyr|`iRSk)kG^D8^Itp!<}Ai^#jm!F5L2<`iaWq zV@T?AN(d9R#q;67Z?-qvO;;)SILE)S&4vUFB_4rv=v}+n@4D0j?+|a!cu(SLx_`>z znxRUl(g||=6q*aaf<)7+_;=o@=oNhtcSWyqeu;BbSJiC;%G0RguIME^FCsmqmn?c) zeYa@T%;DJw(RZc@iMB{lT)`cF3yI=0Sp~Og;5vwo$mo~{hI|x4#5zJ*OWTL7wr}n} z6G5PM@&`Jr5o@F!7NAPBjy&4$?Zugp7PT1d(O;`h}eR%|J* zaTL}Vui*Lz|A7kzu$?YMTvR!XJGGewv2R!}gDt#n*W2j&aj6})!P{QY^y3_(nK*xw zaZH?pVj1kkIS3fYMCRWUF4$8kn+%62e$RV@fgcE%nJgB-=Om-+1;w{3pxW3*pFDuZ zTysadQ7Zyx3UGU(&Oj0VuBr=VQ`I%f7_sWg7c>0HcI`KaDX4j^Z0mWjn{8OSrPqkv zkY6%}_gP}82`!16n`C0rMOuzHjk&A@;s%#MpAiqdV{96$8jLVQ7%*Iw42+ov%Jt%B_ht_#X%&wb@_EPA%B&U7bxKlOyGTBQP>Z09{eNvy*S_8>SK-M z*Qm!bG-}fW6XTAHp&yC%=H}C)USWP|l%RNSD-MvK#`}$hWZ4ql15-3gfxnw@P@^j0 zJpdO=xCbJ53UZt$*Ft~LDD&dc7Wun8&bUk_qgvXqF_)}y(jCK^F}Ash-gj|_1eCP= zqG20V(=yIOxmESk`Z@4f)u@5*s>blKRAk>!P;TuQX>l`SI1}xqclcR6ya$F8+ICT% z9eXWCP+Y(TP&{%C*vJ4#=c-Hls^`bnD zoFs>ZJWC0oVfhjztCSD{N$N#DL`gEn682(FpyvWpL{rOpI; z`1za(e@^|f(@AHi)AIB<|EGKjIhd31KqS6|98_8%^CcAEOMqJ^^CeJiDf1%D1{WcyZP&J_g__{JwP3TyIH9w82tqD}U#Hc#=5vckq2yNQbJ|5+YNmKiiplYS+ zoXz&GFET5eVHhEDpD-M(a~#unk>e9)W~+RBI&&=cCO8L}sVv9|3@~A&Gv230`Phd+ zH@BQkmVkB&5RV7K=NXi?AV7zya_Z3^vNi31V?lS-yc~}grH=!2E>|R!$N|H zb6jqZ%)?4FcUZ+?ggr8rv@`G+K_QzXe8C8^;9a?ovj1B^B(@5^A$EN>Px*K@=)n~IC^kZX zNYVCid|FhHw3vbXxID4Qz}Cv6-LL>N!`K)+-6$jr+bE`sCmN-s%pE*E%P7Q$|B5># zAfXi(^%by4m2db|;YqE5VOI3vqq9Q$eqYr8OLU2FiOkA*YIctQd*G6MV2SYs%Vz>y z$_}M%12NW|9jV z$fPh*=na!S4-ln@m$2w0mJhJ=qQDT6jFyxFLyl4A@QW;Pl5z-qR*8IealzU2Jw!)9G{Ke3r#tLbR_eD&mKI&)8lV}F z@g9;*N21!*$OLh56cFJT!5VPyAy~u8f;Dm+tdR$PqH|pw(baRqzsh9aac(_Grc2HX zv*PLex*y?4aeLR{1ETuaHXs$AOaSJ|wPY%ZR68FCJ_eY1U%3e{IX9L2{I=hOmUoVM zS?*w>H*r*N$9CFDoL8T#!+UyR25;l&DpjklfVcoOXIdp!kGxIE9ZHsvfQ(O4o=AFp ztiMRKOh`v7=Y{k4|>9zmy#9 zFGcJ%PLA|pHHt9||0ZBF0qs0sv53sb5_!OIcsTCS0eOy(;Hkcd`fqZ5k$^bWp~pN% zP+~|(K>~SbL(-nsVLa(@gdFl8BW9vU&bX4bCPs; z@<1d>t0*5wlIW`$I7$5ZfkV}wz3MV+T95}!W%r^Wr7$E?eYpTu_17Ebm;Ugs+Vj|j&1vz$$r9OSfu6SRpm4=U&Ugx8L z2fT9=4qOsWvA!G6`& z06AM9K70&gOd4$Wi3Jl4h2-su=(GcadQWpFVl+_(0YNceT$4)zE63E9(9d|Y*>>Qq z8NZbEYM>*s6^NkNxNc*S-ddzUokKbn0bCsE-Lqg5#VAnENCOrKH*k- zFgd0)&&KEhKK^wYlagc=efapt8P(MK5di01qsg_k_Y7C1I_C{eO-?ONbMNcBpMKBi zu;U=HooW-4()ffcX5bJCcs@gDsmZ=EIE)42q+^Wj?T6s@&ae&cnehWEDn;% z4votE9ZFX5jr#K2`n$041eC~KQsNawIb2qlE?cH?V9%e`qPeL09^&Pwca+s&JRl4N zIfH8PYw}w(M&ziksWQBkl%?-4pq@vC*utVZDo2Txw6=I`ZH{Vht--h{f+b7C7cyCW zB^9?MyPQmR;oMinI<hNwPN0#A^97PU4tb|SVLxWgI}9=i=G zHaS`rOp{}sm9aFy-gtmn^VBpg4L48oVR=wRFam}5K*vtV0@aDhC!?|6mb50b{d@;y z{c%6TpLZcL-p@v=gh;Z9kE0u2h_iH<3lkU=NZW`UwsTc$8EPxBBp6KMny2Yyl8ZDJ zLiqAM3R0138rfD3#xwXY_s^#SJcx<@<@abTy8J#3ZkmiH`mXDZqWXgY1P!d$Paqb& zrIfDxHhuUmCGS!4yGY^^>1D=Dp&ScIJKWh|10>GlFA*Yy%=`hpT=Ze9!y-VeL5{Vz zB)@iS5e3G11d8tXA)C)*OoX*#r1S&oktbuy?^2#;F$Vm1QT8^ipo5HN5j;x&)r=?f zv&LEdv-%SNPD_Xt7+&}!tdi{p=`cDhxPm*3k?j5>sd&JrpK|ycG#N2ARD5b+^E1Wq zQVnEYPQ^6STSxk#eWZcrO>&dM=~?QOQC~qY4kC>C#13x!O-O5?YLeLoxmD2;Y)`bL z40wdzsk1$V;*3k@vr^>vU5Y?n$V!pRT0qE6(gJdjg*IfN`ZJ0_CJS#Srdn#BvPgvM z^0hr>i&7veJnHBD-KIZ?96UOydFI^NYO}S9fSZ2bWekq^oHY8TyEg&DQ;#AoF9=VO z6CVNnAxh+HNaE5f{h+1d!W5*)F_>r4jEihj!tn9cLH~fi$ok!tFH$?y{FsXIrzhRM z7bQ&+Ny|S%k^*!hH3Z=zOhu`E69F!N#4ylzX9 z@Z{sV#wq+^%UT~I&x2DXTw}>Ve@Il!%GBxH7x6`7g4l3H38d*#H0l diff --git a/sgl/sampler/base_sampler.py b/sgl/sampler/base_sampler.py index 50af7ec..aa7ed26 100644 --- a/sgl/sampler/base_sampler.py +++ b/sgl/sampler/base_sampler.py @@ -1,17 +1,87 @@ +import numpy as np +from scipy.sparse.linalg import norm as sparse_norm + +from sgl.data.base_data import Block +import sgl.operators.graph_op as GraphOps +from sgl.sampler.utils import adj_train_analysis +from sgl.utils import sparse_mx_to_torch_sparse_tensor + class BaseSampler: def __init__(self, adj, **kwargs): - self.adj = adj + self._adj = adj self.sampler_name = "None" + self.sample_level = "None" self._post_sampling_op = None self.pre_sampling = False + + if "pre_sampling_op" in kwargs.keys(): + graph_op = kwargs.pop("pre_sampling_op") + if graph_op == "LaplacianGraphOp": + graph_op = getattr(GraphOps, "LaplacianGraphOp")(r=0.5) + elif graph_op == "RwGraphOp": + graph_op = getattr(GraphOps, "RwGraphOp")() + self._adj = graph_op._construct_adj(self._adj) + if "post_sampling_op" in kwargs.keys(): + graph_op = kwargs.pop("post_sampling_op") + if graph_op == "LaplacianGraphOp": + self._post_sampling_op = getattr(GraphOps, "LaplacianGraphOp")(r=0.5) + elif graph_op == "RwGraphOp": + self._post_sampling_op = getattr(GraphOps, "RwGraphOp")() + self._pre_process(**kwargs) def _pre_process(self, **kwargs): pass - def sampling(self, batch_inds): + def _get_sample_sizes(self, **kwargs): + if "layer_sizes" in kwargs.keys(): + layer_sizes = kwargs.pop("layer_sizes").split(",") + layer_sizes = [int(layer_size) for layer_size in layer_sizes] + self.layer_sizes = layer_sizes + else: + raise ValueError("Please provide layer sizes in the form of either a list or an integer!") + self.num_layers = len(self.layer_sizes) + + def _calc_probs(self, **kwargs): + if "pre_probs" in kwargs.keys(): + self.probs = kwargs.pop("pre_probs") + else: + prob_type = kwargs.get("prob_type", "normalize") + if prob_type == "normalize": + col_norm = sparse_norm(self._adj, axis=0) + self.probs = col_norm / np.sum(col_norm) + elif prob_type == "uniform": + self.probs = np.ones(self._adj.shape[1]) + elif prob_type == "locality": + """ + This sampling strategy refers to GNNSampler [https://github.com/ICT-GIMLab/GNNSampler] + """ + min_neighs = kwargs.get("min_neighs", 2) + sim_threshold = kwargs.get("sim_threshold", 0.1) + step = kwargs.get("step", 1) + low_quality_score = kwargs.get("low_quality_score", 0.1) + locality_score = adj_train_analysis(self._adj, min_neighs, sim_threshold, step, low_quality_score) + self.probs = locality_score / np.sum(locality_score) + else: + raise ValueError(f"Don\'t support {prob_type} probability calculation. " + "Consider pre-calculating the probability and transfer it to pre_probs.") + + def sampling(self, *args): raise NotImplementedError def _post_process(self, adjs, to_sparse_tensor=True): - raise NotImplementedError + if isinstance(adjs, list): + if self._post_sampling_op is not None: + adjs = [self._post_sampling_op._construct_adj(adj) for adj in adjs] + if to_sparse_tensor: + adjs = [sparse_mx_to_torch_sparse_tensor(adj) for adj in adjs] + else: + if self._post_sampling_op is not None: + adjs = self._post_sampling_op._construct_adj(adjs) + if to_sparse_tensor: + adjs = sparse_mx_to_torch_sparse_tensor(adjs) + return adjs + + def _to_Block(self, adjs): + return Block(adjs) diff --git a/sgl/sampler/sampler.py b/sgl/sampler/sampler.py index 9990341..408d438 100644 --- a/sgl/sampler/sampler.py +++ b/sgl/sampler/sampler.py @@ -1,19 +1,12 @@ import torch import numpy as np +import networkx as nx import scipy.sparse as sp -from scipy.sparse.linalg import norm as sparse_norm -import sgl.operators.graph_op as GraphOps from sgl.sampler.base_sampler import BaseSampler -from sgl.sampler.utils import adj_train_analysis -from sgl.utils import sparse_mx_to_torch_sparse_tensor # import metis import random -from sklearn.model_selection import train_test_split - -LOCALITY_KWARGS = {"min_neighs", "sim_threshold", "step", "low_quality_score"} -UNI_KWARGS = {"pre_sampling_op", "post_sampling_op"} class FullSampler(BaseSampler): def __init__(self, adj, **kwargs): @@ -22,10 +15,13 @@ def __init__(self, adj, **kwargs): """ super(FullSampler, self).__init__(adj, **kwargs) self.sampler_name = "FullSampler" + self.sample_level = "graph" self.pre_sampling = False + self.full_batch = kwargs.get("node_ids", range(self._adj.shape[0])) + self.full_block = self._to_Block(self._adj) - def sampling(self, batch_inds): - return {} + def sampling(self): + return self.full_batch, self.full_batch, self.full_block class NeighborSampler(BaseSampler): def __init__(self, adj, **kwargs): @@ -34,60 +30,16 @@ def __init__(self, adj, **kwargs): """ super(NeighborSampler, self).__init__(adj, **kwargs) self.sampler_name = "NeighborSampler" + self.sample_level = "node" self.pre_sampling = False def _pre_process(self, **kwargs): - specific_kwargs = {"pre_probs", "prob_type", "layer_sizes", "num_layers", "replace"} - for kwarg in kwargs.keys(): - assert kwarg in specific_kwargs or kwarg in LOCALITY_KWARGS or kwarg in UNI_KWARGS, "Invalid keyword argument: " + kwarg - - if "pre_sampling_op" in kwargs.keys(): - if kwargs["pre_sampling_op"] == "LaplacianGraphOp": - graph_op = getattr(GraphOps, "LaplacianGraphOp")(r=0.5) - elif kwargs["pre_sampling_op"] == "RwGraphOp": - graph_op = getattr(GraphOps, "RwGraphOp")() - self.adj = graph_op._construct_adj(self.adj) - if "post_sampling_op" in kwargs.keys(): - if kwargs["post_sampling_op"] == "LaplacianGraphOp": - self._post_sampling_op = getattr(GraphOps, "LaplacianGraphOp")(r=0.5) - elif kwargs["post_sampling_op"] == "RwGraphOp": - self._post_sampling_op = getattr(GraphOps, "RwGraphOp")() + self._get_sample_sizes(**kwargs) - if "layer_sizes" in kwargs.keys(): - layer_sizes = kwargs["layer_sizes"].split(",") - layer_sizes = [int(layer_size) for layer_size in layer_sizes] - self.layer_sizes = layer_sizes - else: - raise ValueError("Please provide layer sizes in the form of either a list or an integer!") - self.num_layers = len(self.layer_sizes) - - if "pre_probs" in kwargs.keys(): - self.probs = kwargs["pre_probs"] - else: - prob_type = kwargs.get("prob_type", "normalize") - if prob_type == "normalize": - col_norm = sparse_norm(self.adj, axis=0) - self.probs = col_norm / np.sum(col_norm) - elif prob_type == "uniform": - self.probs = np.ones(self.adj.shape[1]) - elif prob_type == "locality": - """ - This sampling strategy refers to GNNSampler [https://github.com/ICT-GIMLab/GNNSampler] - """ - min_neighs = kwargs.get("min_neighs", 2) - sim_threshold = kwargs.get("sim_threshold", 0.1) - step = kwargs.get("step", 1) - low_quality_score = kwargs.get("low_quality_score", 0.1) - locality_score = adj_train_analysis(self.adj, min_neighs, sim_threshold, step, low_quality_score) - self.probs = locality_score / np.sum(locality_score) - else: - raise ValueError(f"Don\'t support {prob_type} probability calculation. " - "Consider pre-calculating the probability and transfer it to pre_probs.") + self._calc_probs(**kwargs) self.replace = kwargs.get("replace", True) - # When layer_size = -1, NeighborSampler always returns the same subgraph given the same batch_inds. - # So we can cache the subgraphs to save the time. def sampling(self, batch_inds): """ @@ -98,11 +50,10 @@ def sampling(self, batch_inds): Outputs: batch_in: global node index of each source node in the first aggregation layer batch_out: global node index of each target node in the last aggregation layer - sampled adjs: list of sampled adjs in the form of sparse tensors + block: sampled adjs in the form of sparse tensors wrapped in Block class """ if callable(batch_inds): batch_inds = batch_inds() - if isinstance(batch_inds, torch.Tensor): batch_inds = batch_inds.numpy() @@ -110,12 +61,12 @@ def sampling(self, batch_inds): cur_tgt_nodes = batch_inds for layer_index in range(self.num_layers): cur_src_nodes, adj_sampled = self._one_layer_sampling(cur_tgt_nodes, self.layer_sizes[layer_index]) - all_adjs.append(adj_sampled) + all_adjs.insert(0, adj_sampled) cur_tgt_nodes = cur_src_nodes - all_adjs = self._post_process(all_adjs[::-1]) + all_adjs = self._post_process(all_adjs) - return {"batch_in": cur_tgt_nodes, "batch_out": batch_inds, "sampled_adjs": all_adjs} + return cur_tgt_nodes, batch_inds, self._to_Block(all_adjs) def _one_layer_sampling(self, prev_nodes, layer_size=-1): """ @@ -123,7 +74,8 @@ def _one_layer_sampling(self, prev_nodes, layer_size=-1): v_indices: array of target node inds of the current layer layer_size: size of sampled neighbors as the source nodes """ - current_layer_adj = self.adj[prev_nodes, :] + + current_layer_adj = self._adj[prev_nodes, :] if layer_size == -1: # in case layer_size == -1, we simply keep all the neighbors @@ -148,67 +100,19 @@ def _one_layer_sampling(self, prev_nodes, layer_size=-1): return next_nodes, current_layer_adj[:, next_nodes] - def _post_process(self, adjs, to_sparse_tensor=True): - if self._post_sampling_op is not None: - adjs = [self._post_sampling_op._construct_adj(adj) for adj in adjs] - if to_sparse_tensor: - adjs = [sparse_mx_to_torch_sparse_tensor(adj) for adj in adjs] - return adjs - class FastGCNSampler(BaseSampler): def __init__(self, adj, **kwargs): super(FastGCNSampler, self).__init__(adj, **kwargs) self.sampler_name = "FastGCNSampler" + self.sample_level = "layer" self.pre_sampling = False def _pre_process(self, **kwargs): - specific_kwargs = {"pre_probs", "prob_type", "layer_sizes", "replace", "pre_sampling_op"} - for kwarg in kwargs.keys(): - assert kwarg in specific_kwargs or kwarg in LOCALITY_KWARGS or kwarg in UNI_KWARGS, "Invalid keyword argument: " + kwarg - - if "pre_sampling_op" in kwargs.keys(): - if kwargs["pre_sampling_op"] == "LaplacianGraphOp": - graph_op = getattr(GraphOps, "LaplacianGraphOp")(r=0.5) - elif kwargs["pre_sampling_op"] == "RwGraphOp": - graph_op = getattr(GraphOps, "RwGraphOp")() - self.adj = graph_op._construct_adj(self.adj) - if "post_sampling_op" in kwargs.keys(): - if kwargs["post_sampling_op"] == "LaplacianGraphOp": - self._post_sampling_op = getattr(GraphOps, "LaplacianGraphOp")(r=0.5) - elif kwargs["post_sampling_op"] == "RwGraphOp": - self._post_sampling_op = getattr(GraphOps, "RwGraphOp")() + self._get_sample_sizes(**kwargs) - if "layer_sizes" in kwargs.keys(): - layer_sizes = kwargs["layer_sizes"].split(",") - layer_sizes = [int(layer_size) for layer_size in layer_sizes] - self.layer_sizes = layer_sizes - else: - raise ValueError("Please provide layer sizes in the form of either a list or an integer!") - self.num_layers = len(self.layer_sizes) + self._calc_probs(**kwargs) - if "pre_probs" in kwargs.keys(): - self.probs = kwargs["pre_probs"] - else: - prob_type = kwargs.get("prob_type", "normalize") - if prob_type == "normalize": - col_norm = sparse_norm(self.adj, axis=0) - self.probs = col_norm / np.sum(col_norm) - elif prob_type == "uniform": - self.probs = np.ones(self.adj.shape[1]) - elif prob_type == "locality": - """ - This sampling strategy refers to GNNSampler [https://github.com/ICT-GIMLab/GNNSampler] - """ - min_neighs = kwargs.get("min_neighs", 2) - sim_threshold = kwargs.get("sim_threshold", 0.1) - step = kwargs.get("step", 1) - low_quality_score = kwargs.get("low_quality_score", 0.1) - locality_score = adj_train_analysis(self.adj, min_neighs, sim_threshold, step, low_quality_score) - self.probs = locality_score / np.sum(locality_score) - else: - raise ValueError(f"Don\'t support {prob_type} probability calculation. " - "Consider pre-calculating the probability and transfer it to pre_probs.") self.replace = kwargs.get("replace", False) def sampling(self, batch_inds): @@ -220,20 +124,22 @@ def sampling(self, batch_inds): Outputs: batch_in: global node index of each source node in the first aggregation layer batch_out: global node index of each target node in the last aggregation layer - sampled adjs: list of sampled adjs in the form of sparse tensors + block: sampled adjs in the form of sparse tensors wrapper in Block class """ + if callable(batch_inds): + batch_inds = batch_inds() all_adjs = [] cur_out_nodes = batch_inds for layer_index in range(self.num_layers): cur_in_nodes, cur_adj = self._one_layer_sampling( cur_out_nodes, self.layer_sizes[layer_index]) - all_adjs.append(cur_adj) + all_adjs.insert(0, cur_adj) cur_out_nodes = cur_in_nodes - all_adjs = self._post_process(all_adjs[::-1]) + all_adjs = self._post_process(all_adjs) - return {"batch_in": cur_out_nodes, "batch_out": batch_inds, "sampled_adjs": all_adjs} + return cur_out_nodes, batch_inds, self._to_Block(all_adjs) def _one_layer_sampling(self, v_indices, output_size): """ @@ -244,7 +150,7 @@ def _one_layer_sampling(self, v_indices, output_size): u_samples: array of source node inds of the current layer support: normalized sparse adjacency matrix of the current layer """ - support = self.adj[v_indices, :] + support = self._adj[v_indices, :] neis = np.nonzero(np.sum(support, axis=0))[1] p1 = self.probs[neis] p1 = p1 / np.sum(p1) @@ -259,58 +165,33 @@ def _one_layer_sampling(self, v_indices, output_size): support = support.dot(sp.diags(1.0 / (sampled_p1 * output_size))) return u_sampled, support - - def _post_process(self, adjs, to_sparse_tensor=True): - if self._post_sampling_op is not None: - adjs = [self._post_sampling_op._construct_adj(adj) for adj in adjs] - if to_sparse_tensor: - adjs = [sparse_mx_to_torch_sparse_tensor(adj) for adj in adjs] - - return adjs class ClusterGCNSampler(BaseSampler): """ Clustering the graph, feature set and target. """ - def __init__(self, adj, features, target, **kwargs): + def __init__(self, dataset, **kwargs): """ Inputs: adj: Adjacency matrix (Networkx Graph). - features: Feature matrix (ndarray). - target: Target vector (ndarray). """ - self.features = features - self.target = target - super(ClusterGCNSampler, self).__init__(adj, **kwargs) + super(ClusterGCNSampler, self).__init__(nx.from_scipy_sparse_matrix(dataset.adj), **kwargs) self.sampler_name = "ClusterGCNSampler" + self.sample_level = "graph" self.pre_sampling = True + self._train_idx = dataset.train_idx + self._val_idx = dataset.val_idx + self._test_idx = dataset.test_idx self._sampling_done = False def _pre_process(self, **kwargs): - specific_kwargs = {"cluster_method", "cluster_number", "test_ratio"} - for kwarg in kwargs.keys(): - assert kwarg in specific_kwargs or kwarg in UNI_KWARGS, "Invalid keyword argument: " + kwarg - if "post_sampling_op" in kwargs.keys(): - if kwargs["post_sampling_op"] == "LaplacianGraphOp": - self._post_sampling_op = getattr(GraphOps, "LaplacianGraphOp")(r=0.5) - elif kwargs["post_sampling_op"] == "RwGraphOp": - self._post_sampling_op = getattr(GraphOps, "RwGraphOp")() self.cluster_method = kwargs.get("cluster_method", "random") self.cluster_number = kwargs.get("cluster_number", 32) - self.test_ratio = kwargs.get("test_ratio", 0.3) - self._set_sizes() - def _set_sizes(self): + def sampling(self, cluster_ind, training): """ - Setting the feature and class count. - """ - self.feature_count = self.features.shape[1] - self.class_count = np.max(self.target)+1 - - def sampling(self, batch_inds, training): - """ - Decomposing the graph, partitioning the features and target, creating Torch arrays. + Decomposing the graph, creating Torch arrays. """ if self._sampling_done is False: if self.cluster_method == "metis": @@ -324,29 +205,26 @@ def sampling(self, batch_inds, training): self._sampling_done = True - batch_inds = batch_inds.item() - effective_batch = self.sg_train_nodes[batch_inds] if training else self.sg_test_nodes[batch_inds] - return {"adj": self.sg_edges[batch_inds], "x": self.sg_features[batch_inds], "effective_batch": effective_batch} - - def _post_process(self, adj, to_sparse_tensor=True): - if self._post_sampling_op is not None: - adj = self._post_sampling_op._construct_adj(adj) - if to_sparse_tensor: - adj = sparse_mx_to_torch_sparse_tensor(adj) - return adj + cluster_ind = cluster_ind.item() + if training is True: + batch_out = [self.sg_train_nodes[cluster_ind]] + else: + batch_out = [self.sg_val_nodes[cluster_ind], self.sg_test_nodes[cluster_ind]] + + return self.sg_nodes[cluster_ind], batch_out, self.sg_edges[cluster_ind] def _random_clustering(self): """ Random clustering the nodes. """ self.clusters = range(self.cluster_number) - self.cluster_membership = {node: random.choice(self.clusters) for node in self.adj.nodes()} + self.cluster_membership = {node: random.choice(self.clusters) for node in self._adj.nodes()} # def _metis_clustering(self): # """ # Clustering the graph with Metis. For details see: # """ - # (st, parts) = metis.part_graph(self.adj, self.cluster_number) + # (st, parts) = metis.part_graph(self._adj, self.cluster_number) # self.clusters = list(set(parts)) # self.cluster_membership = {node: membership for node, membership in enumerate(parts)} @@ -356,22 +234,22 @@ def _general_data_partitioning(self): """ self.sg_nodes = {} self.sg_edges = {} - self.sg_train_nodes = {} - self.sg_test_nodes = {} - self.sg_features = {} - self.sg_targets = {} + self.sg_train_nodes = {cluster: [] for cluster in self.clusters} + self.sg_val_nodes = {cluster: [] for cluster in self.clusters} + self.sg_test_nodes = {cluster: [] for cluster in self.clusters} for cluster in self.clusters: - # split train/test within each cluster - subgraph = self.adj.subgraph([node for node in sorted(self.adj.nodes()) if self.cluster_membership[node] == cluster]) - self.sg_nodes[cluster] = [node for node in sorted(subgraph.nodes())] + self.sg_nodes[cluster] = [node for node in sorted(self._adj.nodes()) if self.cluster_membership[node] == cluster] + subgraph = self._adj.subgraph(self.sg_nodes[cluster]) # map the global node inds to the local node inds - mapper = {node: i for i, node in enumerate(sorted(self.sg_nodes[cluster]))} + mapper = {node: i for i, node in enumerate(self.sg_nodes[cluster])} self.sg_edges[cluster] = [[mapper[edge[0]], mapper[edge[1]]] for edge in subgraph.edges()] + [[mapper[edge[1]], mapper[edge[0]]] for edge in subgraph.edges()] - self.sg_train_nodes[cluster], self.sg_test_nodes[cluster] = train_test_split(list(mapper.values()), test_size = self.test_ratio) - self.sg_test_nodes[cluster] = sorted(self.sg_test_nodes[cluster]) - self.sg_train_nodes[cluster] = sorted(self.sg_train_nodes[cluster]) - self.sg_features[cluster] = self.features[self.sg_nodes[cluster],:] - self.sg_targets[cluster] = self.target[self.sg_nodes[cluster],:] + for node in self.sg_nodes[cluster]: + if node in self._train_idx: + self.sg_train_nodes[cluster].append([mapper[node], node]) + elif node in self._val_idx: + self.sg_val_nodes[cluster].append([mapper[node], node]) + elif node in self._test_idx: + self.sg_test_nodes[cluster].append([mapper[node], node]) def _transfer_edges_and_nodes(self): """ @@ -381,8 +259,9 @@ def _transfer_edges_and_nodes(self): num_nodes = len(self.sg_nodes[cluster]) self.sg_nodes[cluster] = torch.LongTensor(self.sg_nodes[cluster]) row, col = np.array(self.sg_edges[cluster]).transpose() - self.sg_edges[cluster] = self._post_process(sp.coo_matrix((np.ones(row.shape[0]), (row, col)), shape=(num_nodes, num_nodes))) - self.sg_train_nodes[cluster] = torch.LongTensor(self.sg_train_nodes[cluster]) - self.sg_test_nodes[cluster] = torch.LongTensor(self.sg_test_nodes[cluster]) - self.sg_features[cluster] = torch.FloatTensor(self.sg_features[cluster]) - self.sg_targets[cluster] = torch.LongTensor(self.sg_targets[cluster]) \ No newline at end of file + self.sg_edges[cluster] = self._post_process(sp.coo_matrix((np.ones(row.shape[0]), (row, col)), + shape=(num_nodes, num_nodes))) + self.sg_edges[cluster] = self._to_Block(self.sg_edges[cluster]) + self.sg_train_nodes[cluster] = torch.LongTensor(self.sg_train_nodes[cluster]).transpose_(1, 0) + self.sg_val_nodes[cluster] = torch.LongTensor(self.sg_val_nodes[cluster]).transpose_(1, 0) + self.sg_test_nodes[cluster] = torch.LongTensor(self.sg_test_nodes[cluster]).transpose_(1, 0) \ No newline at end of file diff --git a/sgl/tasks/__pycache__/node_classification_sampling.cpython-37.pyc b/sgl/tasks/__pycache__/node_classification_sampling.cpython-37.pyc index 8d28e8b20cbefd6e02bcb885dfca5e194822ed5a..86b9346fe8420064a0169ca8a6b7a24043c4b1b2 100644 GIT binary patch delta 4287 zcmaJ_U2Ggz6`nh@Gdnvw`{VW6-nE@LX&g7(v`Q(2rauB2+O$cjiBd${C1JbXnK)iM zYy0lFaoRgwgr!nYA_BMczt!#=NbpcqLPdZ8i3c7K;sJ?jB-)Ax1mcAsi5IHEckYbs zZB*6LJ@?#u&b@c$&Uemt)^9(!buxcEmvc1yx#Cg(j-zkn-(`6hLlZ5F*8R2oFMHuJ z&k$N;+)IgUe`265Tl~i&>!qq1W-FdfHlg*a0~$|y+QVASoUiGTzQQB-lE1WgA!vAh zt9kiLC+hAZZ?0WvbT%$7_`Ejt=1=uwIy)i<&fbdW2R^^J)bYJWw|Q}G#g9_U{^f2Y z`_4>dcv)mO{7t{L5%|#@R?f&hrp;VAXD-;0wS1|`7rXLV^UG{bzH9DJTg+t+v*iEG zhYsF^SE@077#;ugj<_9KWCwmXXf#``Ik|3~U~kE{tkb=(1x}C^THD=Xp6OX$+Ow~w z@RB*qp*8YJUr735=Q(I2{Pn;=Z?;R2qt~q~Ut_$S*vDWW$FAG^=9Bq}c)qt|paYC{ z#mg+~{QbnKdfnn;r}E+DJFAPdx<3gyHDGvbe9iXs(0Dq~`}#l&4bKTvXy!GxS;33< zwDz7YcCU7gowGwLaY|$H2;SjJu$Q()quZE9n?jpG+qY$%)k8bXbZ!dvV@y6~n;tj5 zkq`T(V_#V5X5I;INLJu(du@7e<-TRY%?Y!eL&1T>VS7uuM6*p-7JJg+@CHk|$@CM=6*YB$Mlr(S5#@XsHkyIoxVY;1k==oWUsqnqo??gPAG2%K z4@U>DVFL=DZXh3b$3AN_n>o70bT$r+BsCYZqC?Y*tf<>8#Zr11cAJ%0NuN}6d1w{+ zefP}C61KC@K2O{J2A6~nsP)-E!^J~u>mCExBj|>wFWdaMvLpKNu`T66NVErryeBu? zJ8OhSmy9F`UOfPKt^#t&rT zVRbxqiV`evmJbWiO{EvZ5#iA5hEi&hIqwy?%mTYk5T48O3l+FIuO$n+qI`{W1t+j=S{(loxG z1a{T<0}wc-8PrNDaYyOZjbLpf=tky3GiY6iY}GZEI#DLpV26zkh^*K*a?dYy5DW2$ zKSKLH3Q^Bege3dbogR%=nhX9)x9%!FNK8>yjS^E{xvjP+xH?|9UZC;f%Id<-MC`rm zQB)H{F4wUSe+Ul~X!jx#PC(N@Sc7@L_r87$pSQB-*s$D35sFg-4zh(aM;0 z0kmqWn-oAZEaY z$3;ew;4uU*LOi69Gmsygs@5+K2m=WweN_j|Y5fDJ1Ii*L zOb{F6FAB=3YJ1I4H0gMGVbaWNdNMAEvE8v4L#`>#*HS!YOL?VoJaZSmCsGc+ z@*CAx?mWnC8eb-H4~c-paT1(FypXh~NC*;y2Hud*OwO?<0F_UCZyJ}x(>XWTMKw8UJ!1j>~^n7#?6?jNQf z=oR5XWx9ek?p+hiV-vXkaSeQlG7%WrQZit5$_a95#U6j%$7s(Q{XHl=|pt%o%M zT2enf((<9b)n1K7mbcnKR;rD}kf`IXcu0c2Weg6*GKzoHM^;}l4Qn*o_18E2POyZO zu6~Rxbr$)vwCmhmtDf5(fph8^H6#j)7sY;?!-LEGIeF{8^3m%mwQu849cbuMIjMmbh$_aZ zIs=suO(-zQaNmg@CHox&{R0E~ggX@&)n7J0jkz7fsJbid0ZT|0s(=D&bT2i88Xcce z7YwK|6l!8@IUasL0W~H-Vd5raj3aDS#|)zewFpp?#c^4n z1;+ss$lA!wDWJv>Zi1Q%32Km~;IRuDREP!@GV-1M7wHr%Ht#T;zv_1Pmu06Efth2lkwN z4%^iWLr&eLC;eT5kD2`%oF)kq^wWTf3kPMy^@y-$@9(X{j9>2aaV_WYWsLa>iHjtN zl6i*&(QRB`5Z*W)z*kA^pmCef6_>BU)JN;m{}y&ii~F1kCrfQI&fE>MMkf9@uy&ch XEZ?5}#Tc1)_l&jMH!$Nz`T6Hodz^%*58(aoPZztuO1%I%yZ%z$;*jX`R6=&rqU7 z4yiN4wWQ88X|<@00(KmrD4Hv%zU85Bc`W*t0QnDsc_<3>$v|J*haeCAecuc(*2-EL zXiMRIbME)=^81c%mCL4r-(LBF`-4{%SQSoZ`_)0ss&PEmuMZklW3Xf` z4VJCt!HTsqXj;v|sXiC!`Z7BWx8z8Pin`mjG z^uE$ng)Xv>3s%O}MNZ@&E61vp6^1B0R;-*bMG@cp`--ku1{PY5FWnmo_pQF;`@L?j z;{?5-XWw@QSdq6K8(zO}_lLg!$(J$Cy-tzF#E-sNgbV_oX`lpZs2pn(?F1(lB5r3w zWjC{{BbE&{&|HuYwTW?}gqctm>W3(|3t<-242oesC>?9^ouC|4!u+x3X<;5E)i4KI z3k}eESO9H=Cg@UF5XxRbh5|JO^|TzCob%BvXC+i7O^l}4t3tt2^?NuttuN!!dv0%g zC$NRvaSpfcwAbRSKOVVKma&@Hux*@6VB7H*>~prpwgoN(&NNmJ<7?;F>^l$Lz8_yb zzt#Zf(3cv`!5lf#8MuKf{rF4gw;GOu-k^8n%BM}j9Su7>er)#qp63US*Ky;FAILbr z?FLSOhFP=ZI!1V(glAv9Tb>HoG%PB#{r=l`U+NzoT=%#8*Kv3E z{Oca>TvB(={C>*(o1;Uygeg@J{AWQFt)k^rlLG#7>IOpc*H9bk(Y159Wpm0@9hEx5 zkG>S%Lhy4SkQT(HDuk~564Lm*fRFs@j6kV?J*Rtl1u37nMbXY5dfH$%bO~)pfbq}) zx(H}+^;lJvfCX3t@s)_K-Rv#|5NI9(OPiHzfgv;qHFZr9nV;xF|5=7wqh>*N0|+Rz zD1w$mmSpgUBm~PySLDB?rB$SAD4Ukn)BMJ%{3WC_yUW7Z)8wxMGk{P--_2v1;bQw$ zZ|yUBSfw5&YoO~VQ?b0UklskwoDqdcN63Dx9wg&jI#zEh@BTi@r8OXscMBh@ax1t@ zeIn4BHW1bkof1$-;`n)p`nE#h07TuIgx=GRdF zIES+{Lw!%jY%Ygpq=!Y?hyBYBaHf;1CrUDtw2mfkOzU9WI^P>ZG-iEfxKFxGhFF5C zqZ_guX6gGgo$pB&7k!1)QN|Iwv3hXS{El=-a@cX9g}v>T7=mY=o8|#qXr8T z79muC&ZDC`L?m^U(Sp~n`Fb5wq@^^j7;tT$TiYtq1ck6+UK`n!IU;CPpv$g6hfSpj<@Z?{2C8zlF?67d=AVs)&psGS}>F?Oj|aI zY!D&iT&@#&j!1zBNkU1YFwS*{a^M7U$sP^;Aiaw15{uxth;6s&6l`3xc>?MvAJ_q% zaBR@&(VdA49HX_xS=ME-8pP)O;>Cs8g^Tm)MPoDBrZ%Un2SckCA0mj;egah-#aWEPOew<|`5EpqU z*A~4_5a*J@xYX@=5JO2y-ltx-xWfd0oMx?WqpXME-vj|`frZ$r({#j2V1tTkW*T6t z4Pvbfl~R5MB@L#q0>M6}cJ#uzmfymn+qxucQdU9k-D?*JYVQOoqy@y~)Yk2|5P|l5jW8cyp-N;lGO#0X=aUdLR7I^U zm2~4R&hH>=_M5?KI_Vj#t}(2JDVM@(_8%~;X3qg&bzPK1`LSx{MMYErunh*-1zfvE zeC3SYdB^P>cKYrVXU})#69Dti!rA|dIsJd&?4S{vyYvynELEfxG3FjsQw(N~@>S3W1~KE%jp!)}(kzo6+6PXKXWWbo-; z%?u^7LkkEQ1ph923X~5{0dJ(?GXw?jW|{{z5DEa``T_tq&jR2g8)Xn+bim9h?v^HH zGQv{$O8sRf70@b-g*@s|znMu5W&?~c#2YX_0B{d>!oqFkWA)wdMkc(9WSE#D&(v7Q zb!D=Q_zHTZ-kV`@A*af^gY<+EmT$uu8bs_sVa(lI%-$L`nnP=(QGD{ zsSl1l6E?zStY;;xlV21)JPHv<&6N+7$+L(rp|lw`C(osA)+Vhq_B?xNOCQ1*L)|MU zO0*j4VV2?Xzl3^}4Ob($8vw7b!f>lFVB%q_o;D=G;y1JAIB9#$1V4oBg;FD3Suuz}zG+5INoAdyGB?33QuhUR;>1|wXZ z@J&Baf2i&2_YxtiZ1%SNHqM;yxH&hL1#D^z4j~WnTE+JN3UOm z|5FF$=FR@FDd0>jn_sBiru`rF&HWHdGP=d4lTXJdk7|p6NWe zfTy6oC3}wdDX&Il3jAA;q&HeeRZi3SBpFlqA7EJdM?{81UL!IfGK2l{5hdRv@;Z?} z2Z{C0SU50|553Ng9D3w8E^oUYEebH!8T)dds#Usho9x&#qdn7iJvpY5dWwHtVRv08 z7)y74QqrSwZW9?%-2ze=6}M?$zZTMUJw`Pg7cACl_~m8 z)8#}9r;y#!Y0*}mTlPd;NH=3BBJx+%dFo=KOC9Hv%YI;`kwYsl-45&(8CMV? z^VNRjc7SlKqUXM6PpdB^4z0?(gge!7&#-VZT4IwKGFd=%<}^5C1+YR$Pfj%x@dT^v ze@4Xbf+T){7Sv-C;DvC!Nvaa@Ww@*_tIO;IXaK5kDVM{GfLax*33Y0qj1+6)1xWk@ zCA2MZPfeA43K!COQ|AD)Ooxmo`jKa!1OqF+XURjL3Uw8Nv56F0* zjKE6(2WnIc^OG_#cx{={{Jw}a??2pw;N|Xz8_mvPhonb=p9`7Fu;5%E8 zPsZ4H8LrPStNb1|^nXS4%I~8mi`{*=q|<9QH_iT4Hy*U>iNyGrO8y4K zGN@Tsy8Bj^VuwivF`^`c(=xeR68*kan+=ip7JRFED#t2QM#AjSi55>PN%lO8m2Q$E zez%JA$_I;yL-`Lh^g?xVYd^{6AEDNNA^0zVC}3lHd1{io53Ee_DRaYD0VZysMAwcs zKi>h(E>Q(z%-9re0@B%wOsRq?zk|FPQ?iYk7hqS|O2mxBn3~WR7*hjFz&Qt&$S`9v z`!Zu@!I)WQOk92e76&Ev#FJ2JW#;KAg?lZMwKx+2Ce0y4tyxZy8IvCQ4YFOqc-#(* zY2pDyi%OV936>*f5tPW|A}lrcrhn6TYG#h95e1KY~H$81>g@XrAcVM^>l zpMz0;gbn+@@3UC)(gI7qd65&Jsn*ZSk+*1>XK`dXQDJAZr2Gjr`zc7fl6diQ zT>a20@RvF8A5k0^=bF&B>hnIQJXkxOWmPz5&V|WDW^9<)x(j?*TBv_uEv8aVTx@YR zEB*j8#lPC50F{|<<6K5Odg*f=%A#*j_AwaV#}YNrHYI6!xlM%L0oYR{dlZ{{14&d! znxxz%!s0tM8(yIdHXF##M3T{`mk)LXuEG)cEpwecP_VYx+GHy-3lxtRa5_XB5czR8 z*}ud8pBNmwOKOgTUf_L9F%8B{?T45a5WymqNk;A<%pww4iGuz=7R5#>Qd4-=+fllp$G0J{rO`ION$5~Qbjvs`tYYNa?)Z`{bv_mM*m~V35(-q1=Wr964}#6ZfruCPEH8n!F5)5~qDUi&o3`BOah*Ra ztO^pr2)36JNn!VDzEO$HLj}qBF@Y^FEew(NO%V~PUTq$HHs8|x_`J4N^9*6SmavZl zw#}J7k^GLy*3#s!&$X84KCIy~X20fTkVxCyqR^4Z`KuFYtuSd{W^w}f0pxkJITH_GEkxZ-R)EbUkYLu%%v|e*uIc*L{I{B!4(Hw+r@}jx5#3o8^`Eaq; z2;yO@IqL?7DyDJku~NN=_B6UIEHAWbjfxW_4%M3zO9s#2O8gdtapqPXx7e7S8c(QJ zI8*ksLhA?E@6aF(%Z}+$upk2|#3b}UAEfk7#LQ(q$#3-VojAK}QVjf12^;v_g^%+n z!gAFL!w0HfRhJ;NBN433Qp7VcbEZ;QWJzJnjb7o23+5YYvf;fNagPLnsY2egb zUcMDN(4e-WH1;(ih9azh$4DBI@f)d40YN@Md!0{9O@J$>RU>tlu}iQBvZlM zq(~{*vBpCvI$P*1kysA>9bs|I_k>M^VR12A@r(K?`kl$4B1~KsjsqziH_Gt0U1U&Z zsEK()mZ}tEZ^gmjuwRB#=ZAq<~~V zrOzO#BkFwr!pn)H0Kb%`Jqdec!`|BSQwmST5+uv@S;wvNV!i1&E%^t#A138L?SI6m zuJ|r_A@$=}$Vj%Ox7?|U^GDQ`Cb*0lnTH4t6A+ysc66#*YB+p?JY|Bt@>u#d_>H`j z9s*Z>kv<7W}l+DD0>>}nL#(i@XK2ZUu6Cpg}=zZca6fVOm-iJmYnS#G5H-7 z?4Uf|{XaM+Pi1#vm@l)}!do)ea~4j?zxKQvqm1zvv-@6ut;Lw{(gF#H#MzC}X z*WqUt_VjIqO30Lkt3&Oo{92z9BW-AJt_0d1JSO0aU={3w}62qFYg0xF9_ z8`i^BAuW*?wrs8s7%A<#|lZ=JutkiTXE)Kzm?eXHP< z55BsfBCY5vmWr>Us`9d$jTJ$V3@F4of&~J)%&wXO*emt?@W^Y_Pa+A7dTF9s=ai3e zYntOxl6!|nVR7N8;xYm*nf~KJ1#s-XKzayvjGx;4H!O^6B-P+0SVn<1fjJ>m^)}yEhEL$Fja*0e&X) zBb&zNXyoGrPap(Ft>NB$lPbR;y1$0+7K#_doaqvGs(cEKLX->gnfL;(;Xgk93;E2* zwvj)r7DNUwmYljg_$M@BkgGIyVVl?Ge@3#gzfp5m+J%{W>0KJ$iu{S(e3_!xGp|vI1Mfx>@Qguu^kfXDwaZ_w9-UdLadNw=}gPQcDKE= z#a3r+6d!y5B3|XiL}o#wd??oCBNF2?F=}`*CgYt>RMo50{D^TyhzlJ_9bH z2V912-BC-D9d=ayhktFY#-4!OU23H#r?Q#c1ltNG@20Ju2Z%HDO?R9+O?MLpHrQ>&f1(wd_T=i9x+Z zZLqDpS!;@^tQ-8esu)gXj9C^2nJ5YFs zY+bfPJf=4{21TmuS;&g*K|qDDK7_q|SZ`}fgexrS6&ADkv|+MEB5_tLLO;Kvch&b` za~HyXgaHJhqe1Qowr&~$+rsI{wij<1bwfcnf_0%EokV0EG{g>+jo?W_PVxJpak8H$ z!%1?G=fibtvbaq(ZX;L2;zqdx1Q;Qk@%SI*m%=|OSUk3e&qod_i;1|AEj_3Oh82G# zc4;A~w&)`(dA3@r~pI{%d@J zT;vz(?p;jLI0z+# z3K!Ti5ex(mf)^nOkf^HEI#Q{14!**vOwKT8-M}lE!g2on8oK{7PW}iXiGcNF{Rkf; zh{9jO79N^Wt+J76OqlEw>`ftjig1NDHMNmA?`zsbZt$6=GwZHl_jQEj_^yJP2#KiJ zXS}7kHzK^Q@Wl(wN21%w!y?G|V^9amsFvBy@`L7MkFX{-N9x$JL&1-OeS#1{jbKGk z*oOmq5s)N7aN@DWn6Tq>gy#`fql66bZ`ZZ93sQPulIzQ+kC|D)$EY^NOe0n4Sfr({ zWee%!{Vh+E+vT@f-XlZHZ2W)JB4Dh3qGm;KJ29^35Gu^^@zxhMF9pMH;Ru#yR1;C8 zFd+5?_AVfNiBRUhwQdI0$JTEncgjQSr%B_o-h|YITa{JB!GCVckh8pRLnpb*r#IZI z&BG?H$DA`+CvF!-c!v+QchMPWCMbKKzuC^4FJl$)FMtG}f{6GmXi&)z(c~J!?(yzK znN0HsiOzitcd!r&0In~SD?HgP6zzJ4Gw?freq6=gF-LC={n(_*t= z@u})sp6^&k7Wny&+D-RyUJDZ}*Fu|(BVe?yZjEK8<~A;F_yhLCsqUEG*Ne^F2nP_3 zBIFR}5H2F1hS?2-n+UfN?jYPnxQB3`pY7bWMx${Wq%j(%QL0jl2IyL<(*$j%bu{!3 D+Z+&V diff --git a/sgl/tasks/node_classification_sampling.py b/sgl/tasks/node_classification_sampling.py index 64fc623..887bf7f 100644 --- a/sgl/tasks/node_classification_sampling.py +++ b/sgl/tasks/node_classification_sampling.py @@ -12,7 +12,7 @@ class NodeClassification_Sampling(BaseTask): def __init__(self, dataset, model, lr, weight_decay, epochs, device, loss_fn="nll_loss", seed=42, - train_batch_size=None, eval_batch_size=None): + train_batch_size=None, eval_batch_size=None, **kwargs): super(NodeClassification_Sampling, self).__init__() self.__dataset = dataset @@ -29,6 +29,10 @@ def __init__(self, dataset, model, lr, weight_decay, epochs, device, loss_fn="nl self.__eval_batch_size = eval_batch_size self.__mini_batch_train = True if train_batch_size is not None else False self.__mini_batch_eval = True if eval_batch_size is not None else False + self.__determined_sample = False + if "graph_number" in kwargs.keys(): + self.__graph_number = kwargs["graph_number"] + self.__determined_sample = True self.__test_acc = self._execute() @property @@ -43,20 +47,27 @@ def _execute(self): pre_time_ed = time.time() print(f"Preprocessing done in {(pre_time_ed - pre_time_st):.4f}s") - if self.__mini_batch_train: + if self.__determined_sample: self.__train_loader = DataLoader( - self.__dataset.train_idx, batch_size=self.__train_batch_size, shuffle=True, drop_last=False) - if self.__mini_batch_eval: - self.__val_loader = DataLoader( - self.__dataset.val_idx, batch_size=self.__eval_batch_size, shuffle=False, drop_last=False) - self.__test_loader = DataLoader( - self.__dataset.test_idx, batch_size=self.__eval_batch_size, shuffle=False, drop_last=False) - if self.__model.evaluate_mode == "full": - self.__all_eval_loader = DataLoader( - range(self.__dataset.num_node), batch_size=self.__eval_batch_size, shuffle=False, drop_last=False) - else: - self.__all_eval_loader = DataLoader( + range(self.__graph_number), batch_size=1, shuffle=True, drop_last=False) + self.__val_loader = self.__train_loader + self.__test_loader = self.__train_loader + self.__all_eval_loader = self.__train_loader + else: + if self.__mini_batch_train: + self.__train_loader = DataLoader( + self.__dataset.train_idx, batch_size=self.__train_batch_size, shuffle=True, drop_last=False) + if self.__mini_batch_eval: + self.__val_loader = DataLoader( + self.__dataset.val_idx, batch_size=self.__eval_batch_size, shuffle=False, drop_last=False) + self.__test_loader = DataLoader( self.__dataset.test_idx, batch_size=self.__eval_batch_size, shuffle=False, drop_last=False) + if self.__model.evaluate_mode == "full": + self.__all_eval_loader = DataLoader( + self.__dataset.node_ids, batch_size=self.__eval_batch_size, shuffle=False, drop_last=False) + else: + self.__all_eval_loader = DataLoader( + self.__dataset.test_idx, batch_size=self.__eval_batch_size, shuffle=False, drop_last=False) self.__model = self.__model.to(self.__device) self.__labels = self.__labels.to(self.__device) @@ -71,14 +82,14 @@ def _execute(self): loss_train, acc_train = mini_batch_train(self.__model, self.__train_loader, self.__labels, self.__device, self.__optimizer, self.__loss_fn) else: - loss_train, acc_train = train(self.__model, self.__dataset.train_idx, self.__labels, self.__device, + loss_train, acc_train = train(self.__model, self.__dataset.node_ids, self.__dataset.train_idx, self.__labels, self.__device, self.__optimizer, self.__loss_fn) if self.__mini_batch_eval: acc_val, acc_test = mini_batch_evaluate(self.__model, self.__val_loader, self.__test_loader, self.__labels, self.__device) else: - acc_val, acc_test = evaluate(self.__model, self.__dataset.val_idx, self.__dataset.test_idx, + acc_val, acc_test = evaluate(self.__model, self.__dataset.node_ids, self.__dataset.val_idx, self.__dataset.test_idx, self.__labels, self.__device) print('Epoch: {:03d}'.format(epoch + 1), @@ -107,12 +118,12 @@ def _postprocess(self): if self.__model.evaluate_mode == "full": if self.__mini_batch_eval is False: outputs = self.__model.model_forward( - range(self.__dataset.num_node), self.__device).to("cpu") + self.__dataset.node_ids, self.__model.processed_block, self.__device).to("cpu") else: outputs = [] for batch in self.__all_eval_loader: - sample_dict = self.__model.sampling(batch) - output, batch = self.__model.model_forward(batch, self.__device, **sample_dict) + batch_in, batch_out, block = self.__model.sampling(batch) + output = self.__model.model_forward(batch_in, block, self.__device) outputs.append(output) outputs = torch.vstack(outputs) @@ -124,26 +135,31 @@ def _postprocess(self): final_output[self.__dataset.test_idx], self.__labels[self.__dataset.test_idx]) else: # ClusterGCN + val_outputs, val_labels = [], [] + test_outputs, test_labels = [], [] for batch in self.__all_eval_loader: - outputs, labels = [], [] - for batch in self.__all_eval_loader: - sample_dict = self.__model.sampling(batch) - sample_dict.update({"ret_full": True}) - output, batch = self.__model.model_forward(batch, self.__device, **sample_dict) - output = self.__model.postprocess(sample_dict["adj"], output) - outputs.append(output[batch]) - labels.append(self.__labels[batch]) - outputs = torch.vstack(outputs) - labels = torch.cat(labels) + batch_in, batch_out, block = self.__model.sampling(batch) + output = self.__model.model_forward(batch_in, block, self.__device) + output = self.__model.postprocess(block, output) + val_local_inds, val_global_inds = batch_out[0] + test_local_inds, test_global_inds = batch_out[1] + val_outputs.append(output[val_local_inds]) + val_labels.append(self.__labels[val_global_inds]) + test_outputs.append(output[test_local_inds]) + test_labels.append(self.__labels[test_global_inds]) + val_outputs = torch.vstack(val_outputs) + val_labels = torch.cat(val_labels) + test_outputs = torch.vstack(test_outputs) + test_labels = torch.cat(test_labels) - acc_val = accuracy(outputs, labels) - acc_test = accuracy(outputs, labels) + acc_val = accuracy(val_outputs, val_labels) + acc_test = accuracy(test_outputs, test_labels) return acc_val, acc_test class NodeClassification_RecycleSampling(BaseTask): def __init__(self, dataset, model, lr, weight_decay, num_iters, device, loss_fn="nll_loss", seed=42, - train_batch_size=1024, eval_batch_size=None,): + train_batch_size=1024, eval_batch_size=None): super(NodeClassification_RecycleSampling, self).__init__() self.__dataset = dataset @@ -189,27 +205,24 @@ def _execute(self): iter_id = 0 generator = self.__model.flash_sampling(len(taus), self.__train_loader) - for sample_dict in generator: + for batch_in, batch_out, block in generator: - batch_out, batch_in, batch_adjs = sample_dict["batch_out"], sample_dict["batch_in"], sample_dict["sampled_adjs"] batch_x = self.__model.processed_feature[batch_in].to(self.__device) batch_y = self.__labels[batch_out].to(self.__device) - batch_adjs = [adj.to(self.__device) for adj in batch_adjs] + block.to_device(self.__device) for rec_itr in range(taus[iter_id]): self.__optimizer.zero_grad() recycle_vector = None - new_batch_x = batch_x new_batch_y = batch_y - new_batch_adjs = batch_adjs if rec_itr != 0: recycle_vector = torch.cuda.FloatTensor(len(batch_out)).uniform_() > 0.2 new_batch_y = batch_y[recycle_vector] self.__model.train() - pred = self.__model.model_forward(new_batch_x, new_batch_adjs) + pred = self.__model.model_forward(batch_x, block) if recycle_vector is not None: pred = pred[recycle_vector] @@ -251,14 +264,13 @@ def _validation(self, iter_cnt, prev_score=None, val_freq=1): val_score = accuracy(val_pred, val_y) else: val_scores = [] - val_sample_dicts = self.__model.val_sampling() - for val_sample_dict in val_sample_dicts: - val_batch_out, val_batch_in, val_batch_adjs = val_sample_dict["batch_out"], val_sample_dict["batch_in"], val_sample_dict["sampled_adjs"] + val_samples = self.__model.sequential_sampling(do_val=True) + for val_batch_in, val_batch_out, val_block in val_samples: val_batch_x = self.__model.processed_feature[val_batch_in].to(self.__device) val_batch_y = self.__labels[val_batch_out].to(self.__device) - val_batch_adjs = [val_adj.to(self.__device) for val_adj in val_batch_adjs] + val_block.to_device(self.__device) - pred = self.__model.model_forward(val_batch_x, val_batch_adjs) + pred = self.__model.model_forward(val_batch_x, val_block) val_score = accuracy(pred, val_batch_y) val_batch_size = len(val_batch_out) val_scores.append(val_score * val_batch_size) @@ -275,14 +287,13 @@ def _inference(self): test_score = accuracy(test_pred, test_y) else: test_scores = [] - test_sample_dicts = self.__model.test_sampling() - for test_sample_dict in test_sample_dicts: - test_batch_out, test_batch_in, test_batch_adjs = test_sample_dict["batch_out"], test_sample_dict["batch_in"], test_sample_dict["sampled_adjs"] + test_samples = self.__model.sequential_sampling(do_val=False) + for test_batch_in, test_batch_out, test_block in test_samples: test_batch_x = self.__model.processed_feature[test_batch_in].to(self.__device) test_batch_y = self.__labels[test_batch_out].to(self.__device) - test_batch_adjs = [test_adj.to(self.__device) for test_adj in test_batch_adjs] + test_block.to_device(self.__device) - pred = self.__model.model_forward(test_batch_x, test_batch_adjs) + pred = self.__model.model_forward(test_batch_x, test_block) test_score = accuracy(pred, test_batch_y) test_batch_size = len(test_batch_out) test_scores.append(test_score * test_batch_size) diff --git a/sgl/tasks/utils.py b/sgl/tasks/utils.py index efa8446..966e2f3 100644 --- a/sgl/tasks/utils.py +++ b/sgl/tasks/utils.py @@ -33,13 +33,12 @@ def add_labels(features, labels, idx, num_classes): onehot[idx, labels[idx]] = 1 return np.concatenate([features, onehot], axis=-1) -def evaluate(model, val_idx, test_idx, labels, device): +def evaluate(model, all_idx, val_idx, test_idx, labels, device): model.eval() - val_output = model.model_forward(val_idx, device) - test_output = model.model_forward(test_idx, device) + output = model.model_forward(all_idx, model.processed_block, device) - acc_val = accuracy(val_output, labels[val_idx]) - acc_test = accuracy(test_output, labels[test_idx]) + acc_val = accuracy(output[val_idx], labels[val_idx]) + acc_test = accuracy(output[test_idx], labels[test_idx]) return acc_val, acc_test @@ -49,32 +48,44 @@ def mini_batch_evaluate(model, val_loader, test_loader, labels, device): val_num = 0 for batch in val_loader: - sample_dict = model.sampling(batch) - val_output, batch = model.model_forward(batch, device, **sample_dict) - pred = val_output.max(1)[1].type_as(labels) - correct_num_val += pred.eq(labels[batch]).double().sum() - val_num += len(batch) + batch_in, batch_out, block = model.sampling(batch) + val_output = model.model_forward(batch_in, block, device) + if isinstance(batch_out, list): + local_inds, global_inds = batch_out[0] + pred = val_output[local_inds].max(1)[1].type_as(labels) + correct_num_val += pred.eq(labels[global_inds]).double().sum() + val_num += len(local_inds) + else: + pred = val_output.max(1)[1].type_as(labels) + correct_num_val += pred.eq(labels[batch_out]).double().sum() + val_num += len(batch_out) acc_val = correct_num_val / val_num test_num = 0 for batch in test_loader: - sample_dict = model.sampling(batch) - test_output, batch = model.model_forward(batch, device, **sample_dict) - pred = test_output.max(1)[1].type_as(labels) - correct_num_test += pred.eq(labels[batch]).double().sum() - test_num += len(batch) + batch_in, batch_out, block = model.sampling(batch) + test_output = model.model_forward(batch_in, block, device) + if isinstance(batch_out, list): + local_inds, global_inds = batch_out[1] + pred = test_output[local_inds].max(1)[1].type_as(labels) + correct_num_test += pred.eq(labels[global_inds]).double().sum() + test_num += len(local_inds) + else: + pred = test_output.max(1)[1].type_as(labels) + correct_num_test += pred.eq(labels[batch_out]).double().sum() + test_num += len(batch_out) acc_test = correct_num_test / test_num return acc_val.item(), acc_test.item() -def train(model, train_idx, labels, device, optimizer, loss_fn): +def train(model, all_idx, train_idx, labels, device, optimizer, loss_fn): model.train() optimizer.zero_grad() - train_output = model.model_forward(train_idx, device) - loss_train = loss_fn(train_output, labels[train_idx]) - acc_train = accuracy(train_output, labels[train_idx]) + train_output = model.model_forward(all_idx, model.processed_block, device) + loss_train = loss_fn(train_output[train_idx], labels[train_idx]) + acc_train = accuracy(train_output[train_idx], labels[train_idx]) loss_train.backward() optimizer.step() @@ -88,16 +99,24 @@ def mini_batch_train(model, train_loader, labels, device, optimizer, loss_fn): train_num = 0 for batch in train_loader: - sample_dict = model.sampling(batch) + batch_in, batch_out, block = model.sampling(batch) optimizer.zero_grad() - train_output, batch = model.model_forward(batch, device, **sample_dict) - loss_train = loss_fn(train_output, labels[batch]) + train_output = model.model_forward(batch_in, block, device) + if isinstance(batch_out, list): + local_inds, global_inds = batch_out[0] + loss_train = loss_fn(train_output[local_inds], labels[global_inds]) + pred = train_output[local_inds].max(1)[1].type_as(labels) + correct_num += pred.eq(labels[global_inds]).double().sum() + loss_train_sum += loss_train.item() + train_num += len(local_inds) + else: + loss_train = loss_fn(train_output, labels[batch_out]) + pred = train_output.max(1)[1].type_as(labels) + correct_num += pred.eq(labels[batch_out]).double().sum() + loss_train_sum += loss_train.item() + train_num += len(batch_out) loss_train.backward() optimizer.step() - pred = train_output.max(1)[1].type_as(labels) - correct_num += pred.eq(labels[batch]).double().sum() - loss_train_sum += loss_train.item() - train_num += len(batch) loss_train = loss_train_sum / len(train_loader) acc_train = correct_num / train_num From 6d7e54ca165efeb846b9860249684e65e688a843 Mon Sep 17 00:00:00 2001 From: infinity Date: Tue, 21 Nov 2023 10:46:31 +0000 Subject: [PATCH 07/28] support multi-workers dataloading; support multi-subgraphs mini-batch-training for ClusterGCN. --- examples/clustergcn_nodeclass.py | 30 +++- examples/configs/clustergcn.yml | 12 +- examples/configs/graphsage.yml | 3 +- sgl/data/__pycache__/__init__.cpython-39.pyc | Bin 577 -> 672 bytes sgl/data/__pycache__/base_data.cpython-37.pyc | Bin 11995 -> 12287 bytes sgl/data/__pycache__/base_data.cpython-39.pyc | Bin 11024 -> 12433 bytes .../__pycache__/base_dataset.cpython-37.pyc | Bin 12046 -> 12807 bytes .../__pycache__/base_dataset.cpython-39.pyc | Bin 11775 -> 12673 bytes sgl/data/__pycache__/utils.cpython-37.pyc | Bin 2786 -> 2786 bytes sgl/data/__pycache__/utils.cpython-39.pyc | Bin 654 -> 2816 bytes sgl/data/base_data.py | 8 + sgl/data/base_dataset.py | 26 ++- sgl/data/utils.py | 2 +- .../__pycache__/__init__.cpython-39.pyc | Bin 1258 -> 1192 bytes .../__pycache__/planetoid.cpython-39.pyc | Bin 4255 -> 4338 bytes .../__pycache__/base_model.cpython-37.pyc | Bin 9277 -> 9523 bytes .../__pycache__/base_model.cpython-39.pyc | Bin 9501 -> 9286 bytes .../__pycache__/simple_models.cpython-39.pyc | Bin 6970 -> 11395 bytes sgl/models/base_model.py | 13 +- .../homo/__pycache__/__init__.cpython-39.pyc | Bin 629 -> 723 bytes .../__pycache__/clustergcn.cpython-37.pyc | Bin 1063 -> 875 bytes .../__pycache__/clustergcn.cpython-39.pyc | Bin 799 -> 992 bytes .../homo/__pycache__/fastgcn.cpython-39.pyc | Bin 1465 -> 990 bytes .../homo/__pycache__/graphsage.cpython-39.pyc | Bin 1480 -> 1094 bytes .../homo/__pycache__/lazygnn.cpython-37.pyc | Bin 4418 -> 4445 bytes .../homo/__pycache__/lazygnn.cpython-39.pyc | Bin 0 -> 4508 bytes .../__pycache__/vanillagnn.cpython-39.pyc | Bin 0 -> 1276 bytes sgl/models/homo/clustergcn.py | 16 +- sgl/models/homo/lazygnn.py | 9 +- .../__pycache__/base_op.cpython-39.pyc | Bin 2690 -> 3114 bytes .../__pycache__/utils.cpython-39.pyc | Bin 3399 -> 3703 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 282 -> 330 bytes .../laplacian_graph_op.cpython-39.pyc | Bin 1094 -> 1104 bytes .../__pycache__/rw_graph_op.cpython-39.pyc | Bin 0 -> 1014 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 938 -> 1008 bytes .../pre_normalize_message_op.cpython-39.pyc | Bin 0 -> 859 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 247 -> 273 bytes .../__pycache__/base_sampler.cpython-37.pyc | Bin 3412 -> 3416 bytes .../__pycache__/base_sampler.cpython-39.pyc | Bin 790 -> 3453 bytes .../__pycache__/sampler.cpython-37.pyc | Bin 9898 -> 9429 bytes .../__pycache__/sampler.cpython-39.pyc | Bin 9891 -> 9401 bytes sgl/sampler/__pycache__/utils.cpython-39.pyc | Bin 0 -> 1919 bytes sgl/sampler/base_sampler.py | 6 +- sgl/sampler/sampler.py | 157 +++++++++--------- sgl/tasks/__pycache__/__init__.cpython-39.pyc | Bin 826 -> 876 bytes ...ode_classification_sampling.cpython-37.pyc | Bin 9347 -> 10348 bytes ...ode_classification_sampling.cpython-39.pyc | Bin 4302 -> 10317 bytes sgl/tasks/__pycache__/utils.cpython-37.pyc | Bin 11243 -> 11185 bytes sgl/tasks/__pycache__/utils.cpython-39.pyc | Bin 10956 -> 11193 bytes sgl/tasks/node_classification_sampling.py | 67 ++++---- sgl/tasks/utils.py | 18 +- sgl/utils/__pycache__/__init__.cpython-39.pyc | Bin 0 -> 293 bytes .../auto_choose_gpu.cpython-39.pyc | Bin 0 -> 1185 bytes .../basic_operations.cpython-39.pyc | Bin 0 -> 593 bytes sgl_dair.egg-info/SOURCES.txt | 8 +- 55 files changed, 219 insertions(+), 156 deletions(-) create mode 100644 sgl/models/homo/__pycache__/lazygnn.cpython-39.pyc create mode 100644 sgl/models/homo/__pycache__/vanillagnn.cpython-39.pyc create mode 100644 sgl/operators/graph_op/__pycache__/rw_graph_op.cpython-39.pyc create mode 100644 sgl/operators/message_op/__pycache__/pre_normalize_message_op.cpython-39.pyc create mode 100644 sgl/sampler/__pycache__/utils.cpython-39.pyc create mode 100644 sgl/utils/__pycache__/__init__.cpython-39.pyc create mode 100644 sgl/utils/__pycache__/auto_choose_gpu.cpython-39.pyc create mode 100644 sgl/utils/__pycache__/basic_operations.cpython-39.pyc diff --git a/examples/clustergcn_nodeclass.py b/examples/clustergcn_nodeclass.py index 91ff3f5..04ea21a 100644 --- a/examples/clustergcn_nodeclass.py +++ b/examples/clustergcn_nodeclass.py @@ -2,6 +2,7 @@ import argparse import sgl.dataset as Dataset from sgl.models.homo import ClusterGCN +import sgl.sampler as Sampler from sgl.sampler import ClusterGCNSampler from sgl.tasks import NodeClassification_Sampling @@ -18,15 +19,30 @@ config = yaml.safe_load(open(args.config_path, "rb")) device = f"cuda:{args.device}" if args.device >= 0 else "cpu" dataset_kwargs = config["dataset"] - cluster_number = config["sampler"]["cluster_number"] - classname = dataset_kwargs.pop("classname") - dataset = getattr(Dataset, classname)(**dataset_kwargs) sampler_kwargs = config["sampler"] - sampler = ClusterGCNSampler(dataset, **sampler_kwargs) model_kwargs = config["model"] - model_kwargs.update({"device": device}) - model = ClusterGCN(sampler, nfeat=dataset.num_features, nclass=dataset.num_classes, **model_kwargs) task_kwargs = config["task"] + + classname = dataset_kwargs.pop("classname") + dataset = getattr(Dataset, classname)(**dataset_kwargs) + train_sampler_kwargs = sampler_kwargs["train"] + train_sampler_kwargs.update({"save_dir": dataset.processed_dir}) + train_cluster_number = train_sampler_kwargs["cluster_number"] + task_kwargs.update({"train_graph_number": train_cluster_number}) + train_sampler = ClusterGCNSampler(dataset, **train_sampler_kwargs) + if "eval" in sampler_kwargs: + eval_sampler_kwargs = sampler_kwargs["eval"] + eval_sampler_name = eval_sampler_kwargs["name"] + if eval_sampler_name == "ClusterGCNSampler": + eval_sampler_kwargs.update({"save_dir": dataset.processed_dir}) + eval_cluster_number = eval_sampler_kwargs["cluster_number"] + task_kwargs.update({"eval_graph_number": eval_cluster_number}) + eval_sampler = ClusterGCNSampler(dataset, **eval_sampler_kwargs) + else: + eval_sampler = getattr(Sampler, eval_sampler_name)(dataset.adj, **eval_sampler_kwargs) + else: + eval_sampler = None + model_kwargs.update({"device": device}) + model = ClusterGCN(train_sampler, eval_sampler, nfeat=dataset.num_features, nclass=dataset.num_classes, **model_kwargs) task_kwargs.update({"device": device}) - task_kwargs.update({"graph_number": cluster_number}) test_acc = NodeClassification_Sampling(dataset, model, **task_kwargs).test_acc diff --git a/examples/configs/clustergcn.yml b/examples/configs/clustergcn.yml index 42c9297..660e9ce 100644 --- a/examples/configs/clustergcn.yml +++ b/examples/configs/clustergcn.yml @@ -3,17 +3,17 @@ dataset: name: "cora" root: "/home/ssq/test_data/" sampler: - cluster_method: "random" - cluster_number: 10 - post_sampling_op: "LaplacianGraphOp" + train: + cluster_method: "metis" + cluster_number: 10 + post_sampling_op: "LaplacianGraphOp" model: hidden_dim: 128 dropout: 0.5 num_layers: 2 task: - train_batch_size: 1 - eval_batch_size: 1 - epochs: 30 + train_batch_size: 5 + epochs: 20 lr: 0.01 weight_decay: 0.00005 loss_fn: "nll_loss" diff --git a/examples/configs/graphsage.yml b/examples/configs/graphsage.yml index eb2ade1..4c4b488 100644 --- a/examples/configs/graphsage.yml +++ b/examples/configs/graphsage.yml @@ -11,8 +11,6 @@ sampler: prob_type: "normalize" replace: True post_sampling_op: "RwGraphOp" - eval: - name: "FullSampler" model: name: "GraphSAGE" hidden_dim: 128 @@ -21,6 +19,7 @@ model: task: name: "NodeClassification_Sampling" train_batch_size: 2048 + eval_batch_size: 64 epochs: 20 lr: 0.1 weight_decay: 0.00005 diff --git a/sgl/data/__pycache__/__init__.cpython-39.pyc b/sgl/data/__pycache__/__init__.cpython-39.pyc index 410cdf53e3c506869810aca60d60b41da760f567..64c3ad3845a4ba82d08157af05506ea934352ce5 100644 GIT binary patch delta 214 zcmX@evVfI0k(ZZ?0SF#zMWohFw^&?L(o=7-x)&uDWZYtP%E?d8)?~WH6O@>j zlAr67pO}(bbc;Kw^&?L(o=7-x)&uDWN7kEp2%p+&IdHWPg7v>Eyf6;TkP@ii8(p(@hcgM1cBm3 Tf|H||thm^KOc_QQMiC|e=d>KD diff --git a/sgl/data/__pycache__/base_data.cpython-37.pyc b/sgl/data/__pycache__/base_data.cpython-37.pyc index 7603232cf284e7f5201419ec498ec015f9de809b..1bc787f6a9aa250d2d71fd5dab3a3d404f7ab328 100644 GIT binary patch delta 4205 zcmai1U2Ggz6`nggyE|U5*FQ-dJ5J(>>!0r0vD4UX6W6sJ*KwS-F;FL^I7uhV&Un4? z{`Jnd*eI|Pu9P4og1JZ_9uR^`Z6B&4SwaFKct8R~B_2@e!+<~niHE)*Au3UR6wY^N z>)oz)=xXQa-gCZt&bjC3-kIOud|)OP+S+0Y{VIpn?AOokC3@Hw_eMrvuoO$nM+(uk zC{r#eR@jQXqF9k#bxo&P)Y56D)2xMNEmn+XF`C6_)@sFR7N=P&&DyL4%@Q<=TWnQH zw(ouCu9v5F?k9D@vt{feEV=)#}BTW2~%?*N) zWE^|u@dMIEBQI*2dAn$uT_UBfrX*06Krcj!8_~LH=3JYb<`MClx-?b~S94Y)y45tZ zwo3~Nrr9fcLz9)OWb%%3Te+ufsyu-mx{T(k#d;MtbX&U}-c_9`OSM9;0H-c0yCgT& zyUIGVG~(*@+i8MnNyCPH{3|w-&pB?UR4C6H;+@chD1;uVAdZw&)h#7eegvzqlsFMTLB0~i zn0FI7MUMloB1$sCTfmNhAcS667K0b@tHEns%9qmaRaq0HT&ArDVBaRkc#zBVwNNGb z_Kp+N0RWlT*kRViy2Xq(y_6sp?*QonK`=PLb-R=`AzclV0QMXZo#|0Mq-%PJhs8v^ zcmH?V`;0}z+mYe1D9JswSoWetE6sU&hhuZR7wp)xrDE21pZ5_%j6~;IyV!`HXEE`; z=*(ehuVxe;j)=cUGpTVJ1r%!AhuxEQd2V}FRP;q-Fc8kql4wEvMV}uU3j}$3ajRh3 zR@Pn;%Pkkj9wsDD2E${3nz^E7Z$B-5)N;WX4aD*(60L}1v5ONDRLvr|b#a#=2k$w0 zWOXQ-IMq39z?0^2;1eWzLA(>2HiA9nX9#pdjJHl3sbEcykm#1U*E*0rK_h+=q#p$N z@boe-m2K|s@JGQzJo#fFkAuKzs#n{k(?Gf;O>(HL+g~lk1__2mXZ!>+_9x=k)qWJe z7s?e~{t(s-fS{)M5Xi82BeD1*ib)obr<*G!%QnqzK0>^M=?=9L3e{^VLhSC4!;_e^ zZIA6xdcnIZ@M#t6uCMv3q7t+q=eG%=PH|B#CJw+lRvYlUEmIsF)C*sMTlVS2(P{vd0sTLg; zHo6YpH^sTGbH)W41w>$xMBf!(>k?O_#k%xBqs$J{QD;^vtD42?Mzi3(D$0kBU29Tu z{2W#w>>QEhh#X%6kvlk#QG?Q2flZTiUHs`#ADiC)_n``_7{u`6<|UpkZa0>Id{Vnj$_pGCOn`YN0G3z!3dV z+~{5yS_TsB6@Y~3$KvN7BGHloNZdj{6{(|7o{~DztpUi##r2*(@%^Kxee3Am03<}e z5bqzo=p)hm0Z52`EzF)pABpe@3DLXaTfGau zPUI~B36Uor^vW z_|NgBQ}8Un6%FVQMCQcClssY3p#FUz7ggi73hqbkaPa;s9-NpOK$QonN27+AxOQ@* zBJ~^4^bm+#Rn>nLj0XMS(N}=tV4b4In9h>g3jLDZ{Qn8wt;Naa$G6;E-jN@GJPfb_B8w0`#-Xt;tOrIteWp;d fw(|BY9TJMuMMP&EaXKdAYFtYkPs}GSCvN=@7uo2( delta 3945 zcmai1U2Ggz72esM-JSilr)xV-(gdeAj=gamCozp_h>=o%|Gw<2U>xmd~7ya~-1SuI6(jTz0L9~pDK zO&gSY*g8UbEI&TMBcJG|+-gp^j#bLlL}6>1c4JHCh%a0mXVxq|ie<@6Mu$74i&3tZ zEXUq(=4m~Ce)0%!5M#aqWgJkeP_~_|Mfz?0ocWwDR-9zkGEFAVA6FnR3)O0i`_<8b zT;ZP=IrevK%!DPKoggNe^)me}F>Cr;DvmK|jk3vE^Qgb1DP~=xx09ng13VJQ9)Xk= z2!=GOqEdB4ZCgB#H8@k80yzx=qh-W(YFt0cB&AY7WXt)Q={3ZRVg9+jObjjFpVS}! zPe&pcilM!8sq4X!Fau<$P_7B!#WNttIbnhvrk{2#y*i5FvrMF6tyXNuvbMxWTBqBY zh*$QpGT60kKFJ9Hv@NM^OGB*|Y`4YM{WP@bM)w_ahS5F|fS}LP-@6ygkNJZZnYBX~ z4lJ0n{-6bB-J@?GSR3h-ou2KboO|a14fbTsv;Kmz%=$8I_k8gp+}CgwM*E&pRqDA^ z$!R1cHnuwfYhR_6r(Q7UdF11SbIf{+zWEgGD2*)Z2sUCks#ha43jW%yYYSOI?v_*1+Hk_DLtX;alq$urD-l^*sEX=nCQ zPw!~G$WfWJE{I%p(^73goT0Dx^K ztas`4!G8Mjz%dVv0{20K^>d02p7)UG5I#t-e#vr|JR}_Gg9PhWBnMygyoLDsAi??# zEgozj5d|M4Sihrp4la5!v;Ii1{y+zY7RFVq6!Jsr=8trD=y^{tvcbm$uq6Hdv0#P# zkjnTA%^tdZYTpiV$XB{d;#Rxmj5oleg?J7%OfCZsaZmY9}G65O+ zV$O9P_dQyF=K3p2gvQUat%92gt4E9sx%E=B)@9Uvz3w=7os4&u^{A+P#scKY%bpEH zQ?{KA;W+O5lpbC_49k2>(Sd$Q_2JFas@0%wy+)=sHT?E6R-e=(u=a=!jGP`t1^J0b zss1-W8zY-uD?=ObSL;-WdUsVu}heK9OM%q>dJj%(^WS=(a-q6ZalK02 zF63(>LjOB*B{Bz*XXxtKB?CUlq=iA>8rvC%H8QSZmao+LVn$WA`x<@C*t>jG*JeHq z_H~do5I$8v++fm33mLlOY`1U6<8*oNOV8D{+>;gxB~8VoDN?0CQ*Oxn7T9DNGWS({Qm$J(ZnpCjeuI2BzH8 z&6wi?d%O{2@49T;2*v~Y#nhDD>4I0^EVr>2SG|gixPuXVD&RIz1ZmcVQqE;X7?J4g&Lxix74-(F8+)d!WkRIFE+7>rEU03U8l_mE6R<$MD_T)|!V}-uF=RP%EHT z@QbJwQ7fWW^2?}|Q7fS~=T}gxpjP(HYes$kr@+Bz)Xms^E$D@Pg;KE{s^(TJQk}hf zW?T!mTPh5iTYJr@kN>KDqd97dWJa4Sq<K;E-}J3J257Q<2Y($+ z!CFq^zUFGI*YbV|Se4`4v)z9CW}8bTKbFb5=a9m4$RcBCM&@lRTpF4i=Fn(bJ$ER7 zeFHVH@0xMXY}B2&*a z1aNGFGqcc1cNC~*^CZ`Ekdf{b&C+>jPp4=m219c>fvT2LFs+V13$oN~ZUhmaZZ(_B zggPR6At8F)dk8E=i79*y8Nt41+_Y4M1V<2}O-Ssr5!sv0W)33uMc?x6I~->{Z`>9k zd)HVueFv>xy7fJbY2pM!igOLcJw3#7a^@N#lei6xf1`ZMq%hw$zi;h0U^wMXao9a8 zwpO1~WaupE#Y+_Ac7JR8at+=8iDV2O({$sYYDr$bgk#j|hL=%DMjyf_hnXjkL<9t% z3I*T*9J!mu2&%!t`6r*D7hDZNqebac41wy% zibhIDOO^U82Mh6!up}~F7!$)~LyA4c$}yMBMe~$7n4clJQTKQPd#$aY*^Gb>6!Xbs;(e~ zcaV+s-${`mbIlGMNKNhzq-ZUNvg_xi?4s=X1u5rIF8U=YdnlLvIVl%{!e-I0jBAcR z?^mU#gi#CrqLj-hpYWHYJSSTIRP4OqZv>4t^+@t#nOpxHQuq^O5J5-u>0$)oziocU z96HpaqGz2bx9LXtb^ErZ8fdYi*%mw9626W#EUqE5^Eg7Ak#|rSns-eRb}M4rJJM#M z&B8ca+J1nx;?PE$E!w;^w17o)o5X8A1vO*;gQ^8zHY5s!`v{p zjwAbB&OIT=(P`X&m|Unu``bZnYd4H)*MnNC7Q)c=HdY3gU+uqrq1NtqYptGN3ox?z zcF@_l5v9iR5M8MCwpV+8OR3g=t*>fYgjaAna^a2Ec2Hjup^t4W92Z^#uosl-tGL=E z^k!35CJw3qrpH!qJI;@mQ&eB7#^&P7{oaONT0JTaqv@*uHl$iF1!UvJ^s*Y^cp=@^ ziXs)N1+FnaYVYlC2^xx`8s|`j*;*QPCiY$&fO>wU?+Ml>yi$rYabZkM>~Yav-w)!V zG-|Lj=ZJ^yJW}{1G9F{kbggq{^?_x1w2|X~uBrbOt<$FaU@%8|k2afBvm(YkkJROf z$u!k7Xi>C>sdFuQWp0Yj%NU`Jhc#FfU{l>wI+M^o0^)@uAgV8;_48Qyw+^k2V|$#{pT+`NJFR{J9dD*q zuTHjDE1o*M?TR|b`JZF#`ekO^RyYPh1P@U-B!mdO z)Sieg5pzrx!nH$YP#h0gX3?Haj#75g~V`+ zfj$vC4XDmW+XWuUkL9s}{#R7Q?DE1hHqc^RVUojhiMbu-d}G&CSJ9K3GRRjk!WN@W zljzA=D2zs5hq*MG#a=Oju@%D#11#(OP|Pw+Kg|9gMCF4yZJ2wD(oPfpL#*q_x-`(0 zp(FER^g9~maG$mom^mcktuo0-0 z`wx>-M;Y-|>3r$=^vsVi+V@xXQXn$z`u+3!sUaV+R()Y)_kpSU8EluLkrN}YKF`uB zGb$w5VWJ5qH8Gjhq=C&e29qId>G)xsQ>`B8uW^w_xrjHKv!^bmtNZZ0smS93IVg*% zUu50fi%nckr4KdYa8!0g-8brdd{cUOtMnucPA+Q!rd1ZSRl;riy)zswSN(1PsF)1$vOPGCb zTyy--`HxG_$>1b<&L=hTD(Z3WMb+BA(P%Rill)lDBZ&mkL{S29$F^~|c80~^LMxjyX`+U!mvoGh9&gH@6JZfK5hqaOL_~xF zG`djIm!B2RD!PA+6mrK6dkKL>f&U524r5Tf*KAHgsp>b-rUf_aBooO$M@mle2-L@$ zpDa+X6O=f+S)kbQm)W3rk2_j19SKT2>?}|WBmPwiR6#Qkq1;aaMI84mDD3zk1*NK? zkQl_paZCIZzRJrW<>lCns6fk`!`4rH29m#~oj=xPoOOvMWpS;v2KbD!~pdGne zatgBe$d125N@Q?8McPNl3$kE|yvwo6g2j%%KMd;=I4VcN5_^yZiyi-vf|Z=DkM+9_ zfr2d)vP6eufnrCT0+rV+JnW=c7$n;yArQD8o6v!@vN}-kO?0bk%qrtdOc(+D1ZlKK z$B?xIcTs&=xU=J*Crb}TxF7Z=J_h$StfBbEF^f9^{M!V;A>7C7K7v77eahADj4ta3 zl4}x9xo0(gN=dKVsST;4A?p;yWvbEjpcU<^ApCb?AS&G|n}bEYLFoa9@Uk&O_Uar8 zDZFa1S)$|*?7MhWeFy<|jh*^U46HjEHR1P*Xo&5t)?OxlC4z5vTVWW4|ACPc^hVkx z)jx$DIX1>F#Wg&AewR^i>i`k$yXJx8BgX@g-B1kHAad-q_BvrbCl5Uc9Pe)Fd)`Ij znGbh%gJ6JHNwKGEc?64^95OA7U@K3MA8e=~2>)w>cV)uX)PO0(r!#&3|3Q05q6~`{*W(Z_mOGPhS~DALn;~2*w(x@R@S75ePeS-f zE-@jzS9<~VG3Dp6q^6vxC6otjouQm<_*5WGdEJ2@->9EbJUBA}x|Ab-nZ5GjC?;M^ zFAbj#+BH}b-A>)7o@FLy>J65f%-YOkH)aqM6zAFd>&X5LNgiK4%Xm<9vz-5|He3iK z)HnL;U`NzfrbTowp{h)n$`f>)5j-$gdQESyd*3;JCS-zLbL9A4+#)Qg;(dXaNe zFLG|`lHjH;39di6B)I2uXkaH3nyZV?_mXxtExX97Nq4zXekJsA7Qja@&v=8 z@7^j4erIl#hGl4z;;?kdK*<@FmW?kNx5_&|9^!VtYrJo&s>~_L_zR%NjYcz;sfREV ze`Ow2xfl0i6E=Ws8)L%63v77Red3et;20 z-5P#?RXncJdlRiLcL8hvBU%g6`ZYXXVi$%M^PATDmO7)$c-$PfV6UV$+ADXwR~R2Y z8ZF?9A3zWqzIAJEgomZxj259<=fFcyh1`}L{ouLR2;F;HM1ht#Nepu_T*Jx0w@I!feWl7?e zlUWf4hj`In87zN85Blb$vsDvt^09?_f;c_%YX9v}-g1$PLe+NxJyjaST_5XNE)W7? zD%&FvekqMG@qrnhf-!n6p1?SYhWd&4bTf745eFviO(K!iqu0@+5)A3F9-_-Lm4Z5CRoWY~$mZS$xR5Y)sjsF^!9_ zx4OIf)muUZcoUL{v`5cHsPyn8rAt+rr?*G9%q7yGtHp^pj9!*``3Xnv4v3Nk-M!J@ zef_s0%W{*YE`h07FV*hvB0w+Kt{IzI1=UsB%x|M$;Cn5Gy(e$dMe{6bi+GQ=Oph3OoyrWoX=b;5?{0E<)w_K;=Qj)~hhicp9@D%l*)?ybm3_D5KArXVB)!>vx4~ zeOPw**ZOder3To4%`E&tiVa#1|-#Ic7xdlvmIn3cOVT^LrtWoa*(r=o9)#dQBWBLc!{0;OH=cJKCwq$s zW;oeu4;?p%fXKWxIGvUi(F&u}e;jBctr^hp6%3VdHwABD7G8K9Kyo%R0I9b?)^-Za zsaY^uiyjA-oX9L#Zv$3z6f9AZqD+qiOC&1;R-EH{;a&>b6a%Gq3XlWjnB-VrL7`*d zyF(~2cwtZ7;-N)YdKq6_;zKa~8vXRp-n98$CXTTM53uONBh#f}>Ow7^04|FY_4IYD zt{9L&D@~D-+e!naHCPUOKDs?~3S_{Lw9rNv7H$^M*Ld=Trk@eSR4`+TJZh{ldAvZS z$h(jxPh9EZgxn;4ktQvPVN6JpOV_usEP^gVlH9#W(kveKO~q20iYaXK8m|eq45%lFPF?`WTaE2JOYHa_5%o~eJYO`Y?O@^?=r^hP%&RPI zjrumTuP~z;QNP2CCr|N6sP8bF+*k#p2{KX=jMg!3o)|rWg4|Otd!DzLbuUF*5h-t( zinHo-I)p>rrGAbX-!0>E*zRoauj-G>YIM(L&JNjHIh5DrspoZm>g907$Qv4T z;j&(dqpnTy(8%ylu>#hhu!n2RxIG8eZK%q5tM+1iVtZ1P_S zyph$sc+pjbid9#ooor0=qSu_NTT}P`8|l%LUcBU%s&3t?79Hv`$Ip32x#ZS8y}IRf zRc%WtYgdKYkr%1gRB_XbKU1sT^gVcq)pE_M`wU$-R|dUs*{OO*3YYDhPN7iUt`rMib;>W(hvqrsZb+r+L+xQLMwys%=B~zRszHu&S5rg$Q0W%Q@XG6Uw=Cs4 zjUv+Yv{y~G;AyO2qfRiHq?7TL(+y-lqyb3JY$ke&z8b%DR^Yli26)e5 zHgObIC{=A|XO65yKlLO|OG;e9@k5N%7%R>&`jVuiPBEL+{eq$Ds9NBQ^t;4DvsOtd zpaM0_){oId^3f*pu)Ng>BbVtbZDcP32l*Uj#YOsCGDELMran8)oM&#S+Loj0yJ`wn z=&+gwnE^owUetB!bw^1?5=_C*fy&ZgWrGY}nY7ux5W%Xre{AU9kGfa&$u8!+aH(2X zFc~NMHyf4ida)|h^<7T>Bmmlur|t5- zLt3)k1L;X-D0ZFRJ9?|x&WBhK*d%y!=*qF>Ci2sVA(vQDqVKnni-#c}WkrRK^{i+U zbhT%2EMRjQ1ywBf_8ms{(>p!cW(JQO=JzpHY}3a*o6nu$Rn}B!^ns?f+9~(lZ$6ye zZ5&b9cXkT0UHW$KBhBdN01C)hJ3(nd+F4#@V}A6GeU%ijg$GiA$7Tv}(wPEe zZ_v+<4}T^L?P-u%5bUCw1HmpSY=A;<)e|66%{*2ePScRdG05@iDa!TDYB|dHok-~# zZ?=qW9@*107`;rt?mIKzeJf@6+kLNxUeog3j$sejgR(y29AV8+i!Kh@6$7*4Y#Xsu>z8A5pfIL z&)Mw=E$K1(!|43kcDhkO1DtzFynSDGtopWc++WhdiQ>{4W3z@|t0eU)7I@)`wG-$e z0T0Ep<+_gh0sZzw;iaeGK=TZ1r{rcM>XMkd+ZELsFkY8?XWMbU?C=H;u{z(SYF{b< zl4i~)sEO0VOTu<;D#vmENN=3HJl*)7klKPOGGOv&`s8GJQARATs6S3&(g>)YZ^cPt zB@Kk^ue3k*&`cx71*igJ;J*GNL?4Zn_auk`q`WX;(+qMI(sYYR5XkG3>=I#w=uV|Ykn8f`zUMpiI&Sl%$5#*?}!-V}wev9qB z$SYQP+X;xqUtE5)Mf%|MjpmE^CFB}#a+J-ARoa+cZ=Nv;7l8C4=#=|3-JctzKh64| zxBSH|tq8!YFS6$i+Gq3=bFFVysWJ#%Wb_M+UTfwdT?jyW5uA1RI_I(XEN-^anDBgP z?};2b9A}o|HY+)H*zSCEr@&jW)zLQr3vDzvhY+sndx$EV9}$SNT65o2Rj1wo$`+NWwDQmf=Bfzpt* ziJa*dSoUxJxu=+zdgy+9*)+`_U5kh1M;|ZrzXA}S(gA+>UxgT!UeYa=wsv#=KVV9S tG7MsX$Zf&s?h1?Cr9%j7@Xl(&T}Y)RIQ(d}=*)<9}<&#hCyA diff --git a/sgl/data/__pycache__/base_dataset.cpython-37.pyc b/sgl/data/__pycache__/base_dataset.cpython-37.pyc index 235854927d7ae6d8e38b15f9239ddcc09f54bc60..14fd9a7539e0ae32c8d4ba8083608ff710e75c8c 100644 GIT binary patch delta 4025 zcmbVOU2Gdw7WN%`obmW4PMpNe55;Ln6Q_x3(vtGi7Fr5j(j?F}+lAq0I`K?Xw;qqV zGcHY(Vx&}%2S8L;@BnCcC0MQc(3Ox^+STp@54`Q;iU)LgL_%Vhns3g%=iGD8J@@>~yRW=urcNf4Y6yPo^DkNFAG?*>Q~xE5%#caSA#?f-8 zozXHRbUI|Vnaa(OsjSUNS{DpsrV2xKP13qy6gLwvO2DYcBpnAJ+hd+uk|WM)E?)*Rq`f zynbbV2zi`cWP3B|u@Gb*iUFgjCEjmn_{|u37=IW9$ue@!kvlUcI=X zpyHUjDk!k!_?&PESA=ZIbOZK!5KgqnCX;)Zk|<;|c-DKJ&E^spvRwhZKPAp(y90Pb z$s=_McoBG>gvZ?n;DxS`)zB5`5&%Z7(L!G4vggp^ zWiGE+6r-$CvvoWck1RS)nF^!*JLeHhylppBhX_(VYLqG>DJfck5gd}bh)Py^=0VXZ z!!umVD{P7t35vk)|FxbTg2iAvtse%0K7=rha0Eg4awAM1#)<7I&6R6LS(tkrm|M;5 z+G_4ypsWZd4Lb$y}P~T;}Eu#NUN&&Na*gwMQ zb6B!kFW|{4b&3}3ZR$|x-UC4wBP1lP^lrVGZBq03z3h*v;q(B=aRpdai+amYwqKpd zj9_J?JqY^|us-o9#G+?WDymH$r^ncM`@R#8Ao(c5V+iONif8o$>o`)IaU>|wA7a30 zhsT{}*V+roXafHc{L>h_-G1!h$8qHo2qzF)={N<958x4}RS3m#@m5E7<50Rp(%8aj z8~_q`em);(pLI+~w3~h1v42SPr9CLni?FF&*oHp&m_5+>d~UY}2SUGLA9v2=0{R&$ z#)fp8MKcS#31LUN!&Wld_%51qAoLkL{>_q=em2>gVZW0P?zRZ;@%LL6_2Ht%Ky;MM zHTxyS>LgdD<8c1jc)2KWQ}b=xo-==o$ugv&TgakX-{9k^V*V3cdtQN+j(5qmn^U3S+=hKr4h^iidap3R0gG*icQLeATN6KH13Wi+}~(k7`zntstEJt^UfAy-PVv%GT*6a^PFI>AmcjGCnHuJ-RH62v~3ybp4G* z`)2)AM@s`Fk@ZjZ_rAUC{k|z*8od`FjjZoM`rPmVpy)1=6TtruEB2qw1vF1we-PjY z_GSNEPE`JTJ1h!K0{qBM44fGZXs0MN1t5{eoq^Tx*)upX3&11glIua6tX0bvox>T~ zt(Hq34})PZhVv#ZN=gstYDoNvSF=w;V4idT8WsaP%* zKEOK1WnU75g{!kpWof~xxDabSejK2mxLT_~ZM2G>WkTUf4$egjFioe~#liJ4QO>p( z0iJOU8w}SMm-XhSQ5$J=?yE+~IkvWcs(v2zN}jU-*R=q2evmc~?_2i^+@gXYqV)x& zwjY&hI5&duBEm}uFC)BypaXE38kHrBhw-a6g|inC7Dau=^V1>#evT9!hnj3=q`xj) zA&U=xKflp`8~)%i+wu&(YIv8ah0D=4Ez!(Z-SxJPz+ab8N_-~>pIydb3E>I?I**eQ z9mByjgx8)!v;i~)K0bBG=&HyfO@W>)tkM){R=$0~es9`&zHxZ;xJ1U8JNDHS%7e!* z;7j6C;h?|sb~Qh!M=|<}j$J+e1-o(JRAKXKeva1RauJkNK}rPEmK$LIIxvzIl@YHm zAK#25dV@`jPm(dFkDIyeoAUOGL0!E{t&LWDn|(9x^zG3!QRs?BQB52Bu&{4Icszzt<8 delta 3263 zcma)8TTC2P7~a|0U1l%90t>seT$XZOC=1lyQUPt77RzPnh0=D^F3Zjo*6!}CXJ(s1 z8fj8*4@RZGMx%+Q@wrbJpBkTxu`x9!YSK)as)-MpG}T(4e9`|uLjfs6B;lL?od0tE z^WV_BO$&w+oliPP#|dl$l}q{Bl`YGa%jLPC9OP?vwjx(4R|<|HhvHF6UUn!Y zSDbRy&@P7}DPCN9aT!9UPx0f*kE?1$IN^u|ej|SqF~Rhub(YMf43@dL`1=`WNPzmP z1iUYNAt~J`7&rak4!s1G;wM5ad?vn8=SRhuz-sWrYFA(jQ4u)nN(R?+;FfE;uFh7= z48i4^5G36v!!|GcSvyTCGegOwBQ)O1tIrOvBLhUu)4WEu#9 z1W|+yt^wbbTnJa#@TR0hBk@e(;s&b7U~OEgn76ZjEafa7I?X$b$l!Cm{Ngh1(g|dEJ9LL)K=B<`i;`EnxK_{ z>dN%!s$r;%@v7*cZAu;64i77~b&ym1s5SM%yOuYsgj{MBJr0wVaknerViGJ?4(#1U zYP$*c5Ulq3s_s5yT*iOhi69A#*6GEH>ZPq!8KK&XD@H>xCMFVzn4i(ufOkUO@MWmE znU9DtiSSFVn&)0o9A1S##4t2h?~iUaZ&2_pxK!5w^RBKYW-~tM7z!i6{pyM6CPFkE z-h=5Hc{3qO!-sIEW_&ZDC<=YFG!h;X##TlXYr;)FnAkR=nhE~jP$J0|glq6)34=ZD^x7n&NqNZcTFeK@tLpPQfFVAqC}iV^L}h zn|2`gzQtS)cku4<%yHLjgojZ0d$?aW+DgLQe;Y}<+HFYIAGIXAaA_k+tUD;#0Dso+ zZ0#hwn@Dd*{!cK{a4c#wtw_HI;TO2kFc!5LTO`zr@GHa`N1ANr7YX$t+;c3w+jvPl zv7a^$AegRfMmJ2iQ7mLtc93@5dLf%JOc67}qhfkeg@T$_7>^T;92D1Bde-!vV;at< zVy-ZUuhw+)!Gwpno>|>UrO%r#EwAc~M}hj=bX$^?CDS*mBeog6Arh$9^`{*O{2rzmWby8@^uE^^n{nHT2H!!XCM>Xv2cWp zISp+BGc}BKtck7M*>Y)&?5YI3bjOKWxgk4B%VPxm-91j!34#fNlLV6prkkbmGpZ@l z4~A0OJw^V{& zfVs}-q34OD$TJ>|BvDfYDFOd>Dg$t||?si7}YojN0MTCK+ zEnS07p$pD*-yS3;#%q*5PSY9hwM?$jLC)X4ekz%?mMaWi7nWzY$1%SP-mVV8m3T{h zHP9m{&&~15VkNZX^{^(M@I$;k+(aBo22Ob%J1PcdKvhql5Q9TKO7y=@nC1Bt(~||J z;!~GabsgYlkG8Px$)QJp?Io~%2m6WpG{Ip4$_=B*U=;=Tg4&d~A#XF@MpoVE#8@C$ zb`IOvBKZ^aBywJXo4tJtM8#@MFAg|nXj)e1Rl(*-mFBovqUmwUx#{AW4-kV_AjQcH z=;=&hF0L2w$|~cnRH@Ho<4Vd%#ivucnk0%3Uu!Oxt#FG*Caae)HyQN6V&AsoJO}h< u^T(EvYmeN2M3MKV)SuIoVpbi*6gzZ^pM=wx<}pWSPzVYESJ{hY?tcM@HExsu diff --git a/sgl/data/__pycache__/base_dataset.cpython-39.pyc b/sgl/data/__pycache__/base_dataset.cpython-39.pyc index e5d24b4017b7877fd9948c735afa653808a76314..a88fef899315ee474dfe21d9330f38f471f4d7cf 100644 GIT binary patch delta 4106 zcmb_fYiwLc72dhKci&z=)^9tG^Rl~kyzwK>yJ_=s-kSs)Lbfg2_1=4@!IO5s%uv!^=F+~j zDtZ32Uw#AW0DYB0&M2Sf;r)w#WP0M z=J^F0)#ZztzL2p^nzH5}{_!<-di59!u#>t^8`SpR^(?b=2hl@XQqrA7@6(1Q-9_|> zHY(|EqQ|uTlHR6`(~1XH&nx|mb&Ffx2iunkm~*ycOCce@mD;wjTU zZ(5}Xdi|5^oF9gF6F6## z=ZiLP1g)5s6Ew2(x|Zh_?}N-9Kog)J@RTnf2Fl&;-gvyGz`o$+-BLi20uEbv$_VsE z#CL+{`sAjEX`E2pSdM_VDWD@F78>7ZbFrq)F)Rf>JTBSv7q)N|Mv`vlIf4a=ST2%bd>ISxdK*6;1Il-X(q&jSN=u z*-VnPj5#5jXY<8OUR#(E?a_%uMZRS#`M?|50;K1}*P^EfD$=c6Ne6A+vLRj+|A-EW z&l16ZP%rK(78hcy$GSiehnDZU;+@VIW;zbWoHDqHdp*pH<^4v{F zy{&T;cd=|_uwBPDZJGs+Ev~qw=J?PcHr1LF##V=lvUzZLfrH~wnpv3LZmo`?8xoih ze@e9cFCO!gWHF)g)AZqI026?-02$JApf;WXCy0<@wFP2UJLM*}zLMEBssE=>GI6zh z3dzL5skb&c)q+Shu9e%mn7F3)ppM6?>o|%O&FERX%=Ox}e3`7@6=vPV(Q3U2^m2u) zWo@CBBDH(sFLf7W1y+uY8|%%dk|ZVLzR+X6aF z4-kXOY3mRpYO9|nQ;8mNFSci@hnP-~Pb1dR}SVv1oiZ@G;yN-YKN+@ z@Y>HybmhlG;@$ca2d_fo8sKw)TJ}CqjCbiLqYz{qW3hvEuD;MP$NYF>BqHKljYA4= z6~ArV-O)}nRfgo{i9mfId`;JPA^N-*P$Tv-g5MXvZtoEvG!444=od9*5$F~1Y18C$ z^;7XlJoOwoWb|j^a`W$Zy+9OC0BY@Jh#0>RKMCv+(=9*n%j_Q# z^{qX9cwIQ&d3|wHl$=N=GppNqTQ@S9->puz-eByD;-j{mORh2&QVG&wTvQfSXU^s~ zFzbWGFdAZIHcb-=uY32t0;WnWq^#FikV?^A@DC(rAx|D^ia{17RG9ME=Zi6sYQO3R zB}wd>n}vsero^2esy@vAMc2t%;mMK4r{ZA8#SQ0JZl{(w7@vtBcU;*ZerPLy!1x=9 z_lQ%SW3KN9wvs(ejDL!6c2@RyY%94JiSaKH?YiXJJHD0NOT_q(c(dz*eBAJtw~{+S z427*e>MD4zy#}E#5jeg=-m>YwE7R+eXE38$r9$3zyp$@LZpWW3mGq+KzSQI`gGrYQV-G;M?Z5^R}z_X!k5hA?37-#OrB_U z%LWm2!sQ~hJ3U8lTs|%i_g?lNf&PTJ-+OOBcCIHK0|(uSTUk7dGt-Ni_19E@tsdX; zx^HQQ1Ri;z&p z&0C;w(Q^x614he)j>@yeS>5sCv0RVoHvkI&M1?OB;I)x6O@3Qk>FY|$0IBlt4?X}9 zM70tnCuHb$HdD&lH~E|3-4NgC>pgP^q)dzq*IO7a0lor2+#Hs7oBn;wybV|;NC)Y; znxO_z;`_wE``RN_;T668?QC-OaR0}Q{zM}MT{I_lu|wie;*(LxB^WQ@qc}=(*tK4z zP8=tE48@f*jneVoh>3xd$tN+MhefE%n&UxGvR0}>DSkfC-zGzVJR~Ad3WW>t_kj^M zF8T(w-c5_v7qi^Tl(=5GS69T32A7vMT|8(!d=jwbC3+p)R{_%iq=lo2xb+}A0ZuZ1 z5r8wualqV?c?Z{{q4L;+K2R4ueXx@j?1H0n4C#cs`db+xddUQR+^UHCsq(^o~nT(<40tO15jRid#d?em_hF z#Mg&*PRWqt@|HKO^l8monv;IRStoeRFtu_)KaOK#brPUk2j%}Liz*GQfkl0>B~LJ> F{ugj6K}P@p delta 3388 zcmai0U2I%O72et3yFdPq|6*t3bmOG!*c-==NlHTclP1ou<2IooH>I~XdvBag?$0v! zZm|(4NsDMDBtkSoh)D3*rVm9R^Uw!WB*bHpKokj$Dp3RvR0361A9)Dpo9o&ed!1PF zcjuh*oij6M=A7Bv)d!c`&9=5&jQ;8)v-aBu@3s%$yhoB1vzQapqc79rN>1nr^-a1d zJG=Z7cVx?{Vx z7uwb_&V9~DRvnb{*_`Jbtak1#J^NS|jXQV3ikkf}Pac_Ht_ zIPM2hbIlvJRpS941!pIq3$P0?0%+Qx_$<>Hj`qg8TM}+*XLl+>ToEK&+zNzPhiCD8 z>k+d&MYFu3X6Y=H#JOUAkD7_`9x`l*PCA(lxmoj!nzeX^{EnAb_**vPXI?WMu&;`_ zw&%xI*u9&vwa^Yi(65Wht^x5;Td(+Ke(N;czRX*}Yq-W$+YB4r-k1OceulskPj{Z4 zYWB;+uj-gVu!-BV@D;!9JhivkO@kYf*eVpZhL7hKKI-Q*pL`?lC;Ht$X5w{y&y>(}+I_7y+!?$2jSJE|@kd z{xa}RS{?2Iaci)+3q4&XZrDrfJY?Q57VNMV+OA>zYWWX?ml!)OibGG_jGQq?l0a8g z9vD`w8uByHCJ-@)o4BGqONw$L>UeW7rl`~!c8dw=rbQFJPqY9h6Jt51u>xV7#UCV- z?M(b`Xg*5H1esfI7AyRa4y`EWhxbf07fE_4TIdK|*QMws()~%?A3m*$f2E3}Ew!R! zuOEqR+s-{iKY+lk=rH_^=tc3HZPU^EXEx)av_2NO;>te1x*6{f89ou;E1rqs9o~$G z+WeRJPw|v`R``+4cqre0$CmZ&PU8GLNG}k`q*Dt*nF<^9uJae6Ndb=mGul%NRlAI)RLb{ssOrjmkUtYEq&#zY5YK zKwbPR7~yqQ-IiA|vH)nK+y-0)EC8+nrfP&T#Z7O)mI?e&X@~kczyTmD+$BIltG>@Y z@j-d0q7sy)UJ=gNK)xxx;?CFrdusXq*k4)tA{@RU-WV^j1LB?Wj}L=a zX57FhF5{KSD0-7SS!C`ou2qp+Dq<=xW*mq{cxAyK3pG@>u)-wJI8B{QJ2YCjf zx;agW6{D5i{VE|i6?z8-z{0iRZ;EPVAKNc}Qn5zYEuq6Rd0^DJO^;#K4uU(hj$lsw zdgq`xu-D#x6z*$p5?&5|9I*Ce8mAzi1AHBT^Wo?x(Zx^;KvA!MD%IWj0Sp4Lwk|bN z570?6+!hm)vkpe(KAEKs$wS|F0(CL?kKj2*ATxeY!LiGvI{Y>m)FniQWFe^5>PwZN zPMfu(a_bHjoQh?JX61?**ak+bd-J6Dd~!=Vg#c-h-}l6eD%%=lJhvQ6X|-CZ9d>=I c;n+t|KLN6@QTs*%R~YYPove^-ciU6{2gjDAssI20 diff --git a/sgl/data/__pycache__/utils.cpython-37.pyc b/sgl/data/__pycache__/utils.cpython-37.pyc index ae340f59602da1f42b5cfd64bd873e5b1266ecec..b6ee9fd3083408607a1c26527e2a9dfbc6334aa1 100644 GIT binary patch delta 22 ccmaDP`bd=5iIYA8+)R* z(d<(y^@;zGJmxR?m8bp%Uf?^P-6Y$j0=)9^9D8k_%lDm`Y;m!}&^CX1U`B1m{v_pS zi=f;@cfUhurg_17&B%Kq=yDMCTAXR23z(_XSLU%VHDps?K?+AOPX1CdIXU4n|yRznc z+;w3E&e`m0vBJVE+vD~+X*91H?vTrO?x&IO^j&}Gs%)4W9~FI9d0D!VkIfT2yHxb% zR{Mn+W!1X#$@FobS%irF_2JTbIWSAkZ7!|cy?3>E`eezi6-zp+vZbx6U$|?dr@lkG zR5-6iwx5Dvyu*{wBhlfJuydH3aW}@qfxF*9so4WD=3`N@U0(BAtcpE8mO&FW4_Z{Q z4N=QFS{0gW@rM|<>Q->oQY2IErS1<>>04E4^K~Du!}0u)Z-pU^eaGlEqxwT_o_J}? z$6n@T;oC#=SdkVUXr@auhyM-1Xhed0{9;wAt)cE)lUGKc#*_=#g3Dd}P78Ymb2<7< zqYU3fcb6dy>v7G1EI0xbNiB7R5t!03#+GiwvBbys%1j%ZOKh1odZx1iQYhL87gMYj zkT`RN?epbS_>MD1t6`~)^PPuTm9Hz;-!}WwN6r+h_8fNF^CT$ON8C#kLe#I6dJFT8 zlEiq-w_j^axi(#$Nhighda#MMkRA3-JcNVyxxF}je(9#lSf$>>q9dQ!%spZ4|J-BW zKKYhBD+;APIPsQdZ5i-@t$hvetx($Pys~7 z*$X&ypZhtb@*;Cisl$|pEg#d(4Q{%8`|MFmPeu=!O2MKCxz8b}_HDpS)V)ho_cr84 zgqaq49H=kHq87gba+{ricvL&kzirK^$Z~`Dv}`t9GnP#8Vu*qV9r5yUzFziobBO*_ z255S&d@DHEICGdzO7^E@Si|uyxX}4K7RUWMj#ff6)x=x0{SpKU7)J?X8xlR5c*0TS zT2$A;m=u3XV@rUrS4mtWaX?j2x3pC%8a7d+UbdD{N%M}NNW6uu@|dj+=F_$%;@L|i z-X}4OZofFmaBhe`N3QyW{v5zgocaiT0``R{i7m}hDrXwyn40?v;=h1hw{?O_7zfZ( zHJtPFE2E-c{Tt@L!3D=LUjafp!d6o4RPNhqJbeUukAgys-E6w$BIC;8~|1-urw&m3{&WJ*jKT+&{uCa{w} z1G-~b%U!<50{G)Yn}xlOfB6=OI(nK$4XC5ct~nn+dQ77|2`_~1ALx#TuP4}c9M)7M zC@cZi#EL0hCnX>FBuE5tl^bZFmf<+@B#qC92s~3 diff --git a/sgl/data/base_data.py b/sgl/data/base_data.py index d474050..a5b15a7 100644 --- a/sgl/data/base_data.py +++ b/sgl/data/base_data.py @@ -3,6 +3,8 @@ import numpy as np from scipy.sparse import csr_matrix +from sgl.utils import sparse_mx_to_torch_sparse_tensor + # A lighter wrapper class for sampled adjacency matrices, # as the Edge class seems contains useless information class Block: @@ -11,6 +13,7 @@ def __init__(self, adjs): self.__adjs = [adjs] else: self.__adjs = adjs + self.__device = None def __len__(self): return len(self.__adjs) @@ -23,7 +26,12 @@ def __getitem__(self, id): return self.__adjs[id] def to_device(self, device): + if self.__device == device: + return + if not isinstance(self.__adjs[0], torch.sparse.FloatTensor): + self.__adjs = [sparse_mx_to_torch_sparse_tensor(adj) for adj in self.__adjs] self.__adjs = [adj.to(device) for adj in self.__adjs] + self.__device = device # Base class for adjacency matrix diff --git a/sgl/data/base_dataset.py b/sgl/data/base_dataset.py index 887f5d0..d3fa8fd 100644 --- a/sgl/data/base_dataset.py +++ b/sgl/data/base_dataset.py @@ -1,10 +1,10 @@ -import itertools -import numpy as np import os import os.path as osp +import numpy as np import torch import warnings from scipy.sparse import csr_matrix +from torch_geometric.utils import index_to_mask from sgl.data.base_data import Node, Edge from sgl.data.utils import file_exist, to_undirected @@ -110,6 +110,24 @@ def val_idx(self): @property def test_idx(self): return self._test_idx + + @property + def train_mask(self): + mask = torch.zeros((self.num_node, ), dtype=torch.bool) + mask[self._train_idx] = True + return mask + + @property + def val_mask(self): + mask = torch.zeros((self.num_node, ), dtype=torch.bool) + mask[self._val_idx] = True + return mask + + @property + def test_mask(self): + mask = torch.zeros((self.num_node, ), dtype=torch.bool) + mask[self._test_idx] = True + return mask @property def num_features(self): @@ -122,6 +140,10 @@ def num_classes(self): @property def num_node(self): return self._data.num_node + + @property + def processed_dir(self): + return self._processed_dir # Base class for graph-level tasks diff --git a/sgl/data/utils.py b/sgl/data/utils.py index 27cb1cf..7ed48de 100644 --- a/sgl/data/utils.py +++ b/sgl/data/utils.py @@ -66,4 +66,4 @@ def __len__(self): return len(self.batches) def __call__(self, bid, *args, **kwargs): - return self.batches[bid] + return self.batches[bid] \ No newline at end of file diff --git a/sgl/dataset/__pycache__/__init__.cpython-39.pyc b/sgl/dataset/__pycache__/__init__.cpython-39.pyc index 613b028ffc7058fa0f73fcb3f2499694649e79f9..b70232dce77211bd4246654d840560dec6e67f17 100644 GIT binary patch delta 192 zcmaFGxq_27k(ZZ?0SL-3c%~{(}ZJ~1aJK7J)b dku*q9dU7R;EsGwIIr#vKIy)l|BM*}R698~AFYN#T delta 262 zcmZ3%`HGV_k(ZZ?0SMFz-BJrC^2#zAY}6KIRA6LCVMvk4(aY73($6)BGGJs#XGoD; z#296mA{ES_DLpxa@wBi|Ku%&_YDs=(N_=o)Zb43FUixGiCO1aO$;nI=97R$ bKO=v#Ba11E0gyMjfkkz4KMMyllK>L{XT(Nq diff --git a/sgl/dataset/__pycache__/planetoid.cpython-39.pyc b/sgl/dataset/__pycache__/planetoid.cpython-39.pyc index f15d42ee9e0bb939ac0d0e8dbfb4c71274b17162..d5599acc1baa31f8887530bd0a88086ed17ffd1a 100644 GIT binary patch delta 189 zcmbQQ_(_p3k(ZZ?0SIymB2qg9H}cKmWn40OC9hk3SPf$eYc^9+O9?{_Ll#pCn*>8M zV=WVc2Nq=kiZazQLq&l+kSI`uB~J{fpB<=wQVDYtV;)BdOA1FbV-sTyQw?(sOB!=9 zgC?h65hu{*mm)wS^9v)8c)`rTaEm3aG$$vM=??2;RlW_v5+KJhaxn5SN->InX|c@@ I`1Ugb02`?*xBvhE delta 105 zcmeyQIA4)3k(ZZ?0SMFz-BK?JY~-88%a}5GC9j*ASPf$eYc^9+PziGrV;)BdOA1>v zV-sTyQw?(sOB!=9gC@IQ5hqaROA#QE`GpZkykMSOz_)>00AwH|2O}S&)MjJ;{fq!# C#TZ5a diff --git a/sgl/models/__pycache__/base_model.cpython-37.pyc b/sgl/models/__pycache__/base_model.cpython-37.pyc index 59154ffad496d39a0d65e6008ce128e33c6b889e..17847b640b6e86009f79348bd12d9fc25bbfa1a1 100644 GIT binary patch delta 1238 zcmZ9LU1%It6oBWrD#$~MrcaB-j0 z?+olHX7q{@RLt0-QcAbB!`Ia%kd?SeFEYvESQ(>+YZ0k+5 z*K6ywcX6Df+r~=0TD6<5wOYfM=|>~?>;#+7$Z8rxQ8m!ji@6Ega~I3A^E1b0>yBF; z+nYs+NI8UM4yCsvg9X!Awwo(!)_TqHFvHd^vZ6_j=jhW&sT^WWnrsnZ>15wzMY4C4 zjWA|J+XWwP!KF|w=rfzDAKI2a;D-pJ2PXZ1qi*Qf;LzZ1*bm+UkOlNzLv&xU^iJ7< zxpv%YAeZRc&G)#kH;$Kjv6r_qX)5~khi1g$L#cVIb=iwAjJ z(y)&^%!>~$nDI?ML#tb=yJc5#j2)(4lND^I^8d<<>|JH?QVazQh^ax)!B7oIsG6qg z5CQrrdJW#EGqLM%i5|rKA|B^hj58kX{<;{Z1+pDZ(Y5$~_=s-D4?&#n#Sgxl$AUyrWoNjjZ)05eoheg;=s_maN@%(w3J?Erj5&kx*$Z|H}Ct1w3ksV`xH z9;TM#Z^{mxYm?+Zp-TELd`F|1R(hVtr5>DFeyV>gxxYz`r!Kfa<89JLdNp|isAg4CeH!D-5;_m7vSfks?L+~3-55CG7 zUmg6jo45s|xWL{LU?$Hli*CMZ4o!j@0*>Oot^LdZwg(jVePQ0 zS|qw2-Wk@-yPs%fW^IaYbtwu7DJmznq2JtWOjr^X*Cae;4p?qT1}UocOQEX5l6g_a z6>BFBTJ2b_RzVT#)EZo{YIscbxfJe8mw}A^BV?IA%&VSCC27dRB_%88m8-lec?Fl% zoxX%bjpazu7E22~HlfF}Ia*WckXcmF(?zsvO(R8XEGy#4%tT7e#^!~C5+Q@Aq(%6+ z2WqhtJzzjxgiy-imE#Fu)XH&6+ zJ~uh83meskjO(xIB42bf_AtWOV)K+Wk(3awY?Q(N4wKWF%xLzRs;V`Mw5yW z#~MWNb4?u%IP;J%ymh_<=qt?D#{kP*JKzECa@~eUxZ=7Ay?ETc4E^}SJ!Tu=H*M_U z7LIyW;VIU83+_J3gvv!P7mPBR89lh>ZHF-Ws}I359Ix&$a-HbIU%{i^2)P%@GOWD` ztadtzMw22U2sIxe%i~q{a5?XKz4Hn3@V>JJUSVhLL7MnV?YH&8N1>*lk|7YK88;8R zm7-s9ZTsc6L2mPp6rBA|mXQfY5J9dv$pFR>-t%?95We>XykQRf=Wz+`{(9puXB9l) zuNyqgF6%2yjLjp7^DTcuqK+}f7-u9HIwQxJBNPpp$qRbXFgK~63pqt4KK6S?_)j6G z8N639!`Q7n;*tm$b z<~L`~J@?*o@BN)~uD|#6FRrDnR4T5}@6zfm=Z(pC(nIXMch1fhY{d>2Tdu4>TR5N!Ur!Bo8#5Hd|DT)C0`6U@$*& zwOG!-{F4u}taD#oSeP@?e#EU-9L^*6AFyG@Vj>#)doD%heyqaF6>G)v9G;KH+>r+-UoHave*3RvSWEw?(iW}1u!hW9R3+QDf%PJi4JHx0jI=w zB4>>eh@%Ai!65u0l6|dQ^(s~G&=|jqY!3(16zol3vz?7X{_rd{?T-d#r6sFg*-(9T za>@_GpM8}-hKZ&D(y?LjW^`d1tKnDz|16-nV(Ag@BAubhwFNyfhA)!(MDHZbDlT9J7_2YWIn=v8lKdw&QkUN7-h=iVY6h8M}K&$tCR`yO-Xn*Ojo6OA{6P z#cQ!M@emDw$2~9Vtp9?|ZmBiJ4AT=ikTsH{wnC`e3iVSLc$7LAF|N~#&jMoNO1!5I zu5g^C-)nl@DwGPP71OoWD@6$9ir-~A8&O&(P$h-dDzVI5FWU~6TM(R!oP%JF z9=Q(+OEZln>2GV`s^vP1v-8i2-^M3IBANKq6j>RJe}P{6EI27SOlvHRGcPVBC&b?p z-$|Q`RUp3E7&uh_4XdI5aLN4f}fISojyo z-jCPf*Qwz}dR&b_Q8lKkTZ2b^s<%eU%gul-{UC#0I13k1EPc$t z$6d9?_9!R8&xYLQD}S1-12^>shoqB6z?NM6w&mql%|gj``A{9kt@C|`K9}eTdGujW0Ik}D6Bvd*#7{2GBDUA0`x^EjV}brPXia+k||-Q1LE&5wCy^B_|q z!!@!Qj>l)4xsG*;Tr=2RM5N4;IlY7tFyE0_LRwo!Xt+qk3X8L}8fH3+t8vz;YAU5H z5lt=Vl7cGl5HF>&y^oOP(Sgp1w^B3DV>LJs$ru!xnzF0Z80}tYFLc0U{%O)1;UH`1 z4IcQBAZle&KjP$hF_pep$KhgpIrfAcdyU4<;#fm;tl^qUOe6_svKy+=_nvw;Qj0Vr zsI4x1iTWR^2G2pU2#_8;3#nBS$f1*snm{eD(&Hk5Dy);XjPh8&SnuqwQ{K8tN~42+ zo?iSi;Bf#-fUlzvJKVd&k-_*=fEfUSl!NuhRT+K-4Xqt%)fOag1ip(#ou#9LDvO8O z9!=NuE_y%vBdDU@|AxEQgaP~ z9c}_2iuu7GFCxVnQ8_y=!BSp6z5)qR(02KV9ldFE+rylZ49K$$@>~e^R<09Qt6&3Ofb>6!^BZ9fKTf=PfA|lqSFR$5K3@gg z0=y)?J92Tn-`{AEZA_D{jLSyyA}d!Kurd1M#%@ z^QjYA?0?%vHK$96Q=@(L)6{AVNE-5GSfb?c6~HQ>09XUO3@8HD0ZoPm94SMGqQK>q zxdQ3y01j{gJAl^+{IFZta{TZ{!KuGuWC9YQ!tp9#1Mmu96L1G$0=59(0Mq~h=ONt% z+ylt!`zlEkFRKXlh2o@9E3M-YTI1G5yd7|3_f-iOAUIBU99cay`PbEhfN~ zM&X)9lgzo_6h9mr|3u1X#h=EeUyx)0EbwokW9#HwY~H4ZU(<7#&5seA(PDJTU8UPp zndet0OC?_|m2$!zAFUhI>}TAS;-qbP*5qxvnwpSgdG>q%yTNrldH8+|ZB?&ObY(&J zHM&7oiw=*_h`}WrAtvWw(Wa6M zkz$G@h8-hGY(805Wqtx!0V&jwMq zIpYYSKYYSRIgiPa$fv~=jS3PoqA_DtEYA_W)D~$av?>*0dUt3_uG4nZpmo$H(3 z(eaPy!f3cm)t(Iu+qqROtzG55@s~VSUFL!2HS$`ZjqeY1*t4vOgIMT`0A<*){4iEH z2yX;j5HABd9xIDP56SX|@-3VR9Yr}7-?e@(hATyfQ5yyvkk{hZj>b^k1{en%QauDg z#xzSc%XLdmt!A35TpPd;r9qefiEoVR)Gpm$*WX9q%$+1~SdrVTsM$F=k=Sp9r!40Z zL&xfa-aN3VrL6eBbfCy0*`szXA9b*w!}EI`IU&zOcb3;)MkW+Pb18js>c;3 zY>gyIRZl5QE6mu8R*4koUol7+w1SubD329+T4~d@re|u^dc|za2zA!j#2{-rx2#%| zSY=K(>=kOMeGmC$)*E8l1bs3}fn%FEHaAGOo7Cs#wUB3-waV1~sk7Ljb`kcRq2huq zjAyl8Zg4}R5`N^*lZOg%>hjqXo1zF2`=~+6U!kucJtLfs7t1_6eWhK2THzcGyhO#- z2$(iEB6&|=qBBdro*t?^e}T{?Vb-l_N4!c3R&hb8e~o09sdUu8V!6)66Bpim{qzOe zz1sLH4f1??SpGeg+M#^l;w1DvOX*qaxFR>DcC|o|_8;mfnC9*$z@FHw$;B4b7P-%t zKoQLcEh6=_jg}s^K#66I@<~+S_={Py)Tq0jXqG&jr8q^xFuGKnZ9#L0ZEHKMe5Dp^ zP;pfnN(bAMx+C4_?4+asg}!5|fr!x^AirQ*KhA&Z@)_7&mTPqT)V% zQi$$(D&a}-xzX^XEhISv3AV<)#K{JIk)qIMOWJ%?)iqT&$TluM_M>*>h9<80 zx~&hg`54(68)S>x`)rBN$M3N+x1+XkH|`s!*aDwV_ldrvECgRHEbVpj*|Gk#(*;U)dd0uN^_wK!)tAAQcsbIO3!Y6GUy?N)m+8k2cS zMvJ$74Gr~EezL4lPD&I60!AR@i`g@l>zzGWxM=l(Euxrq8%RC1 zv9VUL?qsd$(mkroQPzmlG>~rJNX7AjxF@V?ol<3Z_w-Pg%0R`eTQOB1WkuU@1HI}w z(_tJTq5{L3nQ`j2K-?BH80@%SI94s7b|T-(hNZl>ao|~9$i~GUr(AYQUiFq^s;Cj_QiO}R#M|5XCfaJd4p8lFDlT@I@icu5 z9;Z*1=eUl*r%#eTy*#HSsaNMCd;@wYgZI;?S3d3?ZJnUW`HVP1U*ag>7ywB>h~x3% z2=9(Sh7R3C-h0-4ERLh2o7lUz@C0?h3zay?x~#^Nk(H8*$Kr7#9uX7r_;6AFbKN^f zk!#2vbsr*YhxZ}UzeoU(*i?QD-%=?Pk10H!D1-ILpXI*1No7ae_Cguay=UOmqTCYc zf4I5t7o1;t_^AIPH$t1`fem9l-J#&(Z>6=Nai!xDJ$gXmfUkaRbcp*2b-Boq2G>3l z+c8k1X9?vg_88Y%zF&L+!erI0Qc}0-^w`qmo{bkrlxs!m6JN)Gs7n{V7kW4P%R2*)AE{!v~V zeDKvPRLjRiXX!9--ohL_Q&&CO5+V21a}Tr+k|48Lt<$5jXv~_ms_O-OR!$DRFsNAj z)S95`pXAceVH;}-Kw&xlEY5}0>b&xv&ni&kLn5xa?q5AWAq&!ynLT$x2F!+CR=x@>;IZpI>IPzz{UU_sCbG`{M0oUY@Hyz&I-2@9l zbyP~$Ax#gL$<1P-JJ7+vZnHBmM?6}~E>Xcy&`VsGzuJ7vxQ^=WF>lWoNwV!JP)lFi+_pSxe3IU}CTgzZsgo`CppV)&-0`MYt-0#r3k-S+H|d5)3KJ6mM3`D^m3_)m g*cZf)&@T^~a10c5dkH8AanaB{6v-XQ6?5nR3ngGMzyJUM diff --git a/sgl/models/__pycache__/simple_models.cpython-39.pyc b/sgl/models/__pycache__/simple_models.cpython-39.pyc index 4dcfdbc540795a34fb31224c76718da3b68ffe98..fd41b8640f753bda743976f98114353fb8d7367c 100644 GIT binary patch delta 4420 zcmcgvTZ|i58J;s2&-l8I_nK=>dc!o@O)HcrC2a(0cMFK6)ke^mrqhYf?5-Vq>^(Ei z#dvGVrZ;FP$pi_cK6oVr3NHxl6Q~d64GA6)5-)=UZ%CmCA$X~Z5Pbg`+na1c6$z=c zn&1EYb1w5=zVAQ#iz6GSb51s!QgEfe_Po3Q(AC_&-UoQ%3TJht9RIR1lh0~HD{Cya z=?zU)n9?xH38pM4T;GdHxd1&huNSd4)s2}^>Tu$Ti&3KokeVMz`wX;@Oc zmxd`lFlAuM@GLCZfh7w|j*r1IHn8Mi$@5*X>>60cU@7o%SjGpIJS-D@5|+tWSNfu+FNl2V#U+@x?UhHJ)`wsx~0F=mDDsXu4qTOI9L#%97_>(`jM^`h}d#vTlx zFyCd5h98>$W=FygtOk2vYdIM*b~5}h^#k^7cs>0ywiM21USXdLf0g+IJ01Qo+h#9> zdva&k^II2lV~o8NzA^SD<6$8`&+M(!`K*?xVSoQ%_|uvD!?(w8vTtl%nkX>#TKLxF z3#__zbMifvUEBJ}%yq_o9G=*-Hhz`H>kt;pO^TPGRIcut2zz;>_ZS{rh4NQcIV)>C z`MR>qxW@I%$ucs;6PJ~;!A)-AH}SfnD`gWwNJZwt$;FdR?|f$=Rm8PaZM5oc5yoQO zx#WsNMZemp)*Vp{n#I*1X!(!N%{k&?_55s8tjsxQ{kcbv96t8&?BU0b9X=|Tn6J<$ z4lmLZxHa7Vi{ZbLrQW8x$-0W2>}p~@F#3tW>|0&fq&4Mbx=KII;mM}T6zG{v)E`&y zwd$fcgJnutL8}T1pD-Yzbk(z$T_*NQ44BuDx z^vL6rH1rt=&H|+mk85gQN)%DQ~6>3yTi(n^=bzSa_4X*2N@#wKomC;itJSFjwe zzn>6u_$ir@-f)7|NcY!7AiaWG?YO;2=T6{6R@mOw_#8?tyOw8L}J@_iw)QDijH4&R#s5C zPEc)n#S6|Q%Ndv{4c&2w5V6_Rn*%4rkMZvJKEY=EX!;5sPnkC#4 zPJvRx?m#``>J<%G*B0XvS0mOE_hDEG+I*DsP|i*rHC|#7yEsfZ_UKU=N7nJW(>TkW zC;x({f8tg;_YcE6Yui800CGAN>8?9I z_>S1Am@LsoP6JZl7Ye%M7^x9o#BkEYWsH#@Bwi)0_ZgCRf&brNm+eI%=aU`N)(tANzTP!9*-?dt*9`FmSWtd) z5)}eSpQ@!~#suDmE4K~SR(QOp(`>OH zCPBTRLh%}2R6}%2ga%M{BDAk>5~vIc4E0GMRNX)aL8d{TwiP71I%hM=275|5Gqs)= zjZ=r?dTYmTF{1r&?FwRx1+b`{l5Is_l;1aBVuIpWO2=4u6M9U6vev(d$4FbL_##a& zr0b?94+*|Vk?E@tNE3C($ndBiNBX$8TIEr~tJEFe7vGk{!lJqFK2+;4-UNaIK#)f> zonoCVMU1*=cp{yYsD=j$9Lx|!#f83xMmZ?(8>;vgElUjyy4+o_;6$?{%hA}5I!@6S zQ0dBYpP*fLwz-N%bu_#vb#zPM{TWIZg&lwzH>lIGbE@zHkYo)cHb}`4+3&Qj4!Wcc zh~EwBht&3Eg$H^~QZ1xXGl}}vwY4I2jcamF4+)DWd9`>eP8d29yX) zD!;ca-neCp(HzeTEN*PDZ=x^Ma5@7QCwf=ZCFO=WAY&F;fp9>g-iojh0Ex`b?5S0^ z=sP$#!X7uPVx#Q`#j|eFU28jaIy){b4o;M_;s9peNJU!R_2P<<)gzO*PK%x}=$bRPIL z9)?G9mx8G-Yxi_0PevM!Dcjpt2!F$in1*(goN>3VKIDuZQVh}nQEZb2VIEDrYm6#% zYTX=-8Bdh zxGXM9rn=f@@lk^tf!whXC`caU# z30E@~_$KRFs!2D|N9`-NWhLnYWqPm6V_puX?o}MfSo#>ap!52*=KX3u!p(wLSDfNU qt5bA&JXss7#o2|fm6f7NKkeO=>&PY(f5DjMXZ%J#<(q!gZ}|n;@MXvV diff --git a/sgl/models/base_model.py b/sgl/models/base_model.py index 9ad19cb..e73f8b2 100644 --- a/sgl/models/base_model.py +++ b/sgl/models/base_model.py @@ -86,13 +86,20 @@ def processed_block(self): def processed_feature(self): return self._processed_feature + @property + def collate_fn(self): + if self.training: + return self._training_sampling_op.collate_fn + else: + return self._eval_sampling_op.collate_fn + def sampling(self, batch_inds): if self.training: return self._training_sampling_op.sampling(batch_inds) else: return self._eval_sampling_op.sampling(batch_inds) - def preprocess(self, adj, x): + def preprocess(self, adj, x, mini_batch_eval, device): if self._pre_graph_op is not None: norm_adj = self._pre_graph_op._construct_adj(adj) else: @@ -104,6 +111,10 @@ def preprocess(self, adj, x): self._processed_feature = self._pre_feature_op._transform_x(x) else: self._processed_feature = x + + if mini_batch_eval is False: + self._processed_block.to_device(device) + self._processed_feature = self._processed_feature.to(device) def postprocess(self, adj, output): if self._post_graph_op is not None: diff --git a/sgl/models/homo/__pycache__/__init__.cpython-39.pyc b/sgl/models/homo/__pycache__/__init__.cpython-39.pyc index 46f849da3382f6e7b7fc4eabe681b7562fac27dc..591a3bbbca63e80307fe655a88d6bad474263456 100644 GIT binary patch delta 188 zcmey$a+#Gck(ZZ?0SJEZg{L|&P2`hdTrp9*T`h$ng(pWWS3F9bkpal&&5_8JjFMzz zNM}glTf`V8mBJs)peZo%K!p_JEv~S{yv&@OM0Y=jDNQM6u@pjREoaCbKey Vu}A@xPEKSJo?O7h!pz9S1OSK#FAx9# delta 135 zcmcc2`jv$*k(ZZ?0SL@$-Ba%~PUMqeoG?+loiCjsg=Z0ClvoOHFoP!F#2Xci+>^5y z&6)i)#V4<0Ow+l=o|af#lAfGb1k!zrD>nII?t diff --git a/sgl/models/homo/__pycache__/clustergcn.cpython-37.pyc b/sgl/models/homo/__pycache__/clustergcn.cpython-37.pyc index c1d662793efc8e618d7ce3f95dcb387d2020acbf..e845138149a176a6551ca8b3d55c35a56efe512c 100644 GIT binary patch delta 240 zcmZ3^@tTd#iI44Ui{cPemlg0ukv=fp>vl1w!W@hl~*HOw_EHLNx4HEb!&y{rq^YB(mVFs_p3 zyv0(Snv-@*prj}Y5Z$laY$6FUl0ae#D`7?a1y%+WEdmk?x#qasUNqqhvU?(gg*1UwDZ+JD zRtk3ho$IXq1D4L-Ar8DZJG(P)-hT6U(KxHuYY1Zed>eP@wy``ugepcpR>&txtkDE3 zro2_vr;4k}3i)h}81lUllGW|!K~=bM2(is&7h4pt)U(7=%PhB^q8sm(+e+5QA1#qH z6DJeDo$7EppU&@s>?Jo_P+&3`|osqC8eegXfpRc!zO diff --git a/sgl/models/homo/__pycache__/clustergcn.cpython-39.pyc b/sgl/models/homo/__pycache__/clustergcn.cpython-39.pyc index d01622e149567d1218670cd2b4f6a30a1a9de4e1..956a7442e31e2cfaab59ee59ed268d278f28a9f3 100644 GIT binary patch literal 992 zcmY*XOK;RL5VrH$?6M1`sE89P$33vQa;SJ!XgRb6X>SpdWn%9($a&OGDi-BdM8%>1 z0gnAkJ|S^H{R^C!aawh8P)OTlJ1x1*Kl+1|;MEI46Q2J+n8i{B@rh_ZUCR1wt`Ti07?!D$lE?ynJ zKX`Lk38}!fb--)Ivz(W6&FkZLb%rUM3!&gLkHJWiQbAHrMCW8lUCUwWpOb*30XDRw zebt&qYH0Y!rIBARu7Dx2epa{Hu=hEO=K*XVws{MtAr0-wnfJ*PRCpJ_w(*58f(!7n zphFvW9@wZEbbc3fVK?X^5nhBp;x7J)@u469BRaCN;YF?Tas?bG@|3HVHq_ zn3YOUAZ0LCC0ReoRaVaGwcdupsGM;Rmu}NBjo|z8;CA4%8(pTp6ItVQbqofS2GL4h zNXG1lA=6gjdCb^J%T*t7JwVPLW+s&aiFdJvGx+}W6Ervl+y)hlskf44+(k&~+^mdulZ#l#zTGwAkPD=obbsTn zyuU-aEfs%(iWxg7S!+h~&3NaV`J5Ef=@{_*d^wZ97{G4=hbKVr9?iT#CqOU-70Vf; z*i&BR=YAjiDyYIdWH18}i0}$TD8sFvM|qh97UIt!L{u=Q@dTVP=Qj&(s7cgZ&BfCn~29Cp#xzll2k7T@Ht0AyhE zG;yJ68>uz*a|!3`&DwIVJ%ST%XwCIz#p^~BoAYe#j`$Z`l2^1}I&y)LYUzSPuZ)XD zVXbsDP6eIMRV?wr7SEL>&lJm);8@B_t{oU<06aKrM+Wr#CfR-V>rKqzgc!( zS28o^B3sN(U#ZLQnOUi5P!85xt8XuXlBZAxoDOF1P zq|9HM3W);p7bt-D#)-0wHP7$xyS~qSeSJXS{Ce;~e)9Pcrnww+Iv@}n6i3s6g=i02uTtuNaBg;m@KGkxt93HBp^wE4Q=lY zH#Kzo>2k>TXBV&@v3^=LX}|M1iFyyN57)c}QInds8w&p=aJA~vZuhuoa)v`RS{ge<%S{1>SC@9kJ_0K&cZBq9V5C~WT)YNvos9l(C3O;C zcaho-U@ajFCbC>_;H(3zf4N;G(jeSX`#Xq&n0isEUA^TObJCpfQc$@#x0i$n*Yu8C^0yxx#GhI|4M)tC445 z0RSy{E*Y}}hO0IT^_a2!hN~{(8UfVZps7>};J=SGoVbhH77AntuR(PC)Y&`m>)nAR zCfyPflL0T6pqh?`&6V&q;u?QIdmCdV(9i2qo5>Z^j(t~uWe_*b7H0LIbelY-U>U?H I25@5UACB4y&j0`b literal 1465 zcmZ8hOK%%D5GJ`#$&&0o(l(d2hoFFhSZaIf!Ehf*>x<(S&~1T&*eoe~4erAvH;9eR z$u4p-(EbDan7_nRfgXyUdkc!DGgK5CzRZ@c|fL(z> zZZ9kpm-!5oBpDSXVGRmr!z>X7v0=h1U&Nl(w| zZ|o#oke4{5psekbIX3&;4OMREg0hC8Y>wrHw$cLPHF4MNWDy&1xc^Rg~2=d&Hv6uC0HV63legdPf_ z$bgNTQp_}RQN5J8&BFSq6j@vsDmNxeYJbVNMChh%Iy>ERk?C5gbqt|gy3PjPR+U!c zT-}(rTJl3=A8%TXAh`)In%Y?16_)4X*afHtR_A5SOYzK&xIu1~^&viq@X4s|Dvm{U zS{fH9SvzJMx51wP=lw|3Go!y7f9{ z?BBDk{|})+lB^fnh|fb1GXez+CRenl*A#e152j3i0$!J{kpRa%aAdz&W;Jd?I>N!P^*tAG*dC^Z?W}#;MhLB{_FPjyQG-?GfiEU9MIqzHN=sVVWyc z3aQi@8Bg8&7}8gPTtI3W9Cb=3o1Ps9@ZWF;csyJ2cy=>3jn%pj@T>tY<+jmg_LeV* vQMQ4Hz6szOhRr^@6>4>`;(+mN@fmNS<157CT)MMraTkBQ25}fs{{{a6$SQN% diff --git a/sgl/models/homo/__pycache__/graphsage.cpython-39.pyc b/sgl/models/homo/__pycache__/graphsage.cpython-39.pyc index 03eeb7376bb9c3fcbb277afff25926676362a245..da1ff506db7dcc4ce06654d9453637c396f2ea42 100644 GIT binary patch literal 1094 zcmY*YOK%e~5VpO~q)8t@5ho5Hka8gT0|+Rh%B3$6PJ6MI*iD1HkJ`J1qP?|JPyG+= zv46=YBo3&5p$A~bTiW22zs!s^uWu%4zuzTr>Sv$i4I|_yDvQSly{*mdLP!Os)jchD3Y~3BW73)2Ey33L_;sw)@?f`|Vt~biMdC3}0VGMNAPE!U z8QG`87r|*K@y>`(5+Bk>w&TKJ@?Xx2{CIl}trP1_%X%8OHisxja6P!@BS1we+K>}= z!7d3Um3KlK!u=``6h3y&KnokhsPO>fs?*Rf-Npx8YQiQE-g$7*Yobet!ux_iJ^s!$ zg~$VDX=yz%&jx`F@_m_B*3S>-!iM=&rN-DOSLw0TAP8NQMO_V+ZD8tBY8^rV8*@I- z=N0EVK()_HE%^)?<3*{FBD>;j7`ukIqO|M0(kXc5Gj7tXR49OXjmx7{Et<3=sF&9n zzjZfu2{l)nG?rC`UE-f%l00)mjID2^+J|b>3aph~{Xdn|ws)Tev@SF&WX^#Cy@u6x zG5&4Op^!)Osl0+U;B8LMugJ}CbWmh+WX$1cGTwTnj=zn}OpROx2H%R2V*rnNfC&NFd74SiZI@%iwL-ba z`C*-^_QfTGF|_4QrBX1O7Z}6AbCg~p;PmhsLx^46g?yEsIdenXF EFIy8U)Bpeg literal 1480 zcmZWpOKTiQ5bo}I?Ce7;ArdDcBm@E`3)ZfZvryvLW-qc5*xN#A#?vinM9)LrJ;oBF z(<=BRi2s3f^k3>zAcx>{ZYIR3>di_E>7io>xLDaBUr=wMw%*EvGuIPkyR1 z%&@)?B`)(2lq7R1$efApYjQ+Ag6=#J9g$p-dH9-yWF7&Ip6mK_h~iIwn*9F#`{=vf zoo{#FF>tsbtD7uYhP+0*0}_Ck-#}SnX-gLDlnE->HMFITZUPb9(#uGMBD#Y8Q?h`a z4cl16qI1n!!Xq2E_!_c zXd8;|b#yaqJMb3an}T3bhb}y76n1SxzD!k@1v-w>YLCFYSmqHr{nTm7Yt@O@1Nu5YBk@emps;3OO<6Y%-f3?5Ois z6LYL4c&joPt0u=dH4f8EKhRb;xbdIF|Jekm%tFZs;kD1wp?VIT526bDwnKo z18=Qzt0VMo@LEgmQS++So()5uRi&}I$t+LBsS8kbtxofj=i-$caD!aV%VRu(6mf&H zDL8^^v@|YIvUJQ=uFqcp$NgB;C~={Y>Im?q7QEK=0KSkV7kS}&0Jv@}+WK?Is2^b3 zA%d%+9B!7XbzRPV*_yvlE!C zF!$8a9^Wp+KLIiRBD#*@E=%YKudxi>MTP%bqZr3mriJ9(^*A=PQD_f1KWkF8jQILB zUIJrXwkMNHty<)HJjRqB43YvlX7J-k=y221L_dX^lpMur9lsG-P zdN3-%i-&hFDR}f?>8(fq20;k)>Ol`89tCk;yoh`F%$xZ><_-IC>R`&8H_ZeS-@&g( z-q&4oTd95#yMgSoO`Vqp&}2jC!45oABdMqMX_aXmfTf638LC2U2Y~gYVeD@$GTq;< zGqJ{_L^IhmA~Q-2)A&P~R5MX_Afv4w$&DKh{nwD6k!rH5y2?vEqE@L7eJHW2mJ*(fL}Jer3oP5wgWUY-N>Q delta 551 zcmZ9G$w~u35Qe*w%uF(wHj4_Pf(xjO5D`y`B4jbHh)GN@AMvY@aX~lL$4>j4o zsf+bm+HBy0psVf%s*dd{U z*@C*(p$%afFBG@xnCfTO`a@_pzxdr7bp~#FkQl= zdJ$E91SX%y^ge(DX1rIB#7}QjP10)$M~z5VoXQnKln^6K5jqLecx!}U3*U^QcZ-Y( zLI%UU4Pwt5`~qfY$caeQl_Rld*?<1op{H5GDq)u}N0^r|osA_o)Lb!hvV9=-(DV(! z8YX-XtxMFjPS_+A2nVS9$03Mu|6_kv8g<1g9Uh&E6xG3O(aM&D$XoW%PRTBb{IQ5) HXCV9yq7;3W diff --git a/sgl/models/homo/__pycache__/lazygnn.cpython-39.pyc b/sgl/models/homo/__pycache__/lazygnn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..105e750b7088e64258385c8af220a453fb643beb GIT binary patch literal 4508 zcmb_f-EZ6073aPD5Ji31k?k~1+s$i-ZW>xA1-f}8aAc<$FB z@A>%MbHYZ$GjRRo-yg?cuNubRsj+xjXuOU)KLUx2h$Y5=`HT_I6FxAx!Hj|B+q%b0 zoI%B}Xx>WPLDjDgJm1r8JE;xoeto*K;WxB|rr*@hC4X5z&-txQBXXk39V4p5m1EOi ziTIY$cJJf;Mw==7%_QsYcfc*Z8RqfU8}Hq?{?2<@6en%2mam7yB<%LWbW?=Gz4wQz z`oXO!ixunnFcf(l3~mQ9!(Vjwf@znG(;RJj$>ODzgO2=Nkl659WcWOK;;ym7A}^}l zsrqJgA*x4>JIuGDX0&w2@a^cy=v>r7j}xs#tLUjjYteZ;-MfZm_*HD@qN-dEj}A9G zozDH=K(sP9-|Bn?2G`%N-sgDSw~0K!_0^ZZr|X|rmDwGRx(%8%y@;>iHgV@)f=ENM z!Z_v=ehR}BcEt9%SeNFpQ5b=huA+yd=VfU}W@LSixq(%1P^U0KD}@W{7S+h!x5NiU zHT7h5;+?`I7fC{`sDsuE542I#K$~(&E*JJOFIfLvVS%=aN&y3YZckQ<`YF9(H8RMA z9c693vH6au;yVRt7A9dr&>lpMnzpjjoj8=rO80t^a?)-R=DDh*Nq86wFpKg80)g zInz}S2I1}CRwnl8jkb0YwX_?{ARP??xhLW<%2hdQ)~W})y;ykIqqgPE8SzFknuN$+in6v?tzTYYLbdRZOP<~<%+%vv; z;nH3&!KJOu>n|pUw=dO37o)P5-oO;8~Skzt*qIhHOrj;O%VHIA#(5oPBZm zGvlstm)$jYOp2I0u3VPqPRwI=%-+TvtZjV+FtqSHHkpwwzl1=%jM$^!$GElpmH#Q9 zL%#h`K3|Z(`iT5*7UdV5_Ms!p>FkzJ@#m}en79m0b)lq(7oiFT%nI&&9b^|F{mF$f zJ256uLZ9{d3Bvf8o$}xE18d6=Kf1mc7OOuvdVXeqgz!*g z-8j!Tt^n}EB!3OeY~lZ5TjD45MOQ%DrnpMH{_Z6i=kl>1F5bYTmvEo?;`W+&5`C%} zbh9*HPNk~sj7k7$-PDM5Bg>L^rUp=EFO}jMlKd@_Y>P16jg^_iDcB0*q3+s=WP zde7Rdc+JW}b2g{#l-e<|HN|d=*Gc%BB-~9iG0;0E3%9MgD)o3dv$Xr%r0bYEMFheE zmwP~lWy@htB;_e+@!T9Uy|o9Jh&6jqMx4r&F=d zr8%+ivg`dWvnOt5aM!VX+TnkS@^l^0yG8<0#%RJW|}&)=318yz}ZDS-CA4U2UD04 z1A?v=fB@(Ldf9^hz#Dz7F=}gU;hk#;6j~y#5qTQ~=BA(#AAl%ZJGbwSaD45IlH`%C z=p!Q7$;7QtL=;Rp+0CHifX_ev%&aaU{p_lPA+UH3cm5BMh`!L^u+i9xcxmBdhPFL+ z$kNi54y*~gpAx9CG4XD{`Wa3L?G<0sSW%gH$coHU{-iN>3scsR136C_@C$eadC3#T zxGL*K75cHYpH#8F3ZJlLqu=Z=LCP}aQeRyAjOnwY#Bv8`#krzdo)cq_=1y9LcZzf4 zljgWq)FvyZ#)mi|4*s}^@CUEr1o`VC!XLbHZP2T}Df2TW5kh6Tz&0yauy zyRMyh7L4>#T?hCs;x?eBaw%OR&IbmMQd#m<>g^tOli06P##yo+LG-L!^LpW2i5hdA z62D7jZqo;9jBK9O-6jTd+OL5aUuNfllrC@Sn5XAOzRZr!KVI!;Fp{E(a8gJ8AHWn$ zNBs?{$G7g0HTQQ6KafB8gAjQ-iSa;04UJXfBOxq3XaB zp04Hm4r9o?7UYyZU&^VTks~qJnaV+3H1+eF+Q*Nc`zY*596{^AjCGt_J$Pm>OV?qS z-gYRvdm9@YN9!LxW?;RSuT!FaNHd;U6hlYdSpgSsp6Oa#+OCQBpe;dm0Hg~#SLe>P zPeUCdK9PjT`$TSl04oU^Q37Uk8bG3(o2M-WKPi!5DQ2?Ecu(rOY%<(60iY4c78CmO)rqdF=L2XPQ6FQ5u! zln`GJf`d_*lp{K!7dL73Cqx1w8Ic1Jq!MMN(2bJ><=A|`FVhrqab2;|$K*t?#T9SO zt$U8UqDR`Rs)cf5gD_}AS3<;-kT$fa3nUpK1>{X33-Qv4=gIZfyH7~6M zg@ig!V;Snm5P569ai(M}Yso6j7|C9eml>VTs`hBDjdBVFjs~vQJTu#vCT69dP`16G cg}Le7E2Vy4qa-44fz)AL3i3LNA^qe31*fZ!9smFU literal 0 HcmV?d00001 diff --git a/sgl/models/homo/__pycache__/vanillagnn.cpython-39.pyc b/sgl/models/homo/__pycache__/vanillagnn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bb7abad65e2f40f285ee876851f747e7ced9f35 GIT binary patch literal 1276 zcmY*YOK%%D5GMDnwUoGSQlz&<(W2EM(zS=8c_9!AIBHP=EwC4YUQ*H~+($^R90z-I zfSmjvnqz)V?mQLfp~zq8q3z7dfo%#q7=9#YzWHXvCzF)GdT{u={NNMv7jCYv05%Wd zGmpUtBB&x2^#~=Em-`eEPX$$&hlB5{sEYHrO7et~1rdP=-w+YX@Z8Twg09Fk`Ur8z za!TFp$K1%(qi4SyKYi8+sivMgJm#(9Mak=h=I!?9){TBWy>ttFcYPr%csfa}UW|RkSNkH-tGLKyP8?Q^H_+q(yeUfIdezI9TmPK&ga!2- zKmtI%x1k6{bOHBS(0gE`-Ul1o#E$y#-0SJ{v=6`@^l=Y*T!ep&`{W(op9lh)Buf_{ ze`v=m8^npxM1~Y@l;1ZNFzPCK=m|qVZD~zx}e^c!bNqVxG^rS6+Z**Tr70c zHl3Z`a-r#3sdWf%UBVd1Vj0ueGdE$amTZF^W=*THa(9<~(vBHkwF*RpGxr%oRyUw{ zBQ3X9yYZ@oi*lG9&9Uw?iz z&a+PnPlKzT=>#5Y9I$DGJ*asl8FMK^89Igjn6Z}~SC=O-W1=abO_(3rv@fBMN&)sQ z1{=&b=zWI@1LuY6J literal 0 HcmV?d00001 diff --git a/sgl/models/homo/clustergcn.py b/sgl/models/homo/clustergcn.py index ca0950d..b4c38b8 100644 --- a/sgl/models/homo/clustergcn.py +++ b/sgl/models/homo/clustergcn.py @@ -1,15 +1,11 @@ from sgl.models.simple_models import GCN from sgl.models.base_model import BaseSAMPLEModel +from sgl.operators.graph_op import LaplacianGraphOp class ClusterGCN(BaseSAMPLEModel): - def __init__(self, training_eval_sampler, nfeat, hidden_dim, nclass, dropout=0.5, num_layers=2, device="cpu"): + def __init__(self, training_sampler, eval_sampler, nfeat, hidden_dim, nclass, dropout=0.5, num_layers=2, device="cpu"): super(ClusterGCN, self).__init__(evaluate_mode="sampling") - self._training_sampling_op = training_eval_sampler - self._eval_sampling_op = training_eval_sampler - self._base_model = GCN(nfeat=nfeat, nhid=hidden_dim, nclass=nclass, nlayers=num_layers, dropout=dropout).to(device) - - def sampling(self, batch_inds): - if self.training: - return self._training_sampling_op.sampling(batch_inds, training=True) - else: - return self._eval_sampling_op.sampling(batch_inds, training=False) \ No newline at end of file + self._pre_graph_op = LaplacianGraphOp(r=0.5) + self._training_sampling_op = training_sampler + self._eval_sampling_op = eval_sampler + self._base_model = GCN(nfeat=nfeat, nhid=hidden_dim, nclass=nclass, nlayers=num_layers, dropout=dropout).to(device) \ No newline at end of file diff --git a/sgl/models/homo/lazygnn.py b/sgl/models/homo/lazygnn.py index 174e004..f8a590a 100644 --- a/sgl/models/homo/lazygnn.py +++ b/sgl/models/homo/lazygnn.py @@ -32,10 +32,13 @@ def __init__(self, dataset, training_sampler, eval_sampler=None, hidden_dim=128, def preprocess(self, adj, x, val_dataloader=None, test_dataloader=None): if val_dataloader is None: norm_adj = self._pre_graph_op._construct_adj(adj) - norm_adj = sparse_mx_to_torch_sparse_tensor(norm_adj).to(self._device) + norm_adj = sparse_mx_to_torch_sparse_tensor(norm_adj) + # if evaluation on full-batch, then we can pre-move the full feature/adjacency matrix to the device to save time self._processed_block = Block(norm_adj) + self._processed_block.to_device(self._device) + self._processed_feature = x.to(self._device) else: - # If dataloader is provided, it means that we conduct minibatch evaluation. + # If val/test_dataloader is provided, it means that we conduct minibatch evaluation. # In such case, we could prepare evaluation minibatches in advance. self._val_samples = [] with concurrent.futures.ThreadPoolExecutor(max_workers=int(torch.get_num_threads()*0.4)) as executor: @@ -45,7 +48,7 @@ def preprocess(self, adj, x, val_dataloader=None, test_dataloader=None): with concurrent.futures.ThreadPoolExecutor(max_workers=int(torch.get_num_threads()*0.4)) as executor: self._test_sampling_jobs = [executor.submit( self._eval_sampling_op.sampling, test_dataloader(bid)) for bid in range(len(test_dataloader))] - self._processed_feature = x.to(self._device) + self._processed_feature = x def generate_taus(self, T): self._taus = [] diff --git a/sgl/operators/__pycache__/base_op.cpython-39.pyc b/sgl/operators/__pycache__/base_op.cpython-39.pyc index 348737d1d82c85b7b17bc78bb67f9fd311e4dd74..d397ce3153de8359bf27ddb5fc4be9e58607c5f5 100644 GIT binary patch delta 549 zcmZ8e%Sr<=6wRHqGnqQms!;J&5s~VsZNaLDF9a95u~pp7!l9V)g;mL@6e$YYA1Lty z1UKT&z2D(tz=caUeuFn7(qiBoPI7Z{Zj!t7ORrtA?F7NGJ5lzZn|20nnv(~E_nMnr zC!BG!LAdF&gP8^n!q9UUyR zwuHaw2SMe)FQ1EUwCrg}e?w`7HAJAgsY9K3gp(a2beuX6#yoyhbu?k(G^9LtR23a4 zf5}qHMjAr|(aCU~Ze_|(n3G?IeG9dJuF&!>MCCNtKND52HqmD9YuU@^`Sx8U2cj!D zaWScarxfKBsTRbvqE>H_iW({nu>J-``Jrz^K^BY>%*q>M6lUa; zaRu|TpMBdaxY?o6#l^BrnbCv`(5FNU|T03}*BI delta 228 zcmZ1_(Im>7$ji&c00ahwZmH4}d1V<3CTgeirE+BPH8V%?r!WRHXmV^^H-&ZbRrVE( zdXwuo0vR*4~WnM5*k2U400L=BOfahBNsCV wGgw?xU~)Nk2G=b%-~5!)oYW${$&75`lUaHAMFhZ#IT(2uL70b0fSFGg0BGGU?EnA( diff --git a/sgl/operators/__pycache__/utils.cpython-39.pyc b/sgl/operators/__pycache__/utils.cpython-39.pyc index 135738ac678833b14b9b8543f7d33c64fad49136..249d654efef0b04289e66c8ec155736ce566d8e8 100644 GIT binary patch delta 870 zcmZ9J%TLun6vk)zyzOnvMXx**(Up3IU;s5r08Nad5(%OaQWy1xRvya3xm8?D#f3{_ zLdKP=-j!^Pe}svD0x{9JG49;CL%(w`592iRn{&RNY3F?N>G1Eapz8ZB!zgdxOYGK- zpo_oU=y~o3amy{a z?TJX6!a)%&Rnz}Oe|L{qn43?-(m$J;ih62KuD8)|c?nOM2i(<4`~a=U#YMcCrSt3# zGH}ntYe>)>EX=hX(>QV+mT54|fMHs!OAVSSwT2Z|Zc~)$lI(#MnXxX7eHVWI!4Q4; zFj{ZLBKipfgh4_l;Vj_{;T+*SP%)cM**h(VPP19DvKIRL-XIkhXisGiy{8dwwN&H9 zC|!wA5G}8Wr4VAx1Dei|FD@6uAHUk0+!RC5)oW*jFR9PYARkiO&Ry#gL@}Ud+(*Y^ zgtaoj;He8gT{v367s|Ywiz)TpeP0}*a=Ljjs@@j+d#)m^cw&r{>ja%}gV0N8C)^~A zt6#+9GgiaAarMPJ aUZOh|!-Ofqw6gsZyhHW*ZDvq#3&CI6V#J;R delta 611 zcmZ9JOKa6o5Xa9XIZ19F=azc059x5Cj)4>&&GW#0m3b@}GasnIV7XzRjwB7D9ywyOw!%Ltl=`B(c-GK#$JEyl;o5nrzGcrmCDklE4;$0q4I04Y|@;< zp;4_a0*!_iVL&;GDA2#`rH4%`$Npxv^)gE%7mrb^0c@Nk9blSfO&jCk=Bqf{#yozW zKFOOy7{BP}|A`a-xNtbw$Q%9MSw7z3i=YD&lnOj!a>Nec66uc(z3)taKHYi;^2a`f z{^DJMME~|K!gW3Gui7`rg^T*F{|wgkPyZ;aX+PNWud>Xm;5GdmSM_>$_SBGC*?0{PCc&c> xTth)XD`AhiQAh8EkL>e&P*+QJ!sq31jnUJW)S21CR0iPECkHg$lk~J*^%rl6ja&c# diff --git a/sgl/operators/graph_op/__pycache__/__init__.cpython-39.pyc b/sgl/operators/graph_op/__pycache__/__init__.cpython-39.pyc index 9b20426556221331121b1beadc3c65f743b18139..fd342356196c61ce434981632fccdd1105a887d7 100644 GIT binary patch delta 137 zcmbQmbc%^Lk(ZZ?0SIK90#gr7XhlCwAzlL~$3D$ESnT#OD`)C2q0D$0z3G Y#K*5>C}IPtW}A4-gA2ssVFW@Z0B7?b0{{R3 delta 89 zcmX@bG>eHhk(ZZ?0SJr=-BSA}^2#!*Ow^WPXHH>BXNqE(7^TZw!~~QrVg?d^nyeF7 m=m_0nkB?8x$%&6&$xy@!l4G6t!-Izd$l_t-VP;}v`Ue11vl9CN diff --git a/sgl/operators/graph_op/__pycache__/laplacian_graph_op.cpython-39.pyc b/sgl/operators/graph_op/__pycache__/laplacian_graph_op.cpython-39.pyc index 0f5a3938dd5c1593fda96314c347bcff59de527a..cefe1c9830069981626616af1575795e95e3c69b 100644 GIT binary patch delta 209 zcmX@cae;$3k(ZZ?0SI>f2}|vq$h*djy`3SAF@-sWrG+z!IfXTat%V_qrJaF=A&NDa zL6hT_fKOsUPGWLqVxD_ZVnK#~!OQ^T zi#ahRD~bag2jEk(ZZ?0SJr=-BRl&@~$ytYiCGfOkql4ZsCk#PGL!5ZDELFX=h+zh+++9 z&}6?Q;FDO8lbD>DnCD)USdiggF!B9OM)u7o83h=PIe-S;VlBx}E-un!F5(37m=ja7 zZt($$_`Lk0+{B#Bs??MsKA;SQ;GfLH?8VGj#5y^SSxb{0BrXZk#v31>oRe5woSd4I z6CaPHfd{07fAU6VV@AQrPnqp?5N3f{{2<*NHo5sJr8%i~j6hy78<5~(;$h@r1^}{% BIb8q% diff --git a/sgl/operators/graph_op/__pycache__/rw_graph_op.cpython-39.pyc b/sgl/operators/graph_op/__pycache__/rw_graph_op.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..692492cadb3927e9cc27a030dc4f915851ff6338 GIT binary patch literal 1014 zcmY*Y&2AGh5VpNP*@UE(p!^)e1+hqEk068)0-*{C6+wI2i?v<5?V{cF*5jy6lwO)j zeU0|mSMrrpUx5pZH)%*m^6c32*)#LaY&ID52+GUzAMz_Bwd%}THGQ@^#nGN{5dq-2{2Uj!FK1k#^-X-Cjq5{I{Vo9x8YcD8la9DQ%l+|0x; zZWPz*gx5+}xx=iBi|qtnppqn|f}~7@m*jv7CcKMo>MdfW{w48A8epW3?fyJj#<=~9 zl4DgXakjTyu!V*phR(>GU9lOxrmhlvgYLwh4Pe|ztz8Z_;=HH}!+FfC53)S4y+*5s zgOLr~P}}7^FEaqnZ^)lF$&spL0&tw{Zhw1SPEHdzEE9#VWkzX84jltrHHkjqclM~6 z=njU8a6y|fojzO@GP;X@ub8_+KD#^bP#H3#hW=u6rk~8{m3K|fXuY*O_2&#;nLyAv zJ*Os|;rEQq*j4A+mCSagpZ1RA5L=Svvd$;NDl@t`9adv7!+mrNAupQA2%1cT9OX*! z1tQ5oFU{w1WP1e^HJGfV^2FGtcU5? z4L2rkKC-fyN}>DcS*U@-Z17xRKt0Y4N8|%M0|z0rPkXo`7SZXmH7f7HW2yth#-ZzR zomG-^+vB`aVq7}A&-w8T16kT@vX4TODk-#0cz2`y>uyK5=51hw2m?-zx9pe~OP KRiE`~-}?t07Y3vN literal 0 HcmV?d00001 diff --git a/sgl/operators/message_op/__pycache__/__init__.cpython-39.pyc b/sgl/operators/message_op/__pycache__/__init__.cpython-39.pyc index 3682859131dd6c509afe0901988356fafc264c45..18fedd28766436a635a027f290e4dcbc223f5779 100644 GIT binary patch delta 160 zcmZ3*{(+r0k(ZZ?0SI1b`KLag$Scb@XQFnyR0=~1PmWlwc$7FJLpnnW?;^%1i4?wI z22K8r2Ur&9)VoqjNYJ4t4aeRIOSlKQ1`1r(}ocQ>a3`P7vZAJW(-I%=?B_>Z~mSAS&VFCd4 CfGk4* delta 110 zcmeyszKWeUk(ZZ?0SJr=-BM>v@U% unlt-pif>kBGGP?D#U3A@n3EG9zmlPd52(M0Z*mv2my{Tg$-{^RG64XxC>8zy diff --git a/sgl/operators/message_op/__pycache__/pre_normalize_message_op.cpython-39.pyc b/sgl/operators/message_op/__pycache__/pre_normalize_message_op.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa5c6f9876984ee97997949439fbb52f271dd897 GIT binary patch literal 859 zcmZWnJ8u&~5T4n)b8KUWh(G}nEz%+3lPD2FNJvQo4pO>nR>$o+akTfy?paDKcd?X; z{{UtF(zaA7X{new#}C0sGuqk5H{Un2*>Kngvfr;y^g9Ij#mTOPlzb$2Zz(8HXrRCp zk=MeAA`pO3glU9u3@T9J5>%-1CP+IsXK558LcMLU63b#^X{&39X9uh zf`$|oq(XIWVS-AiVA)FvTSZGq!&?X;jRVo0Xdj*RuUjqBom*zV9bc}8;> zZRRlgmN`vBuuM8Vq|xN(qp}=LW@X+?t1>fQl%?-)bNhG=zf*VHj_&&v`w#s2+TsK8 GMEnKkQpvCY literal 0 HcmV?d00001 diff --git a/sgl/sampler/__pycache__/__init__.cpython-39.pyc b/sgl/sampler/__pycache__/__init__.cpython-39.pyc index 4eb3ef69bb9a235c42a966616826651ab46d4dc0..5a55c50262ffa93da4ce33068f6d2e5e9f3dfbbb 100644 GIT binary patch delta 160 zcmey)IFX4rk(ZZ?0SMZ^c&6@|$SZ580_3DHq%h_%M3Fd5iB5r6-eA-kB?8x$%&6&$xy@r66Bb;#gR#Z2>?Q|C-?vW delta 114 zcmbQp^qrA6k(ZZ?0SFEjx~EQ=$SbR>0OX`Fq%h_%)&@J}(_{5x?`1q9!MeIO@ NMeGyLIWmbb0RW+i8q@#) diff --git a/sgl/sampler/__pycache__/base_sampler.cpython-37.pyc b/sgl/sampler/__pycache__/base_sampler.cpython-37.pyc index e13cc70fad5ff585ae035182e0b9988c9360f84f..996b186ce21decf01df9c7bbea5711c62ffea4b9 100644 GIT binary patch delta 715 zcmca2bwi5RiIrwc-h6`9fr$;GZ89H+EF<@19S(Va6`(RrP9!&I z00p8rlJj%o^YV*wi*!N8=z|CY5McA^`ydVji7EjJ&&k1D89HEnj76*%C$&fqp=9!ZZf#>3po$_<5TOntEJ1`Fh_D9{P9Or|B#65v$MQ&vaPe|57qJ1Q zH3e?5=9T6aR2G>|?%|PRG?~1bM+TU>PVgwRNrJQsPyWawt#*sEI6X(Nv?Md97#u!e Ly_iND@hSlT(@>C_ delta 767 zcmZ9KOKa3n6vyZ0F(#R09_>^+8Qa>a7Q$!?9Z(RYWsrg(3JOxjMF>ecMl(scH(gYO zZn|?LTnge&{Qy$Ba^=>Ah=L2r&aL<%JSWjkXcoVmoag_Zo1f+n(-<0tR)Od3(T@M_ z*l3W3>S>3S(K)Z>>7H>&PWr?oWW=|HCE%{`iOlm=@%{2V6j=go!;79rprb$sty%tC z{0g}rrG;84#D7VzXX=>H0Aw;t!XTwx{!(5sr@gwTz4Fq|$!B>>s1{|st}GChUsaaw z_aMzuOBzg;Y#RNcwJZ}#j=0C)MPuKn@@IT7&;NOyTNTV|VSlA-kb7+%7cyU7X zUIN_Z-_$1uc^#oLrj<(0sL2cNK~65DTCwCz3s?}b)QMa_lfr6T zYo@Wy5{EM1PW_0*G#_9QvrvZjm>G5wv(!##?D~w|;or60%lSp!g6J7RnQBAFA#C{B zBmTIy%tgIFFfnfyaRqS$aTBqL*h1thn4-fSz1vnq&7f5XTg^<4_9n?b9q?EBBH7>{ s^o#H>e$iKY7qGkq7`50qES3+Ku$KlQqZ0ag%b06Ud*@b%UpIPx0dhs2?f?J) diff --git a/sgl/sampler/__pycache__/base_sampler.cpython-39.pyc b/sgl/sampler/__pycache__/base_sampler.cpython-39.pyc index 25b7f7212df290985171560588024e77dabb76ec..c6e8bc195ac5af8c4040b02f9cf26771ca56a148 100644 GIT binary patch literal 3453 zcma)9&2JmW72la%lFJoEQ?jGjNz-mqBrVaHYJ;{w1H%sTVIysz233#)kzljptf-Zj zyX%>uZ4%TaGSH%E3iQ|(1*D>f`jT^c>A8pg1%n<6^yG7IIraBuNr_ZGkfrwRo3A(T z&HU!~M)CT3;Nbb~C;jAaZO8d96_y_tg;(*a_aTHMSmyK@zC7b2x9>5kbFW^i`8%y@Yr zM~|7Qy_OY&yQtWV#bKmnoaRxS$Js=s3eBM^W2us8^f1x|{^Z~wnpf#0S16Og<>xGz z+o(-!Abl>F;E$ZXD_r5>>xr7E<69FA;p1Bu8zK;Ek66DUT4EhFzG#a}_%=lcZf}~k z*J73I#-lPzAZBdlL|8*eW+fL6bXRalPq;IFQag3-JFUb8_!v)aO^w(m>0eZNWMoJJkE6Idcx3(`qLF%;KlkqQ7MEFRK>G zk9=!x^V4&iR1b^6zLPpL{#W)pMmveXn>)taDe~lY|57QFNYMhM`7kO<(|J2Cvv`ok z`5hUT2frwX2p-RWZecT-G|gW=m^bu0Wuf%);O?67BJ8(mS&NB`@_3XiN>P^FPcpL+ zT^wdw%T5^YZZc8EE#XXKnCMt*Y5e(|%Cw_Fkt;3716@rO(cn@RgG8yWXFQc;`^F_N z>UST+a;S_yq(~6lV@G}ip?r3I`=A&l+e+Qr-o5ko^=$HRTMe`AxzpQwz(f^RB75b; z_)(PRsg5E#Zt6u^4OUpmD^o2k$;xzei0HM`{R|3_%uXDO+n1pM z*{AG;3l6OD>QK%28_qFGGs3x>`x=Mrn0>@_4ajnj39EIBsrrst%i;+(U8U1Rea2o| znVS5h0C(}KD-c5*pAT9`Oyg)Bv)}W3-fk5jaOS4D?sCI>KQO_$%?oQ_f*VnDvF`0iQaL!SQz{p#qKq28yShfUGA zR{in6u6ag%IR$mh>=ryaaiw5+%)!Vkdp#!ot^JcYKebzXx9r-ig&J13h z&)m^#;?j)&$q}7D1DPKo*MaLccKXfvdDv41YOxpTNtu`?rK&gsSS1}nz|M;wrph$N zc}n~Jgldg%0h8*92}a0ad6Etf)DuFpSy$;O(g!k82Sp}^fBXYIf4>C^(n)DHvf@E> zZ;T1_BvOMyCe!vCMgCnKs&QEsQiszI337Y!UP^sJ;Bqj|Vx1OwFMJ&&2P{hTNUoo? z$a95Y3qy&Jejq5vZW=E>l6&nOXI5X$au&L z3iv&U%S6`DBRP-KxE|1x#@5s4p@_zAVP7t2fc@^pRT-kGcE~>DNeS)rwo=1Rk(c9qkQk5H@jNQoK)*TB;f#qx zWjeU^t=#u2^RJRYiq))Fw}g2_w^3!2hJ$1q01O`M@Bpyz0G8>sFI|b1D;;eMy|0ro zAEL@U$U5IbK@waR8JD_T2-q~~rfP4syJ%hH-KSROUiA3fY7>>*p<&g$6*bv;K|7as z3jG#cK}QKviLi-c?g^F)H<*jA1UPyiiX__4-&^u!YuT6cF1VJz+YRUzL$|RgGC>rL z3Ng+|Z$*})suuZ8ShH6P`F+xEl6aE@ZGg1cJBPtJFexZXDhmIFI6m{Q`mW#fuLkSB zE2*v9F*T&Taw10w4C&+{qN!88D~RDJ0pKe&6~K3^+c@W-X2}VNRxzCvqcRB zSSiLo463)d)tTm5oqU$8JO$C;pw9CfU2m+@Oj$BBO{%ZT1WW7R()ykSW&9h}BHg6D Pv-uV`Jh0Oy@e%(&mI7ju literal 790 zcmZuuO>fjN5FICJHf*;-D*Ob5ltY^vCsb7+SRqg@d)teZ#_o1YKDKr$XtgKWGyi}a zm;NPR`3c;5>KkW+wiOf2v&XUD%$q6p_QnMJbu`s)1koQ3wk!FfVEhYL;*lh0^#T3DJ^Lo{1!ijK#5#d?xMVH^v5)EVWH05#P;m6DQfK| zE2E`9SvCuoNQ5iXmTBlBtE*GPXRR-*g=jNfqVWwU-3btiIPNg9|D zjtq19gC2N}CI)b|);Cc+NwtyrsKZ05_xy?}JNki<&B#m;q{6+(+wK!jrq)fT)dyo* zvnj!iPl~QQgHUCkRN|7My-f4pa)t|e8gsXJFjBg)ly$2vE8b6}ytpi?{*T!Okr!y5 p^NwY1ftSxcW4GLgL?ojmG#uxt@axun|C}GW~)^H(Q7%qm3I_(V3441+s)3{}@ z3UePC%uSXK?C>noRpue>>GU$vHRdDj>-0IK>#TuvL#NL(^S04$K0-^4jkcv4w?{D- z$(8CP~EmdI>&c{YsFL_Wy)&Yo%xW3dy-Eb65!*=w7s{%$Oi+wpKTNO*^8n*HoD zB3?oyF{t;U^WPGmC+p~Gy1mE4-T%#qdH=zIMdtE?Ia<{>@zb7NYk+(UY& z!CZX0I$rmK^7nokbmNYEHIUmq5s0E|f#{)&`vFhnE>8uXB|#UB4YqhZ+FqN#%&Tai zs_5Kx<)s-*q8=06MD|9zsVZW3l<=aI7hx28GKy5A{7RI@!(^6<2Fd+opc*5dMEZFz z-BR{eB2|UQ=~kkgC}w@sxE+rYbW3`_sm9J=X8*m|EQPr6TgVySr$-OG4Q}=9;-bxLs#2G6=P9X z3mZvqYkM=}w@V{CZWr_kZ5OW~^5yJ;Ib>?%SLaK118elP2R)GlY5AquqMsMxS5duk zXf>~Iv|Ud4r}K(YqjBK07I+Im+b$;T0%g$b%*IX6pdLJGG7Yux5X{!jc;!@2A74Av zbsFz);u6$lT031MHwgj|H|KG4Y7x$8g9qNC%l3w<@?0u~p2j)%VqgN&l(9X|+i2Zu*tQv>bU29UAxIo7SFw;G>FmRQLeeZ*B z(*QPei&@P6%tBpu4t$)pa_g`#ODFz5Yhav6B_dMyyzXz`J?rZ|zEaPMj=0``%l@)*#!A2lKBad-UZb?raAl z+vOcB?c)0dZuhv5LA!>RkdiadGB@i@`C+Y@}eqq%F#&4B2 z#sYw`z{nIW%lMF5F7!qmNWBw_-~slSQNpN{M)g4B^}M}$b6E$HF9P<9SEJ(9_JZ=T zrU&un06Roa^imF#i}udH;WswvY`r>e2g5l$Gbe3XdvFbQ=w z%F?7r9BoW>{jjto7avNv1d@(2A!i1{mr+x@s?ACHv&myZd1xZ-Wg={Lc6lVXWJL1? zi!3VDT%&j zWsNV+R$d9RYeyA4`d|DyGb|^lBa_wU3c6P8T*PV2Ppq!BYD{c0p&g4~lfaDQLQ=RUJ1mA(6U1;(4gRNV=cv~xGFenPdA zY5Tdi3JRo2jE*0gmyAgbPe5jNJpMyGYvx`brn?2>O|9IT`1tlhe=+xR(4T%Iud(W9 z&IFJ!J(JfDP1Cp+ew*sFa4qG|Awde-J3X5_fEP>0mjDQ8Ej>Z5UG$l9R*kL&g!RUM zJ71F;9Zs3+eS(G_pz;_xoHQXOX!R4f!aCqdDxYAPN9znW1l-p4qPIpi-@rr=7ZmN7 z^&cV62LC0@pthPv1j9i1CEQe1NZ!M;TSo>VzcNSdWqyXr(kfE!ZrZ!Io2Y8p4P_5{ zo2ps772(~^j!=Fqq9TVWm;oC}u50JNOjT6$=Avw|J1jn74|^$S8Tz4IusW6vmD}0Q zFiOfLqfeQtcd`^*JV|k>nn)z;b-Qn{c0(Id=wny8U)qa*G&tdzatKR9g|&o!n(URP z6}EWx01b^fHbI$P*IvJroAHu$fjMXI9bczwpgDDTH$VLN!eU}0v#;KvtEw*4nf!t}CVZ#$gaLQb6N zOHnO7&-bXY#~9Y7(@p(Cg7pwJJdI$9m!vf%v^zMrfMh1)SH@2 z`HzvQ`Y|%~&p&N4RV==s+uH+mm?5R!SBZ!vge^c=__VPssFGe-N3QJuQ?7)65#{1v z!GZ!3U!!&0rhs+|zDB`y3Wz-M*D28O^s^NEISL3}juRsO3zR{gVnK93Jxh4~8s#;J z81>SU`&20fL_b2h!YL&j+{Yyd0cV8AHyzNOMG&5)8QobuWjIG{XHIl#s9;WX{sjpv zdUz(siB3TmKu(~$kk4>vGI$u|ar8kD5Ub=Bl(F-QCMaY`6;=U338f2N#yNm?a{x~R z)WnkH`tpnd3YpkXf*atb()58PgNZZFjck9-~ah>LNq_ z5|jy=(xoAgRAYS$ z0e=7eEA542T&bF7={VsTzlmp>LTKRU?;#1^l%)I?J@{2hy0O-?__}_^GN~*9Lc)46 z^r>3Ga%yw(sOrZ?8Jw-?Tj+qY2xX1l2&?7DhPBz~vCuE`qR|_tZHDafPjMACQ(H1u zOdmvru4S`fd9WPMS~15joU)lGF_j;eqR?jEr?!wT0*c~-qFmP7{*x&R+53#5Ff+8Z zyb4={yy`HARUofIAZAtET`escVO2{@9tef6&fFO63YqXDrZWDTUXNuy9B@s>wVL^A z&`n~9pIEu@c|lc!g1Z+R#$3-J`a}{9ku&iDf+JN*+O)m)fd(g3TUV>LRjt~3cwWQh z7D{*YvfGQoasuky_h6+T-{~Xd$Q3z5{u* zB6$otZr_gw%8`kX5W8B1*rm%+pWr0#sTD1Xi%xH}KXbSX^Wj?bI;n>#UsoHEa;`ko zUBqhifg^2HU1q5w|464G==70;2;8b3wQ4a z6J-Z~8fwW%BgauD9Fxt2KMK%v5{2(45$kc~ z1Kwnh#GOc)(b!%a$+5FG+8F>)w6xblc}}(l{d@kFl?mmHV!5rHeukCMEFMu#G?IFL zP=Ap`RZmw2HC_Ez)GkO7HMHwPnDUbOvNgW)glRd>TxT8#y-HILpEe^BM5inzx3^~V?L5mQ`y8#-n@5bq!|dwJees-S}FEY(J2R5x{Lf+a>Qa%vu8 zNt|e?S`wo{Ps)TvI3OEpIZ!!0nGDg12dE*e%@sqq^;(kGe!lkY$0+chxXP9TIe!N! z9#QbS6r>2+S2>Xto>4HOfE-HY>Uk1eP>!S=s5HG)>Y-E40Q0G8M{TMSvOv{QO`S&$ z6;Al}lnroEP6oG%{~pyu0E1HoJ=Zv3(;F>jSDoF+*X6nK|^1! z1ugy0<_Rn<^hhYy-=wLUO`D!EqAvfAOT2@?05rY~Wp>3pZ~3@GpS4^lM84gE({d3{ ziIfo~^YFKbGR{18#T;LL);gW9NrL?S4mgDn02l>Oz;Ng+xo=xILz&o8vlIpx zLtcdeUZhKcEu~1b{S1O?4zZUxkWk_2Oi6n68TPrm#8+^(95rw?Y0XEiy?>JnauFpO z2iR3t0CxbLzlU#B0$y=v7J($b55ZTuIMadP3)z+A-um2~A9s7rFlb99a*uS}muyd?vK zSDA+`g>qkDG(MNW1^ge_w4LaEyufUt&MU4 z>O(plb7(mgx36FxmGxkAeBHl|Z9jMln({>K4x|WnA$M?1XUB~I6d~#03}k=JZ@9@XY#Lo*8g`&8c_Q1CGY8Vr2nCcwR8pB6%6v~#Rb}=x9PmG)lAj|8y&2T;KcPHwjWl}l zKSvDuv-TUK;ohk|gnxp^|AR}cAs~ZpnHGGfMQ9EeK}lB-Hprtw{t}Ehz`b~O-WQVE_OC literal 9898 zcmeHN%WoXXdGG4i^gKACxXiAo-6&&ugJ^6m1mun6;zN?(_f>b#km73P zZ1$EORCT?&p5Np5eKqeaEj0{0fBt8`7l|(z#@|z>d~77%LQol^&dr4H;m845QrkgK&f%w}WYv zH)uBgR*?3$^SYl-{CCEa{++Uz(p@XJRWu&t7JcU4xgW@(!U@)WzdsI?^8FtgKfiQk zd$JQXHjZTw*GiW-iu6!R(d9ah849C6h(%;DqQ}Z7qiucQ(Q0C*ry$! zhu*YF*Adh*A{<2-#K6K~h|R7k&!9%0Mbs5?h0-rp(Mw3>Ub!WdQ%ox-MlKt4-5fhj zJByue`2Nl$+#OT8<@*^2^*mo)*p(V%REsdM$6A^M)7caYa(wJOt#1uGz8BAqT)IE zBD9IsW|u3H-2{!kc0X28Cn-iMiN0zge;M89Hc|7&X4jR(j$Ed9&7*9cB9dKQOCQ(^ zRIo|~#lGW7k8XAC{Hza7(tHK>if56^n|?nS_x-7yY$-{*Dt`qLXp!lURx+mr zCIV0pF(*FgM~dZ~qMb3!7Jjo=o?6aA<9>OHv=#Ibv%?=DkzzRqLbeeb`vQ_qq~^#v zFy+Ne9NI@H8yfqNeEZhk*HQy$#gfpT4hN@vW$V`3f0w4 zAZ75NGa2Z@4iQL)YN`+3kJ9Z)c)WvhUo@FPs#UFjdpE^!>Ui^VysMqzc(N6Yt5Kpo z8aWF3+Z{F8m3cuNIsXb|7*Uc1F)^B6n!!Y8pszT^c27B30);G|4mysTIra4ol@~zN|vJt=t2V zj)Sc+M1HK|M5RH}k8&rSNJxJ77I&6A$?gthG>PWCfwa$SSNRM6B#Crpu2Y&^SulZ% zE+d&j2u>#|EooblpzhX~6Y@rRBB9DXOor5u>MiwmrJoK{pAG@cq&o#MM0>hU$-b`Z zHSs~gdf^F{VhMa1gDDclM$>~*X$jA?g)NrFi=1wlPm79D*gyGiJpEGg()x+Pn7JCD zYmyfkArSFhb6~C+`&MR#<{eXBN+B^aYh-~b?BqvaF*_9_d*qB*VeLbEznR&GXS3#!5XRl!2h^vDu#(wF zf@Y2M3z?k(qaCt{W~ge1j^SXVGhC%r>W%5 z<(H`&hc!WNsofntg0&MTz))C5?gD~^lbzh{Z%?pExl6p8gn6?+N%{d;8X>Q%C=KJm z;PtTEV%F(3paSj{9MYKYP48^6g9U4XKYEQM+AEl;*Ors}m}nrOvC_#Dn2*w!3NpPW zt8F^1XixqIO}t4Q&m0kfFLff+qn>29F52HfP!|vxwH4Dd+m>x!GTWjp&Wj6Rxwg4% zwh+%gUkdI5iL9R@OrY!OxVoNP!%Lp0h%~(h3KuCwj4f$g`3gl9t}9^zBLo5*^f4m@ zz^l6Fc`q~M=%ciHN=8>-LD6$TEVgYH09HM3Z$O&gxV~Aes?yBVH-%pmUlG`_H|1~P z)7;@R-Y6jxm3<1xw2jX-WD*)~&t5t{_y23e#ALtGwPlAU{yB=C2Ny{?=LBI94!wMG zs-K94M(E)<8d^X;KrJJ40r^;rd?4zMAs^=$^0~;!2Z(3^`P3GW&yPQA9M#V;>W#Da!Q1^Gpa2wE(H z8TlF&tRvDO1z4~Ulvk*%MW7~53fxko9AxucW@%fHB-%X$MY2p|78{6V8JK0IL@Xyc zm8CMUZ-rIJeyXs_Un2toKY>-W#s|KDqa(CF1Q~}QcLM$+Cj5OTbI``h9L5-A95|r^ zjL}P7I6~|LI7<%9eV~SscUVVD;0)jr#t#`>BO`o&WW!{zGZ*~=%M_@gR$wKE%e~Ax zY^6(w?a&26F;`*AF#@qNe8JeF7F>Udl|B=Af#ZU~z|eJ$F9?w=LWxc~>1?4$!;T8{ zoXSkQrQTdXJqvvin?48Xfsc70zyNmFXdmPQ3pD9J=U=x z?AvJoB((Fh*T=gmjpQ7C%zmBevnYprV0_K)#LJyQ6u|okN`^TK_}S5LT&ss;u!Vdi z3|qd5#N1t)T8FR?MvxB2+@>zOO>68n^-vGn1Wr;hf|Jy5;k|@+8&;3up9#rN>(1V< znCLXi0VIC4^Y%$PzqA>p_b2ksUgrk6OuKWfCVvOBlr-mC@GdQD?-G}hJOrBG%g+tu z1dz5LPapW2k$lZQlD$kGbwC1?9Z-2al>vly9PYv2H&55@1!FGs{3?}3FmkCOe5Pfx zU75yy2%0PWS$QoCQuw;lXZW=K3R7U8rs2F&_(MHGzuEbxPWltDS_fUh;RDBZ;l+X@ z2sya!*{20}KTg8g%-elULjDeWTA0h{g)juRBKd?m!!j7VeFte7P{GGLwiv+u4caI= z6$8e=?7}GvSVV9oi_ixYt|O>FK}2dS6i`C%B8A1k*uId0We#A*Cd)`cEs^u?&^m%| z7;n3HJM<2r#C+NamporE)=8ahp{H6V){G&%V96?SAlo$>sm`S)%z84ssfTcvI8f_2`1}Ha2b`|B9DspaH%=5ZsZPjI##pKHQ?LA1p)IWd^OC-^c~w~ zQMlfiU2J@dt;A=vxdqGc53x(~JBY~QjFU89@wIgczkJziejfhHB$A+DB10d_B#p_d zg%R^LBJxBHBI!rrFj787EsRV%PNSW?t%e03Xy;4|d13*bTcTPbex3D@ou{}%SODY$ z?KMkcgBr{|<}8fgt80f4)Wk_yPsr#e2wK+L4M_QhrCGu=UC4S5hUKbx8hca#iIg5a z^d>5!z|P;)cnpd<5RVOi;bIA85OYVyySVQ`jT6$9f<2p`U~wk%y=4*0E1oy1t}Dpp z;2@Bk;Z?+-v)e21%o>O2e*_Sgi65AE;Z^6wgU#lO)++{4T{Ab=za+`XhM7CR&mdd( z=Z$Odv!q0fZ~QHdLPNU7>}%IgDu?x_>0NfQifHgd%o}IQmpq z%0t|VNI0nNc)B>JGRZ2CGTH=Owp z`H{(IW#ax0!1h=PDb?d8&PQTzMSt^=xb1$=cw{_+E03>GywpWa?a({I+3P+>ppk1d zM(wbU7&+L7Vqf5n6aA)*vP`%Co_GY+2*vF_ghV-PYAWKg`th=ts85{#HDm9O%Rb-) zl<=sV$hYpe9|`FpWzmR-OOGMDAdT3?2lw62xK*25t)rFwj8cKat8p~t6oU9wG~jJ| z1>1y83c+NwQRLgF)X>zMg$*LLA-_QpO|ICn+$Oz-%Wrpw9Jf+Wf<@jgTKy+NL8_m6 zd>)!7u`s^8PjeA+53i3Oo51jxTI^87f;^@aImYB2M7fKbqcjS&9Vn^9Q@dLXN%DFG z^asQMN3_FQT>d_N_cld;K+$bPkB0-Np2u-#g48&hbA+2vn^%d3qhZA*auI)eIlM0@`T30)I@y?*ECB+D)`ExOOXNOc1C){B)5wG z`WpD~us!doEGJFgh7o#^~%*wJ51-aXEA1Ta;Wg?s|~q`hI>| zN%C6-wj*pwo&px=aA|z=fkvnU@-d{HOVM6Gnx-9OJ3w%J;+^~3F=JR<(M%@VWWtS3 zd0#VruxN!5je;cUa)tS2v-^Ukx%a7r>F`&P${X)ZlA*q5)UF4IZ-JRe+4M>jP<5nT zZDfGvjcV|{+=fM^^dQUBFA;R!VEjs7`E%m#IvOaEu+HmrrJ?bQh0BoK>QBa+afqw5 z8;%V5|3Xl!hzx+0Hf}G@;SOUJ?+f_NzIX~j^3bP1=Wg@F6-Hk3VGs84j}S6T*+kNJ z83d9f@rI-sJnL42jM^U}DB5JtvPG-yIo>(%jQ0{7!SLwN7Rz6w=qrfu52C_qmcL8c z7MsB9+5`sovrkFVCil+ae>&^h>mt{;V5kimqW}N^ diff --git a/sgl/sampler/__pycache__/sampler.cpython-39.pyc b/sgl/sampler/__pycache__/sampler.cpython-39.pyc index e04ee1e8d5637a5cfad8fb4d23c9a90c58bd5334..2f13357a3b7b481f851a429713c336e8dbbb5771 100644 GIT binary patch literal 9401 zcmeHN-;W$eUhnGe>FMeDvFr8P-nAX4^Ml|yl6At5yTh2=amj7Sk?ft$5r;;OTD?=f z>lx22=N9Fyz?*A1481-Zy?@~Q22bSduC_X zb_f!AV^&?=Rb5@bzn}ViKechQ8CdxJW&cL<%+FcYzfxrJ=OXfTT=J&~xW(DP>J)d@ zVJ2>O9CLR&6?6AGRde@;wNAh&uRW*_8=c0m*=d@%GiVK$I!nXl&a#QSgHywm&I+^M zuy}=g?_1nURt}ubX~e7CN8C5@Rm5vNKs+$FiI13Hxsw^b#oZYTan75UdogGHq-TQ#4@=a4@ZMUbg87p zpHoKU^LR)s8i9`efwjXp>n?x^=(~>3xt} z(>kRd(l-t6p$XOTW)POYH`8z@?y46Ve5~cAlnZ=?(awi$+ z#z-WQdEZNSwX>TjT@i7*n`k$Rc^@U-j7NzMs5Q}cthzUKjV9|GgRFaNmQ2qN9PP?v zu%jLN%)50r7Q6BRL++?Mh^3682iAXHzH&1gCRe1qedYQaZ@o0Qw|_iXEB2+kJ(+}QS>wzE z=2-+;wsvW4Z*H|cLFj1Git(c{5wtQyi-NNhm{BvcXNHa%N-^mpn5~`h+KE0se(g}l zsk6U(H;Rk(aOyiI6pwO`cCRzxs2?P2uW%r!}TRfYyse3p}tMt7CTYgnlrq!GsvI*O@ zCcwyvqkIfV?SQ>*z5At!3+xMW;A6j@yX)4WeZ!tqCLXY|0R;69ZD!?_iF5l~$kEJg zZgc0pjdGkESUGLw_F-Y3Zv11|@O(wQnU40<)p?YF)F2V<81iZu3lZN7vz;O_q=ggG zoB0QCC+cR#j~6f-v10iF`{ud7y{Ax}Jf3}4-_>w;kZs3CLY zsK3DQo{%bxcXvgy8>?QHhJ*NCBIb1+btOX&Kcj2~G)>g8vH`09W6=c6S4#`yA%Iz8 zWQvw$Vn{8QW`hou*`Z~47yHd9;griDd}y$H-d?lOYy#Dnu>};LM#ZP?CGBHP58~|s zc8i|urBcOdH_;W932Yi($pI?q-f(nJ+v!MGQ*NeCS7T`&1YxxfkRvj~B=zYiOOqmU zv`N*?%fgbhhl&!RfT*KPs+oa^Rg~1O8gtUYZ1T9&KAK4Tn(Q>Ydm>W1Dx&$4x-8Ui zkm8*b-Qr!Rjt@%vZATn5zeyfZ@Pkfj8{?_Kv-7O&l_#Tei5YfJ3>{cVL`u z`nNDoSFyg^hv6c%ip%HDy2afco4b_KudRaX-FGIHyaMA*VXEGQ@;SMeSHDkXlW7OJ zzYbDlD2;hOWEZSS4R1hg_AK!U-ZgW-598g1`KESmPXg4u)L+j19Av29$ZNcM-<<#w zrl<1yA!F9<-=fyyTFKo*f)liLdOCLjEw*`H0U)5I^aiO;(O*hgw{~nGtUvy%`Ig5H$1vl?&)_+=RHK)lWPN>p&!_dW>ZrtuxpVaNF36*&z9R`;uW?(6t5YKSG`@ z@vE3UV>J&6hQT^kaMM+(L=Veu6A6U;nvL44;uPh5g`(bG+Pl4%=xW&w?F@R`x>EdftM8#|_+L3$1q6%l&OF_}l5AA{B@ocEQ?#&FNq&+hF zle&EX0|SW5y_v6y%}$8fxFdoYZDv68?SBsUsE?? zGVYf69$Ka=X0|tHY>d+WS8VJSJ}B52E13NF-0@a_kz~v)|4s3;)ac6;98H>`i!`sd zseucO>cYw9eJ(=vdnn~egi1cI>?tAKB6+F+AH49aQ8*r z1Q83fF9>o-#4k`lR7ntO6ktT&pZE#|U!`Dy{)k_ugsTV&S_7h4g6XeNT7zg%FD;o) zm14l}lO{r_cMsPwdJ`}gw+ zYMQ0vL}cRE@y^f&1M=ccM8TDkRFMBJ-lC`%8$C-9QP2+0l(r>E2=B$avRx~fO>IsX zb$wx!!PAPqbzIOCsqN7#ooYFzXKJtqC^a) z?c9k6+Es~EP_bHtie>T&Q2nT(G5E-D1H;L!$>7h&W1w*Rh?5has9y%I}>d z;ys}Q;F#)>auaD5jh&5=8oL{#TLU16(d~#h3PB&$Fz)47b|$quiq%c+_A{)7X7P$r zqLDH)gq%h6*5HYx^dC{IB$d;$TfhrKjpx}5_W06c=46qd&b$ixl0K(-p)~ntfYc9x z6P#h2{2`ujdV+D^A<1v*A`Z)1hbPiYtf>~>cpS3f zAw$|y8X~^TUFZ_;sVc9~_a9o*Q->C_Qdh0K$GC@lmO2f&(eJN9TKoj%pDEZezOCnM z-I|_-{D`_J#;fP7GQ~yCa|Ir$HSlHkz zwBkZuMP6*>yne^}LngkCuP)|L7R+{EgZfcJ34wYlhpaiNqqa+VJr79Tm^9MwK+Jra zH}GuttVttB-S1wTG))=JCzX*mreWS(EUAa~GSq(iyQr^``{q=lxwSdI6ugziJS1oc z6YPSDV6^a+^umDPdx-uEB^^+ZP;du9 z`x%kZ^IZz|DIg}SJu^E}{5GZJlmf-2mnt(@+8tmPb?vAPRRR(yG`eZh$XUY4-M)5! zCECs4GKu%8B*GJYiXf^dad2jKSj?$Dy_Kmq=}2@4&nn^H6d*r^V{nDO@Cp?c_l$;8 z&@^jj+#89Ao`P*sf`Mr9uJH{-ic~>EtGzo|E;!1on1O@xG{Q!aV8=Q~S;%?NF3z(O z@m+d!NG6({JMMRu3N_q#57eE`3eA5eG5n=EJlxk6Q#ne(*43$(Ul6tI1ZMt?j6a(f zu&<6!qOYmb>`-Q!%nV&a6aER8d=-HOID7#b>l!<22e?CNwLQ2Yfzx6uP=5kLwTNi> z_*+IE7xyA|k&Q2Y*2+A#3JK@uzRErz6fQoX6W{=2p_AgiW8*XgXCa21a6lK*Djbla zJd$21M)K=@h@&}VTkb**g%2|&88zV8*lrL-!Krc7z}2Kx9<|hxTE_dv0XES!033kk zui^8Q5Leu}O|XdCA>AquXETs)A)%6RTc4gH*#KUOw`LtWkZfzy6_6w!@9FvUp>@NV z1d}>IESRnyT6vv9l!Y|Z?4y58b*6t-otv)Z0Zy}nyuO3h=gqu{K6p5}X`!C=yoIg2 zTE<$_^Lfiiy!&`}fwmVwtU}u^Qk!t0hyCUb$sAA7d62p^eHza#6(W3w``AJ#^;u41 za{(KF@UcVNhuOsIg@<5J4Q_(RlRJ)5NqC*Y?@;uexHIrpuj+Zaq(|ZZZF082zuBcb!9)o zF*}+<{msP7l)bQVL_q<`%_>CX8q7#WgWj#A5Cz_*tb}5aUeWskxqfKk9EUREKE2N= z_y7U?EkXe85^vOOOS5kpUaDEoK!NxZ%K4`VI{pmi#K)BO7Zeb3icb(gM>f7-G~7S2 zV~CIN`rmO$>I9i{i`npomYMfoo_ih~^eoJ{0Y8c4+!YvahjbbX=sU*~R^9#Xm2+O5V35 WM}*L`_-oMQ*ih7$mcPCH`hNogmQ;cO literal 9891 zcmb_iOOxEzb;bqIX!K(q9FF*sqCoN^YKa^Z~uS1Qc^Y!qI_r6x#1BNVaG)OXP|Q8r_v zXEiOM>+D{s>2ThP%RRSQ;k+GJd$neb^QE}nTWT(G-ieocE6o+oi+HuUiuua%iQZas zO&HgV&O^&d?YeYakJmM=cft&Woz5dtI}es)l>$ZzMn=a zMO)lSTlBd3pOqKXd1NA^DMF)Z-Z!>HC_?i+w`qk|XuoGP?XVO&xR>r5w$XGzdO53I zf9=}WbAnpt+^jSy?L!ThdIm{qq+)FBn_Y7(wnWDmn|~{QD6o#G^{voEYKImkwMV_G zm;W}SPJ8=SBDa$y^!PEIe%tFs={Cw+NgOA4m3K`BgKeT?V><2)TA^OlPlL3R^u3Pq zf-o4Q5eCVm=e@OiZPr&-K;6~NMky<);UJP+&RpN`^gF5VXLY65+wTXxD60=-xPI-M&&Rv>E~s{V zA)mFlZwz)Nt<@^6mxIJ8oe_>0t1|S%0WKAyWgBbyH)0R#JTVT8 zE5^+;d)C+*n>&VD{RX&yY@Igt#k=CpYG~;`JBI9HcSLBNHBximl;73)HtHtW{IGOj zt{b#NccgBieBVvM5Hu6_Du?a??V`h_N?s~+shYZzN?1D(+@liGTvg(SZY{0jdB39E zo{h9D|9xzCmnO>xx{o_q8QYW91LHvKmC$lxTp}5IXredz`)e4phW;lH3}N)ce9ram zWBcZZLVkzl?Vj3m_d_yJ=J(Jv+;f^t9DeUuhO88HL*ZY zX0@Y~-!rRzJBWwT*QHD(Q^*8nCDvt{&a&@gseE6OXh_m2Sy`byXh&HIEyE~t0vYsC zwEF!~R#8z(BT!!g*?yWiGPyfQWoD{DR-Mt8Ih}q;^<~_#hbptfAPuq;xeBG8T8(B- zelLOem%f^bG^^aeGBMeWid-SeP7r0w0qLYD^tCR?N_>|a)Ztm`P;Y@4s22UeS}TdC z1&4`Na-BMJ@)(V9eW)C4ev7tFT|{CCSJcHRvo4m!IkRrM<}z~Y!aY+Il@?G>= z!npdleAXbv^UxFq`a>POW9_+nmAz_ec2~xgvD-F7CoDq|)W+sybqwW#o=_&bh#xOHQ-ozV#a_awcq$fZx6l3WA)V@x$a5=2}6b30YA@0@SM-axDed!N| z>BVW0tkta`h4sb8s*9eA?hGR^JIN&1d;J8GOO4YpZ?X|voa07tn|iTKYZ2bUT=CY% zoa%r~hh8U*Buv{*6z_W4kf3}hU==e-wZ=8)dy zI$OMuN5OQ;$O9}P>v;=`DD*Ihv<}o3lnv&(yWPPfP_^sbji|F8zTr(@mL`3p?*Xl( zoz^f8PIud$HuN*0~{DH*ok%BVhP7r5W$3%A8oqiB|J=o>vXRE4ar2pYCJ+du* zsNRcSJ5Fw4Xt4{?J=T_U3!{lxpclFWTRGi^d`XT`gfCw4w0iK~e(}8b=1r^zt-_n$ z?|Co^x3ET}n!|Q9+k9FNG75M=G9ZW&#ODhlsVA`th~(|tiDa?KAx|@5bb9$~H<8_(ShPDf>()uvQwhZbNeq22np`3L14* zq*-+)N|{ZB$}2R5O|JmF1wLVYn%QhtE$CA+V{s<@tV)YVLKxjUqDGlsfVP)0nIbxk zrK+f!w&@7lv_(~%F{@CNRk03riaYA-=95sPtLCEGEG7c{1xh4T2$Wd${az9dW6IZk z|IRRobqk|9`3fEu8G5afEnf{(dhPP&b&YZtYyr*x`9VCvBWqKB2ZOIRN|JmvF6fmY zc9ct$EK|a7JvDaRXy1QrZ|PcKqOuqey}f@^_ImU4kDxZpMbpfyAx@T#D)o zxUmWlQHM+bR+ytBn{phMAh%?SCALs#8*nL%uCY$|LF`!!*RfS+7!B-OKoDOhEHQyC zMn7w8pEV9l;D-_*yaNkuUoys}J?qZX+9w^GSK!KMEJ7|$SR#wvaDv~*x=!`v8=r2i z11W>u&ugl+xOiKgdoYM>yP)LBz9`n+q|{)*V4j0iTVOWfhi#QwwEwN4^Z~=MC79qA*`D$=Wn-vHj$kr!PyU1t8RiFFSzIKGuRM?->TzdHf(ZUyVgtr!&ew z!OqE7ku3O~WFAQ3NUJfQw88l|>wAI#o?+ohMAcgU(q0bk=wHBC9<-u zq&Il^C+PVLT*m0RLsDO8-Hl6e(h6erGBU-e+=Rir@k04*?6fT_G5=>J?I4`=lYXS# zw5b!XC5qHiqpYRm8^||nbBowqTG*Lp?#i*9Y1XKXsf)tApqGA5x3f^uLj4&oEmYMy zjLsr*MMim_CI|4CpswdI%W7PAs=1 zY|q0g6Q0b6%*l`P(!{$p#i&yz*VyrpJPz}H{Zv`}wPZ~tku)uw{t;;m z4%{Is)kU13-`e%G%UlqA{5x*4RgiCE&sipKw2{N_KVAWJV(-{TDJ%DrJ}^=uzlV~% zNmUNmKTLX=)7nnp>B{G*&Ld`lFMqI5u>LMiT^Qz z+i>L#%pm`Wl5bN|qht!TBt1#~F(osFGecQ~PrA62iv%dEacIK5-vj=RyNEB569>o(>gPPKd>GSv(c6qt z#I$bT(~;Sk&C7e`uY(6%&0j2M{%86jo4m%6A^MFkdN;IA*a0!{Y>(+rVO*Q2&q_>8 z<`z#w42K5;sgp0lzDWU#?qI2;q(o7YQZi#1R^G%~e~(KMoHfc6W}&~)xg%S7JU%6y zM_PQgHDJCj^eN*>GqolZ{e(Ua#c4v24u>7A(M!{ z)(@L7X!JtBm)imGbHSr_#MAIwS6d9NkGfz;y$ot!*%=(|^wm}*eU2{pAV5P$xr@gf=3s&!|Ab2s)G_S3s6yM)-*G%VMIN+m z#joID;$uZI0rIilLxAuBHj=y@0^<*$Uik2c4vk#2*tBtZixZZxea>jcsQIeRUYHdh z1K?Sw1x6w!A`0kz6j?pSo$em|w0X=)tlKcua(fUX_%l4OZ8H>XoQ$_d9gAYyE@ z-6T~6xQyVjhRgp#L;p9#PiG(Wlz4@_UDI6p+*^fJB z55-OI1N7*^C%!~tlu)p#6g&SIP% z`GL5PQ%sn99FSsUU2{0sl^5$y6CK$1OGdhsEJ9%ttiCXVG8&)8R)zNEHMS8i=g`{Y7sFw66 zoK1opbcp;s#U}t=afm#pB#&$w83$&6ib%3>p4t-+1UV^ZK`ivWy&$RxW4@1`^3Rb# zV#!6Cvp&pDB&a{9dgk?4jJ&1p3YHPg zpoj(HA9HGeXN4|cbtO9TU7GQb_1#C^e^UMbWf7)=zd$PK)T!3`A_}~1Phi`~vjo@C z#_8zSS)*-gJ&U~}i=e{;MC(**kvXDM#U4d>rGAg;KK>Vtp;!2Sp@3r(Oa`ykeVY6O zO1OCRp@@rD^tTZ3!!{f1nXB5`XmXAXX|{|y{n>=pdP`KM4=8l8hGGFH5ZOi>BTF@> zE?qUer2(j;k51%!AS$awP^@%xgR@C~lqBaPvni4xDKO1yIMabZI^juXh!o4sF}$A? zZwGc^S@{d9;3j_EJrw^9m-3Jpc7assUKd{$PvGxG@eGa&*Ts`!^xWr=r6U_g{>Eio z>aURyoqsw4oc3+{hQyJD!CguVk%G}>pDaP-rCg&kWSKrIUK!-*z5Q4a(oVw zRlk|^+c)^j8u6nGrQW z!{6azqFe#je90~2T5%tf)Ydqy*Q5lF#jf^eq(N@bIbGA=3dfA cJTH^J;qV2)XN1H#ERGa_x^Aw%z52%g0JCVz4*&oF diff --git a/sgl/sampler/__pycache__/utils.cpython-39.pyc b/sgl/sampler/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d116c303b5952cadf1d7f86af84bbb1c4aa6e384 GIT binary patch literal 1919 zcma)6UvC^W5VyU%ce|InlvWW91*DaL!U0W+RPltM5=eciBBWMAtb||Qq6KNs6+ zdzJ4gdFK(q)7=YFJ_=vKuRNge34j-v@g?oGlnSqA>>bx7zsiaoft3e`<=P7LU0(Avyy^e_h?)ZELk{%)$1Z zTeq)O)5qIpTx}asPb#Ii4{ceQt;y711Qb5etvnoAu7xdIbi|-1gkFMAufgh+0{iW! zA94Kh#VEnuvvC3DE>Q6M5}b2nP8Q(d8)TM@$*g0kOOEJQ-vi zG-@l~)oOVrs#N1UVlq*U^ikPZAJi-xWq^EvCZ@J`!@TxLB7>Ss&KV zQl{82WnET6mv+kSE<&`G^s%vO;xAV1BYtowD%>=qR;wawTB!_g4(noVnMJ%7df(F$ zz~mlBYqMtw8i|MF;ss8%#@Ew4YaMTe_ES<9j6eN!%G?^2+1N|ie6MeQc2%vp@yB+QOxB@<8@cgQJpfS2HzC_A7 zq>$x&2%83*1fd=JyTmJy$Hr2lX literal 0 HcmV?d00001 diff --git a/sgl/sampler/base_sampler.py b/sgl/sampler/base_sampler.py index aa7ed26..cd30230 100644 --- a/sgl/sampler/base_sampler.py +++ b/sgl/sampler/base_sampler.py @@ -66,9 +66,6 @@ def _calc_probs(self, **kwargs): else: raise ValueError(f"Don\'t support {prob_type} probability calculation. " "Consider pre-calculating the probability and transfer it to pre_probs.") - - def sampling(self, *args): - raise NotImplementedError def _post_process(self, adjs, to_sparse_tensor=True): if isinstance(adjs, list): @@ -85,3 +82,6 @@ def _post_process(self, adjs, to_sparse_tensor=True): def _to_Block(self, adjs): return Block(adjs) + + def collate_fn(self, *args): + raise NotImplementedError diff --git a/sgl/sampler/sampler.py b/sgl/sampler/sampler.py index 408d438..8dcb511 100644 --- a/sgl/sampler/sampler.py +++ b/sgl/sampler/sampler.py @@ -1,13 +1,14 @@ +import os import torch import numpy as np +import pickle as pkl import networkx as nx import scipy.sparse as sp +from torch_sparse import SparseTensor, cat +from torch_geometric.utils import from_networkx, mask_to_index from sgl.sampler.base_sampler import BaseSampler -# import metis -import random - class FullSampler(BaseSampler): def __init__(self, adj, **kwargs): """ @@ -41,7 +42,7 @@ def _pre_process(self, **kwargs): self.replace = kwargs.get("replace", True) - def sampling(self, batch_inds): + def collate_fn(self, batch_inds): """ Input: batch_inds: array of batch node inds @@ -56,7 +57,9 @@ def sampling(self, batch_inds): batch_inds = batch_inds() if isinstance(batch_inds, torch.Tensor): batch_inds = batch_inds.numpy() - + if not isinstance(batch_inds, np.ndarray): + batch_inds = np.asarray(batch_inds) + all_adjs = [] cur_tgt_nodes = batch_inds for layer_index in range(self.num_layers): @@ -64,7 +67,7 @@ def sampling(self, batch_inds): all_adjs.insert(0, adj_sampled) cur_tgt_nodes = cur_src_nodes - all_adjs = self._post_process(all_adjs) + all_adjs = self._post_process(all_adjs, to_sparse_tensor=False) return cur_tgt_nodes, batch_inds, self._to_Block(all_adjs) @@ -115,7 +118,7 @@ def _pre_process(self, **kwargs): self.replace = kwargs.get("replace", False) - def sampling(self, batch_inds): + def collate_fn(self, batch_inds): """ Input: batch_inds: array of batch node inds @@ -128,6 +131,8 @@ def sampling(self, batch_inds): """ if callable(batch_inds): batch_inds = batch_inds() + if not isinstance(batch_inds, np.ndarray): + batch_inds = np.asarray(batch_inds) all_adjs = [] cur_out_nodes = batch_inds @@ -137,7 +142,7 @@ def sampling(self, batch_inds): all_adjs.insert(0, cur_adj) cur_out_nodes = cur_in_nodes - all_adjs = self._post_process(all_adjs) + all_adjs = self._post_process(all_adjs, to_sparse_tensor=False) return cur_out_nodes, batch_inds, self._to_Block(all_adjs) @@ -179,89 +184,81 @@ def __init__(self, dataset, **kwargs): self.sampler_name = "ClusterGCNSampler" self.sample_level = "graph" self.pre_sampling = True - self._train_idx = dataset.train_idx - self._val_idx = dataset.val_idx - self._test_idx = dataset.test_idx + self._masks = {"train": dataset.train_mask, "val": dataset.val_mask, "test": dataset.test_mask} self._sampling_done = False def _pre_process(self, **kwargs): self.cluster_method = kwargs.get("cluster_method", "random") self.cluster_number = kwargs.get("cluster_number", 32) + self._save_dir = kwargs.get("save_dir", None) + if self._save_dir is not None: + self._save_path_pt = os.path.join(self._save_dir, f"partition_{self.cluster_method}_{self.cluster_number}.pt") + self._save_path_pkl = os.path.join(self._save_dir, f"partition_{self.cluster_method}_{self.cluster_number}.pkl") + else: + self._save_path_pt = self._save_path_pkl = None - def sampling(self, cluster_ind, training): - """ - Decomposing the graph, creating Torch arrays. - """ + def collate_fn(self, batch_inds, mode): if self._sampling_done is False: - if self.cluster_method == "metis": - print("\nMetis graph clustering started.\n") - # self._metis_clustering() + if self._save_dir is not None and os.path.exists(self._save_path_pt) and os.path.exists(self._save_path_pkl): + print("\nLoad from existing clusters.\n") + (self.perm_adjs, self.partptr, self.perm_node_idx) = torch.load(self._save_path_pt) + self.splitted_perm_adjs = pkl.load(open(self._save_path_pkl, "rb")) else: - print("\nRandom graph clustering started.\n") - self._random_clustering() - self._general_data_partitioning() - self._transfer_edges_and_nodes() + if self.cluster_method == "metis": + print("\nMetis graph clustering started.\n") + self._metis_clustering() + else: + raise NotImplementedError self._sampling_done = True - cluster_ind = cluster_ind.item() - if training is True: - batch_out = [self.sg_train_nodes[cluster_ind]] - else: - batch_out = [self.sg_val_nodes[cluster_ind], self.sg_test_nodes[cluster_ind]] + if not isinstance(batch_inds, torch.Tensor): + batch_inds = torch.tensor(batch_inds) - return self.sg_nodes[cluster_ind], batch_out, self.sg_edges[cluster_ind] - - def _random_clustering(self): - """ - Random clustering the nodes. - """ - self.clusters = range(self.cluster_number) - self.cluster_membership = {node: random.choice(self.clusters) for node in self._adj.nodes()} - - # def _metis_clustering(self): - # """ - # Clustering the graph with Metis. For details see: - # """ - # (st, parts) = metis.part_graph(self._adj, self.cluster_number) - # self.clusters = list(set(parts)) - # self.cluster_membership = {node: membership for node, membership in enumerate(parts)} - - def _general_data_partitioning(self): - """ - Creating data partitions and train-test splits. - """ - self.sg_nodes = {} - self.sg_edges = {} - self.sg_train_nodes = {cluster: [] for cluster in self.clusters} - self.sg_val_nodes = {cluster: [] for cluster in self.clusters} - self.sg_test_nodes = {cluster: [] for cluster in self.clusters} - for cluster in self.clusters: - self.sg_nodes[cluster] = [node for node in sorted(self._adj.nodes()) if self.cluster_membership[node] == cluster] - subgraph = self._adj.subgraph(self.sg_nodes[cluster]) - # map the global node inds to the local node inds - mapper = {node: i for i, node in enumerate(self.sg_nodes[cluster])} - self.sg_edges[cluster] = [[mapper[edge[0]], mapper[edge[1]]] for edge in subgraph.edges()] + [[mapper[edge[1]], mapper[edge[0]]] for edge in subgraph.edges()] - for node in self.sg_nodes[cluster]: - if node in self._train_idx: - self.sg_train_nodes[cluster].append([mapper[node], node]) - elif node in self._val_idx: - self.sg_val_nodes[cluster].append([mapper[node], node]) - elif node in self._test_idx: - self.sg_test_nodes[cluster].append([mapper[node], node]) + # stack len(batch_inds) subgraphs into one graph + start = self.partptr[batch_inds].tolist() + end = self.partptr[batch_inds + 1].tolist() + node_idx = torch.cat([torch.arange(s, e) for s, e in zip(start, end)]) + global_node_idx = self.perm_node_idx[node_idx] + composed_sparse_mx = sp.block_diag([self.splitted_perm_adjs[batch_ind.item()] for batch_ind in batch_inds]) + block = self._to_Block(composed_sparse_mx) + if mode in ["train", "val", "test"]: + mask = self._masks[mode][global_node_idx] + global_inds = global_node_idx[mask] + local_inds = mask_to_index(mask) + batch_out = torch.vstack([local_inds, global_inds]) + else: + mode = mode.split("_") + batch_out = {} + for one_mode in mode: + mask = self._masks[one_mode][global_node_idx] + global_inds = global_node_idx[mask] + local_inds = mask_to_index(mask) + batch_out.update({one_mode: torch.vstack([local_inds, global_inds])}) + return global_node_idx, batch_out, block - def _transfer_edges_and_nodes(self): - """ - Transfering the data to PyTorch format (except for sg_edges which are coo_matrices currently). - """ - for cluster in self.clusters: - num_nodes = len(self.sg_nodes[cluster]) - self.sg_nodes[cluster] = torch.LongTensor(self.sg_nodes[cluster]) - row, col = np.array(self.sg_edges[cluster]).transpose() - self.sg_edges[cluster] = self._post_process(sp.coo_matrix((np.ones(row.shape[0]), (row, col)), - shape=(num_nodes, num_nodes))) - self.sg_edges[cluster] = self._to_Block(self.sg_edges[cluster]) - self.sg_train_nodes[cluster] = torch.LongTensor(self.sg_train_nodes[cluster]).transpose_(1, 0) - self.sg_val_nodes[cluster] = torch.LongTensor(self.sg_val_nodes[cluster]).transpose_(1, 0) - self.sg_test_nodes[cluster] = torch.LongTensor(self.sg_test_nodes[cluster]).transpose_(1, 0) \ No newline at end of file + def _metis_clustering(self): + data = from_networkx(self._adj) + N, E = data.num_nodes, data.num_edges + adj = SparseTensor( + row=data.edge_index[0], col=data.edge_index[1], + value=torch.arange(E, device=data.edge_index.device), + sparse_sizes=(N, N)) + self.perm_adjs, self.partptr, self.perm_node_idx = adj.partition(self.cluster_number, False) + self.splitted_perm_adjs = [] + for i in range(len(self.partptr)-1): + start, end = self.partptr[i], self.partptr[i+1] + node_idx = torch.arange(start, end) + perm_adj = self.perm_adjs.narrow(0, start, end-start) + perm_adj = perm_adj.index_select(1, node_idx) + row, col, _ = perm_adj.coo() + row, col = row.numpy(), col.numpy() + num_nodes = len(node_idx) + sparse_mx = sp.coo_matrix((np.ones_like(row), (row, col)), shape=(num_nodes, num_nodes)) + sparse_mx = self._post_process(sparse_mx, to_sparse_tensor=False) + self.splitted_perm_adjs.append(sparse_mx) + if self._save_dir is not None: + torch.save((self.perm_adjs, self.partptr, self.perm_node_idx), self._save_path_pt) + pkl.dump(self.splitted_perm_adjs, open(self._save_path_pkl, "wb")) + print(f"\nSave Metis graph clustering results under the {self._save_dir} directory.\n") \ No newline at end of file diff --git a/sgl/tasks/__pycache__/__init__.cpython-39.pyc b/sgl/tasks/__pycache__/__init__.cpython-39.pyc index 66524d09b1b6b06b66e255afdbad23e6353b7a0c..a5faead87d113183a97cf2700d33b50f6bc6361d 100644 GIT binary patch delta 340 zcmYk0KTg9i6vkt>`5%*nE|m~VB_t4HL}Frqoq=v%EEAi`!cL1525x}ZBHy4J3kP84 z1~BCUEF6IW+hX8Z`t*DD_ddV3=sQ9bp)9aOfnI6k;hAr3*n0ydSAigD~Y zw4C6?b0kxh;mmVvImfx@L>3BRG$|OjfGKH3(J#%eMgU4zK4EIPMt=X4#Ue zjTG%_T(~yA=FMXxxq;w4$@Ro8r=j@(i@I<)brM>+I+~duaB=1vT()b*Nh2S+o@<+> zBj<~)jHS)KBOA?`XnL`3b>f~$UrLj0vGq U67mx0+RfMI3tmp=AJ}E=Wj4h@}tE1&Imq03NWSI4Ti#0vGcJh&}Qfu(2|t zF1$vcfrW(`+Z%V%>67ksr~4+KgjSTM$j0jFx_KSx62FWWXLlG0B&IS{gc0x1I%1JG zk+Dix;vHG1EcK3MrgD~hCz7fen|Y_QP$erd+9Hv?VRkTnY%m^wmS&B2G=+&h4}mM8 z`NgNn)bjRxzF(NHU|}}F7&-72a8rD{}cGiW2ol{ZN)In|31Gvg4+JrqH)hhB>!zV%eVhkkE%_2a(+ zmN4JEdGGCe^Jd<>88*K2;9TrPG^#83JGpqteD&~Wu?I|5nRv+V)qS>-?J@jzxY^rM z7Bv1RJ8by-6ma_t71(VAHe~;REFfk5Qq}`W%Fv9q7HJO1QP3L=+Ddb%MRtY_<-C#! zO=VO^EpaDu)hx`-TUo=*=dNA2aK=ekJXa`Z=XkC%pDiyh&Y0ZkFjsP=M_AJxt{N1w zSC{yL$*Y_pqUpF)-NMeK?|-entg@HHdx2p;*%rqGo$MfnLw-`mu#fE{Y3Gxk_mW!g z6Yd?~oM|~E98SoOwm5Xg>6<6!A#Mb|){~fG46QmLSWn-O#x$0|KlKr&Z(xkY*xGjS zSvV!W(Dv@%k1sNePoU%9+JH$=PS7%|RyLQ0H@M1dS7u69uZJoMg*r|28u&ZQ;`~QCpCtjsgZ)j*GvOL zEv#SG2Wn)?c*NTcpAoJ3(E_*GYOhtN1>e?Sr zgti@RM?<@$)`GQAdABu&G3f+~G(C4)9@TAgbs-JH>WsD1`=XEbzaLGiHv&-8YvJ-P zYnP|6Zju(yT_0DIoI}&njaV~dYesvI(qcv&bP#<9daQzxK(C|kL?1@q)y%j7+HJki zZ)mKp*t+ds@FPXOvh^CO5iQJ>uP!2qo4Z?z>&DaZadlk>j*~LQb)$R3s81TD#Mg}; z+i&>p`i)*YK>a;GY1;wRqrRz(-|^*)qQh35di!FbT*%Jkto(eI&OSecbF|idj++&} zlsBu@LU}H2ES1f4p`2dxjUSwKST#@ehyJMdqjjQH>9k{+lqRzkgSAxnQm1$b+>FZm z&^da3sZ`2YW_GrWLw+2ru}2QOd(MxRa*H!YZt@7^|3OpM4pDumk7vzQGrw$^<3~$N z`CO?w3Bpyqa3&LQbT^7>Va;^HTdEX41shIVb$)qvwq!aXgD+LGrCil|jxZeGN~0@{AQlVBUb z2~>EYZ1FK->>{9o?gV(QJZJJf#Mn+SF6!aE2b_?5Cl-v=>9|X7rWVhTBHa>v2f<#z z^e(!o#9zb4t4|qn) z%Sh!lx~bL^LGI5!FsM~m01DI9AWO3#>sIOdQ`s)OLDJv_L63P&@<6Ali5t=S8=qrc zZF!zrIB1|$gDnL4-$3~!bxcbAKwNA@7u$frmaiEy)E29;@~Jsof2#+<)lVsxch^)~ z6_NH-oi)R@)u2x=kTRKt<}383h?d zJh~nBB8%8jio*`tp;@YAfvGc&FJD>a-^GUVvw%#}&2IU!a&wNd*@-S9bxPUUC4M!> z4L9L0lAztIsHTyfK>=Hkd9U*ykiiQ8aU#BBQT2KzVN`g!@rJODmFNDwFJ zAhi(P9R-Wt*@uYv4T6^mzDaNxupYEre0kz2v`@st?t1!NSN@bItX$GB?%(aHsoF7U z@#A8s{{%ZJa{UL`l(^Hsa^MEBzejML;2gn^362rSlKBo%w+LP%SP+w`vD8~2o!Bfk zXg*ubEmlgnwjE!|Ec2V9p6W@!ca7mhdxVrI0uI>%{zt_&l{on4Rw90=G;lW?Bsrw%d?qCS&%c0Twb^2hA z;HLyM12-tzKR5s7)TDpPeW$Z(URgHFRsofIJskc4Y2P6b1VvI*%N560U0%$@-MQW) z$u9uY@}`^R=9Ott9PVJh7WQy9(i_0jVKbGv+{wGp3ElUnue17ZL1L?aM_q@z zBToy1a_(ti@ZXc*Dbwa3li&|zHY(@sX!6I+wds)TabFtzBSrYOnY)|Nnr@fuXV3XV zD2<5sx6RgTpzwopicf*a==>@=zDjV7fYOw&5o{iB3Ye1x{wBf3o~|dkTopRMG~sFl z5M&CA_28(V)_dj2&h)!UyFgMWfCE&y#y=K~(IdFa{yf@9%HzLb*?VG$i`)0G)c*h? CbhwKE delta 3065 zcma)8O>7&-72erhE|<%HilRtLp(M-tp;X6i3@47GL+n~=;wGx$25_LnUC`VmC5j@c z*_AEHyJA`%`j^0JCr#S~PDr4KUJU{LIpo$uPeohww1?)D072aX1&W*!r0>mAl4UqR zNz6BI-p}lNZ{M5Q2Uj~U#x6ypngG9u^Q!smx%;sPMEs+`J}39C$xgUC)DzYe`dcS# z$T`6g^sFHQW*niL>)|?Y2;UKkfu*A8iFN8lZkwys69CKKe`*8ZP7H=M#r__g2t@ZM~4lx3TVgwEKEg1rf0uP2B>(BB2oM^laRFb}LYcL0DDZqh8?9zLLru*_i zkY`}mgMDLt{*co?F2Dy!wX6|Z6Y1xDrQ9RNiAMI&{J3=jt-E=^;T{3EC7?zgT95Q& zmkiC7VFdO_dn?N$f1*P!3wz{->MDJuAQ&ehhBxg3ZVLmmF#=;4#wd)jj`F(bs&1$; zYmdVkbpkbX*U!jt-Jq`vQq0Uh>=WDq;>YA-++4jA05#1GH)iY!u5h#^CNWSOI@H96 zgQl$-;qFOCF#n zxON>Xx;xnu{M{YRk=Nusk?tPo!1IzF0(GPssl&Y_+^f$WaAZSzPd3s{0Oxx$@|^&5 z$RKuzR5EHF+1klmrDlsZ%cjMiE~Kt_#QKy#t-i8YlsxTb$u6(x*6NPw1+A6MN~Lak zA%iwsdc9=Xo?O|i*Q=n-e>>5IR~XV6lkYSI~?mwBJE_Z5YG5Ld+I8Q#F3#f>&S ziK-!+mR1}3s8xGd8J0N5bJuSNpF{^;Wx$=ka}^yNiEaRZ`PaT zHNQN+g@y|#SFO@|t1hiJs%$PjInJ9Yh|+H&@G^b@^`vsiE~e-l+ej}e*HQK=`!MY+ zJW{hSg9OXfS_FWag?faGkBo>iP$~Gwh$f|=M)NR6B#mpN*zTaCf+aE+crO}sp56aa^XO{Wrp#D*fWdWG{@iDdShnHmYF zP_>0H!2X?`WwER(55kg-=BC-j?0F>%3-3ubo||O9&R$3iqv{C469@$`x=R+>Zf;(| zBzdP%GlEA&y;(9$N?&0A$*qwhyOBQ^=6C$p(C!qwm(QI(jpP}GuOmE#a2DVo4xUdQ z4CbT9KF+sicYXO|J{NeF{s3Z21Q0})9G=|YA6^m_|5Rkb!X@$?yHYqy=GkuH*4b^) zq8AbHEzwH|6v5X=>JEZ~fFB=JXVaropW^Ic=e~g&HBd z313gQR?8;6jLc2;!RVaw8j`o!Uq(-*UIMZ>ZA@K@6a_u$t#fie~4^jLhgdZS$7r+x6 zEl;vG*NX|?rJthY`)pxil>CgD6MFG&U@XV*4IkZXMhYUpImxySc|lQaMYF9aJ(eVr;W6 z3=6Nk!;W0gvb+I{dA}BQlzIX9#BOcVcUgV7MTD z@u_OZGg~dhjyDMx?|K&=0w_xl60nxgA_Tl4!~-u#Jna)YkBfv9i8o%50h#YVr}{GX zxM#uxJ*vOXoQx{#%yw2&RHbb>IZ>jw>~`MCi?q=yw2Mx$U2;kypJ|ob6{jN7W^1Zl zb*k-YXSzM(%(Q2n+4h_>*PeIgRpo8P&-&JF#kYdWp6)ClpYv_xZINH})hkLZ{~6|~ ztkqO*eAV~bc-wD!v3I%S`2nkGyzs6U`kl7S^Zb?VRwGU;Rc^iRMZr}s+Nzn{^6K?2 z^Xfa?ih|gUg23mQn0bwmSK5uR;a>OR`lc&Otl*~C>Uwd&r^o6Bg>{VS~ei3$D z2^88$NAn9GDeJ0l_?g=|M-L3&^s~2>J=MwhmY=(=IHqs=dAze9DTd-$SZt9muXX(3 zjg}WhjrB&|iyNKLz2dd8s&Io_VXNh~I#Kl5d5p7G&(oOcuYh+350yolswlA%tBJCw z^|gI`xddrfPn4$KG?2<98mJj(6RmIUD~X;MzIq4sZZ64y+HpR~#)UnNy&D(fQj*=% zLM_Rnrkt3dmBa#_N^+pp#0H&Ca=y}4S>l7!r=DgKTaj=KRdC?q5)s5p_$VZXPwuw3YkBc|zo@Gj!mj=jtV16<^uiVBYqo8+O~*1LjQ8*L|o+vVp<UqtJ8BZacBj4>ak~*U!YKB_dcgH4W<0wQ#9oYq zdg-VYlk~dl;ZY;`!OaGGKQnpa;ta1&=8>OJp6v*WF{fO?zw1d((c!7og1Ax1mUK!-pa z)k=;tc8K^9TP=ne=eOM8R#5LkNk2V#f^iVVF4TXmz>O$qt@G^Q;AP~AnZk)-RXR3i zD&uAdzu6s|p|K$stopx{`if&xE z^7iHDTRXQdL>sLOIR9JGg%Iaf@{Y8=_|nk&tJ^zl7E>$XiCzLxw3243Ha)8Lz`}b; zHPgSUS{11R>P%JAl3G%C&rZ-os{?gaC+#-)QJS7N@I)pEj0cqJJr%Z?eF6DGdq%cQ#}mB_8Y>?wO?6LIl~~xco{BuI8EB?y z3YrDA@Z@02j3J*E+prv3qAn=D{zryye5z9~;LEcgffm?@WaE%3?@^x>JX6ASXMFQE zYzvv|I~c*OhG*g;YBSC00UzFab7q_eCDvV)eS-Ha-d{}no)dk~qjW*^ZHvCGVc(07 z;mbe7xAYjk!b5z^;d(mP6BtvBi?`l9i4h0u!I&t^G*3O21EUF!CmzeOj2gYa0-D>$ z?zuWrr^oWA)32rb#h!sDBX;`Cp8BS8SAEawnQ1+JyZl{^72+qUM?6*P2~Q5sSv*he znb)+Q)yswF>3r<2{Xc9`BicO|CD=E`Wc}gePVv z#+Cue_(@{--~;tCchws|`B>?n-&f=qDTl`Xd%%HFXNBb3{@k!fUGzxj!S?6jzmaS& zBpG^tY7lplfe*B}Rx`Nn`Ax3g;^l|T4|^G>>;Bw3EZAn9dH|<0+*t8DVX)E&SMKYp zFRgPmI&dsWBapW{O~;Z{61GX{!@2a;eGS8DYb=M-dm3BBi`(^1s|7b2CLikJO%$=O z5)nE>s*6`zUi-T5z4~R8{|6pr_vMMYD%JU`ue3UKuNA$DgzVtzrCNsDW12WWW-Qn@ z&@<0Qo89$w*w38LI@>OkbNo41$8}h3_H}f|exAs|)v(JbBZbL+0p#;`ZIB|I-1l4m z*S;-sZjbK&?Lq$=NQmEV@h$TDF0Op*;tMbOz0XJh?pK9X6t<8mASb=uDHvE8Jq?QP zm#GK?4+|=S2df;E>=xmA27_F*JNJvy7Zrie$|ZQw(PrSE-d(=hiM`eeF|PzIZ#xS7 zQ6u|x^td9%*)2Y|a-DLt)Qc;%XPJfhvL`{fL9J7<+#X6inD@o0L6EqtKM4zcZR~YVT2n@RZz`h97Y) z?yv=#{|O=lNU&uhr-_h$U}WAI*>`TPcUaqtdBNT8L~*KzY6an2@H{@0OW)0n;;{@o zv25(dBoo}C?@dz1b0S5HSQ`0q)D|;P7mo7hQXP3f0g-!4ZXAhv^kW=j^5^=h00)$& z=Br(W{WAdsBK4Z61yledh=9MTG!gtat$n?p^R<1or-rX>C_U)?rv0Y!-uWIJtZRBt z?-@P72k6%Tq(CXO=$VPy&nM#&|%LeH`5)ZaI$-t@M&gOFgc4y$i+^Cp7)`WDXJ=wM1snWZK! z5pl}jr^Y(as#;YZ+ia&BZ+GK}XM{P%85P(XpthV$p&oTK2X0DUURZC0&@nQ{Lh8H& z;^ay5r2RQIWh55j6ofiUYj|ENHvyA4CBa7i0e~1ka>|t7Xmzd+3sO}$(F_wqI z8s=jz37+U35U2z3XR4J427)kCH$=WlA(vi-I+%N4;MmWhMsnIX?j_U~Kr48vX%DuB zdTsaIM3-l5|7!*#uZxvI*4An{!kR+hJiwZ4*TskLwkTb2-5Xu6C2K_Z$Sr()AXM=V zd!0HU)g~6ji)olz51{kiv30*fEJ7C&szs0R_$(E1;e*mC4NJ8y3&g9Ax&uEW%|&EW zhk$5h7*$C?RG_~C5EZyD0MRMG;1_SJPS!8^Wk92rfJSo=|0+Lo$R52L)OYHwV1SU0 z_U1G4QV$}eWNio`B{LrVJ;?tDLK;^SwFzu=kI>2iBKpI47Nv8EElTG_ZXwM{1av8^ z`d9(l%;Q-Ui0Km0&1HdZo)GA!vBoM`1>D|~Yz7JTArXQK_;ZK(0T1of5Pf9x)a2P? zHA!K}L*1Q6`X8Z@Y>9?A6yJHI)d-C|9zQwx4T3q@Nm~CYB2N%m0XgPQ3_zaaZ9)nS z9qMd7vK)eq7##=p9?jQv zpoH5dv=oW{kMKm~vOf+bBj1V;G4iNPb@$n$mCEOW$JlEa8UWev<24o>jwVUU_aE>K zkg?-TKztkQ*8zWyPA$38!k;^#kvfAwRODwpI0%<~d_L-=OB0`Wi{+KUSO zAOK|K_2501dq0I&rZgA(OEeQ-=~a>ve6zBj+0p^1 zoDs8J#tia%@G$!)_7&gUGKedIZw5}9;mOqF)%&Ls^1!rSHL1`=6?(XwAPr7QgI$}D?F1B9dWpWIh^Gl`L81hDgiguLLH-YkF_tA94p{hR=c zY^0|F3gz)O@Gjsz15gN1CGmVo>FGCqngEqfUN}$Zr#Ih1|Mi2OQ{E(ka3UEkyRXeX?1vQ9{s zKgEq8q{$+v^tuuIU1BS(0|(laH!F0@s<(oW-KCnTk-xkicyX5n(pmm44e?vVk&8RH zRHVBqc8v;uM-<{JtwT47KF?pm2FF)}FzPUFb;AY)yRHm@88x^=4d5Qv>DIu<^+X(O z^E`#K4SWsU{=%(mgzEtdaNEYn&#KK!r5RB3bg1j&HW>81O;N-*L zgY*`ChZ~R(tcNTsw}o`cd6kDj`XLj8sT7)Wq%O%#{YN|zMJeM^IAUQ$!!~Ji%}{5w ziZ%nMx`5c)MxJUAOM{k(2fE%uIca?vMW+$={}>rK;KYHbn&R{0*@HB^7Gho}+dlY_ zKSt(S=w!M*NmO5*$LXHKKU8dK>_4G_#$#?BG50mnN_dPWuDSc9SF_NH6nP_hrvDze zL)ui-LJ!mRyn(pT3_n5qP0=!hei$tXH9bh4vkcQ$#FF zr5F=hh-POc1w^QbfH77{#GqzegmbDw_hzB-&7?dd<60J%`-s^0Rn&_a?y1WPVrcV< zk|1IgPHJ{dDg#m=m#*Sc%<4u!gS)L3(x%@bNv}=GbA1Cm?DvS!*^(z*1ex^hnK*Rf z%uw{nr=#FTHwfbfNqqVb66}Y>_{T&(B@z>(7H;E~y4$W*491a3R@1SfbtO6Ly`uJ}Tm_Y&OKwvfB7!z0xLWG+)2&^syR*JRlt0dMil9^Os z&0bbG5V!+M%rSw*|CeCA>{yANSdR!SKoSUS9%JPxit6bQNH`&IHw*}@o)l7nT^9li z(|bT*^+bn(G}kU&g<&g%_=g+u^#(f5hz5D{g6O%;cTB>Nkx_*;-# zQQ{@0P>?%LPP~N#{utFMsF8}G#5)EuSQM;7AuNq$NeZRPm0~Cx4&-n=mkOe^{*RZ$ zKf_#3lO&!&LSi3F6XbCB`Qu38(5j6?8tfwsutv8?tvh`tCoizD zRS}hn|EzJQWl0#Wu6M&a*=S(TS_n%$%FE!hB00NpqZO@^b0w}2oDwenMRrxx$mozx z9+Vnt#J>crUdKNM!AXg>m`OF|QH^XwN>F-70#6qhJ0TnrS*Nos@qp2Gh!h1%+;jz49B%^8W%E6fA)N literal 4302 zcmai1TW=f372er9mrIJeSeC6^nAl0fG?oH1XaP5E6I-@{z)%9%Peg;onzNKhEtmA{ z(y`cOP(TLql){e%iU1VQhXyG6gXXnQ{s%=1)cwxvQlcG{+Qpo?oio>SzH=t5)jSQ~ z_UK;p=gXS*PpT|_CMusHr5Om}L z9)(%wFeqBw=t&rl!z_~P3w`s(4$Zpw(Miyk-UW#?+I?G(jG1|gQ{g(-yXJRVmvMue z&nj&zvbn|WXW9vCJKW*!Gp+6N3ips#zSB&t?O|_KxxGK+(d{@)(_XjN3A5fX2_A%l zQQS)oUVMm|_B&OYwfv}TAsI+%4w7jZ%e51Irk`p|%TW%DTM5gYh7CDcoYdi>y>U0;Ow&?knaUcwi^T-_lkRy>{wo3 z&KZY~qBxb;mU9howz1G@4c;ge;ULN)k;*$#6^pUM#IiwdeNrhVVVZr z#^AgsCZlijVnw*DWPtOP-aOezR| z)c*bJ-NWG^+D+5PyAST&zY!ll-AxbTU1-`-x|={l3f^<`OH1?H86ArTmgFO)Z-Hof zUALJ>iRmw03GV${>@)1hX zou4Q1q@>@PzI^(*z$??!d6R*v>HmRYGFkvTVNA;upiBWLeIOjD)dxs`I;Z+sub#P_ z0bUV6H-0ckmSp-B@ohPVf}@#7@Km;#kTY{LH+~{lea=K9^VLi>&^lo#cfKbu zTPfQCLj6^=&fC|@_Vw4=H&8eFm$-Xmh<{~%wu!pczkEV_T9n^Pw*JouTQtJ#3h33- zIk0Xo)VIt1SiCZ|ksLM08z*c}d(OVrr*7E~NbOg-Jw~hed-f)cie&X^hucV3k={aj z`^5geKJ}*6tchB|L%X&5v^T|hVu8b-){^_#J7+uq@&3x319>r)Q;zlSD|~CH(a98} z(R>&U_uO5zUzks{APP6ylJ7vX8C7&*`IUnX;RT}5=wX<<8H7@C)%hz*TY1c;XQ2v9d zbj*R7T-u*(ekG!j7y< z!*RD8N3z1ja1;PLGP!y?9;fgZ_ipd20U~ighDP6{c88$KmG-uz91W(tSO*U z@I)P0z;v<-%qyaev!)el0^*)EG!zd2@~ou>{}at?>16ZE!Xc^Ppm`UrG(C*?JCp5) z!z_%O#N3SHaFj;;T*$kH5u0kBN$sQNBdQSy+-$bq7X&rpDu^^mT5M3!8=W1cSfvId zt2Xkyhd{v^Pc#{UBctlnTZ0_-=VjK1LhZi^nipzpn#c|RhCq$BO zkmLm}hJ^OVG2S zxI%}pP6R=;CLAK0M7D@jh`2;3*p_y8C*2_krHm8sICPEXYVjq)XRNN+&C-M2a9Svnp&AfGy4=$sXj`EnaWPEftWDm*-daz^wmNsK5 zn%4R{XJbv=RBc99Izr__;0n>|o*J_X*H770Prl{a6!CW7+ta=lQzOSy;7AvLRI=6F z01wF%uouv-s%gw*CKM_D$QzIc~pv_+4IW>yDTt_^viNEF6RDZme`Li0Fuhmtv zKi*Svbx;e=KtI)|cGjTC9`)5Lv@3#sRX)+-s^)#`3w`PMD;p%IalJ1#bb=`I}I1@jGxS=n$V#=~qOQVceq9=R{~AyvT;4b13a6X%=>l zpw;JAAbv?)^uiIJ5>Y1g8IGl=j8&mb=#8z0A1_JiQ@IE>34J;6)$w!o9&PLjbD9p6V$MWNCKP~b_G z9g7=SsIon^cv&l>-A)8x3KsU)B0dTj)iaDDy_7;eG9mM2eO+gs;oehytu^rp@jV1V zbV=8Qz*{r5Z)cIP1mf4h8ZB{I4Ud-ce!o&LceI0xAl6HOrOr9wk&< qu<}9*DHSv-kR*ImD_Yowy09Yd{j@l?-;%>U+lvp-@O7UxjsF4DhQr4I diff --git a/sgl/tasks/__pycache__/utils.cpython-37.pyc b/sgl/tasks/__pycache__/utils.cpython-37.pyc index 5acc197cd41d4f79ab7a6ddc5168cad9e4b746bc..2f36ff370d90bf0037b186b1088a350a3c9834fa 100644 GIT binary patch delta 1983 zcmaKs-%lJ>6vyYx?9R^4&iz?1!`>&OKMWAwU$=sDxr!YD8F{0Yk{(Nw+ebE zYot#`P0g*w2V-XOSFKbF+lp<{_~47NPo^;@9?#mfGxBwXYGS%E#&&}T<*}P-2xt=_r;5EJ)#f~IZJ34ont0-XUt^#740?q>K z-H;HI67v7opx;G<&Z8R|q8XRR)-`OAu%5eJLMgW-0dr9TuAqc6lSdqXLE&YNU;O2} zKVq_WG>JWeaJ6K5y4YdVLAwbdu)u&ysD_Zb0~rBB^?-?=5Pr75VTX%n{IvkxqG0Ys zes5Og9|J`}XO)5iN)l`xo9Dx)aTexOIRk&&e(uK^zzIOtJr2-L%uu2ob_FLJM}bT4Vjt zf;bg!hjFnO{#7s4DDsi#)+=)(d!a`}qh-y#bYqf`B2ZDTesVfBI+JC?v@$|Sh@R*O zWW-YRW0({@u@T_KrC8hxO#D)D0=Vd^j2o_JDxJ--K`~wV59GyE)e+o#rRsP5SMovi zsy;*qofG{vKfuLeMQtY}w)L(OH@NEbv_+DsW1p_MR2_=NB&q10rGP7o%=M?3c& z{hY!t2>S?41S#=6MY83eQ4~fDnmIO_o*7OhCsXVKt&S1CA}oozy60g zhETeQqKk-@5POv3OQNy9JFw*fk*m+bYVlV67<6ybC{G}vZ1Ys9M(Wb&#U#N)kb>7{ z>Ai;Xrc1QAOn8=XCxr>|Q$tf+N<4sjT<_HASZYeD8#HHGDmh$QSQB|)<1}0?o@qP> ziA`BiQrnKx>lTWpD5gsSVs?LjvroKHJWw7t)4YhdCl=lx)DXd7VmkC!0--vI3B3y5JCbzdkcHQ;#r(Y|bStQVA53ee`!I!a jH2NqSCu9hmut=Z@V=IJJ!d1dG!gaz8!cFl`+e`lf^xM++ delta 2013 zcmaizUu;uV9LMkPw!OW*y|=yVm|Hit2~KGT8ZwXtiKsAOAdcZ;Ah-p#wLoF3?DP`( zb1qHXs}GAuc))VxF0TT$dsf{8CCh9_f8YDf%ECOoK5Bu0$qdkP!Dflcn`{?6~* z-#NcO-%jtiy~A1OIF=@VKOEVa<Q;#+@;81N&$S)M#9`L`c8XiBv^EYR2(`JnDs{X0GvQhE(%oN*mi92!rc9 zp429F-{opuGG=8I=DL(o9U4_KYu`y3NePuKB_o4}Yt4gWl=Qyo=EuoT04Vxb&Y*psDCk{p>UGcx5n!Cffz=m%V2yZFiw@C6X=Q(S^#*mTu~+ zB?-6cVw1G0`uKR2JXZS{TdI+ru-*bb20j65b-_9a91tU|t2;i0`5Ev~h3DW?tv)NR zw>G$+W1R)Q0M4sk-)UV>7sZ;kcNcvL=QL0|9M%Pibc{U@dq(`$wki7Hp)a%-=yK(J z`v|GO-vjDUC!jbscdDvD)e^wQ0U!vdBIin*)l$0t-o^N+0=y2)S0(j}sYM;Bv)G5= zs-dyrk?a^l$-AsWESuR~b39P_?~8WP)ykH|?^F9+!Sm)U?mpGrFoUyS23`SF7U#sW zj_u25Z)I1oQ@c}Lw;MOJui*U}xCDHo`u=UlfN@POM+0KGc|qmbr+3oAIR&WULk8&x ztIb!GpE*g_MfZ~L8xBi`FZiv596?$8oPcLme zf^Uuj$0VwuT;5x?QtIcAbmJ4_2nsTejc_iX$&F?U?1X6T+?MzWTZ@@$N`ti7m^jhd yX#5ECu=t>JVrwr+ diff --git a/sgl/tasks/__pycache__/utils.cpython-39.pyc b/sgl/tasks/__pycache__/utils.cpython-39.pyc index e8392fe8ea3e28d6a12a0abc85e80ff2dc389df3..579d92513bbd1cd6796332cbad33b4c63fa3857a 100644 GIT binary patch delta 3070 zcmai0Z){sv6@S-$&%XZm>^N=grb*V2tgV;OGH6;@2kl1JHEoxbwG&h5WL*0-bz?iu zeNI@%wcM&-hM1;pS7OpY=tbIqZBSamTDED`B29u1mG%LqNe@2k+a$gXgoOBjIKTVy zkI)ud{=IY0J?GwY&ON{TeCx?~zHOIMsklL(iBqSYkwe$)K}-FEXPBi9T3b8D`H-|0 z!tF>aEF&_s#gHcOa;P2Eny}Uol7`Eunh53IZ;Y_d_GTla6)aWc*_ zp)AWX@eR_KE!7;%cV1rg%=q$uHHmWm2ZL7+$y%USJhU%`+uQ?X+FOhT=0z3^ zFS=mtGuoEyL|a^@E;AZLXY(?77w7o@IKwx->TAj9R->Jv;Jajqg8$_mgU_x7-+7UV z!>jRU?+}0F4)I51*J}7(vwxD=+l0U3yD|L0_gD}=oJm3Pcgu_yTQFod5Is>aCki1! zr(~>y2o66k_CPyE;D;;Ky826WII>CXqp`C2J?4H05AdPRq;87FT-~?Mlv?F;KG%<8 zwc*DCX=biPbLgwdI2wq?O;W{4gLZ0ETt6{Yn>n*Q5E*RzAi*SWIc}>+YfmP$6nKiS z!yp@g`U&sre34}L0Rn^CkcUjdII~%jo_W^G?9e9Qx*3u;%!ni8U5;Xnw zT?$;)V|lh$U5X9#$92KQS4r>3TB1~KfQKSYW3gY|h#l^XllTPvTvRPcKSM0c>Ou(2 zmT{gdpKcN`)=BO8FM0;ltT5_q(DPsIn2ZrUiJU}8b6v`DqGFh6nvw~x6JF-&6`3HO z+M16ka8ZV2T3zcwyhLa zXx5~3#l=z5YU+pss+ibri(^n6Cn$7_ed>ioE&=|CVZemCni$>nFhrb+I0SeE@HpTw z;8DN{fF{Sc)Jwcu*~zZGMbf!nCcx2~_=eE?e7XCxy33Qq0GojWNzIA48c+T$BMy+4 zADwobCh=Q5p^l{l`+>Tds<6k^kUf_F4$TWRjGrvm=G>MeinSTn?Wm@`m7P=1+kcLr zzr;~hPd^z6QdECP-+vI7gm?@PG%BPZRY+C9c>wbBBhz)K)Nn-w8V7JfeJ%4KdqK@- zhS;2XHS-#)sRLa{bFb5+n-ww7ZmA#kjInvO)bkX(sCsj|Sy3I$ z{hM7@ zR|yJsaGxxy)WE|Tz5TjPRdL|y+;O@Sf+ABRC0P&N=^^p%;vWZ&unsPo;6;IM7dx?J z*T$CDTcTWIEeu@~U7au{Y5)_U%eGdm6m%#mcmjD5@LRx1!^qvfWGB?MZNoct&A=kx zs+G<-HC--0-W2qwCQz)aT4r#7J*9RJ?q}~WP7jusGrL`S$+h?3x?p%;;bxDM?N!0!O>==1-0sGnU@ ze;*oSHx~=Tv#f8eIY(fn!-$EX>D2XmtISi|wvVtSb$t6$A2Kc^#XLRyuyer?qwvWA zexwRJM$bwZaPeapFyTil&XjN**G~(2anBdsShe-%~wM=Lt;9@UkaC=9X>RkM)W{Ru3;peZdj!=M2`)_&40SDP1h zEid^o^z-U^VfN7@5cQpK3eq{i9N;1VcZz11cpK8YfcF760ZV|70H3JFeGm1=_`N*E WxA8R3@fde`g5)GG@OyXE2I2`Lgzjt(B-!3O z+dx)TD4EGq^3Z*NhfXU^m^6e8Ej5HvG9lrWJ4ZR@58Owk{WVqg(U-BX=hXzx#;q+`6rKC+Ph=1$P)ot{aP+}8X zb%;aQ&p<92sBJxMFrxem_y#+OqCg}{Qz42=YKE)64iVHSS9d7K-mIBiIYr#WX_bwE zmCDT~sac>)aFY|RPb7SWTbIeZ1a_rrX0^QzX063$B3ANLLvytzakXxxp+mw*&4#4d z4TEdk>{gnjk!&gr+EiXpYDsQ{CO0n-m>V-)U}y9wwKTZq@We8g9z4aR`jb9o(5GcC zfsSXM;u8HoE>Z40cFCFfkb?{^WqG<`H5Hyk#QNs-eT*nSg-WV0DvE@N2F1#B*{@DA zOnE-Vo`)eyBMWAWp;y3?vzO6)L=4-#U06PrM$bsXMvT(_^cGNzOL|mi zE8-Mg`{PJIQ4C8bqm(?xE4Y3sWZ;9svIsBXhL}dMgKY&-Jf!&}8(|V4AX7B>Qewep zwc1WrkX67Ms-kwej=cgqp5T(~2~bKd9TFEaeTNKa1p^x6Q8Of#NYD%o4XMVphAzis z!+<2p24O#e7v+ba z8>6DC4RJsuFk`q{O5#fEd7=j%fD$ra|>ERlP!^IwL}^=8$1n{fh8{khdUTA#IGagVN5run}FU(KyR2FX0;4( zKGU$`J;N`@@V( z0ow_YO14FsKkc4~%)ko^kk69GMsR5t!fqhqI2i?9zt((YA3%ntESyKVhzpbORk#6;7Grp9E;#x#TSAmN}?VsrXm*(l-ql_+`I z^QK|K_KO!CM!2}@xMZ)$W=6VR$1^$rBD++c4ML9<%2PqW#IZ~d@x|H9R|eLa?G)qL z!$vEdo7v~KVK1>g2zZyWCGntm3gI*YMiLpPD_*f0Fb_2o2>ZmkwwK6xQEFRD=EU1= zmq|qox9`qfg2f=daxI9ZlNj(}=NH~-|A$8IiZi(pGB2*>j>6Vt*CujYywLR@@|M`# zy%Szg@7_q7;#&80^1e8}>W*;&Pr4)qdajZc;>(^KxxVmC&sHM29hXiPsSC^^wNi;l z$l@)yPhK1waK;e+fbaoAoCKUz5h@4?1Px*7Rg;Z7s|r_A!LPdBnJ6{yhbP7Fd+DwZ z(fBIDvebQqI$8fKIK%9YTk(rLchY0mPk9~rj$Q{b@5shyH^jGn+q3fcC5MW!{$sh9;lho}0Q9m+C){K0D!a4rS^rK#_KVNg zcGTOJbbw5wJOp#e77vyQlc|##lDU#;k~v+>BxX%!Jm1ms4HY*L4kG+-juObB-wg~7 z$t+>}M0&Y+!Yj)JMbB7-9C1UKyw_`y4{fkL@Y7pY# z+Lfy$>u=zm?Fe{F$G&|eHVz*6(PLlOC%A&88n>hq1@wpALhWA=ZX^6v*6!Y5FA2rh zgCpebLgzCxWc4BkCH;_yd6UJ25$~<@i4f^$H;{W`%d?-a#sKm*n};DvcxOE}jBm~% zToj!{!xJ?$pzOD3KvUwk)oi<}iV@*76ap&xG`a6Di4tKsJ7Wy?g`K_7mQ(A&re zri(1_*fa=vjr|IBNoW}Oa~7~SK@tbrSBkCc2gogPY<+Ii-_dRjTL}Y`wVJ{i#vIu} z$)7&EvGu~gp?+Sxw|-^}e`#aJ9>STAFo$p+;e7<`4fZj@9fVI2?jqbnxQ~D*vvygw fU_-vwqJ1 1: + local_inds, global_inds = batch_out pred = val_output[local_inds].max(1)[1].type_as(labels) correct_num_val += pred.eq(labels[global_inds]).double().sum() val_num += len(local_inds) @@ -63,10 +63,10 @@ def mini_batch_evaluate(model, val_loader, test_loader, labels, device): test_num = 0 for batch in test_loader: - batch_in, batch_out, block = model.sampling(batch) + batch_in, batch_out, block = batch test_output = model.model_forward(batch_in, block, device) - if isinstance(batch_out, list): - local_inds, global_inds = batch_out[1] + if batch_out.dim() > 1: + local_inds, global_inds = batch_out pred = test_output[local_inds].max(1)[1].type_as(labels) correct_num_test += pred.eq(labels[global_inds]).double().sum() test_num += len(local_inds) @@ -99,11 +99,11 @@ def mini_batch_train(model, train_loader, labels, device, optimizer, loss_fn): train_num = 0 for batch in train_loader: - batch_in, batch_out, block = model.sampling(batch) + batch_in, batch_out, block = batch optimizer.zero_grad() train_output = model.model_forward(batch_in, block, device) - if isinstance(batch_out, list): - local_inds, global_inds = batch_out[0] + if batch_out.dim() > 1: + local_inds, global_inds = batch_out loss_train = loss_fn(train_output[local_inds], labels[global_inds]) pred = train_output[local_inds].max(1)[1].type_as(labels) correct_num += pred.eq(labels[global_inds]).double().sum() diff --git a/sgl/utils/__pycache__/__init__.cpython-39.pyc b/sgl/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cac455710fe332e6bb0f8dd87dec1f60027ffc2b GIT binary patch literal 293 zcmYjLy-ve05RUBzD%1-41`J3j1Na3HVxkM}QY54-CzDu-B|8p2M@k=wmtf=}vNH7* z7!c=*fz#c0f8YJ6BuL?~CzSluT<2yX^r&9oBEx>yg^0Xwu_qA-j(-QGb?$%E?c0WpN jtQ?}TC;?1oBEW6GpXPVY*Mrsf>3y6NEZ9PD#(#bR??Oyp literal 0 HcmV?d00001 diff --git a/sgl/utils/__pycache__/auto_choose_gpu.cpython-39.pyc b/sgl/utils/__pycache__/auto_choose_gpu.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98928a8df7553e2caca3bcec071b8a2149d67e54 GIT binary patch literal 1185 zcmZuw&1(}u6rVRcyZJDwRbru{vWN(KNa8_nr6_8v&{m-pAs}JROp}%U%FINPumQ2X zdhjBGEi^~}rM-IUzu=|5*;1|5fq66U&6|0@_kM5udc8_;+#RloPZl9RjB6hqosnDD!;UEcN|{ZCf>p^ zzD7LA*+!r=u8tIXB;%ejpA8e*#`zeh`iO!9Cn1@gt}p{mGdQAXwdV92;q1_w(Tu#a zj#x(LAlLqPKD?8WBb(b9-Z(@rAO%#Sos9DG9z80}$rG{%Gpq2c{PD{y1YX^efJi#^ zemd$#BI<^SzZ-n`@|8?kOON&RLHs`8f!|b7;5LVDle?QDO60`N+`bg4+kAB2U4QxN zOkI;gxT|i@ym~e$ZcO6aIModleUgzkjTvZy3Bw$DSVoAS>5w5LI5z$s&8;BTxvkPL z(5TVYy&T5hiLBu1&sKYoM53+Ku)V#$(F`Z!w(5uNkq$!D_D4GD_6A9!L>KdFrIWn6 zViM~mQM$TdP#7*8peIvzUFh|6bedCZB}{sLs8(@v=AMI`S~7CIa!wP4#QHH$qP%2Q z1CXV@9^{S~;{`lc6QxCzvm_O9UXjA*VGs-DF@~ literal 0 HcmV?d00001 diff --git a/sgl/utils/__pycache__/basic_operations.cpython-39.pyc b/sgl/utils/__pycache__/basic_operations.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88549097102c3a5747c83bdd20c6c732d711a58c GIT binary patch literal 593 zcmYjOv5wO~5Zzs`om?(CK!8L+;WhTdfcr)}Q8#Y1DviP?k3B~JxXH?zDy_Fwvv<)Y zpvpeFz?#o4$V}U&4x!)!5p-i?|l#skvP-LlTv{hJ`8cwOV z>*0tvMMEF;o?h2F% Date: Sat, 2 Dec 2023 01:16:08 +0000 Subject: [PATCH 08/28] save calculated probability; fix bugs in inductive-learning; add batch-wise evaluation choice (as what PyG does) --- examples/configs/fastgcn.yml | 7 +- examples/configs/graphsage.yml | 21 +++-- examples/configs/lazygnn.yml | 12 +-- examples/sample_based_nodeclass.py | 7 +- sgl/dataset/__pycache__/ogbn.cpython-37.pyc | Bin 2850 -> 2856 bytes sgl/dataset/ogbn.py | 2 +- .../__pycache__/base_model.cpython-37.pyc | Bin 9523 -> 10649 bytes .../__pycache__/base_model.cpython-39.pyc | Bin 9286 -> 9683 bytes .../__pycache__/simple_models.cpython-37.pyc | Bin 11723 -> 12552 bytes sgl/models/base_model.py | 77 +++++++++++++---- .../__pycache__/clustergcn.cpython-37.pyc | Bin 875 -> 1369 bytes .../__pycache__/clustergcn.cpython-39.pyc | Bin 992 -> 1379 bytes .../homo/__pycache__/lazygnn.cpython-37.pyc | Bin 4445 -> 4447 bytes .../homo/__pycache__/lazygnn.cpython-39.pyc | Bin 4508 -> 4510 bytes sgl/models/homo/clustergcn.py | 11 ++- sgl/models/homo/lazygnn.py | 8 +- sgl/models/simple_models.py | 40 ++++++++- .../__pycache__/base_sampler.cpython-37.pyc | Bin 3416 -> 3740 bytes .../__pycache__/sampler.cpython-37.pyc | Bin 9429 -> 9426 bytes .../__pycache__/sampler.cpython-39.pyc | Bin 9401 -> 9387 bytes sgl/sampler/base_sampler.py | 50 ++++++----- sgl/sampler/sampler.py | 8 +- ...ode_classification_sampling.cpython-37.pyc | Bin 10348 -> 10727 bytes ...ode_classification_sampling.cpython-39.pyc | Bin 10317 -> 10264 bytes sgl/tasks/__pycache__/utils.cpython-37.pyc | Bin 11185 -> 10793 bytes sgl/tasks/__pycache__/utils.cpython-39.pyc | Bin 11193 -> 10756 bytes sgl/tasks/node_classification_sampling.py | 81 +++++++++++------- sgl/tasks/utils.py | 77 +++++++---------- 28 files changed, 253 insertions(+), 148 deletions(-) diff --git a/examples/configs/fastgcn.yml b/examples/configs/fastgcn.yml index 5169934..9feaa22 100644 --- a/examples/configs/fastgcn.yml +++ b/examples/configs/fastgcn.yml @@ -8,11 +8,9 @@ sampler: name: "FastGCNSampler" inductive: False pre_sampling_op: "LaplacianGraphOp" - layer_sizes: "256,256" + layer_sizes: "2048,2048" prob_type: "normalize" replace: True - eval: - name: "FullSampler" model: name: "FastGCN" hidden_dim: 128 @@ -20,7 +18,8 @@ model: num_layers: 2 task: name: "NodeClassification_Sampling" - train_batch_size: 256 + train_batch_size: 2048 + train_num_workers: 3 epochs: 30 lr: 0.1 weight_decay: 0.00005 diff --git a/examples/configs/graphsage.yml b/examples/configs/graphsage.yml index 4c4b488..8bfbfb3 100644 --- a/examples/configs/graphsage.yml +++ b/examples/configs/graphsage.yml @@ -7,21 +7,30 @@ sampler: training: name: "NeighborSampler" inductive: False - layer_sizes: "5,5" + layer_sizes: "10,5" prob_type: "normalize" - replace: True + replace: False post_sampling_op: "RwGraphOp" + # eval: + # name: "NeighborSampler" + # inductive: False + # layer_sizes: "-1" + # prob_type: "normalize" + # post_sampling_op: "RwGraphOp" model: name: "GraphSAGE" - hidden_dim: 128 + hidden_dim: 256 dropout: 0.5 num_layers: 2 task: name: "NodeClassification_Sampling" - train_batch_size: 2048 - eval_batch_size: 64 + train_batch_size: 1024 + train_num_workers: 5 + # eval_batch_size: 1024 + # eval_num_workers: 5 + # eval_together: True epochs: 20 - lr: 0.1 + lr: 0.03 weight_decay: 0.00005 loss_fn: "nll_loss" diff --git a/examples/configs/lazygnn.yml b/examples/configs/lazygnn.yml index 7950302..9f64cc7 100644 --- a/examples/configs/lazygnn.yml +++ b/examples/configs/lazygnn.yml @@ -7,23 +7,23 @@ sampler: training: name: "NeighborSampler" inductive: False - layer_sizes: "5,5" + layer_sizes: "10,5" prob_type: "normalize" - replace: True + replace: False post_sampling_op: "LaplacianGraphOp" model: name: "LazyGNN" basemodel: "GCN" - hidden_dim: 128 + hidden_dim: 256 dropout: 0.5 num_layers: 2 max_workers: 8 max_threads: 10 task: name: "NodeClassification_RecycleSampling" - num_iters: 200 - lr: 0.1 + num_iters: 100 + lr: 0.01 weight_decay: 0.00005 loss_fn: "nll_loss" - train_batch_size: 2048 + train_batch_size: 1024 diff --git a/examples/sample_based_nodeclass.py b/examples/sample_based_nodeclass.py index 62da388..b6ec4ff 100644 --- a/examples/sample_based_nodeclass.py +++ b/examples/sample_based_nodeclass.py @@ -19,6 +19,7 @@ config = yaml.safe_load(open(args.config_path, "rb")) device = f"cuda:{args.device}" if args.device >= 0 else "cpu" dataset_kwargs = config["dataset"] + task_kwargs = config["task"] classname = dataset_kwargs.pop("classname") dataset = getattr(Dataset, classname)(**dataset_kwargs) training_sampler_kwargs = config["sampler"]["training"] @@ -26,19 +27,21 @@ inductive = training_sampler_kwargs.pop("inductive") else: inductive = False + task_kwargs.update({"inductive": inductive}) training_sampler_name = training_sampler_kwargs.pop("name") + training_sampler_kwargs.update({"save_dir": dataset.processed_dir}) training_sampler = getattr(Sampler, training_sampler_name)(dataset.adj[dataset.train_idx, :][:, dataset.train_idx] if inductive else dataset.adj, **training_sampler_kwargs) if "eval" in config["sampler"].keys(): eval_sampler_kwargs = config["sampler"]["eval"] eval_sampler_name = eval_sampler_kwargs.pop("name") + eval_sampler_kwargs.update({"save_dir": dataset.processed_dir}) eval_sampler = getattr(Sampler, eval_sampler_name)(dataset.adj, **eval_sampler_kwargs) else: - eval_sampler = training_sampler + eval_sampler = None model_kwargs = config["model"] model_name = model_kwargs.pop("name") model_kwargs.update({"device": device}) model = getattr(HomoModels, model_name)(dataset, training_sampler, eval_sampler, **model_kwargs) - task_kwargs = config["task"] task_kwargs.update({"device": device}) task_name = task_kwargs.pop("name") test_acc = getattr(Tasks, task_name)(dataset, model, **task_kwargs).test_acc diff --git a/sgl/dataset/__pycache__/ogbn.cpython-37.pyc b/sgl/dataset/__pycache__/ogbn.cpython-37.pyc index f6be8936642e7b9be9103ecc066ac80f215fe174..2cce8e43301b1ad5296c6dd0830488f3006734f9 100644 GIT binary patch delta 39 tcmZ1^wnB{8iI9C_g#1xHvUsvp4f~b^ziP40Heh delta 33 ncmZ1>wn&WEiIA>DXN zNa<@=BJYZ26+Yw&Q^FL1IWm$G*&xh#TBeZCn?=jW+9|@CV2ZU>A&DCCA~SQg@j&kd zm~}(|L0r-gnT~bp;PDf~hmRMimG7%n&Ja`-1pbK@YumhseXzt1@~`TW39rGLH}f05 z^CCZA*HNyS92(7pe20Qi2X7gkr{HlG09-OAl|m(@6m}&_rd8I52X;eAp+a2o1}Uj9 z7F{JNlp=OJ$WqOk9IUxA6~Q{~Mk=K28%Er4C4>@z7*-Tlb;CErlICh|Br8HZrIDf0 zRG2lwCfH*Do@UIOb`BzIEP8FmEGuLz$FZmpgePm6#W`kq@v2rZzvxAZ1%q1ixs1gc zz)z~K?B~P_jhn^HBG(a#u4k}JBBA9x_L z>_;ky{eCcIPXoX#a@kXDV5O0x3v2|Jjw96YyFELX&--WHw7;Stlm$ltBY`v%jch*K z|J!CHG=m?W3n}Z<6ypZRKtduFk(Lr1M+$0E@{fWEhLl=1kYJz;2J45>s97}g1(RAV zU_uCU=hPgCodlCbcpUt)KoVprS(}pT2dc${eF^oC!Q=c6fV-r^G{00IYf^04B_*W@ zwX=+G$yNBx`kiH0Fd}ZKqEV4bP!C}@Ol2sXt6(cv>RdUfmnpRpMH#fQ03+2=U>l+` z*s88dC9wU~_oq++<|1wccA&D`#d;cECKpLDDeAh8s3r6sn zm@DBtNDWkg(5fR|1Jc;IUwB4g*3%x)a&(A211E;!^E(}|W$DaNf$7yc)AN?e?3_KZ zOP>YNvGiHfar7sBeI4|5)6t8QIL*u%t7tJtFJyHroS^dc?3@i1Td@05afW@%`4<=ksSfa?#kOD7*qxa6a?uM^)AVy(?|8H9&gwv%u8qb*m-9sXI%>m-jzua9e|1X!=cv&e?Y_u^=(L~!~&84X^QP0%t<#y9ta=E*H zd!{vPY{Ak{L9`tx^g&Dz6p<>D7bz43LG(q4h!6LqAih|9645tj_L4TZ%lDi4=9~R? z=9}5QHNAZ@G#(841YJiK=gez^UxwV?&?-$$kp)9!6(fPY-qFX6fDybRvR)%(gh~30 zh|x^aZ$yn4$$$|zT1W;rgetP3X`#1n-}Zi_Jkvp2*jLCGWbgP$`x2H8ZzRmPheVf5)Y9GQUd0@(_gL0Zm!!7|&bg^jKzU;#j6) zn5Dr7PZbn)ql=}B#hSsIOC!2?1vg$!#}C%XpJOw`vRAgT zwh0{Hy1T5}#F}Ee-)8Bu-S~54V9ZdKHCZJ`b<=CBTLApHw&`Qt7CFsH#Wh(E=czquT^yG1Jl3~JLqR+G8B!e@PCkXgO zvORK{6iXurh)`Mq{43rEVeD)f#`%`vwm2>|D*&J2^^W)9bL{JU2PW}$=O=I)XS(LYFVaTRX(n#koA_$4ME>AgI7HUJ1)e&v4%1vN{N;e!^)DHnOgtbT;=NyD0AuDx{|Lu<+@Jg1vYO<{)D@AZrJuMd5y4~F#DQMFkbES8q6PVsceYPCA7*g8w=n@nNp z<|&qn*FO%Ol4~-B?p-0gE4WVMDvcvJ@UmQ$t4cNCuzBXlakdSPv;$QYb;(io^`H~j zmdOIlVcUwMI)NRjsyUhyoRy{+e;7AQ1$Z0Y_+toe*jO?PxqQB68Y^z7Tr@0uDVMc* z3ps{(2l)hM%yM>)kCLo;vXLveAzvz1%IJ_UW-q0cT)|j&gDXavSITp)x?(JH+v43A z=R*(ekXXnSa)uv__C=H1#%!_1FPq%TphGEwz#t4!{DmO`t4B_mC3|XY;<@u@CyJJx zKl-0LcpGFT3f{Y@g!^S2LIq2LYaXa*JCXx&R+UheQL3>j*3FR{RZ_ef_@$yA@bd^M z#V?6gz!x!c)-szL|35bC#ZfF?L0m)} zk0{qrz-*YsD!@BQLXrTHq@Xy6Wyw+3=q}UUVTMvhI$5wL@o`7oZHkTHvyh`$Qibt# zj90e9NFR?%^V;?ylF>s$OK4vUiD5kvjPNEr+!Yt}2hY4jQUp0*7D}^4zF;i-HKg+K z2?8QsD(DrGC&?R@U089$jkby}^?~F9no0=-B7-g|pg~wdT1cv-q8tHT?CP6vMQk)5 z(*x+%=(V^QePHx2x-0wy8og8f^B+MKhk&@vpCy(6f;!ZRPQ4wi*RDFSb zUKC$zF%j-c43D9Dmi!JQXrKor2+Mz4s-|fHtqoUzkBO6AX>qsZ24uu~>jrFzKDoW# z+x9ttRln7K4WKT}j`!d*@mt4rm=RZFAHuvC?#wh_qC;sg`5?jv;{DEV;5(6te-g7$ z$2QH7ZH0jP#!WGlI0n86@jk`x5ZjwYF9{ViNW4+i;(EluaJ91g{XxiH{DS ziRVc4a`DE{OXxU0+ZCbSDlAaplcXb?SxG@b1?dldKSdP@@G;|y;)~v=U{T!b?ZXT) zBtRi;Vx%t_eU%(&);K39i68oUYF-wN0;s?{e+jf__a-XM6BV2xm%V4i@e$u|f#89UCr`QZbk)A;kZ2#8$#9RxQ}nll$| z|I2i+(Raj?{XfHD5l_~JyafE_Ql$6kSN%5{p$Fb%To+#@Pe8x8mmG=s9Y@mzenUjM lI+N7bpKu-Q`zZFEEU6a0B`O13Y6ca4Q-lW}m!D}p_a_~kky`)& delta 1442 zcmY+EU2Gdg5P)~iXXkuz{MRUU94Cz(N8mn0MG#SF1xZ8sOHu+>T4G9VC%!gr{Kua2 zL%86Uv`s4Ur;e1qIxp}50-;C<@qnn1ctA)9cry|eQXhdwBqRjPoa0bdnr~-zXXkcj zcJIf@ub=kjy1H~h?fBw3ZDGQoh30{1OYDaDC^;qInzjuPtQ!9ny%!H;z&E{ zf|H&08!tH6aE%P!a9%8f471h%-t>mEtXjmHVmUXtcW`fsj9PPAxa2!Lm{*pCtkMd# z;kMLW0FI+Ibk@mF!!LZ1j!xM{A4uaJ-~CTbv*yXvWUg#psK`=gO&+45rn^?jSWEj#g>oU&%7n3!E8?VoXy5`(nw%p6I(QWoG|&|t{EAD_z=QAl zw_p{+fnyUYx$xob>Uc@9L_=6$DVAf`v0n-K0F6^QI9^k0=e+5hFIE&%L5ZZ@)fd^8IQN|1b%ZXn9 zWl~IyAiz&2^k5|z2Or)HJ_u8ITMOZ>XaxTXy5oE!H1KvPpgZmnm*&y}Bn~-_-$v88 zdFU;e#X$FU*up_a_uhxyp8?G689mzoAK>qy9ry^Zh2Mm;80>u?GWdCKwmVCqWj`Z9 z*dZ}b<9G4ru^uVGm-t)cBA znjKzZoMXI9z$>A#h&<0`ma+e0NIOpEAGkBPq0W<0y^<_EMDi^C;17sio)T@-SeT~-*Yj5Wrqj0|I)@fyQo@Jo_I48Fzk zbq4)HguF~>I?csg)o3Lu%cBP+{{_E^{{%hw-cY^MF7bYKk0v5;1t$|Hpcl6j4?Wqc xE-gJR-{1gt&TT{v;%$=lJ<0#eK1yZhnPfm-MP>LM7vB!KiMNIyaf}Bi{{=wBB$@yK diff --git a/sgl/models/__pycache__/simple_models.cpython-37.pyc b/sgl/models/__pycache__/simple_models.cpython-37.pyc index 4d05b2de9353e321cf73c35c5774984a9bf8c5af..cede134e2a81b99d1806c490f515439c4cead2d5 100644 GIT binary patch delta 1026 zcmdr~-D(p-6rS1L>~Ff&wrZu5D_Yv}ZRGWVTUi zQPldk{xn`~A0ZTH|2#`DV_UGaobOJA1btm+dQkeI~=P zw_kR@pY7OM%D#ZE%L?3K1d;no@}E|h#1?qei2{Cd%G?^Y{rEctbKfHCk0ZV*th zV!5)~0Q_=`wk1E4gnkMqn#I!ca@_AV$_r3hT`VrwOO6Y%dO!zbM6gMjQqzkQ<&s}n zjP>qPtkz(Vj-L=T2#)mNDDGmTTD$FnTdTNZ_z7*z;ACLy5d-lxb zsmas58yN|X6I%-*qleeiqkKF3G(LK49RvGlo{pe%+td|(5Z1%5her6t@MHSSFXgqa z@;VmL3csW~yd6fF6`l+2>{EVe_e1tG=lM{KpZgy=0=BRi-Cls_w0lADhcXC zuMoOjfE5O6e+UWQ5cE`sx8%Mfcu(+8yZi7CBsfn5 V2iguDXxk1G6D^c)b7GR`egLV~;9vj% delta 174 zcmeB3IvvgD#LLUY00i985vj5|8~NB(CUXgzZVpoU!px(|R3rtIDyjz(lQ(jTPv+I) zntV!GTC@%*@E?ea*?lrg?a!h`%uL3r*7if|qBiu+4kdeKUoeb)Mx(^w+ HGAaQ8(eEin diff --git a/sgl/models/base_model.py b/sgl/models/base_model.py index e73f8b2..19f528d 100644 --- a/sgl/models/base_model.py +++ b/sgl/models/base_model.py @@ -87,49 +87,88 @@ def processed_feature(self): return self._processed_feature @property - def collate_fn(self): - if self.training: - return self._training_sampling_op.collate_fn - else: - return self._eval_sampling_op.collate_fn + def train_collate_fn(self): + return self._training_sampling_op.collate_fn + + @property + def eval_collate_fn(self): + return self._eval_sampling_op.collate_fn - def sampling(self, batch_inds): - if self.training: - return self._training_sampling_op.sampling(batch_inds) + def mini_batch_prepare_forward(self, batch, device, inductive=False, transfer_y_to_device=True): + batch_in, batch_out, block = batch + + if inductive is False: + in_x = self._processed_feature[batch_in].to(device) + y_truth = self._vanilla_y[batch_out] else: - return self._eval_sampling_op.sampling(batch_inds) - - def preprocess(self, adj, x, mini_batch_eval, device): + in_x = self._processed_train_feature[batch_in].to(device) + y_truth = self._vanilla_train_y[batch_out] + + if transfer_y_to_device is True: + y_truth = y_truth.to(device) + + block.to_device(device) + + y_pred = self._base_model(in_x, block) + + return y_pred, y_truth + + def full_batch_prepare_forward(self, node_idx): + y_pred = self._base_model(self._processed_feature, self._processed_block)[node_idx] + y_truth = self._vanilla_y[node_idx] + return y_pred, y_truth + + def inference(self, dataloader, device): + preds = self._base_model.inference(self.processed_feature, dataloader, device) + return preds + + def preprocess(self, adj, x, y, device, **kwargs): if self._pre_graph_op is not None: norm_adj = self._pre_graph_op._construct_adj(adj) else: norm_adj = adj norm_adj = sparse_mx_to_torch_sparse_tensor(norm_adj) self._processed_block = Block(norm_adj) - + if hasattr(self, "_pre_feature_op"): self._processed_feature = self._pre_feature_op._transform_x(x) else: self._processed_feature = x - if mini_batch_eval is False: + self._vanilla_y = y + mini_batch = kwargs.get("mini_batch", True) + if mini_batch is False: self._processed_block.to_device(device) self._processed_feature = self._processed_feature.to(device) + self._vanilla_y = self._vanilla_y.to(device) + + inductive = kwargs.get("inductive", False) + if inductive is True: + train_idx = kwargs.get("train_idx", None) + if train_idx is None: + raise ValueError(f"For inductive learning, " + "please pass train idx " + "as the parameters of preprocess function.") + if hasattr(self, "_pre_feature_op"): + self._processed_train_feature = self._pre_feature_op._transform_x(x[train_idx]) + else: + self._processed_train_feature = x[train_idx] + self._vanilla_train_y = y[train_idx] + def postprocess(self, adj, output): if self._post_graph_op is not None: raise NotImplementedError return output - # a wrapper of the forward function def model_forward(self, batch_in, block, device): - return self.forward(batch_in, block, device) - - def forward(self, batch_in, block, device): x = self._processed_feature[batch_in].to(device) block.to_device(device) - output = self._base_model(x, block) - return output + return self.forward(x, block) + + def forward(self, x, block): + return self._base_model(x, block), self._vanilla_y + class BaseHeteroSGAPModel(nn.Module): def __init__(self, prop_steps, feat_dim, output_dim): diff --git a/sgl/models/homo/__pycache__/clustergcn.cpython-37.pyc b/sgl/models/homo/__pycache__/clustergcn.cpython-37.pyc index e845138149a176a6551ca8b3d55c35a56efe512c..e012d0b3cae8f17e7a35a1e424e5bbbabc6fe925 100644 GIT binary patch literal 1369 zcmZ`&OOG5i5Vrf>y)$81l#n0+xq#G%%v?Dj0$HKuuvw6n(hAZOD6K1 z5-F$s54#5rT=_XZMdE<+7dTPXoyjf-+^Xtwl^<30eSS2ZW(3CXfBqtqfRMj%vKufb z--4L0KnWtKA{AXwingb`%3t_{?W>>)7a=7_LGLp|~+Ft`*-%4TjK7+w@(Y zxzO~j)H;OSE@7;!%Z@RPJ#tglYROjELe{kIHtRHpc)en~Jkb3i!%S|v^w1WsYup|? z18{~674$GLQZ3#7e~*#cO|DOXR3>P?kRNxBfQRhBkQOZKG{$$x~wed*qT@?}}~{n0q#U zKTmCV>0^WvOD+SO*zn4;(LJ)wY-ZD?2awYfK=J@kXbNeJ6bv!#wbXauMwhX3UYAO7 zwsC3Ku(2QdV~FK@`YtXB&v;j?^`~ejgAD+?bhH3}SfHBXSC@eX49i*=x3^Nw*=`b) zHGA&j4M5*_YZq;h5<-6t2af7K9{ouLv;e5#E{wMvXtQkeZ(NHh#0>&SXhsvd4_^lI z;)@&I{r}y$uMrxJFG?q%@(4*<^GY%X4M1k~3hfDF&w8%L$bhMS0PBInBR2}E6!7|e z;yKly;i30YA^A}ZLYmN9@!&bhV|N$0GaI=xGbKE}cVspr$8X;VUo%}JOWYldHSo50 h*Jv}_J)_}i4d-7cbEB2X*JDZFP|OF5p8z2z-oK^~b|3%% delta 466 zcmYLEO-lnY5Y23|AC~P_6sh3H&BIC&FP{7WEz*mM;6+v;)F$G>u1eBgB;Y~OlLy&9 zpufno>n{-hgfmrf!o0l6%zH`R8(%@P(`+^vM$$Xg3*K`wd)vWEF$D!HAp;=p6<*0R z4y?x%SK^*2q4^!pgz}GA8hnyL3N`O_4{&Sk^sPSH-ap*i-M>=07*P8!#4oXnNHohp zvCLD!BRdC$cJ~u;#vhqrnLy|8`F#yGsus+295^Pq^AN?*ciai*yU+!SKlv|#i(aW8 zDuyaiox*(e0L$&D_U-6e8$-9&k(BvmK9sWhfV@q=t{WG*Qu>v9wJNbyLbg(YJP>i$Y2>PcsB5 sm^tDL1oSjY>mfuit`E@VO{}kui+|k@BUK%3lviq0=q(^A>0^k!Z;=XIcK`qY diff --git a/sgl/models/homo/__pycache__/clustergcn.cpython-39.pyc b/sgl/models/homo/__pycache__/clustergcn.cpython-39.pyc index 956a7442e31e2cfaab59ee59ed268d278f28a9f3..4b5691536ccfcf0c9e4d80806e8b59b81b6f8faa 100644 GIT binary patch delta 490 zcmZ`$%}T>S5Z=Eeo0JCepsgV2K_MXGJM`eK7d-`KNwZrUjg4#)wX&fX!L#-S>Onl~ z+nAH!8~6lHta@{o`F1|OnQs>6-TCnRsNMDuoTt}g#o4|8yt55Vp%6z$=d|Eu5IuVw9sWSs_hAUM@hW_jtkgqyw}c?aUvGBpkG8Ek@poVCr5E&Zb43FUixMxwpeZe=Cuqk diff --git a/sgl/models/homo/__pycache__/lazygnn.cpython-39.pyc b/sgl/models/homo/__pycache__/lazygnn.cpython-39.pyc index 105e750b7088e64258385c8af220a453fb643beb..511201fd8358abdced87a088a318982b417eaa66 100644 GIT binary patch delta 39 tcmbQEJWrW7k(ZZ?0SKmUjZOWskynk4n=3g#CnvEaH9jqGvkhAUHvrfS3yA;# delta 37 rcmbQIJV%)~k(ZZ?0SNB?j!b>Ekynk4lcP8>w;(4oFMYEeTLL!#%#{mE diff --git a/sgl/models/homo/clustergcn.py b/sgl/models/homo/clustergcn.py index b4c38b8..e89a8ef 100644 --- a/sgl/models/homo/clustergcn.py +++ b/sgl/models/homo/clustergcn.py @@ -8,4 +8,13 @@ def __init__(self, training_sampler, eval_sampler, nfeat, hidden_dim, nclass, dr self._pre_graph_op = LaplacianGraphOp(r=0.5) self._training_sampling_op = training_sampler self._eval_sampling_op = eval_sampler - self._base_model = GCN(nfeat=nfeat, nhid=hidden_dim, nclass=nclass, nlayers=num_layers, dropout=dropout).to(device) \ No newline at end of file + self._base_model = GCN(nfeat=nfeat, nhid=hidden_dim, nclass=nclass, nlayers=num_layers, dropout=dropout).to(device) + + def mini_batch_prepare_forward(self, batch, device): + batch_in, batch_out, block = batch + local_inds, global_inds = batch_out + in_x = self._processed_feature[batch_in].to(device) + y_truth = self._vanilla_y[global_inds].to(device) + block.to_device(device) + y_pred = self._base_model(in_x, block)[local_inds] + return y_pred, y_truth diff --git a/sgl/models/homo/lazygnn.py b/sgl/models/homo/lazygnn.py index f8a590a..61ecdaf 100644 --- a/sgl/models/homo/lazygnn.py +++ b/sgl/models/homo/lazygnn.py @@ -43,11 +43,11 @@ def preprocess(self, adj, x, val_dataloader=None, test_dataloader=None): self._val_samples = [] with concurrent.futures.ThreadPoolExecutor(max_workers=int(torch.get_num_threads()*0.4)) as executor: self._val_sampling_jobs = [executor.submit( - self._eval_sampling_op.sampling, val_dataloader(bid)) for bid in range(len(val_dataloader))] + self._eval_sampling_op.collate_fn, val_dataloader(bid)) for bid in range(len(val_dataloader))] self._test_samples = [] with concurrent.futures.ThreadPoolExecutor(max_workers=int(torch.get_num_threads()*0.4)) as executor: self._test_sampling_jobs = [executor.submit( - self._eval_sampling_op.sampling, test_dataloader(bid)) for bid in range(len(test_dataloader))] + self._eval_sampling_op.collate_fn, test_dataloader(bid)) for bid in range(len(test_dataloader))] self._processed_feature = x def generate_taus(self, T): @@ -67,12 +67,12 @@ def model_forward(self, x=None, block=None, use_full=False): return self._base_model(x, block) else: return self._base_model(self._processed_feature, self._processed_block) - + def flash_sampling(self, total_iter, dataloader): min_iter, max_iter = 1, self._max_threads count_iter, max_cycle = 0, max(self._taus) pre_cycle = np.asarray(list(itertools.accumulate(self._taus))) - sampling_func = self._training_sampling_op.sampling + sampling_func = self._training_sampling_op.collate_fn while count_iter < total_iter: # adaptively increase the number of sampled subgraphs diff --git a/sgl/models/simple_models.py b/sgl/models/simple_models.py index f7ed80c..665e7b0 100644 --- a/sgl/models/simple_models.py +++ b/sgl/models/simple_models.py @@ -285,6 +285,25 @@ def forward(self, x, block): raise ValueError('The sampling layer must be equal to GNN layer.') return F.log_softmax(repr, dim=1) + + def inference(self, x_all, subgraph_loader, device): + # Compute representations of nodes layer by layer, using *all* + # available edges. This leads to faster computation in contrast to + # immediately computing the final representations of each batch. + for i, conv in enumerate(self.gcs): + xs = [] + for batch in subgraph_loader: + batch_in, _, block = batch + block.to_device(device) + x = x_all[batch_in].to(device) + x = conv(x, block[0]) # one-layer sampling + if i != self.nlayers - 1: + x = F.relu(x) + xs.append(x.cpu()) + + x_all = torch.cat(xs, dim=0) + + return x_all class GCN(nn.Module): def __init__(self, nfeat, nhid, nclass, layer=GCNConv, nlayers=2, dropout=0.5): @@ -318,4 +337,23 @@ def forward(self, x, block): else: raise ValueError('The sampling layer must be equal to GNN layer.') - return F.log_softmax(repr, dim=1) \ No newline at end of file + return F.log_softmax(repr, dim=1) + + def inference(self, x_all, subgraph_loader, device): + # Compute representations of nodes layer by layer, using *all* + # available edges. This leads to faster computation in contrast to + # immediately computing the final representations of each batch. + for i, conv in enumerate(self.gcs): + xs = [] + for batch in subgraph_loader: + batch_in, _, block = batch + block.to_device(device) + x = x_all[batch_in].to(device) + x = conv(x, block[0]) # one-layer sampling + if i != self.nlayers - 1: + x = F.relu(x) + xs.append(x.cpu()) + + x_all = torch.cat(xs, dim=0) + + return x_all \ No newline at end of file diff --git a/sgl/sampler/__pycache__/base_sampler.cpython-37.pyc b/sgl/sampler/__pycache__/base_sampler.cpython-37.pyc index 996b186ce21decf01df9c7bbea5711c62ffea4b9..3e89faa3c0813e31a25a8c5b323cbd94d17793ad 100644 GIT binary patch delta 1578 zcmZuxJ!~9B6rP#=x!t}0^TjweNlutpoJ10$q@XYff<&1r9!5G721qY5<0zd zq0xg%SnX8{*$8T3y;rB?b;3Yub~=u`)y44DL@R4nD68p~IY<1ZoL=_sn{ZuqGyy)X%USB`dlxsv~C ztWL^+HQ{?3KDkB-$>;&}k+Va##|l~F1Y|7(zGVxZ(lI-Lzl?oBxW11#z0m2Ien`P? z9GECV*3H&Q#(rxdW}zBZ-U-*u8P&bF9%kUEo$jLW>lt>hHcxNc$u#*V?(!dJ_ z+kuz*yi;uADB1vh-1qt;7=fJ`P|?fzi=|&Ciha=EU0(O@`<-;tKiy#zrE#Y_r!<%g z#bT-S#7j4oK8{C`GW`!>=A>d#qVxa?QASZ2GH_IyOob}T#1qnYaYe$0`2qEIsPy=? z$d@pin_l7z9rWTkH0rQiAFPQVb<6)d4J(Lv0Z6Y33w3d!E|!wtv##7ZF~#L4X6(`x zaFH(oA+!RoMcV~$Qyt_QZ2_I9j*d2bTeQU*%mGyXv-Rj228Kgf0VY~A`8>*JP^Po; zisZ!+?SS^Z{6~9pQut9}AdHD4mlMZLpu;+Rk^wPE*HF$#jv|?Q(Eyr&C;B+|ESgp% zRl3K=QZ>@pebV04k7O*01YNo~s|<@$RX2%c>cYkhd@0xRpPb?IZ7_&iNKSOlb|E6% zRD-R%`Q7rp6$C;qu^~?)ok1Mnpt-~7{rtD`I~$7_4lgNj5y`7a-azsulG{j*lO6XA zZxy_yCEiySwH7SRvaJhF)zSpox=nE%m@=a6Fxe3==MO9A=_~m+m1p4&{h_jKJ_ASxfy?#o3WvzxW-wym6aA_SSp%sQ(4I9q>1=~n?T>t<8 delta 1188 zcmZuw&x;&I6t1fN-P7Hf%qHt(Hfxsr9tb0eiU`9R7EwZQF98otpdF`r*V(O}ZmK(n zm0lKH4-(8l>Jo@T=H@BjKcI*g$;qQOxqH=%U{1bQ6SI?$Y5IF#y?XC`Q?FisyZU=C z+6Y6R!0+ovar%AqP1L=88&;ffDo9MhGQlzDx$`CA&LNAB{c!geN*lGw&Hzr?Pqe27#*q4I*Vwwy1K~cC%V1;| z0UdK#L3R@0top}z3|T*$t1B;VKlO)VQdZ+UE7mWof6cEtZA>~2&{$<5CY4-O$~xCO zs=tZ>3sux_)-t`I{EiaACei_wPg<&<#a}ClCbHIurfH~#*FQ= z8tg8Uo0VCU9gCaXy2tj7+NkV@PHoKS3eh=tA zxZ+!+e|$6c3YjK_%s(v~uUG<`RC`6*c+*^FBQg0rRew2u4EljwXzM)O*iNgaRTd*z zrU{127BsA3(?VVMx+l-WqVFwS>t%4Ld)^!GFSH<FH1M#`FI$07tn7KxjZC7EqT) zv;(+IT{8e5J{{U&s_j316AQuQL}28JIbTEm8ge?9)?#{&&Vu)b3R|CD(Rox_7{=rp z7@e{>6X?T)n`+iNv-w?4Q*;v8r!B6mbTN$&NQ99&F0J` zrs*!LsR~*|>OsUOTS0vAr4nC!D}53a1ur1yMe(iRgBN^IU;O{GO_~N&(7^fjzs@=T z<@_h#Ono_(m`EgIA^7|F_gTlhp2(0J*S9}WCm}mz%f&<~nN5~b*_7awV!G6nZ6cv# zAv@qUuAiz{ z)OTi_vhPtt;)#nN{xnL6%KD^jfJdb}B*rxPe0MWA&=vrL&{m`~>(I$oM#iD86Hw4cPH0nf90d!ct0-zB+PF40rBE$A7cb@GB5m(Bd?MxmwR2l`A zqKt|?;z*$@O#m2j<2Y?7Xh*@lUBvmsRxx%Y{O&*!4F=IbYIvaW{ln4D)rfjmA}RK> zHl=Mry9jM%6Va3Ljgh_+Es!=Ak9F2VbR-}*q^g{kZ<3l)zC9PJDFMkV`EXtZX#|#L zOBtx;Mah;?l8?58Y~?k%7G7GZhL42KEuWQYY7H&rs#=Y_1lBdJs^w$0RFG(OK3fgt z<5evaf}Qj=`X=cL)uLm-ChTxQqqjhrM5n4WC*34vsVdcCaFA3!UDc{Fw5vvKm3sK(L<K3HY5%`#{iy!<>SXgk!DGglCmO=NJ)|; zF(R??#Qw7fKxJsO8-8dHLN9_S{5GWaBjA^4AHw<;Yh@P|?*|pOqOSo!b44{x>!JP7 zU?0p(gNTQV)@6rIt~D3R{%doAYl(fHoFB)iCJlvd#g7dk{A&UP-+n4G#lp z=yW@ZM-X-)>_!+xNFnTDqs<+|YYmE$hY%hH$R=_N)btht6Yq{sA7S&&T`lV=YFs=_ z;FZ-I%~vG#UTFtf+YjR5=}`bK7o9SlMD`f6szt5xoI@vs&h`Q>`4t|vU2D!aqHsgz zTcY`B{;dz7?5JcGGCeE<&!M^%V6gGm( zd7fFaV2Cb>o6DnK+40SyTX5(;w0aT&lMVCJT67u-UJ<;#P^!?YsKqZM zl9S_{9HrR3oVX$f^mVjkd@y5RG3nc=i~8a{H%m*KI*Gmr>R(y+*7NlTFC~wdrtRh6 zq24rP(czzlEMBTxuzd8>WQFLu7S~27%8j;0KZ(?e4B;r_YQ@bJwiF#Ig4W7T@SH~r zOY{QMhdNtC67|5N(A14EsKgErbv0_GunR-IMi@RL{&%4LNa2JK8*Yde8-aGR_lE{` PounmQN~h8%(ntORl@jrv delta 2461 zcmbtVTWl0%6rS7e&hE}`m(mN{8*VAmq1*%&kP2F%SQBoR60-r8>CS+&mznLEX(?rE zMZiQ2G5AM~PYS^&qlwmMA9yn+MosYHk1^_#B)%AZGVyZGzqFJlYNBoC+jGu;{@Xd< zIsaE9w?J3zqd=ro!r$_R=#m{-Q#oBR6Zq3wq&c7bhRnpRL$fwLf0xQs?GUk zS=ukz5j%QGvZGG(MK!+?^_Z=rt_!^d^|+ltJt6c~)RVS>x*_y7TOO6n)NR`Ops8?U zv}Qrz9Cf_F2d=3_EBx?>ZL%C=ZOU52LFJ~LWPd6ryH?^1Xhkq(Xd}5}5jxp9b!;d} zdKUsu6$-_Q6$C}6QYjSfI^xVMIdu}m5DfeR4MCTUB*fU4sVp5~;+p5%&hSAq z3KS?H8;Ct<7NiK$2&URNPCFS^k>SoR;(TJOB-<6ao@*k9HRPZ)W{~sU!$If#ACcRN zoMAueBYF$@#h{Jtj6a=Rdeb-Je+lIl`Ps>?S1Usg^lS!q_SE4ETrpg)w~QR-?* zy(Z6U-p6O8SuK=HS}9VBqBVw}ZmU_z)=n$7MwU{%1#@~uosCRS)gybQ6CSe?tpidi zQP;B)u9n2TgnUizlxE|o8FsX+!wnou)%8-Ey`RWNjC$OTvCk4mRDI@McA>qS4JKdS zx|8mR$CYj?bRz|t@rTEP5M0{{hS5L~;d0GvffXphM)phc(Y53SyJ#d*1Ajc=YA{)C zxT;mx<67X(IBd+=CbzMdj28!%M2KcSfbbN4K@WlyYgTk6t!hfYl9tnKcWV1_I+m$J zH~s+mz-j`K-TO(}K|rg(eFTfkFXUFFl^8)npd%oI;<~PPLk?N=-qFKo;*pAV&Vk_t z29fQ*G6<)Z*%#@FK?>KfsX-qtwwB;uc|}fitBLv9rq`6Ybtsw!^pn{n*hsL6U^796 zU<-i=z*dsB5j;q+KunS!B6t`fpDIp5!JiBZR5Jl=C!0Yw*K3&aG14r%1TBlfY}6rz z`rQ2W=JU$21_4m#G0*pA9PnW;YVZUBT^#n&pjvS}*iY({q(&{UyfYYQF=qQASArUk z*sgUZFymN4LF=LAo6^l*XJ5CDjQu|fP)bYqEl?ne6zGD()JW#9v<=CdPT(LAKaQFz zkAz?$dDSVa!nM!|LhLT-Mt2c@N%nhtZ$>ndB|V01UbV-hj>luR6kn9uG2qqAC)S1RT0|>|xk_Fp@>xJUPT&wsN`?Y&Z z^gIo_Y<6DS`i^?gFISui{+IS z2gIdy&k4`^P@V>hrTaRvB8W!l1CqLN4IuV{$_D#78wrcB6Ma2ql&moVdL)3_gxGb- Yl|>uuWH#A}A%UTOEUm8_H#wVa=?_FG|FiyZ$>%e=qT$_?sV9xm}%UnWv5Lh^ivHR zLie;=4+Ah2VZ8qBgMEyte5W#koKb#aY2F>5D|Vm*Y)0H-P@r+|BNX|^@fnp)0)70q z_@-?wQx7SN1QnsFhpp@XeF7FM*2`g zV3Jdsrj=k=CR%O(^zbaIMO)2&!-ew0M3e4+B>Ly})d4zt8J-}=ZG_fVWHmtA-ftIRgbf=(h?59-S<1aO4nlnHbKI7-HN&8@>R5?>-wwZX2>yD%kfP7nzt zY=}0;v*$xIh3&H1dHnGOck_FluT0EPj8-OEMWF;=z~957s92OD%&wm;%rmu0UP*qp ztCW##FXc4qY#!!bA>c`k-T~& zcO(+Ez;;Bc;ev13&V^7U8#e!?G@!20^ajuMZGT>N+PI7r$wy77QBcF6B3`u-%U%P{ zkvTnA3#tniRaQ0|_!K0;-|Q>pDa*3!M8d~LgA@GgzOY$#c|3Wt{p&mWla#nUgS%9D z^G)s#PN{Kn&hZ}x$M!zTBYZ|nrtqzeyvPgcaFHfmBfLkrPLQXNx8gNEFx0=Yla$T9 zR)fIuZG3%}%woS;dc5crQnJtI^r_=0SyP<|;ae`#YQkogAxW5`+g`@1EQmgNLh z%X<6)yFniR`G8UT-^O2k zz`_!o=f4hbS&`|}qX4iwn+E`eygHkzc{lL!9(mRA#mK9LAZtfos7iM@UmFi?zfp5x sd=77$^_Ku`XqKO|<~M4VMB3kmv^;8m{^8bsU1vE(S900h@!Si40u$D$V*mgE delta 2017 zcmah~-ES0C6yKTMotf>o1X`UiM_cZ6=L#wPRId(J)Q-g|!M zqhC+`FlE?=(U;+WpI^D`*6$iM>8ratk1utMmLX-FjFa{IJEc~sQ*M=oob%Up23iAB z=0wKHJIcEmM{x&kWn1gHtU4N(H6d@{a>422a-WbZTrN5~mvtczI@0-!X{_Mtv!)V-vYwDow6RazW_0lV?qZ zKod3D!~XTd{a6kH#7W`>;_xX~n6;GRV4nj2^xU(p_>=|`%tq6zHe zyA;kuNj94A&^C5iKQ2|-XSyq$U^!!~+;bm1_OpG)bPIPr&}a;rI-mg%w3=8Q2J8oH z0j!>}N-L<_#J@=6z-cXLnpWfFJvP3tT;i%E@7p(AntVd*rT?S#xBj=s5a}d6!;!WF zRxcLZo6C!(d-Bo_zGdpP3l`4;_5k(*o(Gfx`v4+@16aKPI0$%>4Xh6VUgXd++VflZt2q7y0D;jIHnYAP z&?$(gAu2Z6!K_OsnOzy=iCd_gTYLU~_G{(L?u&3)qLokFay0~P=`I3$X@;6-sn zZ$k4HG`xYlAf|ElbL}9XcU*mysiV6BQCrUmpE#XbN*wy!nX*u37ric7SoO{LVSCPU zJbPA@7X>&ZFFdbBvE80a@=@3K;zXJ6a$;SxKb;z=za-b;B}dcji_s`ey$DZ4c6hnD zSp{*s>U$N@@jIAPpHgx-e;r${?>zLBNBR_p75I>iz1Rx_A{c!D_z>_BKtw0{<1(u@ zHZSdlGL4}fhE~VsQ)fZ+Q&I>6H?n+h&Q1H~3{1ggl8fwXZpz=sLU!gmUAm2Z@Ja7; zFiRrBB!@amq_c_i3ZdE($)jl)eF9+LHHNk#SM(+J3X%91VRaUI3jy1A*dGmNatzC- z^D0p-%L&_j?5`PIcVPCPgUey}w^WzV1ZFiY3qhe^0)r$`x?X$EcS*#i1#Wyjq`3w1 zu%}0DuX|%WeP^PizbS!w91^e3zg3v!X_-_~{j3{ycsqIR@%h;EwTMIx^Xx!#)Kp-s f0-6AHG~t0gETq-WZZ;cwSt`rATrQQ*m0$h~Qd_Sw diff --git a/sgl/sampler/base_sampler.py b/sgl/sampler/base_sampler.py index cd30230..05ea9c7 100644 --- a/sgl/sampler/base_sampler.py +++ b/sgl/sampler/base_sampler.py @@ -1,3 +1,4 @@ +import os import numpy as np from scipy.sparse.linalg import norm as sparse_norm @@ -44,28 +45,35 @@ def _get_sample_sizes(self, **kwargs): self.num_layers = len(self.layer_sizes) def _calc_probs(self, **kwargs): - if "pre_probs" in kwargs.keys(): - self.probs = kwargs.pop("pre_probs") + prob_type = kwargs.get("prob_type", "normalize") + save_dir = kwargs.get("save_dir", None) + if save_dir is not None: + pre_calc_path = os.path.join(save_dir, f"{prob_type}_sample_probs.npy") + if os.path.exists(pre_calc_path): + self.probs = np.load(pre_calc_path) + print(f"Load from pre-calculated sampling probability from {str(pre_calc_path)}.") + return + if prob_type == "normalize": + col_norm = sparse_norm(self._adj, axis=0) + self.probs = col_norm / np.sum(col_norm) + elif prob_type == "uniform": + self.probs = np.ones(self._adj.shape[1]) + elif prob_type == "locality": + """ + This sampling strategy refers to GNNSampler [https://github.com/ICT-GIMLab/GNNSampler] + """ + min_neighs = kwargs.get("min_neighs", 2) + sim_threshold = kwargs.get("sim_threshold", 0.1) + step = kwargs.get("step", 1) + low_quality_score = kwargs.get("low_quality_score", 0.1) + locality_score = adj_train_analysis(self._adj, min_neighs, sim_threshold, step, low_quality_score) + self.probs = locality_score / np.sum(locality_score) else: - prob_type = kwargs.get("prob_type", "normalize") - if prob_type == "normalize": - col_norm = sparse_norm(self._adj, axis=0) - self.probs = col_norm / np.sum(col_norm) - elif prob_type == "uniform": - self.probs = np.ones(self._adj.shape[1]) - elif prob_type == "locality": - """ - This sampling strategy refers to GNNSampler [https://github.com/ICT-GIMLab/GNNSampler] - """ - min_neighs = kwargs.get("min_neighs", 2) - sim_threshold = kwargs.get("sim_threshold", 0.1) - step = kwargs.get("step", 1) - low_quality_score = kwargs.get("low_quality_score", 0.1) - locality_score = adj_train_analysis(self._adj, min_neighs, sim_threshold, step, low_quality_score) - self.probs = locality_score / np.sum(locality_score) - else: - raise ValueError(f"Don\'t support {prob_type} probability calculation. " - "Consider pre-calculating the probability and transfer it to pre_probs.") + raise ValueError(f"Don\'t support {prob_type} probability calculation. " + "Consider pre-calculating the probability and transfer it to pre_probs.") + if save_dir is not None: + np.save(open(pre_calc_path, "wb"), self.probs) + print(f"Save the sampling probability into {str(pre_calc_path)}.") def _post_process(self, adjs, to_sparse_tensor=True): if isinstance(adjs, list): diff --git a/sgl/sampler/sampler.py b/sgl/sampler/sampler.py index 8dcb511..c4affc5 100644 --- a/sgl/sampler/sampler.py +++ b/sgl/sampler/sampler.py @@ -4,7 +4,7 @@ import pickle as pkl import networkx as nx import scipy.sparse as sp -from torch_sparse import SparseTensor, cat +from torch_sparse import SparseTensor from torch_geometric.utils import from_networkx, mask_to_index from sgl.sampler.base_sampler import BaseSampler @@ -80,8 +80,8 @@ def _one_layer_sampling(self, prev_nodes, layer_size=-1): current_layer_adj = self._adj[prev_nodes, :] - if layer_size == -1: - # in case layer_size == -1, we simply keep all the neighbors + if layer_size < 0: + # in case layer_size < 0, we simply keep all the neighbors next_nodes = np.unique(current_layer_adj.indices) else: @@ -91,6 +91,8 @@ def _one_layer_sampling(self, prev_nodes, layer_size=-1): for start, stop in row_start_stop: neigh_index = current_layer_adj.indices[start:stop] + if neigh_index.size == 0: + continue probs = self.probs[neigh_index] / np.sum(self.probs[neigh_index]) num_samples = np.min([neigh_index.size, layer_size]) if self.replace is False else layer_size sampled_nodes = np.random.choice(neigh_index, num_samples, replace=self.replace, p=probs) diff --git a/sgl/tasks/__pycache__/node_classification_sampling.cpython-37.pyc b/sgl/tasks/__pycache__/node_classification_sampling.cpython-37.pyc index 131ecb77dfa66f119386ab6c6018038707729b91..30fc8e8b3f00f0eaf6d800883d5715d2c167ef99 100644 GIT binary patch delta 4228 zcma)9U2Gf25x%`6?~dg0N7OGVS&}WwHYNWhcAEdx`LXH-tzy_g5~y|=ns<^(QKY;) zDRO!jCSh8q`8BXW+M-EK6+xS%En3$P?Mq+t(1)UVXi(fkfEFl<7D<5?Mv<2S{brAn ztiS;xVZWW7nVp;6ota&JeDd$7Qx7DQaRqZu)lUGX$8VzXB@@C+{sG~JkIxbe2iPq+!!@{G1xS3Lugq^AQ@ zEzd-k_F}+{7YFuu31HTX+sa~$ui2PjlfhmuAvrH?a{45{U*dpgbq4K-t*w(oLv|G7 z+`6);^Y=Z=j^-3_^ES!)ZGBVrSFdH+x)cseVQ9v%O=)IC8jW>FJ3OT`iZv?5F^su&bvk1DVHujIy0AaR!g;_+gNbi^A4ZOy1Z1W7Z-SG`FydyvUCo- z9;pJBxLs?q32{{)75~w$v%N&cDs_9M>{eDC@k@PN{7BcYC89dhH5d%0p)p}@7#A76 zD!!_ZLCM8I_Zr9IH_Sd|tR3#f>80odj{+1mrRu8A;`rBD9&O`4%d(PAiBFBQlZT*M zVEiyTe$OV{3o0;Or|A|;tkr8OB;;1Q_vE zn|W%5`5H!1TU|8yjHh*U8lf3g)7IWp*D)8PST#ORy-`g_v?L~hDLe9J#M8c~zQ$Uc zpNY6>9D2lKLyE2cFk%~TY2*adwabi;qR;U6-5xg!s<%x}^`Z4&Qs$X$+OZc*TeahM z;#D|P9H>75S=5fh!Ru}hXt=g91AVMI*wt=@>QiLc9SRlOSp7|vr7k3q%PP;K&%v>G zx-RAAoMCWBq)V1`DZc5_j+^VX?Br(@M{llZr#_=MR_FNGIIR(GeA->*PVl4?=wsDg z`25D6zya`#>clpU){QFov{c!>P31<7F-T&aNnmOndsK{dV}eVTl3AXP7Jd&?A-J2Ub2EV@S@U^fs?{uA6lWUpHcm|pQw+xQ(dL-V*@JV-cNO)|K!x zT^)`8c|!*u9+b(QwufNzjLdN_DveUkr|q1l%LL4O2K8@6NZm6~Fo$P@_)?`_DV{62 z<@1jgRDL^Elvr3FEB0C;5@;p68nCs1UF1i_xb@hPeo~2-8?{==b&B(Kl%)s2+E$&y z!gNoqv~;hdXJ9i@1_D_GiLF*xKhn zt2uSPhxDfjb`cz)nfv6#|EBFU--mHv&~1m)3xu6HI8AR5K1i^GAVxqhiXcK)UVz(a z-r-KY>~KT;JTtZ4&0>@WrDA=J>9lXwLpz;G&KWxghkXoXqmVY zv+77z8XbX&EE~oDxSAJNdrqe4V%5=T=bN5LX9rqucqmae!>w@?l+uX+qg8#KA)A=5 z*7K5!*H0CXo@xyB18*p3#C<$WppS{A2CprmV(Xfpk{lXSX-nbV zq~@p)r2PzNP0~Gn*3&yZUdqEK>8AD!jFf8D%j8-i!<6Elt~k~eXT$M@Uq7z6y`4S} zyW;Dh`aJ`bA$cY!Q}SYRl|Hg4e%E)q@cX94^}ge`ZGVLEC!z6OwB~&P#C`Wcmr=9KjAs$eLuh%i z+JDlDa(ex7dZ*2Z5Bdv5irc`%(v}_WUgTs!=8-fSB#Xszt<-Fm9c)~2>+y%QyPGs+ z+R-lqYE4RZlPN4UYDLT^IPpY@4T#3TCGiQ~quStyi7c4>If7pCuc0ZC8Zsl4Gh!^a zSKL2zhdzKoa8T^bO^UaN?#kpyHBT^1FoH?1vt#0!+&wz&esG91Evh-SM#*+KpA-Mi zEwY2+Z2m~n3OAQ(0>4$fmCqfeXD`2v;7bH|5Zp;X^<`sKa(m^{Xm@@q{+6F+s^}R$ zJ@tC1b(5E{+_?`RZmIx<)F-c98=h75uqq2<^a1veczE;(dqg}lx_X3a6qgskw~0DM z&?eZb;;+!~Rf5w57sTGNJ*O^#3{vwrmGi}BX=%AusW0%yNd7WOdb&~S*u{CL{C1Rfw>G4yEH6kw( zJS}SDZr^8ukc#i0#P#tv*dy0o-Z{yz%&WU53UYBW4_~9HsP*MQYpyI6GNDWBBo^ZM#2EXra3+d{x51c8Q3!;o-K;$AY=vqVqgSAl zzxK(*v#j+~kT~ggsO#aM>mQ}{4O=JoM<)HdB)EyA{5=xu#l`J!N2$_34J|r`=kPqYY0!f1?eyc8Dbjww@iL(Ozc zSXAk+-H<8w2r#C{6akZ_vZ4DE=oGH&<+>DTGPwMf=ne@+AMa|sD;Qf`q1oR;M+2q+ zdkC_Klb8}YN_J*;;{wDqN=gI_O z=hdPr&0CJQ3&bPrW3@uUY&I~(-V0pb*$r`y0@-^o;$l=(iH-#SD(MCC{}UvE96PC< zAKeXxSwe9DJ-h)&fb%rRvK=Fr&rh;HXwM!9gDeCj0ojKD9pR=>BCCW5j@T4iLY+Wg zD#?IKN#z&-)NZmXTgDo|h3E%>Vfv;MaiYM*R&hiesDs}G#Z=cIYIyo@!+0Z6+3G|< ztCTv|?Vy!PT@TuTM7vGsccJZuc2C=PlKW19G|hcCbnaZ+aNhF-!_Xsky-%0^-KhA_(z;xVV;E5I)&`8Zez+emFXGp*eJc>N1dWo2;ar1_pfVDw!5;lk-|#Z5!bMg>H*9?hFHmt`a~)KtZ4ppwU(JoLJZD z!OFO^4f<%KlyG_eu{>-W@NGe>f!8@_tUaO!&%YOlo3AQhrq@H2?aubL!MaNfo*y6I z5+?^sM>nI*tgD)_l+Yq(9C!`d4rn0?W&&Cr+D>Rg&~`Pm{s!GZ8thk9(h#5`$cr*u zk&j%x4tb7&x%kmpxWvuzmf%m*wvni~ZUYl{^DC~K$qlCxcZwHZH&d=`O7F;K+EuW9 zS4L}Bfkf|_$;zITGfN&>@e;A6Vx?%z<(&M2VO`3V>3;S`+m}z>xS}x}t%P3SJ z7oP{wLn-NJ%J{j{rE7H$45C`*t2S4LXFZctl$LC^0 zB-=%i41!Dpep{Qk4ERALF;3#3mx(NPft3ceJt6jTSroW5a8WKt9h>ev-blbaDi9vE z{YxUir>P_~NlSoxE0jc&ENb+Rz)75&;#wIHiaPS80HD^Ann6=+kvgeN&qL{7*#}a6 zS~x#m7hUn3RF~>`E?jFs81$q|pqyY73I&;>oju6K+Oaf^N8J6i_L}KqA%y zPTa+*lu`$NYDyedjz<69>NR)=Uk20Xk7fREQZTphRJMth*VK-@ym!zISr{V>TamL=dc^W9lDwuEQnBIl$jHrl&>w*m*6bv zNq}s(pF1+IG5jWDu6U89s%e#tLX|cy=cwuD(^=HS+9_t4#$36YU*u^;zl087LA{!* z)vSu?DNdE<7d-Wn?d0-{aJ8$pv!0z(XdOU!6(NoA8UQRV=VTLpG6bs2PHov?$CH!Y zUbs-KKtA~*YT^Kk#gfmki^=1MAAX3?S0B>UiyVYx{XHc}e?IghlG9Bbqwmb8*%0v59TX9;& zjd(lS-Q?@=V#|j77u{aa(0vGTgbp+e;-{hL;D=#93ZFsv9Ky2*2LRUh&Ch{XDcsKo z>_Xp!M6OP--RZ6Dr+vEy-|}rA^LXSD{_yU9`|6@P1y=NVc75O|ImUiJu$RoRAH>q^ zUTm1{8$7u8Yp8x1;Tpmzgs&h>A@FK{1G$?BUqo1BcLukez7DJxEx@@g7#Rn|K+p)Mf1%3d1Uk$4^0 zY|4-18z?@<-Woa`x`-@Au-W|~r|Xg1il^~AR^9p?GILiQ9wo5d$&umgH5@RH@O1=S zfxk_jF+AP)(iYgtNdNe^Q2lKLhEM|V#7fPR?B%6w++WhWsJX>H92p?rXMY_rNRv&C zc0IbBcfcxmcVYAzDf|#vINBd!)8R+^aSOuR^+2q6KzMGyhlQ!J4c!n@L`h(xsuE4RE71xhkZ24N65;_QUXX~YN+5K2;eiLl6TJ0_@cnbvb{Y}b z)qL~c=AYaD|8xFu>aA0$QZgCS@O$;{xw5hUgVe{Y@iACW^O(?@dP{F>OcS7shR}hg zX9>M!w>4o1(`9cMkVS+AbiAmDc(JC=k9%=1AtFt^ri%zPNnrz1!U1-PC@?Kzz>J8x z+A8DRg{4cLx<$-wkm%Zms7EQX3VU5s#9MuB?Xt%2inwd$H4N%^X<=E{+R)vD&QiLZ z()B`g4NO+-scrRA&_*iRbvyJ{ zrT0M_NX%P|{|J2_`e(e9^4+g|=OE21-!bLe>G*Eo^dDgu|Ae7d9DIab;vu`C+I(lR z!Rg!VCVdOdR{I}ovLo8L zzxk1lmafgXrE}Zc3aViNddLOZINC(hI;;CpKUN#^Cc%?FJoZp)hFVBXI63A`b!-OE z52D4JWD)lhwE=HiM;EEkNnp24Bh(=>Qmttbal0D*TAR6PFj4dw^fC0^=&>ia2Ymv4 zwly8DUN||}H*HjaxUeqKeSak4z9kaAeT%JpbVF+u+FCeC$3o|STDQQoy(+L1ZePda zQB9c2CB=Kf?H87c)ts=YzhjcQu#tuNlLadd_D z%y=JVS8Q&nS}l3y;(QJHcNijmkU%Bj0r}HJl^vG@$52Uu;8`M`2gt{h zdv=D! zSFYE2FG&v(sQo=d)F>dZmia=>7cAOYcOr)h!<}q1Zy*&fl3M;6gkpbDWFy9f=)-4r|C5VpTPq| zdPOyLWc<3OdqV>X8mAU?mWD*{P)J#frA5AmMK+AHhM)9PeiwvVt4AaSQPH;H(=KTV zU05+zFH)|ar<61trNwd-m*{AFLVo4%4YhaF z1V5_wj;i3tlniyjk1H7}gP*`2qXKWd+HG7UXm@zyop~KeYo0vDPvHXj3xGmz7|LcD z=@AAU#dk0;zgn$^nOvkOFCnnz_b&05OWX}(y+|(3D%;Jh-?Z(Rj>`Sn!KaM5l84}2 zs(b6{a)?Gp2+k3l2L#E*CAVBH&WA(KOwfaVqFTC8uGS0LFu)^AtKRaeR}bt^9$2c_ z1a>G7EF}&S^9wa(dnn{pnqg7i%kCQr9A#fzs0GoGf(@Khm*y@>tM}-$pZqT17cle+ zTKyydxuX6yvym7?XVDTQGrF11viD6LC~x*&KTDYsIN0rSnR~DB&w^Enr#uG9VsWlo zs@LbrIOO63A0zw8PGu>5f>5i|hm6DUd3QfGtYaqU#M(;Vbr$a-)9VCTnaFRKJ9Car zRVUaXGx;ewo7-phgO&d>`{dPpM*cSU>GS}Z4-yO!3}fI;_Eq_I{(w!_5bVa#$pjV3 z>Qc!qbADcS4O}v*V#qN|`RQB#QYys5FR4A`>2dk1f&8w$L_SIIDT1d6_7SLUh1H%| zI-K_;8@xLCG^hgOw@JGBBy3vJx&3V8bljOjRr`9 z*9p{$RP@gh`(1*|@~>NWWmZ52dTlu{>Z^-|G~X>p#@~5(lJ{U_uYEXvlg<1$$COcg zmLPOYl*+d?qTV5Rm*5ToO~`*v@C!fzuZg@ndHdlR--m8Qo|#%YO$T0x@o&(G*9e*f zbgi5ol?TF|R>^4~|0coaL2X7l#qba1;I=Kcx-u>gZ+ngn%h$KvNN<{lAtA@6cd*=l E0Ie_(8UO$Q delta 3239 zcmaJ@-H#hr6`wm_9?$rD*IwJ3{a$C+Ng9$(Nl9pv&}@=YLX@V2mL=)L>zPf~@s6GA z*@TK28R`(0N)1bIA0kyC`vE9ILP85qr4kQ_2WUa0FV*}36(GdN0}>L5aLyg?2Z3}~ z_jk|7@1D8$%st0*@BDk`6P0*e7vS&Y$|d`mqd!jkgNPzwf6pzoy{IRc8BYfs@S>Km zN~mXvj%Z0>8T52+xw~b_a(j;BkR^LYH{TO(3iM~5VaWvn))XueFlQ-Sz7N2sa9@%8 zhQKpysaB*Xdc%Ubi>n$}1#`p`x}!Y|G{*idj*x)ulAa^u%#n`CXk*i z-U_Br+4!~efuv{Qa0wW#$Jy_S!|a!ldg&mjKOl4tCVFfOj)N3vuHACYO0~L=y`VjH zGzK;Uh6JN^7|<2ggc|7(BDi8#>1|N=@Y$0o2ymUxsJeAQ|*6z!<&b#@z%+ z$quea!AO5ySR$5Uskh;TmS#m?gY{;{nRMgeRcZs9JeZ~0;7$%$v~@!7!JLJ8t{*$c zV-JBf&ttRO(bOjkk8m3N!P?e?i~q}Gzl~oVMc`=0JY$@$7Uxp99@=R{It%wqjh__m|hCJDl-juDpr@*{d{RI zUNh~Rl{!7l-ZvJe3Rp&ZwNbBET-#i7pwQ;QL_dYVtL->DAFGpR**mebQ=dlbXCCsS zdKTa$3ZG^B;zvH3cWRS&h`k-xKbm_c1n3Q5r4{F@P4}W<8eu=c0y~(`vT<2wNmXU@$@J_( z%B7WM$J|a5EhBp$LJ8qxxNa7k7`vNXNiQTZ!8`cupUFZtkIY>N*um*A!UUU6l?P(H zw*>~A#KeP{EtlD;D^ooE{6%21qXz(RwAKLxqKg`|(?6z;l}QRjnaE;_Xz-sQ+5=rG z0&0>*5?sr3jh#$iY-eDx3Ip@h`VkSJMc0Hb=?QRagqmoPRe`<>lEkGG+$#e@!$i3j z0n~a@H)@GJ;uGh@RcNbV)8~cDdwkIoFG;@S%f5o_uIx!3O0dFLJ<>JWp4x*8dgEC) z){T1vLvUgOP^l%sBKi?G<>6zLS{i?DYMD*B1ASo2H?YG8GWeoSa8SC?=-uId6R=wqwxD8J6|gK5eiCH(O$0gRr#-!Ec&e8uKv6@+4A5|c$-W-r z-c(XlR~nXGHwG(39*Jjcu8W z^+xqNFBH0n0bfGDT4^?I#|jj;L95pS^=8YhRIkG&ZM59YdgG-Q!6haFzu{MKu~eUGGB0f+hnd=!6Bf69!BFR0F;hQPP7P5Dq!mo`$gX1)9 z>b^{$29=i?j~T>Gvs$mTT2;GVH_aX29#10gI2+03BcmdO)y{4_nY%-bLA2jN$g=l_ z=E%_N`9-w|da$3RizW72{%C3h-J=L&2;)H9!$`qaFC*_D@^WFnsW&Q?O{s!3%rBcz zbL_AAH2Y1#P_H2S0FT-%D(szPgr$lfSFt|?kM+&QbNrj>5ZLJB2x$bY{~+qDu9(ZN zOANwZ3KKh;VsAAP+UmAhOnKG@PeZaZ}a!rhr7pB{%TfPY~~F7?0jx!EwxP# MUp@QN%w975A3_!f4FCWD diff --git a/sgl/tasks/__pycache__/utils.cpython-37.pyc b/sgl/tasks/__pycache__/utils.cpython-37.pyc index 2f36ff370d90bf0037b186b1088a350a3c9834fa..6748656b1276377253d9b31b4d00bd729ab82b52 100644 GIT binary patch delta 2034 zcmZ{lO>7%g5P)~q`}1tCfASyONt~piiQ6KTR!u1_f2B}JnpA3$0x4iQew*UfKj~X1 zP1|Srkce9)U_nCSz*-#mIZ-Sj(E~S5NC=4o(()su9yxIZ4$Q3WP*8c%o0<3C?7W?s zZ=E|UpS(DCIu?s4II@4N*;k)blzX)3jSw17;dB&uYp&X~6|cXg@EDKpDm$<`&|N2Z zGFT7tKHk5ptPb%3K8Sjl5Ak8tBYcF9q8{Zbo<^PG89s(O&Byr!>I~2FNz`L}icg~+ z=ll2!>Iwc3-;X-WXZalJzDb_r2hf<}2l*k?)9qmhYfXAf8YWH+=QB8+=TJz+Q(6k3 zmdX{tsqUz+sEtrVtA}JrLQ0Vu$sM(!2Rmx8gA#WPsY_#By{P11Ij8xWy;UgteH)@y zv>nIh`E|Qc-xRi=c=VsWwQ{X^RSe?QIn&oGHEx%E6Q9eM__ojLwo@ld#L>Zr8}2=Q zzNs?kSNowK#7W#O()-g$;M8#jx%`SoFBcA&f!5GDzTx%-31^g77aAs>f;45QMftHt zSTaxLp%(B^O7TK6+=4~~bB%Z!*`D4~Jfo!?R~k)C8g)~KcK{d0we$*=#KE zjb3{_W4Is|gPtv;(d(I1J?L=gsS;Y<4kCe?)QL1?uVb0g@HFx~i;rNF5M6^BCCwg) zh*ez_N>x96(-yV-Wl`XMdVRB8{-2Wl8O|JGosxxv2Lk|S@a257TyvcKdeyf&2i^Ri z#xwUqWVT6Z&{yP8$rPYk#IY8Uz!G#?H0T5e(MooZ`CT-6nR93u2$G4KZSMxON`y!2 z)#HjP!&nDo^>zr!jWnW&M?^-kDoqP%3ULh~)(~P%MiJteG>|--zv zv0mD;1?8Eug8ao9w8beDh-Ugkj8+y2Dw{ZqBG<*WTjCC*Vu_CNw^k9<2=RhjwH_H- zAd`wO7D)u0o4norc3t_cXi=aI3=3>9;eu$SU zZYq8Yw%m8)>#*Sd6+b?@La_%>_*SvJ>C|n3UyGStP+pP)@#awPF{lj(ClB@+21BWZx;c)n4k00leM5 z)X#MI(&Z!H!prTWqgw#)w?9pl04}+Mna?2J{yK9W#`a!I*c4-wLaUXE4|Q?LJv=_$ zd$pIwmxk`EE%|;-V!^#To`r(zj~|AV8=d$ZZ1<~)ay0PmdK}BH!HM>J*=>m4cb%J= zn(swlnaZZ=)BbPHL4ZZ~&8Y@F+m26v0~p=kcRKnw`92Xm;Lgr`0oC>|GapfAZ|vWN z+wJJ=cK{!@f10}qV|S1`hwxFi1N_~2hf@Y~bD@I1v| hbH@%|pGbz2DHdZXmS7_+%p4YFlgwf{mSw{%{x{BV-bw%f delta 2442 zcmah~U2GIp6rMY~vvX%>W_P>&q1$e|r4-r~j6YzB{HP#^NLxTiAWco#?%iVTwhMO_ zXzCp|ro6}lfei80S)%bb23f<4J{nDw7!w{$=FOP+;)^fFgcm*M&h`f^I-5P`o_o(Z zbH8)Xcjx1iUz{4+Z`*MNp0Se?rL6}PU}%R>spT$x@(O zSqn>pZey*i4RkweXC0tB*ap@KdIQU_F3_E~SP$z3-NkaO4|F%%$ofHN*(Nps zx`%CMTR``+K{fDci7RrG=S7D{{rdvhAjj-1j_-fmYgb%coIB=7%N?z zDVBHw4DUB>xJD>RscDiXd=E?&5ez<<7C}m2H_0k6MkZ8Ob7Kr3EZi34fb!ZWuyr~S zaSb=_ni9AnOc4dxmWcamU15>gD8C}&5NW#!VF^2<_z@PV6IKK4epJBFO$tpUeN99} zBBRuaAFC_AURQQ1t|s&<74{N=v3|ZwSk2(|V&m(^#+PGHDexA(TX9nWpk?99not43 zXOBdXUPW+?@R1b^=`|SEZX@;x2r<#J0-@!6Aks|;ht>;z7(*BwVt`=405gE0Riu2Q zu83CO@Ot&WRGdEzN5;2<2qJ7|&O4(I>a6k}i@^+d!t)tR()m!e4^uqK-M>!5d_S ziD$rAVWQ@fCL>kJ$cSuBUBqfYIHY8xNl99bJ~U$p>T5Ds!*&E(Q*23LEYp+*T9p^D zSc@fInLNu2OqQSTfE`F~VnK_ulDtu;S~|yVSjornPvM>DONZt4}%Zjie9cb z&gArbV1+JW1-%o-VC@jBzY2c`>m3#`1uRa>a&lZPE?-iuV61R?*8^M*dCZ7QuaN~P z4{+hWYX}3_3|t2OM%NTC4$O@Jt8HQWwMdiJw}lC`PDu9J$Xua{q`c2kLfU>*5*_#7=l!8u+6z~;@CQm&&#-W7$%2}*CJ{)VCw`I-w%Q%4q ziT0QC$teq1l_1C~OY+PDAFA@>-c|E=jqinE?=9<1@{V`k znkJ)O!QR<<8YVfmPhb_x7odGPTp^BA^KRMQr0CtXe@vv&dmj+$|t<{y8D;6_FeaAJ6?Iss^jP* z6r)}~nlmh6 z?-gU--#M3@Zn%9v5XgSGf4})0`aUlMym$NWkhdG7n?6Kje-2zGeq(X-PlViQ92s0B z*>9lr91C0(H3i3YxcQgO;8!v@Y!JuD&-|m}7jLu!xLM>9sgZ+h{0<9PBXW%xU1F(g%P!R~Uws4%l zEzXKZ#G;Z!aA#~NIoZW?vtQ#jYX@ebCH}#J2sL6sg zRj?ch+T;ukW2QcU5qDGah1`gnz>T9j81Z$NS|LIpDp5r*V3@2SbtEJ{(t|O&7or~7 zI~5!#OulzRU@Y#PEtZKZ$g*gY%EC^96uq{Dk^>RCmSW3g?;sBqal1gnGH#izpd#re z$cYNGXeyI;e5VaMnLJ8Uy3;UT$Y%eC!iYZ}QC;GE{6W+r8}K(u6c@^l`PzQE#0HvUrR!iWR;3eHjnFDwk<&%h_N=H>wMDK>A6KEz_Cu`!+gJ81 zX%w^x8r4uDWsnkR6pAe?T%+(Bx{^VO+KPqGZ$!N+;P!KFJmmnkpkEn=$t*z)Lox;0 zpdA3ks_bE)DX8%hD4q@UDb}sEckV~?06-$-Qlq)lxI&^&<8T|o^X!bav33AWWQz_W za2*YSR%+O9TI7ir!DlXt#!KK5k+(4LhtkoA27SM%A(z8!XYkjsNk*c>~j4za)G_t@Rgjzncib9@k=DgzKKW4mARkeePn2+t>V7*V0<)RgAob;I06zD*y|Equ1G=&3nRpAuFb518Ya4 z+sGonpc{?U;ip2R^KiNipSKTTCY~AQProKYCwHV)zpN#R;^P* F{{puoQ%L{- delta 2918 zcmaJ@TW=Fb6y8~T?cMeI5{JaO0U-p42`x~PKpJSd-^!gxr76fb*(DepCmB0|&~Xig z2l`S~j9R2tsvNb@LJEb3O96p|+NY`#YSoHr`va=dig$P@J!jT&LKJGt-<~rwXXcwT z=bQDZ_FHcT;(os_!Dn0hUaN8K)xazz=Rv>l(cr1TZw>s?YXO*)mxd*RPr?Y34#`Xz zc4a*o7xOT=QetkzlXBK8T3n(95L#5`W$FQWw{(GQi@I&MH5gCY{+?dbO2!WK(j#$d z+QB}`uPLQ1nyrWq+Z#_NV+r%9?S;|8P%~CKgBA?}AKPPEgNZJS2EgHul-68G4wEpU zuYuo%)W<&4z`(7Et$aepw_=jLa#qQDOd!q3T4~=!psMb5WworH^@)vY%*WiYO@rxq zc~~-Chu!ov(;?2E4KRcGE2X^4bPW^Jl>^hwz|01j!h(5)xmci58YX$qu$1=>OG~A! z!n_%k`A-tCz5Nx;7z>^vR-ZCfFUG=L)lxPD=ANZHDP+l9BSZQA4GrYASoi@Qy(PjYK zWhQ#~es6tr8eNLMj+Kz?jhhy=eWDl~N7MstgCTsxyCi3;o$*Z9!B`??YmU`Bkg=7{ zWN+8u4lxJ#ebA&s3)WW3OxwQhWN&A|p&_KO9H7Igxzs>UhqIs)Nmt;2a1lzuvP*js zsYI+Z&KIgBxmTe@A}Ud3H=s@e#DHf7sUiV+8sdZ=uqXAUVp)OqoU8))Pq!x{-c|`;(SfikA73-8~s8zZIR*9;%lA&?T zK@Klg=d(BK!zGm z>eI{oB8zl0q|!O0*TDh-RkixD$v}7mVGBUC)De~8oY`b0(Rv)XjsK}Hj^f2h@f|?d zA#6ujkFWt@2ZB({PX3awenkxY^eX^(#%}z|r~tK~fv~7`19T7|s02eOLzh7xGHvJY z`u+}4?0PATy6Z}$OtH;^;nP5xXR_w%=h^WxQR?$r$974cN+MdHbRy>uaCbn1zJ9vL^ zIr)HJ4$dY6{FmStB*~A2Hb>4uXBrbMiLI6s4}9tI(ceRl+~gkmGz; z`9c!o@0I^cPVy6xjWEHFk(bG3{v>jhoZ}-EcRX=i=`2rHj#SP9Eu`@cCy_cx;C!OX zI;=pij4rBbBio+eC9QyQw%pTW6F5{c34%EHMzs17&Lez@P^<;ogKm=Fs&3tM0nLjD zc%e?n%?P^U?7u<_)7POp#bcS>7QKY6-3XTvzTs^(bIAhUS+kH_=Vxk8H{zWxDAKEF zeT(2MjMf!^Xu!ElM)|twt3zV_39It6)AvQTU_wNqp{fj0j_-V#-X8VM*hu!_^+I8Q zMDcx38*s{!$)XTTVx}h05a|$!abZUi!Hpm?=H!WAnORmpijCJ0b|Xxcq8 1: - local_inds, global_inds = batch_out - pred = val_output[local_inds].max(1)[1].type_as(labels) - correct_num_val += pred.eq(labels[global_inds]).double().sum() - val_num += len(local_inds) - else: - pred = val_output.max(1)[1].type_as(labels) - correct_num_val += pred.eq(labels[batch_out]).double().sum() - val_num += len(batch_out) + val_output, out_y = model.mini_batch_prepare_forward(batch, device) + pred = val_output.max(1)[1].type_as(out_y) + correct_num_val += pred.eq(out_y).double().sum() + val_num += len(out_y) + acc_val = correct_num_val / val_num test_num = 0 for batch in test_loader: - batch_in, batch_out, block = batch - test_output = model.model_forward(batch_in, block, device) - if batch_out.dim() > 1: - local_inds, global_inds = batch_out - pred = test_output[local_inds].max(1)[1].type_as(labels) - correct_num_test += pred.eq(labels[global_inds]).double().sum() - test_num += len(local_inds) - else: - pred = test_output.max(1)[1].type_as(labels) - correct_num_test += pred.eq(labels[batch_out]).double().sum() - test_num += len(batch_out) + test_output, out_y = model.mini_batch_prepare_forward(batch, device) + pred = test_output.max(1)[1].type_as(out_y) + correct_num_test += pred.eq(out_y).double().sum() + test_num += len(out_y) acc_test = correct_num_test / test_num return acc_val.item(), acc_test.item() -def train(model, all_idx, train_idx, labels, device, optimizer, loss_fn): +def train(model, train_idx, optimizer, loss_fn): model.train() optimizer.zero_grad() - train_output = model.model_forward(all_idx, model.processed_block, device) - loss_train = loss_fn(train_output[train_idx], labels[train_idx]) - acc_train = accuracy(train_output[train_idx], labels[train_idx]) + train_output, out_y = model.full_batch_prepare_forward(train_idx) + loss_train = loss_fn(train_output, out_y) + acc_train = accuracy(train_output, out_y) loss_train.backward() optimizer.step() return loss_train.item(), acc_train -def mini_batch_train(model, train_loader, labels, device, optimizer, loss_fn): +def mini_batch_train(model, train_loader, inductive, device, optimizer, loss_fn): model.train() correct_num = 0 loss_train_sum = 0. train_num = 0 for batch in train_loader: - batch_in, batch_out, block = batch optimizer.zero_grad() - train_output = model.model_forward(batch_in, block, device) - if batch_out.dim() > 1: - local_inds, global_inds = batch_out - loss_train = loss_fn(train_output[local_inds], labels[global_inds]) - pred = train_output[local_inds].max(1)[1].type_as(labels) - correct_num += pred.eq(labels[global_inds]).double().sum() - loss_train_sum += loss_train.item() - train_num += len(local_inds) - else: - loss_train = loss_fn(train_output, labels[batch_out]) - pred = train_output.max(1)[1].type_as(labels) - correct_num += pred.eq(labels[batch_out]).double().sum() - loss_train_sum += loss_train.item() - train_num += len(batch_out) + + train_output, out_y = model.mini_batch_prepare_forward(batch, device, inductive=inductive) + loss_train = loss_fn(train_output, out_y) + pred = train_output.max(1)[1].type_as(out_y) + correct_num += pred.eq(out_y).double().sum() + loss_train_sum += loss_train.item() + train_num += len(out_y) + loss_train.backward() optimizer.step() From b02651a01f617a851586960a897b3c967e2801de Mon Sep 17 00:00:00 2001 From: infinity Date: Sun, 3 Dec 2023 05:02:34 +0000 Subject: [PATCH 09/28] delete useless files. --- sgl/__pycache__/__init__.cpython-37.pyc | Bin 132 -> 0 bytes sgl/__pycache__/__init__.cpython-39.pyc | Bin 136 -> 0 bytes sgl/data/__pycache__/__init__.cpython-37.pyc | Bin 724 -> 0 bytes sgl/data/__pycache__/__init__.cpython-39.pyc | Bin 672 -> 0 bytes sgl/data/__pycache__/base_data.cpython-37.pyc | Bin 12287 -> 0 bytes sgl/data/__pycache__/base_data.cpython-39.pyc | Bin 12433 -> 0 bytes .../__pycache__/base_dataset.cpython-37.pyc | Bin 12807 -> 0 bytes .../__pycache__/base_dataset.cpython-39.pyc | Bin 12673 -> 0 bytes .../__pycache__/base_sampler.cpython-37.pyc | Bin 1565 -> 0 bytes .../__pycache__/transforms.cpython-37.pyc | Bin 7082 -> 0 bytes .../__pycache__/transforms.cpython-39.pyc | Bin 7060 -> 0 bytes sgl/data/__pycache__/utils.cpython-37.pyc | Bin 2786 -> 0 bytes sgl/data/__pycache__/utils.cpython-39.pyc | Bin 2816 -> 0 bytes .../__pycache__/__init__.cpython-37.pyc | Bin 1276 -> 0 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 1192 -> 0 bytes sgl/dataset/__pycache__/acm.cpython-37.pyc | Bin 4328 -> 0 bytes sgl/dataset/__pycache__/acm.cpython-39.pyc | Bin 4365 -> 0 bytes sgl/dataset/__pycache__/actor.cpython-37.pyc | Bin 4571 -> 0 bytes sgl/dataset/__pycache__/actor.cpython-39.pyc | Bin 4654 -> 0 bytes .../__pycache__/airports.cpython-37.pyc | Bin 3862 -> 0 bytes .../__pycache__/airports.cpython-39.pyc | Bin 3968 -> 0 bytes sgl/dataset/__pycache__/amazon.cpython-37.pyc | Bin 2570 -> 0 bytes sgl/dataset/__pycache__/amazon.cpython-39.pyc | Bin 2626 -> 0 bytes .../__pycache__/amazon_product.cpython-37.pyc | Bin 3674 -> 0 bytes .../__pycache__/amazon_product.cpython-39.pyc | Bin 3704 -> 0 bytes sgl/dataset/__pycache__/aminer.cpython-37.pyc | Bin 3443 -> 0 bytes sgl/dataset/__pycache__/aminer.cpython-39.pyc | Bin 3497 -> 0 bytes .../choose_edge_type.cpython-37.pyc | Bin 2972 -> 0 bytes .../choose_edge_type.cpython-39.pyc | Bin 2962 -> 0 bytes .../__pycache__/coauthor.cpython-37.pyc | Bin 2568 -> 0 bytes .../__pycache__/coauthor.cpython-39.pyc | Bin 2624 -> 0 bytes .../__pycache__/custom_dataset.cpython-37.pyc | Bin 7129 -> 0 bytes .../__pycache__/custom_dataset.cpython-39.pyc | Bin 7049 -> 0 bytes sgl/dataset/__pycache__/dblp.cpython-37.pyc | Bin 4367 -> 0 bytes sgl/dataset/__pycache__/dblp.cpython-39.pyc | Bin 4414 -> 0 bytes .../__pycache__/dblp_original.cpython-37.pyc | Bin 4532 -> 0 bytes .../__pycache__/dblp_original.cpython-39.pyc | Bin 4507 -> 0 bytes .../__pycache__/facebook.cpython-37.pyc | Bin 3227 -> 0 bytes .../__pycache__/facebook.cpython-39.pyc | Bin 3270 -> 0 bytes sgl/dataset/__pycache__/flickr.cpython-37.pyc | Bin 3590 -> 0 bytes sgl/dataset/__pycache__/flickr.cpython-39.pyc | Bin 3620 -> 0 bytes sgl/dataset/__pycache__/github.cpython-37.pyc | Bin 3212 -> 0 bytes sgl/dataset/__pycache__/github.cpython-39.pyc | Bin 3257 -> 0 bytes sgl/dataset/__pycache__/imdb.cpython-37.pyc | Bin 4255 -> 0 bytes sgl/dataset/__pycache__/imdb.cpython-39.pyc | Bin 4310 -> 0 bytes .../__pycache__/karateclub.cpython-37.pyc | Bin 3112 -> 0 bytes .../__pycache__/karateclub.cpython-39.pyc | Bin 3174 -> 0 bytes .../__pycache__/linkx_dataset.cpython-37.pyc | Bin 4654 -> 0 bytes .../__pycache__/linkx_dataset.cpython-39.pyc | Bin 4672 -> 0 bytes sgl/dataset/__pycache__/nell.cpython-37.pyc | Bin 4164 -> 0 bytes sgl/dataset/__pycache__/nell.cpython-39.pyc | Bin 4177 -> 0 bytes sgl/dataset/__pycache__/ogbn.cpython-37.pyc | Bin 2856 -> 0 bytes sgl/dataset/__pycache__/ogbn.cpython-39.pyc | Bin 2904 -> 0 bytes .../__pycache__/ogbn_mag.cpython-37.pyc | Bin 4571 -> 0 bytes .../__pycache__/ogbn_mag.cpython-39.pyc | Bin 4631 -> 0 bytes .../__pycache__/planetoid.cpython-37.pyc | Bin 4274 -> 0 bytes .../__pycache__/planetoid.cpython-39.pyc | Bin 4338 -> 0 bytes .../planetoid_sampling.cpython-37.pyc | Bin 4481 -> 0 bytes .../planetoid_sampling.cpython-39.pyc | Bin 4447 -> 0 bytes sgl/dataset/__pycache__/reddit.cpython-37.pyc | Bin 3228 -> 0 bytes sgl/dataset/__pycache__/reddit.cpython-39.pyc | Bin 3253 -> 0 bytes sgl/dataset/__pycache__/twitch.cpython-37.pyc | Bin 3340 -> 0 bytes sgl/dataset/__pycache__/twitch.cpython-39.pyc | Bin 3386 -> 0 bytes sgl/dataset/__pycache__/utils.cpython-37.pyc | Bin 2834 -> 0 bytes sgl/dataset/__pycache__/utils.cpython-39.pyc | Bin 2933 -> 0 bytes sgl/dataset/__pycache__/webkb.cpython-37.pyc | Bin 4685 -> 0 bytes sgl/dataset/__pycache__/webkb.cpython-39.pyc | Bin 4720 -> 0 bytes sgl/dataset/__pycache__/wikics.cpython-37.pyc | Bin 3979 -> 0 bytes sgl/dataset/__pycache__/wikics.cpython-39.pyc | Bin 4071 -> 0 bytes .../__pycache__/__init__.cpython-37.pyc | Bin 127 -> 0 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 131 -> 0 bytes .../__pycache__/base_model.cpython-37.pyc | Bin 10649 -> 0 bytes .../__pycache__/base_model.cpython-39.pyc | Bin 9683 -> 0 bytes .../__pycache__/sample_models.cpython-37.pyc | Bin 4119 -> 0 bytes .../__pycache__/sample_models.cpython-39.pyc | Bin 5267 -> 0 bytes .../__pycache__/simple_models.cpython-37.pyc | Bin 12552 -> 0 bytes .../__pycache__/simple_models.cpython-39.pyc | Bin 11395 -> 0 bytes .../__pycache__/__init__.cpython-37.pyc | Bin 130 -> 0 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 134 -> 0 bytes .../__pycache__/base_op.cpython-37.pyc | Bin 3060 -> 0 bytes .../__pycache__/base_op.cpython-39.pyc | Bin 3114 -> 0 bytes .../__pycache__/utils.cpython-37.pyc | Bin 3796 -> 0 bytes .../__pycache__/utils.cpython-39.pyc | Bin 3703 -> 0 bytes .../__pycache__/__init__.cpython-37.pyc | Bin 299 -> 0 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 273 -> 0 bytes .../__pycache__/base_sampler.cpython-37.pyc | Bin 3740 -> 0 bytes .../__pycache__/base_sampler.cpython-39.pyc | Bin 3453 -> 0 bytes .../__pycache__/fastgcn.cpython-37.pyc | Bin 1778 -> 0 bytes .../__pycache__/sampler.cpython-37.pyc | Bin 9426 -> 0 bytes .../__pycache__/sampler.cpython-39.pyc | Bin 9387 -> 0 bytes .../sampler_fastgcn.cpython-37.pyc | Bin 1786 -> 0 bytes .../sampler_methods.cpython-37.pyc | Bin 9763 -> 0 bytes sgl/sampler/__pycache__/utils.cpython-37.pyc | Bin 1869 -> 0 bytes sgl/sampler/__pycache__/utils.cpython-39.pyc | Bin 1919 -> 0 bytes sgl/tasks/__pycache__/__init__.cpython-37.pyc | Bin 908 -> 0 bytes sgl/tasks/__pycache__/__init__.cpython-39.pyc | Bin 876 -> 0 bytes .../__pycache__/base_task.cpython-37.pyc | Bin 752 -> 0 bytes .../__pycache__/base_task.cpython-39.pyc | Bin 772 -> 0 bytes .../clustering_metrics.cpython-37.pyc | Bin 3268 -> 0 bytes .../clustering_metrics.cpython-39.pyc | Bin 3280 -> 0 bytes .../correct_and_smooth.cpython-37.pyc | Bin 4686 -> 0 bytes .../correct_and_smooth.cpython-39.pyc | Bin 4732 -> 0 bytes .../link_prediction.cpython-37.pyc | Bin 9960 -> 0 bytes .../link_prediction.cpython-39.pyc | Bin 9949 -> 0 bytes .../node_classification.cpython-37.pyc | Bin 6805 -> 0 bytes .../node_classification.cpython-39.pyc | Bin 6645 -> 0 bytes .../node_classification_dist.cpython-37.pyc | Bin 4768 -> 0 bytes .../node_classification_dist.cpython-39.pyc | Bin 4800 -> 0 bytes ...ode_classification_sampling.cpython-37.pyc | Bin 10727 -> 0 bytes ...ode_classification_sampling.cpython-39.pyc | Bin 10264 -> 0 bytes ...assification_with_label_use.cpython-37.pyc | Bin 5266 -> 0 bytes ...assification_with_label_use.cpython-39.pyc | Bin 5292 -> 0 bytes .../node_clustering.cpython-37.pyc | Bin 8114 -> 0 bytes .../node_clustering.cpython-39.pyc | Bin 8115 -> 0 bytes sgl/tasks/__pycache__/utils.cpython-37.pyc | Bin 10793 -> 0 bytes sgl/tasks/__pycache__/utils.cpython-39.pyc | Bin 10756 -> 0 bytes .../__pycache__/__init__.cpython-37.pyc | Bin 208 -> 0 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 214 -> 0 bytes .../correct_and_smooth.cpython-37.pyc | Bin 2344 -> 0 bytes .../correct_and_smooth.cpython-39.pyc | Bin 2365 -> 0 bytes sgl/tricks/__pycache__/utils.cpython-37.pyc | Bin 2237 -> 0 bytes sgl/tricks/__pycache__/utils.cpython-39.pyc | Bin 2240 -> 0 bytes sgl/utils/__pycache__/__init__.cpython-37.pyc | Bin 287 -> 0 bytes sgl/utils/__pycache__/__init__.cpython-39.pyc | Bin 293 -> 0 bytes .../auto_choose_gpu.cpython-37.pyc | Bin 1163 -> 0 bytes .../auto_choose_gpu.cpython-39.pyc | Bin 1185 -> 0 bytes .../basic_operations.cpython-37.pyc | Bin 583 -> 0 bytes .../basic_operations.cpython-39.pyc | Bin 593 -> 0 bytes sgl_dair.egg-info/PKG-INFO | 175 ------------------ sgl_dair.egg-info/SOURCES.txt | 127 ------------- sgl_dair.egg-info/dependency_links.txt | 1 - sgl_dair.egg-info/requires.txt | 10 - sgl_dair.egg-info/top_level.txt | 1 - 133 files changed, 314 deletions(-) delete mode 100644 sgl/__pycache__/__init__.cpython-37.pyc delete mode 100644 sgl/__pycache__/__init__.cpython-39.pyc delete mode 100644 sgl/data/__pycache__/__init__.cpython-37.pyc delete mode 100644 sgl/data/__pycache__/__init__.cpython-39.pyc delete mode 100644 sgl/data/__pycache__/base_data.cpython-37.pyc delete mode 100644 sgl/data/__pycache__/base_data.cpython-39.pyc delete mode 100644 sgl/data/__pycache__/base_dataset.cpython-37.pyc delete mode 100644 sgl/data/__pycache__/base_dataset.cpython-39.pyc delete mode 100644 sgl/data/__pycache__/base_sampler.cpython-37.pyc delete mode 100644 sgl/data/__pycache__/transforms.cpython-37.pyc delete mode 100644 sgl/data/__pycache__/transforms.cpython-39.pyc delete mode 100644 sgl/data/__pycache__/utils.cpython-37.pyc delete mode 100644 sgl/data/__pycache__/utils.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/__init__.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/__init__.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/acm.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/acm.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/actor.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/actor.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/airports.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/airports.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/amazon.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/amazon.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/amazon_product.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/amazon_product.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/aminer.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/aminer.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/choose_edge_type.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/choose_edge_type.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/coauthor.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/coauthor.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/custom_dataset.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/custom_dataset.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/dblp.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/dblp.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/dblp_original.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/dblp_original.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/facebook.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/facebook.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/flickr.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/flickr.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/github.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/github.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/imdb.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/imdb.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/karateclub.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/karateclub.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/linkx_dataset.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/linkx_dataset.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/nell.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/nell.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/ogbn.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/ogbn.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/ogbn_mag.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/ogbn_mag.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/planetoid.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/planetoid.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/planetoid_sampling.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/planetoid_sampling.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/reddit.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/reddit.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/twitch.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/twitch.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/utils.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/utils.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/webkb.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/webkb.cpython-39.pyc delete mode 100644 sgl/dataset/__pycache__/wikics.cpython-37.pyc delete mode 100644 sgl/dataset/__pycache__/wikics.cpython-39.pyc delete mode 100644 sgl/models/__pycache__/__init__.cpython-37.pyc delete mode 100644 sgl/models/__pycache__/__init__.cpython-39.pyc delete mode 100644 sgl/models/__pycache__/base_model.cpython-37.pyc delete mode 100644 sgl/models/__pycache__/base_model.cpython-39.pyc delete mode 100644 sgl/models/__pycache__/sample_models.cpython-37.pyc delete mode 100644 sgl/models/__pycache__/sample_models.cpython-39.pyc delete mode 100644 sgl/models/__pycache__/simple_models.cpython-37.pyc delete mode 100644 sgl/models/__pycache__/simple_models.cpython-39.pyc delete mode 100644 sgl/operators/__pycache__/__init__.cpython-37.pyc delete mode 100644 sgl/operators/__pycache__/__init__.cpython-39.pyc delete mode 100644 sgl/operators/__pycache__/base_op.cpython-37.pyc delete mode 100644 sgl/operators/__pycache__/base_op.cpython-39.pyc delete mode 100644 sgl/operators/__pycache__/utils.cpython-37.pyc delete mode 100644 sgl/operators/__pycache__/utils.cpython-39.pyc delete mode 100644 sgl/sampler/__pycache__/__init__.cpython-37.pyc delete mode 100644 sgl/sampler/__pycache__/__init__.cpython-39.pyc delete mode 100644 sgl/sampler/__pycache__/base_sampler.cpython-37.pyc delete mode 100644 sgl/sampler/__pycache__/base_sampler.cpython-39.pyc delete mode 100644 sgl/sampler/__pycache__/fastgcn.cpython-37.pyc delete mode 100644 sgl/sampler/__pycache__/sampler.cpython-37.pyc delete mode 100644 sgl/sampler/__pycache__/sampler.cpython-39.pyc delete mode 100644 sgl/sampler/__pycache__/sampler_fastgcn.cpython-37.pyc delete mode 100644 sgl/sampler/__pycache__/sampler_methods.cpython-37.pyc delete mode 100644 sgl/sampler/__pycache__/utils.cpython-37.pyc delete mode 100644 sgl/sampler/__pycache__/utils.cpython-39.pyc delete mode 100644 sgl/tasks/__pycache__/__init__.cpython-37.pyc delete mode 100644 sgl/tasks/__pycache__/__init__.cpython-39.pyc delete mode 100644 sgl/tasks/__pycache__/base_task.cpython-37.pyc delete mode 100644 sgl/tasks/__pycache__/base_task.cpython-39.pyc delete mode 100644 sgl/tasks/__pycache__/clustering_metrics.cpython-37.pyc delete mode 100644 sgl/tasks/__pycache__/clustering_metrics.cpython-39.pyc delete mode 100644 sgl/tasks/__pycache__/correct_and_smooth.cpython-37.pyc delete mode 100644 sgl/tasks/__pycache__/correct_and_smooth.cpython-39.pyc delete mode 100644 sgl/tasks/__pycache__/link_prediction.cpython-37.pyc delete mode 100644 sgl/tasks/__pycache__/link_prediction.cpython-39.pyc delete mode 100644 sgl/tasks/__pycache__/node_classification.cpython-37.pyc delete mode 100644 sgl/tasks/__pycache__/node_classification.cpython-39.pyc delete mode 100644 sgl/tasks/__pycache__/node_classification_dist.cpython-37.pyc delete mode 100644 sgl/tasks/__pycache__/node_classification_dist.cpython-39.pyc delete mode 100644 sgl/tasks/__pycache__/node_classification_sampling.cpython-37.pyc delete mode 100644 sgl/tasks/__pycache__/node_classification_sampling.cpython-39.pyc delete mode 100644 sgl/tasks/__pycache__/node_classification_with_label_use.cpython-37.pyc delete mode 100644 sgl/tasks/__pycache__/node_classification_with_label_use.cpython-39.pyc delete mode 100644 sgl/tasks/__pycache__/node_clustering.cpython-37.pyc delete mode 100644 sgl/tasks/__pycache__/node_clustering.cpython-39.pyc delete mode 100644 sgl/tasks/__pycache__/utils.cpython-37.pyc delete mode 100644 sgl/tasks/__pycache__/utils.cpython-39.pyc delete mode 100644 sgl/tricks/__pycache__/__init__.cpython-37.pyc delete mode 100644 sgl/tricks/__pycache__/__init__.cpython-39.pyc delete mode 100644 sgl/tricks/__pycache__/correct_and_smooth.cpython-37.pyc delete mode 100644 sgl/tricks/__pycache__/correct_and_smooth.cpython-39.pyc delete mode 100644 sgl/tricks/__pycache__/utils.cpython-37.pyc delete mode 100644 sgl/tricks/__pycache__/utils.cpython-39.pyc delete mode 100644 sgl/utils/__pycache__/__init__.cpython-37.pyc delete mode 100644 sgl/utils/__pycache__/__init__.cpython-39.pyc delete mode 100644 sgl/utils/__pycache__/auto_choose_gpu.cpython-37.pyc delete mode 100644 sgl/utils/__pycache__/auto_choose_gpu.cpython-39.pyc delete mode 100644 sgl/utils/__pycache__/basic_operations.cpython-37.pyc delete mode 100644 sgl/utils/__pycache__/basic_operations.cpython-39.pyc delete mode 100644 sgl_dair.egg-info/PKG-INFO delete mode 100644 sgl_dair.egg-info/SOURCES.txt delete mode 100644 sgl_dair.egg-info/dependency_links.txt delete mode 100644 sgl_dair.egg-info/requires.txt delete mode 100644 sgl_dair.egg-info/top_level.txt diff --git a/sgl/__pycache__/__init__.cpython-37.pyc b/sgl/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index d116424784444aff324ffeae9eedd6401ac578ba..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 132 zcmZ?b<>g`k0+m9yR9+zc7{q}AMj*ohh>JOZL<&O`LkeRsgC^50=Hm1mKTXD4EP08! zsVfg`k0+m9yR9+zc7{oyaj6jY95EpX*i4=w?h7`tN22G}0%*E+BewvK8Sn?8c zQ&%z+F$1N+#4kntjQreG{o>+6{a|+=-JHq_eV|(X`1s7c%#!$cy@JYH95%W6DWy57 Lb|BL~12F>tCj1?l diff --git a/sgl/data/__pycache__/__init__.cpython-37.pyc b/sgl/data/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index 0733336f1fd5588b2a396e227840afc945312dbb..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 724 zcmaKp&2G~`5Xb!)$Bv!EX@fZQT**k_7eJ`0gnFrZAQ!ZkmBVaYt2mYjI&a3p5 z_R5J@;KYtwL_mVI_UD}s{bzRPvM4got#2>hi!XuulksaU-QmJ*^VI=7@TG@7^D$rn zhAhO0MHsW#r&!=(6E?wwB{*eMOj(K<%f?tJXPC1b3s&Hq&2hmNqd$^MELn*wwsLk1 z6G&hRDP%B%9157j0+vw13Qpj(U9uBcZ@lW+ci**B#kSPV9u#uWs^_5@?uuM#u>TVL_NX+*GGF8{DgzoziIj79I6Q<$sB?bnD)d zLo+l-`XoITvGHN(r1{;pO*n6)vLOHf diff --git a/sgl/data/__pycache__/__init__.cpython-39.pyc b/sgl/data/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 64c3ad3845a4ba82d08157af05506ea934352ce5..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 672 zcmZ{h&1&2*5XUXAyQo?rXys)FoqZ> zFu?>8Od)0Di-T_pQ_LX48O$(;91AE$-$c)`gc2*LZ~+Tk!g91HdWCCP;|ZLwsEuU$ zoxuw^-9&Zvi-nYR>ep86r9m#O>A8}va(|olMv`gU38&<7RN1P5yTi7%qQCVO3U|kM zN(=|9T+l6?vBQ&KE~I?Ao=TU3bpv; zQz+7>7h~G=qc%5pb?nRQv5iY3q_Td#>2)_eKI)ZSg;wpFKm)fELPD4lQbIfuZ?SUL#hqu-M$+*f7?I!**i4ypw)Z&aYsF# KvvYRFR`DN;AgkE` diff --git a/sgl/data/__pycache__/base_data.cpython-37.pyc b/sgl/data/__pycache__/base_data.cpython-37.pyc deleted file mode 100644 index 1bc787f6a9aa250d2d71fd5dab3a3d404f7ab328..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12287 zcmcIqON<=HdG6Qr^gMRXF1f@&fPCLy z-90_S8F2|Dv)k46{9je||9}0}*B2Jb27Z5VVl(*jUp9N;!FMxEVrb_<&YsdIYnZgI0{8gCkY!FNA0eD{I1>7i8gJ(N6@N+^~5 zGD>BX$|zO*DoRz9Dkv@ZHI!;7Reke@(OCQ$XfT=$Gj?wV{cxa=D|bTG+HFUwdw^0s z+-s{aXzd=fq5=L@=T0lBieyBG5YnH8%r!jW%SZxa)AWr^%Qt=NBLn={zJs@;soyMU zx;I@-@n+GlfTC(#xZWFd?sf=@KfWolSMh}BkwnJGjLZjCxH>Yo&5_Zv`tC^HzJU_N z1~6u`;l$-`*zJc=yWa_7r`HXmxY`ontybIL3FBH?@`De$ouFaGP8jsIVu!7NVW=vw z{o;j7cLuw`r7*mA>BgJye5rSMa4FpGUGm#e`_lVuNQSwMy~EgRwYvRo)M`~osfWZc z%Vybf&0$TjZloPLCL8hNoBRoHTruFy*fAe)YQTXN+r6N#2&!Q!whKu*SwJx^wp!@X zYOSCUip0#}LW)dUGF)us!}$u&AURkCF5@-ho8|lF$h-w&?s>Pd_%v=)qqPKWY(pZ` zxXEu6s8;JFAv#D5vzq5s8Ykv)f>&Z6nQJrgSzUh`U8U994kBROZMD{jF=2N(WjAY` zK*^=UC%lP-D8FmmwN#C4Mp!gESeD{iJ~02-y64=`6NyP>N6fRe@se6bc8cj= zl&^5vb_Tn9SLNE7&saUU(ImwEqM>1q?nx0H=@$JvC1H5v1{O zlA_p(1`SJbPK6?>290FA>=dD_6CCHjH5#revA@fNt2Hi&qKI9sDdNf-y+J!lG>l$S z>NGoZhd~OGmjzwQcsh%wknP+>j=5^Cn5WF);uPo2hR2Q9Z|?@JR$Okib_f1`kNIk= zb#K4j(=Dn7IBJOrw~FGlC}PM7oVXN;P?Q#TTgb#Sy`pD(j%Q0fw@>!t!n%eh{0K>= z?oP@gGB@qO!NMtggoU(OK;HF>l6R5!{F3C0$d~Xjm(FptamG7-M6LALY;+v zw$%Lz>dGS&3f{?NA8swD*F>sx0@|2H?1K6UbgyYM|n!7O}Tl z0+}=@&OmJS_u^s#PI1>L>ae){&Y-`oL93^wA{nj*A7V9YP#`v5$%EAd$IJP;b`+^l zEfU6JQs3X-6*3f;X{HCov^ zc6se8di(E&bE-$uCjcoU#P-U3$=mr9K*OU z!QWMM)b_#}t_Y%u=1f%*(_C!l=VDWE@r>&^95ulq0^$-l9{h`Q;cKZA+#5M0T4L!W zPvGDJ6Nzz7C^0H=z21IZlW>8b7Z{ovs1=#@%tamIf6=fs_LYFHJ3TX9vwW{ z&qnSVo^XgHGqfxJ&xW=F6TF<7T`{jGVy6k^*X&Rg@yB;&UjGdRF`7J`Wad>&BMfSI zJ~fu3f^Xa{sW;G4m@ubrpoc9cnFh_1RbEPZzYk+)w95T*1XC$y5$03id@4p2h8srv zdr|dhL7Ub7iqz8pe+sycz@>SujT{*lqutT?w=9hPdyHKiIWWnNjCI6N#~544+=L}D zvTk#DwzYVelQ3RY9CkGrs@>xXecwfo?ord^wJ zFBx;7Xb<73YM6)cn5o9%Qa=TzzQVG35SutjN-Y59wvVVWZpLjqAuXRatTe7?XkhUN zSJf(89!_D^#OQkFQj#Ao5cLxp*ek#?N1Kq`2Y6EBpuwlYI-atkS)K#_bu@kaDELIs z#{>Vh`QWdk>5p>YYa*O%GL8j*b7Dd4w!^!#c}XUj9eywyC7X_gQe6Xp*;v%8X!_F} zd{f;IOH!)<6pk71M5uSAAYC#=N2+Bdb z2hE)sd5%u$IV;yiCiD zb$^B@kz1+&8g{eUY*c62cSSBm0V(b~#oeLaU?TJKYs^hy8e8*M(7J^uPXfhK_hz*-7J{>TT4>!kn|0N#IY2$Uor8wH?qB1Xfrh8D*(zxc8hX%k&@ck{H#syVEi}Ylp9lk;@;MmT z^tU+-OBw^2K=2ypy(dCM4}T6KHvL@=(Q*!vK5aY^CLS^7U}Doh>NC7`sYXSJjOYhgNGY+4jwjrmcx^5 zQNaV>^SJXP>`RHPx=eTpwMw-`5cK*VbPjH|q1 zkHx0eSuy(xsY+$J6||#$6@))0eg}zR;VNfEPgTCd1XVWr$*i46CP!5r_DIzDk^K<& zn@=GCp3$a$3w;}ori~&|zm6ofciXsa)J8~ri_F1@E5zUFwZkw7{}X-3^d(9q*Pf9_ zfo0_7godlg?=v)Q9U<)e&^&T{q=<{zO--1I%fEwe*eJ;52cozYL>B;KH12l)A`gqk~6H&I;2%#kp0A@l)S=t`FL z4K_{6Lb2|D@Pr&^SUj&DkKwn_eICO!ST@8p#jqG-kpWTRc_M~Sa`>3xe*JZnGmh&3 zr#TkF5{_@Oc8X)x;hH|ral?W0-E5qCjQ1@uHF^O$utp;)(KQ`$yUwCG=-hN_?3Ta7 zVw1_cOk@_O2on+|hVUl3{3V`{&LrXmg|WXr&eafw!&8tyhfz7jaUDF_kFvdnC;SeQ z%#S~v`tkbe3%C7p?S%_4H+Sjf<}SV5+@+VB^WWW^|L*!@|K0V+{<}N&-^G7kiz{yh z2ty8J1B|_BI;3Sh6kPZttc))9Rtw=bdd~;XjtaO2KXtzXy;X@As#}Sm7#Q@rw{qJ+ zEpEY3dkwX6+tuHq#~tz9cN29Ivg^Ma!6QE^K7_VA@^DvGI4V6fn7eDMA$mywu77ca zvmt_i-vb1XP}QfSlGOYdH6_9GWAv7gn^0Ny-TT#pb3)^p`;}1@TA@6uTrrSyMwK<= ztH%B6y`S7RMmVD1^*%AxSs7P3xFTaNf-g5pMy*k&U^Lz`7P?u2i8~cSy!Z3l20o@J zLMoMmt1|jBMlXyiy7$OCs*NhX@eq_|IfG-@44yl1(Euo)nD;K+))HQno4h4Z??S?7 z?k|AH1xol6GmOAxkupaANTe)b2hjhI=w(q7N#0KLqcSCh+RxCVtZT#10LLSo{(Go( zIV*tvAE+%!?bmT@h*=p~Oz&EsSZYJ(ai^HoVD6-p%$=Lf9r}ltqh)+g0}KaS!pFDn zFC>&$>Vs$nYI6bd1ZPNXNzo4ZEKpJ+m&MV-6{CM{PT%Ud&nMQsE6GO~72IS}=gteG zt~F;wVEt|XpV;*mf6Y=R*PH!OSOPRDt#HO=27^F7`FynTevp~<1@3H1>$cI1EVOhjZa*r>7ndqQ_2=iXz1BljIVZr=Kv^P7jz8fbK!^K9u)n zO~7)Jr7Tf}ScB>e_7RQe@oL(VQf)8Ml5Qbm;7cl3E{_@A)_-U5Ug5~q?3QCc<6B_pe5%{+~kQ%U=I^I1mBQL4c!Ib|^&uI3iPI9>shAwRzPh#`3y8Tw2+ z#$8wOYbtzF!|$1-&O+S<)WOeW9n_H+lr%8-;_vhVd_7k0_V@Oq`UmY^SAUJ5cRM~f zpp>Ps4rjClOzK-qZZo;XcWiL@3fl0I2n^vycDqIA9|2M*yq820F8aMM{dgC}Pe!skt{vQQp`RC( zw0RWi6q>^euc!N--(F;*$kanS`<~;=i5v%NyWQ^&Au?zJev*fMRo}<=f7?B~Y^6HRdZU(uZbJJ++lh0k$#sw^f{a|V~V|@Mwj`Zhmw>Yb( z&^NsTf>Vw7I2SVBkIZm3?s$GTZ=YxT-D(?6cZgiJgEc&zUrM4DlI?#yOkBpZG2v?! zYT#ZD*YY&3IS~pTGq@$wa}zhrY$Pf`(_Rk6sc95ieIAb)ZqYe-sJ)_jctkIX20b1g zYJu5!;sRlX2RTd=LX*}E2~YWaWS8?G-Zk!<2d_|caMfPA&s~f_^BO*d#3x$%rupTO zy<_u(N$hJ2?qb!q(j45++&NlY0-P1c+KOm&7{bwOT@EeRQM(CKfv-CErdGUShY@+9 z6)-FZin|I0y~N1G%68_$YlOfYf*P=t|Z6Gm3^j7$-X(3zc^ zfv;d}N+$CPt+XM-o@V zPIvEcLx1ljU&3vc@cs72e$?%S^7*ZN)+uh0j%x6rD-d&2n2{5ESM(RAS3y#G=f!1N QwQyoSTf0_!yY}Ay0v79|Q2+n{ diff --git a/sgl/data/__pycache__/base_data.cpython-39.pyc b/sgl/data/__pycache__/base_data.cpython-39.pyc deleted file mode 100644 index 0657160b05b9db0a433fc91795fb6c04893ce70b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12433 zcmcIqO^{sIRepc{`t{HB^o&NbMpm32Td~|m896365QI?tBgesXjAcwR;Iv!)-i&(G z(>-(F8+)d!WkRIFE+7>rEU03U8l_mE6R<$MD_T)|!V}-uF=RP%EHT z@QbJwQ7fWW^2?}|Q7fS~=T}gxpjP(HYes$kr@+Bz)Xms^E$D@Pg;KE{s^(TJQk}hf zW?T!mTPh5iTYJr@kN>KDqd97dWJa4Sq<K;E-}J3J257Q<2Y($+ z!CFq^zUFGI*YbV|Se4`4v)z9CW}8bTKbFb5=a9m4$RcBCM&@lRTpF4i=Fn(bJ$ER7 zeFHVH@0xMXY}B2&*a z1aNGFGqcc1cNC~*^CZ`Ekdf{b&C+>jPp4=m219c>fvT2LFs+V13$oN~ZUhmaZZ(_B zggPR6At8F)dk8E=i79*y8Nt41+_Y4M1V<2}O-Ssr5!sv0W)33uMc?x6I~->{Z`>9k zd)HVueFv>xy7fJbY2pM!igOLcJw3#7a^@N#lei6xf1`ZMq%hw$zi;h0U^wMXao9a8 zwpO1~WaupE#Y+_Ac7JR8at+=8iDV2O({$sYYDr$bgk#j|hL=%DMjyf_hnXjkL<9t% z3I*T*9J!mu2&%!t`6r*D7hDZNqebac41wy% zibhIDOO^U82Mh6!up}~F7!$)~LyA4c$}yMBMe~$7n4clJQTKQPd#$aY*^Gb>6!Xbs;(e~ zcaV+s-${`mbIlGMNKNhzq-ZUNvg_xi?4s=X1u5rIF8U=YdnlLvIVl%{!e-I0jBAcR z?^mU#gi#CrqLj-hpYWHYJSSTIRP4OqZv>4t^+@t#nOpxHQuq^O5J5-u>0$)oziocU z96HpaqGz2bx9LXtb^ErZ8fdYi*%mw9626W#EUqE5^Eg7Ak#|rSns-eRb}M4rJJM#M z&B8ca+J1nx;?PE$E!w;^w17o)o5X8A1vO*;gQ^8zHY5s!`v{p zjwAbB&OIT=(P`X&m|Unu``bZnYd4H)*MnNC7Q)c=HdY3gU+uqrq1NtqYptGN3ox?z zcF@_l5v9iR5M8MCwpV+8OR3g=t*>fYgjaAna^a2Ec2Hjup^t4W92Z^#uosl-tGL=E z^k!35CJw3qrpH!qJI;@mQ&eB7#^&P7{oaONT0JTaqv@*uHl$iF1!UvJ^s*Y^cp=@^ ziXs)N1+FnaYVYlC2^xx`8s|`j*;*QPCiY$&fO>wU?+Ml>yi$rYabZkM>~Yav-w)!V zG-|Lj=ZJ^yJW}{1G9F{kbggq{^?_x1w2|X~uBrbOt<$FaU@%8|k2afBvm(YkkJROf z$u!k7Xi>C>sdFuQWp0Yj%NU`Jhc#FfU{l>wI+M^o0^)@uAgV8;_48Qyw+^k2V|$#{pT+`NJFR{J9dD*q zuTHjDE1o*M?TR|b`JZF#`ekO^RyYPh1P@U-B!mdO z)Sieg5pzrx!nH$YP#h0gX3?Haj#75g~V`+ zfj$vC4XDmW+XWuUkL9s}{#R7Q?DE1hHqc^RVUojhiMbu-d}G&CSJ9K3GRRjk!WN@W zljzA=D2zs5hq*MG#a=Oju@%D#11#(OP|Pw+Kg|9gMCF4yZJ2wD(oPfpL#*q_x-`(0 zp(FER^g9~maG$mom^mcktuo0-0 z`wx>-M;Y-|>3r$=^vsVi+V@xXQXn$z`u+3!sUaV+R()Y)_kpSU8EluLkrN}YKF`uB zGb$w5VWJ5qH8Gjhq=C&e29qId>G)xsQ>`B8uW^w_xrjHKv!^bmtNZZ0smS93IVg*% zUu50fi%nckr4KdYa8!0g-8brdd{cUOtMnucPA+Q!rd1ZSRl;riy)zswSN(1PsF)1$vOPGCb zTyy--`HxG_$>1b<&L=hTD(Z3WMb+BA(P%Rill)lDBZ&mkL{S29$F^~|c80~^LMxjyX`+U!mvoGh9&gH@6JZfK5hqaOL_~xF zG`djIm!B2RD!PA+6mrK6dkKL>f&U524r5Tf*KAHgsp>b-rUf_aBooO$M@mle2-L@$ zpDa+X6O=f+S)kbQm)W3rk2_j19SKT2>?}|WBmPwiR6#Qkq1;aaMI84mDD3zk1*NK? zkQl_paZCIZzRJrW<>lCns6fk`!`4rH29m#~oj=xPoOOvMWpS;v2KbD!~pdGne zatgBe$d125N@Q?8McPNl3$kE|yvwo6g2j%%KMd;=I4VcN5_^yZiyi-vf|Z=DkM+9_ zfr2d)vP6eufnrCT0+rV+JnW=c7$n;yArQD8o6v!@vN}-kO?0bk%qrtdOc(+D1ZlKK z$B?xIcTs&=xU=J*Crb}TxF7Z=J_h$StfBbEF^f9^{M!V;A>7C7K7v77eahADj4ta3 zl4}x9xo0(gN=dKVsST;4A?p;yWvbEjpcU<^ApCb?AS&G|n}bEYLFoa9@Uk&O_Uar8 zDZFa1S)$|*?7MhWeFy<|jh*^U46HjEHR1P*Xo&5t)?OxlC4z5vTVWW4|ACPc^hVkx z)jx$DIX1>F#Wg&AewR^i>i`k$yXJx8BgX@g-B1kHAad-q_BvrbCl5Uc9Pe)Fd)`Ij znGbh%gJ6JHNwKGEc?64^95OA7U@K3MA8e=~2>)w>cV)uX)PO0(r!#&3|3Q05q6~`{*W(Z_mOGPhS~DALn;~2*w(x@R@S75ePeS-f zE-@jzS9<~VG3Dp6q^6vxC6otjouQm<_*5WGdEJ2@->9EbJUBA}x|Ab-nZ5GjC?;M^ zFAbj#+BH}b-A>)7o@FLy>J65f%-YOkH)aqM6zAFd>&X5LNgiK4%Xm<9vz-5|He3iK z)HnL;U`NzfrbTowp{h)n$`f>)5j-$gdQESyd*3;JCS-zLbL9A4+#)Qg;(dXaNe zFLG|`lHjH;39di6B)I2uXkaH3nyZV?_mXxtExX97Nq4zXekJsA7Qja@&v=8 z@7^j4erIl#hGl4z;;?kdK*<@FmW?kNx5_&|9^!VtYrJo&s>~_L_zR%NjYcz;sfREV ze`Ow2xfl0i6E=Ws8)L%63v77Red3et;20 z-5P#?RXncJdlRiLcL8hvBU%g6`ZYXXVi$%M^PATDmO7)$c-$PfV6UV$+ADXwR~R2Y z8ZF?9A3zWqzIAJEgomZxj259<=fFcyh1`}L{ouLR2;F;HM1ht#Nepu_T*Jx0w@I!feWl7?e zlUWf4hj`In87zN85Blb$vsDvt^09?_f;c_%YX9v}-g1$PLe+NxJyjaST_5XNE)W7? zD%&FvekqMG@qrnhf-!n6p1?SYhWd&4bTf745eFviO(K!iqu0@+5)A3F9-_-Lm4Z5CRoWY~$mZS$xR5Y)sjsF^!9_ zx4OIf)muUZcoUL{v`5cHsPyn8rAt+rr?*G9%q7yGtHp^pj9!*``3Xnv4v3Nk-M!J@ zef_s0%W{*YE`h07FV*hvB0w+Kt{IzI1=UsB%x|M$;Cn5Gy(e$dMe{6bi+GQ=Oph3OoyrWoX=b;5?{0E<)w_K;=Qj)~hhicp9@D%l*)?ybm3_D5KArXVB)!>vx4~ zeOPw**ZOder3To4%`E&tiVa#1|-#Ic7xdlvmIn3cOVT^LrtWoa*(r=o9)#dQBWBLc!{0;OH=cJKCwq$s zW;oeu4;?p%fXKWxIGvUi(F&u}e;jBctr^hp6%3VdHwABD7G8K9Kyo%R0I9b?)^-Za zsaY^uiyjA-oX9L#Zv$3z6f9AZqD+qiOC&1;R-EH{;a&>b6a%Gq3XlWjnB-VrL7`*d zyF(~2cwtZ7;-N)YdKq6_;zKa~8vXRp-n98$CXTTM53uONBh#f}>Ow7^04|FY_4IYD zt{9L&D@~D-+e!naHCPUOKDs?~3S_{Lw9rNv7H$^M*Ld=Trk@eSR4`+TJZh{ldAvZS z$h(jxPh9EZgxn;4ktQvPVN6JpOV_usEP^gVlH9#W(kveKO~q20iYaXK8m|eq45%lFPF?`WTaE2JOYHa_5%o~eJYO`Y?O@^?=r^hP%&RPI zjrumTuP~z;QNP2CCr|N6sP8bF+*k#p2{KX=jMg!3o)|rWg4|Otd!DzLbuUF*5h-t( zinHo-I)p>rrGAbX-!0>E*zRoauj-G>y><(c_CX?MIYlWTJ8BruOJ$=h| z_w#8ZhkB&6`b148={A>ccw zs;jH*_GHY?iWh9v>2rV2{r%3l=iavq1xvx-C+i**WHocD_@v3!S1}`2<1c`p*%trq25$h3!kHn%KEAr zSPfU!I(1Jr_Yf*HyTaM4c|AnA_g1yQSnmlZFy0lL4no#Ov+dNJy{7901!UUoil%fL zo+D7U{LWUd=Q`inZF|j|yLUHbeS6Doketz9JVMu1MR62c6^gA1RcH?tTXzhh3*(`3 zpxP<$DUoJA4W4_3_)OR<+YI|-XGK92QPx798ByZ=IqEON>W$}fp!=1btuigV1*fwzxo4}$Er>@^(M#X8W}*Vk8bL8{iRcbq`)x!Zx! z>NUHFk-eT5m^BorVc7)5T3m|5shU7H0;}f9db3+=ialh$UvD#)t$B{?F}qkxbh(B~ zoT#2vJuqCSy%88J$)FZ|rpRUV|BK64wt5}sitFxNx%uAvFSqyiuDF}+D*=>5YT8W0Q2%@NIwWyZxpT^fOCmOIe%G|)3=0EzSj|V>u##dT2SSq%g z%p;7SGx*_#h2$&+mq6%qH4sJXmy%MbFViKM$m}T}V=qNAg8B8H_pQ!$+vzx6Y|VG2 z?8)gegCb1^Z|d#3x85 zyH^*_$CB0lLQ-;^<0*(UM+^Q^LfYh7iaI(Wj5af`rBA?kw3^fdRhB&<=Ygs!-8FC; zah#?fsP1c?F-#$!eu!9X*fece(RhAyPIokYyH_o7#X|HN`G;(bUejt{lW(!nnYzw z<6u~H)snQW51P_)Vdz3p@ZD=z3x7aJ7f9JNN* zvl9(iN_KZ5*F+=Yy=0>7l!cz{!Vr2~$JbEq*AsFlIvy83l}wY1HN^>?mRBI|ZbIY) z*pJFimw420m`61y3~e?xm^Ko!Cb}2rcrsz-lZ9m>i>#Fp1%o5DLD!=6Clj=AvLJJ0 z@O2ZCCI)bnAr?jJG&XxLBH#fHu~f7T6&wuRPqhI~8;oGaRUjGGuug}dl%NfBYc&<5 zyq;`q1*yIxNttwaJ2jf8z_{D%wIxqN2y%fzC!`|?5m@0sa}qV1W>L|~&>{Uwa*)Q^ ze(J`{uR;WxE?>v5e1piFM1GD4Ph^D3U!w>zhC@O&XJFE7QET@QO15@lphua0wr%CL z@R+t{dA_N!zj>swlS`e>36v{tEg{F3p+}x(dc&!EyV5ynH(y0P4@oqwGZ^R-HbAoc zsU+v9WTRcj>FFfFzW{NcBm|#HE~P~A)5zsTG96Gj>30+ICKdz9qb)IA*iqOb(`&kv zpb(pUbo_gikT;UZCR{orY>P92IMwp4f#;q87l0+u5zR!xTyOTVSH-wj;fL zK3cK47}|qXq_RCIE5yA(gzS?bjT4~fNXg2d)UqUPDPJZ+>VpAYzCzrqL|!BE^F+Q* z1Rfzp(khWOW8pR=OD?hNc>nQnFN=g3d@k))*q7P7LI0+X|9Rqy(HF+YG@Yk^&MT>o z2EO=@miaY&E?H`04-0Tk3-a%gN9H;VNc3`jRhPerJd$$DUm`MwhH)dWBjNk_c*&cp zr0|459TC?k!#{zH1^)Cd9l;JAI~O{0>^xi_^P+4ragUU~%*koA3X?;UdWflR1#c)xs_XS)kY=5n5Ax z914#;bO>O*#kxQe$Oac`$9j@H0byyd8NWzH$s>_i8oYFAcl`mrX}*!=V^u$056K^r z+z5ax$7P+uPx2V|1Y^Z#2}#xClAg9pgrVtA5;9*nE^~~9W2eb<93~@q!YlN+-d2(M zPh-hy|Lk$ePtgy;$ouCBanqJe)YFO15!Spzu$d|a5aIFkMk-kYe2e(B>GBil1{9JJJXX<4E3o3AdxWf{e4h<#}s z>%mw*P$Q?Nw%)(|LEC}asy4gZyIvLcXj4?_s8}Vtz1ovivPd|vvTB41{UwMDEQ;qt zp&MAY>+M}9^oB+@ynfHwKLTz+u2$Q0yr$=LYPG*b9WDW`0v{0hfh?_mep>(I5)&FB zUa%0@U8h>Fy14%7Zm#rSjfS}=H|yPI zA0xWuRLSK8i=n#fA`i!KeQdHmaS>9Q)u_m6CNmdUO~`e#`) zrZB>|2SP_(Sd$G;zJ~~DfP`xxiv1=cM(XWbl$gR=+I6O^0rWUr1OI?*E?Gq64>Wll zXlhz7j=2-oL|nrJh{Xqyqn0u$>gkp;ck^GU8N9c&{uI4;YQJE5pU3G0iN6AA|C*4v zczn@^<)`h<9&_AlCx|=@`>x`WCsx+SiVRIoTA^KnuCZBM*o!% zwsgFATQM+K&^Fq8te1X_i41|+(>#4c2dFN4 zMyPwb*TR$o4Vp>aR0k@qS_s4`r-oR2?}D#hS3ar?bj0I66{WACG%jO4R(G!5Q8tx< zv4^tA^V~r94PPIo1TJNSv4s2)pITBjwKK@MBktgm#!LGE?*nyUpmeIK4AR0FWcI#> z?UX@1YXj3aeNq{J?Wcy|`GO|nr+pI+E!@(!vMtLumlV{P+44f)5FtcH1uutZosdXMM{rY-lKH`MD_dAcK*Z9jpNaB#GvJw z54AyYFoPD;2o5oV)EXVFnMX=#FdL5{#&a+ihcKdpdB22S%`GX5%An-Wf}dx8*4KoI zcH-8;pY!MM>BMbdy9T~xOAxaMW!=EH?$4S&-&R7O=RJWT!>06w3ew1=Ub>P`wPBB zJW|j4xkGhB2eer0{cqfYI^JyFce-Qa$33RT$sDRz+jyf9a^5bkf>-)yS!=VsbWHIF zTCFAtSw1IV>o2l7 zLoq3LyWRAlvz<^I2kH9uw$l}SFBLjR$YqH~C9(&X(r#yaKlH+w59+cDTf+?u;)2}w zTu0ss8&uEfxAXGL}cn(kP79H%O|kj?{znC zIb9dm+H@7^wiiRAueCES0)$Z=sVjYwrh?U5!;PHb78;IdQZi_96S0jmh}ouBp<_sx29VJO zAO;kmf~{{1QUGE+g8>M?4gezXyc~kCpW38j%Pl`W1e8WV447^KmLmYZ1sGF80N(r9 zPor#qV0zi1g_<%9tk~XzTnLZyue|)Qz_6MEtQHwo0p9ma`S*TqkOin_0lQg0$Dor9 zAygGItJf=MPn#W(86K@tBR_ z1>ZQNJ09Xvhbq#_g9X3rFWl1r?N?cG(o?1~sT%b01Wpee3iQX-Zd(VwvWm}`Fqd! zc%K1vd!&w|eL^4pE|!k*NJu}@`$zgt5zX)_LUBm+ffLDA8i(i*`GvZAB)k*}S{{AZ~o&$AUNy4NR?jB^^qm2HT;=V|;hLWXKV6gL08#302G%aZMsE zA_Vk-+6*r%!*obpR9c{ssvZ5jXQDsfdq} zN{8e*fVc&~57hxVPC{isDJClcQza!}k`mAj^#bCmP!dHbl4aeB(u(S`mWEP5X{0Se zHPCMbw4%=8O4W=*q&DUdf*R5*>f=>L+~Nj#4Cs$s>G65sCkQ-+^FRu9q>i5lycFP+ zRgTaMPgl5cK#8YKl?Mb zFINpa??z98$>hMiEUI@}jazjW-Z0qKFx#qLuNtUpH|Pbl z3rARx!7I~tz2U@XR@&_3ULxFa>6(2L9*>Xm1H3fwp-q2J+c9p2))fHhHqud*w)RT6 z2Pb#b7uHoWp3&y*xBiAU64jkn`fnV!eY$n&CY-HnSKGZtz3sjQW>mxKYy$YchniSA zo8*#AqYw-y@h#B0+d(=Ulpx*M>NVkPOV>lFOVy>U?}z?ng3s05*jAPUMA-9FhSRRZ z(-j+%H1(lHeW5OT+jyHgHZp6;DYGrK!FBi%uss{`aW$@8i0h^WW0#KRIBM0-92vLI zXq3N>rcbkIH2?|yn>yLDWV5CjZY*$Q?Q)#fU-{D6!OIe8qg$(sM{io`{!G42gm2d_ z5XWBnDsd}BZV>r4kq?N_c}0>_F4SFgObKtLkK9GGdh;pu6C&ev7Q?rE7YRSW=UxIi zdhcwSW;woT<~!zm%6i|r&81e)$Y)U|$n@Mb(oKQE8mfReR(tY*I6Bj@`V8`Jqq)7m z=5AxrI;p-56{_pO=f zcK^+-8r#!U@2)aNiV%@PSnUJLLJxumKkxuT0`bZV;(-SqOS~bi6oe20#Dj&_eCJkm zb@g;lvUZGi?7DsKz31L@f4_6?xwoFn84~_J*nG#n_Ekyx7mCFHR776E=ll@@CNa4! znZd8zQA|}1Xw6KByrh{FzbP|?U!|>g(q>xlsBNQ@F*AbJ+SyLd%ysf+Uc?jaLZ@gJ zI}7H5h$q{Noh5TgmR^@wis|4Ml8d!h-DEgGx?fS&3%G4N}E;L zHyRGNI(3&f?;@6KcA0(Ga(hT}ZdDau+w3vh*Ir}WHe$wBvu#`U-KOLEIWX;YS(Doh z*JdbNdUdDQbL?;KwcX~my_?&--rsQ=Bq#nCjnalJNw#FlOfnTFGv%ISs3)dMfCdsRw;B(=5mGC~JVHz>0!D1A2ii3OZ|NTRFBA#WYr8%OWRl z=aEy6Vj4TaRzyw#e=;v5X&}*PCvF8PS67&n~MM0ke{YCbYpii@x zF_&NRGc-vXb+_)=Zi5DR{71Bg?^S%x4FsOzNtWcwBgKr@gN_| z6(*xVRh>r+gE!| z-`85bW)~^E*K>W{LIDfQ$In|)DG^Ut3|;XJ%jNZE*J`r6V7^^%3o327w&MzN#fo*? zLM3)s&wbU`9J{^cYeJGnEx0f7Q|SMLi|VO|`kT+S_wQbEw%eCT;-#Ay zKOq~#SnKaYs@3c^UCS!c@Z}Lma#GIAMf^+R8g?%Awmevl%Z@mnh(LY3{4d1hO|Pu5voqpodyAFz0SK4Y zmV6-dk_$XNkY%Y`1Emlp^3(&_dC}EI3Gk_hNW~UT(iR36@{!w z4QtTD8q#46N#YVukR=9#Nas2KkL}c$?krTCFx%h{eI!wbutL z)1{-St_^Z;1^5z{El#3A(fbCcO~%h5@DqLB?7Dn`Xr<@)dZ&KN2H-gSJn_)<@#Q}mBQGKM?Mc;xLVP|@=rLraZJ2eIyWh9xN_-j8$Y>ZZwkZ&d ziq6Q5RsxyJV^oUPoUgX+u5U!eoq3}DT$FD$x~>^(#A>|TQ(Q9*iFVk`(jO0Yx(mbU z_yPC@$bBy+dujlpqK_e07loT_Zs$1;7t+m`*ePW&E_uAviJ39RlM*VMJIv|G#M+9< zo9br7@hAe9j}w>*?YLG<9E_95daZ@|k0PvmoG^WCDt2SirUrA&FeQ%GZ)7x|MZ^VW z0?(8!85GRq50xR#D+1%R%fLF=b1EIB5{$Oe^=iUTxINz3@e>1^lj7;_bu5}UU%T1s zwK+}4{QwJnjm}v|P{cQaK^Ixnbed5~DM8~5%JIRPuzlVR@NYmuHNiLV!(XF-z?i>I zftbuUh$8J51Y{WcI!zWe@-AZWMox|KnCbtsiDC)NYhqf=FtzbJN7^{O$mvW$TvLrC z{P=7P$!i1+w=GHH&M?$5>G1&#{86x5HI~0(lb=P(XQjnHD5SsqYq-jPsA5c ztnj1D4wlF|>OQhCM8@yNWKOLIl1ZE4c!5U%tSSd)+8jxKE;0`3==f1g=9He8;@Yg9 zp`#t0qAUx?OsmsldvI%5KJ2~Sy?Q&y;b%~rucIi`qq{GTWWulvEXVWIZ(%u}AnFnY zWZ3vgoE}}9b73Bk(&nT``Bx|)g(Bdc{~}Requ}clkcGm3g@QLJpcTVu=7KT)RiY-F zPf+5Rm_|kcpTiKq7ELR%p^iUY#lNL-Obc_yzvJ{{eiKL^e^HH!%t>Mc-_##%l2`CK zWcy8gI`B;>e)e_p18u-G#kg!%RsI%u_^(kw*38^)eG3^sz$Z40E^~^{2vw2z3NigN zf{D=2???*VNMe?T$Wf#I9zg|mswfRdC*V9VFdTjIrj8D*<5DNvh(a2P-S+3 zRpHb+!7gGDpYs>qu(70iv90^FaPx37=&Bjvx8QluoYo#E?}#{M1>2OYSfSzq#X->Q zpRr~-5IS<~CiVQR)CpuxyfMc^Hd)wHPN(C0Ou+dUk!*Qza#r>e zb+8DK{;Qb4d5b6P?Mw#=WL~ARq(aAMnV+?XgA%lA2C{58JMc2=C3GxTwn zg;@WniT;SDPmIQwK?G;~mV8^{WjgJ9@+v%zVXA`EXD6vHtmXqb@JUvc%?B6XYTM8U zm1eiU=T=}LHd%$vauqV5D?MHz13^UAt6HF2Nu+O3x;U>nzHy`8-m`@(Gdz*|>MeW! z2#oj{%i6ZxrfYXB>mN{uLm(!>bwlnOLm8Z&HwICO8JLY0#BJ1q1tGigf#NI%F1Pn( z(S5b0q5D^mtGEf|J|DQ@kPCl%OKl}R`Lfh~DdbIowqAfIO8JTwBXfUgB)_< z046pE7han?T9wY8<5q6km3qa&jZSxaeegm!)IGjk?=}Y*)g8M+o+7B@%ANxrks|67 zcH~c8gtSp5D8hdoqXBQf(L{xgTkkgD@<$8&^mR;V;BFQc63rq{Ai#YMdf>onYq<@IQggA;m#HI$cb%A}bFyEgz?kne?cZ z85AJKxe^wy1C7Rmi!%;xR<1|BtMCXTb=D8 zo#=5Z_tW?7mE(kc9T^ISWV}u0sOGVY=K2&AlT}HZQS8c(F_!@pyNauBsesWnR|_;w z_i~VPpgRZsKO zQG(&-glVhb4e7+Hw5^-~-|l;mp7a3I2l7xu$wX5cCYd(Gm1!#l8MUGA=^m+uzwr_y z(BcLq`WC`iT;X*WS z7{lRW6vHSEm%Ji+w74p*NW-GH0D4K#3!cJsv{B)bx9Bb1Qi*a);9kHC%3-~E(NYDr z2+|da+R#0P-Yk20uPAy$eOvYlR0@47d?fFFA1P(`gqJ^%ky3J3ywYg(Q1bF3EKsU_SFc0aZZ~h+-HCbQ z%qVFxc<~P8Jr?j9T{<%i@G3pAe2cSP=%3Q`xECwn0|L?x=Mb zTbtK$0i1Al>V3O+I?leb|IG;aA>p#_>usITw>K)g2L zDXKKN6R>Hx0(AJ#lQ8OZd|}Ge2F>FEsWz$Nl=) z8A#6$?qEn7y|!cOpz(B}(>Ha*1F9%0pb55mTX+?5$DB?9qn{D#ibZKcQ8h#ad4XIhs^pA4sN^>!c69-;7 zD2vEJ?n4dblS&?8Sy==ha+N&l%i>=NG;C?a7e!j~p(ZCECbhCMczRZM%xUcmu#*iE z0J}N{b^!GS7cTb^y;1?dGUXOhaDfeYrgtoK@Cfn%IVuC-AOyq!2()1WV2t-G0*DDj zF9tyDCAR38^0t>80a8OS22{5I%^?WC4e%8p>?KidFx1`j$N;3J1h`_O4>JKo^1pSn zqnv=&1fVr9pcN(eB>wkaW|#)BrUAWaFC)Mw9l$RCZ>pj98xC*bqZ;x4`r2R0x|*eQY$e!<7tOZn9&2K_sa`6^KBgPDPPC;;T+@#06yuG0NCQ3EN8KYgqkHmDTU5y~w??yQ_al!o<% z*(rA2P};a06utYDhsP37uSd!s?FZWZmqIFfACr2dx8VAZs3cS+q!yqQ(@=^k-=KM9 zDa_>3fn1U>(@Smz+)#Q#Inqp*y`_0`3!@_i#b!XU(frYBz-$S{Mk^s}oyJV3sic@w zBUEgeO}>RK`U!TY#Dx-ih(EqVX%yw#2z5)f)Qt3#}?@~aTgYP4to47`=#|K2KfLLD@ zsCpF-eB3~j2={`Y2vpHUDoFY$)GvOzW4m>W9z*fpqC9G_8K^P=+>RTM*^GO{CHf(B z2cabS%gib1r<)ajWUy~Z?kD6yb<$76rg%xi@G zKR|tS1YQI#8$j^@?cq|;;6CjCV-kkpD^`9ZJ5MyAnD;*(=CmE98fqUq((Z^f*8!`#E(=rO<3<5 ztxPNH8lzm02Xs?elBGU`)?qqJ-vh@e?M-hIqyyj9hK5`441pKYsOVvz zK~or)rQLt@(!!AB487(yf~sa_o$zrjnP6t@r39i}*h9D#!$YTr8(a(yzDHJQ;A+E` zVX)oeHB{`o4clgpa08)=n>MfXw(xeVEm9i2&dp}G?lyZ}htkQ&5S}?_-Jp8s)wol4 z;DmzF3wx^K_A0(yyGn0~9r&O86y9mJ>kWJA_(&U;96;CbIVA+!@It(s9pc8_gOd9_ zWmm)5oR&Xs1KaeY_)6{hU>_aX6~<+hz>a=X{4 zx1B2>#^tLEG3dI5dW39YEpxI`0yvvSC|~LJ{bVpKezLLCYr+|ptOsC~sB>Q5=K^L3 z0CAhLO7+mR@!v;aF3lNGvm8xWWaH7~2gdH1I_vfEs&iuhS@B6T4Ya|r#Y0+iA*7>f z95WZyO-sfs9_2XdZz~)bxBCP*n`rucV-x0Q@b7<JByTm%srQ3C@H)&Ruj4AoI~3;xWmr}@KePah zSJ2H%5RxEZ082&imgJ!bMFcu(A>^^>!1k%@9P_FvrJiBt*phAk3AFhG1O^b{BD{&N zUv9!Xa!YG#VHG%gYHvY5&B(HnKi%R*%5a&=W}k#Rj=R#Mq7vv(A-bg`M$jm&#i z8q4OIHwzcen&tX=Rg2Cc%-_60*U8hUq_vy9hOqLo?2GSgElVY;>@L|lG~*Q_;I5PWdYoR1Gs&wd7+r9 zg)|Hnod8J^7a3We>I00wk1><;D@YEo@UpD0q+nY=_WDnvk6_l{pG_8ZB`3yQOpagw z{6)FCo|t(#5!~|02?x$?^BXrSmoiqU!ZJ31l?Im++M`44!I18Obr=kSgeEkl>%A?b z#t(~T$V+Ufv402MU?>aZLfZHu2GjL3zh|N5Sz$6&+pGX?7FZ=I;5HmgYx@S-+J5mT zJPPy^h~&DL08v1tjZd!?&u}Dt2>LN@>#;ln`7rn%B)kaz4tS8CgEJ>L!S#y`1=b(k zgqz@wSh@*s=^b$MECkj9Cog)y`v90or=!RvylJEo8bi4Ns*R-{MJ><>3T*OJ$t#AI zFWG+~KV`E8S97RAd;Rh5`|JJ*1Ty~hq7j^aM4!-g_Wr#8XVDZm0u@y7sxanyi~*yU zCE5eVE|$D(NAyP!Tt7$gDGDs}BY|hBV-S7l1}UA=bUzLKx6z&p8*$>Iira$=29chgEPM5 z*pk2jiYfuRaT<7$Md5`E6h*r$@&mf;x~ndTE()|j5ul4Mx@tH5&LL+slAWUH$h^F~ zJomwK&-u=GuHTuMC@A=?++XvwmlfqN^f3I{sJw+MyrU|LP=wl6s_9*=s{E`~b=coQsX;%;wh3yPS;+bJ=P9>>J_tPhPegOSSO7}{sW zYuujqa$bHx;dy4CSmL;t!`tKht|U%8(Z#%2z$|lO*(-WB)|3|~G5QPQ6y#2b3p6Ks zPly-Mx+q>k>%3Uu{xqY5*3;q)S{KABx2E%z#LIZ|ia3imC&y-cRh+|YF9`LvQn~mw z>3Y4QC+3}guk9tq4Zj{G_78iJ-woV$MN10rbi3_4UJ!O=Vx@P88|z&G>Gwp#OU#?n z?QQ*yMpBeNzj>!(CYsks?4aMN1tg~SlIhK^tb4V7Abjc7BTpoj^rF5D>PFg6v2dsB zU0mV&C}O3l42)P4YD)>{aMxOfT#Qtq@2Njkxm}azW97CIY5RI?JVjf}9GH<2E3q1z zu`Y~920qf<{k-7d=Q$Ugu-nHv91rUb-*9?p5whFs`9Z_!$?k@`;kW&0=MCrRax=lS zqZ96QI$k9Gy4!B=@Yt?!d-y76Q+7Mft(}|aLLRu=k7_iRv)K>oq*4}c;OqF8RsQkj z?#c};1JdKGaJJmgiMq~)hkwyS&kLN3jvEN)(&>sSCouvv&03}@lEQ~>yYIaxB{s^8 zpzf{2ln@V+0(~Q|Bdl1mh>o&E#RL^Jj+{o3M*QREz2{i4~v}x7WY}L>vH93!# ziZQl|<7kjmDC#;2`jJV{#isB&Dr^dRtnI5$fdVZfbWn#uV9WKEE|(%bHHvKOwLKl% zZtSkHwm;b9UE>eC>)&&(Icsh^^p>4y3mQe?_!purmT?~X?KbI~jlvTvJiymLA4O8J zdzQq*Ew|@6-E>$Mbx3KBA36_ZH)x!$n5jZ_;dc@v^mjcuLqp7>c=ppt_pI);IDyAx zX2W+wPmHN^0nH)lR52{=B-ANSp}tlz61Bk_nb(zV_9#$s#6e*UK`F&Jkl|5=2 zck-zD2w=$`r8L!n&Tq&4N{5nF_?5xRHTJZD3Ewg{HCFNC)S%%WUFkP`K!(#GH27xU za=1t8{7&FKOy$|G2`o}oE>ZPh$F4bdWFN0QnB2jRu{f08lV0e_2cD>;i@`VY@BjBM zGu&jU$gQ^$z3xVd zvDx-}Nx={O0IfmYljK*)L0N!#+UCxh)qh-EETo{$Tidw1F)gatCMK4 z)v`9L&Z_fTib&SM0+OODL=D>T>t2{Xr0NKhlJMF!>GeDpaV~|Jq4AM9j*QO+P#!Wq zr3vFhi$1V{+6E)I9$t>MCX5M8N(F*niL4BS$W-%#LagCl9F*eXe(tHVZ$AZWwal;? zm&Tu^L_dSUtGJ&}l@Ty}gY>7`9xYNmr>L@0H6`P}EHHj=L={q%}WC{W4FBt+o9Y{A{z`bt8F>C<0*@`~4Oh$Ox4Ehmm8#&p}@DRbx z$lL}%?4}BL(TOd?*-Vr=a2uEn$J6)CrwZb_!FQlI-+}(IIt!dVp&%|;O`xg(A9zG= zNA9P`1z@p3C0u|;HmMH8xABgiEEs=|(O2KS!gljqGhTMWetpXUlU;%pT&{4SPpjPN z4y>V+vH`*!Z@ZFD3dl85-W)z}ZxC7g9dJ%9=p|M^@bC9M-uX)LKzG$Vje%{$7LB%& z5%ge>6WtfviCU}XFmc`Y!Xuk0uR+V(xI(f>#VBZnY$ItCZLOr1@dqYGav)*`L3UEw z1VZ*@%3czK?EG0mM##yb`cJ43S^@?0VHwGdmQgWfU*m1l_rYz6=i-hO2b*Eg1`EBV z$&Fae-Wn|}yoKHzWQ9(fi|tfDLnyzh1P>q&3Hf*=Z$kbjO{qK$Ib`xQDe^MdA6R(z(6E}d5M~BC>_$L#XqD2% zOgUiDNKdy9NYU;*auFo)U@4C1dTA_xP<=}hW7s`SiA{a0i?~mS&6-EVfXiN52DNuC zBdJN$ZFwHTiCJ$0f8`6PGem<6-=zBcRD3`M&}0+|ytZU=SAjQdd1NW^%AkG~v}5FO5?0=XFm4cs0dk)GLY8cPjCygDZ2mYZWmH9NP6$JokKnY#UP%IY zn*#ySOQw!b+VswlV^w z#zi2p8kZJu7B-lOW-vk-4nDD`4<@5y7+oi>yk`uif-lq2Kt?9frVd8~ZbO3R!5atU zMMX<~I0LjW8*`S2)@ffE%xo`4$Kz?{=Vh#-JUAAYx`?Usx}v%L6R`sA%*IpkY;8WCL|Kk!2_ANT&OUi7%i-9PLr@M*^Q6~t{eTaV z5bQzV>GXrBa|lli0vsgK!3|8elV<8j^^t}LU{Z(sjph0+a$|Ovp%ahN$8-$NAyW3k zYZztw-05(lPr=J}!&*dVXW%vRBH8J}cq+*F&)kMtMg;pIR+0<40RflH)63MYc9^y&1}Q?4$tRYZ<{WtskOq|0-lCyLORP>2RU6YF&noZGJb6CeAsvhbRdT$l&2$J0!hsZs5$URmus+&`kX0kU zMN_>;#c3)IlPCES+J22IBzMQLQ5J0rD*f3^pXPuDM3*dNH5PDUnn)BV)U;l}ouq7a zQZ;}flXxy8`{9_a@z_9+QL2^Xm}G4191hNhSk_~#0^UX!<(itFaUmcH!g0o`=uWJP z?g-5CEegIFv$BE`{F9z^@GRuXl>QJWi*wAbs5GfQHnP$*86~ViJ|aotTJU9fk_EIV zz-QFmD*Ft5K;{iu99;rZBSfX=Aldf~qu3RkeL_B!1g67lEi+2T$X)_Bm0J+`%H zp#V0^B-w&rmM;2}vIzTeR1S0SeR^>v^37kL>;1X!|R!lxvw33Fyzj-j=YZ zr@-i7qBsH&EfT;=Ei$`UyPd^DO<~B$-b|C<^F&ngQ~}X`T@w zUO?~Q)O&P_Z0ufU1yM#6dhJc8-R<^L^)EZ#N@Ily?#0TovxJa&Dcyth6i8m9PQyfO zcoH$5M%^8zNB2%6n2y|^lO6dXK0&@p1!oB6ff1)4&noAX=CDqaNlHd&tuaV;+fNn zqaCUk=ZKx*pqiI6rWX8v(NHZGqecnL8a*qTBxT*n?rRloZM~wcuVI1c0XIYq0bspy zN`6ce(ehFZXJog4{D@AHl*Xgbs?--)4cD bv)cU3Y;kdResS?kd2a6H+|{`gCrkeU6JDhY diff --git a/sgl/data/__pycache__/transforms.cpython-39.pyc b/sgl/data/__pycache__/transforms.cpython-39.pyc deleted file mode 100644 index 3652577f810583054fe839ccd0d012d4be7fed06..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 7060 zcma)BOK;rP73Sq@I1jzXvTQktL*=PS6v=kem*YH|hmxkLn!@gDTw_YjrA8dhaK@J$ zTM{@xQ6)e(NdqssD6)`&E}C7DAJApjMVCYu1zMoZu8XeHMf#mX&S)e#MMvi4<>k2# zo_o%BzH{wOO_enKR&THS`s14RH+q=-Ie2&lS9C*Yn$)D|X!Yza>Vlv3x`DgiF}qgX z>e_X?Tc{TV$rv5cb?Q#HSTFLk*(r6)^>VjTuXLyCQ)pZEEB>@U?geNmIscsnC!(c_GKChtQd&0(afJcIUm`7E~= z{eoXy)Oeoxk8OEYF5vB1epiv_J~rfexrkX7~@R!=qk#vqh|@|&?QQ(O7*K!r^+>&K;ylCIZq zMQ@=>sqWI8i{!yZ9CS`ip@kzL}Dd| zG(Rx$q1OI~B^N*6ea?+~1FXgMvG(W%w~rR7di{P7w%oqzZF-wQCy4i6a8E8d7tA`k z(O$Rf$0}%goz5PQ?MbhXFLSq4uj{_M_vX_P58NBX4Vuf{8iY+!D+}l9(zLwtua`ex zeI3hy^yDhsZ7*`;p1bMeUwqg1L-&&Fh0?veR1<0%DuT9I(=%Q5AYdKNK;ho zKS|=zw%7OFUN$U?x}-Qah}^rX7q*saR;E-#2Hn(*f_-1j(-7xSoc?^$KdU?|F5oep z+YG$OmlGgzQ#6)oP()Y#FOwB)9_T|0K4xy|tmb>Hr54FrrX?0)qV~R?=oG#1-cIOlLoSZ>wux=D4lVMemM#f2 zv9%b{gq+PIj^{~CsSSksb8IJOTsY99U%#zMi;O0A602R5#-es0Ue|7|42y}0`4RgL zOZT-!jp|s5nK%<|s9)1=|28huN->+(t^l~T_SAt5qp~;F+2Pc~_-X*&Be?B9##<4< zD-f&P?jVZERpE!Pb-XBYDT39e(-I?D5K4bHEpz>@AGEe(Mf;xGccR#9?xaT3i&Jx} z6ZF$k5CtJx!=|suNfHtTK$4G&VB9v9UQoyJN%F@_GH*JKzr zqAQ8shVg(F3Bd5x*v?UeOtUyFB|7fqVI?Ua7Vc|@&VAs{juo|&%H*?B=w~weln;xU z98gJ%?m=$iS^Z`#XGm+HsZZllyPT9!R-_HXHR96YR8maJWVl=}B^9dc8m)#~VECC? z$J3B3!SD(&JLkY2P9^3?LTT|#GL?t}AX_qZ9R?0#EbM>Epb~7klPYvy%pfTbBzWb= zzB^KI#ui5XPEm^EeY@UH&P&`QrZnnH8=CJSg_y9%MBz>#eVI=k#N)u#3;l>Yex{Sf zanO1iBe=OqjNK*(kK)+!l@)ilwzU7u1B8HXpfGTa-c7A91j;qx}*OFH@19ClbK_e*vflVsm z1~jrob)dY1cl2b#0Sb(|#{N~dod=roiW?1@+b(G7GA!XrjiYr|zg9ldU%g0UVC%3=h{%;@mh@7-nY59PUJ+INfl|GKw_sEzL3di&0#XiCMnO`O zZ2iH^icpk8^j}aR)RbZ%qAJ3po^vXVh%2_v0BI#IMR(*Z*btL8SsFXK+Dt_L*4)vf zcSpfhq8IToDc(m2zovzEAP)(3Z!B*?{^yhOnRFspejV~Q+pJF`ENwUEr@EO6_1r;x;sQ`7bRwx0SLKS^N1X9+siB9=zMxz0`Awl!tjSCJV&{7=D04>ZXoUNgC zHqeH1@!4dSX?YcEs1DC0)nqO?b4~k5+{gg9l;615C=na z2m`C_W*Ik9cBIoG$kg$E6FK`1c`$p*$bZKvU^WJ)4jBRBHH@=%?sPn1r`Y9qQ6r|a zunpuJS73m(#mNki;h*UZ6O9!Jq$YifTBoYzDdp`c3I*y7DwzL_Ju5AQy^t_V73pQ> zR(s6eQ>)xr61hFq*g z$g`>+(L6;ysv#Xr#8gT|>kcy_EC`27;6|jYI->e`8zNSX`YKKJ78R82p1=n67}|b^ zE80MTBcMFkmIVDd%%B#42E>?bBsVs2VwQLmSk$ah!kwfXF)d8gE4Zpif;eRBJT?$y zoN=WE<`|oMr(^RmqIC}|fw$2`X{K)EbO=ad92z;XqB{{Qx+5;DS19)8CPVwA9l#h4_CZo!w~O*YV?1iumcb@m(jh}~978`Y`k8s?KtgdzrEMQ_+B& zYRs`{Y&!o1&R#?&3QXp+DdXsLm=8oK5*U1t0sh1ppBW=&K=+{52Xr!Q?q6XgQF0Xd zoh`T1>-DoOSaJQ;)++PdOSKhu8Bz0cwhtQ_pgc>RMmgE&AYwA3x;e^}ZY?3Gj$wfF z9`zG^f~rx$$%1uc%$WjM?VQ~l*K0aWc?qqxL2HfvgCQb1BhpN$7!HCH)=fM$b!y9h zbo6K9VuE@N6>B&u>i~T8fHR_oP_R*ZSly(FXn7f$GsfG90f<7b8D1l`vK>Q#VNk9f-OZJb50ylx)fKd+ykn=fCS*S>%keo=Xp U?xim(k5?BKE-YMIICr7)Uj}uNng9R* diff --git a/sgl/data/__pycache__/utils.cpython-37.pyc b/sgl/data/__pycache__/utils.cpython-37.pyc deleted file mode 100644 index b6ee9fd3083408607a1c26527e2a9dfbc6334aa1..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2786 zcmbtWTW=gS6t+FHv$L0krL^?ItO0HlhvjO39o!S$DXmz<@=8Jhoz+^L)-cBfgaTv`-_xm z3-NFh-Tn@tnBp1hRwM5UuLpy$TjNX#6+C6#NQJ6~v3{4SNY$S*RX-5j7+S0v&>A{E z2)YgEP1SQ+9`tJdzUuHU_PJ$tgZYB#f0)hl|dyG1{iannK}D4l#)FVV?GuIJLDxV#fCWI=eVUJm%BsNH##k}I)f?e5yfM^jNchy&SEY_ zpEb(uO>}z|LbEPc40wVQz)qk76=DRERD`jn>ToJ{(Y?Hga?5t@D>#YcmO_d-b#D*7KC!`XnQ(IfQ>UF#Xd~rYShVB~n>)d;{ofPPUY1Gu z&Z#F_-sO^gm-8+(RHXGFc>xNGS#^4U^~Q2 z4L$Ny55}StzW{DKZB)ETZG%O1qeodU)reEg^oE<-T)~D;3{l>oBW^BAxAK0fk8wZo z%yH>z-q{m*^(m!vUL)g3W4r?vbWZVV9;<|9s(rW6?PUlQD~@u-wgdEV;sr-Rt5N;- zW$O-D@-ALYBJ(~5LwPg_iCNS^+T0Y0sFLv^t|KDx zbnP#XgEBbeN6Zr-o)CbFz4-+1xf*CYdXj`5BQETjP3xjZk7+cgA%^h%EuB&^w!LWC zJhZ7?P86{~b#mnxDz)B}7gYValqfG!evkL2MgNugVEFZvPN7DI(eqE}{%)m6 zl>Fl4lH=836NiI`ePuqwn}&+XF={KgQsKT$+xWc~T%9l@elat?piO=GzdpD994S_? z^d+4_6p48|^7jDECM}OD^C_M^%UwOkeYA8+)R* z(d<(y^@;zGJmxR?m8bp%Uf?^P-6Y$j0=)9^9D8k_%lDm`Y;m!}&^CX1U`B1m{v_pS zi=f;@cfUhurg_17&B%Kq=yDMCTAXR23z(_XSLU%VHDps?K?+AOPX1CdIXU4n|yRznc z+;w3E&e`m0vBJVE+vD~+X*91H?vTrO?x&IO^j&}Gs%)4W9~FI9d0D!VkIfT2yHxb% zR{Mn+W!1X#$@FobS%irF_2JTbIWSAkZ7!|cy?3>E`eezi6-zp+vZbx6U$|?dr@lkG zR5-6iwx5Dvyu*{wBhlfJuydH3aW}@qfxF*9so4WD=3`N@U0(BAtcpE8mO&FW4_Z{Q z4N=QFS{0gW@rM|<>Q->oQY2IErS1<>>04E4^K~Du!}0u)Z-pU^eaGlEqxwT_o_J}? z$6n@T;oC#=SdkVUXr@auhyM-1Xhed0{9;wAt)cE)lUGKc#*_=#g3Dd}P78Ymb2<7< zqYU3fcb6dy>v7G1EI0xbNiB7R5t!03#+GiwvBbys%1j%ZOKh1odZx1iQYhL87gMYj zkT`RN?epbS_>MD1t6`~)^PPuTm9Hz;-!}WwN6r+h_8fNF^CT$ON8C#kLe#I6dJFT8 zlEiq-w_j^axi(#$Nhighda#MMkRA3-JcNVyxxF}je(9#lSf$>>q9dQ!%spZ4|J-BW zKKYhBD+;APIPsQdZ5i-@t$hvetx($Pys~7 z*$X&ypZhtb@*;Cisl$|pEg#d(4Q{%8`|MFmPeu=!O2MKCxz8b}_HDpS)V)ho_cr84 zgqaq49H=kHq87gba+{ricvL&kzirK^$Z~`Dv}`t9GnP#8Vu*qV9r5yUzFziobBO*_ z255S&d@DHEICGdzO7^E@Si|uyxX}4K7RUWMj#ff6)x=x0{SpKU7)J?X8xlR5c*0TS zT2$A;m=u3XV@rUrS4mtWaX?j2x3pC%8a7d+UbdD{N%M}NNW6uu@|dj+=F_$%;@L|i z-X}4OZofFmaBhe`N3QyW{v5zgocaiT0``R{i7m}hDrXwyn40?v;=h1hw{?O_7zfZ( zHJtPFE2E-c{Tt@L!3D=LUjafp!d6o4RPNhqJbeUukAgys-E6w$BIC;8~|1-urw&m3{&WJ*jKT+&{uCa{w} z1G-~b%U!<50{G)Yn}xlOfB6=OI(nK$4XC5ct~nn+dQ77|2`_~1ALx#TuP4}c9M)7M zC@cZi#EL0hCnX>FBuE5t*`YDsUW8T1X^UF>PmW37;$l?MN z3c^JyatTVxC9=5;Wo4Tju0Tb(OfFZUs_amW>rhv&P=h`2lwE4_7>p@bX`CluLb*nh zJOxwAb(-cGm{D%fEYHE5vPbj001L`Zy2p#Ks60kXybR0A<8+@_U`2U?9`GuxDo@fH zufw|X6m9S(Y${LF7H`A0@(k^83tGyv^pJO9S9y*e@gD5)KJ1(MJTBlpT*M_@#{0N} z4{#ONa2+>r6Sr_1cd&&IaTg!q9`0kiI}dGq{LyMX`DI+ST7_ZhAY?O?55hc24Y&t! zI!MwyGwd92@Ga>Z@eE%6hZp8r`7G)jzmS_}aW4+DR&nSZUS@g1{+lR|(qve_n|w@Nz*@RDSulK5dkqG)CX`JD-6h>$cQsIdy`3%P%_pCd@ByvDs_+axq%~L*ZsH?ahjrmu zd`ug#Av}khv;|wj^Z10eVOw|spOOo%@FG5=9oP|G!soOLyR-*;y4Kgi);AqGVf$Qj zmwzdzG*_P#y8)@e>;`#~DzLk8I!e+!Q*3p~|C$UHIa4}Ilkhsor;p-cBpdb<|2n^t z&Bj3wNBkp)eh~GNWWXyAF%Aa2(mIOstLvU3waPZLVqR?0^xEwZQsrll5($x>pmS;{PY5YsA4jm2T9vou&{ zSY}z8EORXLEDJ1)EK4jcmbQ{p_XQOFfaLvP`kaE40k(e^gcJmpFIh-UP@hh`(KHXK zmRE2jQZU%TT`E#K7@-PA8UXw9E*ohDtV@YUS^}#t2}x^U#403d5Gn)7DNs2{v*2Jn z9Nc)}G%M*HjG>B48~`IxNr?quGkWCr#TF=!?t>F^z)B@MF$t_(;uEvL97nxDPdo!_ zEFp?_P!Fa#-oN~>my?p``G}tPT}xNkV3yx*=PIG7lVx8z=SS~$@%E;Z^|5oeyW@HB fFwQ-1Z*)7cUy?fC7`@{EMz*Jje{O}be^X=nGeNwIl7y%vlRRX7?g(DchGXi!l2gK4gjT=ol=YYq+Wm@C(X<&> z`*Y5mrb}V1Uw7)7w!%h#-kIm@29srJpD=0r^<%?nN`8;ED^IZ_+iml_`oWF&w!PF# z{1mOy{ox?=LCk&Nr@o4BC~vTb=0efjjivvLxjxtrUFCby?FOOWp%te;dqO0C0^u`< zOXdj4rI3a+d!}PZ4(&)Yi=>kkS(S6LChI-hvE<8NvM!fRxo}c(%D(MaWb1E2F3P17 z)2aG%@`_wOVaMF5fnJfTny$+=dG&-j4f%q65$}2VlDvj@^GjwjX8}5|=fu8V*}=-A?EwiR=E8ef!$xUflOLljLCY-i=#tghz*)NiW=- zNWBR?HwH&)88g3wlGIQ!Zi=Q@7Bz12Q4P{AF`a4R&$rpFFZ|u1D zKe@BxZtvW>`Qgp`J9pbcT?M;(nW~p&)iqH0@+Ac+CX!TbU@RePGG600A6=z$`QNM^ zl81gI61l2AiWfmhxZWeEDxvn6r~Fi)ZDjmRKzU==As=%`>zm4ZyxH?%w^Rj%hq*Hs z+8CrTe#s@L%#Ao1=cK0hM3627K>W55sDDqm*V2c#5YVLo6-Dl*x$`#&e=ij0qtT)F*Xlho$Rza z7MYnDv=-i)|7UD$9(<;GOW?Oq%7-KHn~-cUw!dP?llB=OSM=WY8j}_ZcEUO~y(vDk zQdZew7dXlj4$ROT7ICE3Y56=Cv@4%+=vSuEgMU1rHQ2bSFFIU)C)l4PIp*UzNLGV3 z^}{XVfpkBHHd=>q1D9m=gr#;?$J1CrcBeU6O>2rCXjy;454N?Q^H`;c)fNulovid0 zR%&Jo#C_fpPq_LzYl27bGO-5a#cEBg_Vojt!?>kSi{u%M1@x;FGG)y(F|)$`((nXDpj1IkEr0v^H1HpNFhXSKGIq2$fA+ zvIc8hp?Q%lC)5*haQ^{I=U_Vv=g+fhHir|N&#oL-#>*+KRG0N*@t0fUm7`|bIGxXy zvlY54AF$|6lcml5D)#sfXX91QMoX!Yt&-=Z3ljGzagi?L)jzC$%syd}_*nP+n%!j| z7XjAT&H15D(0K((-Hm{vl2P?pejioVsgKfoUe)J));#{Vh~%zH&XM?ilM9JsH<#*cKGzYyY-=)bLQg+Y>b;{M>BE-gvx zTYR*x!+VlTHnzgJoSdxd9*xID@ggeDij3Wt#qBkha%!PIT&1vms9d zDptr^`u=y82g?5}h=+-`s@sWD&EhOP*Pq)ern{+=I(8=x3#zHD{n!tBd&NX;r@S(| zRGsQ#&$mN#hi- zN(OyrIY~X0x`Q}zA8AEP;){lE@)^9}1gzqr(gVg+jM|N70C5-?Da+J|E0DP=n zHIE&fY5{rsUb630G|(jy0MmZ86_K){(cbO% z-6Di;yLayHY~Q?oKQBjdMA49!yKod(FJIh^)0=(ZSOS^8)c&ttgEVckfCMj)WR-JZ ze2Ty?1+vq;jN!k*%%p*e;h0y&Rbg{W*kZJ#6U-&R56nZX`XCVe_<$hbpA=xJ>kjk#v2hQ~JS-|Bho zg}fXm8-zx4Q)5zsg1I#aI{Tro$jcOAEN4|GFdWVV{BBG>RFdMm%R&eTQBSMrwD2LL z$2T6rFZ9=xPH)bY=LhhGrs(o;7N%q9*%+olnCNdbE#))C3#tg$t>XOOp&Lo3c?wM% LfmTCE;W7RTEy%C(<+l*!cHSOy||r>>q#3aj%Qri6T4}%ZDwapETy%V zyA&3*C1%-&$er}1Q};RPjMTUMiT;p%?UP^Iw>(t+4wjT9xupc`0dN4EgLA=$pw)62 zesAu-A2wDP`zJNd|7>k?nMO;eK%Ims^1j zvQVQ{c`zErA&ABYp$=7gO9jI{H0R6aPAbEf%&p<$*jHg7{eBdOJzDYnkF=v-_5@YP zJT93hcA@3pxo1fM?JTc?wvBdGE`QGY+(=sYYM~R>W&KMbUDer+%x2y^^js;dU-SlDyQmCuu`?!(P)O9^aQ*b;tIG&0rmiGl|em28A#Ry7)Sd_PJe?fa@q&84AAdtsJ^rMvjU zK<{Nm#Y`=he3;4icf%x9fe!s_7)RQ3eXW8h@gsTQIsTI%HdNi$VWx?MUF|H4`)Zyx zsF+D=0rXeAsQZ2|4zkSmzhYm1dvh-xgqvBmzj^=Gowwq{gUxI=-keIk2|YK4hiVx! z-$2RQs2I0JM=Xmbx4AeSHzD7~Od%Y~Tqt)#Ump%pQOFck73z^ndUP_gAJJr?_fWF$ zqaq6(!A6fQwWMLGBmNmD+pICzM8HB^+G}iu&cCwrf~Jn&N=Ac+p;}M->tP%Y!bE55 z5&jb#M2)K9;+XS-_U&4Q6^BXTeiX!`a9b&O{o?l4t!@9o$M?4Vt?fIv ze|7u8_J>`eu7X{yQ}xZW>RX_SGT!74A77=j`QNNPl4tfKvB+u7 zal8gX!U-QiT?y4ET=Qdrww3b}fdjC1Jn}YwyuPV|r<=PWET~mf`kH^{OdCT@;q8-q z7FL=K3;R(TCGZ!_@@1r^3Kk9}DRR2YTDFOi21}Rx5d_Oj?h%}7`&ha zsd`BXh3kd=3}q_0&Bv=G`Png@Q}BZPWFIOSYuC^x;~>NUt`BO!V^SD zE^e@7;cSHB#q?=Q+PV0GNn|*h{{{P-eZ)E2cd-NRn29~zWpyHQJGW>Zd=2kU*~H!- zg0F&7MX7<;hD5`O^CyO!>74LM-R$YCFR_cDtB?gqj^RW=r+O(%-Z>GI#S1G5wf^cllcn>0`64h4`=iJg@VB%As!z57jr|t# z#bft`<#Qw>_WlQ?zcSq$XCRxim7cKYLOqOI2~&Ja_X3 z&S);ba#WuzYg(rzTSwwAzn!d%JGy;5moMilbW8q*CD(1HJI}>o!P#V$v+--%%2&y| z^t{CVNjx7zhso+6Rui}UUSrqLE0N`Fc4Tq*?Awqq``0zZ82x>WVgjtaQ}6?mti3uC zyq^GAWn=ect{>OesgF`U;1At3Q6j^G7Zj`N4OABbj6$%2vWSw=HMYxg_DP$9>=>bP z!v8Av?fYeZ?{f7$aNrIT8`sLTe+}(2p_sJq-iV`2_tL@eW}g-$JX6IbXfDEa6Xw&r zvT-9$dqJGt1TmZPieMD=_i3FDDkGRPK}??$$?T03_hT>=Ji46QeuC}<<=-e6KP<%4 zhuww3-MPC}-e&jJ5{=kdG!B9ERFvqVb9UF=_aFTF-koiubx{Fe>+KbGnuGu!DoA!Q zV+^E=7ET{1fRZ4_hhbupYGDE80_tF_sNKF>UZi$-5b5rk$&OxYWV-RqMHDdu3k9 zO^77(`{^i=Uek<5vRrK`==DZ}Q4Dw`&*bsuN%-DKl|3}us~(gS>N?ecTrwJX%b5<8 z_J?War~PxRE9ckjhXD{!XeP9Vft1)she?(yuTDE0_`oAx_0X(d@ZCbFe!-_VnQ=km zea`F}d>4kuo}*%x&0XHcpEz}`1%4H6>T!5m5Jj{rI@~>Vtg^QG!l`RlEtlJ;ReQ$m zaKWqpXhKO<63x>l_Hcy5JK%B4KCEn?+7@<4yF-Ad1cTEUxPbV` zdC1@2L&-c;Ky>`EQ11}XLBX9yx5u$KVaL`9LneONk6Z1m-%6GV3u5L6+Ww3uj;_dx zL5l*J&FWhKNYx!RhaEh(1j&OSd+gN>G$s;&G)QDR7{7im4uQ?5D6vG11}U!t4B>r~ zyxy&r@v~x(;obNBGO+IY_deX-y8YgRqL!oy#Yj=>!+l`WVsR(cw+Fzz1V}?^qC$NK z(sb<-GJJz1tDl4KGt7P|xV_G0)c*x$P7%1nSH)ErqbeLRUNVX2(h~tmgTT5u2&4o& zIS3s2e&PE5AeAG4I=^W8{{AS4%MtYs@$XReU8)Ge804?WZiNMWTV@Ie_QhL9AF3Y^ z3jr+kHdS}1B1luTrf-a17{_?#yo5;vZWVd6c?mT^WCXGu?$pTjoJPCRsR-`cR2<9s zo#E@w7qv9oAdFhr2B{J_EULq(_c#vK7Oi&^mA7h=7>;5Ji8rPXnXG)5RT0KfvTGFb z7Vyw9;~Nj*1?IJ7Qk-|?)d9TV6kQ|U{A>(88zUXXnR&w*DbJ~C6fT4IMtRyd>3%Ya Qo?ZqH-Um&*pzvq?7eYxTA^-pY diff --git a/sgl/dataset/__pycache__/actor.cpython-37.pyc b/sgl/dataset/__pycache__/actor.cpython-37.pyc deleted file mode 100644 index a102bfa06e67db97b9d8d53a4e0d1e7a7ad6f727..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4571 zcma)ATW=f36`t8$E|<%TD2Y~VH%%3{C`>daoit4^FpOMW2LS>#t&2EilWr-_ilVgK zrDlh*MC|qUUR(OTj%UuCR|Z|$V548%*pl71AXgRsqY>0Q~M?4wiNNkn+3uli}Iu%bRWi~@ipLr0W__mItbuY`U@_@?iQ zs;E6=khv}zXg#qenrJJZGl%tj@V}u;Zw(X}F7>H-Y<0TYOZN6ggHb<{90v|cbT)p} z5-T=J`YTkKqj24Gepr?*#f3HJpK)a=J7aTe;bfMuUtk$OwK!9yj2*ahp0S{u*_mT8 z^gNVG#uk2NpYo5fDyX8=Q0jt#HD~7$X??bJZxW5v_DF2UiP}!5lSv{~D5QtUc0-qB zKOTnK|F9oT!yijYvVbU*hCq8kFp5Vi2&4;o-I~Z`5Tlw3CL>o_6B?IMj|z z5(U0!`<`bSP{AZBK`@B=X&MB7VUNGv-A~40H%*Va_uhT~Ms#xAO@~pJc!R}sjfOgt z6WN3$RMG}2#%tW;TUL!XxWi`+!*6FL^?+1JKMUndd#~U_Kx0GbNH`A}^jMnnJ20OY zSk8@Mxfu5{GX`ChcV=Hprs~xoh6M+EVP8#U7)1TuFzTpdHEWs4A=!BlieZ@2=!LV% zdhVg5mr)I&|4*7Zo@vJa3{JkZ|70CGILDf8;w_Aq_kXkVMsTz9S}=1u@nj}n!b&_F zt8-KHk@+gBZ(#hfe=8cLYLJX4x4%xuki0INT`#Uf_8*u>R_MEPD@q3aD7}r&YLTpA z!3B=VUN&i~tEdde8#z~O!GFgNEVAo@=YlO;+etdn&OtJYVGwkl7|FIJx3O4z^uWmw z+FV}2z&YRY8a_#|CBq=7;|mLkPFGe7)8d%ra*|j#E(%3lro=O|T;Afd*7GDk6Ohb; zRPqi=x`m1a!y~m&Y@-tE)FOrK%-oLAYul0E1fBdAD%}Xq-x1`*1A0a}5=s)I5{S*V zQH+Spnq4h8JKuqXEHrb*-wJFDf4|3K!E=86Cb%O!w|;9~VRzUko!nApolVUA02B&g`9ir6%c*x?O0Y zGaRY?>FzWQ&79h z1&_Q-w%?t3T>}tlw_TFgL1FG)d%FMu116=39K}kOF$D@^C-ziEz4`@%c;b^*^^6S$ zpeu~;Og(_0mE$4R8e?GGv0B^*qIkS<)|;&@v1LVnrmag@SVb|~YCtN52Y$c>zXDy& ztv~z?Pu&1b25p_Pg)K^_oV;nFg*nfyA)c7Ua%k@OPG&*3w?4M>QtsyPYXni5i&d3V z14y}>+0ZK>YsleR;8CMs`A1k*J4@+=lNe)xtDp1$|=h|RX*^s z7ETMjs;FYFa(vC4eAS#>&V8KW!)aBDvu?4YzusqL0Y7CwV;?c>RKrfyta?x5o=#~`G(=RC7RIF z`}yWcQ(anYWgFQhIqv%`e%)c}g+u?3T>gu*yv5n9r7mYJ^R})C+&xj@O=9jht&i9* zSZsat#4VZkvc6+N3C-o1$&l%v`2+$WgRl*AUD)iDFFgW(%qsxJ7eTQ?Q-1uE2^$do z9w?R|I=i+s)Fs3`5AQYi9}x&X`Zsz;}tG>lCOdG zd0hPv<4dj{^H<=#Y?(aJ_8^HUFcj(pfl3857dRjkHymuEV>q}>zEW@)3}4_$70+OK zL-vKW|1bCQ4eWh(qn|#H_xC`ty3re7eWOo0(^qNp3tE#mG4vIu^qPGNLIVWAL14OI zkZxA<6Ky@%ZC}++G6`dYiioKYIvc|QX!%Pfie&?e{04=1Thn-lO3(nWoaFeVlYOL+qBsVJhifjVIGK!glQlEY{`s zeqtg{dBc#0$0OCgR3I`(=x2gTgnkcF?LgzGDPWXj6PgEwWUdLK5Q3T(kwOt$rP>2~ zCfw2$6LgWOXrF>B1hgVNmOr3vs%GFpI2!IFZZ#iqJWMrd{MZA{ozK{vZkw4Q!9%$y9A)4U!`hGi)QDL0}QT+^!Xm z#EKD=Qc@}3GH30{1Xg7X-%^Qoc9SHM1a!I_C-E$li7v~qcoLgCiF$85iNY}=noty$ zPANybUPQ43%DQ^aeyS@g{!@rWsSpW}$`6RpR^y4>MNbZ>5;Rsh&!Wn!B+6UmP*0m~ z^XV+?m)Ih`hKj)u$$WkOBL23~;txg)gN6~imS@ejmH_rtB8RA!;8#+nS?GkELc2^h}E4-70SGHeG5;DXV9)RcH(OYK=GT{{TA4Y~TO@ diff --git a/sgl/dataset/__pycache__/actor.cpython-39.pyc b/sgl/dataset/__pycache__/actor.cpython-39.pyc deleted file mode 100644 index 2b98eac5c24cdf9b68c55a5374295b9373a65f30..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4654 zcma)A&5zs073T~oilRPNyDM*OH%%3{H6k>YHffunso|zb8aF_&4eBCk)TD&0;cB&} zNbPVGdj+zG$VGAqnx2aQyGQ>6`8#@^Yk>ehwD zZ@$}6byDls%-o6UlSaQ`=A~$D((E@U>-}|R-L!-&%J(f%4jT)*zaiKytK&Vysn$-1 z>)Ng9K&IjE!Z=N&_J+wI3e#cOVY+lv4yJpURCf{)zCBQbG*rN-Pwzy53+A zisWt_C3scIkaj)&5SR3%JLp2gVV{fj&#hf1O2WNg>D!?r%EG&EEm*&Vc}4hU?ux3Y z-M7GZT{JLyVofwLRz9~JtM7w{4PE->P=V;ufQCoB+tXgMyE`6^2a#kra9EqDwvs5Xw_T2@`te~9PnZqs2 zJhVz?9r>Ak%svKIP(`bu)rAG}oSlcH0rRa}(`c;PW6_Qi)lO&AX(Cl9q=%24hAzoL zJPNh{!62H2zmk&p0Z=FnfcAo59FJ8HNEh(BHI>OQOw+Kq{9rmzd#Nr30@BfLkkV1Q zQjEb^=&it9Q4oy6IFthwE<2P8vqL+KC5~P&+b7 z6zHPu`<@v9^?>21U>FV3GzcD8kFNIil1bQ0(}Uivo9|tV4)68SQPd;apgBDwrS9}l zHo*zClsb#mn8&twjWw9V_~UuQP~2VdJtRS{9*)Tq8NQ4c0lAGJDdF6=AkET(y$wZr z&dQllF&FDzX4W8)^3MD_$xOWx#L(wpHyo&$41;KJJB+&Oo|?DJ#}R3L5QU1M~G+`HGKWH^Y@8LC#876 z^5BX@SroxHdZ7ILj1Kn38Q++-1$j4xhK4#rqi+ zNkvF3Z=t1I=!i9#EJw4ASokp~q3q1vu90uYkv{~SyoydYf)lF(r!XhtPZ|@;eDVAmJ@~K<@ulaCF^bLlWxuxJK`BXX1Q(D z$<9Fi*LLo}DLD%UN91I1ObfWAgAIeth4mk?o&SRUzHw2do%vgPN=?%@dOb*?I~uFK z+3i^x%3%_#FjiftVb5Hcm_*lRm(bL8?OekS=4s){8@E3XHu+ zn%|pyJ>w|SUZ*5~1PF8E+Pe*BV4O;6D#x*sWqbkz>5{f*GV0gQD8#8(!s;9S3jkM0 z#mJ{kxRqcPh)Dg68d%KX8_b7i@mTH2Y`(ViFe~^olDYuoDw>hhufdsuiSIMPEIe1gKs(j>8h%uvsv-~f4?q%Ll<=DzSRo?f3iQt7< zRaEh$me-fJ;4W`rAKLp2{&S9KO+0dP!_epQzL!yNC_8H4zSgjJLp8GoT&8!lF4!?! zjSJ^vo;ULH8TvL}vA$yI7iTDJ5M9^-)Zm`ia=N1p+``%dyX8&M$k)|owwAeBGpl9w zY(1-H6|wfEmv0z8TA~R_{yyKFH`Rrst!yLPq%i-X6~E?K>bbA@;T2}(EoRN1R~NIE zvAjzH(jl0SaC>>{Pp#N5#%*gGtEX<1DCX%Krztkpw!UMM63yfpg(Ks65>v=|8UNfc z7mRI>*)pQAN2~&$`#c~XEPXM%_mmMB7yMH|7#BRhvQ+e?dwN#9(9C~G?(^YeJm=`5 zSZiWk3ll?Ejp9aztm#z5KptR|CBH@oKjknfCj5-n$tWMOml5`CncUU(Fo}>_mV`P) zu2aFsbB9yVGKaS@F^6BIKr0yZiygrl3maNqk9`jR|07=h9J{ZA8~gNG^uG&;;zFQFmT5|4B#o=v8Q$m{k7(@-;| zee(O`(JYyz(x!c!bUa7;m&9Qz={_X|Ky9M}(2fAIF2DO5leEfh10LQRtImbOnKNXM zCND+ycsJD!B#xePPWcN0W+w&v+$`4M@qA_2a-nmCOV64ur{YS(PCSO7Z;hwTX@^% z9^NR1Y~k%9yMY<9R6u%s%!#^3tP(&Ozv>`Erp2@ z$`QCQZh{4+11ACfc3v{+0KyBx1nLshB(AV8!6O`%l)djBBWNHD)Kx;A%z*QW-9us+wZvA@mfp3NXE_EZjRVw4;hBgAU)Ud1uzDX038&trl;AOAlQB zOZlI%*i%(eq53NG;fia!Hj*)aM^c=XU3A(n$`RxfceF?4h?!MPIYLuue8?LF>Y`v0 z$z=ZG`%#EgqCJkMGt~wTDlC(5v{4ixD~Vt1)Cx1MmvAL6|cPG;*oFE+vMWO2rz;^0I`b+*;S5L@Kb!A0=3bv@8 zA_h|V0Rh@-GL?In$uV{Nv{pGO#mcKn%v%*=Pb+Wp=~C_w*do1x&VnK~p~yaa9={hb z;s-T`N<)oZ?(zB7(tSNu4I&vP2P~;vEo4Fwp;M*elX9*g(0&k15-~$=8|YdP9Lxq$ z@x>H#+JoZ5jjO{V(IY1zU!`tD9i>#dzN{37VHBa7J#m6mK}P6wIZ3-z|7wT&BD-{_9O3(vNJ z(zHwNsY?l|$y<$w(-+QvK=mpqMC`ALY8<7eEH^oi_@% diff --git a/sgl/dataset/__pycache__/airports.cpython-37.pyc b/sgl/dataset/__pycache__/airports.cpython-37.pyc deleted file mode 100644 index 4834271a965110eee6ffcf3f62e369a2248281a7..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3862 zcmbVPUvuO}5ucH?T512X56NeOvJUdz@> zd!yN$y+}O7N9Bov_f)FvH~0p80zLqb^UBkGg}lJ8XaCq|@`Sdg)#@3Idb<19-TTvK zv&L|J{pmZ=%~i(!O%IF90rNISEI8{SD8qh`N3YV}*AcE8QpeI_cxea?g%)lcpIir^oxZuJGe#CE$}SMJMjd;q4t zn~CUNsKPu_U0XNDM@b-~Pz3vN66w0go~22KN0sSy8Kxo|1^GCMRUl^X3~0y2RhcvS zE||!irf-SPr)-}KS5%+-eLHf5C;aE^l=my(Yocy=S2RTPIl~!SqK&&JR>Ugq{-?}g z{Tg&%)80>FInJcYwLQthFRY!O_Oku`co2ul3$o#hb&!EOSAy>-pss-i)uW&&90`j0%u(Nt$ zpYsoCrh(DKXbA>Io!zc2J$z5QdorBHiFTuj%*N4lb8hGDG#o|SX{NUKvq>s$O#Q{l zdfm3JB-yh_>e_F^WD@;c%1r7?J{d<+R%uf&2;vm>4LAw@ z;~wqKe5KcebTSH*4C6GwM!_Hn^L#G-I80)3NxBs*lnH`iltwaC(R`cQe2Z8dLk+wb z5a$Of$`!F)S37x>>}!V(s2!PQ3c6Fc=(R7=pl^g%6hKV+!A!#-_&fWe+dIfcQ7_M* z_CC1(;C6Cy+{=eaZytWVa2^vo;}h9|L7Esj#R21<`E6JY?wX&&r>l#z?JTS~;ItMO z-O0c|!^n3)6gz|S4=oJ)GobTx>)cw%1W6r8?a%FVYiJp_Pdm4l+C*X)C2_8%)g;`D zlKl3xc`>PuRkxz;EFWvHdn7z)Yo6GaPyy~S>5=zY}H^3}TborDxjkP;3auCjq^Lb-Y`}Nlf9imHtHo&CgZ&*mRT=X zax!@1<7m)>C3~YVSCQ-)gY^DTUbu4U=ZF9IcJlzayek4kQarvOf>3eyGdT*CF}JRa zWt=KWUX}eYEqTRaQ#3aL?>>2EJN)$gUijw|dI&GI-eq@14TuoISn!m_0^Ge4qV>Jz=VP z<`ouV!B_rSt?&=)qO!&CTq|lbzN`zkYz$f1oNI=8quHY?nz5y3Zxs${bxEUE*)E!A zuFvE!Wga z;TEe!r>GZ=Vy(bAg!gB!TsIo83LkcSP+mJ(QJu53V!gOVb?Grn-*lK-KdSwe%fE3} zZg4hTSJ#RS6YU$qpn)tUF!#reC+v4DwVqtEok2itSAGkQ)Ld#_d+{G!CgS?lc?#>V z&zS(FKM3XzT}$(0%}=!TY_Gei9aOc{1R;_=ng;n*+R~4Z>rLIzHX>EK;TX$=uBc27 z4s`87mJT09X^z5XrDJVJClRW|K{&=OlrlWg{1LWJBNX}s`%?lj=|tP;PP8K?qp_~O z`_R-z`8s{liH>8{-JAizk5dsH2cvMTt&?1H)FY}En#bCf*)u35j^&&5nmfau_V%Mt zO(Z_(QJG575^0}kpiZDw!ypU0#;CN~&gDScB3HV$kVnVb9%PBG8^Nyh1cvANB~hItwHuWM#pYJcAdAl1<-Z~ zwwFrVJi!i3=Ear-b05Q?mfUPQAPezxZ;oH9wC5-RMulzg3zW|B>j1x&s2KcK41Oz1 z{JLeesLt{0l|JIxnqe0BwGDp#Iey=H2`PkM)lTZc{a{eJ0UWed1*kAOzuBENW$xJ9$@1S{&K@wtTzK(bch0cSwkCY#G!M&34Z zNX0_le)`YH47hKL3SqoxExGoBOK-u3)g!Xu?}1`=xO$3SaHeDGsg6L_`h-J|R?*(K zQBeI|N&f^4=;XJF`wo#eh)|Z-?yS3>-uQJAp&8$f)A2-Yo91ZSv{2h zrdyYUkeQFD!pR0mw@&Rgb!0)HYe6u|!~~63pc_H(bP^`B8Pj)a52%CH)XB^*)c(jG zk&lR62hq)WUpt7B1pWK6u|o%SgBHC;gnCd?ASVSptJQ2&8;;lVZkd_xs;Y!oXp%l2YhsZL24NOwi*Zy|`e*>jWFRd=- vAlaFyILXZ)j#=MfVuLijIK+&jBWUmL%uqk3z|LDB9iakSzMzHa{)7PIuH diff --git a/sgl/dataset/__pycache__/airports.cpython-39.pyc b/sgl/dataset/__pycache__/airports.cpython-39.pyc deleted file mode 100644 index 9a9dd261186c27d432d1200c48926882ec0899d2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3968 zcmbVP&668P74M!IjYgllt6lGUQ%SK-j0;hCWd|w+#z251F2GrvvI`ScU}{!9tDW^| zM(%0ZyJE(N$W*vc_)eutd-VUpkyD?jf{U*lb0GX)OWLn+LaKiA=5-XO8y*ELl zQDyjDd-75E_6lSFq?@xp2RH9yWCBDm!DH6t=4uV>uES@tQr9)Ia<`1D6?=n9w_@h( zxH_nHYle5?`k>Km44U2MV5z&r*&QZI!hOMn8`dUvcUkZUtX+PMH?f^I*QGl$8193q z?W7|7AW%UTsyMH|CDCx7!8c3U{KcCY1`7n`?R|&gK-pVHyp`y7>+mQ_TEZ@LAaHqYHKeYCF07s zvN&C*y`)QV`Ye>X`kNpgg&#_pN?poE!%)gHZR+`cl)w+Zq{Hd8p-g*WmW4CV`NKf% zXS!tEq}`d%^qQZH2ELL(l=#@l@5Mou&81I+I1*>1oBl!>-|vS>C<7JFx2evzh{Pe( zz=J+<6<>v!BDU>nCkx{}?a%?WBhyqtcM2Q5`ad-28sP;6UUBR=)6n-{u`k=5{d5p^ zvg}Fc!JWI;90tXg`{a0Z13+vcg$OK6pNbN7|W2;D;VoSgFfy04v;P;1UgCR?2A>{)RpYiI0k_f6h#%Lm>U%zf)Ts`Y72|73^n;N z?sMyew-jb39O$)ka+@RUO&}ZhRt~HixpA6sW^#kr($4tfeWixktxl&Osr}LJb}t=t z;xLubZWzh5lPNjsz56umbzslVAjniGJH{lPKNP1fJ@NOw|9iW6pq$+mfg&j`UlT#d zxO=G_1j-m(mxeM*lq5gOZ_~VO;zfI;i$9XFu9z)}z2F$&Y?0>4tJyLLV=I+!!X5h<(hiuD{iOUmyo59a6cy9u)bN1xcBNR#p z|B206vV4o$*CYxJA9sipE@1%o|O+ z7c^xHO>>8|JfmT=Sjrny*XY=s;G~9IBCd~jL(6}h)3SAwz2w=ob9z>$^}KP+4;saC zzDzoT;(6H6Emoj=Q!VAFL(se>%EhW$&fR<^Z{@YTp0DOOnehJX6>H`UE208R|F*a^ zURJH?YQC0VqFVNdCGR;*t-a(&f8(rJ=WP6@x|FXQM63&gm`Cgr*t1xFzMfQP*Bk5x zo;+jYnfCuWtG&Tbbw_>&A)vX`y!!ep{7k%d%kxCoU7IsL+6mvE-?S~w4>dp1*0bIA zhIUXP6BEKH9B4%3H)zXBh`M2FkhT%$+6{(SCUi-qvbV3RchjW*FibKONh=v@J3I(GvbrPa+h_OGVAd`@^jgChzWZf3#0zNf*AouJ<>%6WN?ey^w?(ybZd zfqvVdzcNSv2j?P_(Eq2iacR(xECaN&gGx~Ytrm4if%#*8@oE#=8U@g-(90b(j6U7M z#kwXC<~|0Ng^OI}LE`3(+!HLXk^aXPvIohUxd%i7lN)R2Ix{a7^)*(!mIlDfS}knZ6>0;oTSM(l&LNQ7KBYnt%`g=gkq_^ zJnM1HtA0t+UlF-W_tg<|7ym?w9FSN|_jT3x2dNmLo%40w_n(Y{cs65tTI~Uau$nrW`DK;%xJkq(av4N7<~?&S zjAMKoyh9U)Tp^BnM)UbVuQ!F>vziTyS6vL-YkF7BYc lF(x{w3qNxVa~9M3dbek$`Z0xe)&!wUhGrj)ex0w`{{>lR{ZRk_ diff --git a/sgl/dataset/__pycache__/amazon.cpython-37.pyc b/sgl/dataset/__pycache__/amazon.cpython-37.pyc deleted file mode 100644 index 4e3c8e8866e94e090ed71a6a74ab12e81bb20ecc..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2570 zcmZWrOK%)S5bmD$?8Dw94h~5W4oILCWEXPaKoAPgCBaVD zh68IUr-&bbkn9`$1Aa?i5r_PRoT%#AwZY7)YihcCdg}4jSNn0d+ahqi@%VG`w@=7l z_^`YjQ0~Aee*nP=rwK`EWYI>qBgf3SkqdV#@lrqX&6=GAX&8k@cam1xj@oG_>ZIMM zo32DFl-woU<=zqDo(Si5w94tXr0+k6uaM0?)$Uyx7h_P`n>iPs#wsp_>N~ny>?KT! zn6vFkB6OSQ`&p91qsnz)=Cfj^*JPaWJY{8(OcdkI+ao-Cd7)o;m75?0i6|$L#d}Z5 zHsua?kHW|n4)?f!MCLSdK@WInbdR@q`-nt7@9-|%1HQsn;T}FE4vAXe#Rctsl*Y3> z({3@&RsO;n47JyoG+KTb?>KBU;y#Fqa=Pcso`OKm>CaSAWmRNu9ov<~?R8i`fv_v5 zBD?OKR)l$#4brd3aZuSO^btOVFj_F$oPZ(DX5Z4`C{K&260(#bd}elS5zLJ&PQ^x+ ztBrD66uDG_-`$79nK-@9TiENd{4663y3Zh0I5yEJc$M zR>W#tYS#o^drb!PI%@*L*pA4AjFng{w3w3dBx4hP00;jVCq@gI5~V^FGVYWj+13tP zs~wr=3ar7j0yaX%MoC7v4P>pTH=ugHYrcQm|`rl=iW6LVks`{brw7 z5O#Z0Uur(bL?T8?=GkOavY9ggZ1!Dk=VhUt-F%XP&obU;d?Mu)I90nzzAt2-%5`{> z_@PF;f{QD-6;a>B_!7*O76@X^E-wuK?~YURaTgT!nfaX7^wb4J7}<3YUrY1d9jv~Y z`=lN`sr#;c6<(GbAoL15%Qu*TWjncW^MRg~`! zhdUEBo^B1GE{5fJG@i=gPL|!=64_{+#&T~6Q5>dm2_<9DGlZkItd&Yg${iT~S2i%cOQH>ycHvwWsV|N4XV{1#p`6jMhH4>wr6TsJzOpyvnbf zDyYJW9{VSxvbp;#s9P$4b3^XUtzW*V+lM{XI&N33s*U9HfMmBEqB?sm`8y?bhmzTv z>Q)^S@fGgl{Q-6cSo^8-kbF-v>!F$XhTJC)VIDFOY)ULAC|{?X%!KwK^L#o|5i$uX z47i*j%uusIC32=jOdtiIpgu;FsgL5ZYC_9F4V;XQI?LAWtIVwVO@T;t-|_ zm_g)e8@h*fI81B*-dBd(|3Y)61(M2<_1pz~Cal!J*gb z3bkOShrl>xro}rWQ~2--;?9O)aPc-Mrj1zg>c8*-=%}46P*&4eDhnEivu&HIkO->T zjjs~{sNR@l&?z>6h9`W(kn9G~In+UR1F($@BJZF;FhV&Pl0w6Ec*f%XlEi~0ho3Ki z-uW(o@m|ycIBx~kY<*b==lTU$fjvy({}p3-zm5F}tC2C?Vk`yF06{X{VeIiVPMQ@% z`8t5uK^Iep20I22a)1ZjLxFUvyNiZ05=jCL?v#Lmk+0z<YN`LGZLd9!cN1nqaA=5I>%?GiOo++Pk`N{#1{(`QS=&`Jty%qT zRkdfPE#)G;;46~{2!WYP=PB|8d4ReNx$rCORn94O@9a+0)X`B@sdRq6^PSEnlOciY z?I*vL{~ZwWAAET6V({=w82Sqkf(TlWlBOQ*bw6cpE=U8odqr3d(t%s^i>Qp#*y*eo zmZNl3j?-~DNhjqrol^3Uh(Lr-i3nx9_0t1EKPL0RckmT*Jf}8zsIq1S%ILTj^4FQk zv@~;OC(Yx6E13y?k{8mBM15WrH9VTyMsB`pHujFnDp8kQH$`r^=-*!A*)Lw`8(#Vt zgd`~yB=yAXYjQ#b6T#Cs^(7Ob7(69gng*apB6fNxhGO)TqysS)6SzlWDh}Wte@z%k zhv3O8HvFi}Hg#o#W>uT|nYT!6*qgQ7e;7ZGIu0L({xgV{3i>5gGXnwL(!Wwejn|T` zcjdR9@DE}A8p3Z_OU{BVZ3z!sALO7VS5fO<)6elKhB1UO5(JE4$8*od%erjVMk=jh z_{`?+E~rOUR?4HQHb;8hG_^8P+}{i`XM8dr+uJ>S zb|6T4#}>_{IsmgrF!UG%p%MJvpko@+%|UPFV(;OS;*Q;gdI;lJF!VhThFn1!&OCwA zzhc+q?~uBG_}3nU-aodl^`GRjkV~WLDqm{8VGF?NJg|PP8_Uk>yaK;fcFskv)SGas z4T}0)syS7M@T9)~T>JnRr??d>ppWxwFsD5b#M`{SGydfrFS&p_q1bO+5VWJu-9V&~ zGln?ZyAYgWHEx2hy2V%BJWxmQvbqn#PWesV!E_s*)T+!NHyjf+_nh@MLQ`?#mLF;c zXP|%Dm4wIW=I+k?n`3X>e8~wc7<|j+;Q=V<1Qd^V_&g)m^pJF(fx_8RS29pYXWrJa zm&?uO=PP3x{XvqP=4Q2iw165(^lG_UtK_t*-gzXe<*Llo;{>9al$nO|aY#xad(UXH z3x$my0X;eYeguX<$*a6FDgb@%_Yv^ds_4VOLAw@I03bq30u6Mu^E$t!h|KSijODrx z`0}5WbYn_3_spanyO>YK0PmlY&!A&;9>A=C6f&~)DXC`f68*0Sgu35n zARTo@#|jG$8UfS!M%n@3Ladi2MZQ6W1M_nPAwB@j(F;p2t#|(DJNMp!!dg{*TJ$Pw zqzz21mMhEZN^0u^E)~{;4XSn7T*7n%Gl)y;L$9$+z_cCw?vIWE)lYDN$&1{~Cw*o9 z6laD~oJ!zBAurU=a3X@!cm=J?`l>#+{<1FY$Za~8`RU5o7#F#@1cPEX?^JqfJ#}K~ z?Lq-^*cHSN;dK2L2tsFc2mpB%fWqiKI;9@WOc65A`D^#i(oqAx0~ewaYsVo5-NAV0$EwJ9eG6(mXff=>cQvCBZkzjmDkNb3D6gQW906hH;>gkQ5fDFANp&BP zj~t^uLV+-aVsTW4#@qOY?+1IjkM?YTyF6yM`^CF(#I8yp0AA#64);}ZtFwR=*zFYl z%W;;^N7%WrRynsr&PxCe&?vWK&Y!HaqF-^%!A1~2=y&GQ;m3hOC3w*9Q6TBs$*%P* zWl=yYyrJp4xCOb|bxZubMo>j`8hLc+!(dSu9k|u`%noY3z?NwlN=L9ZZ1UyfLaGn( zl=nf{SD}+Fu#+v2a9nTg!Tb9EFTnpLg?YBG?AKssu{L?3UB`Cx)K`bg6Ozsg_B&5p ZuZceFWAqCwDAz2<0J6v%L$iF-{~t8Yo=gA$ diff --git a/sgl/dataset/__pycache__/amazon_product.cpython-37.pyc b/sgl/dataset/__pycache__/amazon_product.cpython-37.pyc deleted file mode 100644 index 8580934d3bfdecde08774b4272a95193b0dd09bf..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3674 zcmbVPOK%&=5$^684u?-Ytd&jj(4p%ws-wn$FM^`VJUXyP?o+YCTW-9lyjS{8szGux9WyjsF-|^j`GdT&Cm<&AC9mph+G$Z@$fMtx+vm~R3j1p;N zq-X9!k5kA%cC3sX6f@(HzQ&pjsRXIa37mwrgas=#DyonU5>M!wvK_~deCap>_frdS zu8xzt*6&wyqRhE+`fiKm0=fss9XUmH`;Y#nQzwAii`jO-FJy`#Z8)&KI zNH3KrHcVxvUeHx$G87fK#SElWLzgbHDz&IdM~f$p+??FHLt!j_4ANBN=b^$W zXWZcCmYGwH1Vo^n(R+XaL|$x%oaX6}v&mX!;vzNxvs3hXR z)8dKk@HNgrNtX33?KL}Y{z^-#HY~RFq9Fei86b!AAwSZyX$aY z0LWQ$RoRGd)MbMTTv1OEOCY^3-w6CvcH%I(h1;ZfPm7FJPpSi#!!^e5-MkUR9XCjC zfiRgoJ_XautI}k@bliNjD{iRjC)YZQo1$U~yu|3-Pi;F@4!%hN(p^aDB4nr(KstkD z{1sm2Z|sn1M31cImbQIU;O)fEAX9b3!x4x>vxS(=V4{O6I{N0eKm>)TQ$>GhQr9(U5_Rvd2ub^$*t zgU(mpUJ?s=bK*xY`uyha^NXE0>iQzQDM!nfukCENZm<33^221a-+zAT<&`_((`Vt@ zv)1|-Jig?*Fl6<``dV=J&F$W${^pmRe!u(j$=xr8tHIjF2A99x>%ox64}+_HKYZNn z%HGvW*IMB#HyErB?>@V~cH6yw?WtP693_u_wHJh&QLoj1wYIwb=*qJ%*LGjtfA!>> zUL>zR`hTpmVI8ALASjg8jYa5!b4r9CNr9kj*kc<>>55;7-??qIDjY zL+_~3`W$Qw@g^DI5n>1jcLTVf8SND~TOwl&pBQ245vH^t!T_?D2;`Q|i%XLFJI`(;u$nrmTUrBc3aJ z7qZGh{gB)v-!8mnc{Q(PwX7^>vg$s)L3aMWPS$~6JR{G^29dJ|b6FW?)#ZFr)~68UdIx5hZ{DQ6pJ8GAkjGmXYors z5Q8vM?xey2U<-dw6@U%Ju`)!wtBg(@h*cbA#)%g}xPkbk3cmEhR2g6bI9hbJRrzro zZN2oOG!_CwiuR+*0Kx?~8-$4}K74Xt2UOevPv$9^dkuUC-azma3#HdL$dBsx&n;agj9m0Z7l*Jqn!4cXd#aS4hM0JLkt7F&*n^;2L&AyjAJxs1~^T7=asSnaj6zN3*XiyR`!cVIJFA+x#$L`J(MaQhFKhhl&(_Qr+MQVh9719<_6n$It28g>CZmGnfk4BqZGEOZ)3cd+f(kf}4RSUwDszze-ddR)J` zDx<&lwA?6;83~4Qf&3cj9+2_UF-0s+7@|EPh^Jp`_B6%ZfQ8ao$Oyc~0$pI|p|qII zM(0nt!BgC{VF-L1fuUQg2@_j!S{q7Wc7#9!PT& zJx+5I-dnSIZMN!@k2`=@3j-(3P5PjvIgoyk>i-Bj3hI*{UvDgF{%7A9tMC@uLW)7h PhPSDX6L7#Poi+XqHzL&h diff --git a/sgl/dataset/__pycache__/amazon_product.cpython-39.pyc b/sgl/dataset/__pycache__/amazon_product.cpython-39.pyc deleted file mode 100644 index b2d968db74e6f20bb23924efe686bbcb57553855..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3704 zcmai1OK%&=5$^684u?-svQ~~Zjw3AsV;Mx0yzxUZY$txNVTZP3$MO=p7>w9MX{dRr zyJ?BcOb(WT%_Z2)c@aR$&G{Ai5jpiWC!cfE#i{NQ_1a`JL_epey1M$Ss*i51Rz~n# z*u3w}*AV)b8XSL282knb7f?||aez8`i#tS*jnM2Aa6V^sih9ltOP%t;T%}WiITBRE zTBoMxjG!J)bS8A)3>x8NXEL1XOkwl@(E_#JB5HXPJ4WXW#m`Z*_#Spf?IxCm2h2^@ zp;Kwc)Vu2nmwTc~WG&eU9Ok*y>G^>tD>UAUf*6J(?y7H(Kk7W{;cciq)WHIuafi|SgERfW!$O0cE9@8UfJB}au!f}|| zPfmkwU5|4wx1E!?Vx7x^L&3Pxa(EK>LY{X{9LsUmyvSp&@DARW9l=~ba(ucC@?X1w z?v)(jaiO|RQ<~fhdeV&C(32*MV*yGj+3&C{IZiilId`0QXn%2WJr2D^&NmmIKX|+l z47V5gTCk|pfxBMJ?dcpCw3Q551x_&mg;$|Mc!pH5g-uL8jAl+K-5Q^}i`Bw|hd`@r zuI>IYbk+djf5REpcp%hxYp3`wAOY-YwwvaD)unx>Hwc1Ol#KREJpzwUd6n!<7VS#q>pHQEui~bAQG@m0j3u~H&G-94evmg5r z7?-(Q4)qy;p0t*vt?;gF`VKQCQGddeSXsUv_*``3FuAG5|AvA_OD9dDgeeL6t+%cR zan}v_P3Vk2&2=DJxGar#61CjY=^^=m$!V9k8g@}zu=YD9@6UI?DQ-E>@3ZH>W z83hnepcvo5&HR<@5smDT)mqcBz+^v#b!-kQSy#9?05jMKP^?VDw3^D6N7i8qEa#wb z1pq`wpPbb3NY0UYPH1YwSL%RT(|*8stQpv6J=*WPv#GjU@aSugnX#vif^-8ky`EW9B`^B1pfuD5S3e}3^{vbwSH zyM-5@-435V3zwg@S1!|d!F6HD(({$&;Lht?{e_LyE8UHa-is%9E)SQ2&3m7Pww|4ap@8JG291Jx66K_s`F4ats|VUd6abks<@fEqXS>(NYj|p&RBBBhMR+Z_|Kt2*d!(d zD)^B)E@K;)|3^DIdlH<7vQ8j!3Ph?{s1Z2=Pg0l?1TY5~gF;Y&yT>W+7buxUIZjWk zF!c&kSddQuy|V~vOV`C&)H7gY&murlAua6Mdq}nY5|lEOih&5qtgM(puFOibyoOUU zjk2<+9%@k8BRlw;<;)Z{@M(qFN?O^i@1eWstLZOER?TW@EiH?Qw7P?@qs_ms0HZMB zw_{)wH7+9l&n1NQ?+d6!V|OwwgOs|MN-gjsJ0*}2ODRt4wEAZ|n-Env`59WX{p z8l#4o-u)<@NE<4TuAu0OiNwsec=!UNY!aixLgYo4>^`D^P!mh?ccBttGsht4OB4R0t-rK_|8chw8H#A}kUp2Xg1@3Y5w;W})< zTTmf94K8PrY5WPc$qe+$dZvMY0d5G+_`w2aB)9=T$Y7_L#6f!K>g1us!Bd2yvV!)% zH_%Bj#T$gJDkcXk4l)>p>;*pXaYlr>ThO_{IDsdH@@vP@lp53kAB2peYw$|#X=YZ) ztY`*yuwaKoSZ{D-iz0=r@&Eu>b@fIMlpSqU`J5l|9<8F}qO>q(2b5skyq{L{b# zUNz@O$w15jar5b%re$*w&xFUKd5DRhD-nlTR!2|vP|e+jN}g%Q;$fHsUg$-_qZ$XY ztfIE3`^{p`OjJ5&zgJ%X{R5I-I--olF=MpHgz@wZHcnH{S0IQtph9pz)A&d;Rjs*YtY$JYkKT}N7>H4o*VI|@(->O QlMT01%?t4JDsCA60f`~KrT_o{ diff --git a/sgl/dataset/__pycache__/aminer.cpython-37.pyc b/sgl/dataset/__pycache__/aminer.cpython-37.pyc deleted file mode 100644 index a800d8072575c5b67b49cd943189b47ff225be9c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3443 zcmZuz-ESMm5#PN#9*>`*DA}@*52)!^*dkO=q-{{BMw>`+4LGuEsj)-AxYFDeb=L8Y z+B;elbL>MUGTgFMZ5^4uIcaBq1ij1P|GOcPyU6cE>S$g-!vh6}p3Br)Y3?=nYDpl7XGD zJg9Uk1}=ovL9J6Wup8C~OPwXoZZT05-Z>MVtex4NhTwawReFva*=~z#_qU%0kyL;R z`=enf0aSh`Q>o%xsyjTy=W_nJ8w=C#5kF=s+GR7oX zhYQxR1Q(XD`*z1Rs|DR0=v5FdN!NEeu2}n$9dJ<=mGe@kC_Pyc)qhx`ChF%-rz|UC zNi@#c8Shj9FN>yuYhp#Lp0iF}Toc!^E`7-y)@eY)8_;Y(n(4x@JH+C4N9kdV&!;jP zN$vW68vB0BRvyIEcYv?9-xM+LwtTn$uQ?qsL*897^E_ks|j3LyIoW zG5LWPR(rE&0YN_-gG@tf@&P8nr&h``80N413D1~srFQCM?993-WIVGj5eTW9u_whd z!g_*1EqfP=7`@;GN5yg4Dk-``I|9zpl|=P?g5V}yNJh|H6$$P65H?7CU%B+TI#h8_ zCJA&P2l~TqdYI^fIoh(dlgRKuH4WdWJD{j2`F=0#CW-I=o4xq)=3zXLn@MuKxp(W% zd*SJLGwFw$vx_%#3ERU{bsa3;!br%ndOE~Jaj0D- z(~*i6_9G?9b{}ITKfy${JL7^M*=jX~70>veIjm`~GuX^pBg;DLY>li|*nLh%C+|h0 z!DFd5;)4wt%7Kj1WFx?T^t2lWV&foIfQO@W!-$|b&OE1kVG%o&kuHDM4M%cEDa6sr z&i1VxfB*A)JO1|0o!g(@-ru?3veY`r)!Q_^xtPcR`ub}MQurmQ`VnvmxrFg5_xR+k znQ8t%b&o`%A4x^SSfdUAIwPXi*c3K^#@6nnxw-JMaZGzi54Yf(pveW*9yypS&@X z{UzI%@xRc3tW3jbKx&(}$9E0;7SEmhntw~ah)<$0?uwt&h2$K@CyjZe&ykBeR)Mcz zP+%teAP?ZX@gKi}AGg_~&!$M3fRP8**^#S0oWp{zvkO@95`NCCEfzHw7>azem)gR~ ztgo3cd1Re^&Hm|}an6pb59k)8;HfiyXLi$#sg*gIO}wz0`+mu$&hb-2TL8U_Q5;V| z??AHQ)ccAdLwT2cS~7Qg>rA*9kQ978TItl0>NLGb5wec3wK)Rnn8>7RtE-Ij9sh8F8S6@S3 zr88O8Xo?+ZS$oEhw~d}l;~O&$by%KdOGE|iUktYKDt}Q0rMctCea6*)=3k-;?HY70 zLaxpsaQNl{OG~hY3w?N2&MJ^#DQldSrpsv=XDgz1X8rBXw0YV{s~5FwIcrkpf54&- z9G2FP%Fisd&DnH?v&m{|XDj5%bV;Bp5EWVj;QqAokbTY~>!I2CJ-g2y<}q8})qHFc znJyvu`4OskGAX~zgOkz*?NJ)kWs`J*NXYSX3V-!IOpAb}h%R8PU?h~+`z&LR>QjX2 z1p@Dq|J^!v_VT=kH}XSJAg^2P_f0yy@@gm{z9?^@h^D=GFueJQSQ3AiPrg5YF|!xi z-U{PhH%x8ojh(e zbb0sgcAnz*6g^jz#b6@QP^chEbz@Ng_Wb=n+`F@5RMv$wR=q>*#F0cnRo$pB6{QgD z3}s|8lD5%3pcTp}D%KC?B{UVYVf8R4u9Kr?J_Sfs61}*Ca>F$1s-N@@Kv-S4enWW z;Bb=a?xH1`g&ky~cxo+E=R~E&_jTF#2eBBTUGjC+_m4;2FvqB0fmE-VX5j88ejYY= z{d@O!wr_vDuYLfm`ZZ1TY!n5owo&`@fbno6+O+gky-QdsS&C|wdPsD&S=-ffzjhg5H4Zi&t-KXg!^T5CMT?D#Q sMY%*T&vDpkkJ2DaOn+*me5q^><>9fFyWl1zDD&hg6g&h(6%MBPe=^fi82|tP diff --git a/sgl/dataset/__pycache__/aminer.cpython-39.pyc b/sgl/dataset/__pycache__/aminer.cpython-39.pyc deleted file mode 100644 index 051a7a09d6aa2f2a2a93612deabb25bbe5fa45c2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3497 zcmZuz&2!tv72gFw5ClI&eb};^q}8<354ID^$+S)nt|ysDaXsV6o=8ne+QTf&qzNkKTI=-xmny5c# zQ{JfoZioc~*F{q-K4+bVSQ5*4FMP!u)@ee&E6{6T^s;qn*d0ox-R>ycPn0e_mGMZb z3ItQ+9_Sm zZ7Xw1vEVzaz0tF<&HS+ekeTS6{t}bm6Dwmm7uKKoW1h3X%IwU^+0;5N3C9`kb9mtnlSAw7yb<;EmUb63gy0M=O<= zmQskV<*m)zTfy#w`&+^0*4;aw-r3#ytYxV+u&W=?M5-6l_W|jvuPI4^m}crHprzy% z#%tW;<9BA}`TyKK5^4UBWHgLL8UUa-B5RdRUE(Iod3SqTJ zK|nF^y2K6DJ-iShzrYEpx$KtD$5&=@Ugmj2K8imh-`fBX^*m$8i0H?bu#RoKot&Q_ z+KcEe^(hUFOJ7dfUrt&jZ71naJCBnv){WpykwApF;qcX`CYiA1zDVi(Sp=Bcad+KBVO{t;jj6+5{Tq5j*_nU6&*@WVtm}3 zNB*36>6{htR52(}(_JhH;m5Dx(M|U7^NA%0MrK%JPno(ghXr3_$FSuIJf2%OS-f;c zp?hF(nJt{$`Uc6DzSr0{?C;K$b9Uh41SIH*Gdtzl#LAuArhV`>ydSWMb1(#73B$!G zgV%vX!-@AbL$2~p_@rXa^wyYgF>nI*@aRpqFE8blo9q%tdB%YihQlR}%sr;Pc|O|* zT9r>Y^ePke;9Jnx#5aYbBK+q#o#dGENflDnphcZHAlduS!pJ^ph$=W)=H+$#8mq{z zv?8w=4S|7n^%wl$26kAOt!}{5EMFiFP%==O=lhBpV6*pAP}M)@U!n$$nzSZHrY<0? z{!3PY_AWHwxt~{Y?n2(2Rwj#?kCjzXpIU$Y{bXs}%xcH=d@*05tp5dzZ#XP#ys(bi zoK2QF8^4{|`7(JdTM(!R#Pa}AI9dMVa_krH>xk|zfOGJh+8lB49(!S>pIk@i((eI= ziQC4u=Fd!G(-kDcAV#H6$NqU<99Px}M`=;}CJBYHkk4Mx6;MCKbSY>lvP&4t7%AoZ zKFis|1_kah!taFt%{p*)iZs~bCfOn1TkY#bihK_%E~H5RCdzBpO9sPR4{1ZH8!op_ z9gOjh=C@~dW7{{Qq}PqoTL9*(-jIzh(++KgmaLG?qvUu&Dsw#Y#|L20JbXQI{2FM9 z@gEG(kF<68sMXZ|_Pxy_`R^#YxG0+8Sfc$!g-H8nXQjH}rjnfo zd&wvkx@M@MD0W-v_IjhiDC%bDG7JUs=tUB~H&O+LR%`cJvA{hl-n_3ZwWs+kg_{U5 z_0yYI0CI`xHJRmbpEvNcPJMfcf6sO}VLaXdrGR~_$^BE$E@p=>o%)Vz``kHoojJG1 zE$;reX4h;7m93bYe8s9^t#j(w*q=A?Lv=oDUS>gvn5M{TEmDs~g(nEK9|VI$jL=vG zx)uZnqi$4C)NjbytEOeR_eoHM(7oXPXIq4Go2vSm-HGYk9)3XIhI#$RGKA=`wk7y zHeArAoTP1P3AAGx11kO69frLpkyKl>+ed`CW&%3sOZ0=83VZEY511BRsS5{(alg=H z2|c2rw;!QDHN9pL{8!#2pnp{*I$&{51ExL7!YDNztda79l1AMkNNyIUzeVZFTzYD% MJhWCd3j_220Ir@_j{pDw diff --git a/sgl/dataset/__pycache__/choose_edge_type.cpython-37.pyc b/sgl/dataset/__pycache__/choose_edge_type.cpython-37.pyc deleted file mode 100644 index 73fbc656308dcca39e6efaa7219ac177080e49ba..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2972 zcmaJ@L66(U6`mo76s484+O^`{b<;vl(ljbsZ|WWz6kDUHcM}9~Hb%Vdwgggw)=aeO zQlvC9veyO$6qV5?$6k7=3#^YVddi`P{uEw&@?XfQ-y^katx=T342O?1!#D4J-#7D% zR;$5qy*vF-eXzmUf9U1ns-f|76!R}sl1ZMhuzK>4m+wyK;OQjpScGEig&t?p4SiKp zo~rFIDV{shlm0Ui21?*2TMf2A)?urw2HBQ<8_$JYlTFyyV8<<6&t2J;>##IS3-uLM zZ_Lwq?#T_giLWhrO?IELu&vhR8>f!`TD2Dpp6~~(fBhwVXNP^K5cfvL7T(cplBmLu zBAZ1K%ylJ)DzZ-|FluFIIvsEtg3Cd1QSPFcO;kBMXO_RPPd+|6_^j|O{jl$74->$}i@!w)-1$4x6#j31w0n|`)vhtq z-3K4t`(g6rY}X8vT^ZZhD7!m2$ugs&%dxj6Pl}BL%ya~|A6R|6Tolb@1S-ZGTsQH= z@*0N7LELl1$+! z3zXQXj90j()gz^il994yjH+sl=Gyv9jk8B;AAubWVyiA90sJ!4M#Z>52m;(2X zz9?Fkp@`Dixb&lFO|+7ufsF=9Y?cwCl+#s5p^(wL`3I`pIpJSXoLm%v2mu0f@9dA3 z=WM|dx2JRePwWdWxpbZdxwjANyoPAienL_6Frv2LAF#Ox_=p9UF2I^t_gqMkv%fOw z{mr%B3x~7&%=+gT38)3RdmiKgML8Ge^{VwHv&=T~U;z(4WnZyr@0i&&+ss|EAP@fG zWughiw2W)CjQ3Ctf#c76z%bHU;dsHn1kvFTI~5wq{j)O_gzaS-3pY_|;R1;lsB&(- zuREBk&x`BoY?5S%Vw7gGB2q*>35}Z4^-W`spgzdPlRJI98*;{H4sY_OZ{01CeUZRh zdr3Bk6LSZRiy@^CPp{pL)9=~d7o$}6RGQ6(C%yNcw)#CvX^+TjZuVRH4a{74I!=d5 zcc~pbj&(Xpho%tJD%O5!ls>L(pnIIkY+U%D*Jz;hI*ot^5Vk7h5&{a}WZEj(uWOQk z5;mG*8D7UYqtaKjG$p6QodM<@fxD~lu0C}$d!qE}2!yKOQMh@DO+&&&E)=a)J$}`- z3@B2vMNX77XkQK9oA_YfLd9Gk_!n;AwB1cN;7zBA_YQZQ23XPHT}Qu-UeGO`7W}l5 z1wsj3Bng+Y@F8Au=;$dkogA<}WqR8}?-%?J0Poy6dk4@5-~jz?NQIL_B(_tcIGL#ftuw86$RzYTcDUAO{V@Yv`A_c)x3hFuj32M9ZdS1`PsjvSs4e{4{DIr4t6T;BQ z>m(8P*)%+6(<4Nk80;UuEO|&l{y#)&_InrM0gu*7!7K?>Mz0X#*q#(3$%ePf+~^xL z-M6SB(3NN^+0$?6?~~zbc#G|e$zP#)NtMp)+d#NIj?YHpIEl21O_s(F6BQY&CMFE% z5e-tSZ{i=}7O~@hO#C`Rx)n=40e(ggmKZPuZ1V$D%n^VJ_~CZ=4KAEESQ3=rA-*`} zE<#F0ywSVpKcGQ$5kv@VB~0$4L8>oZDd@l>Yjnhs4LZ;an{y^xb1vI+2gRkGHFrT3 zz7TV9*3AXz)5M-o0tPok2WT$+r4%}F=T)41es)lW@ICZL?JkGQXH9rWtj9RwSE7U+^!(&7J7uPO4RGQ22e+-!dG)072|$+H%JX}S9#Pw#+M?9BI=71 zU`G=U^tY)ZTo*pJ6Kolp6j8fWl?NEtSY7Q3bet$$xK{z61g3QQa^^H1t0-D=?s6|I mg1vDjXNkH)2P8xPBeVs);|S5Mb=S6UZq^0A@tvR@@c#ld*wHBf diff --git a/sgl/dataset/__pycache__/choose_edge_type.cpython-39.pyc b/sgl/dataset/__pycache__/choose_edge_type.cpython-39.pyc deleted file mode 100644 index 846154101741dd027e4deb6b5993a89506e76e38..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2962 zcmZuz&yU;26`mOmDM~A?wQFU6q-iP#sT;M3*EI^%2iFMd-2?%gZLxM-lt4hx;zTPf zMM}fb`Uj{>WuVEiy><(%duo85a_Aq!Yft_ca&f;ml-61!Bsd&C&KtgY@B6-C)M_;t zj_Z?8)$I+&{znh1M-303;L87tCT1~DSWw-0!0odL1nwenMlz5iH*h(NoxoEy<*M2) zi{+V!-Pn5}17As;WUImE$9351szJ8J*v6TR*WxDZYp~;tTW3zZ9$$f_X)QEXR--Xb z`^=3u;!V75#aH9b3l^-aEAcxgLVvH;=M0hX2duaC3bC`po+#!0VQxzI;dGp+(hEbA zg(1v!6%SNso{eGD%1m|I=j5_D1nv&5d${sVGzB|nhQH)9ZbZT7^bLO4b4oWKCqq+S zJ=p*3Abj}eKOThp2lqe!;`4_GN2P1%qdlQrOn@+#Uknq3^c~Ywh97>tcbtvXUY<|( z9(;EH659FW5f&>efscR){&DIfWDRT`YT>p(l3PHk-*p7W|ZJ zzu@Ph5HC4r6ZDY|=kOI}P*c-$Y!KQ*8`E?+nXW7;FNHEdiV1DL zf;Fv9mCjWhDr0k0HET52_K_N8r)nRG9rh!mmYD#KS-y^jamjt09b8SW$=;J?>mn6l zIvv?K%GOw`c-S{#KZ){1hN$GrqmB-xjcob9(G=pCe@l6Ca7p9{FqpXpS(|g@t(bX# zWnc4{$Kr)wxcjiqYRFdYx0E#(<7;z%o54dAa?WE1kWGwpE@N4+f3Vp7r(@igf&;>n zGvl4Z7jW|n=iD!R%5x#l>s9X?W|(Od{u~kffqly+@0&H#EF5}9CA`B+r6w%XLax$6 zK0-49n!oA-$tZ1w<~jcclt)nPL~0cHZ(3%&-EO8g(eqZXO9k!{o%3tGmVO7HDP0|<1Eo9E_n$;M z9j1f4l+-KNzRZ>WQ)PnyL}{FjN)Kck_LaUu9#92BT7_f6Lh0q1HY)DbH3@vX!bQZ-!Qh-KG$1 zcGoYb?mfJ)Rc4L{E=b20P0>J>JqbpLCTz{uQb^e7+!XplctFW`TJn=hCo^}A$}`scXm>u9Jj`7}YYhHs0*+P!S_ zc$h|Jn5DnN(42&j@iwFYi{@Kg{}RtVw`}Iq#{(UWkF6S9u4a`1TJiEAAQOMARDO@n zg8n3tZo#o;ch13^_X=<(2BRtqG9fyQ2)=5033=E?PG%nRRWt5FJOU8C77Iv4t*8|q zW-CYnR;0lvX_y?DdV%kO>dvfQ)C+ma^xr3s@Wr~76si9MnP?PsE6hieN0TFp3X<^3 z@{tm3bFX*Nd*1AIm#F}6)~sAw!q`leauk{4QYP8pmaT*S3Aq!(^iQcFtaN2y!@}cR_zEQD52*EvAeXCLp?r}{{eR5nL z1PR(jEi7)MGeGhG#b=b>~ z`+pKm&2grK_KL}NFIi4iI#t0&HMa7r25qjI7E@gsR(ZDVZPjsO#j*@x#ZUIfY3NQ`7U_2v4P+WT0EU+PH(mG9heD$;9kmGYf42zv z2VYi~3(7}uS3iIdM9`FEG-7n2JCSQ`UgW`(rGD0mT6WDzgDi|ftGj7C>qMQb8+Ef@ z)XVx&pOV`|c*1{3gfGLn6Kx23mke4@;8)1bfEw?%ipw!5ot;9+PhuTcQV(3yD-Tny zWGwjJB$cKkilaO&;8hnUu=c#1nJX&JMUnBUOedO)#n%a*y}B?iys8}#l0;OHh>6Wd zWRD70cn`zKk*@GX>mixb$OAnPq1AoS7M+JAYKgAs!7~tju>sHU5phY>hA1wX;8qb& z^|(-mCDl_l92tM%QL_47eB$k(4|hP+RM5kg+SHK8IsJuds##6u?9{25aJFFm46?4> znjCm@S`+Tq4#-wbPJ`Mxqu=9O2zMLqjv(NNyE9-agtH^39G}dtFSEFv$C=#D3%y-U z%d${fiW{>qS$aGgbWQu)IGxH*m4W~~*d&z-;20BdKFJ}WT%qgcN~wxOR+U^7gO{-$ zSH`pXHvS^{<|<$1g7bZuOBHLmTxxTz;z`aY;uwzpF;1-(axE*3DwNx;WV&Zu^w_wn zC^YzkH3eXVoF{2qRh<7xo(@LiB9o)4dN8_s`^%f@$?>S#Pe;qjk5<7A%M;ZH*E?`m zT@Zu@@b?1kQlHK)tvnsB;!G&+feW8L+_&JaUIL-XDHP#=36##Mdq#eP!u{@?F-W|# zW4c+z<0KX`nkzAp*CO@W36ltFqh6=UORt&@OhZ) z5U@w^8=L&MK*128U)a{V3})JFGJAilby>YP8tqT?c)B}G zifmMklkrrI_VfJauFR8h7OTS%(U~ z&1peX z)Q<2T2TfZCaBe94Is5gCrgO5X+ozqnU3ZXu?veZ*m+01FZery-&U)IlFHy-;g`xKG>la!H&Xif(~}8$xK2DV;~r_X&)SToCC_iGGK-t z8+u~c(e4xb#fHKbRExw{W7uRoU8rPi+#;8iaR7G;@ov0)nw2MDF2RDhGY-rR;|egF z*5_YYYEy6Ef-8@KPI`;JdmAlnDfT50osh@s9W(@R8qZ-28AlaI#z~6QbnK=hIoTg; z6XGJ*C*V+M?N+7xhN(S6FE#?S{u1sONp%ebp_{Y~5V#CbaOq35PZ?M?5fsmm((TO`* ziP*LryA6B}U69`ZWFv#9k5C|MLNi#Jx(d1p&q=(s;_zTa;TPe*d2t9}#+O|H;j@6v zuAX-BosR`7tnqF{|oS+kHR{eORF`o45xaMR(23u l`dRdbtrL>Ux%Ugr&XM4QMTS1acCrC<0b&8`!u)&D`4?AfnbiOQ diff --git a/sgl/dataset/__pycache__/coauthor.cpython-39.pyc b/sgl/dataset/__pycache__/coauthor.cpython-39.pyc deleted file mode 100644 index dfaa038867bdd45b6c571723d555c2119ce36f9d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2624 zcmZWrTaV<#6|VBF?Y>R#X2YQ;i4-ZID2U4{5gHIHfv}JYud8O-v%YOr zH8a!LJkTrgg!lnbq?t$f75oAoc&yhv@fY?f=ajo=W(T+G__(TEK9}!&r<2jBPvCm( z@dx799wGn3hpU$d5C4Rrz68Mur#UHT?9fhkW6#d{u@84A4~kygvukc17Ev5o-OKyM zARZLMcvy_$Q8A9kl-wuW=fM-g0}-v=c!$#mWZL@;zC!k=)cE&hQZGOm>{nd8m*_+Z zJ@w3}ev~sQ63z~?T$ll`j?27)M_rlF&X@Jd+>%MjtAeRI&otxR+bN!X^+Mn9Qu`nT zi76+s!zbU71Ij(_KZ#;jcs$^}CuB`yAM}t%Ru6ce51x>?$A^3b_mGeI4&0;fh)3c+ zcyhyp?^elDFDhxAR6TWOa}#t{rQ7e~w^7I8!cd=pXeg&&Q902N&o%v)YO0-vterEr zak#q+>*o-C?KR}cU(<%Lpm9O=8gdpk?m7JepCTB27z0kg7H@y*$Oz7k-Fk7ldT|rY zy)r4pURmiqwXEw(YQgWWqI6^Rd^$A!Pm_Ep-j@=b@L`jXGK6DH$XHfFV3|bI&8=Ei zsZdIENnv%O7s~iH?k4DxU~aQ*JQzC^rI3jho25R}GRaDo@e?@ut0cEt#I#TvRY<$1 zM1EjAwAgsEsx;Vxx&4uvh_N(Jlw#~h@^m_1RE3x;^?3f^{)cbmrzdlDn9nx|m~Wk% z)u(a{rVn7KAqYZ4_&rC5G@z>+TT5qKKU0c(HW%svjCWzEJ0LVUgG3xT9Ho2aos<7S z^8V|dI}m(#-wX@Ik~HC>$WpfQW`NMCZ`?}N#yhIA5}cOFG2@w(yKtQG^Xgd0DV5K` zlf*A{;w4-h<5t9e7v?{}oN_=AXSI7__p^JRu^#tAp7q#FqdCNLUL;BjX`wWS3_hjvO)yOO5Xi}T z_$nA;{90zEmOkiHw~K+hlzA5j2HLei0D`wHhPhT)VuX7179etU^CBb{o1Tfo2=*e1 zZ5)09MHBFnu7L+y+B&VogSPjN6mqmRF(>2T>Dvb+W#&{g2P48bHSpt-=e#qxm79UxbS$#Ng*Fi`GJOgLMRwk*^w2o+z|iYu0_3nMhF9Qp z^#TY&C$taXxCtQf=p8zy4$O=ZE?1~)^Hyrf0pEfPG3mmvX!#Q;*liv;@`Wp?0nD^s z8EU8NKaB(Z#XE3q#YlwO>h8z60L1TQWxdpUz}FeyvuwNv%nvnE-UZwv$H;e3AiqMX zSTef}x`{67zPDxja7*^ri(_)VUAzfLY-$8T;DpZV_GOJ+Yb{_0_C1OJajfmr0X8qJ zRmM!8v7+Kjpi*Xrj6Gf^dADMFg9#ymfGmA#F=T-uXL!)RQQQP!Mw{M~iadvIcuCcm zCW(A)8zz2UO+bWn96Ge`!tlZ%+_9_EiRo2phF#NmC@saBpw7}qxsY$;DSrcDZbB=Y zVJn*<;n?=t0r-!iLo--K^>G5QA< RmF<;709xn_p;tcd{vVcCodo~@ diff --git a/sgl/dataset/__pycache__/custom_dataset.cpython-37.pyc b/sgl/dataset/__pycache__/custom_dataset.cpython-37.pyc deleted file mode 100644 index 902fb0af8463d50c2b0e0e78f2ff3375f98fd2c7..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 7129 zcmd5>TW=i6b?)lh^jtY4hc_*)OxBn6IJ=~=abzb^HtQATUBxoBmS}5tqD{Ls)g*_U z=^j?sNaSp{A4~-{0t8WB36j+3c13m3F0HZC7>O>ec!S?S+25UDtKHx7c55FZGw(%ewCL8vT{_ zN`JMz+Fxt0G2>l>yS(ts;DvDM+-k3*?(rh(#c&z*4bC1Itf%=>1VT9&E~j(ffc1Mcs4F&)Yxa-=FYQX+Y0S`A?K{^psW{xvdc@n{LE+< zc!gK-_IQmi;9dNwVH@od7PqL%w}y#~`~G`zKNbueJ~g+t_c{gQh#$R401Hp@7a*yf z8orU5piG*nbBugjbJu|Rn+lub!{ z(ppeOA7A(pl&?zq<;$}{sH|QXsnQRF-Y~o^L@ZR@5BWi;JAM7Fa+2X76rv2a%Jcnh z)Rn$3JgPPZBJPAq66Sm54+43ZD97j6U$yDasLuBf!YCAh4Eq-bAMbM4>E*68I#5Ztn0|+<^}BP$R@jcVQQOF&HNm6loclfWgo@eNY(tyO@_eAg>Cf%iCw0RX_;Dn<&m(|Ft-x10`6bHljvh=*=*FD zUBKnsm#dvF_I?1Y6D? zIT_2$%*yP{k%nX!IMGKq!!K}(#hcQ;+LEypZs)it4l&9RZK5eVa=Eo(9Gl6$EO47n z|H+d{D^9c%)o5JuM0(uOz5%rie&MK?6=rz4n&)q3c zd6=~{p0z}B@bVhjV2M{Yj5Q-G&Tyy_hv!VZKEp$|WlriGm3igZMDMGza#2kkUVC9e zC$)=()Oo>Bs>9n^@Oy4rNQ)`Mh=srQvXWc`wpIO3UQ!9vaa% zY(uUbmoV?onUPhPF>1)wv~tei3)gs!-VZ6Fq3_3)4~!3u$oxRJeBZcle2|;M;%>{? zrQgUo-HLEv$zn#Yg-=3p)+CS7Ob+8=k2ep4Ct)*a?uUUKim>@4?8sP*ZooC-_>SPqGTy#&HaHW^dmL?t50)FJ@z2G`hmWZD!s}uK>P0nMqjcJ+~6Li9; zbJiTs(~X*)IDR8PKIE*UkbGw#&Xnu>GWPvX=-iHqz2I@!qvd=;%R$U!r;42^^XcQ( zx+tPae3eLx$XAFEgb+0l`0B#|UcL5h_+%lwaipEPSfl|9M3!ij9r12onbAPGfjH;~ zr^L%Bq*uE-Gow{NN$`t#G&ASpf&9S$jMJL|^ zF<6Z);g4MMHM0ilnQ+o3wOuzobKPv9RKZ`#T*p|G)%EB(*FB!LqdhEUC2#HE(e9v! z0^EU>AFGs!$1?{fcL$GWj*V`~9B!q~u`AxeVZhPLOm6G?ebjB#EsleQ{>P|0sN4Aw z+PC12-Caa^m~^>|ZGwep6DwXu%fcjBded0xywogOZXREL7po8iki~|xLJ`qmCAO;d=t1-P&m3)D@vw8Bwk&-zMFk)4waNq|C|3ABV1VC` z`49-tR~6qs83w()N5@{}VQT%15wp+gztpC9&ibH{YiXymC6)P zHFzQB2@#Lg$`z;LdY1PU-P1wi^hX;{zeIve;UsHYyb;cXGezp^yQbh!Jh z+Af4Hpk?7#riPbx+Y5_1yadoJp{yY$%H=W0tnfNt1lX+dCBBSzO@qz_7{H3EPL!#k zvv?DdXxIsKyhgCikN3wIF2GGRr@fpN|Ron;!>^Mc{3LTw!{XJYap#cet-x?2~wE& zb0Uw3G>Op30v6G>rj1V<95Ma~PeN|lfZHML2**UYarMgXPLU&^)OCzit_TsNqNz6` zuVmr5g(vwMh%{1cl};-fl8w^}Mt%m{Vj{>R@1YP6Q*w3%VVnP9x{x2%-^YhIxCDK5 z08z_RMf*Fq@80>|oxR)l$M{c6etGuvp;N0n(&Jz5)cuv#F4roH$aM6LZtpEzqLZ(aH;Z3J8YM?hM<1{PBoJrk#C-^A|`c^zb; zwKu3i8;vH~a_`z2t0^{Rr#(dtUh_4(?d5dPmoX6;ZGyt0S@+f)P0_cb+c+RAQHuszXC_e&F)@ibauSlMM z!cK2c?~d`&>zONya6xdra6N#VHw{@rTy=kGAe!R6Vf@ngJHTvaobZQ`C@XL`^G;tK zf4>fB?E#|VJG}MyZyK3*^7ndd5#vjE$}@5owA}wOBe!nJOf4N8Af!9O>}5#zqpb2X z94N! zfEMUr5%WLST3E&`4a~MOqXocyogDnZyAY9KX$Ci0bmYvZ|jLV(`QR!`-4o+Iq>;>A)fd)nAHq0O}`GyZxyD6L(rr`OU~W)h_hBubWejznowCrYJsQzuH-IIfJQiPGkeH(#14jVhB{ z62gadwJdR$Il*M!oN858L6HRlO9fd*Xa2nqU zk0hhgW$&ms4fc7Wmcx`8m^*6CO;5o97g><13qc?#@zZw|;u~;2f@0Lv_mTS|^hV)7 zp-_K-fcQ}@V>p-yiZ9sTnkV*yDLU+61d=Or>+3ptU5Ojyi9apfBrMX2`-8Xsfu^MS zT{imKoO>MmC*PTZn4p9pD1{0RthzbH%urCG`m zL39v`Ki1_DZgUgm$~;XZg^NEWq6634RQfiNcZlqOw7!rBsX8$} z+=q2!dr5}|Ntp5`AEWh`;8 zNxAlFA_F1)L7ZGdcbqM?Haq9|-a&s*9R!?%Pp)NEBo3!OUDmp1dUo+=B=~oU+$Hiw zkckPC_38}iV-%9V0x?*_rlfciNpS^ebQ2eGwuWDIlS(%F8m32yGX9oKJnXvVTBuD& zdPtv_SWVY$N~|%H$JW_8W~iVC3AVY37K|A8u4B%cSw~VnlX}mi-~a!~_gvoG`d?++ z;`_8Iv`abBdvqVwf5r3iwEFs-vzbqKq zr+tPy=yr~Kp?Di#szRJ>(XF_` zaA&_o_iS5_ae1fzuR&i{335eh<+HuG+ZU862)Ytii<1Fp-WtkoFUb=`eX&-eVFYdS j^l_?#%RN)x&HTdREy_al#^rY)bjwtSvDabr_1*sh=~~#z diff --git a/sgl/dataset/__pycache__/custom_dataset.cpython-39.pyc b/sgl/dataset/__pycache__/custom_dataset.cpython-39.pyc deleted file mode 100644 index ae0b1cbd149fb370b2194c7d2872ba14ed47c863..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 7049 zcmd5>&2J>fb?@r0>G|aFn_SWIL{2Qzk;t_XtPN+Q7}l;<>ue}RqAhE%a=SUzB!@fG zJ*=+blA~@TFgLId3B0)(wi9pdE!&QQBR z4nBCA^}4>g-h1`&-tQHIdc9)c_od^v!hih(!}xb9O#f^YzKN9l41^n;^^A7@W^GfK zt+uV-PTy^pSl;IK%k7HJo4snk)~@yI?Yb^oy+(hoJ=brxo4RcG=KBlnh5llDQJ0|l)@Ts}Az1t}f$KpqM62d}C zz7LYxso@)`3Cg6ITIYUPfTV=FEx(J=M1ft z)VQdo?h|&O`qq%@NDXd4C~M~vV)g0k(XhW4idMYe3VUHcjAYX4;x9T0dR^Yyj|K9B zp=?Q7o6R{@_VI-uVG&eCzkGQ%2$j_fBUSls&>M!gg@}b}`XN6Eb)&DpRZcP-ghEuo zR(ZbPjk?nJg-6AufrvX{l7#u1`GY_nCd%&=%Q`6`B2VoS7K!$!Y=yfH8kRs?t zewUwOR7kCJWnYGgr0ix%*-6;jS9TQiLscfpbCKt+3g}HVvqc$F? zQlgKnWwTLhT8YcCFEyS~Wo9aA3_1?tRb+BbU!LUjH0l{F`#gRak?r^ban_T^M$6ND zCp~7E(VEkK*Yox(u;%Q(ld;Uqtjx|FX-IZ~6Mlp<{RF34d{x@BH5p4`iiS)aqL(Au zL{oO;a%gJ;P_-*ixcid4eFPi$P#z7tw1S*ojCHc(iBfu^ZfM`PhFFF zsmDv}2B@cLc^ziSG;*mU52IF!QOhI;uddR5Ec4pBv1(-HDGqhw@R*6$rg-SO%t@W2 z3a>pd(fW$4Uer>DH=dZ7llny?b)GQf=3xCS*g-cfrR9{N$K2m|Sw%L%ZJsyJ&7XcR ztB#s-;bJkZq*bzh9~jXM+mK6-+1bySk=2+n+K|g>?VQ0fuJ8uEKQKPTqGz=q)}l)O zzHVGcEr$WNjdP0`(ZW}Z$1M4$>o|1ue2Ap2=KRj4vqQg;bGjAbVAF0Iy&9f`;;coM zrIj4U!ya!P1}9-FXzhoA9Ez}Y5_V)PMmJzWVIf-Bq$gc;4!iOYws+hMr$BU1TQIqo zMwtd3H3EL*=dIv6QkIAxDytLslubrvw2fg}=o56psB_jT#_2|_P8`3UA1^ZMQAjp6 z5NFEueHr`yC$y7CSz2#2PVh73;`-)vPEij~o^`Jwa(7*<3umwD1s;`+1P|t+%HmUAq(=*r1CFE*&D&`vc znyiV(yK2&lVLRGz(6IQC6?c%dY3LyX(_rOCEM;PE>Ja6oVQ=aP>6*;pR_dG>;#)Wv zIEIgImu6=)&8OXLG0-vd!~?L*=0HJdBO93bI}JHLl-k|hv>RoIH2c@OicA9J)Z zV{xy1x@qatC~Ek4ce6?mgZzi@10wkSm=A#heO2@Q<6+RtTeM$R9+n4tSf1sZn;eX| zL1dQ*xl2_qyniR`^?+F}9W40+LEA;c8e*+oyvG_edcgXMVK&^h+ej z1az{t#g{@m1jt-BKwVf^`KSlhQDKw5HLB@Jl-ZZFCcNDF}L zSCMMkIGa188~xw`Y%rNzc^Ck+!Y+~$1aDy*5l7@!S|i-NrANU=T~&78xpmv${qWt} z@2gdM`R}~t=XLM+@4k2Y*4;OERrB_(J4Gjd_dD<2yDey|iajC(@ids#)=aFTs9v5R zS3SE|H@jG;Z>|w}31qXBA0t9Hf;NlbL_#9JM?{}5z2o!-)Z32OeuR`XKnxfk!j>>l zgezBO)^>s{H%;+*^i{41;hdrwt3(#cLV6P^At)q`6zilzk0r^9(1Ve^!J3$G{D_ap z?PD`JJA&|zAW8D0_{YfMSk6-gxd^p5nLz*6?R$6s{O<1U_aKf&tt8Xb$s+C4%4Bop zdyBkQBE=!vGz>O0j2`{|tfJ5VXa<;(qobttBvTA+COGg*D2wOdKy4rTXdJ`mi9OQG zyY<@w(Bd@`V2j8f5xF!sddancA}J;a_?21x3wJXUOubn9{~rd{7V(t0(iSn;Xmyt7 z^YJfjWx~I-rIMe6c8h1aUfCXU;&Y5&j=bkel-}^8i zIu6!&6w&dI7`5FpK6)*45d%9Z>?TYG(Bn-*mf;%RUl?$Rc&{73F#ZmxnHk4>T_Y=T zH}i^b*MO2epiO*(x9W~DyYV@Wf1dLk;-@FZ&Tenq;whFb$+C}m0QG=t$6P8w~{`kL1{tR{< z*A-P>dqnzx>p9EjAaN5K&KDYf6&h;2vISm8PXj1>9?!xmOfG$qHnsMsf#&m%+3{au zghh<7g!l54?m+7r+yBRu_PQ<`uf*5EYbjl(@vwemM%)KB@cwUr=-CMnHgZeA_0T`p z2ONRikADMhD}}sASk30ed@ggkv~ObmR;NZY#p7Z$%1w=xt>vY36*pYQUzzC&p0!ov z_D#$(^5C|TuBMA5-(&OmI`mtg(r-@dXNrHG(offARw3IrAe*CaW=Jnn{0}67bE6zq zY5}neG{YLL6Nkv-MxFz&f02t=ZNx0ltC7xO&KJ}5^Ky1gE@ABDBDQ%6vCYcGYI-ew zX)3l^du*La#5Nl`wkf9@I<{HoxbvCBHXA?O(6J5P&y8)YFwiZEvRuJ*rm6Cf+T0D+kl@xe%4oBNpq8u9LGN1{A&f&x-6OKqmmCM#q zdE%h+*eZt?Qz&(`d1ZLA_PC>hR9y%HL5ZKfeh@cdb_6-08S6&|jGz~VDhV0W16co$ z8nkUM;Lo10zcr6-+zF7`-DILiFH@OY*R|W6atzu*3o6KHxhq}>dk8}MD@Ab-bbh0cUy?W#J8zP?pIHXxP=^IxF}4-H!16gAUX)e zw{(7lTiZmrG7tA?#>HDiwEwz8xwnbjCGuw=o1e>pz@H$mYTAhx>n`4)s`_{b;;ySB z+Y2)643!4re1jnr1M4((R7~6cU&Gz=A&w4@-B%< zfNTcw?#IK3tD45cU7ioH6m&Yn{;-ETO58Qmbt6Pj^CW#|D5maA-P8PovWz9}Eh%PR zNn{|TKZuh{+>HaKR;R}-?hEt>^+CWn_~f&!%EaN+r^{IPOph-9f&_n$2<@hPlW1_G z!34!PA0s6{2Qk=^O`+`uLfZRWWjwY7(fGyVSneQ_knoT%l`w&MO>xM&k}U9INF9ZHr^jZYJsw295y-1$GYRlQgM diff --git a/sgl/dataset/__pycache__/dblp.cpython-37.pyc b/sgl/dataset/__pycache__/dblp.cpython-37.pyc deleted file mode 100644 index 9fdc5b99ebc6de95650663405e9cd2f44edf780f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4367 zcmZu!TXP$?6$Y@&7_IMKYhY$pZpIxdFpqt6m>~O0eb)(0O#Ob@ZsHBt<3QE z+W!0gt80w?hZ?7UHi&mnk`R?-l80=-+k)q`)wWGvptUV8%uR#0< z;ZQrBQTi}ex-jeveWh(^J<#QD9QAzVM_pe#o|nd+*Rqrg9kfVeRRV!FeN>Gq8oAIx zR=&_t$vdb&pE6y8`p-JF8~1uaH|T^dOWP`rQ(YN$(uZE$^TxqY7m{HZq+Rk9`lBxu z3njUUDrFhZaK3*Lk9fvBky@#pu_JL@$Z+PXOj;-Kr_{;Vqv8>9xEVVxW!4G*y@gg8 zrGiqG3@2powuG9)s;Yr%T#z`fO>NVsRZ|7@p#f*ml|*$tqTm(~afEYJMH+TJFNlKF z^A!1kUK*;n>nDkyJBmk|CAwgy>Sd42;dy;O@>M7GJ)^2Fd#UOKkr&7VEcdt*8mi=_ zev%T|vbCM~Ved>(HIKdEl_k&XhMgqwynnH8Ufp~c5B$v}+26c-{pPjs@L)6PhnrKO zH%X_B;h|c=)bFAsRaA@%UKcB(%56Tbk#rlU`m}_e^JU-n(!(JtiW%)FKOL#4D`+bI z=u3uq4<-2_Dl*OymweAsODW88#6ROOi?z;R4j2ctz0OwY=nK2gY3le|G#WhcRWt52 z{m>uyQJORZ{6{#5KsI}^0zDk1O(TNhnDd**=DzqkBuh+ii8>!I8qF?p zT+oPo{G3M6gB_DzLW*OA(<7VonQEj9eM+U{!l#psPbaN{B9Bl%LZ$0x8j;Z2fmsKg zRJAa8K?BnBk_KuOOmCu0#fGP@8?k4HbWXjQ@MIP$MtIWK#5>8ABhJ}=^*-%lNinexUYYK+HW8Vf zS+o|uhW|}AvG*Sv-U9d?l;Xh{{5B*TPTVgTvZZ^%CndADyUwJ8f}OCAOJ54otdNzq z*aeQ_lmjyiheI5xb6h;n1?`F_9QqY$bpN0CX$>|hn>!A--&5>Qk{t0#1(H>vP3>Td zcp%;D(8lO6smls@S?Xpr{OYU7@iZr^8copyEo;yC{UTd;+ zzAsw@nqhwu*#dq`tDt&hE6~_;AzM5ypRjC>iZQqRQxz56V!z#QJC*`<1nWh#@tW*p#LzRXbh!Gv)fgd zXL}_&=OzS_c)fTONnJIgfy`H1>U6uK!6*cnk~4YqJPF?&sl10~Yv&-JP)g$z080h~ zy^^FIm3qTCA?P{VRIi?2v*&k!cziRVHta}=eNsP405(dr!+}Q_#5v4Y=Y1_yPxJf= zJN3RM9lm>d1O5W7Btyk48}OozUsWvdm(ZfN%Nv3yqRXPe7fiAF70a>8ygcJ_f!r2& z-LeIK`8KoDZ0^T+o z^tRb}0Nb1QRi5RRzde9An<5C&>4ZAEiNTi(CS+Xfm%ZaSRj%~fm0P+Aw{GZQ$}{hCXj{HFy5Y($DW;8 z_4L}(W|pENQ^|$GTuG>sxaFVV59tFJPCiq`fdl6EdUj>WMyu|5{q_32ejokou-Wt( z{=RweUi8`qWB;Va`9BAZcTn;Km0*G=Y{+{SFKN5yn7&G{g0Gdh!)mW;#_Yr!)_OHV zJBdH6_v(hOB#mLS*EF=7w1x}41uN-)7y~Q|!q0x?ET8j7LcX zqW;rJMKZf4!_hr77t7{eCZe5Cg?XgV>FoUE#w~oCqX$VKqfi8cIEnhi>-=A3*2sSk zB4Rx*SkKyret+XVTUgM`?p4rs(00XIRHdJAzpy`L18$^k_PnTuzSRu*w zWfYDgsU7G$)P6rp2a$}@ek47}tKA?_SrByHu3dImDw8#gpiLjyq>83q=pa*H=%|++ z)L$-{@u2?a9oowVgSa1uN!Qkn%rd3xqfp%ovOzG3N4k=al34Z0Rpx(0Q?XI<8>kd3 zcwq_tN9#T>SYRnzIR!hmPAUZ#);bgRDV$2V1-oB8rWszrPHKgH%70^{<)hS58iL`t zoV~6k7r<4vQB5iWXSb~#`gEI;Op69QLf3QI4~T+qK+G|YR90!&3xYU}RS-y*nkyrj z^`ktGN^c2ByIfbyQoS0GZGvDwN+TJnC@|`3KTtA^(;yayVD~Ug3{?wMlq(`(Oec?$ z!E-_7BKG={>sk=>lQ7SN;7j(O*S7Cv!)QCt54LY#yYYH*bhw@GC)+cjw@Ih1(UDxk z(r=*Tq%q?bcb@t7n$_SAx1LR!B;nRvr>>S%7pN7}ic#VLFSs{#?rnTpz@L`9zWqMj|9NQ0$cz5&T{6KqzCPnL~#7iYYr z75V%*t)L4#A>AxtoglpKJEYT0D_z;A{5q+8Jl*Bl_7vom26G0%Rg8Mgh{>Xk@TW_K?Xd zb6W6Cc7j7ZMVJ)URhBNDjZn~-K4lB1us&x3nU2dbc8m|BHX z*u)24Gxtq4bq+?DS3z-6s+j9QqS4g*j3IY=r+iv7dwQEpxG2~GJUseRs1}u?c9mV2 zQJu}e3Nyo{8Ol8&-n^LcfmYR14!x>0dhi9t*wi4ztsToZ!?3rvhBa0wl+1YC)%3DNE5lwWiA#I9amuGN0+n`MzQqpoaayDwgnD zSqC*B+knP?OU3euf69slk`a6V9n!DP_J;JLG3WGS}snU7k*L4IV}{c$F=F2B3@0=JhuMytLgfrty(7w#agjWH|8%{dc|R?{lq%@ z2WQg_&L*!ayVxKPQ;Pz3Db4v1dQ3Ndw~_kg_cFVTUV&`quw$FUZ@&*I^M7APpwaI` z6cc5wJ=P^NU+Ic8%J) z0z}upr=2W~fJ9`N?qkIS=oc*dN?_)@?2WJb3Z)_eUdDXF$G{G%w|zgI3a z=F+vZJ6HQ>d*!;EP1qz42H7|jx?x6RQF2=e`~C57oB(QxGkNqP3Ev;fvWHe{_pn@0 zQ<(z%67f*4)K*+>BD!KE0 zjri4kcA*&`GVIV-~#xgFeDFp7bTad0Pgq$OTJH#2L(49-61E|DLb)G8S?S-e%x(m{RAK8 z^8pUe`kizQ94p}XTjhMhQ%_ZdYXGH%`^g#DfKb&v*}~4+FOl>x%pYivkhGy{21y&z zh5w6eIJtByiGby2kg|lB290-soA6XhFLm8A#MTXry!l>G#@@}~*896V*WbOPt7(=} z80qQ&9tN}P<-JT@9|9y3dX0n$9QiHC1?*$6C&6}Idye+2B%*=+bFBYDBGBy@UkiT$ ziy3tH>)mU-%h*9DRvoeYu(UMhs9f&dSWV3>(9U|^sdL2xh*lX679 zgQ@a9Ro|wHOsZ`lVe@K`lyu9S;=vK{1X4$GoB9Zm$sbZh=~EH{(#_f1q#q>-UPm1o zFxA(HGFgtmWCJCC1(nCWDmkfF-|*WNi~9~0&-Q+89$MW+UCr_>Lbcj4*p|?wc1Lml zK@!QIp&@rn4Z62``{6dI4?Sg|*Eqtju<+FAv~lCkYhxHs@pL*%AcIGa2v4 uX_%BR!o~CBV-+X4dHWeL&*|AAk;?eKS|0b?99_nE^qlZ`1FtUgv;POQPf7Ry diff --git a/sgl/dataset/__pycache__/dblp_original.cpython-37.pyc b/sgl/dataset/__pycache__/dblp_original.cpython-37.pyc deleted file mode 100644 index 73e22997df25d6ab853919834061acb257a4dd4f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4532 zcmd5<-ESMm5x>1Veu|VRik7W>UX7wK0hyGQG+~O>S(#UY{uMZWmN0dMyuV`QN3Gd)aaO^P>9P>8^41nbj=3 z!md`Q$#h{X6F%vuO-Gl1B~l@i2SFO>m0!n?#VC=-)KenWOX4V%K|l4U;TW{VPl6~4 zdXdoXqfn)~_?vMWCUFq;iShX_KX0f%qalofU7NG6!|kri>mrALUKBR0SYzIBn z-rm~UZf)K9;NI4UPke72jD>6&c1D>{d4acQPV}6yQ3P_wX%NQXjQuClOOMAw5e_0w z!!$(uwiis&y+qFRy!HO()=%$xal(aaZqD{m-mN?DZ@#yEZ!rwmMi{-jd3XE0#b6L~ z+HN2u!8vEQy)~bo3Hv$FpKs>{Qg}Rk7IFh>n#=^0LWw~3_h>r69vmjwgHdjJ)bd6~ zmrbL2Ak%OV_CsJC#;J&+5W4q48OFeR6p)R2Qlv5z&w{8qli<^UcpHs9pawPjG6eWu z?-&sFg5m9Ca2w~v6_$=k8UScG&5S3MGf<2K2jbX}w#ufvD~D>ww8&&+lBC`(h3(^E zbI1D~b$KXZMYw6+;T~`@mPC4yajLy!5DWxXf(?#EhA4q*)=Zq0+H%bDoR~EXMMM}G5hdQsK z5Gpov+EpSNNQ%>PiK+@!Roaf}lJEDUKq=qTL|zyc07tPiEe3r5aVs8Ar^^Eo;2H|$Z@t#P^ICr!{o`pFL0d%1w3o@H`^Np7 z{yI)GjD0vpKNTwVLw-a&G^HgbFR?q*g7G=y^oO=>T%78~<{yBGvQb&~bi-`@eN+nph?1<*2tc~Q z{;U(0GFYXxZ$GywodVpf({$uA<|Wy&E|5O(*-hI3qFu7Z*I_C0aZcRSK6O&} zGzXdUJfAs#MYi zZFVs{R$topt#8@W^)IZ@I;otmI;o-s?)&wVWwf9>S}SNl6SQi)o>d2SR^bf%s`xKo zw@e@0X1#g6G9Pyq;~2S^%WCs^_r7)I4A+d#t4F8lI(#MU5Ssr5XI(#8%T}{WR!2UJ zGH*3Dtgj$Vwze3hmKhyDi_`08)=494?BC$^1NU27)(mEi{f$egU7L@+acOKF=)IX; z%hrcBuk&l)kgu+1YsAxk>s$M2?F+guZ*%LQ1g?M2tkbtM1`A-|^$vRPW^9N)Xgrek z?EE&ZcXWQ1YU!NGShY);qb^^hiLdBp>e9loOUeR5?;@4$@fdsL4o|u*^3}xMs&XH^oa0(X5vwku-oI?DdmGJ8_K7Wpbz;JWzB_r9uBd zmjxdR1V}E9OmIWkLqe$w^oD`#at%8+ug_e3!$^7&8#D{C`FW_2_h`Dh=o4`ALmJDe zaTEediUg3GA8D6B)J$oo;|ElYse(QozpvRY^pbs`3H)=rKHg)JJfM6)6~&xOafV*0 zl#Jci7IS>VDy)J3P4Fn& z)5_&&v2e?D+Jvle%M<`$%+@}G`(ZC_la@^FTTR3P^52wykmA1oUm_*u>G9zE38-gb z>O3Hv&mht1G*{&xh+WzDM|hVbu=={<`%fo9G-Iy|zR#0BhMkFw0AVU+5i2C687Aq_ zMSz8rC_R>xO2{{A(|W4d-ZXn*BH7RHOc#qm;aX}sK9X4XB67kTNa#Y>{`q|jU8>=Ub1 zet}i3Zj0&iL?(P1q^*Lk{ZgbtCLj5!uNQt5KM>dJ2?X_&-*6p-MHztXIu{)C1wM)qvZdIfIx9KI}y_?JX^&B|-;eC>Q; z4VY27(W{86sEOj5#miqaUg6cJ&J)&~1GQ)h(MKV=0){n=1!Fd=D-*oTpr0GO$bvnM>4%P`SRlV|MjaurUOB$EhF2kc{=G(oGL0LJlzZdWJL_+6-Z>)!wlS!;*KcpWeMaJ^FNMp)$00YMrb(`#6iWCq*r8z`YjC(_AC9^1Qp+71 zTh@)`zD&bG7=*w$j8hRsA#4x88OFeR?316mQlv5zkNv2XYw)p8vW>+qP=lF)3<18| zKLCV%e|R$)+{Auyg{c$L1^^mKx$}fF2gPWxD-H~6yJouEa;UaUi(Ez~N$TEESUw)M zw%p%SmxmHwgwy8k?Eoj^Nn{r}r`k!T5%dtQ@*fMHt4o4#9AiZmxi^B@VO)2U*)fFV zZFew@1Jbc`&YJG{shj${*kO3AAD^j361S56XzL;L2zI=oNXCRw&?ZIN=b_=D3odAj z$(&%E${1dDNzU1F!0)k;3M#Z++GDdXInfn6vh)c2r1YipSxQ=RL-^IY6 zM(Nf}*KQJvhpn>-*ck6j-v12CnK;74kN+=W=?gl$#v@#u5gw}D|BBAx_G$ds+uLh% znN0f0ejCBAqdEsW?+oQc?5aDrr$f2fq0r$WEfC~zI0XFndGxbe-Ik*r@?Jehfz3+> zlnZqA3~=4Jgnv-HkOrttrGTO~QDt^!c~;7vSjTq8vO?yZ6quD3j-AXJ6**h8PS|I} zQ$i_o3$feoN`{%OieC1Nj^0H*ifW4Brz@U^Sde<2E*eMHPL%juFPzz9CMoTZFY1zt zLYUbqYDbA^ASq|dIjZVZHBdcgy6SmBZGVOT^>7EJy zroW2a3}X+0(MyF&y^!yd47rB&pEJ3Nk7?)CE9mNZ>F;@iFcMzukA>=hrbutgK3p+& zULOAs8_7;%BVCEi?ZoaLmQB=|r1?5;VC37&HqUqTeDQa|c)IwSS^T@WE~!AOl2;lP z$UgX=b;MEzue3(?6PvD80G)MP1v!s#Nq(%$WEJ9e%Qm2BhkWtdI#zhH7nA>OVTvOA zELJ+*{?oBb+FS)o6~$=tcj%T{PoPb{^+z_evgs1|xxHpRVINt4{NOYDsBq+@_HiL~ zj*HN^#0y#BFBW&66;A9Ec2rEu$Ca#jTs^^wI{n7QbaXf!Evv_gF zMQ&aq?n}I7!&@@aOEcaDtXE1Gvl6fD+Q&;L&@p4Yin+^|=GM&IvYA^xp9|hO@UC3q ztsCA7Z@}6=996SQR)ux*&w#)YtbuK{a|uKlHFG(mWwfvGg)i-+Ilh?9WwpB&YO{us zvGmNoYkkEY)kEv3emduh&WQ~EH)%hHxHSnLmYMK7B z>7$p^vV|GFbVdi|d@Ng>jaTnlmv*md>|IJ%5h3Ayu=>x~Z}(2TFFS+mxa zHS2Tel7TX#Qp->Wz~}htiFMS>nxkucW!L%2mbSsGIa<4b*p(UWl?$|0AosoOO13(* z`3hhCk|OkSwn8$!SHH3!Auf+zq;tM*Dx_TDU4jmJJUISH97?rQ1dO-(P{+O!kRQ-r57ZqM6 zMH9hv0U?Z}<89(l6hn|La*A`G>Y_LH)vmljlPLJeS5dv77MKb=9Jk((w1BRPcsdpm zcXw^8iLTH|_7DwqAwky}C1EUoKmzcxlLK8&5msY`d1<1H158Xc+t+qH zk;}x3q!bTz0ryi~P}8xlVwz`e*P0#Z;vt@4NKySHiKM{;xhzN$U5H~WE|Wc7z{^7y zRq6-3x+eHgAnS3lZ!#8g8E%TYOpg=jE?2N(>uMgmn?}>~ps!ho#ZPm9e4mCJXMF-r zeoVBYnnWR>q!16i`M!1tM9q{gbUmM{301JC>jj!^!!8*BP2gYL_VCb>zlI$78>%Ro z{NuFHvPo22(c^yE*8RL7%va?n;2 zU&8lG`B&biadSkfcst7!UafMsZhqHW!<;=y$uiHhC7fDZc5?)b$p#S8{en@(JmXH4 YyE(M)(cUWB*9ttM4v&Ff6q|*A1BBSnO8@`> diff --git a/sgl/dataset/__pycache__/facebook.cpython-37.pyc b/sgl/dataset/__pycache__/facebook.cpython-37.pyc deleted file mode 100644 index de2d3827c0708f01d5887187af9ecb576cb9444b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3227 zcmZ`*OLH7a5w5C!_dF!aT5Q1rvVn-5Aa=&=iHC9U@Y*b~A=vc>HWO@gw5M95k$SpY z)h#T|_HaNVmJ31f52Pc1h6@*>4twFGD@R}8%j%htC7>rdJ16UvmHFl4C!J2n@Vs;Q zB>8)bv47H^Ftid|`Z73Ur+nyu4`s{nu58QB8N(gBat&W!uFG5a2A?yV4MXUDTl-Jq zQSvg+5469#;9iUVK>PXb?qoEH)61{n*h{kJ$X#ecU$V1>*JxF-ir4H^oQqmWeuvfk zLU2}DHQTpOF%mg7_{gnoP_K4C{hFNzwRORNLn9&H7T&gG&`Rvgf~8u6UPn9SQIRO+ z;W+I_(IkVVB1IbL%|hj)q%4!hwowsRd!=@aA+_6BRc}Vw(KM=5Jjo(#6phljEEm$> z#_2@9A>E3WCr8mZ$r2S;$zq#uvBg9lLk%p9hzp`BDJx=ouC~i0-PLv$PZMpcJg=ZT zIac5PUm6(U4F!OnAZj#>qQA1&y}@2SO$KFoIC%E-!*|n@<3Tx22MeGKmOkqjC#nmB zw8XB{k$DjV_QKsrlzukZC1okxgc!sW9&|L4SOCLy@~l$N~>B4JZ92O>g9 zac&}m%yJ=MS!+i-V@e$}=d*eLvw6=^_pnCY2hrW=3gWPmc6W0%jRCC47^%66`WDS6 z<6fCdt>G`ZW1+XrXJ1|EeYNM6y;j;27U-@<`**-x*ZxAx1!9|@a>Te@^IDuD+D+tI zJF^FSRaKPV9}LWmrU=$-+>a+q#Fm4i8ehQ{IJ%CiH}9^#fwQ$!s7Y2SN`pPCah!Fe zQf*VL-?%ekXPI2d7^M%wxXruV=d+E~?N^glSEkrODQT|!1(=ny_!hf>^J*@|7Mml% z?OPJExh1VRE-!8HPGz4vwS##MB+eGwcjs>H(mYUi3sE-r=Dzgi0rvGOzxGZ!W`|W! zd$nJONRM_M)S`Bz|0jRms*n)(TQWEmfByNreX?G)&pUOiZc|!)&a(GxR&@_T^*7Gu z9nNN()mq&#?pT*DeVG> zbsNXql=swkNbXZc%}{4`Ks2gOaIBlgFO#&?A*qi-A(LZmsr;3;MtQ0OBQ+CShJBSx z#(S0KV=dIK=E+So`ODa&@G2=B8D9sKLT(FC5%LXgTU~J%023IYu{V?LBFaW|S$YUz z1W6NE=tWQ{B|i}AA1ky0W55NKQHwAJ6cEZ#Mu9PQ?nrS#Y2@11B6PWguj9SS#av+P zG5d6Au;v3V|f6t)>kiTmo{dqfz`Z)tg^&Ys( zf8)_^(;3a~{W49^NNrED;;7m-<#O8;)omnzDpA=zpyRu=!S{#|3e^S?sz(6ww5U$B zgUtZ%7J~ zyZmhcFyw7Oa6|ZFwsn=)ZnP{&rUW+?5HU6;1S;wedM#?DsDwn34x?zA%Oe!dNVlWt z@F-535z|v>A2^0ps*}dJ6tC(5k&lRMg6PhoXBj1Fir#EhiBQC;e>C;mG;pMVX;6IL z4cboI@m>F}89^>_U`8}D08S&NcuRZ|&#CWR;#MnUc-}pD zn*6Q7*gxp!`eWng1AK)9kxcTG^|@)q*y`K7mN|Xb$h^LXR;2zo=m)0HO2cuZ-!QzL zHpi`gYuxU)$1D96&Ym&pNcWUUH)+hR{;K59S;xD;PHelwwew8HlU*>4?OY~L;xaCh zvSaDiWIv5m63b|3lqR|%^P?=y@muD4UB#Kq$5Am!M`a}IIRjep`r|Bqifu57(Wo!v z+UIPCOILcQLElPj>C51h&3WGeAIgT|UD=eaQwBZTas{m~SLGeF!RO3o{SfEBtNo|( zAbFYR``X`GaIZzTr~Q0qXEYea>BU#E?FCtM;4V%=4e?}e3!gbz$x2?axi}M*ko+F2 z__^S$v?{h|&(RY(75K=lY*4RqK>doH1(kKqe?vVXz6QRgWH_1Fo(9XK^*Sx>6o->U zDG%FeKZ-^fEEXv`fnJ}ee2^4HQrkG1#N}?G9b-`K*4EYQQFb_vN)?Z?2n$7nG%kvT z^tW+3lGmgg(Q@Y~8YWqy;xbt*6E2n*$zvP?6C>h+s7#8I*p931B1w0&oyFrs+bYjX zoSiOM-~C@FFv430c*U_|eZnYu#a?%MyZJcj6~#gC`Lhqdo1Pr^iecJY0H(LRx9;Rb zwPBDZT?rA!za*O6HShG!Qm5|HdIL^l7LPzncWLhPJ{ZY|e8vsRgtSimia!Cqa5MJy z>*nQ_*<||0G6apg_=>kdhOAoqjS9B&bMLhtg!$;|;CTxtb`u48d_n@_)TdG@O{uPk;$1HuI9Y&0O#+|=M)%mn~LLe^-T!0Gf|_gR1^(6 zR()&Mp-Qz)?tV*MxX&`qehuRbAB1t6x4F-!YnO^&4!WT&xe&5Tt?j?TT;3JmVCT5I zic7J`BjK0X;(odDo6T%@@I__adEF9gSq(gi&=BJS~kyGRikQBbp4oR@7t_w zpZt}xS&OsjJLO8%GWW15U1~pOpWr5Dt>3k>ux_{57NimvYO}e;S+@EfE56ub+`ONF znwtx^6@>@5U=o*yDk;d&!$iAftcD37m&jpKT+o<)n|dN-vgAps2o2a-YI&hGeV*qEUH*ZQV3}nWTjd>G;SZ zGC9_k%8#@)$Wt8{sTtTX>`^ir?v|PlwNN{nC%19tFCCH~S`csA^U%B9>t_;ux;h0;C;0xR#5^Os!qkbDC={j)^I-)qiM0o13<{IW&J{XNGS z;14?%VxqULsPHp@R6hWB@gF?et;?tBgI}cys;$jYHaRReO)lLu8Fv$*p-fcv0ONL%Q-4YSqTbqHxw zP)c9`8vuu7wxjKfRS=)I1HnU^KHv54n_ff{2ZltQLk7@ZUUW_L{~oPDX;QnB(O^GK z)KijtK$10-N?nQy-IoA~DGKY@^46OjDA^SyJAG%_2TAv^9HoURHBI)ZFP}2{H7z#j dmolyPAJvw6Orf!$Wa9HCZplX}*5vEfe*v%@MUMag diff --git a/sgl/dataset/__pycache__/flickr.cpython-37.pyc b/sgl/dataset/__pycache__/flickr.cpython-37.pyc deleted file mode 100644 index b1308335de7c0cc4f8b314b86d9b68602402b81d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3590 zcmbVOTW=h<6(%{ev$L}otz_Ab5+k+y;C738727GCxV95Jj$OsJ8egR>91O-w?rNmD zndDm9?Mxpc1L;eV_BkjZJr?MX=H>`hzfHs!r4GBub$w)4rZN{q$Y*CsrY2sD~+nJG8$= z159mNd0FdO%%%>lzC^pYR{`0jH6uIJqxF{vZr7kqC|%m3b5Pd4MKQDF9rB0SU#|_9Nk~l}ZWRK*8;*%)H`vfCt zgX}vw+NeN+5j55YeO}V}wCZY6wuJ==m8BhC62v%yz z&jPt2RK=&DM&(Bfhnr*84F?;$6sw zC`Vr}6j-7?{0Fp2^i%$sJhvYr{tk3^EM=u4Q}$*WCZI(gZ2L6i0ASMnK)D*bT6rDJ z^@?T+Ujpg1cO?pi?5A;dP1l9Kr-LS)13l<<^yKnMyH}#LA4K9B2-8{R_Qb5KYSWEG z*?h2_URkx?Uh9P3O=}c^?kC58bibK8^ea>#eFjPBnOY5iVgkwf9^T?FWS^9Rb%N;Deb;H8kBWmJL45Y$>Fak|IH(DN1prgy6E*j^v7Ss!&`C0 zK2~SWZ^-q}|*afl#}+5CjIgBx@d6;dgdlfJUR^>13u^8X0(e#O%AUKAbh%lBEXI46s!@ML3lg8Dw$#ot{hY}o|+)a34KD>qxfu=5r5bwd6wXfHAo_fe2U2ao(bhKf1-7Y`{&A(dI7s z?rzZ@x8>sA@q8|C>$tdvk`HYpPi%RA!u$qC#XLrnMR_uxH^%u6ZRt|)MxpQb^Xup< zl#ul|_BVzzxB7;!=h}0rN*D)2_BIZcZK7i}OlSzPQ#l}H5`2@`2r|}dn<+lf3f5$rtj}_T|_DXNnSyCi{FrX>Q5d`y#`X~f~gx43`2Q2e&M^%8wq^Yua zx~;5!8u1VGC_Bwq0>K2L6C7T$SSSlnuSfI#hVt&F$?y?NM9R4i7{+HR3#by@Y7}Rx zy7Jr4Ifve}W;M10p?aY-7GoMC4mWm_QS2wcwp2~$z8q&v;eoQilK3^f>?eA5ozWrl zw^=yckg8@XW+WL4RhwCtvbNY*+4`)?iGpV=5(*C$;RA))QS)-A@!;36f%qk42rq(7 z*<=y_6uaaENScdAJ2X!EL+%c6Ue;4njg->p*{1NHs{tyb=&{Oz{@stD1-b z34>roAy5It+Xa#KUd8x->kz=XMJ+HVh!`dZ12aZr7nQ`+A42yV?ybq;ATrxR1b$Lg~Wyj+BQMiuJMmIoY zfxPPQ@B7O0{Wzr~aLK-E`2O=z5S2ZKA1W8#7bveB^D}z7CI-B#v#)eE2bpRfu+Kh= zB4ESDgBw1E`AI!UGao~ln;;0if{R;jqtdY5Q|_sfS}$^~S5*OX0A~8JmaD1g>IlI# zHq9$2Cxg%!F@8(;H$*?o#$AyGT)?}QCD&-?H+BC4u->lrFWqPG(1wQ_LsiFr*a5_t yTYA#`v=3UkBN;}*uzM5g%t;TgS4xJ~xmQXRUeoSJ=m6uwqcq9{)UJWs*1rMqXtTZm diff --git a/sgl/dataset/__pycache__/flickr.cpython-39.pyc b/sgl/dataset/__pycache__/flickr.cpython-39.pyc deleted file mode 100644 index 9a06a0c053f1cb7d54f5db74f38838370ca5bd4d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3620 zcmai0OLH8z5e6``v$IbwDM|`$#gS4eN1KY4k`+gaA}Q9}v?RukC{fA6Q-2(OZ20&Hnvl>@L~)3^#fZDad|Q#-E#YF$>6Xo$8&$g1y}4Sq26H5= zM~!a7%voVGn(9s&+74UMbay(M>CRwu5782JULxwSsV%E}jN(VAU49Kaqs2B>rF+~@ zSAnQ4CY0Uwr7xIl6V*u9Lyt3`di@||sz#HII85MECOy6O@T1S8AF&7>LtRV}oTK#> z>SJot(#uNMVm5VX`6b%I-4f6)tr*&&Ra$$A;DmMBfYGH*It63pD`cZ?6*Qez&fPHR zt#jq{&A3ZC^U6*7{h$~4VQ;dx-e1$YSRjLtA8v=&ONYCSXKQ){3q{oC{ zg|)oYm#adRJPPJijwjMEkm{sY9K`bmEN0x7Z0|?aliUwtFQA(s__-e%s^Uo|q$b<8 zvIPtK%8vbrDVryW1Z}j?S6Ed%uNV44c-||t`{Vi5Bx3VIJez-X@4>lnv^g&Z;k?!Z zu64e!U?&~%8mQfbCIFuZ;S;2e9c*LrW;|7B*_oWMhc(CELtxa_H5PschykGYuQ_>>t0Z-dfRw7bmuO&eY*5v+%M*|LkCY|%j z)mYX>e1*A|XqIp-R#mTtfsnl);=QEc!dkX7u`JSyn~{16&Jh`4xs{5y9$Vh(*5C1|&x3B6jI0nkjKS$~GR_zT$~ z2DM|SGccZQ^B({oKMkE~dIvxSs{lf^F}TrBpZdTaT;t+hXhL5(BI9@Wlsgb`V2~p! zy1=fFWK{e+5ES4+NeU{!)y46ZRVh<(WqzI}J<%B?$pF9#V3{BGe#JJ^ zgv)D_0Qr7Eul+86zn8@QfJfKl_{@bKkHpdny>RXWap$SHx8Z-Va6ewSd8tb;E=7;U z2v{^+&poBjpI-FiPab;@A3Tu|^e#Odup6HR%L`p#8E!7DugJlNlz+S=?#pwZ-VojH z&%Dtm(co?rvR|k(7gpudTXMKAhqsn5Fa757{osC|-@Scf^YMVKf0jJC#PyBxQ@Z=U z?%sp0?HmAvpMePovyqrU>%PX`xz!{lEc~#c;HK^^7E&TQ6+?EaSVhz?>R@-jw zpxfxn*)K?5&l_1ItIDaYzJ;%%XMbCUiok|<%fcvbT|nX==MgsV=g`a6_H; zX^h6Fxa0R|Q;*B&bI>54{{3_e*-?+@(Rr9Xq}9v?^bHo_VZDqKO)VC<(-BOS zg2*2*gGObWBv}q(8bAtF4iKq?&?Hv)2=_rLv%0f};$hRFxNgeTQ@Sd(KE>n}KZ46z*f3 zWys-}Zaj><7%DDR(Y+@}DO0$wEbuLUM~l6ySJxOFFmHnegH@?2=EHE3kx-R=e^b^v z8!21ss+`bYVWCiXpa}0P%-(i5caj5s2^)wGbO_Iam)T?%{}8+61W;8o)51Rm&jbH^ zG7jqM%0cN@ zo>YGBr0{$D>RNA?j>{lOw4g(9F|+tAc@IX1xMX~C5*XhqHrp20SH$%l8LL2H3&pH1 z13XVvJugb=5CXTS>Yn#(=!eB0Q!gqPZX8NrIpSyaa$O?u4|Vrj-AzHK8hh1fkA)#r z#|K5lNzI}ozA2y^Kp4ec+;j=9+R!ZblzXZuwr7oHRSBE~YSoW`S(Ovf(V2u_)H3uv zC?|y!88ZG*f7C@UNJkx!`dq-x79+o@neXfWGf)S2bg1d9z*U>d+_7@9FVk0z~Q%kWO# zo^*ztNq5+ttPR&Vd&Yz-yfY@ev~_9^*9Cvhdj2I&Vmm#q-DfhH?ty9Tlp=kasH93& z&(dh&HJ_0ALyXm-OWZ>l3#uWM_!UAN8VBkdXtZB{EbG6DPFTv>)fg>!SA!0Us#+e zyJq{&DMn(q1|NI11M1f72CL9e6TYBo)!^l_LD z;y5edqgawh^yXBSqqM5h#*y(fQG1nkjSIEc_)~Aj#cUEQnPf$bjp9+BRMk@YX_9B+ z4e9;3IX8~SX_3lArOQpCv*6hj2r#p_;2j>w+4IVBpp=M;o$kR58ug8jtA8^A1nbfXl|=NJ&|2FqfM7V zgz+y~ZSI-he66vkzw+CN(+oTo{&aul4!6JvKIRK<5M>GbEU5WYU)hUHq#rgN#|tKG|4s#m@Pk*uF&|_^Wm3JZJK&7)A`(Wi{2#ewh{Up-hfqk;$H= zt^T_@pg8;CJ55`-6e8XM*r7#{DE)p|l!I)p!L};O9*&y=k@l8qU@3RAz;eq)g9lF3g_ETssum zx2#31ta2*3p$b44cX*cve7?T4`)bOSu@oyP7>%+20&{gq`~kbbrPW+m57+{!Y~L1; zEo@;gV773;yUICtYZvoeNZbc(-&=UKNAp0v2Z*bMzX*iC2(hoPg4#djm>sFG_UoXI zkO-YRtgYG=!JmUgOCj~`w?ueq{pA;n_Q|?xpLgn3-KGTkf)(#MOmz<;`FGA19nR*P zYOU^=i&z&Py!Mfz!)m&k?M zQ;8g>fL1ETX>~~xhg(;}F=|yDUwtA}`LX6F+B$mGyQ8fF!AuoR5@e9<_L6C;T~*4_ zo{l~&i}A;4QI%5jkF`@4X{Bw1xTNsK_Y^F_RokeA+7YwKRQo^w<@-{WQodv6q{o@+ zbsH2ol=S3xNFGo+El^x_NHkU_IM&OPS7~19hzvoE5b3eDWqG9SQJL$|NX^8SVUN;m zyr(oDYfJ8Gp5BI5unIwnvWlXT@pa%T;*JGsMSO!hcGubh+ANIF*mbU52HS{CrH9}~ z=rr+#_5<}%@B>T!^9p(Z96&8{ef3zDI;GC^v{uK?0yBQ+1+U zY&K4Hiy&j_K2YcvWid}>sl8+h6bkJlCy!Y*lRTf=`9+3pUBUM zP>RP^*cYnluxT%-C-CuYnuB1~8)%+s8$-82wtl|*#-(RJ;TfGXhDuGRgM2v?Cf09~*UW+;?sv>c$qd1vkL; z&XT+tG3|s7fMWPWoiwhcSe5S*`H09Si0&+#l~J1KXvwaM5o$O!ji!8?zKs+xy-C2k zVcTuHffsC<5#$jUPDBqwSE>E7>YG$Qp!HM|wKvU12YD)gOp+gvWCKl5pYlQf6#!zI zy(X*N`~Lwn-;&CkzOx#Gq(4(xUYROm`pD+usd``A&~V$dGi0g9b~|Et zzIX7e^s6Rg|D>PGkBgsQ;43AFWRhoWz-=og&cNlh%p3St77PNkA`2&tLBsYrSu|-5 znwEF7)}%dXPdbCnWNom<*<&U>>7O#`r_B>*urB#i)(u`_CAQP$#(S)i=^mKoP9f7r ziB8H?cOBD|#bKTmiHvn&+SBJ*tWx~!jzlkMZk~N3^au?JPPqudO*Z5ywg zO;e=;EM~$u9_O%CtjH>6bE=AAT9#>T(0H2Yz0!Erl*X^^YBuA1Hi@-L#(4}y@i0ru zaw+{h$;R@M^mbgY8^@zGPgSDRrIKi=U@VWY2~3NLYs5M&HL+daxMiB{8aGcSsc}_N zXzZLW*4+HM^{w#Q@-H|v*87X&7wpx${k>w6_RI31|Mc;bA7;l#{c@D`mq6*)H`bdT zs}79OqRSw{_%}q0`}Uo$)#vH0>^9^yV)+QPbcgmHx4=j~;tOt(C8Tp2R{Rk#g&T2q zUPa60^6C5yB?O3@_{z6HMyz6=Z7jHEXZ$a0Ul17nE)LwMOl_kZKH(JQ)VTY_IEO7& zav008>Ix%z_f1HSK_+~`7467QQFpwH9%RKZ z$;yXdR&(mbh1F^rt7Xc1f$u|Cy-Me*rd#MJ3F}Pwe17X|S9`-81l3>C#CJd>EJ~(; zF9<2ltWS_h&IRo1>=3NW;qY&_;*~go zr(3UecILf3t*7Nr`+a++30yTF^^)-lv1Na%M=RI@N9rE*Up}(>E(FG#s&TFr?V{_{ z_vOq~W?XXbYe$8{tOD$}Fun{y7tVK^Cb``d=^?x5V$TbKF|R zrMSZuh-v$dglyqRXMri+)W5?7TLgy1mctLn?8ZE<15J%h%sy<-hJTZr@Kq z?X5*Sib4V$FHQ7JrRAU@)6dhaK&UjnPShv`;!-(E%a=58uyxTLBc;Xh#VCsqsJ2gIBiizsSBo8Ut=E%XOK{VFKSk}*yXK7ZN zhz^1jBGV(|sN&E#!y+>cE42fgmOV_zqdjf-$OyG-czPXl;mR`!uO*?F@pT|B;;sM{ z5#Qjh(-B+1p+JvC<3%W5`rwd`x_StRgk9@)xErKU$)5}L)deU4MZkv=tV2))SO|FXO}E>K5;|4+ zU0~w@Dfk%?!l2q9auWoAo=o+z@t|y!nkGTUmIXjzkQezpRfX}BDNra)0H+}AD#~)| z*EBu)ILT(I1!K2a=c%`;^Fty}h*0E*D(J=5Orx$osALH6Z)=BeRXP|xHV(S3gKYhJ z5kdzkDKjuY4RAw7^7%c0YlCm{E&dkZ7x5MVxFJF@zr6~6*Gd`0TLPMD5U~~}>?x`o zx=kvuD7nP3iQ;%t$Qe>sD%M8;j zL&dgD9hOj4X|t^@>S+0L9VFzPhTxG)U+9Nhwil7agBele&}EvSD0?;y{D^c=t~CC1 zJba#|>I0JeoFp5ll6n*odd~n6TMgE6<<^@WsL~bXJac2!2T5%cNC%yYw=Cz|2mmLZ~y=R diff --git a/sgl/dataset/__pycache__/imdb.cpython-37.pyc b/sgl/dataset/__pycache__/imdb.cpython-37.pyc deleted file mode 100644 index efdde3371008875747c8cd60b495769ab0e914b0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4255 zcmZu!&2t;K6$h}(<&q*Piuz8{l+Lti*)~$Y#!YA3Oqxh>-Em~sc4LQaW9l?W&=jur^q1FLHLB$)a>mm~{Q>vDL0iewVdM&#@!hYjJJgcz@>| zd<*wR!_Wt@@GC#{ReW7_hL6x(%A0$!^mjUGC-GBsR)_ncr~HoedO_%SX|36hbfRC9 zq4HUqOV$>WOChbk)wU#uRyH&$NQdO=+igd_^cm}MxgaYirFPMG{gSNyO~{&DJh9tl ze?ive(g{1}?F#6ITsCx7uE^CB)~?AlxsLCm+>o32)<0u5YcD~^m!aE0KGcOl{3P(T zD+A?s(^%<3$9x>mOJmP#S;~c|T41OI0&V)J3RN^RWym|onCPI-3y!bE*0y%zUN7hd zov>?RkbZO&RH&W&8kOV|D`goD<4^oC&zL7tE44FrEKUm<&#W_;K6NtoxOhyXZpKba znKeBE?r1QLtm*P4ZEHf zL_z9#%Aw}UP{mz8N&MW9-msHCN_4?Y)vF#IpXc@c$XA`z_mW{4q`K^-suM(BAP=$J zlTK);l9&2PN@UB{cH)OURi_;a`6N5)49K}kp4SaKN#c3`WM6%I`%ygbx0B>x`|kCd zZ-hsO+ets%p31#V>TL~=)C$)521;TAiQA$k8lu9-RT6J&rcg`hMPK%PFFhKfB4^f) z^3#!ux^y!1qc0ibU6kZ|sK^+{T=IQOt)wu^G5?gqJk}2c^i?zNHT}>Z_)(fP1N=ugi9j}cu>w6DrA;G(;+XTC_H9|(3jIizKj?%be^)6u z;_~j!^W9}b)IhbeV8j%LGD`K z+;_HyDU4q5$i~`=lcBaB$3X<|z$_18H`PhIlr6}|E-O|eMjFgL@COhlF^NQ!`FPo= zbCKhM8synVjaNV*#(2t35kHTG6sL%s$2KWqR6*1f`jptlg-<41pG;Z>MQ)(Jk4o2G zXhcG52euq^Qq{uX1r12gOB$$EF#QN+Dz?qXn?~&UA)QliEdlHKEP-%;GJc zJHjO&D0wn|8ijF3UZw4fYx8j}kA>OLe1}c&ETE9w zPZl6q1=>^(uMrQV`!TdJI!tP~waOx{z+{Zag8s@Y}p5a75zfvw!RzDMy zma&g&6NJ?%0_BYVMI6|7^GuG2{w_FhQ;Dq~ngo9##+w-Xs(dXBlC&ES zhHrUjoA`G4_(~q%(`>SJEsVRJFnJ5ajQfH_6z+FvI^DrWB6osNJ||0_VuX4V44Ma* z)7j6>yxIQXIMU+j!&XC=_ipXvSJqwCpt*JujD1}gsvt^reSV?c_3r)l_RU?Rur2^v zbsuRvj(ngC)rtC;F$Ty)3x|$7o{}HNhkj%-skQ(xfn_kJi#KlNx25eL25D>4WJO(_ z*$t_>o5GARm$~}&+EDpVf_Rh|vwGbq%~_Vd)~{U^)BS|`ZZ{5ds&35v)DQZP@`=V! zx-`E|b$Py5qH}J79EsP9N0HPOGaATzwUth{I~t5az$ZDEM=z1^-I2Jonu`;zh=+x0OI&&LUq`Y68ofnlmK0nXoo|OK!UpCLO*zyZy-Y2D~%N2A-(luOb@!6||`B^12|3=(4EuhAB3G)pD#dFVDGL zE_j32EL-3&pT}P{RGruPig*nxH85H+tr}KwMGbSXB4L;*xSKls9Klhjk#oOGLMl{1 zXnbF&JA__PaBtDga4OE&sddJXU0?L$=9>2tEST{F;m!N)XbmeCu;Rz5{V9*!6z>#+ z3I%Q*yJrCp757vPJL__eg9n{tU%P~Z4OKEY*pLo9UB-j)mD`~YBtFH3IZiaNxCtzT zw^nqe<>bM%Zm{63_q;sHZh5!w?C#un_nt0BaYO;5i#_-f46K*;;`GJ6pFHMN3 z*C7`$jRBAZhIHu#Rxgr>2DHzR`lX=P^~+fQuUNvQ5C)TMhz;?QaJeH~F<#B{(WSQm zP*N3B%<&*s;ITna$@B1L@CLCQ0pEGL;&}(7PMD9VUl4zvDnha9ZK??HE3&e-fNt{? z;=&W~ib;>u7L5|zQa_>UAywBtCL7WFsKUb`QhmINFdI--j@2xxsu6}(0Z#YGMua{jL713VoRRW{x;fRP Y3GTK0)Ze0e%NUa4*F`i|5O-An53AWOzW@LL diff --git a/sgl/dataset/__pycache__/imdb.cpython-39.pyc b/sgl/dataset/__pycache__/imdb.cpython-39.pyc deleted file mode 100644 index bc99db55349188ce83705f19c08a3427c9e8ec5b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4310 zcmZu!&2t;K6$h|Oa=9WY>RWQs4&9lwW!p$jr*)^P$CGKKxXvW9Cw9|@ZDvaoOL^_( zE`hEDmS+YwB*azSNcn|LbKXmJL zpW*l2{ddBDt}^y7YMlSMXuO4z#i%5cJZ1yl5xk(Cj%)i$of5txE)Oc5iXC%eZ&2-2 zE$znsV6HP~=~7%9)H`)cm*d7@zBA9+4JIqnd%~m_){mS{Q}R2kU44cf*>0N~_vZWC zzr?q6cQlMc5Oco@b*R!Csyn=o=0efjO=Y;<)!i)A=&TPP#(@gEGU!Kf*rT=1|41+T zWsgvWtivVih&|}{ciwTNfL>0ggtm)zS+0J@`rOJ`?NmZ9tjg+_Li%#i(-qLm1lnYN-ry&>R629|tS+3yQ{EWG*vj7d(pxHp)YdfYiNFPO^@nodJ zo=%l1b?Ku#2xrPcpwl2|yKSfF5E!ar1a11L8dbzPWhgqxuIQl8bB=Gs=9cl&en09( z-MHsqaQ34Up+bGxyQn0eIGW{Lia+xQJZFK>j&^f)B#uit&z%!^fiCClLFI@>y__9a zb8*6d@1W(Q%%RjIgQm`|aU)U{##Cvl+jS+dgK;J7W#%%~3y6Z{Oeq@;L!~M->;*xT zL^=pmnVQQ(mG;6c3k#bD!>+!cnUbAqRsuT9AlM6&P<3?}WWzYp#t*dWMoAFK11$Hb z8(XRx=rGenwq4_9Vcb_u+M!fTa+B^L{F<9;5cJ}1mIcAr?3>rN?x%xrE6et`?%a6q z_4x2$E8B~=rgCqQdYi*TwTyMXg_7C0;;v|jMN#A9I*GSAQ>ZP>d?@$AKpzfKkzbp# z3iU`OJvy1$k7zQ;+bG!&P?157xa1EVwX9*9BmOCeiJT1vLx?ppm%G8%=+H~M&uHrS z^<*@-7phj;Z-sF<2os&PBK#*fjYziosRBJ5>6R5kam;x}`@%1sI82QHVK*LyJ4(SX zmv**q>;!i|y1f%@@4R>O{hN1pK4=T|3fR?msA`>6mqD4;7nCGV&b0a-#xk-g<2CN_ z@hkRl{y%Sz1e*OwE^<02 znGbCaHH=>h$kxV5v!QVxq)`Gd!K?t0IMvH|lt9SWUR16@j5Jty;*TIsX7h^h`FP2y zbdlpFRmjV&DzAb-ym8Hr5lat*6vv372QDdN)jk>W7%4Xj9X8 zt`!Mx%CPC6s}(uS1uaO|7qn0-V0r^(D!R+Z8&>qQgF2_;1qH}1RICD|w+(f^Yd83s z7f#Vaf3c(WhiV|x$4Q)a^=Jygy?TcYg@J5=t4R0$vvq4JY2`4C&E3;ghP})7xOOjDj7ohDTqD z&%BgZudxdpl_>{iSdKDr=<+eG&GYFx(5iC6p;v`Q_rJs#oA~y6s!IO}b|*QG_+$=J z)u2V4I3U@Np@o%w(vWlDWZKK?_-m{oq0@}KW;H|)w5vbm`&Y5TJXUDpyD(klDpqOc z3&eT;)EI@J$CjvTE3t95!iI^;1SV^eX%j-;*&-dj^ z02}O2B45Pc@*1c<*#cI#9w|pSsOQX<9I$_&DZEs{0&R4 zx=c5pibKQMWSz6|tJ=xe$-DG|#C1tLA3=x7`XAR5zxZBZSI{ew-duL%aQN&SkTCoA z6$BameS~5otg&nOflbqNQjs0*qp? zgtCN^0Rpl;ma~r=6lBK;l@tC~vG3j~g1F698yvX3#O9SE;r|5fa{=yOi=#~U(!ub0 zKnoHqDRXh^qKq#WaX!r~o7dvB*NwC5AZBx3l8nOs4lP9Yvz5%7AgIsDWH(2adlL+X zM};gcG!$hV^$S#cDE~ph_@NPx@3j|=zk6%DxXA9PMH+FlXdIf-P(_J0&9iInPH^|v zx8K{bYMT;(SMR=Y(?)NeeM(v5}rUN4Ocs%Z`X zIE?o07Za_yO!e%NHU8ONnJKsljATJS9VOD#>}Vv5)t0-x-e@q20j%ViJZ6D}?~PQ^ zL$gi!pqNnAsRr1R(ZH-`x~p_BOtT>EpJQD+zh*z|0{nz_LVehk68q>d$$&Ykw8KF_ zP^5fVtj_n0P<_LvmzniJd%t+IoABo_MD`RFb6oE82L8mU?=135Xj6~J8-ggJebMCp zspk~6%NI|5x9s@bJuSO4ZjTFI{%_5xIj+E;U7^WewN#Tg`Ld{;*06^sJl+J4U-V&R z_tbT;JK9acI#VDx4T1}3kHU}~{%w>@p#qxY4~6=GP!0-iHM&2J#R)rhP8f3W^M2fG zXZ-{nX8b^dXZ>!niWN&(@ekVllqa4p$+AU|0(XKI=1V!xK3T}U}vwidJ zyQY$)35AiV^x5xB>KO#)1&`W)L=NJIAZICj_U00B?+7kjfEoU|?!N zus`a?#fW-~_#aU9eX7W=_JN?t%Ekc_E)t6eZ@{ZZAF3^4AxNfvN>xbJ8Y)wtzD0Ur z9OFUMq5)fdn<$&=2u%`H$}XYuxK}1m^(qa&SrXiLsd$d}JNve3FPKW2Z4#_Ct_8OQ zCrx=6^&ZBddKV3~O+D*2=iw-(D0*}H;K_;?Tp3XxC3}TJi+B*(@y&bi2>UK8GUCe1 y19V|Vt~!10zi77;hUddqzfw<7y|%;s zyfa$qt#I~?$v}p$nG99?!t1R{{yFPLZ*UXa>v9`B)7khCNOP~0>XXc5l`?p3kB{?I zt4yW`!(8=g`|=4ENVNw~O@5u}%&4dNWZx#o4(|j1M(pg`xI8!*_J>*ihU4mfxDbR^ z(xHdxgfh(Vnq7!XQ47iMv6{~WXU419k$(YB8q~nku=W9?IslAocG;-C8UG`363ixM zOEL&0_RgEvQtfuzHmD|Jr8Q~7#%Ve%h9*rl$zV6ey6mf}Qgh9vkR?`sXYk${W-N>Gj=dI#7io{VWbj7KaSwDFnf%DPfJ& zsLBx94Xt0P{J{D}Hd5BtWoaP26XMRhWN^e=BETIQ>0HJ%{WtsO`@4tbNbOeDtKH9^ zefm*;cDh>)@?H8vIM432(RRjXx(TJU$uJ!N#y3Tahi>@%{Lb~C?JTv}C)3hHD{1s8 zX0-}n*d=Ut1dGnZfV^1jX)=A+vtC(^t$$Pw3ur-SC#f9jF4uR#)eo=0w}IIXvB{_B zmbnL7B>zkfrVw|BiI6j|K~wFzV?6a@4BiO~I*n2EW|ggXNlTE2b{`}F$&wv%3a z2SkKhkXha>sYaS$4=|VFy?;Zz>tn8eWMf%-ED0ap~3k(x0&zpCSg20?F^QsVBwMH=$wDyAb*b zc1MUzei`GNpy=Seas6HgM+s&Vv*j_trr|W2#?qfQO#7&VL^B_uq=EA(^(<%{Z%Z1{ z={>@juUByJ7$>dB_-ipu%xWFhQ5_>M_;sTebs`&I$J3@+gJIWYav}cq>9lpWYVKTa z)Xlm@`Q!yFKK7Zpdz}28>kl}aBJ-Yan9aI$b%84 zfi!1h)fJYXT7G85$$odsisIC+9lK15AfD`}&M!e8TTh)SyMjEhk0ewcqy=gKl2vt> zjg^fKvnn%2>kmo1z?A*6w41l38)ZfhPpzlR6YKTM-1;ulVY}`h+R&(?Dz)|VvKVNR z()#6C70$V>ha!XhlhN2l&wu+%t+6v0>tSK6uTF=i+i@Xm`7;|IsLV_>q>CwSk)XCg zFr`1XUM7$9k8r&HG2PKra-h-^H5?oo+i)*#TbO72Dz7XbSfLLrSGTi6ybRFCIHvj@ z0OK+5@Pzv!;j5w}Ho5Ql{0ID7JP}*q2)@eKvFGh%wusNZ%PcsAivI_gE+l^ld)IKo zV*$rmrJ#{wF+52_NFS3{8a~F7Ms#{7=bx`vaPZg#8xrJ{3Y~{gW0&C0`0dq&Q$PWA zdT3Dmj=~H63HwuT8e?Y!=Z}xYj7uLO+W;j-fkdlOdw&)8*-T6WM+*@q9)-(2wxA=h zu{UybQ00;yz52%sRtJ|28O{(8b$EqEORU&}J|d@&ZIgV-i`X<}e8IouxEq0a;320r z_H+ytI*R6{^)nPLM7TB7UPN8cExPlmAq{fW21EOsN2%J z_{;__w)6^}LnX!rYBV-yHo!>(Wt-$`t~SH(qM|I$l`d_Vjp4!4M({=#D0lWs^L#YU z)kqaa$+;gVUu5}2IlP<9E72BR^PIqax+|j7A)Rf^8v&}^5#C)1kCzqP@}c#>y6vf} zw|jvnIJ2TqW#p~!(D!^4TJhNV=LX;64?$0O3w|7mSe!q&ncHvm72Uv9RvqEA8%n2q z(`{16LuD&XZIY&=Qch4zQ`<_@SCcHCb6nZCF~R_;%ADy(bPyFZO;MnKLtqWSwik^> zU*$PksHnD zWwk>+nDq&|T8vaX>>uY!|BMd*34O7S24;s6*UmmX$2Dv7Y<2hl8R&5|C0)C=pkM1&b9S~9M|N#}84mEZcS3C=L*W4JtW69ul;03plM2q+kjJND!jM;>2B2Cmrw3 zd&@`^j(vz+5y^Iz#NDE75|aesn=BAr>1B25Rq!_Ljl?Cs3X&R5)QMg*RB zp8ifgZV>V>oNPWWOnw8S6aWMfG$lQ1zpUrjuRHR3K3(vFUTA5xUJbr1jYjoe-SV8Y zF>3aj7I)LuXs5SB$rB^lq(qvK1fp8$K{)kx15w+K%`Z-$HeZpWI(OBTF;TL2t@ddT+^NhcuhnW_eEVa zo{?T4nxX~YQ0#~{d~4qlm-He~ZqG!Y#wynGaXLOR(YeL@fWKr%$0p1NgQTCt=?e;? z`~HSjh@&kX7*;V&G_A-ro3V-sdWTeW&M47NMb6x7;Bc=3ocomvFsM9$VMS)O%9+zY zA}4}ThtUuO6l2}1`dYK&PSbehxR6SrHB88Pk|mmRg*ur1Lgjr~mU3Z1Uc~yeG#(dV zV&n6&NE2=Lc{Uz#t>Pr(uo3U4aak^@U&m=8HmLXcS~llHnMt(Ka-(RuNg^&m5%`rO zR^wWh8qtn#+)}0k<7V+l8dv4H2JNj9w|-0qOWac7IfWe;I&%J;ynOHYG#|<1vV3~{ z`IF!OBAr|wm&5cJPY6zQyuP@jVxsC`lqOzG8vvmPtU-PIx|IEJb!&aLqm>=|c)_a& zakT4W7EVh4b%6a(d)zd5PGxV6-g~vjhAgDg&-3RGAC22+03cv%$<`tokB#Od4k>}Q%A6= zt9?y`bIsHvSRFu|(pd=qcv%bPwe`I6&LS9f7!8LoGWDmyG!*W%rkiK2A>_L6DPTLo z|Bg)~-LCv9s6vQxx2jdFibU{TIIZhlxSTx^UbDY`F>PG6^{v@{Rj(SD9lj*l2QJaK zC;z5o3c2*^4Si5GuPNk>JEDf)FUeP6y=n8$%`94cACiYa72u#QxppYY+V7L{`GtRDNNcex4fFW?NXU`lrU%GAnas+%(UI3eRQSypWjj&2(koPwY2(VvM0Z*?ff)M4gaaAtu>$QvOr5BxsUhTeY~#e@OnrzJ`v_&! zr5H+nA(P>$HZ?o3%ltGxk!fk@&@eSHRBpvPcMs-gkeMUKY&zJ&~v8#fuRaP#J-0Bgb;=FI>ln? zK!B2&E1)cHe{B$Y7={|k=b3*^KPG?bO+#20K){9PY)*v(VO|4D2qhMpmdg1HyGQ11 z>RFl(q3K|V-XTjmBm-C*SUOO}iXJ@eR^F^8+&P3{<=269LU2#SK-lB2Z8nQAK)>*iZhXquHO1XDxrEECRs7ohqeG7+A?_v1zby& zz13-`_u*gT*=SWexDVDeC&q)FhSJp0HEcZxKMArtyOJt5eq4aZ2@`<#SjWEA z&GjdvB9$YVX(<+d7kv?@V`=eDw5VHmQOu_Z?xCyzcL(iEZP7enwGUu!%YL}3`i3UP z0oJx>wYU3+N3ds!Ze~5qhr-MrfG_RQeR`Mv1UxRH4Hg15WLNh$E5KGC((!B!Rsrm^ zF!N3R>C~}9!s5oci8vqSVhm-Mn+E4k$8oyg*fzn0;IxqV^hCXl8)4y7=>O_71iJuC zbJ>XWWtu|EwJw!dAr;nk+m+$3wq1b0i5QI>9HA4wVY!`Mh+wHNa&>;G9^$qh*cF-; zd%;0o9%0vJT!gk}L+VcY=c!b`MghOXKlY&IIl?q^bOL^3+q*?>x&1l=x*~-++w87* pKsg%gBrR>zXqDPHi#?U?Q^H>@?DYsec6p$XHXzzVXr$nE{s#{jDyRSe diff --git a/sgl/dataset/__pycache__/linkx_dataset.cpython-37.pyc b/sgl/dataset/__pycache__/linkx_dataset.cpython-37.pyc deleted file mode 100644 index 953f46c0e3e45647fd65a198354ef6e73dbd9188..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4654 zcmbVPTaz0{6`t<7Xf(R4_TpQdgmtJC5tJn-*yK{iU=orzxHcvZ*r<|B&8kPM*&WTu zJtKQpHOd3)5Q?gV%PY8Q_mL|80DptWe&xx3!7uQg9%=1e2oFd#-94v!PM_{R-|6rC zq|xvU{9d{DZus|8hVf5&S^O+CevBvk9EC6hi;W)BH6NKhi6E)p-L|rr<8d!2u zETMKqODvPJB=876*AC(y6V zZK0Yr-o%sr4n@HZj6-9_3MSZ2MV`%>;QQ>?^lt7r@^WG2Tv+?mcNu*)_2p)66~;b4 zunR7nGthC!nUOn%v0d3`h2gt}SyZu(D9Nr|nVz5T!Fm#m z!u2H0*R#oZoXR{D7urlKd~8GO^vvyX9OdhgI59}lrnHxp9j6aMsk~nX@g)4Il&Ms< z43e!-+0YnDJx;m4A0<)l`;wHZwXsb5VU~rZH~8ZqA7;w-1>8kBenwNOYFYb{P-}iN z8Tq*kqQr+zzaIx#HkaNH;z&FreZhbBjD3GAOhOst;asaXR}+a{ScQ+Fe;0G*VU|;~ zT~SsR#sg*1&L~T!X^y36i(FliL974G9xXg!4uDYgOM89)uf~&3cbJaCZkFBa-hS`q z#dvSGn{CBiS~zghjiY2|S8oH}I^#VF{2R-7va=`*?lStX;W^FgtjQewEjB&%?2&DE~SgZ=~Aw2dSpb%bI+n9Q{G?+RF15@uK?TWhHxV~^l{aSx}?vncL)Cc62FTZx_ z{e@?ec<#}%4pU`=iY66HRFJDr&kQzSy>#hfIpalrDp@BP?;Wj}V-&4;6;JjYiY?&k z)0G*^jRRUER(i<)$nROV4fzVX+om$pY^<#9G)mwXGI-#NNFr|W&L!ni$Rc+DGS)F# zeF1rRaoP{!>>3)2Zz*v4dIn3+(5ufq zLJNDfa*6syRTtJ|rRj5<$Q;!?KGg*>X$^&8a#sFb(}m}Fti`A2|A&!BJ4rr7n0*^h zb{YlkAYvBBWBwVQ^oSpFcmUs+I-M;Y`z`rhjFaC(p_=|vsAEIrB?yo0BJIB3N=eGU zk4NvJ!KW9FnfFx9)15prk_@Ed)<*sb4Z$cl!#w^7M~&K?*VeqY=e6^QiwYpXg{#L^ zA8`Ny4v*yl2s0~Z1p+w&H7^jpw@u-lLl9?hC%}V({nvoUH;t{8V|AdeZ%nTb^L(7W z)$MLY`Eat?f!lQ7{>Akh-HkN4_ zLC_C3({u*_=sp9v;~mf7@3~`N90%jMV|~Yv`O>kj{|~0$>Xv@n{RX&}S1oy1+tQw) z?6HiJT+&&}*GO(_I4E->W97~C3_x)!7Q^3pDt@! zP18dVZE+7!aM|?qF<(3S_{hNsLp40w!F~gab8-OO{y1RmYJ#6L4)~#QV1gI$f^S=R zY>sFN7jh0NIb!8LJ7hCkm@@~GD!77$Z#-cilDx8I%-n+Ln(+nr!2tS%bJmzub7#AD zB!}z^PJKrbd{Iia;KHK%6`y&#ALRAIJ7@^!knOW~j8ET~)e5iFT-506?Iy9Nuh=Ji zR-ZKr_h2b+9V{1(Lw2x&`BupyX59PR9q<#(?k+S#>l)VAESlS(fPlx6sLxvYX<{zm zD@#QSe5F}bMdK@Xww#~A9-S4Y7m zOTTDq-BFkanwIfh;4Tg}!#I-!1H~f6d|aE895@UEY#^5H5;01^{8g1oUn^l-ipzmw z;Ypafi`(HtD9q??Fj$kfKx~@K;%nRj)p-ri^Y}MGemqD$w3@igT5N@odxHARuZcdL z7X64bxOg;hHvwG1KyGw&*W~6iz>x27`Ekz5D@~XH9*7ZuWa`Kckj|{a5)MFTiHa`W zbKp_{w>@)amBh_$Q6-9>JSnxorEh6sU)~lBz;*@hf`f{v1L{=(1nvu%n;PaOVD14b zE?{bGS7-H{AuM=@>}36P&fQ^7W9u5;y;%bz8--UicK9Kq8*0{sqzOnig;CVL0OSw( z%+@jjXbu1R8Cam@j+R%n98$7O%PaT(amOg^gC&8$fE&*_YE>4kr)$StPEi?{0A+JS zvouLLl*SxEtRBS36#Q}2-$D9kYc8*B&6br#JIRvdx^uSdyvlCSHYbsgu|ksjaz zOfG1w%d4dJhgAFsMY~nr96SMgql8jpRGy>$6)H4tDVdbS1Lb)X$__{4e6K`djXYK- z6mZ{4k|xs-Nz6wytOAcFxV4uwe&n|y?0LiB{ydrJdD=dIFp1+{wQTEbqRdpNSy^Nl z?jpgEpOG1xR0t|wMuB-@1cu6x-b{RX5-P2enM{h2bI3RV1+oH^FPfg|0Uk}fQEu>( zhRGT`3jonoR|or>zHt1qJCPx^o$~amWSl-ix}YW9e!6JaD4Qd$>if#`{ZT3=ptioM z`~JO25SKkVaZxVdh-iharpsK1D)|<*5-R9cR*iXf)JK|%^iW^nRDdLYhXxU?)j4Df zjg0aDmmx?j)mQ44d)8eq^U%}s4vouDAbMsUN+^{@rRJi@^tIo~9*z2dMy={#-AR(p zV3PFnC{2P`@q}Xb(Yw4&Bi*bYjrTfHs$EH~AYbcHMC@$BTXn8hhT~KJ*?}ZZQu3vi x7k!l0P4WmalsIcibSRU4y7024*}QwT+}vw)0vUO;%j!637oo2X7GGyA=ihg*%cB4Q diff --git a/sgl/dataset/__pycache__/linkx_dataset.cpython-39.pyc b/sgl/dataset/__pycache__/linkx_dataset.cpython-39.pyc deleted file mode 100644 index 773335b6b4055cb4ad97a0ed2e2ca96d1064cd88..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4672 zcmai1TXP&o6`t<7?Ck7CtBWmPgPZ~?0<!~O6on!12)IgmH2;8q!E?XzJ37AM|kiXW`+8DA}i|gdtdLbeV4X!0cMA9J9NQ9;Y{Paz_y!W z4y`L%VjgYvbHg&a9+X~G^&4+*{PcsVX380cN%G2-FL}G8+_c}1dQlKdh7UiYo9I=h zx==%Xx&RX;`wgms9TT+$xNH zeqa||ILnZ6$eEElg|S`PXNBRrg;`kKK(C786*lTxfqB1P7zd5QJY*l4cxs}|p|pg7 zeN{Fd+R}wi%8_9xUU^wngTYoPv;4|s<@QpUgmL_WYHp`nNw$>^ccLVF;f3+}={c+< z!5~~o(tITw4Tq`BLvd-9X@QSyNF6WV9L7<;5{VOwBpu4?yt3o;ZYY)a%ODppAHQec-wcyb26;G@>P^K&Vi#KBWys&fxAHK{ zsk>THRu;y6WzoqfOQva#t!Ry0-H<|y|4kn~ctV|r46;joegC2HWxKPL4#G~B-R|6c z^Tt}dx7*1!;|^^ccC{ii zIKTD$`ddf#NyD-Z9knHnm;1LaudlC_ve)!2WbI_QH?u!J-QOZweO#Ns(#H!EmKz7O zBkb#t|AF7OZW{7CShH#>GtGv|+D@YcwjhJMzKA4z7H_XBm)sSZz(Xc4ks|bp{4Ofx zy%tAV-b)9=tC#WoPZVR^nYjRxMFL98Zod|%y&%r6qI2{nxlEtruUcBZs!we|4|Z$i zCiR1=9_f;mTF_TZRhqLOB1N+^q>stj^#7IPT8}NW&OFxQ3|kmJtz#@K0ZYTs|k?_r%>L#3MjQ;6eGS4 z==uykzBHrYQz1`Jb4E>4k&Ifa`A2jFBex6%`2$=yT65Z3)7GB0&I2wgFncasy{`Iz z!}ww3*dUC5V&$xW3x|v51swObDZI0A-wZ|tNRXTV9U$>_Yis3N?V{@&<7-=aKFnV2 zbT*@WYjmp(GfYw{%;w@Uz+BN;1qih@izrQ8yTKITf>e_}cCt4Cadq+~0o3}%_7i$wjEQ}rHO zmIr6q;vUT3W#iK`*7Zcdj7cp)P7Oty)SsYvPFIddk1N&2CHPt6fFBwMCO7~u__l>& zbNECUjdM`R;Su-QA)DC3oH!V%f)`kL#}oD;jaN2}iCYjwGd=?s=tG`x&KQ$w?rhhd zjAQIGPIFI2@J2bZ38NL&`+VZ%^};)72%|LU|q^UAxdRLjfG^PRB+fr7VgqboCJ(Sm% z>43P)a8EhF^4>7SQiL|LG4PCN94cp&M7Kw1gCQuXQ0xs=fdqdGyo2Y^y$`w!3Q|h3 zv#kIOh+;TCgcq~}RVG-rvIIV)DsR7|X_$OfuL*afe06DxFho`5g<>ttefFy_! zAZBVW5BN^3!V(VPXo-rh-Lv3QfW19&CY8j^ZBZqnpZu$o_63tSG`%k)2Lt3?fw*u` z5p_6$D&PYnh66CQ12EwLJOIfBgpKX$q@FW44DXPgY@dEJr&y7Q9(VM3MUP{Y^wQ&%+kY?YgE;|*fyn1P z=~a$;Pxp;4IYp&!f)q42G;fn0&RnC8KvoZ8U>g^yAv^ASow&u&qBKkeP^llu2 z2CqcPaFnm;nH3!cR*)bd5GI#2+T}G8`y;BZqgrj1afK(~a+F$Xlu9DZlDJ@r+?&)p zPt^rf$_@v^e6Pe|jcit@6fodQk|yI2Y0UeytOA=Rz^%?{1j%n=u z{ZSlut7Tv36J@SK-O5pp;U1+JvPWucQWa74ZB+O!l)zHCO;0A?JPDT8$y_GI%4{8# z!5o4GTSEPe>6soU04;IK4xUAC0WdlP2vHEMn=T(;oDG8~a->zKyuB(Umsd#PTT)Ea z)oP9MIpVIquRPx$q+$fB>#Mr&-yQ{VIiph-LM!DnC3@4 zB&tXdb%58^bF_pQx6UD3=w#$KU5WsYn=Q_~c@)z<fiADXo(tyC7( znwcW!*S05pT175LEoztDPLg(il=Si_O@dhQggp0?XZZoGbhBPG+-pavHYc@!tgTIc zv3&~`t24IJDWCe!38Zq8k}oxXG)KwZD39PtiNBUqi1O)Uo7X3x(+&5v@_?_>Eo2L* STvo?zySUdnvcNiPIsXQ!nbHFQ diff --git a/sgl/dataset/__pycache__/nell.cpython-37.pyc b/sgl/dataset/__pycache__/nell.cpython-37.pyc deleted file mode 100644 index 0f38e5a36b2e449c1126745bf930033c60ed26d6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4164 zcmZu!TW=f36`t7}mrF{b?vj%@vg2G>G?dc@MG*)_o4PSl7)o<|eS8_r51wcMp< zM-j>F21Vo|eQ_^9-+};C-};y4wNH8PlV95J3`r?=LSkm;dgjczeCN#D3kw0m@3oV6 z;`bYj{f9d9pM%D4Q1XAG5=`)fwYm9PZOeS^zSDO3bk1v6%$VP=wu9MNtz9!?^>!U! zD_Q6-winHuoh97#iCtL2-o*Tc#hG#n)^pE! z!9uUFQCA8!@eBKce@eY7N`L}ofg5MPX^;IGD{qXebEe)H|A4>ob2Aa{q)|WKNi(&R zpAH6@RI#`*cIi61&AJXgj*`>(_flq3yI2rQ>4R(ShhaDEsxXuub=L+m>%@5;i;%Ps z4kC4wYc~{7llH=VkaU$^4~c3R9>!@bBNfkvgAj`EreRkMLH3I%F|BH-;#^U)>1Zd9 zlLPG#r`nNOroao4gf&#du#-f29)|yBPhQ?V%KGtco}cV~c<=t}$!NHnA11rR1&n3a z2x2gji(q*fC9k7m+{b^zs&kKz7l^~w{B9ji^Y~dPbZw)xSI`h#Ks#hNWHG`&^h<6m zsVag;e!=g;hMksa{fPPLg@YwiOBOxMorRdgP zl69gazk|knkD=7G3{_p7ACpgqXNS=_@bsr>U8A#%w^(a@y!sQTJtd_Imj8~n5G}!{ z5`wQ-f;@=z9P*zF3zD(-v2#nhG znf)7Rl3PU4fvy2kq@!Fr^m#_X0GD|*xrGfm?PA>FivWy(*Lh$q@@r@X{5l^m8r5HJ zxYRcpbPdJm`%PTskd=IR3z~pz1^=_P&7hl=(kl3bupd!?QEFv;(|s6#fM-eYH8+QrnYjD9;tO z3rHHHjiS+e9_0m;>o&l?TrOA2)q=RC+pmdgxh{foL%mom7c0eTu~w`X8{ozlwa0$B zsdga28=`)0{q26aHEO7vlb4FkVv9WL0ZZR>n7VZw{N0kbI4id~8?ULCi*4iKuTa#d zRj>ujeYO3NeF9BBG$Ze`57+yPXFPlAwoa0zhluOQH?%vBC7j#M zkD>u+cOnd>r~Sc5+x=*$gZo%IBWod`1)VGl37)$Mt?5uVBGHRFaoQQpM^_GH)<>{I zHkloS6q0gn%j^uaGvEQik)#l#oeWnqFs^N+k(k1JO&EyH59|kYL0pZ&2$dRTO%(3H8XE5bE_ZT^;GtG|MqVo*IV8 zg0-hc=4<0(@)O$KrrkBtUw9UG508}gA)qjY830gGhoo{{H7iV$A8Jb;X#OKe^XIYd zm!O%yjfxO$jjvc6lm?maO{;G0phYEcR)BQAb=~0I1`?$J>m6$gC}{BpU%{?<+BA!M z9l~q+QEGS@peN{~^QcUp)1sz*Z?Erji0-tk^;OG8m&N0B`^(qwDO521vt5+4? zHbZ+l>ilrDn9YjH=6r=+y6Fba_m3_4R~3BCQ(x6Yl~V4)fh*ewHry=OmxR?b$Ff?Z zXPirtz>I5k{|g60yopkg^~X2Tl+8;pp+UJlrcUi?+~E@Rm6v0n5Fg|^J=)dK$A-fswPb} zuMxa^r@6e(7prx*?)ppqZ8Ou{hNr~pER)AW`7v!Gv)7d@Z&6kt-=ZN(1hqHlc8-Bq zpHZitce;a7D<4EMM^Z8U%8zLAuc#tj={19RkO`g~0HaRJ;0oEkUU_z^n;9kR6+B1G zy4L3~7n24}k!bzL6L_&mm#8hz=P=TuWGA1RlF5xImAn$ov}N*Ye{U-OI}`{^>O$`v PxUwl!*YW15^M?I@T$4R0 diff --git a/sgl/dataset/__pycache__/nell.cpython-39.pyc b/sgl/dataset/__pycache__/nell.cpython-39.pyc deleted file mode 100644 index 308608767b8f039d41547e8d27ede63b185aadba..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4177 zcmZu!TW=f36`t7}mlshKb$1*`cAO?G8cIe37YGcaXq^}V0;O%zv>_a9H=LD3YPn0z zjx3Sc4T{J``jR#U`qUr*l{f#TdF@l4+b6%Y-x-pUFCj5!?lWi4g@p>k?~S8( zqu(|d`!{u-{yFIU5+(a5h+u-ptiw%hbu3fc1E=Hi>73UonK6G*?o?)D)lSun)jBoQ zR=hB%cj{)&ju!`wPQ$dFcxlk=G)>!$mj^4I71Q?O)xlb4ZLr>1=j=WcCE-6}!jD$Y z?9PVZAF@{Y2XMxATU=M8QzgT$3REc3t;;ABfjU(|C}lXphhm zJL4S}ZC{j4+Y=R0eZ(N5npi;Xi@I1uUH*zWtW&|KENN#qisL7h1hl==-r2eJ#AY~ttkIe$RCGD-ynN&`pEZp$9~Gj?u`%X7wV$uce+zciEKRuT@Ptt3@j z+3|3gN)?IgW0$^VyH(SbkHh#ldPmAsY8MM4DShnLeh~DMo(cl#QFmo1({7YykqAf) z!7x+@nRWvKJ!vnLqa3z4|dr8m}r#SXu7@Jl(P*J9+*>bd# zMe&|?h*RyzG*#e*j(p5@IS9IOm}Nomm_2!I`yd@e+gWzB{o(!h-;776+u44+OmVOWqtTAsQoSO&-u6PJLlZ`hjrxOUh9d2 zVSR_(=G=m0>|J>iYsT*9Mf>xj<;vI5lO$YS4=#$^Dsp+9YV zlTvP-)BOBNC_ zhuS#_<3pnw?PUkYsuyc#kUosGCkIMKQOloNoqa4ZLon2-&1~O*lH8yZ9q1YGMEc3J zL)CN21+2^i37JhsyBK$P9Z*r>HD0mme8aS_@NwO!{bIv2x+bHpq8MG@!{_X?f-Dr$ z75p#OCWB_0g_ZLO0Y0FgB!ARYHlQ9BAgm1_3fAltu!!8=WQ7AuaD;tk$&Yep;+?a^ z${mcB5D1)eAVO~6VMn)Bnf5;@k^YNP;TPqiQdCt%RiR(jUl0^Z>c9~9JJ^mY^%sTr zwN=#e+N5!gSdE(&aOM*Ds3PE$`b`?!V7ZGH_=dgJb1$#w3zu_u*i$qCG&sX(Vg@x6S z6gc*C-|X;d2Y6Y_{c~RUYJIYyHYZn5UdpTIkTFIZd82<7HUke!MYCAWiCg;m z6;Uo$MWt9%ujEU4GhfbE^3{9|-1wsUwO_2OElBXXsGV7V|6{Q+Zm1iRSM&9JgFNXE zEV=D4b@Ll*bd9rOle6(v^;*7ZeEKI8;%Ut%u!3Uqi_N4mt#{ZCMg=khhn?A+CCyvx z8!NlL!?^i>0&2W+VfTms;r9c@C-Ni_C0Qh2L+?qoeHm2{9$leE2->s|PG^hWpi@3f z(vvt6`%!1(q8T7x3WAF+aPJzn$m?{0b@Ey3b$N@LUN{^^iI@g`SBA-cq;13yT{;M} zP$? z{tu&t!Xq%ksMkxAWH%hvr`Dk-brUS%2w3Fh~4617zNhPLmtuDRyfysX4ZL=mi?Wxw4DaMr2&-zhUWgUMW z$TaV1TMdR&;ZbmqTSQ2V@&=Jti4de{6aA4yA^Phj`gG(@)0D>ydFnJkPOLpOGPR8l zlL77i6YZ{&{(_UJw|}6t4*>-UtN<7b8j#9$*{m>4zOOC0r}+yg=g&jm&v9n+`?&vas`4dId{q@?O1^UkUT*K%h<}w`N$_oaYizYg@?GlP zCqh2e`!6VBPnd{wFuwkqI6|Vi)k}uQY73ccPi&bybqj$=MKZbGsu|b$F)x_9#|LyW zJqX}x#=Ye4sk6MBs`mzPqCu3vP^K11*=x(2W^S9b&v>gp&zp7mGOsSp((32KT;us{ zyo+D1D-IQZ(cdw%txY&dq)t+Kcq$*z#$SNwQku0XN04`Eh*CoB4SU@~fE5Ws*RpPJ zIBI9ZP-aLmrnT(Q;ZL}_x`}6rS=asy24j+< zDG05;IDwaobm_FE`5Z>tl=@^-V=|c%C6rGI=ORtj{++1;?vlNlM26lw@M2Sl*6{YJ H@rM0>8=5l; diff --git a/sgl/dataset/__pycache__/ogbn.cpython-37.pyc b/sgl/dataset/__pycache__/ogbn.cpython-37.pyc deleted file mode 100644 index 2cce8e43301b1ad5296c6dd0830488f3006734f9..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2856 zcmZWr%Wvei87HadJU#ZV?RC;7-U3ZChn*?5=q)LNKCnm;OtuM-Y=r=$l2M4@}R8-K*K=m{ja!J3Xnrc>&C0jWa6V5U4 zZy6=pt;mJ9q!r}k9D~?-q`6hi)ZpQ z%k}9b2kSgucsWQr?6*wwi#VOjKPZ(eI8U-fbFMrbK7=F3vMlA=%)E&8 zd1*W@z(3}YqwbtfWhPauB`=FK(Wc3@ij#~d;@X7#>o~Q&2G_FG*zLQwMw zcA%9u$Y}`g5o^;$2jv@VH5yY`f)9g+{`eiV@+Uw9*!Y4RgEx*z4c=IRH!d9X0@&9X zMb>J3@!oJ}jp0;+U|J<(YYM(&3KMC z>Ma|AQbdd`ewc6-qyea0^ z!uzHkd{g&5^;4AO9V^6ZjZkRgPjWSjwL)v%(E#goL|T+&-wH+Osf zCW5*Jd&0V<52~<0RlDk7N%@jw zPhFyqFPrLrl+;~H77z8Ks=K66Zcc=Q-!HM=1Mjch7v#?*V=wH?pU7wA1o49Do<&3iz!FKK^u5CkvHi+PLx>D%GA%WA&ch zXUkzYk+GgDIMGkzGntkq-2NM+00}q>AMGJb@i(M!RDNZgah{r%oxPID^jw>U9aw$o z)G#$M^nRv6RoUd!&p?dweISG$u#mn_ecGfxb6FE=VRzGi1U*K3@P@nExY<9p*#O(I ze(Y-FUtnOdim8tft1t|A=?G!Ja&958D@Qmd1gLANcLG2M0q@or-bb+ChhQIR_gk8U z+7kgHUgcugP8P#07eT}LD$YbcTReWAN+|rNNmk7DsapuGdlXCYf0ZUE{of?H(0W)LqggHR{(|ELYwxS{{xML>Zbqz diff --git a/sgl/dataset/__pycache__/ogbn.cpython-39.pyc b/sgl/dataset/__pycache__/ogbn.cpython-39.pyc deleted file mode 100644 index 80e9cf867290554eeb7690257757d72193b37c21..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2904 zcmZWr%Wvbx8J`(G^|b7b?RAnS+5%~#L!@mQpqD-Z^sxvKCl@2TH6lDu9oi5bqD-|u_QVB92wwy*hzMk7oCDw5% zHOzQVR9u`xqlNRkab`Nj)08V23qDO!X*xRRi%cX+PP7ygJooN{81X5`P)Um~VxspW znNr~j??oeWq`Q%DMI#T|p=i7y(Dp@Bv|bQE-4-4AhN3I>;M@3-xFl)dtxp)2|W)l{>Jtk~MAm~f6^ z{)SPa-HKdzD_RlmS7aSj&JF$6fnEru0i`Jjh~bV0Ooec&2@92r#YC6OSv;5h zEZ6;M4qAG&^m34RG-#RTS8=+Ke^n}1#w!RDN32bk9h7glLo}hVWcM&=piTY^rF<7E0eZe5$6%3TQiDa-V37+4Ede@q z#*x(=U-m~Tz8uXYh^19Bv5MfgN*Wd#(0SYfbMmqZ#KFLL1`rY$Pt%`nQA z-uLzJ`+DH1_fVSmtst*Zq2R`!=4u{m^$VDM4FzJoheGYaqPL-JMR)1)#EO2as5=nf z0Dxg)0USXCLIj1U5EKAm#cIe!1q0Wp!A??sZS^SK*4@(7JiM z=bs{|Td*grTl%01t47tVyvnayl~cK*@vpFM>j$vTiA`uO@KC3TmQ<-7V})m>3YJ14@y@7LrTV1T;&k8TEu2EUKVV;B_>H!fK@lw=3L zA^)M}!DB-0`wi6AI<2w#H7HDB!2OJsLj9Rwo1ctV#E@`oa>0DXIV5p(H|&_6o!pM$ zx&brv%Ft`WF3$#U-!r!w;mKaMm>1W^k4t@BNW-$i1XHlF{_V%=F)m;_H=fSbda!2$U43Zx*>o9BWvmwpPW03GOs1s?cmD<{Kmv}!M|((T{Eifk z$}f#G$y3v^qn9$7oomyu9gCOF3{z7>@5dqJoo&SZ62vG`7D5kLNI#@LZBn1PtO>cY zyKO&&7Nb3=!`s}s-9NE$g8SjY40D@1zk?3)8MyJM$YJ>nWYcOR!fyV8u<4-DC*3l<)tlsJs9Wuvw4xqxy;~7 zVNI*r@b^KmG3i5;#M-rCkKWMRE6xjclo#$8?|u&I-RI0dz>#GSDgvrMfbt#`pM`9B zxTUbyITo@WGQNUq9kH4LMi05A9nOJ(`5a;iIE0%v=g$^#ib)zc9f{Ij6CkYgwL-73 z?yLsbI>l-qmvpuyJdtS%+g@?h2`)jRvbP1)iWM7tC}gzhVhNq_9XmTXP#@#UPoOeE zUJjA(jB8mBdB*rfGI^Rx^(S2QJE)>wK06y?D1h4)1>6U%t;_+C{t#nmcm~F`oVy9g z!#5}3UZBu9&HioP(JwZRx6ZNrVCO?%G& E0fuMi1poj5 diff --git a/sgl/dataset/__pycache__/ogbn_mag.cpython-37.pyc b/sgl/dataset/__pycache__/ogbn_mag.cpython-37.pyc deleted file mode 100644 index 2d73dfc34df927fffb5744fd5d94de3889029401..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4571 zcma)A-ESMm5x>1Vl1GZ9Na|bulI++q?L;byq7Q`;w2d7%g)6&BP&-7@1I;~BCmrwT zy;by)yFNq?0u(6fhyDT7W8d=k>}#L=9~6CQf3tU_B)SDk;cjPlc4vPxvo|}VU)O3C z1HU`3KJkWYhVc*TOnx?qU*d^JcyWWXz!kd?5)E#+3Swqll zOF2=$@Av#}(3{|)mL7|+J;5tsob8uHL&Dal>^mlzG!x_099xOWtrhUk;3a7%#$Mr+ zC5Br}jBzQk&e&HLdJdj4o(eaB#Du=SPtEO*ha~*_B=_iEP`-j=i9-Yyzlk5k(R(Xf2#ycHLgk zjbqpSyYa8~_HH!vw&VEK_UDhEd>9-bZO4OPn*iYCY?D4N?7cNQ7Ih$=#}n%hfjRhF zHmmHUk;B}Y=+URQ`w zr-{VO48DS~Sf3%Y&Q6vm>Rg_5Nh$K>1*M>hF(%baPDjhwvq`0?QUm(xP72?qTi>Sb zg1Am#ZfG3mN)c{V>_=kQmEszhmy{q?XOs}+3bJXA+-4_NC&;I!x}f9~KRJgUjo$%5 zF4?0%!gjD}10$;9@qy>c*-AdD5CdIA>2c zrYZDg>W!B{;HxB~F?@$G-Km)(jsS+alYK`Ywk<(xv^m5EsEJKV`SF|uRd#uF7zR<7 z-^T0=6FX^6Fip8rNFQL@93Bnn7dQg%LP7{J*NoJZ2pN0!xNwGWmRh=3%zHLCC2p-5 zW9Q7s=m*pmxxI#?jzDjGZ}uUD_-SFvDa>$+vOKQHxdfr;jOn>0S>3B8rKFG)`;dT@ zHZQJ`Y@7D$o4oWFGj$SYT*sKkaK+1FuP9~Gm4xzv)C#*g7Or>oJKI0l#|L7 zK{-bNBsqd=Qq2)yEJskz5mf#Qf|^D^oLL0*96@c0U|u7bPijehieT;?8#I!}J2sem z0Qm!!Hhm*)`9}8Fx4*K}g>*4pN|%quaz3$6*+b*&zkOwBZ71>Igv z7InLxEa~=2vP|t8_UfO{je(I~%WzwWrs?(cM!NBa9jyYlO}O8%uPYk)X0oZz=}ojX z;J?M|xK4~R431xblJfpc6P7S!XH1|$&Gr6^I zy|KjW+)ua3PwD%~?c{p$9<;iX+|cd&$p*C=#Vw$?sZre3DDG+$cd+6vR@}x4YH0=9 z@aJ*~V|V(nBs9C7-ueG_dSkx&@wp+FA@v3%EhJWQ4VGO>mQTxRTdp9)tn$XG`PV1u zz2k;l8?PtraXstdw|c^Tn!s=^?;@)~ zsH47`?;@uOd|dV1J@nTfh&w zca#km^R|HLqhnQ&k?8FzI|{v6S%DWS6PIzN7Yn2exVVp$#gTZFJxtf>i?RluRK?>= zT0wNffd>nXMqbGEjVOu$Ea;MeRE+9NPFYAymCdoIDt-DiPpU@34*}FWLQeBRc15-I zI-NUJNSW*c6YpcVGcTe;H}eJ3q~WMD-;07v11nnFL(d=V%1ey-F5kN}P$OC7mrPQP zB%ODcF3LF{u91}UZk=Rge$#O3B;-PW`J$8y<8=~qp!hb=CsCigfpX(M4ukiZZ*a zY?)}Y!7Ao;cHP9RjH^dzR>f-MV|9$t&L&p1KwU_$bC5lqYXKu=q0{s_Ku4E5rg%c- z&ZWc(nZwvZ?qMYsu`+2il$?ptWGn_nFYy9soE1u_-KU?pJ5L`XyMF3E`|Q!fPe1-b*IvrO0YD`wkHt-r zO9gm|LKc&S=m;Pt#9K&3e?evr!JKNp6e8c3I;WV)CVeuvp+-KiGvXFHbj zwdS@LR4IzLsDM(oE~BV&QAJ!i`++BZOPy-m^GC;9@u-XMO%=0N{0c1b5%C+*;N=#D z3W7euza+GWRSPxd76t09mpBr-zRRNi(z^_lp@L4iYEF0<*}{bw#JY&i?hN`C&Na;_ aUF$gyGFyK{m#$W}3is&;#cP!{tp5Pec)oi8 diff --git a/sgl/dataset/__pycache__/ogbn_mag.cpython-39.pyc b/sgl/dataset/__pycache__/ogbn_mag.cpython-39.pyc deleted file mode 100644 index 5555e1144b2520a1500a4212e9e8a76233f6b927..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4631 zcma)A-ESMm5x>1Vl1GZ9DC*mvk{vr{oJeU=^q~lX0I}m3aAl(j+J;Dapt(otz}u+jezaVxTZe>8y7GAVH2!^Ws7vGCI1t(yq0P3m zmkb7BKkP;Q2_C{kZF(KN=_3@1arP|}EeV|;v+tQ?(pJ{7ec~vaJFB3d!dlW*)?VqD zDa$V_>!hNbQ}(rko`<)Jx5h1?aiMBa1=pk1LU|i;Zxn3AiQE_@P}$~DDFNK=&YZ6O zu@~(J_k~D=E~Wc`E-F~5J>L)GQ2M?oQFm!9l75h;L9Tm$+>^VhF8Le=)N6hryzdW# zI1oJ<_~|$brLOr>^upK=`3rFTwilUJ)t5mkso8P0n+DN9y97|XB1t4*(AuAvuKIpI z>ZPgg|788r+1^b?!FHNH-~QskqYt9P7u)GD+9m)vLE9vc51ns~4@DD*=kccIqhKC> zD|Ve7wF;P96WRIeWMb8s;;fEgpeiQ$mb;KSi5^k7(*La<>T;nY_W{Zv`1%})G98cn-_#a zH)$dfbEU6hEH!7zZnC45iAWbGy(Jjg^qgQ2$2uYLY)(ha+H*;)sbE9;3Xe+PWn156 zoszgtfHqL*=9yrGU6%)m81QY=WGU$u0Y+J%MpX zzOXq>qc7ELzJdn6Qc@nncL^09+ZlWdpx8UVm^1M~#}Tx~s726D6Wvsb<2nnh>ht6v zj*=d~h1oeScC;|THKksF^**M~;We;6$B}p&7Tk)xZe_NF8`*PDN~dshnPYn8qUVBA z;m*2s;+U=P5B=viq>Da}yIvU*aJa|-V1l$p5|S>J1@iYlq{ z035K=<>hta?b3b&n^*p3XP)v-niw+}PMK+y3j`hz%w6d!?X2-Iv>{@m`+W=yh5WiyI)C}fLwQ0`h4b%-FzsZ{? zOX>!nS2u^a+Ol`lRiJst><_D}&5Jr$ZFPh2Ww*#8hI;O_fO^wVZ>t;X<}>G&BUE9Z zTV$W?U3E)cQ|~~Y+v>Wh-&Gq_8x%Kz;)X$S$Dp`lP~66fJ6LfGE2yRwsN>(tWsKb( zK!=d*R(AXU+v%14{9kefJg-+?|;tTJ8H@GlS`_j z-lIU^PuS75JP0CG0Jhv zxz@46thoIv>y?}S?nAg%`hA7hII@Ku@g5wDAdf4^+jbmX?(NH61PSE>5#stb!>`SR z4`>3zwY`hT2d+`?@&@ zgNXadz)$5qKNP>gwAXL&ye4<5ZQs|nOtgLQv~z97!)iAQQ>mNr{>YCJHKw?q`ck{k zg2PnLo0}nwdGI11E*EtP)5nLpBooo!)ov08sdl0u);2EcT0a#CDR8O3&<;mnQVlTO zgg4q522z&~b8aQki-!R;G#&>rH#emy1F)b=16(m`E|HVdfLZBEJlK7#0C4sDpvAwynZYOt+{;M+i? zfj_gF4Ye%>n$3n-h$@RnijGS94-^(_+a7DOrn6*UVwYGO?e?3pz4AYMW#&3dti@`y z(%xWetPU>ptJ^Nd+tju}x395lHfH5iGeWa!POBKJW6XWyB38Art_9k;FuVZUH*pw& z#Y2(mGz^c>q15P^Egn&7bRpVVLxMX6QOtLC(l2pwGm#x{f><6$+mrz(R2buB`v)_;j%eH49txcoOnQRWO zosvpHS(VXJr$K2Vh5B&JHQyg4d>?6wuj{`5e7_eJlvN72{l}mAJCE-n`hM&``NREt z4?q3V<3YVtn@5N!v7tXpuEB)uB5t@ zVsww;(0Mkixy9>vj`xk3(^=G&B;BH9OS>lJq6|luaRoh#0`Z7C^|T+354X~B58t0I z=e3|vTYN(FRx*6LMXrP3kMUQAv9oR?C*2~)z4a6a!{mbbEwKDH0coqCldcygG>mNF pf{ap=bmw;neGliPHFBH0(EBLY_s4Vzr<4KJVM_BN*{!pd^B?PU!5{zt diff --git a/sgl/dataset/__pycache__/planetoid.cpython-37.pyc b/sgl/dataset/__pycache__/planetoid.cpython-37.pyc deleted file mode 100644 index 905c5c6ee00215d7f66c49bdd2ac5f43ef17eef9..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4274 zcmbVPTW=f36`t7(xx9#`C|Nfi0hRrTxy3lx#PB>5_AwbLPys&(Bv@stjK| zx*vYM%Gf_>u>2f!euk1MRDuZ}vo1HS)wNCQ4Bc*t&)3Rb&&*YZ)oyJuSMSy_XT>YS zMz>++?D*2K*=?G>6R!?i-PUleyT;iaCR|Z^#6&4RORtQMVqb=V@cU65>Y7N8lQ_jtr9FaOesmZ5vX4=P<{n*3w7z0} zE*#-Ls&wtp5hYQ6#7=qFMc)$@(=Umts6ArcvZ#v{w4P{)OK2;iDOS-|zhVyS)W$98WYq!QjLUPP989 z4WCx&#(PmB!pSq|RdCE2N_GR)faUDds|8nV#=mBVmS8g*twVZpX8qZE=-@OytpLs5 z?+X0euC~)`q@BYwO2A+l9Qz`IM)XrT48TdB*a~~H_P(z@G9@$Sas#KyD^y)Y^`v?; zjxyCthof6N7-z(Z5}SVSf?0%_;O8o^cQa0VL7d$}X9;M~wJn2MSC(gG^XbKjbeL{D zi))h(ByD5XI-hR+uOpw~8X3SZ2#yd6KIa&$$P#2*=NtpK?0xw%_DtOcza;TlzJf|O z{0r>{)8uQk@;a);CFC14VVHEzDR~KVbm)R&hfl9vP<(;!f>q~q(-mD5L-$Q|i0;W- zpqAZcpT1sL@D|o4Q+#IS)*0Vqg{|z|KIO(kz+;@kS3IoWgIhNDn!aQl83y>RB2YGUg6{pye2^a8XR(_S3XFO-%Q97;%FM~ zW}`r6;V$e#Mw7>6^pF3IZ-9>Z+7)djBkcxK1}ECi#zXCn(qlN-az74~3KQ){S(Nm( z1(S0{6bv$Li&TM~WKY@}PIO78#|Bcuk$$c#(#Lt!eV$`g|C}O?tR~k3p;WDgUHg z*o9NLMX4wkUQq#j)kGpdZ&SRo{`6TNtaJvyXRv*CkKJR* zTMkp@qpJKHXT=I<({<(LD@JcC!ssvQtiamuSMIalu*A9#D8YWsK4ABms?KV86;f+} z_PTJ6tZ%ryt5#Gauby&siE4_Kr&e)EHD{};m0v=vv=7&Wz0LZtl%knCc{6VSqH)T% z%voPred{L6n}q}X>kF9DQ!uM}IbZ!wnAM_;b1vs@?hy9b26V2KxAF$vf$(lKjH=5K zPGPgnyahU1IBP3!&9=`Njq(=GbGKQsCQ8LRaPFupv#TgCtVBFe6f z090%go5hwW7u&#fIrs9ix#3Q}^UV1>PoIyS*Yh1=pIOC@dTI7@z5~oV>PEhiZ{}P1 zTE3odgW?_G{n0C~7#Y7JDvwV z;8MPWR+2Z6WN)yi0I7yA)6EROlTs0mc z7j#lgX&asji2_V{q|5L9%49O~2L$eflSs)Qn{I~W$GCLu${-np+8!cK+QVQXZ_$Qo z9K~dya0uE)a0JE!(4*aHDAP=rKF?IpJJPO7u~EJY5cv~=D=TEHNfzpAFHM&StYo5_ zfjA6$VbVKU&bIn89U>_Al%Mu}ik5I1a1z??rLlJAnN%PLpoz|ry#rmnn4<<&?Z@DAE9F0;p@C* zZCDPXjRn`a11IY78@y`OZ3}B#_&M|1YVoVsK`z#@HuwhW24?EmyM z%w$G0Yx}8sZwNCTh6$b>bF-u*d=Vq%HtnihJPOLoH-Wd*v*#t=eEw0<2a0S96(#l! zNPFG18>TH;o;BThDs(>gB3ZniB<1;rA$SaVzDbjPUsrv9n2IsdTwm9H|KT`@DKo-z zh+Y%ggDSvVsgrs7M7f}Rn=0cOrqpZEzLmwhr5DCAV9!&?HkXuMo1ELcw^lI92sJKJ zsk?R8D|;31ZL`+C3=;~~aVn1{@>4?Vg57xB7}r)#vptiNIaJRxxzZ@=9mSy}2I+d% zi$*7V^Q1?Y=B@F1QqXz}kEcEIK6{k^nHQaLHu~~^FW`M8p)5?VF4r)#H&#)c8P{x5 jMnm+)UFSVYQ?z$;&eU5J12W_wf@fKsH|_rb{XA-N diff --git a/sgl/dataset/__pycache__/planetoid.cpython-39.pyc b/sgl/dataset/__pycache__/planetoid.cpython-39.pyc deleted file mode 100644 index d5599acc1baa31f8887530bd0a88086ed17ffd1a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4338 zcmbtXTaP106|U-w-M)<1qeMuEz%TGb2q|y;0Deh{H(uChc*;xoPT8KB-3?E)UH3YtPMv!ND=Sro zFCKjme!j}sKWT9Jb1?WZN~TZ=CV0%c+_YBLHmx&syCptfD|bCJR~c5jwZ&Y$TgRLg zuM8XAhMBYDE5l~DY5GpQI&5`Y!?o@jXKyg!iqd^1N@4TV?yd{|E^C(`;AFPn=GuKj z2BQOXYWt}OZwD&KLe;kP%IGNeWf%y*AH|`riS#&$Qw&wwBiPFy-G{#HJyfB&OV<*u z&sd)eN4WPZT|0C{NtExiQ{Hva_e90?OQI@j_gS|r>S6`0CmP}k+KOn3RkYR5n8Ug? zP`suqZ^uCrsx%VX>7_FG+S==AFYWiEUKGT=B@sjj^{Ln>**mCmF8Dnzn+i-j<)3iH zm6fwoYi8$`us5-OW^tyRoE^HSJZF9>w^5gKHuG}(jDJL<3Q85FCKw2E_S?2B1Gn~i zkqWagl)5w;-x-EtdTqhS-6R->yGg2cv+-z@N)?J5)9U#>I_4C9qnXc+}94#tR0!A3Iq|}ueh%GelHHP%=f=yUw^lA zkPgF6mfh{V`^H-@#gmgxHi$b!3UsG4T8e3JG?5JuPbI6PV%)=T!>aQVpSI3v-CJI@ z$7$2zv(S?Kj0|_s8Ng7!Hjs@&#IHqgq@;evc*$lg>FMuV9Ww3#uJHee#0p3w#&MdWd$qs*7T1{{cEgJLqM% z*~c#w7QBVE$rPVixpl@jSz#+Xw@A^_4FMnl+=ygBWxMZMxfnUhz}@QMs@Sr*Ml>Q7*iq z0@$jFMS$I=h-Lk;U(}|%c@0=Lih5oJ7B{cYN@wtY2Ga-cv*ZZ9ofKc`I)$@oqDWaKH0Y z*laUzfsPi=+R9tA?K4KByoK}JZC0#_Qn3!4JL=l(I?5Ayg>)TzuI5({pG0{IW!GjF zD>jPFVoQ{ZZQ#0^dwJR1a3|k+==`0_=VRxMd`H-4R zU(dHe@s9BR;1$=5)SnX-Nc;mxys4g_eJ8(`*C;!IgkL2Ie_>62|9{B1^)may%D%iT zW#h@J`|>ri94_S#(Ms|ik}MOJ7a%RsP1E>D0DZ(C%)mcCNt?*kXzTb+`vv39bSW5( z!bHpy7gvo($RC{)Q`&}KLs9{wAL;U&zcg8oe1*WB@FY_5N2Z%0Q8EEQyD~@yp|*#J zs`fBAkw2jg)i{dDl;JwGjZg}V2cSp0(@>_FF5Syi&^ywuO0iM?0wD4h!Ic#<+$0Ni zwU?$#1XgmQn}IkCdSTL=EN5GNnGO-cd&*DyK1El!7q}B`_tIE9^QgF;}Lr$Jb0dZIoxmp}0q2{4(Tr3-otvx*U`{Ov)HFe^LC-9vj6c|h{23|i5)L4cf zDj5&q4MnIcRQqZI#@gVI>=7-3XsLx_5c~XyYYTG@3Tmsxufq|OJ9exMzJa=dnL2h8 z9QNaf2zNQDTx0={1vm=+#!deP9eNlXS@O1GaBv9w>u{(P=9Z&nyKr&{c9^k*vt?Si zJtuh&I9%YkzgR16UEr_;MoS3n{xl+en@&`;zk0 z2kWS!|G*}rf7y_OLsE_Yfq;!r7bGGbPH(&&he#uKqhvHzyU6k*fmn+ieiwNIGLPg2 z0_3lOM%PHtCZ`WGd5>n+_EYud5XL$T6FgMr#z|@WBBsi1+EuxDMwFMY2ydxvkIKFI z=rg1b6xkLkdI~i74qrFzhG|QdXH9o5h1y5nFpKw=^Z;%cg6EUxo7CF(b=CKWsTd4HRr~`bKn#@x!${!_Veslj&P_IS%Ru(UvUKq!KH5sR=C?Pk= zvw5GbV3ZL6u2Hdg)uG~*y^8mJv);Z6Qwr5_DvwU&Z?I3hU^pH`#>JJ>Y|kWe4%PGg zuQZB!M{y`4+FZ|i(P*+aPm^?M-Wu;GDM!*RJh=AA3+++XXkLNF>FBGEFW^lkp)O3X tF4r)#H&#)c85eC5O+)mh+nOtzH;H#^&eqqVcgrh#usl=nl&kZm{a?hgZIu83 diff --git a/sgl/dataset/__pycache__/planetoid_sampling.cpython-37.pyc b/sgl/dataset/__pycache__/planetoid_sampling.cpython-37.pyc deleted file mode 100644 index 1591038ece75ed06e173df65459c758847d57ebf..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4481 zcmbVQ%aa>N8SkElGpU!>oFwo!!xl z+|$Zx#Y_&YL%2|c3nvcRBUStb{0s0GbaCT?@0@&r-`9F<>}(ZPjJDt1-}m*`-*dmW zvQlOE{`S!i!^0M1|D?g3su|FE5oDMmti3MUKEGACeq_1PBBzzm(DIex(j{TLsX%; zN5>MYpRpbnj&L7WI(F!Yk|;lBr@Z5$?}>`(mqb<69YaIAjsSKW4yL;M8d%dU|1#x#l6SY)q zqVHW)IT!q>ESn0vJLSLTiYqH;r`FWYEn#n9{mkM_IXOFYPkGM#Qf{Lz=WOcb_8EUj zqY6qDr6w4ta_+ZnSq5zFbt4sKVJLNJINBeCVsdrP&YdI}ggZ&9cCyiMm`W9j>yzs8 zGJEa1uKqlTN8#I2rc%4u6-rq$Kz%<-BIWy%E~3|lGVO+07S5i<9|r0m)2=V@Zo1@W z;EB>V{3YlmpM1X`CZP;eINwzDl?MRVdKDj6P}FTZ+R4JWryU|*J2Fib zC?fj?|Ih5{*Y^(6LAaM?kM`bs=ick_$#^g8$9qH*WNB}>e6(!-XuHD` z*}!$EWOY=Gd-$zebzb6=t6!#mcfo}&r``B06c5Eb?W^brF7SjTawKfR>w5(^LRb;i z$6n6wLj3StwfmZnwWY?|3F3IH`3VLmW^kh2KIy`wN<7$&5)qD{!B+vARg~-|sy@rv zht~?O*p&a49a@4-ZL|)F(3$lI>yZO&d{P0J-Cq^>YrEP`v!Qkl(n9FrwlUJ#_hU#hcb{u7@n+}F|wlU6#6D2lz^8&6B3Ia3} zlik~K+706D4mt}625#FjICN!!G@DJ&Q4?0(coxWlTzU5J0DFTV zQ=v4P^_1_mUHLi)Fig?jK9tR}{q>(=zo6CwE|U=>$%LRnFwGyMdI!iMF%RK)b{A81}i`ivy*?M7vQIB|U9HEnN`>{Y={;Rp2Y>oVEsIU6Sds zIVoX>Kh_oL10QvtXI<4_QbtlZGiz3pH+h3s@teGG5%-p7mhwg7Swk`M^+zzZ&mdnn zNxmow_|)1WIct%eP04$~kDWob9<4oKkTZVDAC?Qda0<6573IP!DmY&?9u23r$nROd z|EQ>)+#J7{*8pj~sOMEca`XDMbOs}7pndd!Jz&Wj4pZf$s{9LQ#R_MWHRa_iMv^PS zNHj^Yz}hD(AFyAr#QFfI1p66#mwmugby~}-c)kW`uM6kM`kc#qYDG2j>M2*3s3xC# zY897Mb9z~|@=Nf!_Tj3qw^$FlQ#5lYZ{`i0Xq@s*gX>ePXWeFbvv2@^ZGNWog)^7) za(?;3Gnb1p@Lb8=+#%@GbqHT8Z{-cT1Hs*57*!V|U}4jZyahU1z_ppTrdwx>MtKYP z+$~nDic+zrwx?H7uH_XY-dBL*h5X7HKYVf7duiF*vEjLj^$+o(gujd>2X1KEze!5MrF;vmBwHa#387>G!WZ2% zjgNTIhyTG04D$=LX@@ph$NTMD#(e2gFdT-7m<2el8VwO{Iw_{K4LgRI1NuGGCA1B%3Wx`wN4t|y zrkO5%oT;FDq+OL_qkJ1j`Jz0Qi^<+AL2k$CB?fK+U zGhOYb>0;f6<G(2 z*|Fan#j&obu^*0MCPgSPI5{79y)aNC8G@)}G=Sj~p{`Kvs}nHR2Cw8ViI#w9sfD5+ z`p03^KTx`224-6+4)()1)4cz@E%g?@o!&Zo9~FZaTH`Hi-E!cIEZEa+SXPhU_SpgDK`Z7#QfcInk97$?_n%<|vS zF{y+l?<$yyL%94km`?I-i_xN;n>ctl%vi$NG%bvtlY9ag?vW)Q&(}(u=P)dR(c%Pl zUv&NsT!1agXI%aj3eTsm@XolXY!MU}FpL=DyK-iu^*v9XnAn9H|{-5s`pj zMvkz9Km&V|>+Q0n%uCluQcTVeW-_ChD)dgO>={4pjB!zHBy?UIpwGodTzc}jVa zrg#f^)r~n(KEUrIX?ix`)6M7eo*uwwq&t*yG$2K5rd>x{HB#Y0N<3?F<05bUTn;r) z{3PWehJWo9#PueO_kCUU{Xr^5h`xPY_x(qsAg1sPnH;4(+Jm6MD5{fL3`dcve3L3; zfF`uXhUJs^-vQGr^K`2l#xVdaF9|879Ka+7CSP2^C?oi|T&3>TU9ap_yf@8S`wG-G zRL7}28q0?S)&K?_RB&O+l){TZIyR)cEmu9W8uu{-^ z3pw8|S-M?{L`_C&44c04iylMXz7`3j2 diff --git a/sgl/dataset/__pycache__/planetoid_sampling.cpython-39.pyc b/sgl/dataset/__pycache__/planetoid_sampling.cpython-39.pyc deleted file mode 100644 index 686f68b0df91a437730405bd5f7e86da13c36236..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4447 zcmbtYNpB;^6|U-qY?2a(qBPn@9&0>ahDa=p;RT5!7zr}A6TnacJMlu;Y4os~R7=fn zPIYlaG2J;tV>p)}&LxMOs2l%DsQ>a7YE z7pn~4^;h2sUv4t?PZ~V>IT-v9B~z#b6Fg>ZZd$8to7U;O?Gm4^mD`?~tMsew+H7v2 zy?{9@UhLP~bu(wjSNe^1!}Ohaso!ij`^)WR&faCh6{QDEl)}cT-Chy=1J)`(#K~;G z#kKpc3*m-U(EYg{o!g#lcbR%PcVL$Hs2bRYV%4^f5Y zE^SLRzhGT19N|8wwC&IlB~gCBPI=o!-xC$nFNvzCJz(v!SP+Y7Jy91|&{jl4ETOG_ z!5r4Ef#PM|{CONCp-Lm+e-QKsag_A5(@ACUjkU9@y|mknI#CdJ=F|{H)Td&jWbdQO zx#0J?Y$$N=l>e41uB@D$S`$0BguRCKGmA6jQr%QeYz9@au zKSzDeE8p*hNhkvq&NfwjC4(sOBQesRe?N##uj1nhin=XFJ6RZawL`>fN2aL)Wpuz- zTvvR*69-x5`(LqdUfn%N`{8bu-P`@(-S=LOPe!|0FWx1ZAWgf2^M}i354STok#$^$ zO16NCaSy*$Yk`;e_}bId@65T-;k0}9S!n5L>-e(}Obog%{k29v<|8m-9Oiy6H1{ znvb-lM%oGDc%=CW1}A24qTL?p!MIAd--!|tj-EQHf&-ROvggf7Is0s@;EGN7@7SRw z*u+NbkN}-of3)s7IFpYnK(+Jh0)H)6+i5n?&S4rQV7Ls9eGx%ax~c33;H*zhZ(>FQn_WvY|*2XE1h|Bb@NuV2D3I)KnjMPz3$PCG%Iy@k#k z-e7H6<^)}tADK-L`!4WKPpm(VX$=b*$s)5nKEClkPkxGzq$@urBy=f(*CHN5;VeOF zc+Nxc&)%2B>#;lIpnL%%`64P^_b(IyY?QZXrG;vCANd`cFl;;LoxFlM8qLUe`1qwu z$iF~*!M?{XL6>(?%q6~u4qXC-vpejwTZIK1W34g8CsuBq@ikW1%FgXmZtM!|%qc9) zy{fSt_Wt;d1EmJpo4dQcNF5CC?%-v2kD`8bk`9yJh3f3$5qA4Qrb21->Kne>a^-Ct zZkVFIyHGWwsO~^UiIPNM%bC8dJ(TfOrpR1Odp_o#;185~ZsuNZ=#VbXKRWxILGNH-UKk7!{0zUs!k7#$G-cwPfk6Tc!kkMcc8{sjq)ffS$sL$o zg5bH);qkTi)8m&D(wI0Jhuhg8kXg75{gKh=A(@7U-{R}d329Zdl?=2SNEw`HI~(@3 zJ4laVq08MkP%2Ec8)Z?_)fQCK6;aU3v@KEv-jd#Ft3T2unI0QR2}^umSEP^gsQV=Q zs{T1;B#u+VW;J+&*TKUoA3t*$2hU+1lQ9z08j6vzufg0No5~o)0-swOBx_BQwF!AJ z__8y|*S#hrji2&Q%Y|JygnYrwKvEaX*Sar1>q z=?uowVEXtYmb~ULRX+JEXT>6C<1OXoi$;no!btN+>=Q&N#o})k6G$_)x7ls>5#}8H zPHoPT#w#rQ`ZnX{_X%oMoz(Iw9<&a9SrE>V^(B{WwW#WO^^~hCRFfY*wTdgMFRNj&ZxjyjugqXdkH9SD<$UQf$iV8ya16)S3U zat&oGuNZ;9iZh*u{^&wJZ8vskfOtQG5`Txbc3yd=r#xs^|07d@Wzkm-Cf;1JrK{@AqDD%}BrtqJr1|7_Z+@FHT;{ zujMt0c<|b9&})Bbo&4$lc<<&b>`N>A`=c*z>|AwUzDerDrTjixNtQ*DxZF*`Hmn}v6Ig7|J=z_I zGR<`9ex`!Xk#<#zjq)b|k$VJJRtTSyEQGxo=|&(9gHD)qPUfhZU77ab?>ov*yFPi~ zOjkQ;I$yV8W3=5#W9>{Mxj^uJmf>bpGtw}uq47hMA)I`p{`#n8;&hM zKK8rAIMy{a^1~4fst5%JCl>>+8wP49LlBh=`!JCr)D^0IbppoP;FatUErMvNg`yYw z$6?evP`Y9UW?Lx^?uK!udGAT9>&+v9R{(EP5C)&s?O)L$!@WA`)=SNE5aZZD3Dw198YNflb#)08H8tX7WB}B*p$S zUn@6gOXVUdDbI5fZ!U08K7HfKq^1iLCak1)^plmS>G`Ugn0MNRMVY zo%D!G;s;VR&o?o;@9V1X_fs)MaP8{_-@i8uVv4h%i~w+_ip8rA6|d}7yw}Zo>nc<$RL7}2 z8cClJyI?d@EMxM@X|`j+K8NaQv{)KMoufFE5p7<`I?>={XBtcC(zG?!P*NbJo5DXU1fku+$L!$ZgG)O>*PyP1*Vu#BP3N O51MESlEVewu>T9NY?NpK diff --git a/sgl/dataset/__pycache__/reddit.cpython-37.pyc b/sgl/dataset/__pycache__/reddit.cpython-37.pyc deleted file mode 100644 index 0294b2fbd6cee77ad5baf0be4473c67d16381303..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3228 zcmZuzTW=&s6|SnjPS3?-d%TO;fV@NsjFOBaB#;Qp5{1AjD4K-bh%6&esduV8Zo8+u zx4I^_Cp`#aOUVmD{Dbi$f2m&K#qW9A7jjPZjGgt;tvXe8>VD35F7LP75rOZG=f9JG zht_|vv;JIYya%QH8xTPRP05JbnvIyPovAzWXg%hS0?P}hjZw7XHAhY0v9vX9kJ^^! zq@C%;Xv4PMv^(7#ZQ8b%ZcVpG+tc2tN68~1d=Xp_5y;JRXS5^e6VeZ_;1*=qr=}@S zwTj1@&yxbW-Xj$khsYo1LVghIxRkojOuINvxstKq2T3YTQ{*REnnPFTW87-};WP0o zhd`u#B0R7A6*-{772ZW-&<;i8f+p$0B`}a_J;F~`Zz0vK~~J?-4$bka7E5bJNvNVHk1-~o{)Or5QW2i$2fUe822boG7y@IPq;`FV4e-$F(JYoWd-Wqf!TJ6x(BTn(StN8^*Emv z4_||RiBIy$yuW%f?t_bK4Gtcp`8ZC?htOEBs*h)OU(;A0RMza(ad?vHyu8;go@T+$ zLOS31zvsOq4#e4Sa1q{9&{`A#b|%mS$dQX8xnv+}XJ|Yt(z*A=V(`VH?oZQAV zc*1#9<7dI=h?=#dX4fD$@j zZ40(=7Or3mPdAp)^?V2CFZ{|E?k)l9S?ceCUBa4R5iS~4aMTpRE?Gn(Tr_oS*{(uZ z6Y37g*{R$rs+yHo1r<2*e}s#c-T>)$MRd;o`TIrttgAPdTUD!Sql^8FWIuI@-ad}h ze<@jXD4F+kuj*J0?uaJVpMkCfjQzFql>CWg?5XAafjlNpYb~~huUh}8Q{>>`AR2Lw ztydg|2@YheXG$6`ji1TXL>a_4uFnctDs)PqD_93sPfra!GwkHq6|(zx3Q(2)y6$=Ik#cqsQkn@;~eyEy*10gtKektu!nI0N9$aT4^8w@o=rZLS(z(^5I1gK z$jlmvaUjmZ^4YX7!G|B)XrX>$=g3p=(Hkpp#o|c)0@&(Z9BWE3k^Dp^lS6G9wu28- zzr?|C^>9F_eR-+AXW7|o%25WUaV-f$Cx)p5L*;GEhHKM*6V{eL0)mcyi{7Cw>(SQ$ zxGs%Y1R(BEM!WEXXS#O&DA}-gT>mB#g*p-qOb0BAy#ojQUo=#h z!o30?^$u{=&wyP0556mW$(7Zeui{MP)A{{RQVD@^FUg9T-h;@Sh&>xu_rM2esj~b1 zrh0(8{0hmnSG68R{Tg`2*A^2d$ntD1Rc-<$Yv61TbN%78Naa*!kXO{OKu^&3>!^e( zt<~MQNdSU%{14aR|F&AYw-bo>V4t!F1l8N49Z+r$lpIo@g=~Iz?U`<+9pLjY6;cf( zD8GIa(;AEloSTUAX)b0E3At%;{(Kgv^@zDO_wFYDha#%WHXI81C2+uwxGp59x~4K_9mb@^WAu^jo+I zYQ*?OGCoeFdLKKja-0-rgR+QKDUDy(>gPDw>Z@nX8Ds%@5_pF30Ij|b2wRXU+3w&O zfWoFNwW;sCIsn;}!YF4p*JHpL%yg2LHYKd%494)Am)9+u9E1;Q={`hvWp#%cSZiG% J6Km0~^DWcrMo|C& diff --git a/sgl/dataset/__pycache__/reddit.cpython-39.pyc b/sgl/dataset/__pycache__/reddit.cpython-39.pyc deleted file mode 100644 index 6579a83de2175701f6f73891299eb95373cfe502..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3253 zcmZuzTW=gm6|SnjPS3^T_!3M4I06MmXvbLuFG~~x1frEz1KPl@b_6PQrpn{Cd%8Q- zHL*S0gJ5lG_XQ#T!T2$Msa}zg{LEW~6waxhi5-V-)v2me_jA5;iQDanz<2+}Z{^>h z^)KvP|6J%ihEo0&h#-QdWJqnzhRoK^#2tFH9`lER<%N^RFk14O!zS=p+M2Y7ZOe1g z&SY)4X4`JsovaVnZQDz4O*V!blg;5KB~OX)MQ~0;AlJ{F;g+D!NH4sETaZDInx;I_ zDjsP*OA6?EPgPtTAb*ex`6SkHDRqyTc5#$)C1b(&lT@0f$d9u$hpx^?xYhL!pNd~O z03z)(;fdX^$vzdX@Xi}UN4oW;&gjsCb|@O>1lqocMDv`$<6EK)btpPw4eG|%#3jQB zF1~L3Po)rvHomg;Yu4X0VZOhgjFLDVt9 zV+*HZ!nqCO7mO0^R^-q-qZQ$P&Wid<b#^+^`f(&l+dSA}RGLtISaybxjP->F#M4Z6JFXPm<8eGd#W3%TP zx0LC=aq+gsRe7%AHn{04Y8sr6(zq-+e??yZaPJ_W$i1?BvG?rhCm*DzCwt{M-NSnT zH1=vO^ovu~hHI}uDZ4-j4dLHrE$UO3vNyA}`lkKWD@PP(Ek6b|YRPKq+t3hHurZrc ziy23_=V3*k075|G-r)6GeW;a1KPzUhyGzCx!Hb;tcK2b^4JakHleIMteV6Q}_I?(+tl;zD`I=i#H`PGhB(t79q z2oij+z5gRW6jPUj!AqfN^WW&R}-`Kd%CfR-p+S${@ky8;qDNio~8Z{*eI+C=Ha|i1&2)$?2vgR z!g*7-7VRp8HKFc+tewiOqN-VWRZxL@|9d!Z={3+mS43y*AD_+Jv#wrW+^Skt8=daw zB>SmL^v3DGDVcXDncdZ!RmbXZOEj_moO}WLns@%v$)dV`Kpp^B02o|y=1`J#KP2UU z9}sH)FMwM8wgw83HtZBRI6nwjoMTIU2bd<ja_5caq}Eo74Ag z;sTD=xf&gq=#xAf|3PMDt`uV2xOpKnYbnNo@C(bQlfnca|IP*v^?{uuPZHf*TcRr# zhYBr1{Q}3DQj8@(mdW@)n}+S+!_;FO43`fFh}xHz>RXnbO(z^>U>a{FVd&T}wQs1r z3EFUF|38Mc<$WON_gim#woP5ON$&!7T^g~7hQRAkM!Oc#SJ7Z;?ISdp)rS^@#e(@8 zXkgqvV(P$xdg-79xF-M`)84|z=)FS_+|UDH3-<`d7NPKf=PY^tHJ-cV1=o09whb#o zIOnBqd_}WRN1}mQf`ziT?|?Im1`4x03mx?{Y<&de@(p~K_<~ETH($k>$S1RVzfUE^ z$lW9>rg|5Ga3XeXh}{JTpry+0^_uFJxXZ7QymhPAov0_kGrqQXFhQ1QGpTYDC|QGM zW032QCq*hJGJ_POh6K8Ro?nM0RA;T~##Js5tWtq+l?mKbY4>I>@fh|gH-VsPH)#jd zx=FX8?X!^0Zm-JNrVXj z6@XKHs(ysaF;=LbAo(Mbbs(m_Oi@NMO(A8viY^!qZJe}O6)v^;NJu*oqmhdx48tuu z3Rdg^-XV9fTIl1(L0Z8rq@!%HZO}5*A4zRG|?+-v~r7*~u_0A%|X-h zZq}N#2kl8`(3z|a);PP*q$mAXO!{f_#2Kth{+M-x*Eosobh+{Ft7N(drnytd^iHCa zGSywjw5Kn#Sfz=KcgI<3nzA^|vjTs0VK!Bg%VH9j(`>9`S+5z=j;qI8GUX1K)LJtT za^nlO%cU=aSB-&_x-yiFSL}ojJn)fhTHco}*?z@fV@Iyx8_IQg0pG?K%w>ZJx?eQ@ zlf$td?iq1+=Z)y~O<3&ij)&tUd;KjO_nN#o^p_gYm;CGCEq6gGzF?$9LX zuisrW9;B%<(eIM%Aib+pp^S%_sZs%L8^-ZC#}&kiOf;KQRSeUzOzTUFr-|Mxjc0Gj z`1N&}%{V`p#9AfeJjO=xFiXmEDg8r|jphGHx8horI3A^WsuG&~~rbH@Xejkwk?j{nZyyx-p|CTYJc_xq3U zKlms+KI)gFtiQxxUt3W0rpM|MoPii)Qnl1RB*$W}r!ok-i zSe=T$iG3Hl@R|47y!Y9>>lvpgr^bCzjB_|tC5N#bE937LYLdWnF}blTjOhK)gcOkW zKUd$zN$UGV-UoRT-O9$L9u|}7?W_1-l080~b(i+iwzRkgM(%z=1TmfXE#%=a^#dCD2xMu4>XX1a=*&y%J;=U^cd4+;XIGaBzuoMu zo6giG^C+)bwQqqTwU2H>S$2ng`oUZvqlhi0`9f6Uly9-QgN)*waAXk|xkH>FpX|4u zBAZ*YW*_cpJuQFS?_2XGh~#|KOU5fem;I?8Em8f3_q(3Di9@aNCVY-GZ9~hOs&TFr zDb{uCdvgv{R-=RxK+VOFo$~{NYn9Y*V0CHh6L)!wZy;c1>#HDHO*(T#2eX@att0*o z=IrkH6?O`zkRz@zI7GZ~BxG|(I&&ajy5K$SF1*UaJP#7@3exM`uY8&Z>R&++&VzX< zgLwn{20E;Q6OP%DZd5@PRuNq3R*gzjo(%sI&YK!(_(fATPQ+h-J#QVa>(-)OHLDh- z@24!k=`!7U8L5AAHg9t_+th1S+g{4LL}I5oD5n_vWBVEVJLh+2MR_^}a(&$}0l$PwmL zNbPMRC+;QF)Ofm3C}`1xA|E|T^AdTUKQV5Rr=@Wa?TR89|EFLHzQ#d$Hm*FFOil2s z-`L8aeq`sQM`PXX)JSkCH!8xa2`TI5DDkF2G}gyB*3XjXX;zwu3_vN7>5*|%acG=j zk(q{-+KEle9;V~bo;G}BgxWPcJrAgG6^#_2C4q|Zbzmvtt^gGgkmWiZahZ3xzzB`K zO~%XM8T9}~;LPbW!7V2MTa0AEy48Tk}1Tw&YV20!Ykg} z$djo)HXb$`q30vW*!BP@4DuqMrK&J~f|ewgCO{w%c9l=dKPE!OuiLDP65YOfNaQh* z?|@(rYzUo9qi#Q_X9!l!L%8Z43?CZ@L+8=8e!dI&Irg-*mP$OZgwo^l8{jthCcn(D z0e2B^0eKrD6tk^WK0DXTAg>b26jkrqwFEpx{X@4&9Tuh8I5trnPYQW}>KU6>9Pb|_ zSv_Lg4HE*saE3mveM&*9enR9EB2Fc2p3U(t(YiGF)A8_SmZ}G|+dZ1Kfi9^>nV?5i!*+vp zKDqS&18Dx0LyIn~#vth(=y6us9@7?@`pT(WYN47cOw}1%9u3#vkkSs9(K)RrRXs)$e^w+-lVs zp6dtiC0{ie`zL+8{J8jd7jGd!B$GU4eQsMZcKR-_WM1F5vY;QJ6=^uG^=r1zN$cZA zzhQYdZH`<0*0|kokJtKZoZV&8ll~JX{iHE>`s>+#oB?Tk`SF$?#SWOn zo~SS6#+PiDOJ4?0YJDehWhiS;*qrx0@O9a+yf2%w^@QP^ZMlXvlw+S-@QW&Jw# zf8O{Hjz)U0XT+VIXQJCPVZOUN8jRxf^qW}jlq@;$mph<_c(SvDx7YS^L_;cZKXow_?6--z3HRERaP+xtem_kj`j9vJb#httbTgKuZ?IN8o}yqVhpf6iIb?(InP; zh4Ji78Na$Vvl(TF<4CJ`ltoY!4br$MmeSASbR=JrZbj8rqG*_8iHdcyR8n6m7|F-j z1g1sA)gqk~npm9OEs}KCxO70{syx@&IpM~f|GM?9@XYd0IW$)Li=wCO*;~E6e4O-( z;-Gi`?)yJTk01AnVcJ{bu%|6Bx|3sd0mgv8v0?*+@zDO>5KZpe-*l}yRd;2*0jF`x zN1!FJSioEaBl(cexP>o(@gywyZ2;F^v2n7=^tIVFOgfLZ_zuXBmF&|iGp^Z!|C#Lz z$rcV;m*90G{w5AwXyVVj&u87wXC2Qtc`-5Wem=@zSQQ^da-@vEo2zjQ3r1wljxeHo z!-V9Z^!=2pZ($|%7Ln^9rhYRW6?%}5C$}!+`#*Tuw6i=lDW$0uJi0g2d=RI_EikJ& zR#Qh124vsByWFzN zrRaS#uF+>|P z8hEWC{tM>$?eRM(!)3p67LcM?aVKI8VBlM zVuHAG$% zT4kebQXv1BWj}J6ZXf@hvssI?=@q?Jw(O;>O9Xft@d*;{to6rMhRdP$2D<^NME-Eu z+~F)+e}@(Sy1}^peFADPvA&}yfWX8_tPfRE^t~b65^k$wHB8{Pi5w=yDGlnc+pY-d z6hW0n!JGMG!;g(Pde}K^1TB*xRZLei_hpJ&4dNmd{a@CU}tv!pN% zqJW~Ph3^Si0=99G#*Hfv#}gC$`nNW7sGr#}$>Wjkv@49b6g$Q8;!WEZYBj>=G{;eC>O>>TF zquN~->J~xw3ywX&A$A>v$zVs33tL30Uw}LPA0FdZrQ7uCy);44*dAq*L%nSi?zTmB%;pd6ZXZq)}#_fU`0!d9U+300RK&;l=x eVL3X632#=WxHBM#AjL^>Ab9!zVOmC=Uz^m%J*u0 zo>MP0h_gYQmvjHd%nu8%sk<_0hbrr-n`q>}55gku9U=%`=@_UCr{UX{MSUiWU;4d?80R(}vN zm!4+kn6pL9+4CLO@8T5I|U!F!vOZeEDK)#81Wa98KLnC&+0+OYv&PHz6aXNrns3oQ z-=^vfRO85uX-z~*!bvmpCJj-vc*XBVSe_HYa+>yi3<#D9X+H$2blCR23Hpu zazoL7eUTnI6nKKj|3G7Tk(U7x!BFn9nFNh?*^Rgs2uVD^2JUek15gPmNFgX8EG4JW zN4I zVrq!Jmp7ul65J>5l9`0zMhCuMIA?N-Bz_p0NK55ZTR*t%wa)k~bnJ1_>HyLugwsNu zESCrGOT;t;z;NvK#PMgM%eaIEyG%1=Uri#)=$Z6wVFudo;ued;%y_`oj_5AV2U*_nsotEw8dp3BsPnOh>=Nfu7n0=}UFosMQtiQy;Thyh5 z5Mo&fGr`4!9m7pt5hFy1@Hk07nm6jK(m$mKjeUqJMRZ&i0D-46?WajPOh@>mG)}kv E0|< diff --git a/sgl/dataset/__pycache__/utils.cpython-39.pyc b/sgl/dataset/__pycache__/utils.cpython-39.pyc deleted file mode 100644 index 3e08430be5346f7084e9f0f62dc1da07abc199d9..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2933 zcmZuzOK%*<5$?y#?Ck6=MNzaY8i^Ap2BVD=Q%-_B6~{mv83qF8pa6~+F$SaUUUJCU znN@dFd|=iH!sO^Ehg^~qdDKtIJ*Ph8;*%*LSp+TG#;#|&s#>@cwf3dl>+f4PAwU4(Re(?JrR>wzMcBk4>g)6LfTL-SJ?Ax`+UQ&kygGx}xn`pE!!LBydn9khKkE$qJOU$DcNS)q}eMo*h3I=8x-^j(%%)$67ezS(w zm_T*$(HLDK$o3C3$SCKK!HGS&7qzT~IbaOw;udB(6;J;(k51Ow=qzdFS?8RcbIv|u z^-m(^(&IhO<}qg{-*DYFKE;U3lqC8Q`xrNz$DhVEU`g~Iy9cfUP9ko$w(NcO$@-i-aVA@S&g}}c0;hYtqJWUVh5pS?sYm>vF}eNVkHg4kr7dgt zMm@$Is}lE@jh*=f7qiX|h3`(vO4H4(PcNF+uawb19+qrEvgM63+=EIaG8FUH^&8P` zUSR5<=)mHfZl!Rmr}bQ%+;x0UD6t=zA2>o80+Nb=GG1p+o}rKMz?Z0sk1L=G2mz{q zxqzyU3g}v*N`5K>szARMRe!yTs+NLQvh9U=G!)OtEkk_qNmaPAeg|kvbv(v?e0;;iQTAE)4~!T;gSb z6gwB?q!dl{rH$PokEmDoJ;nj`xkq3LZnVZ-CgdHg*_1c z5GMm(_tNDq-H_N8yzB}u!U;V}Z#x3%wnypB^(jVz4wjTwG1bT33yEm21oFgfGLb;- zVB2>K=S)sP$oE4N`K}ym>w6cynKM2M9lKl1binWe;c21v7E627x>%VlJG8J%?Zr9@ zOXL)fbY+JtULba4V{7th;Y@ktJFZc>w%QPAd^R!7G_Mb)NRYxcKAr$qj#BlwX$&EU z`~)6_Hx4sosgQ|6svLFpyCxh*x?VbPpo4)`4<^eo2vgL`U)HJJ#v4#rBC46FvS~XLj zjCUJY)lkg!F_=5rOG--;Pmy8lyLE1h>8yg?J5P!#FV)c~LGCZg+NTb=jyS&S3>ehB zra`~HI48HWs&oMo*ms1(mQR`QK$Wi}{FqO`8@}q}VA+L?od(Y!%7+wD%=bvbWLA#% zD{X#6oxUBHv*WjIaBQC}sQG{)gCtk3 paT!I^`_o1pRQg?dC)sbKN)bL61z_N*OuK24_R|6WD2>ybUjnV}lxqM0 diff --git a/sgl/dataset/__pycache__/webkb.cpython-37.pyc b/sgl/dataset/__pycache__/webkb.cpython-37.pyc deleted file mode 100644 index 7aa220dfea44481c64bbb4ab41f657730b431b3f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4685 zcmbVQTW=f36`q;hC6~*KE>UFX(pGkhG%RA$=`BU-8foJ+2x_R81aV3vSWui1MQOQ9 z&kiGrSoR@uk-h|NpMwC>W8eCF=Cyrse?j4w_B*qrY$+}hl)}!Qo$I;Hch1o-n$3ve z_tO1es(-FB_7D14{(KN`;K>s-l1ZMh9(Qjs_If@q`+{D@^;LUSyhRd@YrUEq^O9&> z@6{dcCyjBl*K~A{w8ksFmGNqCm9w{*tjOvECabD>;`P=ff0wnxhd7aKx48`mSwB(v zK!Fb4*8RyIi28OW)t$cS=gPE&ZB7o7SgXE_cSnh`b(tNeNrq384QQX`kL00WzKupX zS$aaQead#Z3}od&t>-CUR%Q5ro$y`&x+Wt>S7co_9xzDQlr6kNxguBbu6@dU){DUL zx()8BonQUJRtK3*RgydsovsbD-QCe()K4_Wk%NFv$B*8`jE5)x7)`-X*%_Pjf=hl- z)mIId;)H*~3*QLoolxH^=<})1c*ZZ-i8u`kA%iPednPzDm4fY8Pk6!Nu<#0BFigTq0ILZfEnvc@i3k&^jrTwwmN;9*S zPbZU18zrwjB%hgWzB@@qW@{ubj_N7~+bcFmvO}e9^nO2?s<*Vxv<-AW9V#1OC#Bsy z8^-Y{9ho@Rq!ha`(b+)dxhicLPx@vrw?QmnY+H$QI?>k3cRZ4IGj>{!iKA+;%U zk6!BTWn#?X4*c^9+{N`*sK0Q1Fllvr z_Chu_uf{2^NW821W~x=3^mkO!F-K;$;Rc6vkzyr>DyQBH=T~s-GkEfEq8Y-CKH8k~ zKe2uGIqrMG(ef|+zl!@lP7|}0g)MG&ejMNId_SJ~opdsz%NmnraXUrlrsbnH*Dqi{ zeT|yyXdXqkl2L93*?4k$3*R}(USYGBOV#N9p}Wnc2Xt;F*`S}~w?Qly=_VFDo%820 z;y7QQbHh%FdqLNvmHWL+omCrVwCi%I5=LH_m z8AY54GRgL=(sA}}`}(_>r(Z^6Tk+$Y1v&M%(Rslp=Z4W_d>VEFGBJC;Tz60x+) zId>!mJ&wD-gOGgAkKPAwa2hH8D4t_?*hhEf!thg3h%-L-Kmn2VePENA`_8fGJAf?c z`<#6d%mbhr+#TpvErP;70nXjuxW|sR-M&Jy#lE+`V&A~+=RI}5hY>sKYgZHDOM>w-(%?yeP%WeqQ7vx$Ju<1v)P7u zwperSy-A^+X5q46>;J9<=XP-Awj~!|KN9SsS-mO{2Au{ z9vX*&*XJ&4gN~7ANL)VYJ;y-#XPhjv!1Cb2E)<2_E)>JuE)Bxcf-mPU;LQfRNVSlK zmAJ`2zT|@rE`LhwEb#rt!pjzL@-$?B0^}d@ntmOtAMsa!RuqvIC+VMn{;IRCfAv|` zpY(_7g6MSboaj#mKob2FYJ~FOAWJ2HncKui$1!F-y0%}3$u}^6p^&E^y@d%*>sK$^ z?g@)|7uJ4Vi|LX^tn#w#Xnca`zZ9W=I(X zKeGJTio>1u73*gcl{yeb(uS1ZsoRQlVuQZc{bM9?dwuK@m`o4$Y;-qEhwrI0NA_6( ziqqRq&}X@^5h43%IL)THChx@NsmiT~$ge$83`Iigp}MfXoQ@~9`p&OiBCB6@v(?ea zwAahL<_E5a5MSRRGF+ryxee3l*kxn3=5jI;Qj_Gi;?lC*MwEyl4VDUm#@5r>q4fq? zV(V_;P>qIr$a>v}QxP9pp?57;mrRE|W^t|PcI1>mGQJL^ZvpWmAb*Q*i1Pa?e-RXt zQGr;1e!*8U6H#E9nl3gPkYeaZ+17c4j{rTcRuTGNLpbUXr3~Ow%)PmPWG+SL9mHLV z`T+p60zmdo{h~q;bjD8ujMb)rGeh1NdGVLq)R~3%)ew-+%cY!Us;RBBWGN z{ghM)XfCN5)*`6L$^pRd6yXM(M6NHq&{ubTwF`ZXbQ6+>2SWebM4#|9G<8{XwOK@A zVAn(L5N&I!h#U;qsH{AY4jkB!Di251T;)Ne!!u2gXe+3PWHz2%`%R*d3v7+j$<%CN z4QeWvF>E2eBdJKQwHsyP@D4451H#^>?x?(UczAYMP}pr88UA zs=O1=Y-i@3@gz}Wq%%sEuDpV|?Plr31YNfNn12+6O99}e=Mv=FdAn!E6FtNdy+_T6 z7S|q^UDajP6)uae%ig+rxoo?KE%NJVsAyW{l&8Fizb)sz5ptCcz9~Xp5uuoEF0lPl z<%gt~3R+#G#uH5#t8%T zYt-yga|MlUE~>%->U>nL*Jywmz}^JV!k1!#v*r$fo4%Qzi$*=N5z zfEra(kg#jZG4yn%W|ZWv_;v}WWA3p;B}MYs@K&j{+oaH(YMv0~EETs6(EzpT{TB!G Be&hfE diff --git a/sgl/dataset/__pycache__/webkb.cpython-39.pyc b/sgl/dataset/__pycache__/webkb.cpython-39.pyc deleted file mode 100644 index d4382fd4fdf081823a2e687784e45bedb3c9bc42..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4720 zcma)A-H+SG5$7&RiK0GEcfOqS)%2XSY1G=%IVjR1b&Vi#k_HK`0bJCLnv~#ac{-gb zQeLinK7r~(&PDnXH0^5;K=)SnFX-R0uPqArxqT@7;{IkS-KUL|!VZ`FIkPkSn;C`m zx^Llm_0D_IKbw~I5BfOyIQV!AC5u6XC0J~AnR)Y}-E~;e=5|Y_t=ui+&0}v^=~hgi z9s9#-w`%H6TpQNAbyIia#&D^-G;DU8%zEDvB~iX_iE>mwu)E8G-LzWXL+og6wwU&M zX&6UYFGAgYUxuR{RH~b)h;D=`%p%p|x<0xc2Qmspusw()T@~qG5~ui7X^+-9`H)QX zWSbz7k*CYW@)y=N6Rs%TuXOFm5oO`sw+>j>MZF??Q!k0CsNJ_9WnDDz_QaBC;$8W| za;&Zou2-~sE86<-ZC&oAGKu2&5pQ?2mu_zldV?^Q3|sbG+I{{|ofxuFvb!KTn^}j} zl;upYyG%Z-nBWKObCx@b3;TfDE~Cw%Hf1X(w+{Hs&AD(dV(uYlmMZ1eZux-aR^a7! z?r;mW3W}dwvubW1vQKS%)lljv4Pk-t(q_w+9`@I+itdG(uIvr6UYcZsWb(rNoY#|Z z7_BF%TF=I#Q7Tm=Eq?Z)X|B(?wRFdj$mN|{RS$}s6i z+QmwdG~=`v1cPLtf?#oaJAwMw9(Oi{JvYA1{0ZS9bF+L39hz?WIxGcSyKz_3-& zi^D7nf(O>4S35iDFzRI4ozBhoKe`g{-|J-kxI?_bAUk`}*6pqKXkV^C65J~*TLZC} zXMQ#2@ikUs4dyWZc+xO@w-?g)$VL_qN9D*oHt-=}4t;KH)fUcuPg%k}U^if?7hyD} zy+rM~X@g?Ro0AvPvAP;0xG=$X6soa|f;ilY;QMGq=02g!Z*K?S*l69hJot zS;z3F^Zq<~4eu*crmPw9&*{n(d>bxWwzku3q@CS#kicAMPce`!F29e-+9Px!5#$f> zdCa$b4WE~&sRhzvl5~zqB%Q7-CS}D0Gu24jI4#n5Fgl|zTeeCeb>XRftB_J6K zU6*%IjFFLGxc?l*HfrIAoRqRTDYcE3TaNr82Fo9T=tgjS%OI!x9`zi{cBB|d$|s&P zvPg?ho-a5%o=rj)nmNTE84eZ0-(OJ?4A=(4AxH5C{+xBg`s~J(D>mafKV(xIHK5b3 z1DxVh$Cwsg2S(*dW zV$LbW60%iMjAH)|GgLr@T{x}hpx6WcCqXW3#t*HTE#Pv5BV1C=lo!4}=Y&~Mw*%*e zeyG~nEz#_$E4*n5KDT_|nvW{0^32Q280Fo9;W+pm*vv|rSFLQaVKMXj1a!`WGw0!V zPvyZsu{u}f-AYc>7u7HMtO}lMaKRN-&nrMT7slp%%BPj7e{T2xp4=Uy!9OVW=hn13 zt%2``TFR>j47}GxZQ4-HyqeeYI#w;^joisSQUB7LE*Z|3MFS`In`v{>R4cO!`BL5_ zhknaS-f%3n`W4&%3$v!n%$mHYp2?TZ6D_T zi35ZIEe-+xgD#i63IdmD69mvK`5uuc@b?zAlLLDCuIuPEs(bdV>P{%kJ~#jWRhaw* zvG5j=C-C?obc0ScKv*BZ+4L!tVD~iJ*L-iQbx}L% zC`t^XB922gV9rbPWTf3t%5WcX-A;&AT$Re+j`lxFlm4%wBtt-20;JR1iP5H+(muib zpg&H>ndIOXiziX0Z3F<)CjQAKY8#1%cEor%(&Z0+X<}RXV>3FsH&CsWBEWgWv=IEu z52@C7+oJZ4p5O#ZD2hG0LRXH3vQSK z02e4(0utLZCod6}9I~0K?A_8K>$9I=Zf?)}fUsx!km$_&+;e@-omX=Th|2ITWu!z! z%>()twF+j!zlhRpVAu?v2YmRZ%{$drHEoqsZM9?#QhJc`Kg$1{C7!B^ib>dfLi%kR z!I!@&DNQmrS^Gsgg6QD3_9z`OwTekcsA@n-UdMP{LXsrX;pEcCafHxdeUOaCY8`Ws zXqfWcX&OuzrcVH?{cV}TzF?lXA}ZC+R+`3=k`Y}_(qt0JRF`E` z+=>gEsrq0zilZT-n@AMKyo9l>dSS+dKDv6$KeEJy0I<^|2{LKy-B!bqOfW^JL`F2Z za-55m7pa)H$i&XtYx8U_cMD5o1T2(EHDS9xdlA2vjQ#q2m94Th?y(a0_+)Jk(Pz>@ zM86ba%L);r6mod2DrKJ(K?Q;KgJ777F~ZqE*Mi{AIE;%Plg()lmXABC_KOoI6p(Kc z84|e&qU-Ziv4=kZkj@gv1zD%sD@4j5HI$4p11+^cRP;cZ>hzoyJm0=*;gnLryA5+L0;M>o<7q5ti%L>FQ3JWZRg=iK?M1r-b@wj?C?P+&A z)nzhcwRzZywD1rvZ%7Eq^T=N+i6@@&7xDt%skY}v9%!qoPMx}+%Xf}{&}@bb-@DI$ z9RF*Dv42x%`8gmyLMe_=2_|^Tx>d`&mg(Cgr|a^H=XGm_=XZUyRvL`LZfM5rv_5Kd z8-{k$=BU+e8QM)(MyuV`(OP$HwBB9k>>(4LsJ&#O7Pn69?lr+5v$p>lr?cHQ*KYry zmt;gdl)do*i280W;*Wc(SH!Aq>*n|{jbz*t(SDM~x-Rk;S(>A(a@|yUG|5CFw*qN-O!$Bh~`U%m$$?UT3@V+ zHMGGO%wgRSx4)*nr^#W`FSPfQGoRk{jt zKjDWpxuLkQPWb1%bV?Rk%2sGkthrlS!oI;uerj>d$C#&TB^&xD9CT3HsKb)Y>!p3l zKcij)rHRrK3^#OkUvJ=~>E`2cnyBqWY-hRJE+*q~E+HxDW3auVUD?Y9u?~OLODFM< zq|BvmMM-f%rrO1#SegaekD???R21nTDrn;<5t0P0H^wsW$3+oWT8hR!bx>$G5>TV| zs>%9BR7ow025}b4o{AR(VWcFC0J|OGjNkQALj{qF3p(Pc?P;fo(|zq^y-}TrzoE9Jbw7-UV41AQw-7_x*+Uqr_y6* zd@R@S9xBBuD#jb!w>-Xy(%>GSu2fIxEG5w=RnupokcN$xZ-Nkfz-Qdp8+7^7FZsu? zG^3koK*x55MV?(+PnL2GrMQJ^z)JR;^%+-e&i}xM7TLmIt>?~TChuSvo?++3SUbZ! z$sje^dlA7A+LpFDU)MhQ3JI&tC2eMw$S;F>6W&jgLiO{}_`x>%1zp@@)Ar(Fv=1$= zRM@$n=KWq;JOHs=RUL1ceRZ%rsHoVB^4mxQJ)J>l6L5Ai+61D>Fvck&bK(1-q~xG?r&c6HLMr7JH1F) zTWki^lC#fj#Y_8ybEYgfxOD>8UpVvh!GTia;=`StL81;OdmX56XN*T=$8ln8_g=ra zx3p!5eP`4wR4kX~Z2;+l)wOL&GS%)_CYh3CcWtLqs6COXcF2D#BZ|m$g4tv^Q*&HRTPP zB6wxZTc>Qka>`~t?y)NWrq)UzEPDQY-Ehs`r7yxQIQT4>g}BQ#wNVBqbjP}=&l>7_ z8I)mJ2TFTo1B!D*pqQDp!p9Z;Oyxl7PV2?^m9&U$e}5W;m5hjA68Pjbb&I&{f=16B#2UrM=c<*1H=E5k|;0qG-_p!tx`{kG1t;ulcnb*PSepp6HxX)m%Ll&&&P~)7Uquh-au$FW z=F%sb(bqr;Pz<=-hyH-7{Lq4f-CDp5uvXXxW-JpoEP|7ZFf(*Zn}9_*Lr=J`teI0f z^BRD_0f)xC3uA8SnlX1VR=Y6fl>jNM^A=+Oqq9A0luDX4SKskW9l!=6%)poM*mj0;VVl z^|89XZzF~RS_r8PzQ}Ka()B7wLtr64k>4d=K)IUX=vJZf@i@r_i+*5|HyV>4V(RPv z;DdYFA{l89X{X3X(>p&;V`N|3Nj9FSZR~|S)#Pv6h(8El*`0PvKA>w*w36SW>LaR1 z$MPqrw6mAzsr(uB*QsC2@@yK*T-Shw#>wn{Pw8vBx%y-@PUBIWA^)vBqJ;(RMiuNS zIgyX4dP3WT=a1F?@{ziE>5Wp*%-}qAtt~RwZsTm8$!M2p>{U-#y_yLB6dNu)A~Zg+ z0WUZ!b_o9n?F~-9b)=*r_blI<-dN(>m9)C;O7aIu;n^s&`d`wnQyxz#5c2piiblDZ zAl-{}BZ{6+dTBLc(qQew$&mf4<4OU9E%F;w{g$dtRJyrHviotGM$!2bVItc!i2{xJ z_s|4UkdyjOtI=>9u3z)tss4y;RH}VMQ@@W&*Ycu6DTv%5)?H$G#`06*kp6X$ zz(f?(UsSlcj(omj66-w}mHF#oAW>ic-vOljlA@VjTaJO#nWzMtQH(Sxk-=GW!}DiV c6g^7&_bbs+BEt%@K1zS^e}QVN0cykkFH=YUs{jB1 diff --git a/sgl/dataset/__pycache__/wikics.cpython-39.pyc b/sgl/dataset/__pycache__/wikics.cpython-39.pyc deleted file mode 100644 index 63f108958f102640cc3338d5897f7532aac5f701..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4071 zcma)9&2JmW6`$EJE|(vYW%)z43n)#Sgl$dPD4G_AVFaiXBSoO5bz0Yii^YnwqLy0j z(zC-@B9=Ww2GSnfy62*RbTrVvXn>r0^0_x(+TWWcEz0(xyO`NGUvFmK$M3z3TCITL z`QY(q@pmhX{hJ!gkAuc1D8&IP!30lPuln+yW!m=0>AAe>^LjPY=l6Vkt+YN0dVv|U z)5fUTYZ}@~TcdWbZD==L8LjqKN9TIyMr*w_&h9bciP}>pYH|C>?wuF>KI`}|u{+!D zaP9W@q9h~Yo{YwOAR61bi0?)!Dq_{Kb!)tzhBA&sxSOQ0ZixIzmgZ=xT(?vnPBM|m zxUXW-C!WiXJ8u-*AY#_zg5i#9-?Ck9&Q$N&u_Js@f69({&jlTbhM_&t6s@NWw{MFT ze0{Mh&f#1CmN~2!;0EWl_aNC%`i1tMnC~mAyQTencQ@%LQTmnS*t72~9fxR=zKv4s zpekJj`5*DGxm;IVSV#ODUOFWUEoCcwkE~<2w1j<$mHfown2#|})k-$>k2vUhX`>EG zcHAiK6aI)=O_UZ&TQHo`*?zf>ou(J>kJChLCSo(o)n+jnk8=r$Ngsoq746C>8^k*J zeUwh(Po>PIZih*+BwOubQ7p{@?T2BKB`OScJuHZEm)hkYPWKiXseOuvZ^$?d79*t0*3C-M{z6_4M#ys~DtPbU@hK zR;A1C_)wn1eW(<;>T z>X-a3EY9fXRh_o(4vRc{iw9ZEd6eP`ssSt6AJ=AFv19&cHnhkl{%$>X?lbv)48v3G zycla|m?s&eC!;4J9HV1tt9wKHWmh7hr`j^PADr{waBJu^?o7yVqDU<*C;l7^rb20FuEv zNigzUBrz}q8cjiV5C@?qr33v#c zpKc|r4K{;%$?<2l;-!7WIa3y#-8upc%pH4rXHThdacgU9kf^=MP8TZN8sj3_VVoE{ zzTPjcFYFrP-x@`Qisi!84M52sVB?N0yZC5#ER#%0vbv5_Db$|GR6FFwl?jDtzaay- z*8-A1#LB{iAMWsg`*^3T3rAT@IwJ(~tp1q2E=e={q4$JrTm1?tOvNd zgZIehEL*+FiW@f=H}6B#%Bj$DgJG_%yyF_M(=jclw1+$Jtipztp)B_=}c*yk>e0@6s2+2D8|# zJ_~S`^J=}UAJMrRqA_c#3uV0w$_6mpE1OWIBbv|rS|=+ijpzLEE@!h9&ZgJYm2$38Ep}+krNtBBC|J`bmfut7dA&X z81Vw#Ds1XlZ*d^QHykK384k@E5Uu39%Xnt2tYgWKFiG zajffcHi5$aR{lnpdjI_T$_N(PgBQW(UOHJavP?f4M;XHdJ1=pb#n zJ84GS$hov5FhSQo|Fua7g?3{x zz#aj8kT{6=K-)5ZqV0a3>W1lk5+{Q_g$vOXvxpD0C3iKC--H`~fzE_k1*xC$HU5#+ z=51h(V_igmzKjsGRI? z0nUaEg0EdmSfYXRSF>toW;L?QkXF-3>yB!E!!uvC0Bn@h`dHoAwGoQ}I)wLC@c9^& zZd5rRg42E>?@&*jGC$MS?Ly__agq(@?YcveLfne@hHxaO;>)>#)3|>N*O54kq@cr5u4!js@h*%Qn$|B zQ4VQl{WKA-&GX=1^YuKFA1{*PH~nMv&9wSgU^w@az z{K^=p)EC|!z%PWP_@>V-#?aH9s07R?a+>7I;IBF1(#0s8R;jXoyAs|ViXjErA0<-w M|Df7xf?Bu#2j!OqApigX diff --git a/sgl/models/__pycache__/__init__.cpython-37.pyc b/sgl/models/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index e80b5a4b0f665408ef212af6d8c270a6243ab39c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 127 zcmZ?b<>g`k0&@WEU(a+6KNzEzNkB`sH%PfhH*DI*J#bE;!EX_%^1DWs{h#3GK^cxHS diff --git a/sgl/models/__pycache__/__init__.cpython-39.pyc b/sgl/models/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index a57627d14c1000d33570937964a2ca3cc658e9e0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 131 zcmYe~<>g`k06DW*qDZmLaSeqY^p!h ztD;0Q#e7DZ$g?mo2?n!FNFbY8B*-$G46@ED2#`%JvdJP)mO&6Cvq{dmRbAa|Hf;w3 zXEjaM?N_hft9R?Z`#JYkzh19b41E5>?gW4OvSIuSPevaHjT<<^AA=}EnLVSGew%%( zWtnN)?pu~&8hxi#kRGd7?3Y?4Y1=)wUv8D9?er@BYO87*ZyBnfijNFc46M!CL%UT+ zPf5AxakHKVddjMTo=Vm;hn}jcp{JJh%q#Pj(X2lKZ^mlVObV~}hMl`;EW8~=fgY|7 zRdBN%wZkAnrx$Lubr|^l9X}f4Uw1bAtSbrzA`76;rNTrN@S2v|R94P3o_Mg-*h7H}d_Y>Tl`cmLEpJR+yCiwV)mOs@q5Za68)CPEYC>6!`19 zy|v*Fw`2|k8LsoB;cpGYXne7jo@ZC9{vABL-{+GjOI)iHgkhl4g?imEN*1Sjw{;+W z=?cA|tq1Kpy`bqNP8jsolF9_bZq9FI%q}*&?E9TwI}CmQC&tfSxVkay2Uo-J-ql-g zz4KCUcjs!j-n%M1g;&S?EpP4W5~M?W2$w(%vtbtbXVuIaDBF8yj_l_0NRXY8#_*e4 z`v#8i1CZF*gq_?i>lL;ROl+)m$PL1=X>U5pyj#@EvSo#P?bcHHr82f+DR3%*z`NYL8lEt=Cbl11l{$GC{W3TE-axP1^)4Cw^y33u7bg&Oo_JF zF-*_nT-T(ty}qvfXcs!On4bqSXaVO9S_9oXcT}k3X|OI zr9%U4Zi1U`f}EKLWo(GaQLvW+#HugwF`goM8f5SNcScI+jRfoUp~mg?z#Bq7ZFm!} z8+wCbF(zP#K11w zgDa1X$L5-)pdF9kKR4~0#zS-8LCZNP#16;cd5W=(Rv|9niJaAy4)=>6az+v;N)7x!V-vWme~{=~ZH+)51rvaq6I(@L!6uO<%dq-p78%+oKCh{4cSepb0gZ`c|3 zx2}H+=OO}FM)Rl$G9fKrL#S)_!s}>^W;IJ4!ZKsC@zOTLf0tdt4k8}sVn7{9XJXB{}J@1RY1F_O42T( z?W(f0OK4YARoX7vHC2~(Sv{lX)xsmQRZ)xT6re*@EveJ^uBkKXEWYdNoH~#1hPt4h z!uOndT3y8Vyg-Qs=%JU)i5y>h_q})Cd{;uMv$oyqb;Ml2HNSIHcm*f3AVmfppzu?W z0*dUr5?d?*4b8HqZRuYpVPHh85AMS=F`m#CC3=YON>2OG89<7`x*xXtTRk=a4i;p> zWLK(%RvJ9&M{YOTw%jIU(G@H#dzVIIM&c}7jKwgH!Xq=~H;*?sn@>T9kU1y)8#u>e zG~uz*q}p;FeedE(J&?6`aZ2FJM>x*C!e=d)%ufKEJ1In+wRSQ$wnQ2BI}?07GtI}; z*e}G1xP}uR9MIqR-H8P+pS)le>}N0O;vSUsCb2m}fu)I{lS*gU>rrpk2D2GR=0$;k zWk%z^JHd@N&CT@SFUkyEt9}jS1TN;dzY|P+Ji)}VJ(wE1a=Sxk!w+Fq4gz;@5}O~{ z`-T1DekrykE-h0?aRG6OansyqB(mb-d1JpkxPZ1jbXWdp_!>fv%Q7Rjz*>ZyGa>Y zlL{n$OCYJHCS17Fjyf9>(UNj^ptd_v_kPe?!ZZmX{w~w4bYkMlV5A;NIc))|Bn46S zmeU>hJ4p${^>(z86nDW}pr-6IxzLA`K~8}+c^xSoR4RoUZHPHM0|HQEx#m2;%)GQJ zR>cCad0|?IPQF#{P`QWsK7Iu!f^?wV+%153;0nQzwzQb7*&=l)z9XJ475r~vroIV+ zUO~<70PfQVCeo|CM2~aqN$J^VKWT*F;?xXt2{HEglbt*@M@@VxYO6R%N{#vmouU>} zBW<7{J2hcjo`ky#NDvGF;OHzyk5W-cv3tWd!a=G91@2oo^*H){OxeW|a$&<Fi5l7bPsB5b1u%x`B^%*pd8F~ zV^`Uj`|onUqB77|j**J`wK7)U+bjCp5DrX~F_>z#_u&dtl94%jlKmkjzvykD86qn|a_q@)55v7S zPME8-HNY6sz|b46d9<%IU-#Cw2e9kmV7WP;ri=|HogqxM6ZyyhHFrr9erW1?QbLl` zjv}oYQUK>tZ!pw-erU7kQ?rlVeXxs-qmwzZWenN6Q%3qF@EUQwiG*PsaVk?T4QetftCBrB|thMA6 zC6m@qKc{5L$Tcg!5Z^khWJ%A4`~vi7$&xj#+L2;qJxx3q(bMdYs8TbF(+mk|-cdsf zrAU^kSe6>iHNL+R(Kn=Fo!*ZMQn+=>D(>pHH<2-59ftVghAk! zVz%O+#U?_ex#?mwT*0C914FcjS+4tuag$Y&0`gqsyipm3SW)G(=wm-Nu)6a%ap9)% z@n3(0FuL!q1IO;X9GB4ZBG4#9X62x+Du<@3GTBAtFP(+Td?=tH@k?08g;xoX}z2J;USnl4?wGi zeLuf|8r+NHVS!o?q+05uzv+b=!|k5(?jWY>;r*@()aAWaK2`3Cg5l`ei{6Zzdsy#; z`+E-n;Cb`zI-(%&{k7~xhu7Z*ySUOGApD0>8*q)W{MZ|j%X=@qmqLN=0A+p;X-`+7 zf~f*eD(4i}zz8OxSey&2{4>_RNVAQd$il zL;u7Z@S%&Q{v~`R7D`R9Q{xFcn@jpRc2>IKNZgvsCgt0xoTnAsihzAq(!<0dT25*i z{KuP?PL#|I@M21`nbbKYm0X#^qyWuW541SI4sTp+Z*2twh1WT#vP7vXyvPfVxVQp9 zU$B#u1E`@?{u381p9?k@(&VE=9`I5usb~9}zAH&i%}~PROukt;Q`GQ@O_XKvB0^L} za+)==Q(&biNE+j5GO(PVYY)TpWsWUn@<B( z{uW1g1H`b;nXG0D1%k7P^oZ(^5b}&K9{Ac}$!Y*Z0p(#NuQaShq$GeSmp)tNj-b9A zqc9R3D}hUias=(9)+<;k5Be$hgh~I3V+2g3>@pR8muJHyMh(WIH#mI*1U}c^xenh4 zU;XjRsA$^<(02N|oo8bQeNUsFjp~EIzkfEZMBz~tm$ogjDW1_o%!;x2?S$VFQ9uxpAE5*vwLLDJKTl~I|RlE?Jf%#Gd z$DNWr{VecC#PXNqgfInaH;;tlc@&%3Pg&(3J8S)XVD0QmJB(&%Y=^4J55Kt-kMM1f z&joCRXGmT_@tP6zC}5LFpOuJLnza*vWFitbPBQ25;JWh*;e?*=VC)LLrQFCWmb<#lSMc&kNh0 zwOPh<%mxZ~$Q!PFzUqA`K<- zKU4rjbI_OmI6+=DbTDHAf-&44d(*G~7`?4a(;#b2BdJy0wKacZxD}p=uWsh?k@d?&?s<(CWcn}YOpc^mf@ m-X{E!I07nv;$DLm;HM7PA=DJLvhci!1Aq47wZ)Ca@BA;?*|Gls diff --git a/sgl/models/__pycache__/base_model.cpython-39.pyc b/sgl/models/__pycache__/base_model.cpython-39.pyc deleted file mode 100644 index a879883b1aca069674a7fd0d8f8ee75ec69c1327..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 9683 zcmeHNOLH98b?)2UGt=_~2)-Gb8p@AoR1zpVw&O||MGuOS$|9nOic+3P)f`SYfC1+L z+-{Hr(r_G;GHEAD?2VVM!7f<4Dl1pksY?DpdYP>F7q}4fedqQ(0myd6$tn%@nfvO# z=bm%!Ip00aR<&BPa6SFyC*glyv8?}MW%_bac@1~;bCAFaRM%>nU)8glwlZa>XWN#t zdTui(Ep|8GD>MsIcDi1#*epuf?Us7wW?5NpTR|?!KeU2;Xm3;=IL#_r3W0|fFKelx zr5KdZQp#Eu&{7U6XsKi^i-EdnHL6d*o3+wVY3_~gpnV606Yqp^s0S;9AiUO!TTvLJ zu^w% zy$`V_DEm%ggSv?WnhVs&tmSbRva*MAF|a^OLd!xc0moGX>!x+!tQ^=YE7u#fG#71e zhFa$^Q|kGCr{9TvKP~&4da&t7akv?!MSnGH#eUG~p?$C&Z*H4P6@x;5O}94J{lTU@ z1A#_stknF?K@?9qE2f_Hmi^lhfZyZl(^H(Q9Y#?Yn2EZbC{CB=TDNs5ZDxjU*wX#h z?QVGBrfw8=SJToI!`_(R(uCaug@yR-ZYzp>|G;|s;+6G5FT4^(TUTzr{oc#no%>g! zweA(+DY`P@Z+UY^7a<+)jp!7JrD`hAKf9tTpe*m6nbGL-NRaK3)$lo%sopnnM<0PC z)&};>w$j&G8Y(E)-lvjqZ#Wxnpzd0FS=5%>`XCT>vu6lpZ`k`Zg+jlQwez_GmD_Mv z4xDBg^4Rd(gMPaeHg*nio3Q6|Fe*5BGg|z2WI97)*IyTppF=$&u~Z47 zbe}=6wkKkBpLq_1du~>^lSh9;@qPetsSR7ic`0@R`+@qQirtM|Vsq1eX7AaFJD2O+=qA!8j-4EUy8DV`S)OrtQ+*%9k1IVTYuTt+s_5L8P zW0P$TG{j!-^y~3@Sm*9|ITJvqzgFK|$A-B4MB&9SR0n&lQPK;j=|z${NLu=+)!h!? z(t4m%bv-RagVng#x}WBQFmAQi(_DYM2V;GX4_pR0FpAL6^WnVN3G;d=laY*%-^aMj z!b_@R+sc0C+BMiXZhQCqoDwF_9j}E7-k4ZJdMy|>{Ek||>I$m^eXpg1jK9?B1hR;` zo%SI+smgxzvnReq4s8%i?f%{@KQrTIILc!zZw>9o z7If_NpMPXMQmb|V+jt01y5U^29;iJRC3l!hT#mt-@`-~|F3HJ;yWXbeMi)7xL*;IW zE`Y2-3ZE?PVS~#|7H9gIz2)9CTcV*-J03Lb)Lwoib-7~>Y|XXmZ<5H?F}vm}ePesj z+kEv;_-+p*4>ftXcN?<;$wan%6=AT|jb24z^lYOb8&{J*{f8vu%`JwbPoS-lmu+lV zAT#NS%uX^tja}V59qU#J&0Fwlo+Dwgji#Av6t1K1qj`h`f^&ieXf(y96S@I#z(W{L zGlz0MC`dVvvKJJkTtK-Ll%?zim7ogjPz-9p0)9)uVsHY#5yW4HcF3H$tdAtKCm!mL6h8w6Md z>=j+W(~Y90;px{&C^ba;@Gg84qY3w+#18RZ>1iJp0#MOk^P^U8v&#bD!3lX_y2;p_ zZ9+$N_I0AC@h*@luVY$KmWjrU#95e_NMaI&r#C0dyx-&-=o{z?un}KDv7t1Z50o+e z*_fp{=MCJ(1KGP@nv?wU4A(Or;fp6t7pK6^ZNgUlHs+bKRN4Itll>=y{(Op;=jM5t z8~a;uW8A_K4Ct$v@Gqt&y?p$nS=c{*TI!UN*ELWC=bNlK{H3WGZMdm7vL1a6tbM{<#dN7+f zh!c!uYEK|o;*RTuBsT;u-&ers*3eDzNp9cX^Aay9tYT4SLAi;`g(JJ$>UXdle2~`5{j{)yy|x{%r}-TM^I+b5 zF7&Vu0le`__T&5mIhsg(i{fIiCInl=AL75-eQ{oB$G>34-B5sIA$|=NfdJ4lL`Zy# zU@Nf7krEFx&O}76I07RlDnP#n(r}GHygs~(54fOjvg-qq*-e^Wp!Yj^mFpO8j6$8l z3JP`7f9 zLVSj+z22x8C|~1YAcBRoV;{&|Q$w1OK~v0W0RXfW$6Eg-YZba}1onaM`S*=7s0U5= zCx`(Mf^pnnOLae0J85ZR9%8ahXbX>Xu2~&PknZ7W!A3N>n(Zo2m24QFO*nuE4A6Oa z*M0%==SoB&gfdivD5Mqx=@Ye^{b!q?baG`7zl&fL_QF0exgcu&GM6*QD1pQeaA#Ii z0&2-{=TwxLLeOmrb#xI^%xwdB5F$+_*wMwA1E!yO3@1UI7&&u&;%p5#6QL8yG;&7h ze}=NA_ctK%uUMMnoKgN$+!^QG(TC&J@tB$BJ?7m=beN%d$QI~$#HVext8CKjxFPHa zcXK>(2|poEf=$%!*?Iqxc_T#OH!!TRs0niQt0aF)B3Q#iWW(ty-jR5$-(|;D68e43 zjEVk;^xN~*sA1A-FeURa^6YYIe|TG44b(qFkTS_ zKf;7F*`c!Jg%t8blG{RlNHSW@ykNNkc5fv;Nr{+r)o+j;TESFDA-Q9-L)aHYa-g%> zA*K#05_ak48dU@DZlP>IH9@g1lQ2dYbSIgBvOhZd540FUdVxsBWDKUVIliQH&a!VR z(L)Et4s#ee%~v0R^!^d;hw?7WzDVBX;t>)Xq0G1^=I0_yE zoj~2OqjT`P_|^5;GHC+oSNb*SccSk|zoTD6zawb`c>CCXU`HP!vEZ_$zXS!S4fn3n zs~DAI8|t|Y_;*A;00Wuv7m0%JR~sJVAC8sok1eqmo)x^u)-@hFb2u#&9$Ik`o30cT z&!Ua(P#2=WHS3dq`qiq;*|F)dsWO*8!jZ}qyp$tix-pX12sI6?P)kE=Y%Ri3pz1?Aj7Ra@WO`jted>WU5^_NU)_tAiAXMpERqNgmP&>PVKn zfS3~udj7Zr35}N~!@^)Kl(dtN_C`HgA8dDn`fVgr^x$qM2!qQ{o8w32oyfk7`d+FZ z@$x$6JLdiMd%$LVd1nn_jPL$>mOZZbw!toXT76`5qPW#>BjK}JKa^&^{PgAbkz9-L zjy`gxbsUj80n$N1SeGX245U7I9ydh!T^OwWv! zHfD;}KBY{V7TH}f8^OXGWGBZ1c$n5EkIBH|xGqplf0<(onLN^^IcE)zoEBlxI%c{) zRUKnBckqzOlQ5|M0(ZpNV>#!PNfKtiNa3)a2aW@Rvw`SX|D35Ev?7yRK|U8LAM~D5 zXNdTK{Fdanapu|Xr7t$ZnR6O-PBVEEw;Wuivy6hE^w-g)|A1r?@Wsv;q5qZ7(V&n> zlvC_n&|N+r8ZmnC1Rzu2LIK`a(uwekuOLn63}N-+bzQVv#K94fML=Ipz#o#x`#2vc zYwO}x32Zw826>!hdNGfaLjkS%q<|8RkH7^;y9AgAkee2%cSs~_*J6o*cq(gWj*e4V zyOMki5f9YN-!}o55UqJ?6uYNWcg-rYi?=)dtZgcjCsGsv636`;?ufw9Wb&x+i%@$7 zAl@M4d4Z0wog-@zvU1=Zh42h=qn0FncRXmjg(3fdTMPgOw}`yX=cN~swqs8l_sm6` z;q^a*XlJjtqWB0)ZPOe_7t5TEcf{k>uZ3v@WqE;S_=bb>EKHM#FEP}{N4tu_XI%R6 z_)KhrXvY{0rSZi$E5(R8oJE9{;jH>G%N}DbqtO|xMe`VAt$+D{to46|wW`F7#AYL` z7@yiW%{sHopAqbrSsR|g$F37nI-8sz#se)11x7@`2|D5lXY@HJW=;(Mk04N=Wnivh%j5cM7=`aaIr)zXz#jO%cL#lhcbd-o zD5g%5DG#0UIRzh9p3wc6SDh2sq_@~0t(v8XBqrG=xl3}7=* zmJ*UZlJAr7sHT4kf^YJj-7t0UcEWp&C6l!H8+QJHAqg>z{*>fLB#%g* zC;2kTW0D_}{DkE1Kp>R$7V-dMF0%mT3A!VO%HJ(ODY^VO?IBf_f9ARTRDI#u?(@Ii zB7djAnAPT&P02~|dncfYi#GYjIm${= ztW)5F0p`$$LjCt>ZC;v(7podayP&hW>aPzrqhpc9<&ha%NCuXOC!RZQokw8`{TVqF zs}Ld?1%WM$m;Zr#7A|}XaDhEQ>oi|{ap3<2TJHB#yWek~9MzUrxBGHRYjxAoq^ui! zW(ogsWSJ;``8I&Q&yr+T)6<7La1h8&EbQBkJKac9v1$H>eAbjtj^@+aExcQJn_7;n hiOWg7t57q(gT4xR87RiA=@M@IIZIcU-d+0ce*-a0o;m;k diff --git a/sgl/models/__pycache__/sample_models.cpython-37.pyc b/sgl/models/__pycache__/sample_models.cpython-37.pyc deleted file mode 100644 index 498780d04abf56d3202011acc119e2d88d795a48..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4119 zcmai1&5zs073T~oiu$nQt_^#gq$pdoNz}$GCk@aXzf2W5Y=2WzY!nd4rX#0DQ`q*8k?FbysJihsOU%&VC@r4U@ z2kr0k2TA;ft&1|Df9J*~+D$#XS<11%RC|7weYF*q!`7fMq zMO9URx36ldj(7Dd$8)*?&TQICyE-24(*dJwQlz7Nx?2y?9;d@`mV`Jw%;M*Xem^wn zFwJ5e7Nc;#D8^>H-H!EP`fO{Y`|bFNX@79*=IyPUAKt$CGu~mRx4b32D|lSGaTi_T z6k_Ha%eg#(MrRmD?o6Whib~_2OvZ`Ubh-_qD9zI%ifkjzqrD_9COR>;Ihqu!nU=&^DMq@tZyP(= zC@y}J5f)$2ypcUWY@2@89*$I!nYM|^22nZN8b8-nY;h5fxrELU z7ln`ChVaCOY>Mgi^QPQd3f&XLk46&uc<$pd56~%5ICJQEAm^@veow?(&JoWC-n=ps z4Opkt@xWKk)0+MSy#Cp5tarE5$98C6-+M);wW`(|#>GC=aG(p_1PQ}*npjW8MQnpf zp6-qGFe>#THG-2k3iV8X2g9=4pb<%Y9giVhI^rrMu8ZmSUz7H{4SM8n)JXH>1uXR~ z3|MBC@!2GJiZ2>nt6{)-x8fxMH8G@IUFuU8e_<&t&}2E(tJEA ztQ)I=t-=GxBVAY@eg^6}$!rxG_%hb&d)U)3PQj*4NLQEB&DZ2xFt3bBY9iv64w-Z?Yb?& zqMg4z*mW}2QS>er8*(kD`mIh_a<#-xtEPWI2h}5T>Ir;1vdt(un8exQ;A#}9Q4brF ztFz^GNTVOpvI~pQ=_OePmX({}esq7$-a>|^cw>DAt|u=x0zU{C26~{c8ZlBc@1R>@ zD6nybpY8%bJ;uul6c<<*i1>4Vg~36=NQ37&{QD2PGFyi zooWt{BZxajzzOH@ouWGPl-P8R#V5}1uO9j2gv=wCJO>|rLDpST-E!ruLIAmdJ^f>x z2nEv(SbJ9Hxvf89JwF6sS-0Oa){n>IBv&OWR&<>A6I)e!G=_(^stXs+2_UVP@2AT8 zIm4B$=9~wN{s0G95iOO@R$$GO{6EEGC}17(Vnpf&$Lq8SwXNZZG(>HwG4|lUf-n;l zjx$+^W2wi7zlI-61(5oqyk2VO*6$FcN058q&b;f+%$-$^DTE2E+!I>u4SY}mQh2IL zzI2b=gOIi7D#0s|D6tz><-a}LDI?{~88`V1e@6SjM*nyB7iV`bR}7H2C=I`qaGOf! z)xXhz{z0K*B%-{pFGE}qT$K)6HCs zU;ms)jcnA9%xJF|#)nJWi9_oFdJHj-wH2xUF0smfZyhgXpoE~`!DHS;=OBsH#hc<1 z`b}|FzRl*_#C+Ny(=D*4CE3mNPq3opb1>^4VXP^J%9sslL};`=Z|N%RB`tjV5Rakw z3&h2mCd!$}T)Kz3i7%_jM0LZlDV6&oKyAZy7HS)=uDU+QXq98M{`aDF5*?2A!etFo-hx=DmBDtJkE1U!XW3Yi2(DwcLyB$M5W%XFysXwjd_w7h za?Y%7(SeIFID87(N?VZ~^hfAyeZ8Y?qLi78qcYrEUWxrHR=orMrA5yM|L6Eprm<>S zn)qURd2Nez{GTpID5>A4g$f-dX2wa*Uwul^-ZetLN*}q5@bmCa;R7#0KSrioVCFkz z$ubX$8hx2job#JnH1~( zDb;h(c3O2!?xAm>vwo>2&H5qLus$_;u645{*KMMaPa$@w#?d6~IyTB!+yyPdQ7#w~M^yC^q$rWba0iL->9_1j`iVN*5vJdnIlzQsv zD}@}O+{1StZ$eHF$SDr4Ft&NbMVc(v9SLv;Gy0oRc`1R$lp6i%UskU6D1J&hjl%_O zgku!EYY*P$!Z}P-KY{;#idRUG7d{;2{j%bIjdvL>+u=@GyN4()b$S?bwH59>Tovo5 zmTLI1&cm@zRNABeHaP9G)nvItS(<(O7M)lB0&bUBEv*b%7 diff --git a/sgl/models/__pycache__/sample_models.cpython-39.pyc b/sgl/models/__pycache__/sample_models.cpython-39.pyc deleted file mode 100644 index d488bdc98f3cd06856fa47e2ff30a4f198b950e4..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5267 zcma)A%WoUU8Q+;*E>|S2xUwqSZP0FzC(~GRQow1{G>^)Wn${5lxDS+Qx>$2o)Y9aV znq4{yOI|7$Ii&>(v}ft)f2YSe73iVxt>?6j`}@A(Q&xZuF~fP!e6Qd4&Bm85RUD+P z2YbnjYmV~|YAhxXjrWlAzoJl%5*?>u|3yP^TXx-sD{Pnl=Rhb| zc@I|_xTEqPI*pPls{sGLT2U4Jmp^wrrxDP(=JKA7`-gN$uQMppUN+pTgh+SOZoiX+ zxIoy6KTGs$VV-u=POQVC7akTxKi}TkiuG~&V6&%NTk*Yo>-AeV-`c$S##=Yv2CPI6Q0 z4T@PyOPX0Kdb)XNs=J+DT>L!A@}4&Sr%Bp6EHn{e)^yTJa|kKwPsL>}W3!}~cjR@? zczMz}&;gd{6)I@)BM3Q)nw>b$qv(NJ1vbw6%w+1Zu`m^2z?uK{pB&{y}RGSc1Ztv_c5JVFB`8L z7l&ZQkuG!%U8ttR#CR$$ViOFq^q{A^5%Z=l^)q%xp&n>hvtx@ZR3oadBj?0VM{I!V ziWolk6vgM|&?KFkCXxVY15Ez&|4bq^GJOzY-|>sO&$Bk!)&F|l7S zX#yT2yWJ^SeN^cH3^5s{H0uuvU)@U98U;(IDFz zD6GigwWnyDOkNTyv71O`(q2IY1UPM>#8sz&VYS^8)RdgGJZ`x}RNdPj%J7w;7v6X$ z{8H3${cm778B7K&g2A|Xq33G+NHBTlDZq4{^mP=yjK%pi6pr+Td?ID=m2~BQr5F51 zmgK+XlMB_AKs>y)EN^_n=2jDht+KS#BfXxXKS zaW#`p2UF%sLi6&FZY4;h{0z+pw>rQaiV_YqP zp@DIMoR7)O{z+>c=kM<}$>64mE>3NdSAKw^a86)=1YreC@C>qxSm{myh6IHtFiMP- zN1hVboD*@|`DA14lZx^TIh}Fz#zLQKSB**pq7!J;ui-SXm`-Q@8JT6K@)2|T4nWPg zt!8fgxZh7QWs$R_9nx>u@nx}&a##b?=72@KnI^g|1B_+Rb$L`TO%~yiw!---qUeoqRE;-UKBm`R8 zfBXaqJ~=3K3^$Us^g6BqFn3kj1o=!Rg=7T0bZK|DKE6RHZ zMK?aK`<92|*m!`4oF@0R5vl$TrY>kLGmW(SZR82lJaG*ys{jJm#jr|Do8STPlfwQ~hNs2o9{Zu%Wz4&5JDKdl+~FIVY_T#OjKso+k2Anu;bTQfD@S_=BxND)(W4 z=!0V#_Qcls0m!i?>Cvss>hjU%nRC0F8Ul_% zKe7f__be*!(k^>ckYm5d{CtcKxKKNpIWiFAq1(HXZCSo`_-W}xp0g4u36!{p@ zvVnqgm0E5KUIza%atg9C*@lHGIkVAC*r*tjKXMqkc*cmVZ&Fk6No0eBcSpB?c$oP| zj7Hc8BJ_qfyd>G|Xy6Y_F)R)f4pZ~+UYsW?>}A|VT$~GMU%!QGv)KZm{RvDlr5q*Qp?aO@$EAVq@L4{Jl+sKcHf6Ky0KRqVHG8IR!b6=gW;JzWc0rMGUVlFn3X8 zOOXkw0A*(ILsS57U~U_bw7|aGV4pko3b*aGOQ+-j5UqfD6BmJ)4WM=3)9ZFdo8oBN zh1eJX%-peHz{4ubXuEPOCi7QD^m-_LApxW0$^7BpW)|~Ne4m&dhZD33`-oxZ{GAw7 zqFM>es_rg*QrPo>5*NIA- z^kl(lAI}=o9jvWfyv6?GuK?y&6QmJ+hG+dl@Vs8ptO2YLKSc+}czP2xQ#lx9d|uqq z@6r&tF}w6V)TW3YOo=ijW`&+9jHxc%`MmZJE;o^L(kiF?toW7)@T)xuT=tD?3pzAK z``|pn*Rj;XJSn3MioO=>32ul^j;ARKs?~y;re3+x#e4hyU3hI)G$GIM z1y-o{?;)m)n}yX83;Y~GSV=EMw&aW`uBO<5udodIitvOYqG!i|{8JaWLk~v;@pu?Y zCEs+$@@PZ3j23Hw+>D9|Un!>JgPebi0%j`iJNd>44>vOK z6LG(!UnoE=VjVo+PQ>p8Xf!6evh62OTjh7-? zGX8_OGe~kxaR-9QAXS7&Vb5Vy7dX5wG03z;(oF5=hNLjRHq`pBVQq~hGg z={p2{I*9&+!JMV+o!0;!f9j$%moG}{GnG4E^7+{%=i$jyx0Cu8bRE9kMHj=f3NnkL zzbxEqZHb3GJb9lj1eSBy`GgMmEfo3$>)iRB!?7*{ptbe6 h+_*+g-UM%3g?@(=hq0Pb8n2g{sJUzPwe|I_{{eC^k(&Sj diff --git a/sgl/models/__pycache__/simple_models.cpython-37.pyc b/sgl/models/__pycache__/simple_models.cpython-37.pyc deleted file mode 100644 index cede134e2a81b99d1806c490f515439c4cead2d5..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12552 zcmeHN+jAt>S?}A-^jxIXYGv)k_C~w$8Y_;ab%^b<*{mI>qZ!S} zYZG2l^i-cYefnI#^L@YXa{8sIsfvN?Z$4TN|NBY9_$NMOmxIg;xa0ps5*UHmG}iRL zxn`7L0Jb;o1(vssSnGa5)}n?7K#n z7}%znT{AJatvzlIH^zNE-w;P|NKr&09d3p^*7ff;Lg96y#zsr*c#D=e#}oYbJjEwjziuccS=CCk%U`bap~N62z;;q9(<(+VmzE(PJX#{LM`fZu&_W zc=;=7He}5YTJ5Y<9&be8-<>dMpr)!gLQGi##ozQra0-LPoReWX=8QROowcT|-s!`H zP3uFxAxjYi1hQ1|$W$HhmM3eTcc zF~x)ex#%cH6*DY2&V(YG?dw2zxs(gIham^B1IYQ}C)k&t7#LBpX(F9uj&hLVk2uYO zG2jF9;Q;tt`XsZyKUiAD01B3_9)P8%kA@|{QQ-g_6-VI6UKMA!G*djtzmnd~$C1|? z`4Dr87H63|fIZ?fELdQ2j>&l@pJg&a8cO)nxQAe5o-p#WpM*p?zzFh*c$CS{Fd^#< zutYq@$9xJIrzr%X&~P>xg4BQ@EUV+ruS~GxizqtEvzq-c(yWhoTRW$a~jYD$=Z_CPru~x#!djBJY+hYpHyvNQfT(ig<3kTD2(M}lNyc!!=gPR`q^F#OZ^oRf_w!-!z!A^ zf@Aj1j$Oh)wlk&8e-cvlF^ddH1w&&7DqHhBc4UOZEkQArrBtV8GOd_7=NheW$SOB* z$7(pU_we{xrHq+dKZI%Wq?#0ki&NU;#9+Q|+xynGb8vEM7gO;o=*c-Sd)^dFXeySG zNIUNAfOs^12Ar+tD34{vU}|aJL>+AaftI}s<7Yh52KL7|k>c{DOLhoF@|JjtqaPt$ zoF1LTc$Opw zsNLss`}CpARkd-M1zVzR1~X!0HWS`eY$pq zY=eq=gPubh4e#VBS-jqe;8@9%-r?~)Z=;?84mVy_TAifbQ4Z1gef?F`(vkw#ruX3C zE)VZ5*TDMG}2$s}w=a(8CPch*p+wY*9OJjK;jX^1JhkW^poO`9yOd8kyw?c6v6!oy3 zh*tE0?)=BkuheN^@}g_UJgFD|Gm>F0=-YvL+Yo0Hcu@DuZ^6R?Hz@t59Q_h-6M7K( zu&jB79#!|Tq{4Qm4CS4j+C5kc`$2695Az0H|Gc%;cXWBbh!LuIQw{9fuJ~)}S~XU! z`>5*IN&B5w{npgzQ|MLut$+5qffmy=rqItZ~!K} zVavZ8F8Ij;bS0jWH(Lu$(k?!COKbBJn(?0pg|s%WyyiDM;mblGEO5>nM{TXxbvN33 zuDqh0E)||QGoj5!Jyyn?_&E+m2S%QrP<1_$lX_W?Vp)C3PwHE%aL}I&Wa%o@1^8W_ zozPj3zAxN(HN5_d@}UXU+9GUrv;*O4y{CnGs}VpAs3=E;_Ds6$FLC~*tZ!>QI>O7? zM&0;45(YM0{21JrGiS|H@HL8%!wT~AR*e$*@Hm1GaDh0_i(Fm6t#IK5WD*!4dloD< zbJy%HQR^f2@yM>lazqQP?LuG04@_ONh?0`x*10quVy{YvB!G@{HvL8f6-@{f^hM+} zlwiHnr0dRi;wseJd>{=$(&KbvUXZ7GdITU`wq%u z`m6?6gOM+FTTm1ajU`P6)0TJvZB{UwaTXPct#tfRI-0WIz@6Gr)AMdou&j>{ z29QDsx~RbDKsilP1wkRkMJoeTR-4*xS_5&{eWP}cqIG+-sDJV!fHKe!{+BusSP^_*i>Zc+Lcuv zM)j^$7I58Ery&a{)VA5@7jQn(X!>FyX)SCeNjrXSdD$1;#@(fs*j)D4_N`9WjX_Fsq-0?S%jG!pEXI!u+3-t7Du|mUC1%Z_5iTahO2JWlEX*TvP zliWQx*FsR}7Dic2gTy}a+Kj9=A{FLm=vvlbpyf{^K^`5sJZG?Q|~$e5TFm`NH$ z*@NLgGyFkH5zK;ZoCo&JDjhLJ4@jzUyDY9!W{2M?thZM;(H*XTdtTezC7i4H$reu7 z?ubO2P)^+7aZBJQzI3zWC%wY-49ba<;O?m9K>hx=xKm?l9;OtG$Wz%OLHnUoz1&AG zpiK^0IwpvCky_loh48L{^D?WC$dNkabky>e^vTmn9B1~OoB1hPK_4JlJ3HA*+Q=%V zC@4c(78()!Vrl!qHlkpicDsdDPr`z~<7(32985zP$5!uHww_GJ&6PtDBw6yjpjF3N z4%c;m(OGiqUvfN}qOR>&(^Z&0Ed2=!ynFM-YcDH_|I(1e^GubmGP&=>;Q~+G7W{Z2 zJ6A(!W8qEzE!FPUkeqX&O3vHJ49j_fL&LD(&q zq=O-qw1lRyF2R|$2=u`^+7Sh&4!W8Y9mUZ@;V|-+3JTJirT@|A*(JRX+RzA&s`#TN z&9*)PKyF2rBe-*Z6G8w7#O#jnxpXfz{hjr|zx*%g^FMGKz0YPe%|&_XQnLj|CBBS| zuDv*|Ev7qYkb1X`xDny3GpdKu#+jrTz)8YNL;&TX%0d$e*3^GOI-K7G9Ad-q5ubm= zo0tayv^J>3h($$f7Gmmn$#}30*K)+r^ zu_>JAGZ=pbyr~4HszvjgDAR;hqDOX@_GT6Esce~%iUCXEux#MxE~0x6=HMD1Ps&Uv-L_{K`N z;y+tOksb}PR7?DQ_a(Hj1AKu-;Rk6h0ssgB7M+ay80GsTqol1OH*tdkqqLY?nGx8a z_&&dEDW$r^lx`vrm%Q&9lv`JoMd4g+UPceYHgt$ zA!%(M3Why`E~dHTXOZCtFViqp=fIJtz>Bps5`?_vTD_X$g%qT6Lx7-}^S}@B(EZ+l zqxx^Y0Sj)ofFt~3cfC(nat+bt(i`@!yIbC^^ecM`z9Fb)-#tyMk33Mtu2uUL!Z~)W zs_nkOI%V9p#dZQQ?^ji=BEB`M?z#JhiX@&zgu1+4>GNA88(%7!cWT&Ao<%|Qu|{nw zWMr<&atP241%ir7RAf16c|mx$0nKdT6Sa5PVnFr^9;IDxceJ@_*ZpL%mZCt{^P9~B z2b1`)EeM4y>fUtR?b~c`bz|%cqsP)!fA|JZp@jHNj;55#aBNPN8^9P5!d^JXeWE5d z9J&(>+tOR*4+vt-TD@ttLwN(7y%F)muOm6EQIwhb9lj?Mh(BU-gkwsa{0O(!DCTs< zb=(tzrFGo-rH}(p@g*dgMhPi0)`opRuN-nuQ@w(K9zL6Nv1wIl!|YZoNZl8x6{-_n zRBPlqLSjrfKU%GTg9d7aDxn0m;`Z^GNUm0RmO{=#u;Q>B6JqCuFvb3mXe!PpOFA$^Ol?lE~639LOHj@*|Dd5Y{n6Ia`InF>KHy9T=bpJ)(4 zCLU2Bf-Lkr=c2ZrO5+5_|9_qFF*-wI_-mvnsAqg?bs@gXF5Y4C@#;eS2I?JD6B>Aa zi>=Qw`5u#D2vYcPg*$v7N}|fW|EbXt;(e0iZ6@Dm^4m-lsfQJV_#M7b{=*NDOBI8l zIuyUlveLnsHlRso~g2^Neppf}BlK)8TpCaK0&|e;nKKdMkBB})T{xF|_ z_C=1Yya;nnV?Ik4Pl%9xamg@#_(Atl6e;g@?L_vtw9$#`{My}b%E}fLc#2SMOAKI&Et~(_VzNStc_$IcH(%?o_%=s V?Ckt(X|^)^*zDuTUzmOBe*koDhF|~y diff --git a/sgl/models/__pycache__/simple_models.cpython-39.pyc b/sgl/models/__pycache__/simple_models.cpython-39.pyc deleted file mode 100644 index fd41b8640f753bda743976f98114353fb8d7367c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 11395 zcmeHN+mjsES?}9ic4qdX(duSP%2=^t54LyZ*s%!+#kTAfG15j@f`o<8>YeW0oz+aw z=yZ=R!z@sdlW-H4K&p~d9(H-5_8;JZ;)Qn}D2k_3R6!M0q{IV7!3)KU`2Ei5p8E*d z6<#QMtIwYK`t-Sc=lg!&<+L|HUodbLfB9>{nO6C`v*&MNrEAQBaO1 zlqFD>{bQgUODIP`IqshT-l?4qtZiBLeG;z#vC9tZy4zrf>Sz z$N3H0w|(bhV*{)9v-qF+w&55XS@f8bN4^pU@7J4`TjAYSqZ8L#VU5gAuM8@4xH&}h z5sKJ|&7Sel+P4nC;-1wrTswC5GY1B0**Mp;9-8kPUq8F&^o-r?ejd-};=+EhXZOql z>%crPP26} z_RQGawf4C^+?p)6L~y&Sj^NIV917`hL*(Lp?`}O1ZYQj7wM5fBv;-$5vzoEu1lz0P zBx{Q(j^d7#L}A9=bCa_j=WdNNO|L8}T)&z%#c$i#vtrKORAP2%&yKBK%)Muy0xI~{ zBMX8@(Q`=I%|P<4N9Hm_Q9U*HGCgZQ*UPB2F7zzufV2LXLt>X~?3W?5*%L4 zqkwqEUhTGGh@9($(Vb2ZbOY%$11}Usv^dOaRxGHEZ$KDb7IN9!-WI{O7YDxEeWyy8#bi9UsqGU~Nn|KKW$oWKUYeA#oxdSzl=UsalXW5*ML|87mRqG$~Y*MTtA(t$V_1@F_r zmpuw!zBv=Vto6rAX66M)6R)wL%*roJg3;`0_hC)xo1ntWOP{QRah%gKOyLwI78GQ$ zz=Dbi^``EMBP=LSVwr{7*?|gkA?I|f+)2@aHi1EkHy0j}t2&PtHc z?#e@RBdZ*ioN`!>t$9&AG4&CxCjxL9{2!48V@?Yk7}r*=ixX$|;k4}MVL_~ASP;4> zcsi%7;0j)?7PMNwfx1@f^X}9c09{ACfpfE(4ZH*`!+qV9PAd!|`SPq6$Mw5vk8AK0 z<-JTEa0qw~R@TgBU^Jd|&rV&!Kt+dbFze)`>WM%BUcq-*f(_OZ41eMCAXhE%HvY+6 zVjS^X0Su*LnrvPNJyHXi-80kYlSti&BSWku(5mTB^*Nb+P6@u~u6qXfQ(JECoXrJc{)mMKWB;U2ljA6^m1 zOLdG~`Jt)80)_)6gatYraFljsly)tB*$=|FUh^8);2qb)?e(fDvm$_L4JWpNQo&mx z%&3kHyhvN?&n2|ei~ecWrby~9v8O~HPK{4vwrElwlq_)3J%mZhc~u-tBoMBX_6>&E zs(cqBzy>nuDvuq1eCiaR9s}Aow(gmktkGf7R}yJ$Jo!$2X)CpYZ|3h4+^YPd%( zZUnE0<|N%6bAEm#J}=1SYn?`1zv?{*#PvYbf_5xg;g3{@#y@^(t;Q`% zF9HL9lU32*qZr-xp6#2r0fsSRqKD=;5Hn%>l?7D6dk*jkO9+cNr-g$tQe5C$Ck+*C zo!lF41J=`8B4%74deiDTn%c{v_adGYeY<0bzo4a65N!@eTu3oP+J5gNO@$Lo{FRs8);#YjHRcVlK(F-?~ z8K@rB@_hxI(Wboci(aD>d>{nE{1@D5oYq!dccUZW!Zj5-sdU40Gv3*%McOZ;z$*Vu zo|@5gGnBJ>*$g9De9wz(JL?FxU-xD1eOL;Jv^+bbGe3Erx&Bdb^%vyn8O_=vXmnH% zaG2MIPwVbZ-3R)stVMy&OtS0`un3t;`?fa2nUTUST1MwkFinxiA5#^}=8}2bT!zU+ z)PVZ3#pFfz%rtfn)Ss}PUX*9rOL8-!k&NMc>|5~G%sm=@8+nWunHShK5%*iWnV!l7 zKzi9O3rbGOX_-(5r$9Pn01TV6?bSorW-5(}CFBAuU#-(%5YA`fDxatT<1}hA+YJzF zy3*OKdr=BC8+>AI)rLSqnHNS`zlmEJY(u4N%4YZc)O9~K)*yoLR_nioiq3%a_3{#6 zknWFl3ZVEf;fZfS>wXe+aAZK=2X8djz!z1!v%q(KWJ}b$cF$Hh5v4IVA3wpBh9Dne zWX+IktgW-S!Q$spq~YW{Bz%s=8jDenN;y2=0zC;Q%Mg=-x=R-9!Re{2NpKj({Acmz z+UyX7dR=Gd^!@YZ!ts1okXFMza`h3J^XwyXZW z!$AMnz*s3$3c}~foM*-1vc}8r1hRW3=b?DB7tUtqBiCxohMc8tzVXu_X{|7Opr1g#Tb0|QOtxlIqX~Z&wLW!a*#?X$lR67>pb}B zUMW;?-B2eF6%e)G^lAlX5%q>ADsii_6UXi7r&d=zald|dr6smkz0GL#^|vm3apl72 zzj)zo)x+Az*%7Btv9obUzlvfE0qNXLNyDTmn4T@x;QHzGjB*Uoz7Uom9!12m#(`y0 zi2LVARcl>*9ivo}L@b>V?dBEhp-QPUJuFLbp!&bYGB7$jN=vVLl4WG>5oh7VqksA{ zk~_W4mW{OeGj$#U$&eW#oJ zJQW6}dl#o~J+sJoNO1&8X>x=-_GCUi?9t3kPJEmw_V12ccRDQFGF4{fYwGhI&`RV5r5@FM;;AToa*$n5<6TtnOmEs#LAbH8|6wNgkPVUB6Ys z$q83CI9HjSGoZ}=CGI50<2VaNI5DLDFQ|Y1#=BQOQ2PHfBl^$NOg>6cpc4fZo^w^a zsFI#$A>&ZF=RHvG-JH;RZlu!sHY%f9ALFntb{<-yg3UFzEXXi%0LOADvnnLb#d(IH z%Jk?QMIoWkk<}EOCktO6YvVm)!u)VxNL5nOJ(08GPS9ypO3C=3$8$r*3-o)FgkjZ^ zW?Os)%~)HLb0GpSuK}e%&|uyllM$J}*zlU0zIW+g@Wy}OHoBipdD~a-m5YrQf{W-9 zD!O%bLHj|!h8L1}v>w$%oHT~DK-xGF6a#?ybvDYeASB341DMv(|AcfnzZuA+UYz1+ z`~$j)j-#MZEBgg>%35V8tz#Cf?vaTVOz@H*r2vZI4&Vj=r9KJ(DCycDwPcHr(MGGl zH(66zz_I9n`0=l}2YH!O{Yg9(JcyKi@L*h`u+kzTDz}8dVWNhkVoDVt1oornd=v77 z17SKY4wDUU8*IIs5g&pASyb;(EMQqEQSda>2}}odwH_P`?|8`Fsud|le~Zl$U*k8} z>>djO%SMhNb-WGGx$lZG5mJHCzoS)RD*!U;ke?}g?8Q=_4Slq^DIT!r0S5Ae}A)A!6L!7JI-MKhKgHiPDN7ynr1?bZ)C}Zd79MM|U6x^O_QES|9 zqpmSoX_N{l?CmsZqxhE#lk7)M zswx-9w5h}55du_hzNnj6hqP3xag=|Q`Ylt5Y$TVmd}L>MzA0borw@GWJw z?H-+Z8gl%9NV87gMd?}W>?-s}i8LhPc+b95K^-42>38^_?}WR0BqTE}@i(dkKJKC} zgE1GN&(vs(VZ$?w@LMAzB=3azIPYheBE_ucm;m*&kC10R%<1za7=e5&kRHPMhzdaz zI&@cV?gW*{!{EpSkJIKA&< z6SLyLFh;TC*raH7f@{o(tCM9K6rS*>hs!@ieGHGZCKvfYVT+Ot1w%>7_V@Vc5sTkP z0Z)Lb?q|TEbJKAMnaL73_Tzvd@dupMAF>#jCzTPUjJ!I6HAgw-|3l3L)QCSO`@YNK zds7%mMZWNiid-a%Q#~FJQB-X4%nD37eSyPLd-CQo6__-ftc3WdqprOAhH@ngmAkre zH2t}<)d_3-z|?EV!k{jQbuz0@Rn(_r`V*=**x@A>@3FYT;uect7PQd{JJk3?a;79s p8M*vL{gK-xOjv*N$P)>moLxRsK3hIs&Xo)0m&>o9e!hI}zX9FPlym?9 diff --git a/sgl/operators/__pycache__/__init__.cpython-37.pyc b/sgl/operators/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index 8130b0d16569d5074132a9cb9277be239f7fe14f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 130 zcmZ?b<>g`k0)s-gR1p0bM8E(ekl_Ht#VkM~g&~+hlhJP_LlH&@WEU(a$ePElMoOFDllLkI&4@EQycTE2zB1VUwGmQks)$2QuU{5HkP( DgR&fO diff --git a/sgl/operators/__pycache__/__init__.cpython-39.pyc b/sgl/operators/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 20ac75bb2cee4b8d4a3be7ba31a48253bcb8a61d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 134 zcmYe~<>g`k0)s-gR1p0bL?8o3AjbiSi&=m~3PUi1CZpdyD*V!hTg_`n7| zZi2_6P-EG1%s&$)BUQ#yi716pM0u?9NoHP6CTSe~&1ga&T0IULcTv>qsDwQpj>c2x zDI2<}n|gS|)il7{PaAL~G>!W*E)O4;2`#3F-c&w8QCq0ye9j_71vTP@MSK8vH=dPk`d1SWkXK zW6sVHzjKNp*Up@4_sCl~>NV}Be8GR<#e1+iIyeGa)CaIS)_QBsj`&fSy4!5-;O%X* z?_BwnZUSRiHFpm9seAH;ZeiZWg5P7`?45eLeYSZH^nK^ZpLB=5Fn80&Hd{ESez8UU zvn!S5oS*u0_iStKpL5QR55Hcbbnf0`3yaUe?ES}wqKo6li71k37g&^|lWsmyy88qJ zU6qW=X&0#MR+wGR*qOaU3-?8=CsMq)teE8Gv|psLl<{=u8Kq59HT(452>DT?ana33 z3fm=z-Bf6iAZ6YCvFz$Kf8rvwmi3KxjQ{1Rn4CQO6Hh|lrH7J7S$Tw_ZlF3qUcTv^ za?Q>-W#?z-*n3nV6(IHj$mGy$e3p%4U0+Y9|E}?9w@YyOqqUcflupKZd55l@g-XC? zZ>*p~mmxlgwdn6<<0Q`19W+)edROHYERqdWrZrL{052{QVH{PO7Tx(O@TU2AS_%tY z;}oT7EF)-qbr_ezw7!b7$#UG6ad99dVbO$T7HdQ=H-VbsuI7@UYTR_JO%Ru*DAHbF zuY&B*Kw}xtVaMKK>P6(N;nA>Zd-z6qDFmijpYb|AqFzPCc-v|77N>H9hfct^(7xtu z@eZF|eIYi3UO@N0h;tD|rWHl`IGtqFZb#AaB+jZ4Nt94U>@8{$w_K>``bx?)H53IP zMC^ED(bNDBlFlWM4UHFv(C<-=3+|mswnbAGYuYIIA_i|BNYcH zyU3Q~P|HK{SLznvT1_BL@?9*m=L96Js&=jR{CTQZyhU#+pRCq^!YDmQ{#Gm#t&Dz!X zXrDD`-lJt0d?yAt3)t2GWHR8p|53z8yeHkYVNM3$|P zM|TrAd7oxl6cVQUD2kvH@{Zpoh9qn6s`0HpOJZ=jgUYn6&i5;ghltq!Q@Ks!#Pue$ zV%aYW$WmcLVy@%Qu3!Jz^(T6iDJwgY@@=`;a)AjRSdy=9*-JRYgnN|?+#zVHYpB5I PkP#|vXUpxp(h2_p(^kyo diff --git a/sgl/operators/__pycache__/base_op.cpython-39.pyc b/sgl/operators/__pycache__/base_op.cpython-39.pyc deleted file mode 100644 index d397ce3153de8359bf27ddb5fc4be9e58607c5f5..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3114 zcmbVO&2JmW6`z^?;F6*k*+$eF2~sL(63}U=qD|YNRSg4iY7|K50y1(SkS^A{BZ~5VA4s?Q+COKmJ^9>QFYWKml9VW=mkuy*XWo3gH}C7W$@X@Oq5b*M zbNTm(v47I!X7e$5g0BAplCg~EY{=bL42Ao8g+B~YPFmgPRhedPPbYm}X z6MC}tSO_@w-(-TC*y~#*LnjT*r3O4 z_(GO?ta`%!DAg*~RidB3aZT&dOPQmoG%#UpAdgV1!Te#xikM>W-h1}QrZ2$1tm!IU*^L>4s?~f~~5;InM z|7C)QkE?z)wNV_8%8`lVHfccaV?5-uou#t<)$$#!Id(d&5eCe#@0_+O#N;)%4+dlN z$D+z*Axk5(FO(wdt0q>h2ko7nH6gjS6Q|=+8#PG{J&p{aotgM-cdg0Bmej+ZRU@eZ zy3^+$7|huvEO|w?;hXm5_7$S>TcQ4HI&^;SWsQeyAufV) zhvt`W)toDS5zM{Ioq2G@IXgS5V|MO6WeXRrgV}FhoX9SopQN%(r(Hy*8l88Gi8kGr zSm^3>R86~x(QX~V%N={OkLlo%Ow2^dyT{5&QBC`0mME1>_ui1>6p`5%Pe<^N9*xUx zG13s1o^&&5WD1vckH)HNHuOo#%yBja?b+bVQ8_t(^B=U7?p-x+2++EY|~-l*dvv=*pJlew>Ypb?L$Qw&wJ>kI+~3=Ixsz|DW=wIwW~n?D+XDc`&U(sAvk0_?+yy z68vye@Qq!7F!%OepgzU{^)n)$5OJRQ1&w~lDSa$_i0;aI#9N*s&K;MyqEFj?3$k`e z>w9(tK!aPFLr?m(V-_56=8zZ2?J#%P_@4P}7{Ogzwz>AzV52^h(^p-SMur*G=7qQv zSB$V8b=rgPgrJd+9uW4`J`9m=6F`^xB@Vh5BUF>)V-*`@rwxIMi863BrIlA>Oi=SY zuz*I*I8O6K>o~5>Wn&s&)y_@zve&(I$AMQt!<^xPxh6udgufhfePyCEkI8)_0Gmnxe?I>M_VN7}W2H z@na&_NpkHJNht948TUu4 z2m{pHv>rJy?UyBht8^x@5AYAzU;hL4CuWpu2TY25?W(n_4>tVD!9bmF*Cul#XSe8} Y28zcNnwqLInXK)4TkLq9cRJC30emOS;{X5v diff --git a/sgl/operators/__pycache__/utils.cpython-37.pyc b/sgl/operators/__pycache__/utils.cpython-37.pyc deleted file mode 100644 index 3ce4b41a875ded214c501d2d7c8408e912ddc351..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3796 zcmb_fTXP&o6`r2^&R!&2E7^u%R$PJ?h!i^%P(fuv?8ICWJ0K?yTToNu?U6K_*_m~B zuZ>i*Rg}f6DR|{+^_c&{51{&$C;x>!Bz&i5wXzjMQH5t}PoF-0neKDG{`%~@?RI40 z>3;D^8eFife^KT1v7x+!pZpU9w>ZnKF&o&d)}4WCY~I)(1Pfa@ZVaM@-W)X1Zs)CW zd(bvECvT2BgAR0;dw;eDE8OP+%2gin2FeRO;!Ttnd5gDEuJI0ELD}VO*Xpl6!z!%1 zeWtx(oD~X+e?Ki`DTcIAdOG+u@VkYd{0D@x=1kcqY-U4qDrlTtv1X2P=N`8!=fs}5 z${qPC7=_%SWj%Ff-onyYTKtM$53jEx73~L=J8H%)Zke*Jf_bN6qm>gjU&Txflow8{ zSx{M{i>H=Zh}zn$0ok2JRo$aWqkeDCTC;ZL&YM+Gg;k@9s%F)y+9z;o70lP`nc)Ec zu3*^H(c0GbP>SJpo;?^RYCO%iWciF*bad<9-TS|~^YOipcl(a^xH_Dq+AVlC*4|#8 z?92X!c1k(XLGnOO5_O=%INr)|=CwZo{GgP^5QS%r6pgs2c^JlyG<|v|UWJ z2lcHRI@m)eD1w|nSbLvpw|0N`;WWREprrRc{78q1*rz2(9g0-ens=0nQsQd|3D)K( z#!QnXN~MK%d7=^>WLWbM<9sL+C3JxCG@#$nPQpj%pC)Q5Qr$4CueZW85gRSQ9%%JC zN`#27b3%)5Ep)SNe7KO0Qz5Qk@#1SFzD^Ew;7}bv#G-s$E5kDX%o5ko`uVl(gL0g1 zOZjkn_s$1zzr^?oDm0D3_&=e%I*hC61(b(1%3&q6 z|BonFYgKs6mN>UY>ocI-oQKu|?d@4-f%XRQzCe4`ISCEkH@<@ROJ{f=PI43JmP16iz5TbIN-i478$ zNL(gC0MbsL7H23Duff(|5#ONpt0cZjqDNwr#J5O%o5XiW&=iA^fFBpd(R1^Bms(#Z z@jViR!CLg8>DR~^@qYl?570wyL7W5HDd6l509XSpTte-#EeM-SsF4fFy#{K3hf)D- zBL)hy|6*OTW|ng1Ze@W=*$G~dtTXp3y8_8xwodGM01|YFW?>IOwrq#M+8_qul8R;)aLIbV3^Qf{wvQZVF9#$d9(YGlOCpZ z+$iGO0}~{N+L2Sx;-n-RpiqL?WLE5fHZw4#c6h1!4HHDo#Gu8~eUYXz&eJ_LSP|fq zN3nb;RBc!jbi&UW+oxkhwtt`otIrLpE!}gYj8%8QaB)$Jagt|8DL)_gb!t{+EDy)y zREcb8`hI{NNW!MYR`GVhAd$VwE>elQ_z7BubV<-dL0ty+eN>1;5J-CwNM>$J#%%=A zMG$R{yAw8#Zcqf#p9iGPoQeXAP|V!QJqyCKFYzoR0VkS+wvF$IH%VM4VNTBt(%vBP z7KtB0^u5~W!NPfiM(s8b@}Pa@FI2^kY0SA(v9jO^5gScXS>WH_M&mGgYfi~EW8Fyj- z2zKaC;92T9@^CZ(aSTbt}QObelIC_54R~&qic^-!d`!~P91jcAEdoSH#4U0 z3S-@`wjCso(jL$D_EM1+st0o0R|j&_9I&I;X^Q82%Y$;7^WKBBm-Mcmes^<;HDXmF z=Ux0HvEe?v`-hDgJUBOrBrrxuV}G+B5{8_b)_;ng^%=Ic`zT9s8_S2&G(AfDwzjw4 z)Glq~T=YMWcE}OKay+>?rtj4sKe|>HxQDZGT(3Ljal-l5PF@a^T;7CoI!MrGkytwD z4DyJG5u{3rA<`)7BpkHs#(g9;?+H-~K|E`yMMs|``Lwob581%*K>Lp{cKE6I8FeD$ z_dO&mST>RqsdTemOU1@O~7YB>H0lvfRYO097KT<$-$XSf?i4WolRa2O)o7D&QJ&4sOod+-Y2PJ24~ zzu^9i!1LeXeq#yuJ`?Ytd%q=qfkM1Xg0LdUuV*L`x2gIb1TtX4hHgxY1+KO8U_2H( zXafb_p?<$4@oR`dX91N3Cd6;3M)#LIjAQ$NP5+rnxW_FqQEn55QIw+B}(!T+h zq{*DLS4_u~L&MyI6{BwrO$JYSbz`ZTluEbcFq<51)d{cm>$2u`vsb_FZ!)afEtJxl wWW!uchK~&{&61wmnc1UBcfy@<$)|aGlP)nyw+Aw1Y-1sva1(DQjKa>p0anz+3jhEB diff --git a/sgl/operators/__pycache__/utils.cpython-39.pyc b/sgl/operators/__pycache__/utils.cpython-39.pyc deleted file mode 100644 index 249d654efef0b04289e66c8ec155736ce566d8e8..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3703 zcmb_eTaO$^74GW$^jvoB9j|Q&X3Qmcfb3!i5JZa;$w^EgUYpnn(g@UQPxbEhq^Ema zRb#WO9tnAQHG)^3W*_}u_yMS1c|d;V%@Ty~RL_pRYlC=5xBAqnQ?bg{Me|xhfn?%e_BXgB9-c0OcwVc?0Dc9`Pp1v%JOIDA#z0ub}L5cGv2!KEo=k z?LO1qFwP1E#lM#pvJ^vF=<(Bm(!h5EpZqI?vgS_p`|@>X5PZmSX%swUJtLYA{Ffhl{;$2EpC~zt%7-{VxyHKHeba|4U}h& ztXWW5qqE1BS%})&tO41bMOEFSNuz#a&swu~<<6T`P=!^aimGPSs@g|zY!%GU)ic8h z{d_V&GxZhm_E)4NX{?Qyj~Nwr(>Y^=SV zJlU1~b?uaLqJ!kVoFr;bhjF}<T>Kugz`RlJB|NIN&52rj? zqF=b!z3@o*5dHIl{34npS|r*eh#dquO{|gt>sjI~i8T@i#;d1cEV|Toj>I~N^CT{i zxJZKFqn$i0USLRk4YvM@xJ2z=C($EunZ!3pe3QhtNPL?FO)&@w_Hj`hyfn|7)cPF~ zS4a>7Yw=xZ`VDeG{9j;3P?0}`cm-zO3z&77y@Xk~1lk2#yoGTmfJT%`&NV>$3zP~_ z8!=Fs{YUG(HM5j6cPk5&%8u}YY@NAZ*%iq4f^}ri1CXIZbPIb3(q&r&)dp#x-{Y~U zq3)TnUd5WnBYlkY%}9Sa5*p~$okx`gvW=<$^{@&-lJ2$v!#}V}^BBtPH?D=WBYiF$aeSCVD+UzwWWJ%l(Fg#m@Y0#F;4RAAmyhKaE_W)8O#0g zI8`DWn!X=n2a=#_u~ocXa7bjYva?j8E`ExZAzc>74+VA^+z(M97C|uWMKGDUIT^PR zNEd;$Iqp%|Ji1L0NPixXHghTpF2XT$EB8epo_&EwY+znE(Hyy}&;@azc$)f0Q*A*oN)aoNP>A-Cqb3VzD^jH^k5a;4|FSs~&$sUFrM*Qr zGp6neW8JT|?IjP>9?y1mQjr#_2ddjud-Ae5WCw536fgIdd*w9ez58h|>0Lek@8uJC zBc3IazKc&1N$$eKzg?ffi*u7)0%wFa_GkM6;mE1U{T=kI&$6xEhgpi-Sw5Ji=|S4J zwY~8J?b23GqafNMXAH~n9*y|jrW%NY?q}$6yziz+X*KLjCUy|-I3v>{H3XpoR z5b*D^OH6zM`zuF)?o-3|kD+|oM}U%T90Qm8PwW|PiBZGopb#912=@lk@=#Gm*8o%=0ulcxAN2@0I}h{Q5v;ulo?B?OXF3Y2b4iv`5A zbALP*@1YH>bc_0Jlehyh=q%#42$1*{wVWcOu6k6yf;fSWuL(7YbD zNr}2k(xf%oE2iVgz9G}WiqSWQCb_1Rx3N@BzN1@mm`(OK>ReR&by-ub*{fgoH|bC9 z7D^sW4q)yILy-m~W=SvY%g`kg7z<-se6F*V-N=hSbz)%ATHJb5-AKRj5!Rsj8TlaOi@g^%u&p_EKw|s z3@J=0%qc7>tm!OKtSM~444UjO8G*_*S#I&UB^H;sJNpGE<`(3n7Tpqb&M7S}Ni9MW z;`d9DkTY{O15d;BFZ%LB? diff --git a/sgl/sampler/__pycache__/__init__.cpython-39.pyc b/sgl/sampler/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 5a55c50262ffa93da4ce33068f6d2e5e9f3dfbbb..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 273 zcmYe~<>g`kg7z<-se6F*V-N=!FabFZKwPW>BvKes7;_kM8KW3;nWC6-nWLCXf6V|qlg(quz(0wAaRR5K0YxgCq8~9LlFl^5Jdda z)X&JzP1P?hF4PZp_tDL%tk5q`&(VjPqaPoinU`4-AFo$Xd5gm)H$SB`C)Ez*;9`&s G5=;OPL`Nh5 diff --git a/sgl/sampler/__pycache__/base_sampler.cpython-37.pyc b/sgl/sampler/__pycache__/base_sampler.cpython-37.pyc deleted file mode 100644 index 3e89faa3c0813e31a25a8c5b323cbd94d17793ad..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3740 zcma)9-Hszi74GV8+wHbJbczUMY z-CI?&yS6+x>y_XlL2wBu(hPUu3Bnuj40RPLSNjHNulP>2?HP|2M0>jWRMn|d=Tv>? zXYRIIzJ;gt?}K=^ZCPJZXYn{7evDUgG~D7Wwfg0k^=N9S^5XS9iR^Q!99Z$srx-MAac zct08yX)Ffzd@mPY+U{p~<-gEO+0+VghG>~-ae|;7ZliU%Gqukv7uHEt6rifyxnO6` z6e^wAqNi%g2!_gQ!TOK@w&6MN}O9 zycoi8JYT$xt74*G|K;QPM893+Qe9o#X=pdZ9qXo%7>h8AM)9H-rtzaV)f?gUWx9D) z3GJT5W2v1YFLZ4ft4JxKz4@L}uZ4p=lS-TpRJm11iwlttVktYWc4eF%YKOF_o;;4k zP-<^TmcY2rEO8U2yx844%17~@ln?jz??2d0$4~a;Fx{Igy>|dClxD@ETa2|AhDnyF zFl=D4e1lGdRTn?I$$aLqO}tIEHt%htw`Nbaub`^C5HesS01v^8Xag_8nzk%g!x>)$f7=nlIS7&28Yz zE}dvNb=gM3un>5x>~qH0av zS~&usW`2LAPTkuGKv631fLN@a-1{JnanS*w^C;nQVD=1*4MCCx>L?Bl5x;``Fo+ZM zL=XkEI>u7XE!jF`Am1zH0QF>@kP zmD?rmYi7bpuG>JclT9=haMosP%(Fc<`Rf0)dwCHA+$vrJwH`o6w@*Ry!p02~?Qjb; zMJwy1F1TXH_Bnb}`)?L^KP5ibjJZ?yf?>XLUL^`Bptx31^{G2!;H`cNY?M)E3Opgc zeXM*Sra__j6Z@<(tzd0)TDf4aQtYf{AA_$BzJFl6MV!ZL+`m;uz*?4JwsB(ri3!i> ziSP|ws)j}g2VFCCoBNhkRwJm9!?VKWPs-gud%ot>Vp14jp;8W*vyQ{IbGc!Tz2^e7H_A|@LP zL=vCuX2p2&y$5;3gF}&z0K#||2pycJk&1b+2*?D^gXkausKzCuW?ex+0}kUxPZFtX zr&&TceNLl&&j20V+8-gyW^pn+lF!LWbW0|qP#uX_9_1+?{`D{P{PS%DOcfV;Bh4R& z4^N@B8i#U_i+Hm3Zk~Nh1@g2ga-o9BM=RPav`zA?8@z*%3?C9Jb|9e`FS=?&l<@!s zKqe2toTxzMw1mPp&P3af4<@hgUz6(gBGSg0N34c(a67IaCXPAB8Y(~)+O4kaA;C@Nw>K})=o9{MrDY=-0v%`Fs=IFP{Htn|UgKP52N>C_OBZ(glCtV>GS_@@~4^+5}P|-A~ncVuXN3tw>Nl|*= zAweFRGFRgY8ESe#h*{Jo_ds7WSl6kBH}(FnNdQT**yP5_z6*|(Bd73npO%z+K1E$W z-&1e(a`hgS=A#%%mz$J2*8#_b3h%lSK75b@X?U3ZjzrG2*tx33B@4{@u;WveN7ZT= zYCjA|IX_K_ZiXf*mNVijkY>s+@lB%WjKoi=d5@Y0)LhnTm&l~lOaG!tsuVot-Six< z?tR^Fd5)mD&W89tKDvSmp%{xF5WOV_Nipu0m)}Ko9;HJ;hKL{0^d<_{E{{~yJ)mlp zD04Uc9hyM65)uBBlHEn=tLw`#@d`=VK-1qK=6TuMJyl66P3Y5g8ZY~VaHHhQEYsfY U@}%z2^)1X}O_af#6qW4%08Q$_&;S4c diff --git a/sgl/sampler/__pycache__/base_sampler.cpython-39.pyc b/sgl/sampler/__pycache__/base_sampler.cpython-39.pyc deleted file mode 100644 index c6e8bc195ac5af8c4040b02f9cf26771ca56a148..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3453 zcma)9&2JmW72la%lFJoEQ?jGjNz-mqBrVaHYJ;{w1H%sTVIysz233#)kzljptf-Zj zyX%>uZ4%TaGSH%E3iQ|(1*D>f`jT^c>A8pg1%n<6^yG7IIraBuNr_ZGkfrwRo3A(T z&HU!~M)CT3;Nbb~C;jAaZO8d96_y_tg;(*a_aTHMSmyK@zC7b2x9>5kbFW^i`8%y@Yr zM~|7Qy_OY&yQtWV#bKmnoaRxS$Js=s3eBM^W2us8^f1x|{^Z~wnpf#0S16Og<>xGz z+o(-!Abl>F;E$ZXD_r5>>xr7E<69FA;p1Bu8zK;Ek66DUT4EhFzG#a}_%=lcZf}~k z*J73I#-lPzAZBdlL|8*eW+fL6bXRalPq;IFQag3-JFUb8_!v)aO^w(m>0eZNWMoJJkE6Idcx3(`qLF%;KlkqQ7MEFRK>G zk9=!x^V4&iR1b^6zLPpL{#W)pMmveXn>)taDe~lY|57QFNYMhM`7kO<(|J2Cvv`ok z`5hUT2frwX2p-RWZecT-G|gW=m^bu0Wuf%);O?67BJ8(mS&NB`@_3XiN>P^FPcpL+ zT^wdw%T5^YZZc8EE#XXKnCMt*Y5e(|%Cw_Fkt;3716@rO(cn@RgG8yWXFQc;`^F_N z>UST+a;S_yq(~6lV@G}ip?r3I`=A&l+e+Qr-o5ko^=$HRTMe`AxzpQwz(f^RB75b; z_)(PRsg5E#Zt6u^4OUpmD^o2k$;xzei0HM`{R|3_%uXDO+n1pM z*{AG;3l6OD>QK%28_qFGGs3x>`x=Mrn0>@_4ajnj39EIBsrrst%i;+(U8U1Rea2o| znVS5h0C(}KD-c5*pAT9`Oyg)Bv)}W3-fk5jaOS4D?sCI>KQO_$%?oQ_f*VnDvF`0iQaL!SQz{p#qKq28yShfUGA zR{in6u6ag%IR$mh>=ryaaiw5+%)!Vkdp#!ot^JcYKebzXx9r-ig&J13h z&)m^#;?j)&$q}7D1DPKo*MaLccKXfvdDv41YOxpTNtu`?rK&gsSS1}nz|M;wrph$N zc}n~Jgldg%0h8*92}a0ad6Etf)DuFpSy$;O(g!k82Sp}^fBXYIf4>C^(n)DHvf@E> zZ;T1_BvOMyCe!vCMgCnKs&QEsQiszI337Y!UP^sJ;Bqj|Vx1OwFMJ&&2P{hTNUoo? z$a95Y3qy&Jejq5vZW=E>l6&nOXI5X$au&L z3iv&U%S6`DBRP-KxE|1x#@5s4p@_zAVP7t2fc@^pRT-kGcE~>DNeS)rwo=1Rk(c9qkQk5H@jNQoK)*TB;f#qx zWjeU^t=#u2^RJRYiq))Fw}g2_w^3!2hJ$1q01O`M@Bpyz0G8>sFI|b1D;;eMy|0ro zAEL@U$U5IbK@waR8JD_T2-q~~rfP4syJ%hH-KSROUiA3fY7>>*p<&g$6*bv;K|7as z3jG#cK}QKviLi-c?g^F)H<*jA1UPyiiX__4-&^u!YuT6cF1VJz+YRUzL$|RgGC>rL z3Ng+|Z$*})suuZ8ShH6P`F+xEl6aE@ZGg1cJBPtJFexZXDhmIFI6m{Q`mW#fuLkSB zE2*v9F*T&Taw10w4C&+{qN!88D~RDJ0pKe&6~K3^+c@W-X2}VNRxzCvqcRB zSSiLo463)d)tTm5oqU$8JO$C;pw9CfU2m+@Oj$BBO{%ZT1WW7R()ykSW&9h}BHg6D Pv-uV`Jh0Oy@e%(&mI7ju diff --git a/sgl/sampler/__pycache__/fastgcn.cpython-37.pyc b/sgl/sampler/__pycache__/fastgcn.cpython-37.pyc deleted file mode 100644 index d0e5023ccc925a2abd089648bd745dde90840b8d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1778 zcmZuxOK%)S5boDJc6T-e+XN@MSYSn339ZQ`h{ys&5Ec^h#YjXWP-{HhyPjoU>7F&V z)qAp0PVgJn7bN}!H~vIlIr+w!3sleSI#HA!Rdr22s;cX&sxODbK7#S&uTSKkApPxr z)`NlY1(1FYOdx@C6n8U@ar;kV(#|wy?d%nP9AI~bjm#K@r6fMh8eR;awr& z$>@^CLxGRb$iD&Y=x~HBD=SsN92_M@olBLvjrE`_L4O2-M9zVjyhIZ$upn1qOa&F} z3V|b@@L~2|B8Flg3um&~Qpi43%h8Q{_hW&;y zYN(;k%&%nO8LweW;Me#eGKl*EfeCu3Fx+-;NX{}1COti`rBY704LQ%stl``abe$+I zxqGN3Wo@aR7uKs)b*x9eWm@KwxK}I5+a()LWYWx))Rrb={>KifMx_;ycE^F7_ga&H1ppQh!wK34%a)xOw0AK`m6!WY|b zglN3>IVJ88j1b6Ocm||%U{f@OSk2HDS`dR5^lab7ZU%X>4fqRkNxworK3Xt?mJmzW zp@!W@*W?F?;=(sRM40U(<1Ofm-R~imOJ>M523=_88G0Wr0uvY_=!8Oq*~O>*oz?iJ ztmln>di(E|PoM4_Crx^?Q&vLiojlXc?f#9!=H@>b9ZJ6Km-B+>$+=Y8dMYWW(r)sq zl-(L{W0sXuJ7~{von1&BZ}arT9px(!n~FEG)Rj_0&_4=Q1e3+5@gT`_4*1uVYT{u! zSG<}xfEPFxuQX<*kZ1A0?R5Z&)z9iyf3?b=f2`gGwH@1V2Ew&>?oV$4LtaD}-o^}X z;1GwTkNa-{bbT#dS5gmdI%)y^9RvrSafOY3Ark@s&;~C_g8{fz`d?$=fB{Hn%z@dK z176au92Op&^_D)|3V=eOo?+M6eE=kbQw)G-#&5x;oh1`=wv5-VhlVcu&0rY`&xD3- zp@+!bUr7#St_lUD3zRBKk2f7IjAiaRbz>+ttYZ%sz<#R+PHTLbu8ePwgW&ElCaGT8?3Qr zrOdP?_2W2LquQI_0(^^U{rC+`)%)PuE|9(l40#&>W*^_d4{)D^_%1-YjW0et>MD8N z#Kvz)SZF-ZY|5mN&}{mg7nPXjPL4Q#F;DVtr)>c00a$7c+u^E(q!RL$zw&K_)m?YU zC0g6bTo|?M0U?fpaHCxt-LbS5$CedIb7H-wQYlPOU(je;5F(vs_4&9f8tZqn`T(@p z&DXcac~&O*RJnemT(gPagZ?!ZNt29^6KI-UNR@K=+nRpm>FedO-#z*>SM=KD0+dAi JMI^!z`5!0J$3Fl7 diff --git a/sgl/sampler/__pycache__/sampler.cpython-37.pyc b/sgl/sampler/__pycache__/sampler.cpython-37.pyc deleted file mode 100644 index e45dd18369efd176176b05eb55c1d7ad89e82c71..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 9426 zcmeHN%WoV!vHx1$YBq;Y)*?@^B>eHNKU#1$SHeCeqVLZ@S%qd z?Cvc?bahpAz3cnF`hCCRyUnI=;P<2d*-E0X8ODE7V)k_MZgSZ4vesPo1v-G zR%q+D6IS%w4XgU?4QruqQdw(IA2z~Nm7yV zwxya~o(-ckkqO#UPdG_Xwsqe zZKG>4lUWZvOx|YBLnCy)Z8%0)!5}KCzOgqL%o?-G2B;wyPaTi=T?EpQW^No>eJeM+ zW^R4Q#EKv%vYCa@VK%}_ZcSb0$ZBr%y?%{V){H~*h`#%|F|AYX5!Goh7hPA!8-7s! zK1hRZ+>x&Za;GN(QA{lmJxuW+;ECMhslc}+=wh(JHjhU;>+{N-Mye`0cSCt;#*(PV z1UHer*>0+e*c&ChDCI?1#h#2J)hK(3(s-E6Qqdr}pA1xE#FI#W@1@(y-cF>d@HpL0 zloQ3Qj}~|0QKEbrO|%uu&W@_l2E9ATI=i!Minnd$h-A=JHdS+XAH;lHUV!Hb8k~v;E zv5V`oId*6T=tpnb()aL)MFiN2*<*HXY=*hnw)ir>@D&7Ylb@&f+B~>~nDS;zvWk?@ zvT~;9S%J2zu+p@$SZP0sh8f!%P`nXE_x9pJ@kP&vU#6y;<$`n)tY~!HF-Ah2aq3_1 zAkAEmqI+heu5&s_%zrdX`iWmd@GK1|6c}c}XCA9PG(%VGL=|gMR|}g-Z+mAe0e zqY-dg3{C>kwu=qBKp7V)qg*$=gL?5ucCO$N9)g+dj91Uh^l{mtsndG@HXcD!CfVs4 zxk)AfadRFwrxw{6t?)3jn6ka8syrDD;!bj>?Wm;)npEf&)T=j6gsMlKc+iPPJlhhy zh1U6t2q24E{%9u0vY>?k5=6?03Y|oVUqAst>Tg&L{Ee?XHk}iNJF{J^qC8d7>hL{8 zB&Ks{@)km4V#3my((2oX7Jo4}r_RwVuhMr9E%BDDOslzhWKPU&V*-0PvG?DS9+skZ zXx=nFd3EBz+W9#wV!xg{YsQgv*P2u&F05k%Hq$+_uyB=$eec7MX#tzL#VlrjZlNtZ zhkcy3a_gv2ODFy=bayTjAEcu_d1Ia=~sUO6)OY~_HJJKqdkf4#OdO*{%!=@gKR4v%zH`pspll_>;xj);~j{0(SKpLdtAsM z-rnZPb}V~Y8VurtgwOjro=S!pKBH|33!2DNZNrd$Y13qkZ| z=rcHw`Vbbu1Dr9Vgi$GN)dOv>=i}9f%LXj@BFuhKH7csM7nBE?9>iM%oDeMuL1mwqS1q-7Fy)7R0bAX{V9JseOC6Sms)Eyr}sMe_ost(G;uG?Tp2$gUqZ z@aTW>>&$F9SvgW!t*#fK4^SE?ENqSWEvsv-855gSXxHG^By3P__iZSrll~a$=}5Eh z^r5)Ot>W<*$}_tQg`=E)ZN*^D=k}zs|7l)<@{XaZ?u1I(xtmwNMXh_e*Z0xB#q)ci zR%&kzaA?|?#gEL(#-v87%qqJc{~^A$s7`KeLy=P}w|1LNd&F@MTgGrAV6ZFT$) zbNSSk;*9FvCo|E5**sw>POB3WqzQzxAQL=EBak$(s)y6+wmvE_Ey;5JHG@+j+`=y43EuKBV zKx2-xP-Y{bNs42DNQCv206_uuIZ3>#lNr#!DuJ>V5mxkiMhE#DDF0h{1g)V_UACZT zwqsqhT4u}iP0PGwUWEd-tVOGVaQxC)6?}q~0g`wNjkP^a8qJ@h;3@?KC^#L!{Oc4@ z-~4$B=BRN7LNxve@QI+F#vd>UHP7>&(}1Rf-1%>&^x{{L^_2je_PkEOx0-5g;6%9f z-saB=y%CggHu*;wk*er@-k1S0%KIq**$~wWKxP(Tethxt*nie|Obh)++u<+Kz{ySg zN~lZkv!KDA09%*O4)sg1)UxvQ; zD-bAH<6k5>zD5C^7M$=0U#Ea@6u(Y^HcxL-?3XC`WeQGUBTnXq|0;q4^8kgGCYBt@ zut6B9mzFT7N-0?NV@p@~r(_5J5RV{(H^V-@=>YjG0{bk_kk8r~#W@B(bIe2D-W>D% z3lb3eGnl793&1I$V90$qG8vo=az6UN2Z&Yj3d-1dMdK4vq>9ETp>#pfIESEd4&ihF znRwG0N&=&Rq|w+T2Wits{UZl%i<7%(4T@JHj%tayDmvfG?P)_cr!D3J-B>Fy@*2I^ zIcgNRB!2HPW&6266KonB7LA9Gu?eA>lQ1NZS+Iqy0p}E04F#dPQ2d}kFco5#x?@{ z{_EG<3n%bWHOCz?ECEo$dNHV~ zT7h(GbL^<@kI`br?^3y^TI?M+&7NST1QCvWjEBe@fGDIOO(}12`XvR@!E$qI7Q&Pn!v4U0@uU$8Xg}{ zojHF?q;7-3azxFNvi;^k+K%8#`zj0QkuWgo0c5sV`+xGDp$M0f5UF%zIUNMT=QS2sWIj2W+Niem4~*A zSdBU;qON`k)tY-FEk0dFP7e)XFkm+F%eK>^5pgo@Vg;P1WiYVj2xKd!#MLaVDITt%N@cFm&RF?UXx7u1=d-5;_kFPpDe4FB)ZzgOV2uLnK#$qn$(RCf(4OlM2c5i3FK6(mUWkpJxrkt)4My z?BC4M`zLQtn!1(d(^|GhuB`9>T3Zz8%uO{Q(O62lZ8q9G3GFr11SE+_ukhpkf zw5-Yo{iE!0Z^S>Q&;Aw>$uM#8w)P%)LVeMJtK9n#9Nhj@^g@EM8sXeXaDeHxBKJZ$ zxCTo&oyy84dTDl5xgRbRWV-e*Xgc9CNkb;E9VNLu+*cLdIag_61w-W9EjTTg@RdLr zK{5}2izwsFQ&-LLm1oK6TqemCn|m}l%qfTfm{AxCC=OjK_iYQ;DHB_2l)_-fkXIqT zR=UL9QVRSnJRNu^;C;c{I>dxdiFx%I&bhmUS8&Z7ftih(B=b>g|6k>TTttb+Ax_m* zm^&Dqe}ZpR8oc7oEHaYlAKb5WalHfX7pyB?>eQ!8L@vOaaaWBa8$54qx(v+Z;XB38 z9T|7+i9e}N8o57RIWqD(g=hK@#oFFiCMU~_-Ub^ zwY-J1y;`PP(@S|v^TK0IW?ZJ@1}1kOqq{<*(&Gk`+$B2XIl5AkSEtV--jV^r7np}5 zg>qkHv_6-C1^ge_be!mOy<Im1`#Ris zXi60?a=n%HU~7Eczl{Stc#1XUiP#%R5$u8P;2O@38v!6f(!q7e!MfjmiIZW~zKg06 z)mdCa=xD;W6T!)-D4*gbEjl-<{JZq2HU}FTQSd*e_Wp!|KczsM1Cl;|2Z3@^T(Mw| z4Ps&TBiznoC^X}Ye~k(idX6MeA-QA)0bPSCiD=NoIDGSIe|ePO*m)oF%% z9ezmP$(r#A0{CWR1(ZWjQkk2|d|yLVW%jf=;D1gfzd#UrGgHU^lJfqF0hB_(OdDZ#-fh0V#9~bnhbm8c<*t3VaDMcvQ$=h7t#O7vCz@Uqudb) ze*CkGUSMzbk1_lY=vzbk0qe!WD7SbQyV%e-Zlwj)VicY)vbGca#|2#*tnbO*KolZK zW;|TQzh0~t?kZp30`4y<0M6+^)&Fq71B!h_f%Z$txSrV#<-J|}%fmb5qzn2N1~Q$+ SUjw@U0^V9&Sp1#EoBs#8JyekZ diff --git a/sgl/sampler/__pycache__/sampler.cpython-39.pyc b/sgl/sampler/__pycache__/sampler.cpython-39.pyc deleted file mode 100644 index 2d54369bcb67b1d5d0f4f0510c82de8730a27214..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 9387 zcmeHN%a0^SUe3tO%F4=m^kaI)-818H`hl6kc8|?tcNr|R?09?d;_Y1+vw~twD!U@P zyK1U3Ya*g%dWx(SxV^h(*;#=AiN_jMTW}Z&A;b+1-1!&cfRH%N4a5xzneP`_RbAcl zz)1K;wK6gzGBV!3FTURw9oOrDh2JaRznNTo#j^gDB8xv4kuT$t69nAitY@`~J8Ll$ zw_A?6yRDMBd#$p$`~6BQV3gPHRr|G8%{)83dcVTD?C6vF!43StGtGI&BV`gcEf7cAEK4kc9ZGajX^AA@?Mh4 zOf+p>-xXOuN)vTA6Sod@y&ub4k;2ury=@wltXxz+W%6<*q~4%rdC4{~c#rPL$(rp7%q zp*-3O!s+jwG~A8b>eW!~cVsAwZicdh?%oSUqJ|=sc$b8`Xmq$I;=%sr{9{3_)g_a< zrTsMHNz~y|FeGm`>$)U|gG3a${0Jl3QBkC8(^{f5?kBTY)JyIpJzX1!Br@+i>7I7> z5~WKbPWKY+MltWA#Qk`X=zv-i?Zm3RuPZcJU+HD-TeD<(e&A?VCcRzl&}ZJQyRq1l z4;XSs?OrTp6g{y1^YWGbte;$w^7fS**WZ1uckkee-0NMLwr_Ud9NZIR3UyqvhQM;y znP>jav(ehZY;Ml_*rpMf{pgvd_8Km^i~!RyyV$OS31Qn!Tdd-ZSVQoDiF5RPVIEw> zllEuBvW}QE>*iGLvkXm7W3*{>G1@^C^)o)~(Q_?|ZV%&L@y7H?T%w}u(?Mw`7}aQ@ z;tgs7)o$jH`Yy@reO9ZQlr|zKe$-F;$uA&ynmrUI4PXek&nq9YmS-%ajM1p-rR}7% zx4)B#8&k78Z6(wQVN}ow6?dl zo1Um3YSN1FqgDtS7(vTJoTb2wnwdQ_bhLKpLHY=0YiG26s*jIfJJfON>@VSxKZ$_W z&aMS0UagSzcrLK2q%P*pI8~-6}c)lXuNe4r9bsi-kHAqA|hP)caLd5sNY_~`ZY2k$QX8ys~6Sbf5 zlLgF1tXO`)zIpC%3>B)AC$rD$yBhBGvYog$uO&I4nv=M_AIfYf+E~~{{RM`1gj8X? zw+vz};Q*NeCmt$!j#1a;y4v-@>>QpwyJ3&c&sN?PYB!Gy zc=#!P?HQs|L48QrqQWd$6{ovs7pbko0udAf>q z-7X9lsbyR~cQ!2U?%Ld?lx}4WH1EDMF6AW{ZwgcO#+1*=y}bM#Dw|9@$o&nFB137+ z^AWpXjVpKqYBRLNCwN!S{Vt4m59XWNxjhb0^HO&?_j8b;ZY{6y@_ly*P%C0$ObIyb3@-OX&?#oua>#vSIDoKv;kDSMxP#(BYK1-XUn{ z04f*I;iL(1NvofD7S@4CQuP?iJYHw8A>g*L7qda~`SvBlxS(qb)_;sV+u}{kp0S#T z1jAq*E4b;hRHB1rw}k`(bIk_LH9=HGyiQSXn09Ut6J4HmLp#0Bj;p59?$yPYwu?mCGC;Xr(9RtS=z>8 zPI2jqOcd|z?!LjBHDgGv0K3Yasl5cpgVS1>4&l^LiT?#@axgWmRzqZW(a>066I9s% zl#-s&6hvCplz2eF{lykFyh`?Om=(;jUOe!k)>FnOej53|hD#EuvdXKrZ#Nv*zHHZE zap&13vO4UKh}okTPutyNpry&od_C~Rb0|-Q6g*D>F$!^+f}f*+>J=|ga7-DGnH%vJ zN;mWa7*O$j|5?Le+DJX|o7uXg5=r&K*cdCA{OFmJt^Pd8m|6Z?;$>>|k5h0wX@u5hUT;$a7Z}xr)6M%# zgzC3Z%999{d`a08LbyfpR02MF3*-s-ixmG9d8(WsPhUWe;1TjvT_8_?|9O+AVyT6@ zodZyX88VuEl!#YCSOP?aKO3up8kyyD?8g2-#Y)7}DVAtrDFJt1!A-nM0qqOo4GPFB zF;q$X6g?R*`Wbp67f^hWf(7~`zC;N>i=d!2AevJ!y-H~{qCuT>%5=&U1Ad=25kkFt zxK7ZUfVm(y4bYq5Qy(B}=*`9{!#QR*Cx{I>b#r3#4~SqzI zcqeY|p){D@lxCEt#HwV{{oI+mslmW~pUiR*Z0 zXoCTH@r#IpD2a<+K82}hPz1&Z4(OB9urg{9d#<7Gj-}pN6jV2dt>qfg+uu;~|HP~#=`vOT)=m^oSGr!%jDzNF7-UMNle86fo|-~?ybCVzk@ zoStCZcS-V_xQN5D7Gj=>mAQQgpIu=SEpps+KEg2vt|~6t8Y`-SHy(#O`IVafSu)`?8d)N_-62)_;pPCiL((}2ju76p!GnuzXv;c ze+}vndZergEES%F{Rj4KNcU!|fkZ*dmbi(a>1ccNb=te9-7*MZD;R0m0uv^j|3H0}40=Lj=v|MM}?i zD7Z_(0Rrur*^y#IDZfQ2P+U5xGJ~bv9%fNjj@wWsAb~=o>n4qyC7j$HXa`uL-3%_1 zIHZyYPxJ~wR7c|A%tKhJU?5t&XM6+Er@CD8+9OaeFz(IK$V53N|W1XWc#1XWM^Q=VtGCevZ z6V1*Y4_Zrw8g9G?>P~Bw=0B4d{;4|LKhPypIZDCS)v8XvAZpnP%={Y}e>N{*UoD?R zUsI>qq0BUy86r32pK!^y5Lkf2m!Yw)vvYQUJCs)2gBub!4YmsPCooith?bAPW#n;j zFJc$j=;G5>=808Esx)ox*Ko-LKqy>%KqtTf#zH5>UB||02+l$bIpKgVq?I@zMR_E> zQjFx+`w&NS$hO>t910(1LNaQ=v9aAGih@()sD`Uft2}C`CAEzAwL@&8>i{?a&0oXk znL=D~=QhD2YKL^IJegsIwoO))mo(DM14)W?QTA$bRI{M(@0v0Adx|c9GhI3q9=CcS+`Wmd=CJrO9)6Zm1C9^W4W4La8rs8k-B) z_=Arf+CI!CUM)NXgKBUSJf7TveA#rh1Bnr9KI`ccs1bzI+I0Du7VChmf-tSWIeF_W|(7F@9fk1mH&PUM08mTJ>5sul> z6zXp#u281J#t{VtBsZ%Nkt;AG8TC51l0p=Cm$DLyL3%~+3*`De6X!UT5kH{!gk<8o z2;grK0%(_bqh{Nhebw+%%{m4O#Gg>kKSj{;XD}x|rnJAHfRIysf&eXZYQT@cCRo}@ zC0Catu;Mn(+==oU>93Gb|*%1Hcxnfr^Gi;`Vx}&`T zM0)%Oiy2^l_TMgIpWf9B>EA2{MtjA*n8lhoE}Nofv5d|ZNqY(Yb%Lb-&7tb_WU&MZ z!dvV3ABxSw1r?h+Ao)|bUI^Me&HoJ)2k76TDD`DwCJe-y4cvdT_`ij>$nlosCJ=BI Se>IFcgmim(Y55z=Z~r$HLQz2g diff --git a/sgl/sampler/__pycache__/sampler_fastgcn.cpython-37.pyc b/sgl/sampler/__pycache__/sampler_fastgcn.cpython-37.pyc deleted file mode 100644 index 66487d2573561c1c6dc97bdc455da7f6a9c83206..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1786 zcmZuxOK%)S5boDJc6ZhU+XN@MSP(^8DO!^Yg2)0y5Ec^h#YjXWP-{HhyPjoU>7F&V z)qAp0PVg64Uy%3{{E)tK@{KbWsGiw%q9{G;s+xXORo7QlUk-LJC0-A{*#!roW`u>Ug5_9c4tT)7QMLN$}}GoQ5<1(fCLlX6%w9| zE@?az_y~>s8_^m!k{Od$F2hR>u_cAwpm`)zR)E@(BP=gJ+Br{@&p%Y@*Zy2M7 z8tTmaN*12+8ny&_jUOO`xGxZxpoa>>ZTE)cEYo1p)AL#?<&@iy^Q_Do&h0?giPDn0 zhgwqBmg;$7y;@brdgNQCWj={}wUWGDvf)G~&0I-sX(DFU|LNyMO?4cmRo?0VgFVjE zJkgr-KheLB_fD!p?&aradq)S~?&)d1r@L6LIiElXrfE5@&s6}|=@Wsv> zp&GA!Pl-DOBZTQLJcFroAX7Ajn9a~OS`dR5^lab7a0dCZ19>3GCH)Hh_-MflT0%@= zhZ=SdU6UUmnhW3f5NWoLjJKdK9)1roT{1(iG3Y`w&(M2l5tzUbK_?U<%`QIe@2BD2m3E6) zrR>&t8@sHW+Ch7U>+C}6c!#Ga?kHb@*;KrdrLL43g8oscB48Gu#)Bl!IpjfIsU{w# zbH%H91Gs@>@k(P>3V9X}++GKiSpBSS^;fI>`A6zqP}`ABXCPdA=l=9I5adOK;T_EI zCJu2(`ndlVK-bsObw%~yrkgFGe}Le?Gp@kVFJwXh0NUUsX)pk{DuZh*954Xsj5#pd zvcOB)mBqq?v)^=aJ!6^p7Gvl}5($10zI$OqT*F!^>{bsO?gl9rS zw$TIR?k^}L0~Drf9P6S}9yqPGfP6Ov0^^_0M?T}V%U}66!s?DY zI)i83qqvRtUe!iMPvPrs}Den z-Fkg%oM&Z{PnGL8$~Bw#J?LX&ku=HpIDzKbg;Xh*zpd$4p1xil``x2Ib49ORES^(K)nMM9UcK!odbe5?sv~!%Sv`M~mZg(CySYZ^eyav(x0f9k+%HodwQ2@#64I=M3ldc&W35Ico9Q;c{nL z*KTQ{8`d9bVLe*dH#_H0ZiG#gn_ON&xfL#;ybwKu_Vb~BTWc?V4l1>cww^n8qC}^0s&Bp%$mn)39K{iJ^CUO>!yP|M@h`eteo>W`Q!Qm8=qG*_$qa3A zKkL$R<;N^C@-i}!*3m<)V|=1*>Y*MQk6av)8Cs9DjuqOWgKzs2&C)s!$gbth+i&0c zenCvvC?=u_>1pDTFCxjbOrL7|#=w~BoBGuF8#;f~_-1GzwL%k3PE_+d_FWjpOv8^lph@v_X+Z95{%a6_TKre)^7A+}V$N9`8otsNRG z8t%)!KUFQ1Pu$E2U7Cw~nPL5q?zF}wt0=i#va|X@BWxV%+@lfFTurQ~4O-a(*875T zdlphx{B~*$77xxGsy_9DrKxpr_E0<2_iVH*PwlW7ww@U1jsE_*Q_U`(8M)%%=oXY$quaRfp2Mk0OC z4tmS~Jbev6Y z9Geyl_>@WK)do%7mBJqendt94-@cezpj+mPZ%Y{ol9lfZk&3+0m;FR$L4s|YAS5?? zQ6|n|KY8QFK|GGQn|KBl;yl&btZ>8y%GZ1!N8Ihy_~Z4E0zHNFhX_ zuLLo-`CGK9!yi&|wn~<%zw1MoaWCC9l;k=RO<%;{s?kDv(P$a2L1hRaY8?Dg{i5#b zOZqa}jtONkS*8A3n2#vGjz_|(K>bbiOy4);N@g6GQyomK_q07?d`rB59P0M0oqs^t zhNylD7S{&Opa#~aoEz#-%#V#d2h8paAVP=QUTx|eG!J3UNZfyJu4z1`6&l-h@d&*Z zFs?PNt!bnRo*24@{)>m&$L5~9SKn)7#^B7fK6QIW=!7+hQFCfQ#F~dt%ou5j^{I(+ zjmr1ef1*ukopL2(3bkF8Oz!_bNY8>^YI8rcZqqDW4(p#mxrYYCtP$+PLd4)lWis)sxZvaRfdjG2wbINg+nmSdDp~4Z_hm?gZP^i^WVy)h6bOH`kQ_ z21)4kvq-=|Jc#05PpM-$*X3vZyCk1p6m+-BH6Nt$0~irZTZ|>+;=Ss;+koCC-O%Qq*M8Xek0hz~K54ZYQ0#&=-!-zVo;p<+tL7Md4#Dj&;`rUCH zh|5ZS^aa#~IxX5(u^=&ZB%y=26QqmIE*1S=KM7)Q2)p*`>?Gxk^dF706IYQy*1wk({EZ-nii@gyQ||jmzHqcX1wc3h#U0^`M6D z;fzQ_hP`NZ;dCCPhOt08BDEF7mrEk4C$R^Jq* zsCgD&f;2ajtr_-*ZA%b-$!$6bS)zf^r|so76VW!gi})4|{UaK>ES{r00U_}MlH5@{ zWM#>El(kacP*+WpDP>A1iFFkgQh1KpReXacX;g~RXKG=svY@*SZ8{o7NytRBTdWZC z#!QBCi)a0B?&l3UL6W{`=eWjn6unn56{AFLp`kYn z%W!ndu%I1RjfTFgH}n0&)bZ*@fZ446K-aFouIAkxMFvevY|{6{1TL~fZ+0gx*I zkX8jCS706V1F{D;!#db9l+4iXDVS$*8CHf{?Rl_Sfk-ELE~~*}xa2x1P)AU4Y8~p& z5&(3vCRv$%av53-*%}+2Iw5jc@q03__sN#{klMV-=*?&8`7)3uR|jEVekW(ae06*+KXX?PfTi3Ly zy=UHk5uO7(hEwAPdurF7r@a&sRD5$t|lTdzdV+wAYay0RgJ z#UNuP*$|L&gf)bf4>D!WnK_OdK6w;CbjQL6(#EA=o~y9p1z|EB3&4@m7Q((OJG;ro zm}jmwOLhvN)Q4j^0s=VI97^wH2G4~9U)aL-QCh1RnP0^lxWdZ$BP&C4dmE1=7VAM) z?zZ-H!khd0-x~LU!&pl~Q4kljjoes&O~Lb{-1bG|8X4?vIvlk&@6xy9LPu^v9SQX5%%cHMI&H1gdud*Xvp1pb5|3ox)3h z7M?QP40DPRn*_dV!8^Z08Xu79hif~}^BKHaz{nZTL%s;Y;$u$2lDPw@d&IpF^=9p+AV zE5$*Jm#F0}YB9q!6R%Lt4*LOONp%*2PQAL6yg7HW@T(PfcdR7p%iI{f-l>(`UoTHA z>=ZCF^hd8BlLPTnEcPor%DgNtLjYi6j1_=4a&ui{RiC_YO!`le+qNriVHWWNO5UY} zAc(?j9m>%<;)j$}8mELUj93VC&;d${#sRL`?vi_+(N7l@)ra)aTX-ZTC14A6#3mtJ z9GirnOR-7vX9=jOAZ(|mg0CG{1zkJ!a1p|ACO>;U9!t1^GyFFB4)bAAPvJ&lbjAdi zz0D{<2nZYtVIU7u$%Qf`UBXZFp>Lv-l>!0!Dd83P&Pa7Krap`(QsfRS4n9%n*O>ge z5`e4V$Y4mJByjuNML-GMeQywTGw&gSMdS(sHG}dw|E(Ctv|jZ#q6`74?Hv^#of*5L zSFsf=aI@G{!S*jFA=|q4i6LsYSG_w*8$19p@Q+XFPr+1+sLyRCCU?u#;0W}f8h#O|iqQDR_)xY48l+kJwlA*uJ#S_fPQG1T%k3!3N`@#ofw+KUyAW;9WcqAEo&4P`(bjsGNqpdIkZNB13 za^J{M(dM@eaXtfT-`905SwhaF95~Y?&MY|;-!^k9kVM;_MAebphzKjBh!!4^6L4qN zC%?R`FYkX&YA$h7ovHNm~e61HH2olA97-WGzzpZ=~FH;h^ z83~_b6Fvw~Sn(-WM{4lbcqI97n$^-9CSV5sPUGGRY0$;dr{rGZU`eh#IcJX$QhbbS z9N_v0dq0L!;TJZ9$6d5obZte86WX{uXEI~a;>>2-WmU&`ch;$5Az~j|onht@G7K}_ zJ;eLxASO#}+Yo<-4tcGZ`nVYXm>Thv@bWv?k89!e6IXE>)I35Psoirl=|v=cut!KRR7f+5H`if=q2tV$=!lVS(1%6Zeeuf7Rf{b8v_<8Z2a#i|38VZ<{;hh z5fQ;$xmXb4tB{wtPc5Vj5qKg8BwsE&3?qt($*ulqp2RZG#+3ZUf5TH~37GJ2e1#J} za+);^XI}dVoY&VL8@@_)&KsPXoLZdPpXxi89~%SbhW4p`*ZT!V3~C4NC)y{#!OW{} zR_8i|trfRY_#2)had92NfvO(8z}43{WR12ak3-EG#;K9N&_BWZLzsBHmc+=G;&HC4 z9j&`Sbbte{Y1u-y_*kbmNstp`urxRux=-NGB5zPTzD@2&R3b;K9jQaFqOR<}jCQ6g ztcP`l_fcaNvzRlK;J zQJ>oEP;tIAIqi4bSLa6)*UvuV%d5UEAH+V|6 z|K3TCdaADcowED&^a_)Cnojwz8GY5M(A0D)thUZ0!>VBsfeZ>&z_prF3rcyJN_px1 zQ`eE0;+jx2+U3J_MP zbQl>WdS5uC;IL4qRbAwNqi3;!-+O%upeZ8e))xpHX6lnBw^`T{P-OW`~A72{9WGNX-AT~*VOx>H%SFb(nAOw08k?? Rj>tG30PL2r^wH7}{|_PSraAxs diff --git a/sgl/sampler/__pycache__/utils.cpython-37.pyc b/sgl/sampler/__pycache__/utils.cpython-37.pyc deleted file mode 100644 index 7ee82cff61418bc1d37947df469e47cf9f61f6b2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1869 zcma(R%Wfn!&~|svbk8ef741S;q?Ujc10>-PaYC>Hq`j;nu~`Y!67u$rok=JCnApxH z6M0U_og;$N%nd$;U+|R!3Lk*PUZAQwuSEgjR+Zi5a=oiQ@Ao?d*3$P+<)1A=e#M{5 z<-y@0fZhZoh@d%1XuY$92`0QVl6aycTB3bM6aNtrp@`0ih)!7&2)a#%-E(+Jwnhv$ zSUvz~AzuLK4*?l6rvS_nOBZZGXY7LFSl_6QeaFg)^YmfiJW&}J zOsXTPhJMnoRK`o0jd!0{O1r_hs>F?ds*|8dkGcMts0akTe6X=w6>>xCgN^M+kJs|) z@rEAf8#*l}xl|j6Cd>8uWa^ed3!kV;9F7cEsmUsIM58B!-hoeb;dD*`{r2MzxPJL; z6yPnaC4e?S?eksWv1Cpb;N2T!7L3WPb^M2+HaMkUl5c;Xwc(c8_JYp61(}65-1Fyw zz{o_VWusGS={`|UY-`@O{(@3+uzIj+LL1Ga1v#a&PQx|p+OF{Kf|h-zelneT*LIMy zO02g+UI%=S+@#zi!p9q+dL~@>vtGIRzg~H-zA`-<+1@Fe^=;p>eP6vqOT~qyH?97f zfc(B~foB6dfb0fa3Illl7c7N&Ry{z&fs4k{aJ`#OqzhD9j-_h>MO8;GFx9B4oHwd+ z*H*IO@p4(ZRyvu;QaCRwjq?kcmI|vG=(U78U#^|%RH<}2TykD2_LAiyE3@LTxGtGw zMQ25prz$g3Zgx>cl?&% zI9eNB7pRxM7$>h?{d2H<3|g>ST<2PIGPog&my^198-&g4?x1L+(yZiZndVcSX)JI3 z01%-)>azhIz}}-BIQM7>r;pht>)?qEScj>%fpdiHmkSf{5J0~NNY;H1Jx`(Eo4)sj zF9PWQpzeN>v_uFwjNA%FDHS^1k$fym3Efhu6a6&ZmpV?%_=*NUA}(rjX5!}68a|7k zPvcGuOG}kOa8I&w9FNjGk4*)9GJaOdWAiL7q}i=RTn+!-3EiJ$dW4V60#AjvY!P_1Q16$-sgqmhd9HMrF19{CS_HRmh( zus~{k_u*j49{H|Eaj%22O}-*e$Tu+ao4|MqtDyGWfvXsQaP?Nh6i)Sz04|~s{<~F; zn5|*EmYEikbJyX#fH9lnxyS47Dw~@&=b{>MuHHveK0t)3Zo~y$QrHmc9YB~S9RdnH zpN4+u)zl%Yr;6&s#jxc9=+w!y9t{e+rJf6QGt}I-33uTqjfJ0L${GzcWMR;`|2G@I B)QbQB diff --git a/sgl/sampler/__pycache__/utils.cpython-39.pyc b/sgl/sampler/__pycache__/utils.cpython-39.pyc deleted file mode 100644 index d116c303b5952cadf1d7f86af84bbb1c4aa6e384..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1919 zcma)6UvC^W5VyU%ce|InlvWW91*DaL!U0W+RPltM5=eciBBWMAtb||Qq6KNs6+ zdzJ4gdFK(q)7=YFJ_=vKuRNge34j-v@g?oGlnSqA>>bx7zsiaoft3e`<=P7LU0(Avyy^e_h?)ZELk{%)$1Z zTeq)O)5qIpTx}asPb#Ii4{ceQt;y711Qb5etvnoAu7xdIbi|-1gkFMAufgh+0{iW! zA94Kh#VEnuvvC3DE>Q6M5}b2nP8Q(d8)TM@$*g0kOOEJQ-vi zG-@l~)oOVrs#N1UVlq*U^ikPZAJi-xWq^EvCZ@J`!@TxLB7>Ss&KV zQl{82WnET6mv+kSE<&`G^s%vO;xAV1BYtowD%>=qR;wawTB!_g4(noVnMJ%7df(F$ zz~mlBYqMtw8i|MF;ss8%#@Ew4YaMTe_ES<9j6eN!%G?^2+1N|ie6MeQc2%vp@yB+QOxB@<8@cgQJpfS2HzC_A7 zq>$x&2%83*1fd=JyTmJy$Hr2lX diff --git a/sgl/tasks/__pycache__/__init__.cpython-37.pyc b/sgl/tasks/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index f4c1274b624759dd7debbfa927facdacba230ff0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 908 zcmZuv%Z}496iw1c+9qlGfMFyMi`g(;4EX_sC<6?OszO8v$r~%grnPD(B~CTdU*Qk< zA6D72;ulzPuiH{ajjbFVU!UB2eeBmk(6eA%B{Bcmu&iIWI1CTSEBN$x0A?{;TFNH2 zjobn55XbO_Y^Ww_supRPbyK#LOI*WSz&oU4cw2T=kMs<8fqTR=yd!-TkihV+?5mK3 zhWF$^4apF=$9xvBJ`34^4cUl|*#q`);gJzLi>=9{-!?Qjv8(go z)rW#Jkzz|X)3+$SAKXhXYYKlB<&s|KMZwc@nz2~rdAWoPPfxWdoh4tm{G>T3&Q29? zgf0uXZmVz*+6XQ}2ce6A&!WH=Q1}QkdX8`=~;=fE8W;bL-d`8KpZDXp%9ijSZ$*a#2_I7HyMg2Hj|G4n)rDgbXdgqgF1Z ziIkN7u!;em8sX2gXqhV>Y5gsVZ)R6=dl%`2jLJl>bVRAhL`muM^|tD~P&wO3{u1A$ Sz69`X-yXpC-)uPc>HG`)k^Rg7 diff --git a/sgl/tasks/__pycache__/__init__.cpython-39.pyc b/sgl/tasks/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index a5faead87d113183a97cf2700d33b50f6bc6361d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 876 zcmZuv&5qMB5Vq5wq)oEjb`e?$#9=Sc!){&xp{*8J4poJS5RxxeNK9+hPD-4L?HlkM z+;|^OxpLwaI5Fe4bXScdfBHQBGBf_{Bn9DW-DiEvUEn_RE%#KQLl#=zQzv@B z2A22LsUET+C3BMbKPVezK}^Qs9}4Y^X?^}qRH7_qMJ8@kVvPJE)1;C`4tt}9s8c^V zj`BFsD<$)JebKP_L8O~hiM<9ApKMiY1Et}wL#cgaHhCLEbU2my@_i{XnPN+~lQ$@R z7~Dx;G!*_Ms|CL)%2K4&B+p`96x9MUJUiB+been?>Z1{$czmpQEloA{OSocE;+>XV z1Q)?W=po=!D1C$gp^eZ1n2pcsvm89;=>wCunJ_1N+OW;d&euaGHrVO*Rvr4N2XF;T zzWVRip|N_DZf)R6o^fN1u+6fIrn2AAIE8lbwGs9}?7T^sfhqQ=waa;;6z4xlImA;V z{CyrR3N0dIzDDuw^ipl^A~RP}m6)Z8IG4GsIKNtLYVV~kvb7Sg@Ku^e009l?Da?DH F^AG#K`bq!* diff --git a/sgl/tasks/__pycache__/base_task.cpython-37.pyc b/sgl/tasks/__pycache__/base_task.cpython-37.pyc deleted file mode 100644 index 981375b8c9d9c9f4eb85620046281912342e8e7b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 752 zcmbtS%}&BV5Z-M|DJU`GOF*LZ1&kpWB_^g9JoU1v%&Nrx!0y(7S9uvi^Qv}IE4lJ~QgSBL#4K_Gv9jJkewgz?Z(7MpTKFz`#L|WcO`gzQ;n1n3pC8~B1 zfMF&N-Vq8%%lJVN;R?Px{!^AnU+bAa8s478i@C2KW8Yvy?cZZ-L2hrlC_Eu1>BI;@ z#hQTe*W-%4S4H=aVahRZh?5{QH_0@XiA;@zE2T2k<%*WJ8BNjX$SUV@oEurg(+<|( z=pR^>iyii`G=zMK;yg0a+oEZ2UakuNlgAZCMU(Vki{xlSvWf-WI;}d55-CK{6e7tW zkI6q2VwOj78Bw$kRVSd4RfAytSxb)=sH(p~Q<}#;o1WEefXkrhuCl^AFL^I0;|rDj E0G(Ei=>Px# diff --git a/sgl/tasks/__pycache__/base_task.cpython-39.pyc b/sgl/tasks/__pycache__/base_task.cpython-39.pyc deleted file mode 100644 index 08bcf452f23acf9279a6f290b7074f66fd968176..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 772 zcmbtS%}&BV5Z-M|DF`v*OF*LZ1&kpYB_^g9JoU0!nN^AXf!)@CS9uv;saG%Fy?Jt` zSiyuC6Q`NKoz8sU{>XN_$xynpOZA5Sm6C0-F&Utm4w@kY27HP7VucK>5$oAwT>l{! zW&>1nghsKD0}CzSU@ci_gAERP2WsG=uR$F=^e!}TRA*4juvy@+(&+fLydon;{}EN>7ww2m?V=- z2x4r=F#dc{G3Zy-->GCsMN=7(b@eI??akw!O<%bh^myrawH4mU8uyg? H-%;5o`mK)C diff --git a/sgl/tasks/__pycache__/clustering_metrics.cpython-37.pyc b/sgl/tasks/__pycache__/clustering_metrics.cpython-37.pyc deleted file mode 100644 index f9e0e38aca650430b50f91eb8a0ecda9f5384d1a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3268 zcmbVOOK%*<5$>M%zW9=?NQe*#@x^Pzq{%o54~$@2vK$05FEk)wAY(9`?p+Nzk7duS zNiJq{p#aGT9r6#B$K3LJ`kG7l7ktWBGs`6<(M^EL) zJ;we;i_6c%TB|5iHgLlFeGeKIb)* zC+)Q3cRBl(2}iiEnQ+exzbE(}>w6dAz;^pwyWfqoLnRBErm|FVL{sbgz1?qo^Q+-7 zBE-vYA>cM3lKEUP2*!o+n)#+Mg@w@)ws0`o5=@2nnw@dq6)n*Q%o80j@9MQE85gBg zaW)L*Oe3OAAB_guW3*xgUBxRFu*yIWemKu}`-X0pYAk~!+?R=NkCYV6yl-l&kjX$f zIO+FjBaX^u?;lLH0R4G$<5ix@jiNZ(*!%X`XUXZwMlnn_%CI;rHhwAD`sh@9K@ewg z83bLD%0kDu!zaJK0s9g&;)I3o4w{Qb7c2e=Xvs?c);KWE_!*z_(yVAS=1{AKnt+;s zTDZ1#XsT_pT4^6Rv>q6xD~xx>)IJ%M9-V%|e(+x!Q|Bl4h7EXa2(R=S9#3F^se7^w zOi%>o?||w3#NY5Y#=wBFtk+)U)>tiYZJqJTuAFb;hRc%AU2) z*{pNUrgqhybt`zI}l4IRHd9sjcn`D~yDzoYkMS*sY5ogHg65m=zoU6Gn zRlTZPErIL5!Lc`6BKfQCxiMWXdk4$SN_e96&YZ3Qvhp5e6_C~UAZyo7wd+&2uAOQi zx38UQARkrB)k?KmtyQ6pvO_~z75}?&WbHN1S9>6w)L8#q8}X(#$=XniNiqcABRVRA`eAu17Bh)QH@QSVIoTYi8ic8}d4Y@aJ~+`gd@ zshC=-=MZ*Fd9d{5UWzRWHwH58sUj_rrjXO z$Xt#~Kro+2N+g#ym#)^5Dcli?Fi{F?$yrOrRBG#`!@b-m&5qome$7scu z=t?$2%0TR32PUEi$J)BK27G3pvzPAFoLW;G5z3u<(-xu>uZ)4&Bp9c$?)@xRX_&+l z80bV30Q^tUal5qFbRjpo}I06=q_-{U^xMZ#QN(vozKouxc++pJPjr>#8m4 zeox&O)ct|FKT=n}uEi7659E(wG7d|))8i(2ewPcGJXLx6jFRA=ai$=1FusESjxpJ~ zQrAL*|ARAKT7gaJpfF*MHH5JZ+%bm z7up=eLv2QRf62T0mkG6+O%1m4P$r!B_+8V{#p>}e3?oeKQ z0$SS42ZM=6+}20g@ZzMVj|y)TPgb9z5Gr0tu@DAkv5!>@1ChdXZ-6?f1hHa(x1-Hb zQfdb_8!M^)LI_rP5*NN(AVve+cbf-dl}94=+pl6FWERMDUws2)-ztkt`p%1hM)fs7 zAKJ7_V$PpL@mCY~=b!PUHOH!g_S&dPnanMT$Bvj?5ty{Qm0=srmBsx?@8!Q|wRy#v_rTw60 zhLJ=pPcmTV;6wibRrHd-G>2Y#(z&;u`ra-@iK&VJ*~QMxdo%Oi+xH!Jy8(gc!{g`T zr(Ht+g}}|n0^n2l)JqsR;WQ&*Gt-a(%rYytLz^Pb%ADK{-Fj_jUhaoJC4$Ub+_|7N zmS#cT4m*^*Al&8NYr?&A7Irz^Cw>1CIFP+QHQpDK;z$aG%Uo!gCb+c!yubJR(E4mN zO7P Q?(zzPylPCz)y*=rJ7+~PLOHg~uSvm<~p_g|B98hX6N1Hk#b4FYt`W|B>m z7BVeH(SmV;w67jqGx(mur+P3{v?39y7>vMkpL|RA`pg7cPDGT&2O=}USPI^(`(nL3V-x7&-_|NXnt1=fmN*(X)zxe#)?Cf+$jj|mbt0T4ZztRoHXU31B zv`BRnbxN{;n(dWOFlviYUj{Jc*wNJ**i9Kw2yb6k+1&plFi(2$Qv@GwK;t4HyoaU8D<`w z18Q*YBRJ=OL*LLhY{N73i*gB_`Q@Qs@_) zW=<8%I~91_tgW57t1aLH5(C^>%JK+hnRP3xY9lQU-#vmnK{gLL*0RB^n(tcGtvb~j z@cXCk&eu@ts&m0+>$-cm-ay>vt#{U}2bkUpW&772ZEsQ8iy30GW0NXV2-$(6d42$77NL7ncYsB0T zr9oVY(m*J4t*DiQs5!Gdh|t1)y4dXENljc#7QZ2I7u-c^wpT%*9_S3C?o0`2+~KJ_UKn#`eSKo%i~ z7UU5Muzdpt{M@-9FTI&Hvu6%uDsSe`T9B=1#fBDys|!T>Ze!#7MJe+*OQ%p{@`;|r z8PtMd*%-3+r^c1;#QYE|8;@jM@Wu9TkWkPV+bZ(Z_;9NB75Q`6lAqx43mkrl!>@2a z&Bo}Xl8l38BQ$)WgxyYW&GXES4;cwxN!k8tXN3O@f7-n4K7^Bt~=J3fnOU1x^meS*hg*fHR&-VhvH(_YM8hSi^K= ztfb7K*n?e6`??uW5CNz1BfR5`O9<)jD$z!w&tPAnk<`7z(bn9Ca_OS8VM!ekZvUNfr?M_U4Y2ES7w=Z?p@-cP zlqv=gDmZ9bxw-^R@9*na>IW!}hNkJkMB#{y+F&ra#4H6luTw;rn94#FT7i-*Xnc$_ zl0p>5KFB7J3=d>^QiSe!BFC99?wQE4@+7p!X>k;~CHU+}Tz&}Dj8zVYQy=*a#>MFJ zwB`Uk-Z-6ZJcl+&y%KyW4)AhLWD2H1RhjM(x+D$6)DWV^SmR6^7pygrLjDmC*zswq zz?_hy0p0hS(6IF}&%@wV%DE^aksrtxa5%KJDn#gh8R0C^Xy3JF2gO_jgF*&%%SdRz z0{DEoUFVkR`~SzjS2y?iE$s1%I2VxXgDA>NKFM(1iK1f&`DRCA`zJA`!mAw8xi7osMA(z1`F|06GDx~34-}XQDJw^kk-s!KIAdfZvO_m-Q2WrCjOk$-D zy}aqv7X?WjWg?bEotwhm0vhCTkq*m@H&(DFN|=pndeFimW#cxp{43C{+1@6p@gFyT T{{&09!r~9&(hk0$)sFRFM;txk diff --git a/sgl/tasks/__pycache__/correct_and_smooth.cpython-37.pyc b/sgl/tasks/__pycache__/correct_and_smooth.cpython-37.pyc deleted file mode 100644 index 8909a19c3b07fd722af67aadf0f8a5287f62a4bb..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4686 zcmb7I&5zs073c6<5~bB@KfL~m<-|>+z}bxx#|7N9sXyZMV7NtXAT7~?pf!|NT2hjS zRI;m5r!|mV4EWSb5kPzFzciRI^>?Ve; zF)gY1RYfu5TH^Vh!mYTTH2g-g=C379znQH2>q*ORB^&-mvgvQCHFms}Z2Q|x`%2>u zcVB4Sjl5&SKgZc)ty6go*|bL;Ce1rMOwd>FhgtY=6!J)5%iIHVV?CB#AQ7 z?*_w>pw~C=jf9B0*_|POoQy`oIO~=S9W0BAdIs@o#C%Z6Ex>AC2dw2bpqDoQ>scdrayPH!)x4H_$27B+ zw*Z^@I^a4dj;(x+YX=qaA5Oel*@k-KCa4D2x%raby;V}T^CoBu>ba7-^9J>Nwvk;> zJT49{9kVy)F6Y|(9muW)SGWdsT8}!q)MFvNZ=?RR{Vd>7H#~aqq+`fxVJK;T8c8SJ zpY(chBr9Bu#sQR{b=uNOC*w$nIus|}Ab@dYK_EW~-u8fkfWwMmrIH=V8!K#!!>3W4 z%8ysrlE7$Up_8c9aVSFAb0kuEa|IToan?_u4Y87}(RkF|Pgl}PJWA7`H(cp8Jo={J zjpW`6-$SlJc%+ffR@gNjr5U_Ozk47*U13|Q+(aN&(s;Mgr}ZN(8ZP&UptbD9=;3pM_RRp0u<$gDdpf(cW)+!>nsjBPJDeo&#XRzF zufQ&h1;?djpCkk7Pr_cF4-ECE}EoyyEJ~guui{4 zVlNIg4v%mH{rX}Y#^e3aU#BIdv6L&ka;O%_N;t_zX*Yy3vWm`G+!u(0VAm`T?u_?Z zQy+aw@GFaV3g2BACZbv;2)Yzkg5Y1;PdE1VM@h7orqA~tfA#RA_~>vieHQP*Umc`- zi~Zq{m%{FjkHi+fxq+761klWeZnG*iJa+M5)F2GKi(Y4&XW9jQdgEAH&Tlkl8c$ewVC~j9p2A8?5Q8K``Rn^x0^GkT<3z?;MRdF{+F#0Z!~Mr1gxP| z(GZ^ptz-MI`m8dm7I797Yp|){7NA;rY+*%h<_-UpZG(2o=iFii`hu5J^4eLhBVOmx ztO4<(m|gIsmEZ|Y*z=3I!<}>53F9vC3i=9qXn|MJ*U;C{d+5FS#p3I7cL(c_ZD_5Q zn+K-&KD#u(%#oeoP4ArOGs2Ad70OqjPkr7#(PoY8>fqhHf$X#KQlHs*ZD!`)%zFO) z*UlG6}E%rkGb#?A8Z!Kjrlx$=TF5f@~MQ6O3 zTW9!Ku-dIh)9pWsXe>tE2#MbCS(}fBQM*5EPmSH1J;~BCo71!h)B2YbV{fu>(|>)X;kp*ToZJ+#T5W)(r$5{(CT=pEP)C~lLk_E$quD87X4u+^=u^G zp^ZBPC><7;2)svtQebh7z(oR=2~-HYPv8Rrl+ue|5THs*{E|QiK-xW21Yw2}V?Qel*imzQyxmz@G${>TaadF z2Jr0;QV-IbbV(<4aI&H>(vDwK=utq-rKMtzWSOjLKkL6tB%p7pc zIVg~vo{medqxcXcu4DpduMF9wSS=`kC_7fu;$0eECC~D; zjvfM4*@tv-l~9vqN8&e-O&P6{SJunirW+v%3{F1@S;1{uNx6Wk;IaNV{70J-&h01?RUOsIB(5>~<)6zY5P51Qw1Am)eXaE2J diff --git a/sgl/tasks/__pycache__/correct_and_smooth.cpython-39.pyc b/sgl/tasks/__pycache__/correct_and_smooth.cpython-39.pyc deleted file mode 100644 index 7af1cd483970601f08fdb57ac39d33f27dcbe42f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4732 zcmb7I&5zs073c6<)W>SIA6~~_a^j{@;OxeU;{ryS#IfU`fV(Mb14W4z1g)XGl1WJ( zQi)fkPHP}113C3l1kfJyFU_fFbMY;Qo?5i-?+q#KIwwoSym|9x=JD~(oA;rxUiUQo z-hFa6+S=B%f70Oe&&1#}wDhk4T;nX({Gw+*Q)4|be1j29kIlsLErlDgojAUua2C6X zp=(S_D!!*EW?W5beof(4Tu&N)BU$s;lBVBG*8TOQ<+qXze&q9{SK4nEgmN5y*ptR-W!EH5*WFkhiUX6Odod~>4x3z zM184Q@q*27Bm5H!Fl+C0+><16Sta}g?6gT=hj3;3h$&FKUOX60@2SGLp(!(T) zGSTk_!;zrZH*b%Gh`QOWA-|uDM%h7^1U&t7%R#yiB+^KgzRnrfBV%qJYuw-_9*bK& z_gk&UxXqns72k?%?()hr?TGmf_jvW0=DWPc>*y=rYNqCUP)9?)vp?d|?Kn)+ey`sR zv;Jrp{Gp#61V5MS!J|TspFVx@F0lP_`1en{8>N8M5$>QFXz5P?nU=9!JJRR+F-|oH z9vHbcFa{QuO_=xxCtj^=L%nemRDd`|D)DyiFf6SM{OOiA5&g?cvI z$lg#q&JE5Vu~+6UaAT!Tuj{f;j6SjgHpQUB3F7VxMW9^QS>F{D=*OxmAD z(n${{y2s$qwbU6}H9U z!zfPW`zvfoU^ubRNmOqfiV$`kiBw)+fyHQ?^%H1AtR!nR9(51Wm9!F%(lqD|S9%SP zzV3G;xwpdikc$vC(#S_E>>7{K3|^$)eJnp*VOy%)MIctvc(?F2*tLp;Cmv1E#g-N2 z^MXEqDu1y`C93~Un&l(2ts?x{caXPZ6v~tRYagoB?n7{!$*-c>B>lVI7#4> zdF0<*fn6BK0r~g^OE-v~M%@Xbz$z!Hz68OhvmwniihI&o++wS^WNT2mXp-jb()jJd zI{g-jy*Si3Jj4z3>x*$1j}JnBotBiwQm&|#L$yFw!bvttyCIyBRdmkczCauVy=HN6 zr@Ysi`sfpaUs=3U`0m0me_*l_1YL?PLGXk2?`wMpqa@l((R3R}XskrA;HkK+{rX$MFU0IdO{B%pJwFK6}4rQ~MWn9IPsY`fH&Tunl8%|hwV9pOu)co8?r7h!KVP3Ye3s9{*vaq5$s||mjZG(2g=geXS`hu5J z^4eLhBVK3GtO4<(7+&zCmEZ|2*z`tMD3HtJ2^G4Zd(PDsD5PTb|KrDTSZjO@VW z8_1{VbT@PB6dwyVyR|>v{+)=%V$_Y0`VAko`Dhrm`@{Cs*uCD9EG_dpO}jg-e?bxV zX8SKU-~W)$UeFzwwv>xdj)5kyXEt4jjThzSx_PUIoT7dE?fX~TOGeYG!s%MvY)`kpDsl*wxwh{jV@eMqetCNC!6*ykHt}vp zaX3yRzI?~j`y89wN~~%9YWpFL3Vxj{f^N3B1Rzb?EzT0^jhD(2sC_hPAfcD+skFwT zKg^__jl@OTxI=&vWO1IrTLdT}7FP(IBXEI0g}~bc-XTCqzW4mmY~eqq&O}pkSMEG)8b7UUMA26 zkWLwz3)3bumUa>1g++^&M%-}=^Rf#I@vF<2&Ii7?NI7tAOtMK?>-C4ooJwo>BeC)c zyc#{lSD?jc=^FrK7cJI8px=gN=xqC?S;FZ@uOHk`Ux_N!Swj>rsI$Bvz>S*3gx7-L z$s~-+8GVSF|F%FC@x?Ghy>@sHsLD>Hi_3+YEjtvSKt5&0N@iIvbDwU6Br!PkSvP4N zWd^FwqgpSe+|JcaR%7OW8&`DZ8SV`=+u0GH)5_lfAdw&+usa+Ia?XCW7^)(wsKc5C zWs>#dbeD3-B3<8D?Az^4hFvQ0aj(4-T>O@}Zy*`lRms@yLzIvKp@RO?k!vS2igj`B z%1EXk5aaaP7QI4Hi`l2gXs7-SNvanc#h>D`^#x$Ws5 zdw1MTAnU@A#6~~?PC|;=;)#DCRgj{fiU**0q2K}1P(`ZXMbEtP0>AH^o}Qi6NM&Yq zF8!VF+`sdk@B1CSQYd5<{I2iccK_>XMfp$qnEni8UPAEx8WB^N+Er|ssmQFV^T=goFMI@fFW+s@sl-`W*iNfWa8?8Ayh z5AzLrtpRtVO0Rw33LX=aWd=)qn6Xo?$xN31P&rcV zH1Zb9h`hzJEcc;eXIP#UkY-ttm5}B>QVhk;V*|=j<;`|~?`^r3?Lq(bmv7y_jy|_` z16)HxT>2+s#J zwhN3(>p_ModzpJG|2syl*8RuOE{!jX5_3a2(%;p>d|2==PA&!cpm3~(MZrVk&j;Cr zrxcb2Pbnyary_VT_fGKyPc>W;Jk?+cJj;To3ZA7CJS$;M@T>$i@T~fmgCa(vvlL5z zY+!_sgHi#1Dxq8fWqFQr6_nLE$~Ej+<5+o@*7G{T=a>S?u(rm_w{Vmo&fT_O-^S7L zc3t)u$Hpt~K+d%;kMi%k?VVlUVQ#B=c>BJFqJ|Nr8O|_HYm}8|&1;X{DDCZzwzs=( zlwo`@bh=H?Z#)sDywT9*yo8F8q^MFt~&a}}LXUFZkyy?605O0P%U8Qr#r&4(In zLch8EA)E7l+uwBrlgCMN7BAUg=(l?iARcWj>ed|&TD#t2iRlhJ&)M!TuE^Z?+ATL) zThyqyWfm?kT{@1x%Uy4G&}GrWEt*{;oK$CI(3KVy-Vw^5@;DltlxS36)VY+juDAjh zcR*qimoJu8x&|eKy--%gcAjX(a@+&AHS*nujELIfgO<}AwW5WBE2Ogd-VWWUt;Skp zcy4!_FVh@TLNIs*d0r)g8E=^x??8IGO{j#+y)5T{LUWYp{*A9~p-M8z2H2(e*iOaQ zj;qrQaqhdY1PYQRRv&hO8cXEaE>1sV?i_tjd5dyhN#mvDJLb*{Ld+gzX3dt8EDgi$ z+BqV`j4*66j?;n)^c?5Ul>gbhu{-FwH#~3u#+}#S{CxND;D)!;y#Ybl^KNu;Jx>i4 z+|QfCL%xn>B#r6aMxP5uyAC;*lO7e3-Q_E@`bfnsN;{lDk zIaS9k()jeFl)|eBo{0$86I~T53^@Kpd?R@t3IBy8dK0-wu2L*OuYuyxaaXkQ`3E}{ z3mUV2wvL}eQxEC5AonYl%P*bi`NBhbp2~32HJwwkxy78XAwzTSsJ=FVNr0Sjtgb4P z)Ugt(zBx$;>aiMXfhJeQ;v1;J&moF5Zy4$Bp&MEKQ4aX@DMu?BNZmTA7lRG4Dw8hno*7)<}6PR5+5AxtG<#L-&iCbRFH zpYIeHwP<#V$d_2^V?EU4HVtONuu6q6YLTWx>z;b>wqFK!hvs38!~Dgv2>>qnV)^)M&Qwheqs?lZ`pML36W9)W7CA6f4j zVKId97FPOO{squxZP#P2S-V!!E}4mv^CEJWV$Ss=bJiYK&6-?>@s?Rtj#ZXLK8G}i zG>tst!+txhh>l?bwHXz3`6wJ!##Jp=3UPeCzE1?4au zlmc_l;FF*fnn5MB_9tJ(4ox11(U;6HuBCBnzScdBZ%pf9)Jwutx7c#huOa#+1KwgQ zK}twLEl5-PLnG#14NO5_3#=I(#<#UiCuMzWT!RNP#Ibihuif9Nvq9gjxBK<6zWK~{ zqcBg{R3mWHp64Y13!s8c)0!SdM^ zK}nWwMw;#iEv8+a=OS8^5=Ky@!c3lts9i|v5nf>zla__3M!9J*I*Rg1q4=5>iPj5A z8GW9gXgj>Sh!i!eS?aQCsWq((L!gE){AonVgT4Vn;XJ-->SeV^CHT|g6!o;M zYN!V47^qXm{ev$VgH}^vkf;kyi`d>wE+N`R5;$=QeIFSpO=Y4ATtVq*OocE1jG*bh zxQIID5nW5TiY8DgxOCON#gq`p$Q@Z%`sphCWZlnA^0*ZAN#R&I);h(|@Jj);BCZsq z)gwb*I5C~B5xPJm9TyPL0#m3f1M14yx(!ciIBFP)_IX66u_8~7K~}7g_oyIDVg~~H zev>o)bxcsNxQO#X)=DEI1->8z zdfRRKBls>+dNA^bBj2-^;Mc?8oiEJ`w>f6!yVQQ0q7FrS6wMoW{2QQsg5cE)fyJNCc z4C3Wj9w46Krs%#Gz@3SFiIrnOFU2aXDgiyvm)Nq%XV?m>2~cm9t<8aY>rr*y2;6%4 z_MP!t^?72Qa8<8<7w#ylE&DVV3dPnNFU~~saeQ1Ck4fH2Lhp`E|z#4%=rV|{}^6C6rootLg=3qvOKj2#e zdGw1+7$;;bIwgSymgDw~J&k|YuTWoo^oa5*X$xmb;Dpc%lNB@Tdb;r-X<@&n=R_|j$?~I1Sfe%Ej{_5^v;JNjy zw3t_)K-xUGS|9NGRayIL{ju@d?qKNE$@mdb2oi_{b`$WK^khzh9+FRp#zOooE~m?LF@AQM zfj}l{P@Dz=B_}{7EaC&0k1XU4p-)1oMyG))$w5o`kRpLVoQ6zfEm#aEkRPE%qQ@7m zn7HoI$<$Pct~{D;dRadJt;mk)3Qw9k4N9G0nKsj~r0*yXsF0LppDPW>l7N!{}VU(UG%ZoeHV)R z$21T)r34q8a@s{_Cs!x2Cdl@u@G8OsK*jHY_>d0ilK3mC6L)w@W6nC9uJvc@@pm5D zV7@Sl#1Jk&Y6#N`$r#84rz=pb`#&P{ zAe6vVSX!{jn1<8{QwLB$9sM=sYrCNtyYzUI0oY(36Q>17KgINkN#D#sY9UP_%^)== zah)J7oKOSKuN7!;b-xu@fMLwg2r_|=@9RMZ_lX&%_OE|cL5p~mg2S4E6FzML00!J= zNzI^rEzpl~1E56%w*$Vm$@a@`Epk2^nz+$KG|tkIyZOP*r`QQ!Q@j^Wy|!`U54JtH zdI@>?q(4~qT&REMjn}r|i`4J*1_OVWeB;mP82+4)l&M=V*S~c6lKq|=L?bO7zmKtd z{8uRYE~2RB>^USO-FY-pXld?%qmh+bhaqq0PxfsW#Me|<_#Y9+15)>1q@#Ul*QVo9 zM^;kll)*_h+&>_A*AYn9Frnk5>i>qIyoh9ODG6u1Ssc1bv5-x#GyxXs|bk{ z6Ckz&zrqJ321(UAI&yf0{TMQg?^Fn&<>1r65U|uIAuN%J3ERc=A;lf|4Mza5V2eQB z!Zg~k0&9|mWI?VDuK2kC5(nAe2s1?Y75FItMVPh@KZSVqzd=9^pfVjc$)|Y9%*LfC zdz;2Zb&9Z)U|JsQlM=%z33CU(6j*5WjW91t^K+#I(8{oS^2^Gkf?8_)Zb{T4Kq$On z23hn|oK(rL5~KJL&WwDCq?bl5Ecwd-*8tjyuQJ9_IruteL+eu$GujHP%nI`dU!Xa% z)GA(QF`{eX65dJ6L3LdTmJlov;|pV{GMwmVl`wx#-TxA1UBMUtg{gr1_MeLTufTAo z{$ZP;6z~_$@G8KwHM~8?D<7CR#{z=fzY?!qZchUk_OH^Sqxaw=^$d2jzbjTYk5Y}7 z{SE(&G`~>}4zc5SW0pNm&eG^uW}WIiDe9F(y(h%FXz@yDeexQ`ZlTSWP#Zf*Eltir zgJ_fUxAPBlx%Hfrsg=-AjB=seW#f`T^F(!kcyeJTR}Pw!?a>*=>!w5?F%ACth+1Uz3p7 zRMiN>m46d`BpRJ=Q61q9f=B8dx*7lJxdcZeWV7`blzBZ>-$BC|Ohx_#+*3(k*Wf3`& zJu>4on~cGW)OMtHF#mRs{{T(%Kc?uE4*ni!_Ypi=Vg*`u873h?M5KiY1S+Z<(8Wbn z)2J+qP{S+KGR(*#C{^%O)oXy4iYPk|-Cae^Dh$g!XlMh_0bqm1XBLx2fsZf{{vJi| zBZ|`SUhrz^2=|7Mi9W*u=o;dmQbd;tr`hliC`zUYS$=~E!YC#;iA2F8XV(I>ae0kk z009M@7fw{;f)utZV2EL|Br1{r=|m zsNW(e7+{~)??>6$w-m*^*J=+BH|75=f{Q8s+Y$e=Xaw(vo8tX&ll*&!vYei=qSCBH z_Draogjds)BS&0ea2(a%8UiwMDA0x*TCzGBN*wN>Ob zU+e0&t`bi7jc&%yh_s5dsVS<`&DvIi$#wGrQ?QE(rqnIl6_GalYPV+BL^|WwyG!;` zciCR3EdQE8-?qL?nc2kXwmzdi{oWJ6R?rVLQc^o<=mp55Fo^4OG$D7}z4%nO-D^8H+_1GLsHRQG;*$?6 z7Cp>27_|D_i)-EX9cR!FoSwHU`)z7*BN(_m@SN@)C+y>&qqhVVdcB~};}w^6s4Up; zcD<0dTTZXfyDeIZqhBV;1U3?$LfdO=p58H-$~64xJDK;C9hDg@^KQ=0cqTJh_Fd&b zwX;ZDEGN)G!2UVZ76dyk{f z?f;S^`1coE#bhiLts__nIzsTTs-lESs7A_xHqj2jU=h-e9w`&!P(dmanvv4cJJ|#E zdKP&LA%~DhC?FION(kk!!qfv*Rl@3|h8m#hpuAJs(GRJ`dV)v24(fEu34STkP`Vth zgsYJ@Sv$o153oKrF#=~Dbn8qXD*R8v4OtrLk-?1j^`6FPbemBI@I*u-I~if5&Z&q- zcsio7onchkh;mFBsvoHQZyB}P2p>VaG`>7a%uVG$zpjmo<5GBbaxN-Hr9*987Bn>e zVw9iKRK`_7Q;DjesRhbePw2_pjd5P}DTyEW=sGsg3jU%mwX{7iWXL z;m(fl#W}|N1IKrRu=zxs35ElY^9m})mgAt4&~f4u&I63;IBBCR&aT(NGvSw1R^VuQ;#faV_Ga) zIK7DEh^RaJ$>pMr%z(FhVVn`-DArqUXfK?->wn_Frp2PsjYS=I`^@uW-S>L2;kN^9 zKE$C5Ip^_*Y|dNlaL*A$0VmE`ykz}B*zSUXc)Yo&TW`>B?FEZD#_tD#v(sB#k$Jb< zEiYbQ)Tp>^7A`McK2EsDyw^LO|v%`qbwgD)Y?mkD7eTV^IZkWIGnrhv*v?I_E zX??W$fb88$1#^pJfBI3H!pjJO34qHA7sp){Y8`(b`MWBA5fT5%DR>R3Sgu*ZPFH{l z=u3qxzkrBu6FAm+N$z06sXqaVwD&U)={@D@x$f!ol147( zeh~?pdq)km2^0mSj6-!znPd)?u^O6_Y@{BlV=dCIDRO`oKZ!d03_z>}gIM?Wz1Zpv zyHE%juFA0Qa_;VPYB?@k#ZeBW^l0jGNvWotBXv@6*RfghQz)jj7Tcnzn_z&^g$L)i z@FJa?XB~|Yr-^xXuumb)RH|!2?HfZy>*y0GA04Pn9kY`?)S$}AGA*>23N=k#hkf!r zgXtgXr09tnLX8p%9<5kmGRqgjVyDEYMYB^zy23K=>tj7>(_ki)uFM!pFXHUj3ag;3 z9jI5757almIL;+>q%PK3b_gu5hfB!mPyvE+Hyz_xC8GQ~S~|#M z-OEwtrh)bN$EZBcMAdONszl~ZgMS!R#%5F-Tem)hD)=Z=!Rh>xIvUsJ>wKW`<+K*Y zyd)HOi!D!kHAJtZ)>~{P%8316jj|MfWF*vUkty)&ku}3ZCAW5PQkomvqgB`@0~~)Z z2->~f2J82{M!VM->08h2#A+Z^Q3%rv3-<@!T`e(NaI(m9x;iSoLdwDA#@m;le435# zk+2%o#cd<51S&v7o2D+SP=1ik8I9$2b|L{274Qpsx96D4EBVpX?ED5CMs{SiQq9P^>}mggZmw zrwIsoW6yp*YW2)Gkfyktf40Pg@qO z8W+-HbQBk-g(8<0iPlTgGRhvmFP_8ld>tX^04Qo+v(#nPQdhOATGdvO!ynR?S_Hm& zKTr7kWqnyKB7I&hQ#t;iKqP;c<}yQ_i?no-}csv~HLC zRBn)V$qQ6zb}YnE5od!OJI$PA_L5-c9nTGiuw>$Fe;5viVPG%8%7^+pUs@DKbVAH~ z)PA49fWR#R^Ewva0q$230+IyaCj6Od`M#lQ_f3O%qN>(OQB@add~;{BL_UEYj7azd z@{WUr8v2wjI?k=3>nA1Z)wpp31Lylri<8Oc^+LE9_8nnh`u*K@7zmv?ws7d-ZV31J zS1~e9>a)<5I2`~n?zogF<}DBUx^rY|UnUY#s)Vzmj6^`@j74hp{qpBDsNfW=jP<;r zXrXZKBL=1~poL4_dQ7!-2DW6zuEH|kSDBeO_sEBuvkfmJy{B;QnMfB{F>&-|NC%LP zKHw{?D$+StV|C%|TVl&|&c2nnKCb|7zjXE5==H`tvOzxApm75xDr?}%A)?9>PCju=7NA9nKe)bxQ!*?wp`nnz=v0KGu$r`Q+|>BD!W#9} zLysu0lc;c(gbNWuVX{Kn6EM&(9q54MigY%FLoKrnk?=0&5bIs{PHeXNJ-B0HGiY}Q zzBhXIbuZlOvqlf=dTlru^g}q?8ZYnl`+?WEKx=y83B>N53ynT+T#&UdG#(kP@AU^k zgESyf7Y?;SpHpQ{+Iw>|&iMVe;6BO!lam}NXR7Hp%`((28l77I+B0Hw5 zd)m}-_tX)Rw3((QtwrI7F~np|xnleZZL^uI(5sYD9WF$i5avS85i9%wLO{BkV#vE4 zT7Xu@jov^^4u(yVf7*{k1yeu6^odD%W+1i@XAtKQo0FtYlof`k0rS_2G#IhAM`4N|Rtt{-X9O*e9`vu0M>b-#%=_%{gr20*;(+;nL3 z{dxDJ5XRgCMKdpj3z-ad@o3+6N#s&4;2#mi1MGA=*3rIHU9-ukV{2OJm`+Iw+TSAt zPXI^@1@cJC!`>(hXsOlzHB7QL4B*J-ASnM}ox}{mo*fti2k=Wl#{eshvuMMLtVteBLo8H? zz)E20&2f&fp#swb{tl*X!1N%BTQ=G#F&%owxAAb9Ps&iD!jpCDkI){~DntLl^Y5WP zsW7nrxDZ)r^(*6|C@apDm4K^4mnkkQlN#!%$wMV!%kVDY(K5=Tr}Csurji)NA7i%k z@JJ7f+PD%f!)FFRoXAx%h8kuO!M#QZLberwQqsbA<>D1$c(E4Kz;*gBq=lhuw)982L% zyhXmW5-Z%);Mw|XshiQG?}2&(``6nOrA1Ntf8l0$Lh8*Z2Zh*QJS5BhCMRiJEVn`R zo)q;esHgEwF)A&Yzt$t;PwWxOwon_pM=ecGK^$n4(_}{C0rw0ZUX>2{1Ih1qRv3|Z zz>V;DEBss4ZT{0>xqJ8V@%ND8D+GQHVD|@q{?9-Bx9DG<$4?=2*6vCpQV5mHBvjzI zV&lsV{yZF5Pwk8wE!S@ied#5F2`3(E$x$Z#Td4pM$|wJ2v_B`9saYL?%*}wLG=waE zUvfyl)6lksJi{<@!(MP}=y@X#?)9Gs)e+N(w1wDu&Ud>vnES#K8vJiyl+hO!ToQ}L z%GPs!ACm}PKti^{e+4Zx&m}G+`2!06E#O4wab|~{;{4YMA%qqgWc+;sC4ks;2LrDs zt!cbQ4PYdHKzM_GaY3)Zj20F$C>VD6A5yg@fnOk+Y|!5cp=b*a)lA@V5`2-3C$n536M&R23%{3_^zk_sKZuekgJ5!_EN>luMRIG9-3N3Q#i$RuB0;$G+ zNK}R7yhuuM(2XG{LY$kDjI))%t+p*fg{J~v^XqtbYz%i-& z`@r2p2z~}YfmB_Fwns9zC}eCN@j9e0>4Q{8UZYh!Gu5FXmVv3FMpb`5sTI{a zYSf`K7J)|_@C3jM(EO3vq*>x`V<0>tFb0UTFiY?-=?J@pf0OVt)1EFK{(A(3eNVID zzeHd*%`a1y_*YmzL=lLexu}MIay_-_)7UyKsG4Ua%aQxVE{K=)E%8b%U$LtaGYs2) zuqAvlbh_e6aA&Wt)J=oupKPG~~D1K+e zzb+cVL*bTqDBL2W-Jv8msH8G0kv(UqXHtHp0Y_%F^p(bU!KWx&mL)l+^E6XOa`R|YHqO)>4V;M@=QdG5uqx;NfGJG%l}3D58>%d8UA>{JRHyky*KC+Fult#9 zwvm;2!?(J*Mo#8U-|prc`EH?6=oTBrZmCh~mK)`6rBP9p_Y{_4*+Ye8z1)G`m_ylO zIh1p9%`q=bjrjx1G4g^6tC-&<#B?f1+(%*`!&=7wrmN!3MxnXeSz) zrym)|*deAdtz~?vv{a@u<1p7Sy$my1=1@6M8(EfR)}hkCb1@tD+?R@>G;BOqUQ{>x z%zM{&gP`4NH{Gz^?>&AQ6Kpo~G#M_sXOMKH;BSCX3DrnB(2lfYMO7l?9X(PyddEP{ zj5J^-%tqRgb*w~sWH9wBv^%-T1lmy+n2$2RLSzApj7F5A98-33{4b39mcp`>SV66h zXEVMc*>g#4HL~e>liK;D_S~u3^ImU84L#==d?@0avZ$i z&~e0y^Yp$t4uj4>Cxp5umZx^`-Fu!Nh^48GUA!ruYc!EPaJk#{LXQXHnW>%ogRtGj zs`(i+d4qm)J2+#Vem@ADR_~0V%)8%idg9{Ld8N8vp4v1~a89};a`Ff`ZT3K1nK~eD z_vIlISsWDEce-$DO?q0P7lg8FZMs-$PWqbM?QMC&pf@9|-mvTRu*>54sRLc#cj#!2 z2V0K!z-tb1crQ)uFV_KUfMaF^p5GGLu>zR!yDVZGc!~BRle~{ivbovuE5Jb5k4n51 zCzP=)7c!}U5z0AElMb@u{7L!t)wS(@*INsM-L((jyYq^__h2p9^4Fk$JHZ+~W861A z^s9qCK97|Ykb-LfMJs6;)ux1BCjOyhA*ZS5PNXGmfB9)`Umc%?rb^Oq(Vfz}NC8J@B`#T{61+PQ3^=^mFY8%_dA`D9)=F4obemX9nzfYfvPH@ zG+!eauMSKo(Cok@0kg+i+Ny;XQ(@55ML3#27)<|Kr?GNO4rm{bjbHv~QqCu`VTeZt zlKD`5Pz?*{Up!E6DG$}pv{5GM17p}JF>}Y_-^6ANWwTQO=87!9q>{WPq*x~PgiP#0 z=4d&Bo4lYLtIR^)!aaw34tE=O>;%i>p2xj#v>ZQOlwCysgA8`ej*K0HA4c}bi1H(I z_wZxv;?c^n63-B~KUDdj$2~A7b_;L9Vx?4`F6G6wWjGdkCzZ&g``3niGbZ$DZgYR( z1MUrYzv;oP^|orP-}7qiUTt4reZ3{rV2gB!u8sZTZ8C&6YQMX2{WUgvO!~QBm4{Cr zIcmU|(f%CtFZS7KRhg<_4$_K>>;vnP))Zlcy>ThnU-(6Am8F}lZNOIq+a9~JzkIhJ zx_*s1*F4`H1Rk3VvTs6SP0q7lyjHtMCDQa8wfZGa$B1772!lrR3zW47>C^LbRH3`9 zBh&}N9PoB86k6ElD>QI{;01yuf@cV_1eXae5-bzs2(A!3M_>`u2q=c*&k{Tj5Sdn= zcim7FoIyVblNV4g$WtxySd~MIbb|0i0oE-WIw5UJYhjN8ID(wf=z$|qh_fKFk}l_MXexj;CBvX8(2}RY8Ca1$d%NB zRziLJqaKE8X$ADd2q|Mxoo@OPzmOUt+Ef!s8lt}mx)>kfBq&Mg2y4goqGw?nU}Y-CYmuqda0W zy(LHQuG4`-e=XiVT7e9o~?Y5TFG3P9pn$Rk(sv5ooGWIZ5c#znU;p5rf+Tt zy*U`xtFh)8N^dAKak)`wwR`Z@af$z&1l%Dw?OpggsQU~l zr~#lG(%CI(hFX9Iz}ab4tqNziB-8%oX^!s1YT@b;1ojY{Aqcb`2b(bTDPM4$-J$Cz zEm{v`;9ZU|>vk_hFu8XJS?Pa;g=2$>h@o#CoLa$~B5&;@+Dc!VqWj81R~}Q%~}h zG7Ke;7tqNhy{)bk4C}hAZ=x!_=nT^p+yzWy6NtNIUAF7No3dT!>u86Y{z47a5N88CMNfH(A=wFdNXLlX{Y9J|oG?{7 z@*LO-7)qFr^rHg!fdSWINbXXUe5EA$N?9@(FcV@Vj5rPP5#lc2GGaAV>7)%fXdBGZ zWLe2w&L_1CQ66>U+6zhT;;CB6PohFpj7m}Y*!YzA(K6DCWGmJt{|&Zk|C8J8aNFax zV0dqfyMt}VZ^JY6dF?7LAYm;ss+kze58+Z=t>$PCJKlB;gN)3v4zs<-W zHPWF{W;;mRuGOXpuVKD3j@~UZr|Eo$=Xt);UX~Cfw@kj6U!{tlbRb(3ihRa?b<6>~@Eq2a0!P#^8jT)n;^= zFf_hMa-W~kcS6u+O`6bxSr0yD3bVu{tm5*F^+=xa!i)}2&^H+)h?iyzI7#26?H04d z!*L|MN(LF?#aWUmOo$nrG?3!C@9K;YLA*SpQ>yUs7l3*tW+Fs{MDeuRP#`hUVZKVm zHG=B|-zRvL;5C4lpK@HpLOR>2vnYybJsmr-z`bUlv&lJ5)|S|N>BpqeC*Bq!1vdan zW)ZF`g8&97IUMre#p2-Z40Zp?S*PoNaV;|+|Q;KrBtMUFFqw9G3t7ks(6N8!*fY3!oDy5YF!g`*fl;m9!N{L(sF_;$V z$M`~-P(FMc1F1v|HHB2-dsqjQ3bvkznSC@V+hf8ALTMr8B7^^)X6oe0SB-3t%xA+Q zYDx*otb(}3B$r~ovLu=R3@OHoNgt6xBA7&6t4Q;bWEMxIkf>$chv=Z$q3>w2rzDx> zQ3WKEXyrr>QOVP?h&qxiqLmX_L@i@^@Mnok>Is?HmCVr!{NR{mu0%yiGFPIKB$+Gl zg+ww}qKYJ$D+$Rgq$IN(Ss~l(nFXyJONePI(aJ1? zjD%L^PS8pN)btbUTb3<3rgJqDM|N{5Qa259v6ZxSl{7JemPoo;bCwe6 zIpF7S9vRRyM;!fTWl1&rADb+Zj`%~Mx|SARGId+#%-RvtMOcr-!U3< zC_6lla$c3^Ir~7XyZ5mYZLQ9v`3Cp9xVu;U*uU2Ic_2`7-t?p3x*y%F+tTry&7ts{ zThfVw*o%UIODh(ByCwIr|drUOusaZju8Q5*t`d9AJslh;owW+K`EFIrS~a6(zBX zS_98#eo3)I;dg^L5RrUrX6ODO zZg;V2amY-;px@kx4q0c|k0P(tJ7g#iZnc|%JT-G(rS@lMHcb^>kS@u*ItpH!-;w8L z4oKTWb=YJM2S)XsEu31jo=zM@vFcizEtZwz@s%}A#=?0P-yvV3ypKtBvU zI-=vjju+esnnN7oGc)_Eb-*g%sF_g^wq$Os09N`g%h(1{p}nxP_hDz7n;X9Z41^u2 zrdzR387pcb(~5gc<~^@ThuQP)Y2UqYaiiZ2E=JMj#Sh-T_I$W?=VG)TUWESLj4skM zrhO*|e`T;GPGBWPr09nLO)u*6?^njBqRB{Z zc2SwCSIAqkC2n;+ofGaDWcP$=`&{?@BV>V>G9~O zs#*Vx`D2=w?J%ar%Ak#4&N?uxK&t~I1mxgBhiUATxi!?pUs9unvel^q3rH@~9Mb%b z{gFO$M)~*{N@RMi&T)k&fU4(?g&yvx(0e~Mf~u9|oE-PSJ(ZQqWG$j*DhnnOQc2!v zq*NyLgiP#!y}Ojad7jkv76;;&-mS$2@f zer|i|dO!BV8g;G(p+ATMJ{{z~35hi|&vxlT?FN-d&o9^NX9OK0aTXv=8ZAyy)*Vcq zUMx_B;q#7Ucce8C?OrVPxGx@~fr|vs5G)ha2yz7H2$l%W5abD-Ab5(vA$XjCVmt99 z!P5X~xB8;%$Fk@R`ca&{fO=7#X<5LkJX)j|#rq1dbh+4zX#=D~4@X-n^D0M=DJ^9t zC5xpyeJ?UUc^@*Dy$2PIc{JuB`yrmg;>fNV*SKwNUmNmh zvk~9fEgkWcv9>CW=3rQ#Pj%2FU!^okmXzc3pc}V+!Bc%BwG?E{Q$@T=ISNA&J@-X( zL)y2Z*l*s1o=g=ks3>Fpi|9e5AC~a--O4_u#?PlK3SF zc%R^)ml1EH?qj5g76L87&t~NZCS0Lym?)8s*m@QEp|gbtCR`!m_SsqPZ{NBR3Lr@A zA$CKM=z1QuV;EAt=y{t%Kg?S69!SGm0%6$IUW{OK>l(7kKT8Kk2?1xkC0@nKlsc=` z%2MY3n*qWUZ{GuvHy7kPRgg;FFQUMMG}wGlKBF_&a4xB~CaMoybh_z^t^;QA4n*fF zK5sdSe-P_P`G;a1#6Q4A6!SnHW}Fvz@h;oPLdtUS9g1W7k8YeGXT5g+0j||%m6nlR zl%Ft?1(@litTeoq6gj)8M~g{`?xGgMN$;@FnQCX^b=7V}uc&rIypDFb@6TDxVw@H* z8e_s}OvPisSUP4J2Asla!ii(b!RNtiz-8h>V(b>dBuqFSQ?Z>=R{V!kQ7iiu05I6E*_{=OeQHNrKFrx_RNon zNi88QD_-NQ368DW{`J*%yb*|6G`z7c{K1A7w&6GWqIMq9FYkA~Ndp4l3)5}^Ux$_6 z8U!Bd>$Pa3KMXmXP7Nw}tIf$kzc#r4nq>&Uhnkq$rPLtdxkLWw)ZAVR-v4x`^d`uvY5|fdQvq!8)(T}H&=k>WpYXK6^x`iFT(S2I|$6 zlC02MEFDyvi=;LaLMJX!@dpGi5WGn662Z#=d2GhnkPDO94jepLn$*)7lM5nf_646F z+H`F~nR(^Ll&vQ+7b8XQ05p3MeoBXj%I*e{JDV;3xAYD^55MoX&mDHs{wI;iCx2*A zlcJLGHHXZEZ~g;odW3J1i^=#VayCLUJu&w1VKd`|_>>05iQsAmhXC!Rk`>OMzfTG0?xq6!@i#keDK|Phzz5 zNXHZdEse@C@y@glaYd^`pWIYW#Xu{gDi|nn&Z!*YnFnPNuOwN-Ij6FScc${-&@!3S z6Ed+Q_U|VOlNuCB6O@Yx zRsr?7)TrI2iaP{b1lt5;JH$^3en#+ff@^@ramC3V(tPm?YF{IGm*72uLqqK6(E3ND z=mtR3iHNChOW&KuK{{r#;{QR($Y7OII!52D^rK8y_8G5ks(4Ak+G@QZKElET(E#xq z0*~Mm0y3_`Cs08+y`3}?r)Sw`%%-6?MJu%B(IUXfh^G0ka-G6&N-$UtR7_Jpma4{w z=!mafM1fX%JvDiwkQP;tkbViSrZvO39Y!nkc~N~^Z5+eLz7_ScZ{-F)IeU~9^e+Nz z+@wVfN=XzM$yAdl*eG*TR-NZ=YqoH$9*rb-&C4N4?sDUy=7BDt&K ztSxW2CmBe4aDkqS0;^kkE&8|gFxOt1e<7#--f&lzg*fO6e4Ouj^XB({?@{ly+n$Es zYeyd@zq_Pq|Dnp_XQJ>fp7I-rSc_Sv`SqLmOx1N?S8pRXeUnk2o>{r=+p5emCwKkE zyzJ#ozop7X*3LVAC-3@QRX4L5Zm${#@_8kr3yW3m23^NrYuO}x$6ifz2D*n#@4Z0*p#bDbjt6X{29 zTJ3#2We=ial@@4~4vAGUlWpUyW{e|BD+_m%tZzuOOtL-R#&F(&_=m=6>UVE%9Srki zt1OSU9(?ri>)GSut#Uuxf|U--tzsA_L6n7MneL@gDAHjO#AzuuN3c#1z#1Y5x>%+p zfN0ESlgsB--dv1{81>^v?=#&AZ$hiI0Xw1VQd%2XQYbC^jMK_W%UUd$RQ0Y;S7n%S}PxP67 z3KS^FlGU}qtPDtN0bE;J<#?2EPFJNYAUA^V-!TXX;^fXZUN{1lO)`1?h5h8;i(7UU z%USg-xC}a(Yy`Cv1i>>PD$Pj6(uL4iN4E+7H@{!gXO6HyC1IbSqMZ2OCT<{S9CC7Zv#_h9FP&w_jR?|;rn%qjcJ(xeW)0mDmg#1q;yDNBBn8jLU=aB^m8 z?vKW$w1@{w={*|qEDp+al1P16N+&5ErhHgPJs$zIqwrB7?Rj8m1dAXbtr1TPA?>{( z&qDzZ9}P<}->>i0B-h7gep3X{Bq&A5h2N+vNfG;9QkA;qlcXa8L5Gvh+&n61DYN)4 ztxhKM8>&RMy1*Hrp5!=*#v*wY+bZYOn9El2xO#`EOLy4~)?%xA3vDJdQD6N5Y|mGV zEG$PqqAl_Y<@>(?lJNk34LHt=tG-XS+VgH8H2EZOIxAI8Mw1X#ZC!B z=FF{2!kaZuwYoO7MGHOcQ(CoyryHA3jj2qfpuM93L zxmuUk!9~-Ca0XWfYlHQ{)v7Vq#>PtCy#1PLeJOT`4l6JAp4!vqu@KiW_QnajuYI{U zZ5{tn+^m|awgXRUf1&?bn|2hrHsjV4w{q`mCv4iSIs>c{x1X|U2ikSf25+l+RR@|Z zRnBZfG2_zgW%YJuuh9GGzn^PI@^fTV@TShTG>a(8CN{8lduyw2N@q8W4gnL{RN**J z#KAD;TQI*gc~~66r}T`=)}tidKM+BjMB!s_ninSz( zA*WEi)3^C6w3Jt8Jqr6yc`)AF%M#h7S#XNOgu^>VX%XZ}KID(3Q-=8{OZZvXC)Oi$ zsHjb>xes3Xa%GC>`zLK3@fi6cd&>T*ADK9%z5@kY1(J>3Y#1G;#Xjn0G>$_$%;p=? z+)F}%HY1IXpHcF6-kl6_C>rLYyT8RADm?%v-+B=z3(BrG?_iHGEAOJPn2p@ymsp_x zs?-b4Uq=zy6guUyvhmyxgMB_6kCdsogE9PN8fb|jj}G{+sO(V&SsNNFA%T3K56kj{ zLh#|}@y94~G7L-?a-;$RX5{y%L4Op>6?nYjq$CdLxB+P=T^gEY2x(U+=JIqGZX8pF z`02buMgN+|8tQ2mqpS`X^0;}pQ9rLJ9C0MdRU?PN5J z4ocZbaPbMJw6`%y_z;E&W9bqQAB8+dswt9De|erdfkHX)n*tIi%25E5jff*{;>amw zV_0)IqWKs(ERrh8pdzY3{O0rbiD%CCJW%L*29u|Afbk1-rMq`1{ zl9A;pnZiA9lRhewdJCnW<0XctBFgm!-E}x4 z-n~T)bmX4(Eafrd(#&FzUrj(JQDBCLM3J@g;Lp(%VjbfKrh9a9utOkSEMEs%Y1fdDt-Z9s8g&hDI~zvfx23}GLa*Dj8o}g3Q04ESZ5K~TiG!x% zCqZHbK%%&M11*zRUj&SE2RnlYo$U-Be?jH{g@?+i2`maCxKwlIK)M5psk-{=S+sbn z#Sl%(#4&;ZgK_GTO|(+3xM=lMtA)sAeb3c%4h{Qlb@g1CoEVS`5*sAOBp#A@M1m46 zJ|XcH2|BDIDFht#*+)xAk8XB;M@@CFb%nq17Ji0lM|jkCG2H_48GRYfyGj|%#*6H~ zNY+3I*D<@_;Olttt0bkrhypQxKQkXFYX9 zQMac;AfqT!(8wi>E5Y7e(zGgp`3>})CJepe?4&UhS6*P4_@xr A?EnA( diff --git a/sgl/tasks/__pycache__/node_classification_dist.cpython-39.pyc b/sgl/tasks/__pycache__/node_classification_dist.cpython-39.pyc deleted file mode 100644 index b2085d834e888e58f1f172901bd8bb9e64b60518..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4800 zcma)ATW{RP73T21+?yoJvSi1Oaw&oh8Yxa&IB;DzX_Vw8Q36{H3L+^%aVTk5`}L)WdWnYY4L-VWQkZD$MlVz{X5PPUX^2ruNz;c|X4yqK?qEBU4HQob6l zGUG#oyWD$da4&70Sm7GlKChu&OBd0;%-LO|6MTc+8rvPF?00#bcTDAN#WF5ZsRF8e zG>CaBFzo*VojY-H*m0E~C&@^}$+7Yw7!_&Cl_N#m?W={n+wVrZu}lu4s>e_7#n~v9 zsk(4(Y~GpB^5=gh8t>pKUPX}_WIC9Qar2=OTHN9`-Zpo*i?_o)e1fkUJGAy}--(5Z z^B>ho<9m3D1r#MKjmRiXl+b+Fc+9pRo66?+ROMvxZkly0Wfy6-C)yAfH0QA)mhk%5 zjm?8Wo^BS!(dOL`KYAlOey~~WXPdCrVX@gC@H9%YxG1`N-6WRXK|kW%LT(IUr6__` zWE7Fgf`DQ$myIu;-}c5_CSf!_`%(RD|Aaqjltz!8;OImXO`MmNrG35E4ik&$(tE86 zz|>G)6hR`2s8NtEgN>J;Sy>AtdXNP0qfMSI`~n(S7RFmMCNN!(>6!&IT6^YfE20Id zj;;5gtZttx_CT@YXSY+2qWx5MWtvCP3T6~83WEjs8?QdItuw)G0$mjK<2;Qb6+}@! z;G>M{jVL-A#aT5XmZ3_JR&kjMI<<2JSwJ*U&`$NbKJ#tg(tZB}M14#ocpJqT5F1|^ zdjKQ19@YRy08ztH`x)S<{VQ~N!0YJuzJeJ;9}sP-m2KeTy{F)@1$2D#I+Q2#rBK&~ zC)CoTeLxx#m_0G4=BdFd{F32~$Sy6E4xzZhf~%ZjG)#rKhPm1nYCZbyb)$$7PhNiQ zxe~B(nyIVL9VfS+pSL$x=Tx)cK`DRCR4uA}B8r|MQtyl)Rrn|j)-ql4h8Mqkr%(Li zJUq!22sqlW|AvY*rpDmEM)cD+-xD!vdxjS7b-BZRryuDL)f|Vh$ zs3rVKt@KOx&=!9!{fS%F^bG%mVRvm_n=VveP&%syW}>$wZVR@@wd-%)*mx6v*H!)F zckk}JeUd4ZiA;IuQZ@~UmtlAn z3`I(drbQuMq7EzOJ+VqPdw)18lta{5sNnuUWIQUm<5Zc0LV0QbURMnI%FKrV?l8XJ zSMDrk)S`KOP|i?v`%<}k1ChrP{yrQOa<*P4s2J~uSS+l|2s?=i84DTKsz%!9VVkx} zL$gWBlaZv4Q{K!xI-Y5>gg$8}lZ7>1BU?SgARwajL7I$Y`U;lSE^08Jt>E#^77>~0 zvumuuR?G(aY-Xdq@;&&TX-gcaoBxQp$U8I(yn`rEnxJ2^#|W{PjEN)7o<&g+eRk<^ z>k$$P=>VIqMwHj{r?rxu5{~StU)5wVt)CiIYvRfVMw+LjwS}k6?MK$cEA8IG#Fs&D zvA0yhPxcv*Zcc1@0lJra7j?N()z`pElbZB;mwKzcwch2jHrtI0O~I`Hitc@Zdqj?9 zz=KEbq%N;Q^4ba8GQQZIG~|nAUDtQuNzKpA-y4&brrA1gJoHO{%Q#_^cG>Dd18+WJ zlNNT?MjyPbTqs-E#bW7A*EK&bPG8b*Z~8L5kN$b|cVfs+Yg^g0k3J9-wRbwMc$HKIRWhV;?-d85y}c||b(#f-I7|gRW7zFSd72NzvGR&I zA7-g|DhnXyBXsBxPRzLvUio5qf-w9$V-0~BsU&;E{%9WA_)>ER#M|g2|Jcn2$ziv@ zk4@Xjh{yCn8*eImFO4N+tS*1>gqBrsD?>y|2Kn&zFR_MZXP24Z`gq3IpU2LUQm>6$ zSSikm+o;TSh>xM9^NKS2LcED8k}M3VC2i}OFGl-fFdAx;3lB1aTnee=KqLp^6RHQ4 zU{(f)Cghjzi9u1k-8*`qH!F0yrRwxOv*~K-oUXgo5ef?H z%F)NK8~|!RRc<;QBnO46r2s%GC@ZdwQ!#)cVy=7w-@2C10z5t&U;SRtJ>?>l5RkpeF2&a<5KfsidgjzRHN+}W z<)4f{n|)+%rwoo`^z5law1|AxoVr96(wllnA&??WeOZG<&;u<&2K0wMN&^BU0yf&t zOx6VfSx6{6*(!;AgI;?Lu8H@;yajUd&qf!Y8byl;6lpc2kquj~`DxYX{izvPqcX3!?3LS(^IZ!smOz{SZH+2>kA!ZVq#2)%8Jp05u6pvN? zTpSgfBtZF;@}T1h+|pBpFsp`muIb;# zv?DzFwoIl+N@Fg;dsirlS%03<=lL2gp)X=~rzWo8B`8>lt5oQ#+0Us)JfTlm&|0cF zyCNiMmd&pTnps{WIn69|pCNSxdAZM8p1PW?9Whu`m}gH%4ByIJ`h1w`n>8LFe?9TOS zZml=(&i5AFh29zWOt0?NduQFVs`8rR=Y8w8;#)!WKzA3BwtWX_hto^GdQE8*KE*hd z^@f@nU-O$igw89?xcNrE=?9{rrNy_Jq2KSxbXvH!({0CDp_*DRH>2QsGum#Lsnu+? z2BO*8ORXq~y(kF$G#86zJ4~y+cG&iAG~?EmCv&XeR3#z8tiQ#Q_Q(U8W!To#F* zM<+pZoMa-s*So0C&?O~2q5Z!2#8L&Z>B3llA;=hpjv@bzvp zirO3PRx@t*L+@I%hZ%*NsTFp+Ubi1bpFWO$*4A6atltts8bbk3R8$ouR$?_#4z!VW zh%HMH_w+;qHR4>Njm$$O(G$a0@94q&8R);E0pkjAs<=a61H zRJKj=)1>IgM^ll(;j#=)`l5Ak` z-tUX;Ktuv7uDDfZoyl&jq`7FY69_^3lUkk!5sN)9ebhTPTcOb_XCk$KAthy>9@9`)M-cUFp8UDD@gSy7HJ-<)r=8Q z*G`%UiK@2|p45{c+-jrsOUI9g59FV)~$t{$pJ;G|FPaV(CTRX#k zLi02FLh6zenxEmPp3*NIKN>lDpY|Iy1K*2CLj>t_$5-P?++Rz_hNe%PqU!jdq{eoz z7o|@hUqudb_T;2jkFUYw`z#e?7c(dBq;2OOX8A2I*bQ0(==jHuA6pzmv4`8EQA&*{ z=x(I>@hQ#8b6CYW6mxVK&FtiunOx_`#{d;^-`GE2uld=kky?% zF^^G}@kGyoC|X%FHB)u)H`Qf4<9|)9alEV+xlGcsT2}WjAEUok$7-%cT9JMf((`pZ zkqMG09n#AxZgg=KX#`Ir5HC!~&m)#v<3)c8sfZMvq8;D$xWc6C3hL0MsN;!huzOwrz>EL?4-y!jtbGxPpt$NYq8eH~$c(z$ZG50o?jWsjkhd4$yu7S|*oV|xg=UCR z{C76Thmpr)KxlU^q4)8Mpnv4$JC8h&13fdx?u&1NBK}2mgCig zdPgryQcg)WIdz~73zO1^?E9A8CSat^dBSos58jtO#J9X*{` z{@hvRQ1xqwI|%0y77#8VEFwIEu!OKadRnewVlRUKz=W|{Ow4VP!q=1H&`e5LdF|%w zZ(|OlXATwFOU9#%e;)Io-?MzGEB@l7%^bH$r?BELB_^M|!~xn zfUh?+@g-=B)RxM*?e8{>RB!s7RNYP0JwZxyzy78Oc0|7wz)lZ0SNwh$thB?GeSP)0 zjZ}>$`iImWd0V7%BhtB8EoB6A+*c1Z^sTK68@cH741&~Y^}Aho959-o6JMs1S=A`D z;tSno?}p!e@g)@hCmv=0i^rJsk>PJ|8!xokMLQ# zd2`&&CL;XXFTG0E^z$pf`ux+M_lKX7SnSu>E?_C60!Yqqe-0*wbU!Cq70SXGCK*-c z!GIL4!`IW}A!oqmPLm;U^wZ$WpzD#6=GR?sbtf4?GP|iz4q%`kL4vF|0<&r=P zvI4|KkTka=+F>jnr3|{D)0}9An}HzbRy<7PNf7t(6D%`w-{BP#7LaBZ)$|ypdN&Bg zW7OCaL@p5_d&0ePe6zIcvQf8jJSQ6+S^p>4`AzumY~e|EaRmcO^VAj8Oe&4PnXyX) zh>9$>>DqpjYH?pI(j*Ap5od{9AVO+H5FU^k)B~BB8-WPmwuw9ykrkAh8-3Ah#%a;p z=|^#JXYSlMUlFAH!XS5TzCg2#IBCbBq=t&9UX1s&VQ} zWjnPeGM(l#S!N){<2w2NAl@0o5gUJFZAXeutb#m)C;BdkqB@!dL)3yPT2~)OSV!7a zHT8<-fL_ulW+7iwAIqLAS`o38@uO8yZm5Q~04@v0ZB0ARZJtL9HT5F?E84PF=2GHV zK)!)~$?&Cm=u=%h$@tB*?L>)@<|mt%#6C)J@X7ccsSL{~9ax;O;(iZ#I>(;lQUf&C zF%MPTcIr?ImpKe(R+s& zT#&THzT&vG)NJnT1flPqfW&m0H-c^ipt3JoTd8>qXiRH6)qxH+=B4A*rN+&Suei|; zp;6?y5*_M%8^kS;p2<#|>rg^gaEq+LvXZnQ^;y;@F9S|HF*zNKZtX z-BJ~`e6eU)?=X1SLe}I*1||KXS3o9cbM-{DdCo8TrQ52T_sf0-xUOD`R< zJZ}Z9y;e6E!_U)Y`IKzh^%h+z^rKs6f3I5b?^MD5Je@kVP483z4F@fqaS6QJ10J3t_hAS zb;FSW+kI8CBYdREb06~*02Xx`)sdj&eYK_l@ag!;DQFY2ES6~g%S6r*IZu6_ZY9PD z*y$D_!_AH~yY8C~VQPX-1RI{|GpE`a>DpyR7>TYR9ep%Dd0&rG&yl$)mHLr5=zX;^ z<;Pa)=T9|AHWbM*WtGd?Cu>{svEU8HC|O=-0aeM6#g>D70WZncDsnCo7lFh#LM5dLA2#ccX@^ zdjEljr}6Mk)j4kt7Hnd*kzGv|HCt_up{1DYw!3gcq*CqSA(Tz z1Ce?;u^Ef{GORP8OT!}ZK`lB<3alj{h^S9?T0JSuC(l{?<)~!Ce z2&=G&vBJ)ySAuBat%PkplH4U_bi7q$8#*q{3a=wkT8Q?-)|Tjp^dhdb8H6+`xT4KLBt9gz@&=HnE%}xU!xO(m4g5Bh%w;|- z-hpid&3GUJ@h%mT)hKQghaLCvzL#DwitCj5o1BUFpM84K>!*b^>`i<<2&2A8twGo( z7ul2Ube}4Ghbq7TZqb_)Xag;{80V!l|$AfxB^5s@*14zroe%0`*#9 zoZ57I@b0{7dXMD{`4yfn#o$B{iZFM<~C zY~hv>ArsEK6|~?zxP>ryKl5z%QUj8N`Ow{*IxLuST$PDKf5gCW+1{gX$xZ)9Jkd8n zq?^71^n;M$Y(jTDtiH@v9PpW3Ly zsIQ*#D#?6sJ`63g4d_Q#A6@mV(xDIVRfMmeg<7cNLzit${4Z43tYfajF@KlT86Km9 z?-E9&mdWiT_Z*HfzG2ASA!W>^P~&938gMtw@Z(KoNY{bValmEXfrGj$5^|+GP}DZ5 zYA!7#_Ds91yjgXRdnzS`5#^#+C@>nI1?al~m03n_<)lQt!QpidRH*Gdt_(A&EU<%p zGlpubz)?qCaPTM()N^q06Z0B?foo9iP+0kOUL{sDYuBM)tU&~K7alO;rr#$yZ_LS4 z&3^Ltv4niGgCLCCB=6Y=EaC^m{X-&rgm;Kh3wQ8g!JyZu z%2RVl#lH*UGV;11f}3uRPR*XoAgL8UqH^|OS!N^Y+{)MBky|8oc`DkWTbices=9mv z*>#nb#;1qfW0}|!e?slE_Le#HjljKv|Ht52h8#tNE-DaD@$7WEHvR+8nYJ{wazxyd)Dt38jqXu*9rV zf*^?HnIuV~RJu|WWkt!ln;|kXl?kIX9c0!~(m29co+W9#gowmSmXDFe{ijYNiIaN` zR?P$y5g(ufxX4m-G#rs2ErvvXjR^S%;#)+fCjb;9<}+ zxT&XU>AqGa=OAu(qgArY_(gW(#T;&q!U()%5T0;qcX0dt$A8ElziM9pdyJ=-w zA}cZW$8S!^laP2lK_-mSGpZvIg~?*bWuc`=trsND^dfdtiRvH(ND}}~H)s{%ieCMe H^2GlEaH66G diff --git a/sgl/tasks/__pycache__/node_classification_sampling.cpython-39.pyc b/sgl/tasks/__pycache__/node_classification_sampling.cpython-39.pyc deleted file mode 100644 index b1cd0be6379a79187342788d9ad63eb92f0ad572..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 10264 zcmbVSTW}=jRqosL^z>Xcx~!zVSYEFkk1tvWJ4xXf$8px{7+k{G%dN5L^iH?5Gn(m< z{_a_?lpX@n5*v~Q6DZ0+IgBVS*(!=lfGS>~ibr1ZkOv<6G4Oz^c;h8yidw#N{+^3= z<>h-(Q(&0MmOi=RM}>B^O~Y6UCSvjNAK9(qEqCu(J6JyPPtogDqPQXs@(-rmcix%rE;tL_MQ5?Qt5*hy0V@Y-`MH2;;d0ktrxv0xZ*|I4Kua8 zX0tE6=3Z(=LF`6B;HSA*c&#w4bz5P}z3RoyYp!F8|G_u?R(pK2R1n>3~AAH7ui zMbCl+3cW~2^Gok2o2qa4xtj$?4-DV*^EZ_P)yesmU%06_rf>U2-1F}!hT>RQY&ku> z(er~Zb-XBQZMK?T-0Fqy8(tTy3b#@#>~!2tFN!{V4C8Dxi!^5TE8(8QMZD-oAh8mw ziE^L~wL^Tl1Z7uGG*BbXCECzDR1!TgeD#(NUOvfzT5%!C$M%6Hz8n|hQj$N=LM_Ro zrJR_cmBa$ACI!%1VuQ{l1z+i_LimvMsi%5kGw0@*Gtc}5ri;Fj6o*TP%5_D2FDd%Q zih^fO`82NuKR0gsKH75Jw#;pdXj!2bcu10~T(h#;=|mZx!DFuDkbuD)Npy};hx$M;W+qzrA+qve3{qEI3ICJz(?_x zH_oNGsJ{~kQAT@exh~WwcHQ)0cg9zAT_5@j6-*zRE!Kr5bc9A-+dH1{x$7F#mysJa!XncrXUp>!n-V{v`>Cxi&A1SroM$X=+&xj`j z-;HQv2kA4jALB*b`;@UA>_zEQv!9hiPHTL6arP5D+K~rzD8xxm$#JIkMtZ`l=3%Dg zEjPFkH2bi)kIbH697M4TTiz(8Mig{5)BI>Z<>Zc;!;TSk+BRmU<7TFT`H{lnL0JEq zob{A`RC!uc40HQ!Pi#fE)wJNcO)|@_dt3SF^B1r6y1~UL+P(P3%dbA!*}HKu+Ui`y zw%(2|hS;}Kc4ET$OT&b(@9c>MOs#?|`Y{kit7xWXs2V3}K~L0W#n=DH@Qn|2>J@Uj_(#yP`2DyNS5c}R&=_;L>g<(szIhW4 zggo&rjOEUU=i>@m^X-L^9PYXH;OS_v**U z74MQe{V{T-yX4MCd~LeJ&$XidcgG%t_ugbH?=`wV25Yohf$|? z`OamlAzQ&U(8uNYkxXV4_tUtF2P9D(l)}~cQ(0TSP459~J%Wne`(yK!99|um;6&njkBdQAj7qSp+73&X87zZV8;t#Iulef^ouRE>^oH`?h0Natz0k)grM zQHH>OzIv!(Fl|E=z`d=BCEQZG+3R$Gbm7Qh2EGWEc%BH`0%-=G>v-L(zV|$_{u`IF z|Jj)~D6{12&vklDuM<6wg6!bRWvtVlvaiJ{7a+cjp3{7Et-rYmuUPO!Z^wnnjX&Y) zQXRfnyo}DoD@2a2M*JjrWC+Difqc@g4Ki@E`+n*F+qXqX^yL0|VKnzGT>RTFzeHH> z*|o1e`_$9^;6qY?kLv8}*uzl+^klF<2e&B$oDu7zDzPx|m{DUHta8NJF9T2pqgu4T z^wSbya%{488IU)+7WfbDpT5$Iz0Mj5uLT`%Ckp(@L-vd4agE2>FF&z%m1?xqXV)5! z2@CTTYappXPm71Beu0Qo(4DdFi$&tTK!g?bMJhc2lIC_qD~!d%#2`PD=7bk+1>#X+ zJV=CmnDf{P*n+}AL|7DCrUeNPTK)~zOz{Zy(jao4$P?7xqqF;?FDiR;8nY$Yab^Rb zfI2+{fs9i!@jQ)!S!V8EUew1o)Ppr^%DJvkBiJT)s7C=zUD9AIX=4MZHy*Mqo zJH05*EKQ@#s30xk1Gw~MT&yffNr#v^|1e~i2hIZGq>Me-O^T|NndW#;dQ zZ6VM{$ctWI^f%a{1i^+XhvrIK^VPm0z7L+xrDwR+00j~{RN-;dffl~Lg*~q+R7Oi4 zUiXOsLT>b90Rt)R!l4F91O7@K+KF*U-bF(xN2T3MadB8m)I)WkqEt>al&BwNSOLwo zt7y~U(`yv!qdtf5o7%Z_puxY7TI%B#l+a>@X*M(X^{fOBjb~|&U&kEvH}yezP#ILg zQ-|=LI2S;Yh!_L9hLJ;9g+vQRZZ?fpb$Gdln zvPm<&ot+@0Bvnr|uchX7MCr}#R7X|_8*s8sr9I43$8Qn&W$L}mvCwF2=Y(89$GaMI zqQ-*U@A+Oo-s#6tYD#{Z;{~OrL7XDlqpSWgPqVOMV39JB>7jV&zxAi)#$-*52R%79g$!*kLhs7x2ucFp))k%{6G`jX5 zoe7PmzFWf(gw8}2WMiXIpafEgEDsV$w(H{i^*dB9x$bV?>&O-k9#abj8fm9^Pkezo zAp6KmO3PW0+6_+~%BL^2;|>C+}DMDmWHrISbfH_4NEP@96cQx!3Fjqh#mg z-F!%B_Gq$`d=@1;$IPEzX*Aul3OLRcJEo5oV@5GDXE|Jv$_Y~Kbv)T%$E?2|4j}=fWt`$y$ zt|Dc58Z%McGn}$CHbf1p06ab`2)2uF5FwwDK5-mwxa(bwQ;LEBQTq6OJ(BJ1V-e9k z{ZDd{;uH;WY!bvoz-Yy@^qZ=l)(omKZt~oZYh= zN*W2;4&s++KA*m?*DOJ#gld9xmW}lB`LugRl$MXssI=-+VxPXRNiv)y zf5gU@4a)JU)#Jwkf@>6I2I(nU;irjA(r|^zcYT^#o+0uXBA+GlIU;PJ827Q2BGvyh zToHk8&XK{i-kl{Q@J1Oif@h|>|MGBteEc8#4lgDCgQZm1*NNrd>cn%or!cwAhBx(e6%& z`MH6P3rA6dTwH9IhGhbtS;B_;(-{O-0fO?7p?>vYEuMoZMBboX2Q(`rg-gnN>RZnM z%Gs2FAsLuMdqo*8puC7@%eXHk#VJlDDGg!vs1@Ta6PN;Iq=eBA84RurS7|1`GN>gL zK)$M<+t!hGInT2^jTsaV09omL$2YePk|GcYNt)r=3_k0_bAXg3%(9-;==2Fayqcg4 z$vRB!11Qg-bv~INu4T`x3?Ix&=lL8ge~aL!GJNPz87u&5%Jk2-6N1wVgGIboYj`2i z{Q{?yY?SAbIx6CB;9kOg0jVQ=36lAoGSGK_m>{``lC}H(J@0%8{gsoY#5|g~1h*oS zL?39JO0xZGvdGCIt|yCdpS2AjhZmqO90R8ek0=G6g@aXAu?a(v-xeGViyg|=gA@Z|Cehy@R?uux6;dCNuzcLE& z*Pz_buI*R3PE(PL#=|nq|845z=ZX9#kzXM4E|C|Aa8Um{RQfiNUnC+xQlr`Ty|ftZ zh0SZC7t+yBc`FEMvItGRek8s{VwFwgi>}FY6*@50J3%Pkqn5cWrp2ktX5htr5s0sV zFMf;2Z;(VG?%|}6PO5}M%x^Li2WdSz9`w`VWxU|{N)Sdpky`z*MIo~*1L^DZz%6Kjied0}7RKzQ1~dfgA8f?C7S%!_b(Fz&ZdJBBA`Z6RZFMz-QHF5yp+{+SR zc}=@Ne~`uMtmaMf?W11}1#;JHlVPH8g6^xQu)CM=zY*J-_z^77bO5g-fPa&$5-y{S zv+E(*)jX^shq!Hv{t@67SyOI>9VYN;AjCAo_qUV**>k0BL(06R!U9{ck!{#ogw%r~ zw-uAZRIhA*+4_#gDdSX}gKiFjX?9jpLZVw6R4`U0DbrYS8Q@ff?ajmDn@M$nL*`u( zSBJF(vqd}4@IXDS3`|N+CkTESq~csnAMLE-6G&c+4Os)>5VzLuS3X) z-z7qOOYU$EPQOA`-gLZYC~8gTqhPlmgmH^BKKow>@f{NUVexzkM%>^eExH+zyndMmz5?Hq=()FIu5fp41S4I=lyE z5XwPqwtr*~va%=slKN*OEID*8eYffpl)3R}cjMb5}?Fq=nMt zN;Q-ZO3@oxLPgdxMU>uOvX0ZlKgV3oktUu;L1sQyXXxSnllP;AW3M)y(h%=pfDJlr zA_&p2#joKezE0$T2tj;tlgK2ALHk6|6veL-nK+P1{^A>G_)}aF8R#s3vCR34m+WPZ zn;WYVU@&mVA@;zTJAEdn$ceOdj!OCeYMgn=3B&cxe%K@*jpVf!!c$M`GWe`X$$s4G zMC$~u_>{q^;Mkw5>)axvL)v*TYG~nq2w1<0|Mme$iMEAFE!9bje27zbddCt^R shC}A(X)nusz~ps|ooZ@5Clf`_lU0tEK-e21IVC-4NnO)xUsIO<4+>=;xBvhE diff --git a/sgl/tasks/__pycache__/node_classification_with_label_use.cpython-37.pyc b/sgl/tasks/__pycache__/node_classification_with_label_use.cpython-37.pyc deleted file mode 100644 index c5693723aca018a818898200adf2e774747f9901..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5266 zcma)AOOG4J5uPVGd~><)l`T=WESY-j%8%HI?AY={@lD+e&eL04B*S69EPtE%~a zrBc-JyZZ2M@TccB?cYR9e-@D2D9IlIn8x%_bF*4^bwwMI>6$w6jL?c~*H*Y0=Ayit zSGXQJk!fhU78Ts$gi~_M=wXGGsOnY~&kk!*-K|Fp?n2aX8_}Y>7%jO=(XzW7t+*?? z_L0VNEdQ0p@mL z@!Rbo_uGfkNdn;|LBOOfxZjOsJ?h3?Z^swyJul-p!RLNB^hF>S&$P{$MV~P*^ml?V zY10g*zrqAb?gI%l+9THp%&~QXy<>*@OlC2g*`31gwT`aVRB-Je#~fDpN;}rwJS(!& zSDNdvGOM75oTa%%thOr8Z}(a7LFgw*x6^I=qT7!@>59F3G?QN>!IP^Pc6&FQ#JBH0 z;5yKDJ)6}U*O*|(zGa+)OgF_ zn%Erc5skbu;jX4dwP6`|ZNgoDin}3}#5twox!&`~`crM^!CzqZzQw--R%6D#!@m<3 zSngQ=5^Ehdu`12G$TVzie!FEzBjmFDMbO>d6CMlN{^8vREmIaVyGXi6fy^g+!%im* zWP$Pi02Yu4wZ}(qeXW)tWkKLeA)Zqm zKKG+Q1U!*f=XUH5L^r}d@n@J#FzC1Ul4n?B*iRC#6F-AW7JS}q2lA!4^Hg_#{&FHX z4v&))xHeZM4szlU50Ny6Pm}d`9VPoHO=2e;udZj+38| z1?3ODE<2Dn=8l@lhRXZOJUpoCH-Au)ob+-;kOI0)K97(i>T9yJ++m{wg{Nf35qetJ3W>^@bo zoxDbSnw!{uZepJ%(#xuuOQxiO8xGR}8eE)pi!pvmd;;BWt3;$#GuUj zJ&d|JIOI!MC&ioO7Jz0}jhtSjf-m>|4}1qaLti>mHjE|RF^;bMum{|n`pPy%I{Hyf zdje~!P0;uA#{-O5R z`1Efhgx_9?|5cR72(^%^p6H{TDEDgPdYVJXGuAa4qu1EiMtKD2{JJ)BsIF_rc*R_K zk&=Vn#sr6+QCXbpJ*RX%59~b4Ju*g>+1xK+?$sI91XTldF)bj0sO_4d>j(c->tCn! zfBN^)!mc)IK;H|LNIcg0U&M=)a)8d`a~@fv#e>V~B4p2xmeTV0Qc6h%Rxw&m8{?PK zm!(MG&2UR5%rP%05qUUubGu87WQ28F| zdIKfXN0QKo^DT|vOl?+Pg$=O^@Kw}R)N80~sMk@~QE#AL7{8gVD_vQG{BaI;zMjGo z`2BQ!WTYD-lh%%8>&GWrHbU0^*yNX{5*V{VR)l(SrcYP;GVW4})DYIToFaGr+WN*) zHfQ2iuS}C&Txh!WM;mtoUkrKBypR0M-$m?hvVI&iyK(c#+$lU-F%IWo+$MV<*Ai_rUob+9WB6VWUq!9t41uu zlo_M0WDwbAoC-3fU`i!ND?iIpCzaAS?;-(A_5ybC=-h+8ptKRf%^>s#Nx)7!72k)( zrW)s{a=p1j6uG2Z&DKvi`96N0z#4!wM4#hzr}0&yEK z=l&soiRc_cE9*znq8h2{o{wDh0*RBW<0~XnqKK;^NRS|D9d}+2fP~# z{sM_#Byk(FY46D#l1oHweuH>c*pGMl6_RNYU26AZJaLdNBj11Y^ zNL$7JY-25o*6t+SXW2PrK5jKj0I-WpWkMxj9-SCcMtMpEjK7X~xlU(@90navH_xpM zW2JDSLHDFZ>fFZJ#9^8`sl-B>A{t0t$l|Gu%figIWqxAQ>K*OTs=gOO`AtlQ!hP~3 zfQEQq*XxF(Te=0fgeYIt*AVH8`UdC*uoY_!aeoQ?st#F5mccJ7s%%t274=K{8l=mr zy`eXZGIZ1-LG8}uL#BYf4%&*7(N-WeKX;fUceChC)?lNk=>0Ds@Q&J8SC3=(VW#5& zNuE?2n}U3`-NRE7b^ZiTWM&|a=<#ia7=CaI+yWxdP53}1PmgJmLtMvy0?@OQnj`(; zZBZPT@bES}s`ta&vl(`sbCz@l_Ay8sWs&h9@s19cHIH}N5nn*{Cy$lT|6t+e-HeW#|+x2c&n zEHgavEb8jE@{>E!ABw?HB)qFsWRxs-x-ru2jN&nM>JxZ6A>*Hc_Xm`u2>@%tmxJYy zjnr}8Iiq1TkS}hia&&E8o_J>7trESlV?1l|#x8mu9Lg{xyzF@ohkiI|F=G5N^a01~ z^A3Dhe{gsYs7h?5gUb&eC=U5Mn42r)Y)*_$kjh0d=VM`rpoX)ckq{jM<}Pi|77j7MGcdkxC$@IeJX#c nJg4%PDkjj%PFJ5{S(^QJw%hMeI87+#(gu{(2bX{=e98VFU+3^a diff --git a/sgl/tasks/__pycache__/node_classification_with_label_use.cpython-39.pyc b/sgl/tasks/__pycache__/node_classification_with_label_use.cpython-39.pyc deleted file mode 100644 index d00dcc637b3d8cf293ba3d8d04a3114b94a509ec..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5292 zcma)AS&tjX5uTGA-dyf3_mC}7vMiaByer>{BFOg&2?Rxn6&o1D7?0S)U2J^svH8RCTM0XNR?@?$)C@cP?tUjcDGTj~3j8XwhAanr>6qZfh*Z z@?U8zA5`~EcZumAX>I2T=B3?j>(Y9c`4Q^k2fpy{415-FP|kaP65RKb?RH)|ey20y zerH!YNg%u=2$-}5_j|FdN4>b`J@7?m%gZ=U@VOt3d=bd`6KzM#qR*HY`VWFI>Cg;L z{t8nfxd$ZBXrEjoFelal_Kz9rGnvI~W_JsJ(z?1@SHZP|9CKLVD{WtQ^Q_29Uumwx z%B+GKa+c;6vEr&cdw0Ns_d`EPdfi^f7rjCJNl$Fuq1pT{37%ZRuy;4JNqqPAeXaxT z)H7N^!gDAlN`e#7geLS<+czf00gfpJ?wP69H~SVaJ2e1vBA*%)=Rix%)MEN0w0nir z1}vs|z*3q6ET;}&B`pJ1(+XfMtpe858sJ=70&FnSH=h=mHq!Zbj5N)Q1vTCxxF$Bp zdPE~Pr`)Bqs5UL*E>F2Dr?{(PL7Y)Kp6frqub*l=3;rClw=MoXuo^Q)8vkCLW4V3( z4_ND@g;iFL|hN2f?pZGJ(CKwJnTgfx5F&rd`*NvY+B?~_9bprX~%y}vwFnc)>9Eius z8C;#I5(hbT4D!{P!pb(hMEG2IY9HjanSIphc|CO!Z#|=QHqHtAoxP4@ugA$x$b#~V zUXMMLFV7q`lMR&@mU(zm)o=EoBsuNnh#(QF+w3h-qSJmh_v6h#T680%6OSSsx~;SWl9j^^b1fc{$vDxSni=>u@#O$vSOoCVLcyye|-kGVAv+ z>e_IZFJhe(Z;}rHG_z{t^dc4AcDj`p(&7Q%b1)Q3}8FVe7 zs2SYDl@c@ueP2hw)%wPS0w?Z=axTk(=ZWw6V_gv4KL8`Mxu@L;+ZFU;jXR zY<&9nF@kZw#Q!1669isJR1frVPL%t#Nj=RW02(VA>F+nTwQ(NdI=`Zg9ja>@u!5LZ zy%50BcYy*q>cxEv<7DF(!7(OH#a~LT14GxQEhVvxFkVWF%gE`zMrQMkIbGKZGq$bA zI{$%Eh5qUkhn{g+oasNObUhF3EXzGI#+9SFpTpd%M^p<`4b=IxfOMj^X@ah+)vwU% zKmF%;Zc`gKpzQ)B6&P6{UP#N6iz%fWSiyKPZA@NF7Y}d> zkM$iP&P*<)4QOw~RV7C{NUj07%Lm%zC7eS8^KDPAP^u)Z_OGRlw7#OT0xLc;QY)>a z2llOmjQEd6*SMLT?x|hFNWYry9PwXf`4LW#e=OEO8CWCE9|b8;e;wuee(v|icxk*W zUI9hck?)^Vumz|j_iae5j925k;#DO>I$k?$MV;x%PxZX6WDfPb0V>}oU2mdf`bZM` zaIU4vTdB>;ORybQ0ltK~ih3D!4fP7@I_g!_bCb8Sb*0T^$nWQ1+bbz7fZs`1#zwk2 zHh1nI;riKumg&!0ADg^+*p3k!WIL$mkF@DZTgF{TksiXv7E|QVUt8Z;%F;~S<>q11 zi)&1GeQ$Lm@WqG+t$WDP{7uB~78}Gts~5NS%(d6MQcv(iu$kB$1{<8b)RU+9foHN_ zv{(6%GLsvvFK=9bjg6lul?~;gm0zX?C>ig~!98TJh9g#uScoM@jJlFRL_6YCkRb(! zRI=Cnb(S)zbiTEL#4p(j*!jIP_XmQ~MF_Wo&>toNJML6`4;ou)oW07m)&ru*`P^u= zf6mFv@pA-L0Hh%X{47z+L|LSeHwo1r@@vF@j=*{181YcrKIi@}zeIEnA(agxX;F<7 zb<0PN+9GjsXncueN)%~T_F-p$lM2A5%~}Ki=;76fLt@j>sRCSbmI~W3j4(>s8x9gNU10l+3LG*oJV8e#9XbX&d|6N!9f)+gnJFgHRxv-@Sc{^xI}P($ zR!$j@Tg?JLc9E$Js07TT6GN&fkBNZsH!v^P>5h=QpyTo8xs_3@6izkhp0r4v+c=pp z9HvYvp^&DChEf-@cx>acaAeyuKecJ~j`nC(KMI!o6-?#^O7a%~4ROA%*9}LvbPKSK z*k09_5#NjYD(D8VrnQV{zW{z!ucKIcS@DT38z_*ys5c>9M!SQY0KJ|TO;Es9j~^VX z3QgHp|F4y!@p1Ks;GxV6L=io{$q=`}DYHcs`+a#%uz){|FEspt@;JKxYRFCfLwO-FSBhd?!HF|zzu5x+W-In diff --git a/sgl/tasks/__pycache__/node_clustering.cpython-37.pyc b/sgl/tasks/__pycache__/node_clustering.cpython-37.pyc deleted file mode 100644 index 626db58444aae2370673be3f023e9156ea315d42..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8114 zcmb7J-H#;KRj==^uKt*w?~h%tYaM6T9m1?-8=NeT9q&5pVC+f61_xSpir&68Gd)vX z-E(iXJ=3n{!DA&nBw_^;2uBg!3KB$-c;E?y_yc(0i3gO+14umWGe}6h!0+6u?w(J3 z(lhtet^0k?J@=gRJ9V#8@ihFtbnqa&|3yvvPkNdBEM&fpB>otLX-to_M)uSjI^K=Q z=$H*t&&sThZD_jIaT*0xVn%MK*eI&J6?vUfqone7RPIz7l}@!$?aVdiI<-cvGvAn3 zbxyRV28%54C#f3D%`;)pcpT&YB%Oo!@#V zY{mjP_g+(k-)V~dx-DH1CV>b;CTp#zFOrbAMdatVXnGd6S==XT*ERgc>ohVT2~r znSnoZ+x~&Jtuu?+pA;KbXfuZuKGBZ#hJ(DziYi|~-eV<|cUhTLKG7OQR%LT|dOy%C ztx*D7YO=c3W#RqlR-b$cQ*J$30&SJ^WzdtBWFm>b21&F;PqkxXWQ=i&X_0CBMR6nb z=*e#-Sm?3-v7Vw$V5QogwPz#eqy}gqaZ_Vd9BZkWTFm&7i83#x`chf|EvF7>C3Qio z$=qIzp#?^mM}C1>s{NuWSwf5Dv;?}6YNOS$_8yJ3hO~~f!8B~o-WuGQ@!C&t#y;zY z!jHR&FZ#V+mnZ(CX42aA)e+;Y>xMKUF3TT=?Va5uU}39y#C2>M=Vr6ys2A3+NgI7b z&PynkZVdD2#+$Iakra=|tS1 zi9E<#6XFtbAvex7?Dmp&r#%Sy1*3<(ZfjRuuqG1ZU^~8`3kyGJw?esmZXR_vY$fi} z!GhwQP`Yuy1KDPww4!djBcD0fYjSfGmxDGtlqKrZikd=%LS8%9Ax53Le&T|~#fszL ze0UhP`bj94&h>A$S{IbYowi&*-Nb&e)ObtA4xbMxMFmHbZAU|8}E3KX5hazCBKfGG{I*o7bKyYF<3Y=5vs2?$8+d@ z*&H~v&Ci$v)mNLNe)$|ORJD?U7fwucmTQwX)7j_%NCzmc{SEk+5Z|EubS+WJaR?Q- zXO2h-+IvuJpoOt9tu+#t=}?RM8Vux*EoT12q`s<85!4%#3MB!mI z4eb=lec$5WPs>B+;QQajK1a)>wzIjXJ>NRj^LJ-GG3qw;#VX5+$1bZ(S753Y5a(T1 zCDQ_a&ZV|`)>4O_KXDXyEtnf;+k|w_Zw*$y&BGq=wn8CrvtS(J&~L~7z}$RkTk2v5 zr>lO0$~Q=f-t|Aa`}~V+_=GOYU_sq}b?2!9U50}aM1J-H*_^)ACh{{asX%3ilQ%I~ z`BtXGn@PJH`ww7M#BRuL4z7Ktn=~Vz8v9|??1_-gI(hd(k@#xTLHW!6BfJG_*1>%K zHk-D-!pV;Fr$D4ddc7$VI9YyNYqlr#LuvQm{v^_Xy1z~xSE$lfN}q2~0jUc9 z43Tvr*N8kz1pXnk@fJtw1lq^`6`jqN{in@8EG%^Hi!nD>FO&wl2OxZ$XoP> zylYf-AIZ)CmP`*dlu%+;RW2`4Bls8S+1x-)H3m7*sv&<&8w?0@Rn_Kp8_bHAm&t*sF}HV{J4y#!bY{!!4h@5A9?g zGH)FlLnn1c3+m1lhVHNkm&;9y>ModD(roqmaDTj+?d30GiS@-yN-Q#EvZ734fRp{O z$r-E#IO(HhcouHj*+GcyK%iHZ3_BMN!&VGkneCZsi*`$-ZSw^e# zIU@1~X5T{+bR8ffFsExqS+D6!5C~7N8iQ+REZr#)FOydo!&xJ*FbKf1enk0l5FGTI zQC7nX7+M)}>Edp{p-hg>E{jB!(_&;L_ZC`V6a~Rdq+v8)qDCap${{6hU698@lDu@i zY|)?llB#Q=c={s_(@#7GnV<}&H%!79KoEcg8a4EI1MwgT6`vtz10K5?@ zO4JnxcCyr^Cj$@kleq^$0jNj|T!n#(iI<{BQrb88yD;v^$A-dCl@!-IAj58xSr2Lo zKf$1@t4;U`2r0<$kwQnoD&E%=A~Lt?7AG-~W~X^5omMw)K|?#D-RVW);N^G1WVg!{ zK6*&_D1ia;@9%cIBJ^()qi%m0&*tH6zsvpGS?}Ba&HB2uqwa^mM2i@k&l7n>g#1)_ z{fx^^r!)mbe3_b@MF3fo{JvgX2`Gf6Uec2XQ5$x}BETk17%`=(&N=_tG`roNxTK2Q zdj17k;nNpxVx}uESxDyYE1M&4UO0p@Xjno(hVFBkkOjQ;IipNW<5>ur?Ug3ExQ51l z>cX)n9z?y6fq-8jeprfYPIe0Xc@!yYA)So%6TPHtuToAV3PU78o`Ys7vkGmRffh&K z08mk2<&0*fW&e_ZrH}D11`|rZ{s}E22P`Y6^@*AcESef@^7AO&r#|HE;MPrn(Fb_H zWFB%p{*3y3#ynv7)#mx|@_ER=`SFQ)E@Ix~_^vh?zjQKLJ}Ihyabnh|FW}bhK0i)Lg-B3x{rr$j_cq70#&+WYMmiBo0%hrN|=Qwp*0&%LTe75gw_fpCr4{V zg_UfD)(jxB$3R?0>H?vaQWNhZU@jBy=tpRdZvu5;wqfC5^SfGt7^6a8Ht^tNL?Bq; zKSE!KI$}P9!h7Qk%>r$~&r~Jb4lH|Vt8VUK^Xhj#Q#X}mosoe@n&$3DcVA$?9K2dO z39oz+!dfvgSfeBJA9lLkWS1b{6>-*Al^j4dOrY0Etd(KEb9gH&hSpT5tBVcuSBS(Q zvKs6Mr1kSS7ymldkYwZliN8q&-z4&zMBXCu5s{P#Y0`#?0XU(Uvwq_Ri09{MCMdj- zAs$4D;Ly52Jl+)FWY}j$U4?5F&N;i8N}gaFrQt9vsr4d~lI|a(bVmbLK|BP~YwVfG zSxi@W1&+%QF@nQSQ|k@wvHt#JAQVU`+$ASL5E!l9)CO890O=Ht^)P0$HV2O=zOYweCfuh_5pwqOTGV};Rz`iQaLy1# zADg2&<^VO74u3m!G3r~0givkyRBZ)iHOO^&NgK_hm!4L}#;~eNtGH>%ryl01jusf! zpjPpBxG#vYAT(~X0`@+aEbT3)bNO2h>zF_M9(F@}UQs*R8ZIz*SU&s;?UC85i0flT z&kYxG7i;Olx|S{?xr*WqYgu53(Y&k;%a8Sgx3KGZtO5632WJmn%IBYl4^Q*YD^AQ) zZL*4+x`rTq&imBC9acB?;8xB%xSoi)NT5e@MAwO4!dY5YBb#w_omR(+>(uXg)o)Jq+f;lZj6v%%Vz{f>X<^KJ=#7)aNKc{k2GY|=KCAu6 z94@6xqnp^pGYaP}CtpC#v-FlOLt<9amC+X!zFi%z4cF7vbS+)aV?D6cHHhbv|HF@O zQ;sB;{|3m;Km6mr{^H-$fBPE3vLgaUKMV&U z*8Lh<%^;t`deVI*YIYv6=BxjM%Kt*r1}|P7A{^zChN^Z^n}j+79N=c zS4Mn?bvp>lG?TD?i{C(vbh2++e3zQha-_W-b(;y_rvjzb=$h#o*jm&>NL|#_gF+v8xcGyh%9N!?^NsbSwTIWi! zHYhTzzS}h_lLi^I(1S)P|42kjYPhs78m?gy$zEH1F3qIVYj_|*)?7N-v*ArwAf(=w z`X2v1T4{sGdqh4vjP(a-^3O=>E1^bUAHFn&t0+;Rs;X~5xmR_=pgIp}MK2o^!Kk8S z0c{ra=MdANh{g>Vg9V+^251e*HI`6AwS$|VD_8Y7{ywevdqe^dSpdN6wh*Zxx1aw$ zku%#lZOarw;&N5dT?#;}NQkkW^g8H7)(wl`O!KhdG{6s1tEcC6PDKQkK-W zs!jEMEc=kQn3crw=5{}BQLq(pEhCPlH+$!QM1AMPeiSx&yg4z_bZ_M%zRYZ@FEg7I zcnB!Vt*rgDrk}JUkr{r14Oy^7?A$WT?4()9upleTk>fhu?an#J z?yBRNwC9-1b}xg37Og;9E~{PaT_l+0jW>jNhX6rU>3|ihFfix*|w?5d=(DT47Hl0dL1Ue!`pWSQf`6e%KL0)U6xnr9MT>usI<$9taPL`@~q4j5PKhMmewdh zG!agI=cRNvVU5Nebx8Hg1r^AE$ z*TqhFoitWlC(%yi0Ttcs9`Gg1ri36agJ?$Ctmv-c=pF^#ct~mWXRceG?il^Glj^uR zQQ8)%_T)!t3U49^2LuNV$LJFs_LF}B`A>BIQ;7IW$K>ack|v~0wnLDfYDT|$iW#bZ z5f{oxLMYJj%>ThXZ|BFFx+`>)DG#0`Fje+1HR1g*sst8Mm?j388 zA9#bpporPn!{s!`oHcEvvmDZS#4cjYgyj(z5POJ=hzrBzj8E#Wqy3?SnR}^q&-%8` z%c(c8V6JVIyO0(Jb_!#E&*DdEap3G9!HBPr5zjau_q%zb-}}>k81pXm#2kxCzAh^r z^EVZLlDf;vbkRVT3#qN*O6pMjJx58_!sX*^o470%xB9DJj5xMRd_z%G4y{{2qU3~SW57KF{N2-vif_eLO+nqt~7 zUaQ?j&R1Ofi{spsZGDx~tfzzn*D+4#6hTbXC6C(@g42LA<`!plTnA#&2F@T(~MIRx=l5Dga8 z)mL={qoP-kw&>3&!@8FB<%dNRsqtUse=XNTEd^J1P&3IPR-@GB*xm7!b;4?pi%pKS zvKqgK1kRy0)Rl8ZDI>|8Ap4N4hchyUIi`_e)Yt+Yvr+^!ew2hOb_cEQzb$Us+mk3$hZpXNV zv#fOPMFaq(bD+3MXY70#t>`hTatqw|5Co|=R0UUT`JsjD)-sEFMPG)BczW6BpFef~ zoly89d5kd}Ir12NAA;+JlrH-IeyXMf(P70jOrMJrd8{CKenLR#%SEI(!Fo;eBbi>gfv)hB**rs*e0FDH0}=?#<62ylq8 zqcQsj`WTlG8p&|U1d9+ZnP3sZB?^rIueeAnM1pi4=?Ys^c*R4y%9d2Rz?RtxFiVlG zvNgmdgdp=*e*ga3&Y%4${eJw~&I{`I&y$jWoFxACH3b@4w2zY? zr74i!i|>Hoh-xE^jw;hX(D`}-kVnWVQ3oFQ%2FnPu6!V*%wh0zKu((PDx8#0yc9i> z!aak34HrJrv7vBQDaB#;>9V)!s`o1jW8sol2c0k$P?VoxD21T>HRRV7YBIO#77sB$ zX-3Wa(rI<#7VNeo+EF(Q`Y+uIlHCqd80t1*s07zi?akdzM+CL2B(AI15I65%t#x?q zYS#N|?NWV1+F|EI;3h(tyh`LA5%Oi_g;Q=h?bR`8;-s}_P(s#Ze0onqh^!KD2~&MU zPwt0p+%^`0I%(o+Q+DeNK*&bg?R3Q>s>t={&(I7n&RxXRaz0`r86;3Qk-Ri_2zAY2 z3O-$bk7+{p;+4mYGPy8M!_;i8#AJRAjs4Wzu_t~-J(od*Ya|a#NzKW2f&UDO)ZHPS zO!QpYCv_hyBSnfZK*;Q?r7l|7)C|@*`Z)j@1!hj!SDN$_0h$TW_)VJ79H5~r?H|!Z z#(-w^q+L-AA*2(RIQf7S)Fn<<6K4+mGT!{?B)3 zc+C^!WFw!1;YW{PD2ua-`{y&9FV2zZY}+T{o<5#C1o*?Bl=~A3n-q55#>&O%ddA1? zFK4Fk%x4N<+~I^2=sF_*M=98UA#oO4G9CB{z8_Gafs{RP26$B4Gl2oWF>uEi5Pl-? zr#qsy8OpPnIdrImS%?YQ*@y|*Ifx0_<%Z4}+2s|ovK6v3fZqN9_{>OMpt(Y7BL96L zG!uFBBjm@g0i%K2Ah&-l!3(8AYBrGLWIP~UAVorI@DsqzpohUxhI)a_;ESq~Z3nl0 zVXJQP*Rb*_K>;1QDd~G;b8dWe<9RmVuT#|r<3dw z>N|#7{4MJAMYWAM?GuzW3#&4icLrKz#V~~mVBMz)yiMdo5Lx!``DDt+4_y2gspdV9 zF&5$kMEEZg`4u9!i5wC6CXwF-X_y#*|0-o44{rV%%>=pZ8MyHPaXbK)ilSs=EB|NZ zGN3c1jsh^%GrnwQYbPK^nMWEKKk|w52ueNssM*l~O7Qpr1vK_dq%5W@U;>{6Dq+2) zeOv$f5nu?E3|^3vkg6JcZfXOPkNeIV@r&wAxj*ns*0P=7Kj?$Luvc=VLBk7%brkRnqE)mM$T8R0g!nEh2`nnft>JjB<4t}n@6^G;Q^#!ovXU5lNIa?h13)2p zlp93PV+Z41Ra-=@!L{UEa-Nk~`A9=OTBsdfN{x))@B+UTlS${4TT zh~al;dxSBcLvQRJMtTaRPa|AJsIkho%)xTHJiLT8TvqUPCHWj`o}pa20v%aRSBIZh zuyt*)KG;at()Dy>{FK8hP{jui@#7szk*4x5f$aR<-~Y>x|1JI3SMWW?WbG(3C1gYT zH^_3}<%SJz)c7kn!_RLIYOQA2>V?@O3}B>sgrPT~?1)c{h;o;Cgi%hIY#P=?pd&2E zcEj}I8`+4yOc`6sXki-pUM%)|LC_EI`1xhDnn7^n7fJW!uo>ND%~$^ix&KDc`Y+76 zK2F+|&6mRtCNEw^BJ0g#^n}6ubwrsTR>b2C>qK~SX(mDa3crXN>13a=_&znG=}3D! z>@*XeP=T^zbk6xZL|m$On%!;?XRsen=-nvdg32xW?nTc>j70r`G(|7szfUzL0Sf)5 z%!$r+0;rI^6C7JYX@)_Jj|t?S1A}yxWV82z1LE2zqRi?YO8o{A5-;aOe31IN%nkh= zq-D7s;~l`?4w^}i<8y;MN%2Ff*10mJ4SM!fpXwT=NrMa_=zgOx{vJeAYPhs68m?gy z$+E2;mu3=m8y<+CHJ45nH@xE+2&uQFzQ^yKSq;(LQo$CHK6xpa*AQP z={1U8po;oB%sE_TgX&y_RlR6b;XId7QbC)Feg*aPKykq+>Q(fp!l9;Tiz42(mQhnd z=wEuQUey=)HqCg4$S#P?0jG6Zcz9T%qBfCJ%Q<;_DWJn?O`Kj+X0(0!UO~W+wp!Fc zutxL3>g1*3bnISMFcQz1N;gS|8wK@&YEyj_%f6m1WhHUEx!sFf^rnhuD!snPx^*Lsf-V=OEaw}_pT+>V1p~xGvbpD}v? diff --git a/sgl/tasks/__pycache__/utils.cpython-37.pyc b/sgl/tasks/__pycache__/utils.cpython-37.pyc deleted file mode 100644 index 6748656b1276377253d9b31b4d00bd729ab82b52..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 10793 zcmbVSTZ|;vS+09m-+L}Qd-d*mV-mOHtaHOrz!>M|K#D4$CkX)aV7RfekZV5 z=L}1uuoFt!9hOIxPGwZ>R3+~XYomInjKuUef(T#7@)&A1#_5VzuLTtmE;So8H0 zE3U_lSFK~av$4oGL-uRnn&{TT$gHOFqB*s*=tn%g;2GbeZE4pVdY#G1Le zJNHbS`MG^;U$y4JiG^7@+i&8qtaeGeqv%kVvZHCzjZ*C-`#Ojx2X}{wcGH8Ac8AGW zd;KgKwLKk74zlS%rh{R0HyNhd!vbPm!lJN{>`yGUjxpbQIw2cl*w>?Wzt`G>#`D?WGA8 z<&3A=&nBvOPgh6L_#hf~rQGWs#F1{ED(XhVp$_jxy}e`{r`kt@@mSa61jjh)kNauX z@9Fa0WZb(qiqu})R$EwwY?TQ%kkw(DWb%eJ&hPs4qU9L3D5}X-QisYv4BV7Vdb~!4Uv&-u9 zD-Ze71MJ=-=#q*~a6)_CS-011^|L4|PMAB&Q`ABWO3(Pb1P5 z9^Y^h#fj2YIl)CHh?D#MUeakU#BrB|jK-bT=>qoX2nnEC`!bu3jJc5QKH~0?YVo}v zMbI&1g4P}LI@3;30iCCgjZ-UJ63Y9~E~Ox`Kq#T(L6)^3=(p^ab0&v6^>MT+2#C8u zc_#$4jl`0GpyF)9L`KyaasvB!xevw(z&IvV$^)uDu$x1IK^##1<)tMPfhRH*^~btQ zDFG1}#rn$5!EpFCr$ohML_H4;nj%Z4=5RuDIAs}5c!~_7OT$T;c6Y|QW-1nY_c7L0 zq;Bzk;Wtj@+?Fb?o{}Twfl!i$P6Jz5k(AW-M8=|$9LKb5*gqn%+_2c51>lAf;_&cQ z;kXPm!G|9B(95gf!+IWodwiz3o%7O( zMcxDV*2oJ;m+~+xqd$2vL}^7f-Tk4 zAM~<5Y$vX=U6a+TPa>_JV(_*+MD{U^#GO>1MzG|j#km+3QlDit%v($G^wcjR=s;zN zbwhUQG*)NUg>}?foHwe}Gwksh20UH$90OsZ=b55mWCGFlty7F56z~G-4VU0xJ^mdW zv9pPTCa=`&8n~zC)W{{jgtmo>;Z{&ukWjZ)DEgdyNLw16%A+6gzAO3nG}I<=m%m=J8jpnV&veu!LW;XgTaacYcmk=ZNWJZ_6oQZ0Gj- zwLQ@GAuWa&R@&-#-YP$}n~amgse1Z8$K{+n>y;Y|+Z@t2O>aL1)fx@cr*9W++ZFA@ zhTKgItzAMd?Tsf2Ml@^=?Tz3K7@;nfiTg@_s@{u6?Lg>82S-#3v2HHk(@Vy%AyaMl zwa0gMLPLI|sab@rUPN~#6jfVxVv`|?Q<9to#T?q#oW~${TlV{5UaPw(?TNXK1dUB@ zpNNcaBC!-1E8##HOSL=;Ay>iVvHiLO;ZcS#2nYUB08l+=;T+|I3h1JW9=3>)Py~q! zHm}WV3o%fa*je_SIavc#S%J8zL$EYYtS`8;#;iGOp-+n{HgD#w9a4z{6{{;_*k|G; z!9W{#p;WW>+DaJcy4X2|YUSza>P|nC-R<{q5^~7tllO%ySY( z^;i)zp!m?i9tkB?FR`S{fU-n8JKq!%_RFEm(^5P|(EF$SfVMbRE>;z@4-*akd6N3|JVXIB%$c z8O_-c4k)PG&OH z6s_ZQI!O~nwl|yt?<9rYlqTKLp$^y^9#=q<6gE70>HKyy7tOOtH#?XPldqvoN}Gz_ zRS3x%;IDGzi}jN^sY-+=?qVl$Iv!XZ*ABKpd{(~tqk zW4BAkEc-nvYa#U`)SxYDT2k|aGp$Y8k4)?Nni+$uWeIkK`?hiTL(FeOW?7S2I_f(Q zn9)1U(ov6N#2=rTrK6rXQ^vVC>dvBUHD^cNld=uee!W=DhP3|jnGsCc*G=nltNHO_ zHJgWTVKtkwn!2pURp0$TujZ{coc-^gS&ge2XUe!5SKV5atybfzC(pFwoL%*qGwn>- zOH#Ip+HWjZ4T`#&Z-Um@=QmD|bEf7SrstU*zENEI7BR|ni6Ea08z~hXX%7HTsw=%zEmvgU zG1fo;ln91E`y$^q5QPhc+CS)&ku}N$&=|L?Q;Y^ux4hADi1x{Kz()Nf+towZXX-5uDbiS{viw{N0B zr?nc%)OVW2c&0#9?P7%w@J5Myo}rZA$9vQ7Ah3eR;hJ2b(}I7^zGPo=0+^dbhTy!w z_+PBg{uD&GBkDY03W7m2IIo#MHuZT3(VRNKLvMfPZm)0a)<< zu3TE27c1*UW7yZ`GNS?fYk_0m0wFwxgCK07{v5OGapU=;|4|%(haQ4shw2h+Cj*oQ zOOQhR5FfSzY!-rm@IZW6SAZ*!2j(;9cDizC?!3&cT^`!JlKa^BD%%ky5{uQ80p7@K zivyG+D{`*U4LHDR4ETc<&%8zn4Ec|&}gD!rIN3wD$E zGf2U+$-_MlMaiQO6QHNi z2%AvoDgn|_*D@*N}!sxnL*h*#?XUJd~_11UTaoNy3;#sDc3DS@@= z8Awb@fGaLaYB*6qWYFMs@k{~9hXV7ijz|3kUa9_)!D`b<{_TeG2>%M{PMspSU`~@Q z7_`$6k-e%m`J&z&!r86Y)#Ys?M%J|@%L=YvTUK#cT`%gGJ%1W&XLy>P?nikT;ye+I zvbTu(JL>PSiek&@g9dnH%isMVhV~6Z1CI>x;4Ti%1|%+Is)ysK0#)bDWi|7NFd_$c z3gtEAfiHsnYKO1oE*zXM;aKWmK^(w>frVv@9N3)39T+%_$^K2SR|yQ*%FDnW8-q&T z8US(t997OMS@ldwC2xY+$@WdKrb#tGBmu@~NS|>3^SB8x?Ejx29>f0rg8f6Xf7YTn z0>BmGjmz{J-nC4_H4}G)PI9n)hde32^Splw+c3l>)FfJAMPb6_!tX8yF$_*cUq1D*2P9XJZjm@o!?jbgaU~4jnTCaYDQ?=E zr%3hZP@o7)tz=)PxfE-ZrKHbRA=!$OWfKvoY_nFhSr(lmWcBxa3-wl~A|x#UqWWu= z{(!;XFd*iBUUXi2JA$SZ9h06PM=t^7M5XacOeh&NO`DmtOOt7pQ@qb9K>ss#D=8UZ ziUXIE_sFDO+@R+kfg3`%CMDLZ&ssdZS9dP~QbL-SWs;T_G>|A0S<2Wu%yTcpmlXsO z)W5ys=%t@HIxYn;JE+jrY#>u%`9jkEk2mjs0q*=_H3ZzuYD`Lc6z?&j^IA)6PB=B)T7Abh8`2M3)W+{bOiZ7aw zG+zneqG&W-7Bgvz&(qBsdr;XI3jTEx%PqW)c;$8bmWa0oQy>(6EK_r?gHVum)IXtB zfe?hq6t>UTkRggm(lDP`oY=rP7C0_k6dVAK0B#Q~WT@g8SSD;CqPlL3`y-UYwmz}> z2?j8JV*7B5Oxx$+9$S<;=Tx^Rf z(ap5ipB~*bxcN;3^W4lPsF1IrmK6paxs7|&3^UU~l=2g@1quYnCdxkaLF1%Uc5ofS z+f&@L9H6b@otk=?y?h`sPkva8745yiPmbOzgPE7B0W$cG=ZIw`_#*haAyvH3Fm$ob zlmxUo^#Xlc!03>4y;rZX%o1a&Y}z3d^A%`mJ*7a##;9ht2RJO46_z-VAGDJG>EYg^m9P DX$=l+ zb?%J{ZQo{_Gc1mRcA(4durw;S%cDxWqU+wUI;yp6C>MnPMXOyGMG+uxh>|EHZ;Fbj zB3~0VQAfTm8ls7OL#&B)U@-yNgaRK>R zaZy}Ceoj0rE+aoLenLEg{2}p^;uFX(h)2a0?=_?j$4kZ$w{{p?sq$JcOoNoI=mCfa3|_aWz_A*{mIxgV9a`VxF08x?2mUk zqbQO6ZrtU0=06Vw7tix};-5m2T63K6kuB^q*4$1}nmMU6cUYRcN7l?u-MMG-q>$Q& z_Bm_rA6ZzbbMr$RGqj4z9fkK)F*%q3?_GQC7hZm1cyRApyfeI( zgz;{CZ9nM`_C=P;7NjO=lEO;L7&DURSRg^Eov6zgM7f0BbGoGqKGLhZe zsxk`4`{A&o;a+!NgsO22)Cq?}72FEDyU|$0s(=pTv8ss(=p6OO{W$4&Rq0kV?%p1S za<^s6E$l)cmq|8IRY4pj`VDKS#dRc>UAC+GVJa^H(&a(qKavC=O%hj7NF8DC+VYXq z8dxt_Z*77E?9|x=nVhpyS2)+Ly?VCc51&GR|h1`2&*~oz`Q1ie!bl!Ak z?#xTTw<-7UF+Fr<1r5ROc@7%lr-jt-Vdo#%H-RrNx)To7+7z-T!h%I7+mLE3LAQpJ z?ylx7aD8Dk5z$Z;F|5Pg21x|mp(IIm%05s<@Mnj!l2JQsKDqmVFWtupK7t{!)(gP$ zy0dPt+wx}sTfsC>c7f@CLnYt8u(4}zRuIn2Pi+wGOS2-{?6jD=M?8^WR!W08a1BZ4 zLejZu2~<(WS(L%!D9;N=TCP27&Ail4ifQT4M%mxnMj!v2HLGCmYU(^-?i%KV}h0L{e2W_OgmEqC5R=)NW*A zxkExm=XUcrfFs%({b*JNy_&Y3Gn4TJV_N!ySJdbiE%Oh0*@9Vqurq6OK; zu^r|_@&vlX6ebh}f#V@1SWT60VqLLF2)SYdZ$f!L-N?;!^ zEr4-+Fpen|Q=h63?Bo5d%xy5F6G3*%AZG9J~7c^ z4zdopTGmIZTf5}3{6}71@Z7gi$&XFghTo7EHP79f1+W#CM;WA!=Dm{Uvr<~dVOG-8 z++A=XxDQ-dB^S)WE#N{N=e&4if%71n3|Ei~(_+H$Z4U#CYZmF z7Qs=qv;v-7a+Btz|H${GMSQEOH!CSe<#S9vl5xm0hI@XIrC&nQsvG=dXfS*ue~!(> zll&}_HdKVRYREqA`Vu=GSVip&vuKi^;fR-+5RdW|CYotTrYG>Hn#UMJ7Dm4LI^?7)ex^ zP73k44TGC%Qxx9we%sosCeGY@hi23Rd-N!F9rhFc4E^Q>5&%(MA(|nylj7X}Evsj1 z!#qfW#CpdXl&HQBVUiv~=pMox-85uT$lHKsFZCcSOGa^jNJCen_Vt#l?5T1^f24}L zQ8YzC{xVuuEJAMC!$Zc6FCa;+x3<9iuz=sOXesodC*L52IW%`PgKK19cIWDo${r~D z9&H2x+w7R+DqY`+#?ig0d~P3;{udtW{>98LhqR63tJk4M!(sf~)kV)%SruS6?nH*n zwlPL|!mvBX|Nvh-*v4z2Be8$I+?fYr0YX0o8y|jr={`XbihCR(4-`d}li_ zx|v0ij!at_<%y7f)Tc6GczRzWpw6A;KgCTBiG|0Op+@NR6$^i;w}y>(Mx#NY%{G zG($s2Iq(4%!b4J^wl8LC4gTBnW<{^ipn6eGMlaJhGJ0T5%Nh9+Mrh0Ogp83pY`D#Y z0z~<{B$Zg+BBaZN(nL93P!d!^+Z~Vk?ErSRb|rsAwq)z3?X6|YeGHCifg43pb*qpm z>y#ySfN~j{S0-p(VJy?M|LLc_@vndot-H`_p0cJLnTvkUG|bHW~ilbWmy;QIgCKx!U?#S z!TqFuNEPdmk0Afoq0)=*k6R*+Ybmyy>nZYy=MzSB_FWx8p?hDHJR3SShzXw4dWHO*u@Z64Z4 z>xcHP^}a2?uitn^kY4n@AYaxV09#Mj()C>jBi?@B-upD<0O~mB;vr#Q(Xb{;-$M)f zqNS-@zPZxdz`kdC&zoir?v)U|!VRo#L$9!^SLn#EJzzB-w?ao=#%yY3g^qk?1;%wa z@}&&6T(2X4O~W?O`tD*U8@l(sl^G1|UDNx-PL$rsCU&x^cT&?kapfQUpLg<~?>l>6 zU)hN(trZw|;>yhow%my;udMXrdR=*arJsSls$rXG{c^rz(9zj+6V%Ny)8q53w0zl& zT-iO=i$3SldPYt5>=Zcle0nNHoP#_nz@2cM^IE`K51gBG>})H=fAG+0YqpiFlUip@ z?=`@p%o8j+nuwv+t~$H;%GHW#2j}2ZMe* zibl60`F9-4Ni6&R!y|;`v}@mq#7=a~w#j=NWBZb%F{H|(A9jrVRsqU8lW{mS;=4GG z?lRhj?ohLBS16h?d%07%ES3d@;Vg-|L|P{_YTQHs#;eDnuJ$s>nc|o;H*+JWw72 zII${sW0~*3e%C6RvAD?X~AsRbEtMRmwZY@a^dc8U)ENVCU_%V~|b{ zwrHdZn7z|C*`VEA&Scu#jm3NhpjB$N!w2}8jBTz^#_!|3@t-2G{Kw&noOQIv3HM~% z-gbQCO++=MS0M#1gpsqj4F+zoxGUb<{XouFj3A@ zPMRP$J=_HIv+D|k{qF4OONx`CWPQ;Y7PYy&s3FFc#j7vF+jtDfU{FH-1y2?Fx_xUq!F>pAqac zo{T@QYmcTvn`_X7ous`Ol;F>#!LEmS))>yf=mxZjDK#;E1KtgN8`#JQUcY{l5;m7! zjqKyljI44-BV{KKq7OM`r$1uy$4r*)%8*rUK`#jAxO1YKMw*Nvo8eH22$Ao*{RELo z-U!5!OfNrRvKk+ee@58LO#Xxk!@KSBQZmVJ5E?Q0b0*(pLg&;Z+IwB4Bt3SUl=KO7 zvodz3NVx(r0-6ep&3HD4 z2c02s%oqu5Ob?-iGAM`;#1At_6Hw(DAeFPlw$(QZ(&iU3u~?-?<(s^-kABly+4iPNguqK(y)I9Xfa zZhE<&CQ7PUG*MEuw;11d;=BmBd3lku^9`fLBp3(x9UA0k`IDpr!^^1QIev)~vgzf!%4~fsq42G0*^;6~Tzj zw1h}weNaxD0|XopjYL!u?$&WgIc>oLLJ8<`OJ0m#P)fkGw( z0t$hoNuh-BR)F{Ap7AT5$Oh!PE7R$K(fulcgp=?NvtQ%@1bs~xl3aRM( zq&F~@VgW+o$EcM+tCc$3fC8woq2u1Qw4terG4e7LhN%_zUPdeC*L2iGTV7t%9Ps~t zX8;bU<&+Hi|kYo!;-bD*g_E}r>$w4QGSN;Rv@(sSA ztjSmh3*~nR{Tn8K%Y<&zNl|(JqllPNb3__`9HVsb4R$F5ZW;B0j_aUX(kp2BSWW&D z`mcoeSFF|>m=l2$2aYChiAk@xBTqesXdp$ilD1t7Y{kQyb9WnoBb0d+rpvkqO(RM{ zLYX5RpWy2Wh7J&<`M9}@KXz_h0$^=Wo6EI8hQjIv-~7mX_uBc}=&qZcn3C?lc$-<> zayO{MZa=o8kAG`+UyeYsmbR_JNn4v{wl=t7hA*t1b}&M;GRBt>_;er%C45qR;n6M; zO2HN5TO}0DHvqU68clOBiw1m>Zq_)0%6*~G-zBkJ*2Ge_FW8s0NUO30Lg6PdW#WFxYX#cbiA00lq-oF3T5P`v^XOISdR z;NogeNAzG-AKCoK0x^BY^5GPjzLmTu?toH!k}wZwXJTv~Rs?R*Zfh-N@d5^@^>7Je4QMgBi8rXt>-1rP zN?IKD{b!$_jPVItaw)tNcl*lTSTm+>>$-*B`tu#l6Zs^pyup?sHTDv6yF`)_#ccUgL*Fl#R+aFPCFf@TE>l zSP@IpChctBw@k}3LVnWF17Z0v<0rs9j!IA zT+o!IcX^CVm**B-$Nkv;FnSKhW5rJ+(IFEWPEg`k0=+`F)NmmE7{q}ACLqHBh>OL5L<&O+V-7~MrtU$G3;+KYgMt*LpesOW3ez3cbZcb%|esOw^eo0Yga(1zPe0*kJW=VX!UP0w8 V4x8Nkl+v73JCF;CLALQQ0|3KGGfMyf diff --git a/sgl/tricks/__pycache__/__init__.cpython-39.pyc b/sgl/tricks/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 536c1fd317a2a22554508c64ed34f75b20eb2d53..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 214 zcmYe~<>g`k0=+`F)NmmE7{oyaOhAqU5EqL9i4=wu#vF!R#wbQch7_iB#wex~=3oX* zmY0k`NlnIE0?zqGMXAXpj(I7;x%v4e8ATxNewxg;gpwgL@rijU@x=(aTkP@ii8(p( z@hcgMSb=K6#4io~jQreG{o>+6{a|+=-JHq_{o?c-{gR^0ARsz$zvicVILA}`H;Wq3FrB+&*R!Zo<;kKROb3Q)84Y8ADg6n0dX}a=dymQQv@>yG24wt{!U_-D zDH}t+OqwO_SyMX3m41mUNeLR(Kr`YhIVQ(+1L}#kA3ni;ZXMG}dkSejhCJ&n z%D?nG#{@I4w*>XVjBU;fXY8V$H8NSA5?oO_?~-w&Y?Lnc^xDEj!n|P6>xpLRmEPnc z%<_c4XFnhtnBQ*ct&l5_!$;(Zc9D9mLvP$XYq4yK%33v_^j7K>Rpz%hU;YC{Lexl% zfCXC{>K0Vh0w~D@t`fnf49H?_?m!omSHI4Md&o=3f&PGlM4!vcFkHhGkc_thW7{IS zdmX<1gi1!Y-a6M=Zge)X0uox?1)*LP(6BWql*U^qPLz6y{yhztu|(Q8jlCopfN=#4AX`2M~N&H&(ggF?yKDh7FlVojirWti6Y+W#~~i0SNP`M}3IZDs2I~SLhAq(yP#3hT5_g?aOqu`d?8zp}7m1us(#UwgKLx zIgp!bAUKfwzJ~V(i0sbDU6Fi7ZV$)}lt!YL7E-#jHi!nwR}}Suq6-H$P}hN6XPa^r zECnt30W|U>1l*Fu6p)`F1PHefZXW<` zjIw&9x`C^WnMPjl9T^Gv6*`LlB20UBzsE_IaenF|Fy<0YLQf@f3{OWPcx=UAx>ofC zUaj|HmGPZC?{tTGjDIAe%&@~ diff --git a/sgl/tricks/__pycache__/correct_and_smooth.cpython-39.pyc b/sgl/tricks/__pycache__/correct_and_smooth.cpython-39.pyc deleted file mode 100644 index 7e5804ac888d6852127d5a2645b47158c500eb89..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2365 zcmai0OK%%D5GJ_~tz^k@oR^&hXwfEZl{8Y3OItuq)22v(qVUBz=|ZfRD|_YDu0m36 zEUbF50lSC%gC3-#e?|WXuRXN61!zybq%&Ok6}u<}e#^&jIGmY}RM=?L2$a@~r*XYX z$nV%WT^#5=T!C1K7T9mY(EN#*zLe6YGOK*IUp3D0g}DHCpA%SdU~D zRZiJ~Z3Y!RpdJEfvP-6P24^m4K{zWIpi?-2Zs7uYg$L*tK43+Vb#h48nt}1uXc$Wq zaGqp|=G@%n?^?roHX87#w2`ObPArvq{{pcikuV)@h33u$5~xVoBOd8_?*gg0M5*TX z1p?tn=PC-**sNSIYiF7j~s={uVC-EAhn*WT0z3CR8# zg%tsQXKVsFGp!c1@66}~SJnz#Nea-g3YrmD$N@Q^9jM3JeunmfV5=mn>G}bkHfE6S z6UepxqWoK5Iv|*R{UxZE=gjk7c@rP?tdi;SjNppGyG16IqEh%c(r*YKiSvrVXdtRZ zPz2K}AQgz(w)-vVV2=AmaFyJEJl-e!w1*U|?*)_U*&U0ju)A4#(!W}+!0c}})?UAb zBEf4ULO3eecA>6ARds-ZOyOz~Y{p=jv$hFCP~Q9|Q40A9c`z7q$O-yVuE1CctI#q* z2j(_J_~?D;{RNeb@12Y7JU?2UG>4>CkDyU5YSM0fm@ADpQWPuo8f|;tWX2K6z*M&5 zcnA$^NAf2m!MSr_VqAf$mH@MEs0)uxXq}Gl|NlvjS2nL1CmN27r*j!?8DGblf~`EEcZRWX(>&WW^r@-jS~h8P=@HOE~WDaFzEoezhFM@@^?C`rQDp0UR@nCHrH+ONV$W*&KxO8i>k|5Nwn` z#FmY6Wgfw07*Tfs2n%RHYv806S_f)hr5`Y#-h}UUsC8%2y-vp~|K+|j9LFGGeF9Z& z0KCg_U^dmjZeaFf4UZ16*q<}oGJMYL0GJJYMwS;2^0{z2L<8R~KWo6xgp_yU zHF*&g{P_Pk_e6<@S09K`UV)i@vIVw7UL&iRW` zn3gm43bqF`l`Q9*G8D3dhJFQL8uJT2iqn+yQzl+V5wuexy)ZV`$bgZzjIy^&3qGyg z@~NF|2BrQhHtkzk*6NM22>)4xsbN`WYA1c`*zOWM9If)gL7_DPU4|6<7l>M7kN^Mx diff --git a/sgl/tricks/__pycache__/utils.cpython-37.pyc b/sgl/tricks/__pycache__/utils.cpython-37.pyc deleted file mode 100644 index 8b2be5c68a7e95e552c51aaaabd4e25543b67ce6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2237 zcma)7O=}!S5bd7to!ymWd2P!O$Q}}sg+u};#1M=f?8Jd!oPz@n3t<>d_wJ4~U+Qkj zS~GjHen7}C*gfVi^))Bm47uf!yzZ5RB67&U)YMd0cUSeRSN*i#j|raLqvw1O5%P!m zxqLh*-{6(MLNLN;K}tGy=vsHjp0)X7-+qHKbO zGdk|wum&uIwa+5fIV0l@7PBtCo7`V+tqAL}{u%b3IBbJ$p3!%7e2dXnWVCgTsL0Ey zk%YW|a_-=hx%Q6`6T5kz;g!1(iY#eMTDo!;u5y-c>$L7YvZBf}x{1%){T8ueb^3A? zX!>0DGf_*K@JflgnQLFwBAaQqsHfU1Q#FI+X*Ke+*9gvZpi(jAN{76Wc~Mu|D=__c zBKip7uaEa;b;whl)*T|DEZl82LY$w(dGdUzh{a!~MKU#XR{}vdNoV%@SmqD=B=q2@`@a!Hhi7 zF}kT8(KVuHgh_HOHlXRD$>XVA`Fb|5u`EWgs+n{>uUq)f42 z4x{&h+&!`&%3XRbS^5kJuIR$Se*k1Db9TteS%xd1%G}n4Ju=8XGHY~Ny%k+{S4In93)>JU=)Dvak)0cHQHVI4o_1uTT|Vd9m$S6t+Nm1t zr$VH2?aJd)`%V3Z3ms01R4HC*KbLtm(LT%5snjm3mF_5!Rf zF7>8J^6Dg!M?xh9pD2A>Bq+xl#;t2!BA?Eb9^7zSTQ{(LQq`hNi~KETH*vgUhE<)& zd0BE(tk&0nB?l0M4(Jx%4H`S!Rt~861oq4vi+KzYP3>4He};ksB{)i|oC5;|0SSGO z5(j+{;*BPN(~fp5tcIjM>@j>D=w0YN=&te(`pb=$8ZxhJNxyygtgcSDP{VX6v%HxP zWs?fYhd3RPpAJ=x0ZZb`0i3_A#qMa+kUOhu9i&pt8!ql5PyN2U0 zORzLJkt)p&wJYj3+Rf?$C|2sLFSW-~m1^%5j#At)YXs{MboaSg>nFP(I!0%0PRdgq z<`v5`;2NBy#W9z5-c1-T0*1cc3 z^fO2>&}HSIoZZWAdR5KupNyyuvLY>;1R6cP8C3KLWBV+9Ji=ms;w9p1h|G2?D5gVP zL2vF6Kt~ika1I^u5xBNyFSrZ91K2}T4hSj&bAuK)%V9_SpeQ&C#v3NK{(aJ71Wb-w z+y;w~0sI}<24oSPex|yu^NxN?ei<&nbIk0uAsFinNE@MJ1g+Nwm=zq&USp0;`uuM^ zVjIDY`rx1i-*g_+ysPQDLxy$5lNop!3Vy+9Rcqf?hxRi#avcNkS|&w)$nB|X9~7Q) z?a8JrN3jlybe|Ve7@TyBDoToU&V@v4D4JQSdtkdt8d0MXq`6)ss;fR$do5X8Zy;Y2i|G3_~#+XF`j(8P}z8*S^J<;m{4l*9uV z#cc@f+d8*2SXcjTTXVZrwL3YkGP6foptV(?#nrz(Eqh0-Zwqr(3DYWUM_3qi?=C!qfTqg6&q diff --git a/sgl/tricks/__pycache__/utils.cpython-39.pyc b/sgl/tricks/__pycache__/utils.cpython-39.pyc deleted file mode 100644 index 40ab04cbdeef6f00a1d33a645900613b74815fad..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2240 zcma)7&2Jnv6u0MlXEz^Bw{6OWl>mtqB9Q=<164s3iiE0Mf&^L#jizJo?j-XS+nZ*q z*;8}mfcOh`kNuZ?<$!WS;>M-?>`kH&35k*X=K1;Ae$T)6v(t@@nBdzxdB*z@A-|i> z&Cf&UD>V5l1S5Rqt2yAJc^ma<7Jn*E4oX@lzGfQ zr=#92YrsNSH(10v=VY|WV%Eifi~Gy%6=6NLagM#m4%=i~=kzrl-C^{Q47V>36*>6B zNJ9Sn^1{I{b8T-eO#G%jMU#6FiY#eMTDo!;u5y-c>$L6zvZBf}x{1-+{TA^ak~j2V z7-;%TZ)BpDGU1gHbu-t#szo-{Zc$IPSEgzT$(q4h- zZ;04N2!DUDKdno?FXhSp;djqJF6L+Za#HLok!Qzpf2Q(6?qQSd=k9_GvTfZ%BFq^Lec|J$2QTs}74zg@DJL&;HBFFjuB0%52$TyE z9V5C%n9SDV4m3S5!GCQRx|za{SOp_kqzA8?$@J0B41QOiS>MECq5BO=dzsFW!{|dm z_JAyia+h99mOcY?E4pyd4*)4;&MsLw%WwrancKRsM+VZzW{r-~7T%R__sY|Ur*A#| zYfos406WC3x1!7L%4h*>VH=_Vy#qlJ*@Y38Rfu!!X-78N<#VomIZYd`ovP7(DnvTh zuAG(HZ|av^=x|)5O7Tkjxy-Av_F0}zq;^@YbVrG_l1(l7Fw$ga~KQ;pvtdM!@H| z#e=X28Q|ZAtxpz_>bA~n`VIMcumIOFuh)iPt$YSe)=FN;KN)iX7UEd_1Wpx96( zG`Wq#9lN%f+@zCK<#lCV0J#f6Py@S8L?as0o)fvIBKlyfIKFLc zKxp4qxh286_V3z?+pDU*@vO?s9%+HgtpX{o`}VZf4}*Oo%qu0#Jz?*J^|6FiolHc^ pYydYBv5$y-KcdHF&1MCEV(QdhGd+YK(H$B&_x<~Rv>9~*`VXjaMqmH{ diff --git a/sgl/utils/__pycache__/__init__.cpython-37.pyc b/sgl/utils/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index 44492ea7c210248797f18894b603a9ac524ea91f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 287 zcmXv|!A`?43~kbNAhZeg2ON;59w0w}5aPftjLVRada<%LY9l30i4&>&Bz}S;AJHo( zet{Ft1eX2$Y(G1BST4^pvd`6v`k?%kihtuo?kVt!pvf9$GRrE)66c27yvm8^&8+3s ztSNR`IsaxPE!ni%4DnTmZY#$}uhdpqYF^?XJu0vx1V`@c4*ps}_2|4$FZ!I|xh$%4 z8A(=m&e2Xg#N)KO*h|zkxIuXtwCnMxpH2XzF#zwGFVc|&KWouBs|4bk*lnJ!&0!SS h8WDpwC;)V?1HjF2m`-o4Yobwi>H9b*Sg=LT*)Jv0OFjSq diff --git a/sgl/utils/__pycache__/__init__.cpython-39.pyc b/sgl/utils/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index cac455710fe332e6bb0f8dd87dec1f60027ffc2b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 293 zcmYjLy-ve05RUBzD%1-41`J3j1Na3HVxkM}QY54-CzDu-B|8p2M@k=wmtf=}vNH7* z7!c=*fz#c0f8YJ6BuL?~CzSluT<2yX^r&9oBEx>yg^0Xwu_qA-j(-QGb?$%E?c0WpN jtQ?}TC;?1oBEW6GpXPVY*Mrsf>3y6NEZ9PD#(#bR??Oyp diff --git a/sgl/utils/__pycache__/auto_choose_gpu.cpython-37.pyc b/sgl/utils/__pycache__/auto_choose_gpu.cpython-37.pyc deleted file mode 100644 index e6ac54bb92150947568708684cc5d35b04319970..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1163 zcmZuwO=}xT5bf&e8Od77ag4LX%kDBPAutC?Ib?6f?B>HxAa)j1+?lO6TiVeCKJ zTnvGWKX94{Xe5&mvjheJOo8H?O!5yP9Vw2*z&ZCQF5M4Yda{8&Av?-N&kA~4@`q!0 z;K}v@3syhjYb=0DJc^9Pucwur>U2bI8kE+>c^9Yo7mWnTVzxN{!U7~O;gqA-S@QQx ziW6tSiyI|->70s^FF`-}Uc}_)Q&zH5SGpzMI>9(#HB@H3lFQ~1KW!}8J$3{Or;c0s zKdy8Tcxzt+q8W6?`E;15WEdm)i-9+KiGZLjb}yQjN|^)MzQIKQ=1J( zhgoLS5Nqn?v$NI~6*bC|e0!BpPzn#`{;smS`SiS+-d3EA!q{x%;v%>LJ$H2C2hECS z1}P0wnI%<&z64L?p*^fTRp13YHZx;YQi&{AX|Q8XAL+;Zh=xHEntTIR;W`)41`pO@4fHiU!A1!JuTk^t3jss~ z-L92KAW#R)LWxaZ0=0Y}@3~PbRZ%%fSOlOc&$no6hqosnDD!;UEcN|{ZCf>p^ zzD7LA*+!r=u8tIXB;%ejpA8e*#`zeh`iO!9Cn1@gt}p{mGdQAXwdV92;q1_w(Tu#a zj#x(LAlLqPKD?8WBb(b9-Z(@rAO%#Sos9DG9z80}$rG{%Gpq2c{PD{y1YX^efJi#^ zemd$#BI<^SzZ-n`@|8?kOON&RLHs`8f!|b7;5LVDle?QDO60`N+`bg4+kAB2U4QxN zOkI;gxT|i@ym~e$ZcO6aIModleUgzkjTvZy3Bw$DSVoAS>5w5LI5z$s&8;BTxvkPL z(5TVYy&T5hiLBu1&sKYoM53+Ku)V#$(F`Z!w(5uNkq$!D_D4GD_6A9!L>KdFrIWn6 zViM~mQM$TdP#7*8peIvzUFh|6bedCZB}{sLs8(@v=AMI`S~7CIa!wP4#QHH$qP%2Q z1CXV@9^{S~;{`lc6QxCzvm_O9UXjA*VGs-DF@~ diff --git a/sgl/utils/__pycache__/basic_operations.cpython-37.pyc b/sgl/utils/__pycache__/basic_operations.cpython-37.pyc deleted file mode 100644 index a93820e143eb593a30646e5dfe5129fd09d68468..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 583 zcmYjOy^hmB5Z+z?B$rE$LqK zsp)x;ZK-$#DrTKKiLqyPcE0g9GoFjtY{E!>-oH}6IAg#4aA#A5dj}^WxIf*UDld6UcS5PoR+&R8mQ_ zdY7Of0FBMpcL1syj6I*a9;0vEWMxg2)!VAsyZ9PV z=AT_)%~u!XrgV|6@%b|s780ck7y3rI2-cFg;zJki_>Jeg8cNQ?<W-o&0AW`lRv8$`spD Nz4EIDpF8Gb@eh+PqgVg{ diff --git a/sgl/utils/__pycache__/basic_operations.cpython-39.pyc b/sgl/utils/__pycache__/basic_operations.cpython-39.pyc deleted file mode 100644 index 88549097102c3a5747c83bdd20c6c732d711a58c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 593 zcmYjOv5wO~5Zzs`om?(CK!8L+;WhTdfcr)}Q8#Y1DviP?k3B~JxXH?zDy_Fwvv<)Y zpvpeFz?#o4$V}U&4x!)!5p-i?|l#skvP-LlTv{hJ`8cwOV z>*0tvMMEF;o?h2F%=3.6 -Description-Content-Type: text/markdown -License-File: LICENSE -Requires-Dist: torch>=1.8 -Requires-Dist: networkx -Requires-Dist: tqdm -Requires-Dist: numpy>=1.21 -Requires-Dist: scipy -Requires-Dist: gensim -Requires-Dist: scikit_learn -Requires-Dist: ogb -Requires-Dist: openbox -Requires-Dist: munkres - -## SGL: Scalable Graph Learning - -**SGL** is a Graph Neural Network (GNN) toolkit targeting scalable graph learning, which supports deep graph learning on -extremely large datasets. SGL allows users to easily implement scalable graph neural networks and evaluate its -performance on various downstream tasks like node classification, node clustering, and link prediction. Further, SGL -supports auto neural architecture search functionality based -on OpenBox. SGL is designed and -developed by the graph learning team from -the DAIR Lab at Peking University. - -## Why SGL? -The key difference between SGL and existing GNN toolkits, such as PyTorch Geometric (PyG) and Deep Graph Library (DGL), is that, SGL enjoys the characteristics of the follwing three perspectives. - -+ **High scalability**: Following the scalable design paradigm **SGAP** - in PaSca, SGL can scale to graph data with - billions of nodes and edges. -+ **Auto neural architecture search**: SGL can automatically choose decent and scalable graph neural architectures according to specific tasks and - pre-defined multiple objectives (e.g., inference time, memory cost, and predictive performance). -+ **Ease of use**: SGL has user-friendly interfaces for implementing existing scalable GNNs and executing various downstream tasks. - -## Installation - -Some datasets in SGL are constructed based -on PyG. Please follow the -link below to install PyG first before installing -SGL: https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html. - -### Install from pip - -To install SGL from PyPI: - -```bash -pip install sgl-dair -``` - -## Quick Start - -A quick start example is given by: - -```python -from sgl.dataset import Planetoid -from sgl.models.homo import SGC -from sgl.tasks import NodeClassification - -dataset = Planetoid("pubmed", "./", "official") -model = SGC(prop_steps=3, feat_dim=dataset.num_features, output_dim=dataset.num_classes) - -device = "cuda:0" -test_acc = NodeClassification(dataset, model, lr=0.1, weight_decay=5e-5, epochs=200, device=device).test_acc -``` - -An example of the auto neural network search functionality is as follows: - -```python -import torch -from openbox.optimizer.generic_smbo import SMBO - -from sgl.dataset.planetoid import Planetoid -from sgl.search.search_config import ConfigManager - -dataset = Planetoid("cora", "./", "official") -device = torch.device(f"cuda:{0}" if torch.cuda.is_available() else "cpu") - -## Define Initial Arch and Configuration -initial_arch = [2, 0, 1, 2, 3, 0, 0] -configer = ConfigManager(initial_arch) -configer._setParameters(dataset, device, 128, 200, 1e-2, 5e-4) - -## Define Search Parameters -dim = 7 -bo = SMBO(configer._configFunction, - configer._configSpace(), - num_objs=2, - num_constraints=0, - max_runs=3500, - surrogate_type='prf', - acq_type='ehvi', - acq_optimizer_type='local_random', - initial_runs=2 * (dim + 1), - init_strategy='sobol', - ref_point=[-1, 0.00001], - time_limit_per_trial=5000, - task_id='quick_start', - random_state=1) - -## Search -history = bo.run() -print(history) -``` - -## Related Publications - -**PaSca: a Graph Neural Architecture Search System under the Scalable Paradigm**[[PDF](https://dl.acm.org/doi/pdf/10.1145/3485447.3511986)]
-Wentao Zhang, Yu Shen, Zheyu Lin, Yang Li, Xiaosen Li, Wen Ouyang, Yangyu Tao, Zhi Yang, and Bin Cui.
-The world wide web conference.
-***WWW 2022, CCF-A, 🏆 Best Student Paper Award (among 1822 submmisions)
*** - - -**Node Dependent Local Smoothing for Scalable Graph Learning** [[PDF](https://arxiv.org/pdf/2110.14377)]
-Wentao Zhang, Mingyu Yang, Zeang Sheng, Yang Li, Wen Ouyang, Yangyu Tao, Zhi Yang, Bin Cui.
-Thirty-fifth Conference on Neural Information Processing Systems.
-***NeurIPS 2021, CCF-A, Spotlight Presentation, Acceptance Rate: < 3%***. - -**NAFS: A Simple yet Tough-to-beat Baseline for Graph Representation Learning.** [[PDF](https://arxiv.org/abs/2206.08583)]
-Wentao Zhang, Zeang Sheng, Mingyu Yang, Yang Li, Yu Shen, Zhi Yang, Bin Cui.
-The 39th International Conference on Machine Learning.
-***ICML 2022, CCF-A***. - -**Deep and Flexible Graph Neural Architecture Search.** [[PDF](https://arxiv.org/abs/2206.08582)]
-Wentao Zhang, Zheyu Lin, Yu Shen, Yang Li, Zhi Yang, Bin Cui.
-The 39th International Conference on Machine Learning.
-***ICML 2022, CCF-A***. - -**Model Degradation Hinders Deep Graph Neural Networks.** [[PDF](https://arxiv.org/abs/2206.04361)]
-Wentao Zhang, Zeang Sheng, Yuezihan Jiang, Yikuan Xia, Jun Gao, Zhi Yang, Bin Cui.
-SIGKDD Conference on Knowledge Discovery and Data Mining.
-***KDD 2022, CCF-A***. - -**Graph Attention Multi-Layer Perceptron** [[PDF](https://arxiv.org/pdf/2108.10097)]
-Wentao Zhang, Ziqi Yin, Zeang Sheng, Wen Ouyang, Xiaosen Li, Yangyu Tao, Zhi Yang, Bin Cui.
-ACM SIGKDD Conference on Knowledge Discovery and Data Mining.
-***KDD 2022, CCF-A, Rank \#1 in [Open Graph Benchmark](https://ogb.stanford.edu/docs/leader_nodeprop/\#ogbn-mag)*** - -**[OpenBox](https://github.com/PKU-DAIR/open-box): A Generalized Black-box Optimization Service** [[PDF](https://arxiv.org/abs/2106.00421)]
-Yang Li, Yu Shen, Wentao Zhang, Yuanwei Chen, ..., Wentao Wu, Zhi Yang, Ce Zhang, Bin Cui.
-ACM SIGKDD Conference on Knowledge Discovery and Data Mining.
-***KDD 2021, CCF-A, top prize in [open-source innovation competition @ 2021 CCF ChinaSoft](https://mp.weixin.qq.com/s/8JX5ymkUt5MvDcHLOjB3Xw)*** - - - -## Citing SGL - -Please cite our [paper](https://dl.acm.org/doi/pdf/10.1145/3485447.3511986) if you find *SGL* useful in your work: -``` -@inproceedings{zhang2022pasca, - title={PaSca: A Graph Neural Architecture Search System under the Scalable Paradigm}, - author={Zhang, Wentao and Shen, Yu and Lin, Zheyu and Li, Yang and Li, Xiaosen and Ouyang, Wen and Tao, Yangyu and Yang, Zhi and Cui, Bin}, - booktitle={Proceedings of the ACM Web Conference 2022}, - pages={1817--1828}, - year={2022} -} -``` - -## Contact - -If you have any technical questions, please submit new issues. - -If you have any other questions, please contact: Wentao Zhang[wentao.zhang@pku.edu.cn] and Zeang Sheng[shengzeang18@pku.edu.cn]. - -## License - -The entire codebase is under [MIT license](LICENSE). diff --git a/sgl_dair.egg-info/SOURCES.txt b/sgl_dair.egg-info/SOURCES.txt deleted file mode 100644 index a717e6d..0000000 --- a/sgl_dair.egg-info/SOURCES.txt +++ /dev/null @@ -1,127 +0,0 @@ -LICENSE -MANIFEST.in -README.md -pyproject.toml -requirements.txt -setup.py -sgl/__init__.py -sgl/data/__init__.py -sgl/data/base_data.py -sgl/data/base_dataset.py -sgl/data/transforms.py -sgl/data/utils.py -sgl/dataset/__init__.py -sgl/dataset/acm.py -sgl/dataset/actor.py -sgl/dataset/airports.py -sgl/dataset/amazon.py -sgl/dataset/amazon_product.py -sgl/dataset/aminer.py -sgl/dataset/choose_edge_type.py -sgl/dataset/coauthor.py -sgl/dataset/custom_dataset.py -sgl/dataset/dblp.py -sgl/dataset/dblp_original.py -sgl/dataset/facebook.py -sgl/dataset/flickr.py -sgl/dataset/github.py -sgl/dataset/imdb.py -sgl/dataset/karateclub.py -sgl/dataset/linkx_dataset.py -sgl/dataset/nell.py -sgl/dataset/ogbn.py -sgl/dataset/ogbn_mag.py -sgl/dataset/planetoid.py -sgl/dataset/reddit.py -sgl/dataset/twitch.py -sgl/dataset/utils.py -sgl/dataset/webkb.py -sgl/dataset/wikics.py -sgl/etc/__init__.py -sgl/etc/auto_select_edge_type_for_nars.py -sgl/etc/hetero_search.py -sgl/etc/hetero_test.py -sgl/etc/stability_of_subgraph_weight.py -sgl/models/__init__.py -sgl/models/backup.py -sgl/models/base_model.py -sgl/models/base_model_dist.py -sgl/models/simple_models.py -sgl/models/hetero/__init__.py -sgl/models/hetero/fast_nars_sgc.py -sgl/models/hetero/nars_sign.py -sgl/models/homo/__init__.py -sgl/models/homo/clustergcn.py -sgl/models/homo/fastgcn.py -sgl/models/homo/gamlp.py -sgl/models/homo/gamlp_dist.py -sgl/models/homo/gamlp_recursive.py -sgl/models/homo/gbp.py -sgl/models/homo/graphsage.py -sgl/models/homo/lazygnn.py -sgl/models/homo/nafs.py -sgl/models/homo/pasca_v1.py -sgl/models/homo/pasca_v2.py -sgl/models/homo/pasca_v3.py -sgl/models/homo/sgc.py -sgl/models/homo/sgc_dist.py -sgl/models/homo/sign.py -sgl/models/homo/ssgc.py -sgl/models/homo/vanillagnn.py -sgl/operators/__init__.py -sgl/operators/base_op.py -sgl/operators/utils.py -sgl/operators/csrc/libcudamatmul.so -sgl/operators/csrc/libmatmul.so -sgl/operators/graph_op/__init__.py -sgl/operators/graph_op/laplacian_graph_op.py -sgl/operators/graph_op/ppr_graph_op.py -sgl/operators/graph_op/rw_graph_op.py -sgl/operators/message_op/__init__.py -sgl/operators/message_op/concat_message_op.py -sgl/operators/message_op/iterate_learnable_weighted_message_op.py -sgl/operators/message_op/last_message_op.py -sgl/operators/message_op/learnable_weighted_messahe_op.py -sgl/operators/message_op/max_message_op.py -sgl/operators/message_op/mean_message_op.py -sgl/operators/message_op/min_message_op.py -sgl/operators/message_op/over_smooth_distance_op.py -sgl/operators/message_op/pre_normalize_message_op.py -sgl/operators/message_op/projected_concat_message_op.py -sgl/operators/message_op/simple_weighted_message_op.py -sgl/operators/message_op/sum_message_op.py -sgl/sampler/__init__.py -sgl/sampler/base_sampler.py -sgl/sampler/sampler.py -sgl/sampler/utils.py -sgl/search/__init__.py -sgl/search/auto_search.py -sgl/search/auto_search_dist.py -sgl/search/base_search.py -sgl/search/search_config.py -sgl/search/search_config_dist.py -sgl/search/search_models.py -sgl/search/search_models_dist.py -sgl/search/utils.py -sgl/tasks/__init__.py -sgl/tasks/base_task.py -sgl/tasks/clustering_metrics.py -sgl/tasks/correct_and_smooth.py -sgl/tasks/link_prediction.py -sgl/tasks/node_classification.py -sgl/tasks/node_classification_dist.py -sgl/tasks/node_classification_sampling.py -sgl/tasks/node_classification_with_label_use.py -sgl/tasks/node_clustering.py -sgl/tasks/utils.py -sgl/tricks/__init__.py -sgl/tricks/correct_and_smooth.py -sgl/tricks/utils.py -sgl/utils/__init__.py -sgl/utils/auto_choose_gpu.py -sgl/utils/basic_operations.py -sgl_dair.egg-info/PKG-INFO -sgl_dair.egg-info/SOURCES.txt -sgl_dair.egg-info/dependency_links.txt -sgl_dair.egg-info/requires.txt -sgl_dair.egg-info/top_level.txt \ No newline at end of file diff --git a/sgl_dair.egg-info/dependency_links.txt b/sgl_dair.egg-info/dependency_links.txt deleted file mode 100644 index 8b13789..0000000 --- a/sgl_dair.egg-info/dependency_links.txt +++ /dev/null @@ -1 +0,0 @@ - diff --git a/sgl_dair.egg-info/requires.txt b/sgl_dair.egg-info/requires.txt deleted file mode 100644 index 5c6362f..0000000 --- a/sgl_dair.egg-info/requires.txt +++ /dev/null @@ -1,10 +0,0 @@ -torch>=1.8 -networkx -tqdm -numpy>=1.21 -scipy -gensim -scikit_learn -ogb -openbox -munkres diff --git a/sgl_dair.egg-info/top_level.txt b/sgl_dair.egg-info/top_level.txt deleted file mode 100644 index d1cebd6..0000000 --- a/sgl_dair.egg-info/top_level.txt +++ /dev/null @@ -1 +0,0 @@ -sgl From efe0d92df87298f841a924b5580163980fbccb48 Mon Sep 17 00:00:00 2001 From: infinity Date: Sun, 3 Dec 2023 05:16:46 +0000 Subject: [PATCH 10/28] add .gitignore --- .gitignore | 3 +++ .../homo/__pycache__/__init__.cpython-37.pyc | Bin 763 -> 0 bytes .../homo/__pycache__/__init__.cpython-39.pyc | Bin 723 -> 0 bytes .../homo/__pycache__/clustergcn.cpython-37.pyc | Bin 1369 -> 0 bytes .../homo/__pycache__/clustergcn.cpython-39.pyc | Bin 1379 -> 0 bytes .../homo/__pycache__/fastgcn.cpython-37.pyc | Bin 980 -> 0 bytes .../homo/__pycache__/fastgcn.cpython-39.pyc | Bin 990 -> 0 bytes sgl/models/homo/__pycache__/gamlp.cpython-37.pyc | Bin 922 -> 0 bytes sgl/models/homo/__pycache__/gamlp.cpython-39.pyc | Bin 932 -> 0 bytes .../__pycache__/gamlp_recursive.cpython-37.pyc | Bin 962 -> 0 bytes .../__pycache__/gamlp_recursive.cpython-39.pyc | Bin 972 -> 0 bytes sgl/models/homo/__pycache__/gbp.cpython-37.pyc | Bin 956 -> 0 bytes sgl/models/homo/__pycache__/gbp.cpython-39.pyc | Bin 962 -> 0 bytes .../homo/__pycache__/graphsage.cpython-37.pyc | Bin 1080 -> 0 bytes .../homo/__pycache__/graphsage.cpython-39.pyc | Bin 1094 -> 0 bytes .../homo/__pycache__/lazygcn.cpython-37.pyc | Bin 4364 -> 0 bytes .../homo/__pycache__/lazygnn.cpython-37.pyc | Bin 4447 -> 0 bytes .../homo/__pycache__/lazygnn.cpython-39.pyc | Bin 4510 -> 0 bytes sgl/models/homo/__pycache__/nafs.cpython-37.pyc | Bin 857 -> 0 bytes sgl/models/homo/__pycache__/nafs.cpython-39.pyc | Bin 867 -> 0 bytes sgl/models/homo/__pycache__/sgc.cpython-37.pyc | Bin 845 -> 0 bytes sgl/models/homo/__pycache__/sgc.cpython-39.pyc | Bin 855 -> 0 bytes .../homo/__pycache__/sgc_dist.cpython-37.pyc | Bin 858 -> 0 bytes .../homo/__pycache__/sgc_dist.cpython-39.pyc | Bin 868 -> 0 bytes sgl/models/homo/__pycache__/sign.cpython-37.pyc | Bin 941 -> 0 bytes sgl/models/homo/__pycache__/sign.cpython-39.pyc | Bin 951 -> 0 bytes sgl/models/homo/__pycache__/ssgc.cpython-37.pyc | Bin 882 -> 0 bytes sgl/models/homo/__pycache__/ssgc.cpython-39.pyc | Bin 892 -> 0 bytes .../homo/__pycache__/vanillagcn.cpython-37.pyc | Bin 1422 -> 0 bytes .../homo/__pycache__/vanillagnn.cpython-37.pyc | Bin 1266 -> 0 bytes .../homo/__pycache__/vanillagnn.cpython-39.pyc | Bin 1276 -> 0 bytes .../graph_op/__pycache__/__init__.cpython-37.pyc | Bin 330 -> 0 bytes .../graph_op/__pycache__/__init__.cpython-39.pyc | Bin 330 -> 0 bytes .../laplacian_graph_op.cpython-37.pyc | Bin 1090 -> 0 bytes .../laplacian_graph_op.cpython-39.pyc | Bin 1104 -> 0 bytes .../__pycache__/ppr_graph_op.cpython-37.pyc | Bin 1164 -> 0 bytes .../__pycache__/ppr_graph_op.cpython-39.pyc | Bin 1178 -> 0 bytes .../__pycache__/rw_graph_op.cpython-37.pyc | Bin 1000 -> 0 bytes .../__pycache__/rw_graph_op.cpython-39.pyc | Bin 1014 -> 0 bytes .../__pycache__/__init__.cpython-37.pyc | Bin 1044 -> 0 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 1008 -> 0 bytes .../__pycache__/concat_message_op.cpython-37.pyc | Bin 782 -> 0 bytes .../__pycache__/concat_message_op.cpython-39.pyc | Bin 796 -> 0 bytes ..._learnable_weighted_message_op.cpython-37.pyc | Bin 1799 -> 0 bytes ..._learnable_weighted_message_op.cpython-39.pyc | Bin 1817 -> 0 bytes .../__pycache__/last_message_op.cpython-37.pyc | Bin 706 -> 0 bytes .../__pycache__/last_message_op.cpython-39.pyc | Bin 720 -> 0 bytes .../learnable_weighted_messahe_op.cpython-37.pyc | Bin 2908 -> 0 bytes .../learnable_weighted_messahe_op.cpython-39.pyc | Bin 2956 -> 0 bytes .../__pycache__/max_message_op.cpython-37.pyc | Bin 799 -> 0 bytes .../__pycache__/max_message_op.cpython-39.pyc | Bin 813 -> 0 bytes .../__pycache__/mean_message_op.cpython-37.pyc | Bin 757 -> 0 bytes .../__pycache__/mean_message_op.cpython-39.pyc | Bin 771 -> 0 bytes .../__pycache__/min_message_op.cpython-37.pyc | Bin 799 -> 0 bytes .../__pycache__/min_message_op.cpython-39.pyc | Bin 813 -> 0 bytes .../over_smooth_distance_op.cpython-37.pyc | Bin 1360 -> 0 bytes .../over_smooth_distance_op.cpython-39.pyc | Bin 1362 -> 0 bytes .../pre_normalize_message_op.cpython-37.pyc | Bin 845 -> 0 bytes .../pre_normalize_message_op.cpython-39.pyc | Bin 859 -> 0 bytes .../projected_concat_message_op.cpython-37.pyc | Bin 1341 -> 0 bytes .../projected_concat_message_op.cpython-39.pyc | Bin 1351 -> 0 bytes .../simple_weighted_message_op.cpython-37.pyc | Bin 1932 -> 0 bytes .../simple_weighted_message_op.cpython-39.pyc | Bin 1954 -> 0 bytes .../__pycache__/sum_message_op.cpython-37.pyc | Bin 740 -> 0 bytes .../__pycache__/sum_message_op.cpython-39.pyc | Bin 754 -> 0 bytes 65 files changed, 3 insertions(+) create mode 100644 .gitignore delete mode 100644 sgl/models/homo/__pycache__/__init__.cpython-37.pyc delete mode 100644 sgl/models/homo/__pycache__/__init__.cpython-39.pyc delete mode 100644 sgl/models/homo/__pycache__/clustergcn.cpython-37.pyc delete mode 100644 sgl/models/homo/__pycache__/clustergcn.cpython-39.pyc delete mode 100644 sgl/models/homo/__pycache__/fastgcn.cpython-37.pyc delete mode 100644 sgl/models/homo/__pycache__/fastgcn.cpython-39.pyc delete mode 100644 sgl/models/homo/__pycache__/gamlp.cpython-37.pyc delete mode 100644 sgl/models/homo/__pycache__/gamlp.cpython-39.pyc delete mode 100644 sgl/models/homo/__pycache__/gamlp_recursive.cpython-37.pyc delete mode 100644 sgl/models/homo/__pycache__/gamlp_recursive.cpython-39.pyc delete mode 100644 sgl/models/homo/__pycache__/gbp.cpython-37.pyc delete mode 100644 sgl/models/homo/__pycache__/gbp.cpython-39.pyc delete mode 100644 sgl/models/homo/__pycache__/graphsage.cpython-37.pyc delete mode 100644 sgl/models/homo/__pycache__/graphsage.cpython-39.pyc delete mode 100644 sgl/models/homo/__pycache__/lazygcn.cpython-37.pyc delete mode 100644 sgl/models/homo/__pycache__/lazygnn.cpython-37.pyc delete mode 100644 sgl/models/homo/__pycache__/lazygnn.cpython-39.pyc delete mode 100644 sgl/models/homo/__pycache__/nafs.cpython-37.pyc delete mode 100644 sgl/models/homo/__pycache__/nafs.cpython-39.pyc delete mode 100644 sgl/models/homo/__pycache__/sgc.cpython-37.pyc delete mode 100644 sgl/models/homo/__pycache__/sgc.cpython-39.pyc delete mode 100644 sgl/models/homo/__pycache__/sgc_dist.cpython-37.pyc delete mode 100644 sgl/models/homo/__pycache__/sgc_dist.cpython-39.pyc delete mode 100644 sgl/models/homo/__pycache__/sign.cpython-37.pyc delete mode 100644 sgl/models/homo/__pycache__/sign.cpython-39.pyc delete mode 100644 sgl/models/homo/__pycache__/ssgc.cpython-37.pyc delete mode 100644 sgl/models/homo/__pycache__/ssgc.cpython-39.pyc delete mode 100644 sgl/models/homo/__pycache__/vanillagcn.cpython-37.pyc delete mode 100644 sgl/models/homo/__pycache__/vanillagnn.cpython-37.pyc delete mode 100644 sgl/models/homo/__pycache__/vanillagnn.cpython-39.pyc delete mode 100644 sgl/operators/graph_op/__pycache__/__init__.cpython-37.pyc delete mode 100644 sgl/operators/graph_op/__pycache__/__init__.cpython-39.pyc delete mode 100644 sgl/operators/graph_op/__pycache__/laplacian_graph_op.cpython-37.pyc delete mode 100644 sgl/operators/graph_op/__pycache__/laplacian_graph_op.cpython-39.pyc delete mode 100644 sgl/operators/graph_op/__pycache__/ppr_graph_op.cpython-37.pyc delete mode 100644 sgl/operators/graph_op/__pycache__/ppr_graph_op.cpython-39.pyc delete mode 100644 sgl/operators/graph_op/__pycache__/rw_graph_op.cpython-37.pyc delete mode 100644 sgl/operators/graph_op/__pycache__/rw_graph_op.cpython-39.pyc delete mode 100644 sgl/operators/message_op/__pycache__/__init__.cpython-37.pyc delete mode 100644 sgl/operators/message_op/__pycache__/__init__.cpython-39.pyc delete mode 100644 sgl/operators/message_op/__pycache__/concat_message_op.cpython-37.pyc delete mode 100644 sgl/operators/message_op/__pycache__/concat_message_op.cpython-39.pyc delete mode 100644 sgl/operators/message_op/__pycache__/iterate_learnable_weighted_message_op.cpython-37.pyc delete mode 100644 sgl/operators/message_op/__pycache__/iterate_learnable_weighted_message_op.cpython-39.pyc delete mode 100644 sgl/operators/message_op/__pycache__/last_message_op.cpython-37.pyc delete mode 100644 sgl/operators/message_op/__pycache__/last_message_op.cpython-39.pyc delete mode 100644 sgl/operators/message_op/__pycache__/learnable_weighted_messahe_op.cpython-37.pyc delete mode 100644 sgl/operators/message_op/__pycache__/learnable_weighted_messahe_op.cpython-39.pyc delete mode 100644 sgl/operators/message_op/__pycache__/max_message_op.cpython-37.pyc delete mode 100644 sgl/operators/message_op/__pycache__/max_message_op.cpython-39.pyc delete mode 100644 sgl/operators/message_op/__pycache__/mean_message_op.cpython-37.pyc delete mode 100644 sgl/operators/message_op/__pycache__/mean_message_op.cpython-39.pyc delete mode 100644 sgl/operators/message_op/__pycache__/min_message_op.cpython-37.pyc delete mode 100644 sgl/operators/message_op/__pycache__/min_message_op.cpython-39.pyc delete mode 100644 sgl/operators/message_op/__pycache__/over_smooth_distance_op.cpython-37.pyc delete mode 100644 sgl/operators/message_op/__pycache__/over_smooth_distance_op.cpython-39.pyc delete mode 100644 sgl/operators/message_op/__pycache__/pre_normalize_message_op.cpython-37.pyc delete mode 100644 sgl/operators/message_op/__pycache__/pre_normalize_message_op.cpython-39.pyc delete mode 100644 sgl/operators/message_op/__pycache__/projected_concat_message_op.cpython-37.pyc delete mode 100644 sgl/operators/message_op/__pycache__/projected_concat_message_op.cpython-39.pyc delete mode 100644 sgl/operators/message_op/__pycache__/simple_weighted_message_op.cpython-37.pyc delete mode 100644 sgl/operators/message_op/__pycache__/simple_weighted_message_op.cpython-39.pyc delete mode 100644 sgl/operators/message_op/__pycache__/sum_message_op.cpython-37.pyc delete mode 100644 sgl/operators/message_op/__pycache__/sum_message_op.cpython-39.pyc diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a937f30 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +__pycache__/ + +sgl_dair.egg-info \ No newline at end of file diff --git a/sgl/models/homo/__pycache__/__init__.cpython-37.pyc b/sgl/models/homo/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index e12704ed54bde257bc41ed1109a54d5bc4cc3f20..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 763 zcmZ{h&u-H|5XSA;j{hW1nxtu(mKVsOwZs{rrh!^WAwpOJX)cz@x~`>JCwP;J@J74@ z58^8)-T-l8XF_kVD}Vc&nf2~B+S@P;3_X7JKIS(q!}vkL(+D-(=udsuA`D@oAx$t% z;uf}~1s1b~wsgQ@wy`Z;aG4$Sqz^uG8w1&a4zr708A8bHVNdp<&+Oxg9Ke7%z@Z$$ zh`EEOatveUE}qE=OqfGFms6N(?uouQ5d$$4BXKIm;!I4$xtNNJB7_TZnHaNYKTX}l z%xs)^{&x9JLyzHozTVfWe9U#mif^f4CGkQpcJd}(@%lGj&F>N&yE=STsvY57rgrgy zb_EOEt6g3nBR{UQ?L#t;UsLHrR+WfZyrPZnGW&YaSUcKJ-6nBKY!dwfMs1OJBt8kf zPu(HulF*~nJuR!*B{)SUam&9!pE<3c8sc-nDn4!bY_L?Z=1Ut}l|{u@gRLk_2EeYe zd&T#Hr<1ABJ+cFE?{%BSx?(F$us(K539zCs$#d`qHdWYvfa38 znjxm?H=`cvnC16X^ste6q}1mqiI-P+_!6lCBe@YdDynUwG%c%gm!{X-L*u>Bl|AM! S=~2~lEnV|J(`uRjo__#)wW(kL diff --git a/sgl/models/homo/__pycache__/__init__.cpython-39.pyc b/sgl/models/homo/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 591a3bbbca63e80307fe655a88d6bad474263456..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 723 zcmZ{h&u-H|5XSA;j{hWf8cHcGPmn|L3sk6SpcYby5SBoi%W8?;t)*Efc$2E(aX7&% z<;saSK%Cf_fW(Eh{OxaM*1O+$Wtt|A9Y3=V^;OStzEkk;usP?-zUHfy%Bfs*va~7FWUxD77JP!&gJJFVcG@OQP(F4NqiFeyPZc8k;Ei)AKidtNRpCdRu+?I5Gbi} z&7Z-DS#*C};(fqVAJ@D$ct$Px(7kUO)$-Bc8_JRa@LPFn_*{r=QdG7}b^zh6ZB#8= zwgSoW&?_a#7Q~9$HB!~=5R&^t$Et1FEre*Brig6sLm?!h5Z|0`Y-5&R7x~>v*SRsD v^0Js;;O-$e3Uj@xYBW?^y)$81l#n0+xq#G%%v?Dj0$HKuuvw6n(hAZOD6K1 z5-F$s54#5rT=_XZMdE<+7dTPXoyjf-+^Xtwl^<30eSS2ZW(3CXfBqtqfRMj%vKufb z--4L0KnWtKA{AXwingb`%3t_{?W>>)7a=7_LGLp|~+Ft`*-%4TjK7+w@(Y zxzO~j)H;OSE@7;!%Z@RPJ#tglYROjELe{kIHtRHpc)en~Jkb3i!%S|v^w1WsYup|? z18{~674$GLQZ3#7e~*#cO|DOXR3>P?kRNxBfQRhBkQOZKG{$$x~wed*qT@?}}~{n0q#U zKTmCV>0^WvOD+SO*zn4;(LJ)wY-ZD?2awYfK=J@kXbNeJ6bv!#wbXauMwhX3UYAO7 zwsC3Ku(2QdV~FK@`YtXB&v;j?^`~ejgAD+?bhH3}SfHBXSC@eX49i*=x3^Nw*=`b) zHGA&j4M5*_YZq;h5<-6t2af7K9{ouLv;e5#E{wMvXtQkeZ(NHh#0>&SXhsvd4_^lI z;)@&I{r}y$uMrxJFG?q%@(4*<^GY%X4M1k~3hfDF&w8%L$bhMS0PBInBR2}E6!7|e z;yKly;i30YA^A}ZLYmN9@!&bhV|N$0GaI=xGbKE}cVspr$8X;VUo%}JOWYldHSo50 h*Jv}_J)_}i4d-7cbEB2X*JDZFP|OF5p8z2z-oK^~b|3%% diff --git a/sgl/models/homo/__pycache__/clustergcn.cpython-39.pyc b/sgl/models/homo/__pycache__/clustergcn.cpython-39.pyc deleted file mode 100644 index 4b5691536ccfcf0c9e4d80806e8b59b81b6f8faa..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1379 zcmZ`&J8v8}5GMD1t2;+lV8;%SKq(U7irlp>g1m%4DP-AzQY{=1)RK3%4)1Teg`zk2IETrU^h(Ls|hzO;X!K-7-ye13U zryOOzh(!E~WWj3^kSv76BbOej&a_fP;(z+%kD55mtbXqMHyNTWYosiw}B09Y$Fj~hF6JAuF*dd1j>j`U1E6I zsG?fJT_^H_tB$vlm2i6=(?%9cAC1- zbdA(HgtIPTtf-2XF^xTP6V_*3^ZBYS<=mKO^RwfVuhshb+$_}`8>9^S>UlmOxy-9sv(^U?7&~Wr0@DLJrV)HE zKE45bw(U!fkZ#8wDA=bXnC2G{mOk+}{wCOjHhAJ~qD^ecQ)2!5j1JF4H|F&HmExP&7~j#$D$I(?JZS(zFh=G z#h$x(4e)pE%0+9Wh0uo(a1dW$=ub$n;`J+68N`+>&;KXR!QaQhGg z_2P5Xu=<|<_yi0Bn$TNuZ#kXD?jEpbHn3-A3V3=S=xhd--`R2Ba9tx$++Fn5@Vap5DKGaEzqNf87V#XB%*yX>RrKYGNR6(zL~@8&Mr4{|MlL- z+1tHRNClzcj8}@M884=qR|g-e6tAoV$3QU;K}nKOK@v|yr({9fc58`$N&=Du*v$3b zabuy=@AeM4+_?f+#QAAer~NKQ@ZE#s!!hqcSz>8JPP}vPf>2_8L4|h)5f(Ha=&+;)A&vLdsLX+-ZfSfwRf;&ND7xh-ZjhkUVQhUb++ z1HyNiJmTthNVnMIqRtr(QEO@3I(qFC;CQxl>+CZy!t$1kz6*?03z&;rh*~%Nzm?Q3 zzUczq0ALAOFp=eY6K5S@{cCiQkVjc6VHInv>&e(p^7q-~pv>jOn8V5bboN{wf1Q}6 znzR!b%$1V`u((W%adoVRP&I;MusorgG^PSXO-I7kt@O>lHU5I`KITdwpW9NK@rvh``?mhgD4m6axc-X{%S(#WAnF)a H6MO#v5dsNq diff --git a/sgl/models/homo/__pycache__/fastgcn.cpython-39.pyc b/sgl/models/homo/__pycache__/fastgcn.cpython-39.pyc deleted file mode 100644 index 32a096f393f10e7b4c8ef55804fc5ae09cbd01f4..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 990 zcmY*XId9ZJ6rS<5n_VX?NI^6}M7I*`wiJj{uoRns6e>5P!85xt8XuXlBZAxoDOF1P zq|9HM3W);p7bt-D#)-0wHP7$xyS~qSeSJXS{Ce;~e)9Pcrnww+Iv@}n6i3s6g=i02uTtuNaBg;m@KGkxt93HBp^wE4Q=lY zH#Kzo>2k>TXBV&@v3^=LX}|M1iFyyN57)c}QInds8w&p=aJA~vZuhuoa)v`RS{ge<%S{1>SC@9kJ_0K&cZBq9V5C~WT)YNvos9l(C3O;C zcaho-U@ajFCbC>_;H(3zf4N;G(jeSX`#Xq&n0isEUA^TObJCpfQc$@#x0i$n*Yu8C^0yxx#GhI|4M)tC445 z0RSy{E*Y}}hO0IT^_a2!hN~{(8UfVZps7>};J=SGoVbhH77AntuR(PC)Y&`m>)nAR zCfyPflL0T6pqh?`&6V&q;u?QIdmCdV(9i2qo5>Z^j(t~uWe_*b7H0LIbelY-U>U?H I25@5UACB4y&j0`b diff --git a/sgl/models/homo/__pycache__/gamlp.cpython-37.pyc b/sgl/models/homo/__pycache__/gamlp.cpython-37.pyc deleted file mode 100644 index a96215cc0eb721e6246f2c193ca558cda60215eb..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 922 zcmYjP&2H2%5VrHT>C%NFm4Xm2un2Zv00FHKd)aL(^?>Ay<*waqz)4&?ktoV-EA_+! zaDyYSloJxCz5*v^oL!0|Pdk;x_4kCkyMj{ zGD;B#Iw(RGdK~Jgh*|7$q?2O629(T*jAe37WTNOgWT_mSk#Twpk;s^Kqa)$e+3fIi zQAwqN*k3fd&gbGv*;8eg%G9TA#v9@VkLb@Uh#$35#9*rXo7?x>C1a2>%ql?1L!jxz`e zDOK6?Wm9mC4RIiHhx4Tt z&T;;e{CPfIRfU>5cR4+q&0p%P&r_G_slUykubO6}&}L$;EIzLkrb9#MesG^2&@cCW zxhLDy#xbT@ib8ST4LRy=G~!*(FB_qIkB`-jHcFOCYbeSdMD=D&?E}OQ(O~5UFde*U zl2X4u9`CS6pmWmeoZMB&;x6{&KcCAB0|NS;eMt|{T|rw!U0FBz*8$%ZfA-E<^zU)o aAbOSDV<7ybo9wCmsi diff --git a/sgl/models/homo/__pycache__/gamlp.cpython-39.pyc b/sgl/models/homo/__pycache__/gamlp.cpython-39.pyc deleted file mode 100644 index 7e95d2ee289975e3efd6f74a6ab72ce23346a0cc..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 932 zcmYjP&2H2%5VrHT>9$*eC*csaA&td0y|&*qZblAkL1Nf!C4mSQ2D z;#{&3Ggh)3MPZe(F=E>u(=JB@ zDFbo}LNo1EPP)iRb(Rf{tqgOuG_Fg}B(Iq$bT_P;+BCQRst`h!Jzh4NDIAQurrj1} zOU0dI?3(<2F}2yQ__3yk^O@WUL!9$S2wz4vXn}JD373PZyutzhX!>Uz;^Ju zNlJZuINr)01ItOza&oVa#j5P-M$Q%5!0~-$FX%p|D;Nr|E9)lz_4hOJ|K2Bc|0?gs aE`8l!^bL;Q3$#Q17ybi;RPAm6 diff --git a/sgl/models/homo/__pycache__/gamlp_recursive.cpython-37.pyc b/sgl/models/homo/__pycache__/gamlp_recursive.cpython-37.pyc deleted file mode 100644 index a98f1bef88e1829537eb15c1ea5dd82a6f2c18a8..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 962 zcmY*XJ#Q2-5ViN~HkVvPkq-(Q%9G%(p+tm4h!f~8iAaI8jpeMJy_LMXKHG~3+Afh& z@gJb1=hwJGqRL;OV(iT&;FV{eXYBEtH$EDTh6u*9)m#2EKrsE5592J`qQr4v*M7bu}5s`|M z!YM==sT@3x+hoX7D7&pD`0Dld6rBhAHy9ZJHN1V4V!u8Ht#^U?q^CZ45KXyV_VGV4T~-nZ^gDZ%_guJv sPEpe+J-Kg+Z`Q&+o`u(H3;Qqp3!p~=TL1t6 diff --git a/sgl/models/homo/__pycache__/gamlp_recursive.cpython-39.pyc b/sgl/models/homo/__pycache__/gamlp_recursive.cpython-39.pyc deleted file mode 100644 index 9c627c1e2ade42ddd776f61ef356bc9ecb0dfdbf..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 972 zcmY*YOK;RL5VrH$bla^sw6q5<99M$f3r7SLAy%N>wxS25Tr4+svyq&{YdeUd+*+w8 z{sSC)=hygz#HoLQ6Jw{3Wg?H};mmyVO+rVbAp&`}de46a2>o%+x_7xnn1$XBMO4LE?CnS-)gT*SG(#+A$t7Y5#}{FivcU-&r#Daujj3!-=$srDP!QzzYZwwMg~7gY%5JDTG~*2$>sE5592A=pQr4v*M7d?z5s`|M z!ZAe}sT@+f>0^~HT*Tpq$SJT-Do=f4BhAH`9ZJon?Z+ri*Vsba_IChiQwBm%=A#tO$!r5DdtXeAij4ukD^x$$X z1n{&6<$Ja3tbL~(l>;~${sNR@?;irwy8wLB1D`xlrra&Nzn0TwC82$Pv*-A+OE<75 rY8s^{_bc&h^6x$*tNydS+n8^^hSMWYx15ViLs*`y8PP-wYuCK4nEu82Z~h=lqj<$$yo%Z=U5O5UtvdnJnY)=E8b zFSdjU^Dgfo0!;6GTu& zDwbEAr$~SM~!mjA-|C3RGNj<&eSWwA0Be8_@d;inc;eQq5<2Tmz7rXqb!T1lj2BP z%L{p;7g&5_NFL2T1t&>P1<50kUXybwVv$@8^7xu0Bu{`=&&RWa)7x;K-0a`J4EAe( zk#!Kk{}fgX%YFdoh@&mJjJ`!JZKF2+4obMVrGkp+OVkpUxK!Xi*##Zi{$&DK&-LMK z(xz=3_L(4f`EdaS!EW}UH);%p_zs;X%qLc=b3fFk*33GotVUTAb;L zzAz*B8fCRXW4aSPpgZ*QHq2qNX>c5wK3;?;#wj+w;+15~4;gmSD2%rmTQ^*FnNSrZ z=@>1fQXtbF)<8F=W*6fp=uk_nZsFCFjE4O&Nc{--PCCAmd(8}<@_YZ~YjG~iwI(E)9{P`05^5^g#Z8m diff --git a/sgl/models/homo/__pycache__/gbp.cpython-39.pyc b/sgl/models/homo/__pycache__/gbp.cpython-39.pyc deleted file mode 100644 index 1765c3769437808404d56fbe9d0ad3295aaeda94..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 962 zcmYjPO>Yx15ViLs*`y8PPzZ3~Oe9DSTq_C{A`y}}JlnHx-k1*teFFCK{Db@!5%LR@O-mqo3(LL-Cy1bm zlr*Ol<48qioW~)KRZ^yT8sbEC%5L7JWJ*LTI+sLrBwfdOCb}~+%5I?&8Bssj=T^?9 z?@kUYAr$~ShqZFW0l$#uM4GwO&QvSFA0BY6_`Kk&so{EgtN|O&ic%~2NfwKxlj2ZX z%NO!k&$0NXVFPIPIXFpjDo7rQ^opEO5sT!qm&aEmA$bDSIzFE6pWKA=tMh4 z=2;5`k8uhshGjp3bHvextfTKyLz}3He}Ez`Zm6Il`WiKaB`y`1CqvMq>8=yNI<5<6 zlO}EAkY|D*fg{s}iqIc5hC+0U&Zk`IB_H*CYHKZx=^&UfR;&ugn86nPfN3My0tI4~ zHptL#1+26SM365aryx2jgW{P_tW;;duT7V&f}bO2&MjVJEf1c$=~FnyYpvR0T;| zMsukY$n+Ghw{D2hGjymVDsJKJlZ=M^2!wtFY{xCz@x5XOPx&Vsdp5knR!`eX{1~?@ t7!!AuvE%;}2~!Gp?>Ec#+1?E-{C^w5BZzkLQewVgnQ;D5f7&%j&2|liK+3XTcaKLu5G7U zOt;4SyF2583kjt z98&Ia|Gbm97sMlp2dP8Tu^}+=U+YDFf4qT~iE(FTJ&Rf$!EXr9g{KdIRHUK}IdiU@ zYeGroo{@$yuktyC>s&e@1r5@$ae+ovr=j1vjR$n42^ycfm;P0+39lgv>vICNc#|6n zkz>cu(ztw)#lG?LLy=a-%a0b^1o=#+TAMJJ>8Vh_a8;B=UB#=$*L5kB3Lt=q7+d6v ziZSJ**k`2@Y>r5=qErZvS+hEfU4yMC%_gf<3SRk~=`<@PGN9gI;v|)iRmy_X%j=BY zxofk6lC4c><5Y!NVP9c_EVILljHiV>glf|Yw3S)^Kb26Xx6A@k7nM6v< zf$DW2gs##EUI5pj&*;x54@)t)$1-*m{$AlGs16XA4gj5}nPALx88%x>-A*QsnR zHW7@VEp;ZOgt5HA7)WBOUL*e&3C<3mMua13UI>HkIYNd z6>e(lTEJqYRiX4CYyUhCTksrRm*$5R!L_mJmzB00)@h>mZL1$}TQ%a*qxL%g0K#P{ AQ~&?~ diff --git a/sgl/models/homo/__pycache__/graphsage.cpython-39.pyc b/sgl/models/homo/__pycache__/graphsage.cpython-39.pyc deleted file mode 100644 index da1ff506db7dcc4ce06654d9453637c396f2ea42..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1094 zcmY*YOK%e~5VpO~q)8t@5ho5Hka8gT0|+Rh%B3$6PJ6MI*iD1HkJ`J1qP?|JPyG+= zv46=YBo3&5p$A~bTiW22zs!s^uWu%4zuzTr>Sv$i4I|_yDvQSly{*mdLP!Os)jchD3Y~3BW73)2Ey33L_;sw)@?f`|Vt~biMdC3}0VGMNAPE!U z8QG`87r|*K@y>`(5+Bk>w&TKJ@?Xx2{CIl}trP1_%X%8OHisxja6P!@BS1we+K>}= z!7d3Um3KlK!u=``6h3y&KnokhsPO>fs?*Rf-Npx8YQiQE-g$7*Yobet!ux_iJ^s!$ zg~$VDX=yz%&jx`F@_m_B*3S>-!iM=&rN-DOSLw0TAP8NQMO_V+ZD8tBY8^rV8*@I- z=N0EVK()_HE%^)?<3*{FBD>;j7`ukIqO|M0(kXc5Gj7tXR49OXjmx7{Et<3=sF&9n zzjZfu2{l)nG?rC`UE-f%l00)mjID2^+J|b>3aph~{Xdn|ws)Tev@SF&WX^#Cy@u6x zG5&4Op^!)Osl0+U;B8LMugJ}CbWmh+WX$1cGTwTnj=zn}OpROx2H%R2V*rnNfC&NFd74SiZI@%iwL-ba z`C*-^_QfTGF|_4QrBX1O7Z}6AbCg~p;PmhsLx^46g?yEsIdenXF EFIy8U)Bpeg diff --git a/sgl/models/homo/__pycache__/lazygcn.cpython-37.pyc b/sgl/models/homo/__pycache__/lazygcn.cpython-37.pyc deleted file mode 100644 index f528da64bd668f8119070189f1fb4d58d1b407aa..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4364 zcmcgv&2JmW6`$EJE|-){$?~_83P@v=Z6c?iNekB>No)hHtvHbp1w@SPdS@k(UhYz} z!&nlQKmi%Bd@zFE+McN1`acx?4|*-oK@UND2z=}{m-hE&MNw)Dv_OF_xo_UgyqTT% z=Dpv%d3RKm-#!Wqsb^eXC<}s#~eucRG&Y?bPji z9nbJiTI>5A-|%i)@6;jdra>n#WuwzHrOOCqqrMow`DZg_+Wsx#bOZ2A_@Cj zp~~@>@kV&iq>`-2<#J53PoDY#Q+xy_VI3}5#}a3svNbO1B6w2o*y5aMh~^XCam0+6 zeZo4fI4@da4lSOT7bnqD6AR)LN?)87XHeFkGKX~n=yE~V?nFD=P-OYVFY$>ltM7er z?S%#Ig-hh0c)tAYZ|M2QHEqX(Vce|57?m}=Hr`?fL@~unwrh>8ea4xx1sB$Hw1iG+ zfx4v)>XkLnTImb>kt5$J{VY)axW3PXbDp6mC>x-SvJToT1JD^YTk_tC(gAIiUg-+= zxjUXK8~Zf(ykJnpTVBTk@O*I%cIHyj;v_jw+tqHimPATB*+y4rFN@QtD0D4LqwPe3 z5i%d-L)AX1-C{UMq$H)ZABNqmtHMx{9rR2%kV&{sTN&m9eJWHkLVva%7EynY5`*Ts zF!?M>k2N*JezX-n&gCQem}@p&pIA>+m<{`(+K@>k3SCJ}YPC=#pLOFzJE7djK~*%A zWMX0Q-m} z3q`jt7wBsN&3|B~#!)p)WwKhE)@Wmi+N+k09jJ%Zy29}H?5j5}Z{+>ta#3tvUb%he z&2)R~aXL*x*=69ZZ7C6h|3w#IRWjuFj zvQ05<6FpQ?GVv(mWCf9-2vh59F)hO4=<{oc9M;NwT0-d3@*6Waki21DVH3Q7m0a^EgV2 zt6(NW{%2+6_h}`hiTnXb+mb&d{zt^y@@?Xc0w+{bsF!LWe@wD?50>6ODbJv$XTvzp z3MGfJ3M0{jvuC;N6DvKM_BA4}lH^=?2nPwmDprLKFg6}anPf`W)`nFe&)+vVa4*l( zn+Mj_b~jV<0!?+1rgCMJttZ+}lT6P~rzPO&(@msR9>lDVs6Fj?{36PWh~28(ng9CE$K(|B zxOmBU@8K;32)S&C*c(}R?X6`T9{drLpVAT2L-dXvlx}J5TNIrmM|pHWl=fL1KqKm7 z<67zLbIvf*d7F*w(h?R*YqPbAxqEbM*vpb)<^_?quU!J#Lf$}4E`n$)8)$csKZaj_ zgX(s_n`ti^43bPtjK=S4{-_hE93f!FcU;U8cB@D@3FLEQ5U< z`~rT~V4V-V+c`C{)Uh$QV#;~wp^{!>t#5%Tao7Um7D_f|umY^I0v*A#p0O9z%|U4` z--7pl6GP-}B3D4#j@kZB?O~*lk!*E>ih(y29e;=eWo?+IU(?6nces|z$B`5iDg|vI zdaLZ$c{h@fQxxJLHK5mvRyyBu|QVCl{0KqS54#>ymXDT)G@ve z$8gn5Zx(q5>$QaQoPU1(88={G0kj9OpDXJXIFEv@S4nz2Uk3Yt{s*(8M%fsj+-Dyn z^V$64Dw{(0<`v@MY_({cf_#PMwviJ^ znnT{DYAq@vDWh%e5Q^)1S0z#b(g1o<91r`$6y7gMErZCqStJosVC20j5mnTEK%G~K zTqAOw$ax~~g6L*c0D6&bB^aPRq(H+|wVRdq9z-C0RIWn0<0Du@s%oEd7vefV-(jXI zzD}+2cAO@iIwf}%8xS~8nhkFTPR(f0)~k@a!bjVDuL8+WXt4$wXOG*F^eIRch5&Ez z)7-ZLK5v5Fx90hIygR2~z6HmQLIO$?_GEWu0vyQ+e`Lv93IMkYz}vc{?4C1r;mMxy z&V><07f1Z?vA@p<6gMxeGRPx!0YIV;{^?X;xbxa-)Kxex77^NuIKR4ZscxPv!kk?L z3`PP6I8hSc>wk$Xgb0RlIEo%n}D3~WxZ zvw3)oVqG6XpdbK2a%lV30@6X>+F6*`?x^Vh*>bNOwHzJC$8C20{;{T)S?eY#FNpk` zt!`1)#`%hIRs5Hg#~lWY)2ww5!8ZL{Fo*77n= zPaGx7ro=bHaB~=?RgX#P<Ta=AbO|h~`%ob%*Ql;h zr;f->Z%opGfIdxqIqrcGa@=BRa^KBheJ}*Qqp?8eng1_Ck$e)10-=Cn_a)w%41F2tBp4yY)nQoBZ|~; zhU#g0C1D5z&;lzT48zCx5Kuubx#TD0f25IvkpMCF7$7HKlCOG5QGyr9Awza`^{cw8 zyXx!eR~rqF;rsQUx8r}WG4>BC%s&gv%SibVNW>zZup#eqPP~yAL(^cK4Xv)NYs|zM zR=O3*=zc)Q0tLeKxYuZDkw??XZ z=l+bvh>d&{h&=X(2fobk7vZiytCDe=qfC>`KfNg^|>+ggG+@4+AJys4(=m+x?0pvX@<3k(b093 zwe`y8Em6ffdr0^`yNToALXaM!2GX*Xo$kbeR93p%kCc;!Ns#BNk|x1nEWkt}8)akJ zT32>H9>qeC{>t_Je%hD5FKEkD!yk#*@6qA=*+`J}s&!wA0G-mFp9jNHLQFPNm;CsH zAh}Rg_lLoOe?Jp@G^4EzL@oDX>8InNFLy;8M7b(^%}W(Oia+Rwv9f%zn}Nz;tST52 zNGVi(tB)CDZJ!uPVo1rch;vn=lGYLhLuRVxZ^NSfp*H4*vhq0Df!={ctyC))wTK~n zp#VxAwk?WNKk5%vl>})y;xcW4M^nuFz&w?<%#9i(o)w!d8 zSGi+=GhDvLkKnl4Bii%h4Cy%ObCAezw59(O%!wJXJy*Pb!Ut^13x3KcwlrXgl$(^F zaNk0!HMK#VQ-D=rOe-aiHaGo!k0Gy0Z%`Z5WkW7amtpCAjUDsI_=tb}{DLNr@hAM z^h($ND_ljq2*pDVLLVUn8UHF$z5&vMi@twl!cW)~iy82NaRRqI=BLId#=f=1#P`vu zWhy5hZx8!YJcXk2!YoMw8T&gakke?tsLbttgkElZ5Qk$Jzi3gvCGrw7!XyhKxE;w7 zFM)jFT~GSC471_r#tSItv@VAqJwHE7xZ+4;VVvjf>j3Q_$!~y}5Ba~<5kH`nkS^kf zAT2}ui1;59Z;I>0YXz2NoXf{*Aa0=3caUaFZ>@_bQBq5On5DTC<52oRG!SHig0xdB zz!$Ai9QoUH5veL}lG0nm26g)eb!(pMHuHDD#u?ne-?95?I2IyKrK;?VOGIei)e!S` zmL<1lXQ0e}D#bG-ncPU(B1n6&GLtyPEN2o4(+1gg4gl2JYM{e3tLKXIJ}sx*7Za)} zejt8A!f%prH_gOQYpe~%T5O_jFMHC_D{pbO8z|3QoSfS=Jh=Lbz+`y3L)Ix3-#;HM0JF%ogS5RSI z!4){6Ha4vkmUex!uwG&lvoIn9xv{@;53&bzF_|4fF7=2=%T+c(cP@U4lK2{kGSZQ< zN7;QC{a2}M4*RKcg3%~Wqn2H6BEPHnUe}Wu{6mlL+K}S+OE)?571ij)X)MUyXp-+i zT#ir+u;_3LKaXF*ZejP3j@IWpbz#gKA>C&5QBFzk);GZvCa!x|tQGKo>;iUW3p#>l z1HE5V)}Uf^-hgY+5J7{*DhLjZwi#~q&L$0B}glh-Tf2pXJus`t*o~7Q?X5pRUB6xpkQY&iH)l>E^ zLazP4++zzM-@i^gY^>4|cc7^@G&pa+w<{k9=d_6$&gs(BaJ~XK=YVtLgk6Mju5oXM zc8dEzZ0{UA6Yi5u9Noc@N#c7Ipf{Lypwg1*GnvjNQ5zip1FP0#~Qybc057j&O4jIBpazaVmt$S#rFMBW6!AyWi_ zw_HY7tr$>)#`FcSujcXo7`J%;+*_iHezRkD`I6}xYlv}O!!alZF8QOU=Ewe+@hz_- zT$+o}QqVh8zmWt1X!&GXI;&d-*D^XC+-3lGj~+y6FpPa)c|JWQ#tHFt-`^hxN!g+! zbMX#|pp}aEi0ly=f*^z_L;f&M61=O<^aXCwy`>FZKfY)=IT0wi>aDwV&v93EOKVLv z@iJ%=F17WehPX0>x>xi?7q7*3XSvypjYthq_E{8J>gz zg4Vo!;q_NO3#-&)Ec;1b#&SBIifF8jml5778p@iBn5;bu$x2J9pyy&AN>01Gd%fJ@ V8x%t1O^`ZoYFDr0U88@-e*tgKiYNd8 diff --git a/sgl/models/homo/__pycache__/lazygnn.cpython-39.pyc b/sgl/models/homo/__pycache__/lazygnn.cpython-39.pyc deleted file mode 100644 index 511201fd8358abdced87a088a318982b417eaa66..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4510 zcmb_fOK%(36`p(N!QoRfB{`0hv}#hMWs%4ZkTz}V#(6kan>vhPxehXbF&%M65~<+~ zb>}j+gdtEs2COWCpar^U&;_;h59p>pp#PzFQ55JRa2ExNc9lil@7y6x5;o9HhnzF# zzRt`&kMEun)aw-k*KhxDJNoykVf>96^OuFj+qknMkkANOZ1kDO81X#jeUlr^=v$tx zd(7DBm%NhZt=R3Cy>h?eRdm~qtNof+o2;yRbuFRcHS}}ITh`ArUUSn3ov?Jz2uo4v z*z{IHzGbxBhj_oyV#5@^Bp-nHDzqE4Dk$nIX86FD_kB86PH+EQ9 z39I+Yo*AACYhnE!^Q^EDF5NRcJA5`g6E@M~ge&1HdP?D1cot9hzF`?&8QVFpO4ov; z!_9WP{qR>1t<25$+Fyae_1DV}IUWyfB9Cx=_2uvA`sZb3b_T;vohD5#;!C(q+}W2P z(vU1Sj`^6M!Z5iVvVAVrrFm@RhHoXS=;7#jUD}};TAyRCZ{-}+$xYBw?t;2`IkfjJ zaU(A$6Ozwz89*D3HoZ_PU{Rl1?0CnJOi5a2N?Np-2bm zP`1{Toec+(5VUvY`hGX*O5YcxGFA5nBJy|10sM3z$b;3IFGYY!$*!LT{XtAjI#6f) z=+hux=&Jes;I4ls75nr?TRVwb+Kr^24Ew&^6HyRms+cuxmHaUJwA+c4<%_)(R0cy; z!kR!zp=w)QyfM-~ij^dWHaQegrmED^yF`P_R8{?Lc(vcx4qaDP7R5W*cOcO#)$-yl zA_$*t0Idqt7DkC5cKfPKg0viQmQ+xoS4|XIo!g_Q8`G_rv;mX8-zoO=fT^-ye4_6^ zFus1}!d}{sE@aukg{{qNuf~UWFJ!y%1^uFo#?lKhJbO1uHU@{{9NtX!&0YjCScAE& zg1g2X{X5Sb4vxF5@yOx)KSu~)9Tgi>lXl3Y)5`(l;?~~YV1{vGhQ`nkS58>Z7_*$6 zvXL!0ypf(JJx`c#VbmJipw20PE9c`HyKnB86fyT)xh&6|n8)mxUBMizZGH?ewD3DNnUO5NhCsZC*rVSkxV8M1 z|0$nCzWrD}pOe4(g#54P<>#FCp(D&`?-o$;=gSY6xCl*kprpqap$Y}e3hrzjWEUZQ z``n0~7-J}*$9nt(VSLO^`LFqbwPlDOVA2CqIoWW#-<9GyG*zXO#&IAcercD3> zIP(~v01!}p(isX7B~q1kh6RK)Z)!}so~H5p69Xu-n@I5jN&W^&wndQaM#_w%1aEmP zkwCDQZf8JDy=RS#c+JXOb2_Kx6xuPdHO20Vw@LWBB-~9>(bqdB3%9J9D)o3Vv#|Tj zr0bYE!307Bmsfxd%a+5Q$FuV2WuU^8J8C}Z5gpR&^df{HWxt0z`y~juwT~Dcaoi@( zHMVaEkIuy&m*&{Q)6V%Rrx+hu(xFo;H_zeJ9MPCDF6EX+3NyFfFd%`498Z4Gpi!J; zW?N8PJtWd{l}&h`iJzk-UIbA*87O;@-htPDi`r(tndna0|cN4=w%D~18?-S#;C2)g}1LFP-uy`O5_R%%uPWfZh$CTJGbW! zae(a%fzLjEVOEE*etOlw5Lmo~JNpMnNMC4h*l1*h zys+>wL)#uXWNB$j2iAn$PYKr8hKPh8<89rgldau!2f|O;-roOoP8Plglf#wcQi!*t-I4MRI znmcahl~bG?w;Q8sULCKT8nQ9*X&Z`^p?@6E1r7$2rr^FzL*Y zTmvF~(Z6sTN8g*e-wx6?WY&{J>uK=aS{2{K5=2IDSKI<8KBi_V$O0jPLuC=rt8!OH zLIT(TeL<%)><{VS5$9?Ck7<4_ivlF^$lPTV76JGP4gQqK&xlaoE}kcH4Mf#~3@}aU z7#1i8DcCTUt(tb?X)w}DbtT}rh+ChU%B6&fI1d;+Ok}~Askd|3i6gH}IcLFE2%@Lm zn%4_wO4OL^6!={zbDKU;V`O7dcbgcdd+JX{b%aBN7w&h{$yiU?o8#iolePT#?X##`PJvuVm4|5T|)}t_UGf6O$t? zHp$FQnBC!&JQw`Yb5oN)WqqsC76g*BGD{a%g21z=wHm}x!6lYJTD;xHi3V`5P$iWF z{mA!K#itr%7!zOf{exi;7b7~Q7dL73Z6ZFAl*j=HQi-Bg=tOahf^4?m*J%p5xUSjg zV=^Mx;!0)BtyLU%MUS*rRTBlp24T>Ku7!vz5~*`EP*fTlnJ%h|LMc;D(W~kU-QwGL zgIb$gH7l+Jg@igwV<{@h0C{V+u~0M?)nu7w3}rXYii}QYReQA7MnQ!VM*~-LmYHo# k60^cjDBqsb!p!vUr9!`NQ4*0gL29rr1$hnSkpA)i0q>j# zKfsZ{_I zk!|!f>Sz~r@i&;;#a*cMCF%%ET)HK&58N>9nIKsAqJ+y}KVAcKY9@O=wHqys89>ly zj8#qL7&F+=kC`@-l{hJDwLt>@PKedE#0b)X$Oll4)j_tzCswMZA8FHSW}Vd5XG_T) z6IJa;?Z)X%z|ek=u|jcc8T(FtPiCvOmNRSDv&H=Eg}VAQv!$AasupwYtl>*L)mLT@ zx@NE}Izk^scWFjHk5RDcZh%RO!!*2>jQJ5mb{mCwpRx6Zt3DG-!I(XdLMjELKY*ga zO{lq#_#ryfX9&x}FG(^Q-Y-DD$G~#hvz*@QVDOZG`d`jgwFUzH$zHf8m~KIQ?pkA~ o|8x%8hr74QT3X9Xx%&*E{VmfTAy1zj_M^U{NEQ{^r{RhJ0JRm|h5!Hn diff --git a/sgl/models/homo/__pycache__/nafs.cpython-39.pyc b/sgl/models/homo/__pycache__/nafs.cpython-39.pyc deleted file mode 100644 index 6161ab0a693b0179dd0270ed3176c76f93ddd060..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 867 zcmYjPO^?$s5ViBsre%vjEQrINxK)Do#t8w7R$C$Mwn#l7X)x~Ou@U2v^^;xnGT6^Iz<@U6fs4&n%cki; z(j|-)Rbd%xaG)D8-AI-rV^-@10l0fUSDG9%2nRBsKtWb|`63q?sa9^Nn_4qtr8X{E zNp6`aN;j;xR&RZLOWlC6OmSlvyCr|d({)|SsWF@Ba(3}TU45CFTupsfgSC2E@s*kA zt7ZUG6Iccvp~t~}n$WK!6m7C6FpiN-#Y@SU8!|+QvI(=9Meom55zn_D6j{)PP!#KH9p}|w`=|RmDrH1j{$xgN>Sgt{OZtKQO o{;TYj_jm7#r8I`;a{m$h_ye3CG2(Q|tB&YvieOQnBkG^qlw0}c_+IDu+|!(l$3p_~Z1Y~8Muhyt%{~$UZ(!L|NP-BeNKJD} z5l1Sj<2(*HR!N=aX~2o<)&0Cr$()E(^lpggNxF;kO!SvzlKlcEGNFEa!mV7+-+nl2 zgj7&CUNmK8T~&OPrIFTF%^Irr7F;V{RD3-*Twh*je09OCJA)&XN4I1}1r^avmdCdwA$bBqVjnN(?+Tm{1D`1@B(N_aIpSzbcG1_UrES#4-(XS~ zw}I$O)Do7sbVuOXbA9aZu$Bpe)aNA-K>T!5cU#t#hB|foIt^81%Ml|F+Rbb^Up!Y=pJ%pIv%qSx*TBA*>MJvVOEXv&6QK|2U7FGB z5m+|e%b%n;%9_`bF+XIeZL1KE7~5>Q>Nacx&(P<9<@R{R4j z=a;-fV#P18;$DX`y^(M3!#O_p#C$v+60ql+5Aq};4IB#3R`$7&vk)-hZg@!>{AFN z$*CZDbWK)NP!U~ad3;S0k|#hW_VIH5zQ8H>hH(lDLF`L#jyT$qUGzO_X&bfi5181+ zZ2@q8iBGIlD?ikx(abujt8({Yd5pyeDOkEewo=)%>t^$UITeC)t6=fhi0%WIzo@=U7FFa zBam&nr$0$?lr^slaI>=nf;X`288|@%Rivgl zr5Hyls^dHkajcR$&C?Jks#o{(J|#0EQqj91q9^G#&NI+{E#8Zjl%eVvGs!Bh05`ZXH4iX; zhz@1y!?N&Cl8lD^1*rHK$WA)4lf4`UPx+_+)ofL3Akdxc^m~Hk2JV2n#@NX}xx*EN hZ}-Kz`@MFz5!&yWb`kRQ*-1C*D~e=MqyrkB_%BpT+^qlr diff --git a/sgl/models/homo/__pycache__/sgc_dist.cpython-39.pyc b/sgl/models/homo/__pycache__/sgc_dist.cpython-39.pyc deleted file mode 100644 index 02a52ef25f1bd3dfc3e3bae79d9eb8e7970b08be..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 868 zcmYjPOK%e~5VrS`Y(j$y6vTlOw@5U(BgC^Jm$XIXwinB+o$X5AtYbS8MSDx7p7;-N z>|gQ;i4%W;6Eoh1W>%h^hduMnXY!SW2{2G4fq9Av2LO-%>X)Quq-M< zkD~iDqu&m1`J3$K7^gVXir11cKV%4UqtHHJY`x*C%YG;)4iOnY7OnXogI5mFx|inaMu_+`A>GZfN=IM fS$B_YcNd}k9?pOmar*34H|T4MV3DE&8ou}+qQ~7S diff --git a/sgl/models/homo/__pycache__/sign.cpython-37.pyc b/sgl/models/homo/__pycache__/sign.cpython-37.pyc deleted file mode 100644 index b23428af018733ee1e04604e93c50526546557b2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 941 zcmZ8f&2H2%5VrHPo85MyhyERqI7for7eM?Ji3Mr9t+cmXtPs1~ASdzKjzm#zt<)3m z07qWQCnQdN1y0O3U5aqz$u~33_?vHXGMS7CtY@2d;%P|8FLbsm0`MBPc@KsYP9>>m zMk(Sz2344a9)~ij;w<(!l7niP4XY$eD47!;^T7?_13|CDY{Z8vk|uXhjil5~PI4nw z^Ec z#1l^7zIeHYqOg9xf|jY)6!0CIbmU^wDxvidJeM$5)}>`k<9ytNDJ|FRw(0=3z1vK%gT+L)~fa%UzVJUx{udw#Uu{L z0KpMsh0Kj%><9Us&Muou%#7L0R`bPkdG&c_)^g@=Gni{;rd-!kb)|QJ%?P$hz!170 zJPP*cmkCNe-8PoS*kqkof-yH{$hwt?cNyEXx$HB(6tv&-D1?+C$srW=CZ+lS@evx- z4ev7n0_y!`I(7%Za@w<;KIo(Ilso<>XUa+ef&OGK=pm*X7)ov%ZKi+q_cQU|J|wID gC*6)mO9u*&j`k_8c|7+euCkn}RnM`E}l)9eN(B|sRhGxsBq4bMP-CC0J}s8T;O+~2 zp`F6PDs~CV5l1_68U2Vl+C^Rb6G(A!7pi@aI>Hi{3iOkGxCQ>e4STPPc1f3XX%}B6 z5YGgGp6PN8ML~YOhS8}p6zp46(~(bYtEDkR@O;KtS(lD6gM0E5rj2B4B#1TIAVj|v zVwGKE1OY?jGmy@z!1%x?R;p7!)~3pot$6;rq* zdqe$*u|jcc8M`5Wrt`C=l5=a%=c~ox3w7~zZr5rax>?LMb6c+KnZ7VvfMx{8W}pb& zjUGii^xFjKp4|^NO>xMYSCTP5W{A2~Xm2xi-g4DvLMa%($5BY7fRcSE+N(XZ4p1R$ z*v&?0(BLnpW4{L&XFbN*13Lz9`Gbv|Eh`P}`Yx15ViLs*@gt6sx2J&11mvt%kp`ToUeho7luU_8Memk~o}|0D$V7iZ^6VZak(~O`1-EiBee+@7 z3aOxQJZr1ky1M)*D zk`z>sBDy0>DyWEVvm(AD2`LgF68mJan7%DhA_o4auy8_<>xiRUvWvb&Te^+5@pq8w z;%%7pHQExExK!Xic??2BrRVy1a);ea5V)mxS;0JrpXWfF8bbl!W6C3+TE~s^v0RD# z2zDDSjp;$pXN=XW+A(I3mLD-~B&*N>)@p+^{ew_zY=sph4V6z|DQkkCiBGIlOFz)2 z)yz7ntn}C7V5o4v|)-v{k{LUxWZ6hbvt|yD>?76!6Jh7FUgsB#L zE%d5#tZ&Q^Zp~m>OoX0959u-eGD2I&djWZhVphD7jQIgWb{mCw$k=+rRo4llKxW6I zluCj0Qy3ajPR$A8Gfe1DAC`r0m1H#RFTuW3U^(tsjt@QR diff --git a/sgl/models/homo/__pycache__/ssgc.cpython-39.pyc b/sgl/models/homo/__pycache__/ssgc.cpython-39.pyc deleted file mode 100644 index 0c84742b0fde6d88cd26ecd6b86197e71ff4ed53..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 892 zcmYjP%We}f6t(A(OhbZDRazGOKqN?(Y!GjeP0}K=n~i10&U7SC#<3lVqS>WVSNsDk z`%7LSvFb0d;$9D-@yJ*A;koymWAo8yKwv%Fe3UODLVja#xDp87z_uU22qLH=HO(nS zKT=U0=W+03mDFjT20u}~x}WzcnGunS-Zc?DNq2FciT;v|vpeWS#?+7AaVwXzx1Sb` zkO~sV^QNq&pv`zb+(SFW?9-<%KS=;o*w^1@SjoV_oQ=6<1{!p1Ivwj++VWEcI2TG~c!{1Yg< zxDDODM=fEAO9jTs84wPcp6lbu9ab|zU{deA><0XL4Fsq$6v8_+ZN)&=M+8|8-AmnOWVg^A&<_k#6>Y!`l6D!rq5434C zvrcO3vz6qIiK_O4X6y7e__x#_F;*yUEn~Oj&v<&-)N*R=X1bispR22{Q(LNO=xVXn zK&>n$`pOJpY6jb)A@n4ANRR2a5lTDR6BwsRX3cBKm>)1iw^is58QW~R>MEfX(Cly& zQYoN*3Qaq6g5DV#RHqNy!mA}24e>dM_ZTovI*gNh6%3y8PY!Cfsx^%7PIj`LVz~k7 rxoeD_{8u?BAO5{B*4^XVUqTpvfHNRQoIZQm5q*VM9rPK}@Wp=tOLpP4 diff --git a/sgl/models/homo/__pycache__/vanillagcn.cpython-37.pyc b/sgl/models/homo/__pycache__/vanillagcn.cpython-37.pyc deleted file mode 100644 index b87e19c1b28a3e074ffd1d66389adf78ca773200..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1422 zcmY*ZOK%%D5GMC2$+GiwQJ|Ma+sihH^aqOO;TA>#M{QxChrJN=l9E>8K0ho;tIVDwh&x$r&!^oB2NJ<>BFo!1(izU*wa3kiT%V8?YzO zVVdv32qLH=CCwZe*vQ5mT@AVb?68ZvP=vSPhfx>*h9?gMfl{NH3lJ9c_E4_5 zYPp5(0uc+y1moi~h)|t|F0AH~TNhNzLb#~P6*tDkmEu>>3l|GrH+5^%u?tPxNUbAy z+YK0lGA&~o;k!fDXvr4HDytiJ%yw;HhL??kexUmSLnL=y8kuzCsx29UXtgx%0CznK z*m$;f2kZi*)Pq;)zy(ICIpE?Jyw=_Se=4cnc%KD!BPdJAiix7s4=~#Z@84Dz3ArwE z39T5iO-{ysC;v>R%es_PW8O^9PtTsJtIMfbs42Ej8I0A_^+rdFyqYvu`Vm}w0Mi_S zA@m^~(+EBueN3+(?IkzabuGuP;lL-(6PV^}FdbQu5458_m7WEh%H$dUEM3 z4fSIsQmVhiLkO_fA13UP4g7!}4SU}#9lB$nX|kbdvKNi+e|(=| ztOf;h_ diff --git a/sgl/models/homo/__pycache__/vanillagnn.cpython-37.pyc b/sgl/models/homo/__pycache__/vanillagnn.cpython-37.pyc deleted file mode 100644 index 563f2c0b5459fa110eaa53e8792829c69f7f5a42..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1266 zcmY*YOK%%D5GMC6S!!L^DVp1&=mUF*bnT(28wY_8#lKSfbkP7 za|%ijK^3W}M<}Vh+^6XAR8WO^IM}|5syL6UBu^+=5D|#*4H2OXul#%@=!#^~M~Fj~ z8FinYawAudpFKZ&`m7OBWu7}c?}jefnjF$;WmD~Jlkd<9C9oC=bA zg66(R-jX#Hk%-@n^588ANFGAgksJTU>ryGcST0|mk5iDJZCaX?r!_BkGTljS%BD`6 zbvh(WpEUKZQJn=RSLssPi$=dpFScc|P0=)IS*PnxsYmDBg3oGR)@8k!4Z)qic(VLw zz(_uveDuI_e$jTtXzwHVAHxE^iBX^|v9u>w-V3Nhc-NGYJ{0~-Pk(LwE7B7d)W_fn z!1KKgMJS?cxX* D4~=-4K9)Q4AIPhX7t0Q5m0_rS$<_-@iC@A3Xr5a1(Ox&X05 z8>Vu{RmZJl6?DM`tH+C{Is!l&SRfxx;1bnY=)!s}xphIkErpBfLUChUTq}MF+__lj zrfoW#ed0pXwNmR4Ub}=bpv5w#v1e||S}oZIyUUtZWA*LsfoLk0CIAjs^`vk7z>2uqHHuUqX-R>(B2-d3MWg z77RVpI8r)6g#(sF$bZc%$(S26q@Yu1PZ)dEadmSNGbWk>T7)T~MSJ*#R0@3GVlY%0 z)!(4~9V+ZC4vE38CVoH%Pd0X6LUpr!bu%+GX-~Xy(cbH(_uTp+1aJ@j_v)@$diZR< oz_SJhirYq;*}d5rhH6NChjVWsN|*ewulPrbP!aDD5P9hT2cKVM^#A|> diff --git a/sgl/models/homo/__pycache__/vanillagnn.cpython-39.pyc b/sgl/models/homo/__pycache__/vanillagnn.cpython-39.pyc deleted file mode 100644 index 0bb7abad65e2f40f285ee876851f747e7ced9f35..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1276 zcmY*YOK%%D5GMDnwUoGSQlz&<(W2EM(zS=8c_9!AIBHP=EwC4YUQ*H~+($^R90z-I zfSmjvnqz)V?mQLfp~zq8q3z7dfo%#q7=9#YzWHXvCzF)GdT{u={NNMv7jCYv05%Wd zGmpUtBB&x2^#~=Em-`eEPX$$&hlB5{sEYHrO7et~1rdP=-w+YX@Z8Twg09Fk`Ur8z za!TFp$K1%(qi4SyKYi8+sivMgJm#(9Mak=h=I!?9){TBWy>ttFcYPr%csfa}UW|RkSNkH-tGLKyP8?Q^H_+q(yeUfIdezI9TmPK&ga!2- zKmtI%x1k6{bOHBS(0gE`-Ul1o#E$y#-0SJ{v=6`@^l=Y*T!ep&`{W(op9lh)Buf_{ ze`v=m8^npxM1~Y@l;1ZNFzPCK=m|qVZD~zx}e^c!bNqVxG^rS6+Z**Tr70c zHl3Z`a-r#3sdWf%UBVd1Vj0ueGdE$amTZF^W=*THa(9<~(vBHkwF*RpGxr%oRyUw{ zBQ3X9yYZ@oi*lG9&9Uw?iz z&a+PnPlKzT=>#5Y9I$DGJ*asl8FMK^89Igjn6Z}~SC=O-W1=abO_(3rv@fBMN&)sQ z1{=&b=zWI@1LuY6J diff --git a/sgl/operators/graph_op/__pycache__/__init__.cpython-37.pyc b/sgl/operators/graph_op/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index 68a3a28e85627f05e50f074030e4b4dcb96464b7..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 330 zcmYLFyG{c!5VX(t2qBRWJs)rdSVEKtA(}K1uXJu!7Cv20vSYIj3TgQpeuK}rmWp4X zVjXe1wLG2~d3I&r&F51_b@uVpe9`%x#886BE$w(oV3^^SIj%TQJhPb-N*uVbxhqsb zJU4|YTcJu*tyn$z;S{WSf3noU>X%M$7O2DeA&@tH3^=Nz=iR@+j53%;V@_~iSL)Ob zacEN>d^p@vvk-8!Qbrx)Zgf(Ez5>wJ0(>$oQ`ZE)S90xLBV&A(tHtu%?%!o>t@NQm lz4aL7Fh~M)n{Eqm5%&G$+IjP4o165mm_BaCX^D(y{133IQoH~F diff --git a/sgl/operators/graph_op/__pycache__/__init__.cpython-39.pyc b/sgl/operators/graph_op/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index fd342356196c61ce434981632fccdd1105a887d7..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 330 zcmYjMu};G<5VhkZh0;ohoe#)R%)p2cTbBwb-MpNvI!Ys*YhxFcu6zm`zrkl@W#Si@ z@KuV!S@(SJ>F%9QPp4y+;_`i2f2IGM$p6R^xlb9d6BuSVGRqa`RA-o3p~SHknA<`X zoHfkkA6%8DSh1@7;c2Ym-T6X$)azDn=Ry0;qfgrS$%lcRKJR-X0&NX~AW!g6jnoD9 z2GEcLIDhO?6Ys+y(Tq9^`#~XtZUmqa0X|tM$(7*uMsA#~WsEO!HDAo|@Frt}()l{* hoeNR+izGn1ZFc}y{?MJ=T4!FdzN5<`8BKV4*&nH4Qo;ZL diff --git a/sgl/operators/graph_op/__pycache__/laplacian_graph_op.cpython-37.pyc b/sgl/operators/graph_op/__pycache__/laplacian_graph_op.cpython-37.pyc deleted file mode 100644 index 47b029651c41761eeb03ac5032f9435bb01dc515..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1090 zcmZWnOKTf35SFwLKd?=pNlSaFdns(8-fIuVEh#0JkRB3x*o#nBO0q%wh$KU7HkUdO z`Y)1We?tF?PK6#y{z56VqqUt^2Q)i7n$gUB^KJI~T?F#|&u8(KAoSDyti%W7JskZ8 z6h|COl;eb8XA|k=e&RQ_CxbjpLX5T$_jz!Rcp&__m$W$EMNxPS*=Q%iwzZ{Lb@ZtM zbA$0?YD%i7c`l5~Qd*QMPhI3nTy-bjE8M?s5OTlXyJf;Wj~}H9ULzu$j@S z(G+&u$ZF;fE^!RXpc)%|Oy)#AuJNUJh0buXzV!NYqFlrBZGr)dAJSX@FljPFoLpN(PQ-j$2G- zGr?8g1vhHoVg@uVp{kmshC*ZyeA+35d$gaO? z<~kF{)Ra_@^IRI0a$1xs=Pq(t90Di$B`Ap!EKov3>jE8M;R*k|lXxN!;W7ay3MML@st(wf;WFbelS>KHy(@Hvi?zAepA!Ab1OiX-S8T!>rgYqzk~311 z%l4I<580V~*gKMa*e&C-;NyPIV7I6Je56f(A0GOer`337%H?HI-rAnn-OR58^iPNwGM%1fxKM%++{jKHUzLO8-*IBgQ)$@Bk5doOE1Q~)U|1T_d9LkgBl zU`UtJycDC%*%75DBbLoxn*FH_h_}5tEx61w;Ljbtp8L{)Yf>+BrD{?s5OD5qmt~;TehjbK;XY zz*=?N{3g}B%vF_6es5j9t>o8My13JQz4_7|Gz>9Nk}18QC0qiKl9q&drH9(*1aF4j zD6oMZOQDp@wspn^qk&<}w%Kh>#*}?@Z=YncpG3a(wa5-_Q>sETZG_YolnrKHunor2 zEYX^=pXBdmuV3V%r}a^9Z|BSFY;xSw-?LtUwG&e)-TUrvWQFWWsn}xEl@nFR&fL-= zDgrAI!pZ7wcilzLDUP}l2ZD?Jh)W-!GNc3pLpr7EaS0dRB{_xB+Wgd?QvJ#V9H#IC zOjsgh!`hUdQjTy9>rWB3HR3D-7mX5zP3~=wOIRaM$SE!9#nPpN(%(JV-tUVJQXomi zD4lfj1SxRb$;aAs4$#rj=|E1pS|&=1Zdw#<7ST(!nw!tB0lQUBMJ?ML=)p*vWRwa^ zwY1IKKpU6>Z@{CQYiwhGBE=`A3Z)zXs^OG-2i6l4fn)9`5a+zf;AG?OMDAroR0SBW+(@F( zacq?20y)-XEHC&tb9Rfdqj8eWJk`k6GmN*(a}K2<%aDgRy+xoNxJzA&6?bcBXh-5Y zw1esdA5QBo<~E#RuhuZit|R-pH|*;kBw8T(l;eJt;q6h@Npn}w%_F`?fdOO{T=)ZpO83I{R$G|0jz1hPxG&x13<1L04-sKQ~Nw8I;$CVOqLm0cCf-ZzQv zX3Pi7Qq*f3XF;xd=TP2fboL?cK9oJ5pY&dCvQkNan}$m0PI&q(B-6Q4u@ zzE`#N?^4a5x-QGf@15(nrTp4SZhN|j^S!z17CMF)D9Dsv&;l+2NI?t2yuw54bAsJr zuN~MxkEBq_Rog0K{bAoQW}9sBE@R5B-8;vr?8R;0`dXw%wk}mJnKnXd3(5v%^a^Z^ zu_TSPX6%am-R}1COmwwA?(Xk?eVvX^y81`j&GGHn6aJ9C0-zY_<9%m0L`578J>fPn#>QuVli3-6Mg!f<1L>rbhEWdaUU z_z5N~5Vk>ON>3?A$OhG?2;By8o`H*60fRdCw#X%HkSFAn7W87}(!uHPO+FsaLBg9gL;;tW>U)13=ZBao>gY#8}{}d$GhHEYzy)sg*K<?Py$u zW>DVX=hIk5+=esc)h1@yP2^waj(?rQSPP_|axAjz_hTc|PxTB^_rTH!3yiAoW`E?p NTV2W32BOmN{sY2`F=PM$ diff --git a/sgl/operators/graph_op/__pycache__/rw_graph_op.cpython-37.pyc b/sgl/operators/graph_op/__pycache__/rw_graph_op.cpython-37.pyc deleted file mode 100644 index 6f895537c519b33eeec0cd34853071ee479e7401..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1000 zcmY*Y&2H2%5VoD4%`V$Tf>Lo@E{GtpIf4*Es0dX^s0i9qE>;@5ZBP=Y9wd%}THGQ@^#nGN{5dq-2K(Uj&y#1k#^-X-CjK5{I{Vo9xEac6M~u9Di@ne4L42 z+$gTqDX*2Ta)()oi|YhlqLL(~f}~7@SLBciCcMjT>Mi1={uS{_8eo);?fyJnzJL1< zB}X=`xZPVW*hIq+LuX{puGx&h+2C~7cSolDE0vAsLD4UEEjc(lH3yZe(ye$iTK?dY@+(6Ogo5v+&l*nt}j zHxpS|Or_9|(6dl|hZ*o(VL(034M(K&f^!hsr#)N|i|F+Ey)r(4$5aQ1jYHSpI;$k- zw#Rv;#JF^JpYxM(RxUhkck~wG?Z&d(xhzZ0?{dQT-2bCp2JPIqhJqWQhz`8=1dG(a zUrQSE=vvbFEtC2_Ah(T4?wd%}THGQ@^#nGN{5dq-2{2Uj!FK1k#^-X-Cjq5{I{Vo9x8YcD8la9DQ%l+|0x; zZWPz*gx5+}xx=iBi|qtnppqn|f}~7@m*jv7CcKMo>MdfW{w48A8epW3?fyJj#<=~9 zl4DgXakjTyu!V*phR(>GU9lOxrmhlvgYLwh4Pe|ztz8Z_;=HH}!+FfC53)S4y+*5s zgOLr~P}}7^FEaqnZ^)lF$&spL0&tw{Zhw1SPEHdzEE9#VWkzX84jltrHHkjqclM~6 z=njU8a6y|fojzO@GP;X@ub8_+KD#^bP#H3#hW=u6rk~8{m3K|fXuY*O_2&#;nLyAv zJ*Os|;rEQq*j4A+mCSagpZ1RA5L=Svvd$;NDl@t`9adv7!+mrNAupQA2%1cT9OX*! z1tQ5oFU{w1WP1e^HJGfV^2FGtcU5? z4L2rkKC-fyN}>DcS*U@-Z17xRKt0Y4N8|%M0|z0rPkXo`7SZXmH7f7HW2yth#-ZzR zomG-^+vB`aVq7}A&-w8T16kT@vX4TODk-#0cz2`y>uyK5=51hw2m?-zx9pe~OP KRiE`~-}?t07Y3vN diff --git a/sgl/operators/message_op/__pycache__/__init__.cpython-37.pyc b/sgl/operators/message_op/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index b0231416230aa9c1e04ed4d093793179995fe406..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1044 zcmaJ7E`}ugrf0RA%J6?8QV95*nn{O7F$9&;Q zAABEqNt9$6%8r*sMOLBectsq^8q^%Gin?q-!|@~0lr3mEUK4=~A#}Vh+A@O3@rLNg zE_5AlieuS>p5rZXBKy#HJP@aH00YZI)@Bjwur51hJ$AzS?34}Ia2ddmoz1=R=$CKH z9{bzql}Z;hpK@d9lFu@0diXZyn&y1MsZQy8!Qb&@xym^!I)xKz_M%4rbq}U=vll+) zG%Xr}X|i`8UTXD$FKpJk1rFwk%np|o%-6E`?ETp%uIEyzeDyjpIZYRTyK%jCE;YYU z`Zh+FtS;IP+tU?sg|?)dVjva{k+fPDt=NfmP4Vj-jRnm zfPv#(c>|B&$nl=MiN|p4cqDJ(5Qbr}46^u32&1e&55~#Q(6%@Zx5*1tEO<2)#_*+> zmDY9tbtN>f#6)mi@V8RD5&3deiOdg*C*160o%}aEn)1y~`c&}3cSO^C#~z(&^-e5o z)$0uo=Q)=9+luCE^pCwe`ylijm8w=Rb5rqR@n;%0>*1Ler%HP{k)|zbyWT)I2$}>3 z1TBI#L5G0ug6wBpa zxc9Hjm7prrDD;tU{6uOT60~ajUYf*)<~vTB4DOUyy0hzZ(AaQdocep4CrO}h*60@_ ziD0PgBA9WTVA6`oW8SCp xSCdEbe3P1`OjSvjS80>}T|mlMUgQ;HkIVCI`zfkyEyXkX1I)-m6z+X%zX7|eBffgc5S?8+CM^j_AP`6#Ip71a#2FPzC3=By=*^dr>)m|7+D>*i5~*@)QU6AJ z7{Vx-;KkMW!L-um5ecnB1)mhbhk0QkY1UBt9_OT)dONr4gua?a9- zO6F3)DX2)rSD<1oFQYtC;tbN>PpXD-D#F1>?OZX{pB5z_?IPMP+;duJ$c2JjUc*Ev zq2yJcU|P^TzJ?g`9z%xvCrw=zeh;_<;mkZayJs|c@L~g(@) zq->GWIP~0Vp{?aFLmzQo&pl!|#9}(N=+_IKMj>|EOsLWq)`v*fie({|rf?4NJN$Z+ z&6-MQ&MmXE(+@Ar`fKKx;BI`Qi{$#XR+l1QRVe+g3!~=2h$|jbErI_r+!mz?J6RW%Zk@!cQ7gmi0pfC1nC_<4 zY#%cGaO-8MjX~TuPle@%mN~R9#5T1S^~E4=Ptw@lV~9uO_vnw`qw~UPGTSoMwr&sl Rx2)G4wIllHx43c;{Ra3Nyi@=H diff --git a/sgl/operators/message_op/__pycache__/concat_message_op.cpython-39.pyc b/sgl/operators/message_op/__pycache__/concat_message_op.cpython-39.pyc deleted file mode 100644 index 5791dc3b12b668021335e8c0214417a6912ab421..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 796 zcmZuwO>Yx15FL9rAuS1&DnKA{+KptJ_?NUXPzVi$yn?aN zLd(k@K{dcddIc#I9fpkuC-uBk;U4r3gt2(fe0lXRXbKR-1}@~a+=z|5A-0-wOf(Rf z9K~+gpk-<7_|?KV$6v-C;%q((h;bj)Wa2Qa7bZ($>Wv*!Wl%1}#LV?A$wMry@*eRP ze!t77b!Bqzm-*T0$JciKE%y_f*MzJ>?R;LfUSPe*+s=4@GF+@(m)bvI@~i^Hq4@sn zAHBmnZl&asZScNBhMqL6YH%q>=oP&*6rjMT}S~Q7NafJ0&6I*Vx5YP>Q4| zldeaf$Gx#p0qx8OH=w+o9`Q{N%ep$B&CNr~d_&@YEVo5jm))GJ$~10bRqK`IeIIeT zQntNmbi2n4KiK+N8fy`EnsYnH44OA$qZ*TXqMtS=S?UfL;0ZZC{OkDeTzNxwJ0{xH V?LaolT2XIXVn_4`Y;k2j`2%G8ywv~z diff --git a/sgl/operators/message_op/__pycache__/iterate_learnable_weighted_message_op.cpython-37.pyc b/sgl/operators/message_op/__pycache__/iterate_learnable_weighted_message_op.cpython-37.pyc deleted file mode 100644 index 79623d9bdcb6afac7aff6d20b022fae695a5c824..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1799 zcma)7OK&4Z5boEE$BrEq_7P=;!~`LBy%M&x;22r)5;#TCe8d?c2&0Tf47(kYdYK>j)tsiaEQkV(o*}`6cY%Z_gBgwL z0`H@?e+{Of?Aci6jjC*!#$4C_$xeJ!_xR=TUnG8opiq#d)E7yGZ+#I|{I)%<4gG)3cm!!lc zxgy{VtT`>IAV)Rz&@e$~_&Z!u*!2X0-iI@GMAXRCN;bp7JEH0Y1SQ zN^b+=Fe6S!Ab=X|+55>HD^HhT`SS%@7ZfG?6Qz?AdH!T4KZ(;sgv;ZwUu3-`k4;kK zp*b7Nt?+B;!)&5W*puOd8{G#T<@~(H(4Bw2v@?!XoJqJn%qLk-Djg0A6`GL@Zxq7? zxYeEUaH!-kHigwk zzvS86!fR>$JA#&`GFQZ*^exje9ROAnD`Tn5MRe;wx3E|ZwG9R`He4i`i$7pVW34&= z5&gN}9Tl1EYJJ?@f42KUdUo2?!?as~_bwp1o6R$r7vt{#zbaozb;f6^4)E6C)Ayl6 z_$};{02A!P=;Hu4t9gLWAH6bcon`F%&SlqJQafUQ3ZMQ1I)SE;e20+IkgmT&f}=wW z$w^OtF>c?j>5M!@-~K#hhFy9ky7XbJfh5~RB^{6}?0!?PWKJHeM5Q++`FbTqO3z75 zi3twtB@sl>9|)}2Wn+04>;j7KiC_b=i7wZKCxB#E1TZ``piweGHh_MB;qerRwbCv9 z4Fnh#`lbI*{SENq_@@JO{O|zf%rqTAU7^{lGGDn!z;jY~)hcbT|AN+TsV3}OpU<%oimN4BPuVOAtUwJOOGNcN8`-iCu>V6yns(rl7ic9O{F_V!+3cCv9Q zGnqrRsS>_jQ*SucdX-gaqK#dfGniM4){@TL30R!i+I&Y_e5me%>b$$>j06j2$hTot zRYZgtZouyjo2wd#?B)LJm4FVmFs72q~jcN}s^y}QvuhWaQ1nw#Kmr-4?Li2JApUi#w$dxbcPzAuc ICTY^Y0S9UNt^fc4 diff --git a/sgl/operators/message_op/__pycache__/iterate_learnable_weighted_message_op.cpython-39.pyc b/sgl/operators/message_op/__pycache__/iterate_learnable_weighted_message_op.cpython-39.pyc deleted file mode 100644 index 34b253f649536e64e4d35d4abe41e7114860274d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1817 zcma)7-EJc_6t+E{$t0VU(uJ*xU^SHxOZZ8o6_+c71azrX3Q`fEk&rcwon%_iOlmvb zZ8fuApS*yq@v&wQVgbUG~r?cvdu_^pA^ zAI`bmC^(;A6Urh#X@^zl4m^?#6%rPftVeD~_ z(=!y)XNaQMhYftYxow0>jM#zI^JHI&??gHp8^Le3aVJ~006NZx3Lq`P7-VpX=H#4| zxFi<@P6a;mN{^Fctll-0dvp9FET8G#syNCU8))X6 zu=GoR9qKVBE{7w48}8aS(i$(5OTg^O09II{w4EO(GUd^7IvNz&ewrsH zE%L~mOvGmN4cI7~Y7^~?=)raFgPwBdyu@HmzgVs_NmP;v*geXp*}hOZ8Wt)tV-Z~^ zMhkE^dnTii5~IWvYE!x8x=rir=|m{ys%S&T(mXYc*%ou5%;B=FZxcBcpDR@;`w@HL z2V(5j$?6qeD(l}Bw4M~XA`WF>sTOhwSWTjg^+e9&)&J~56E)H{9E!v+o@URX{(wa$ zS~K<>{k7L07n$g5ebnE3yz`+vdD_<_*)QPqt{}Rf)l6o^r2qe4mEA=3CMT*1@b18) z?}I^j4F@E|1d}UD0uta5w-;WBPv3iC>3Ylh58Sbr&5_%Q{Zn}KpI|te0SOO)+Cbai zf#~P}14%vOMxE!)$!F+?UuV?NbH7CA0lXU!M>kOE4ao&|znNcB7Y}QEU-~nWuT)W_ z^j*Y^nDC%k5>B}HGXamDx0XA>C}8-Ba5^L#===`%IY{k-I0T%Ol#_LkARu@+199DC zmccp#1dHr4_-FPyoZ#q}qjz#@+D@SE!r85uuUrK{P9A*Kinf@Qwd#6?~tDMeK&ddc^oZ3b`qAfm9_aL=4>u!;7p$zdRc=`h{ z2vgjG-x@Gw^(r7O;Li$PAr#*S2CV^iR`I=y9XU{*Qw)eR9sN4u*c^xCghBi)vI3w(RGDtIr? zd&6lyaF<#lEy;7+y75)yBEw~5rEU8BM&bq7%;*?C2}F) z0*o|j3r1OSr*diZ4Wxqys)AH>!;dzEVrH+Kl27)gc1bv+#6m6%+ZUB`5f{<68c7&1YZV*cpEPTKL=6W#D9Z`rM0|Pn539+3mQr#78L#Xg$U{P5LmHj- zbRK?_dH!ap^nBM1r?zO(%|jcH`|yNs5-jWLcD}Gf+J%VzfAd{Qdc>kEipoY&u&T|< zaXvy^t_s)RL^A$_!N)s^O6wfrp>--exAQzf6q{*B?ui)bcyg7h?kPz)d6aN+TLeq` Xc^)#>9030&^tvDGE&l+K&`A9OYc-|r diff --git a/sgl/operators/message_op/__pycache__/last_message_op.cpython-39.pyc b/sgl/operators/message_op/__pycache__/last_message_op.cpython-39.pyc deleted file mode 100644 index 363a86952ac2b176aff3faacc9b05ad9aa9e1419..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 720 zcmZWnv5wO~5S_I*C+A3@;shN9jSFIl8W0lFU2}I7*SQ((c{d@_I(Bv)iJ){B>3##s z`=xEE?guDnU}haSBSxC_%eUO_&(r!L4v7+)Cg)xvyiE3OPJJtY1m1qMo?p-gUJF0|0eJ}Xn5 zYPp9}-hza3#He9CVZ1&)zsI^l^Ij!2I+Drf&cM diff --git a/sgl/operators/message_op/__pycache__/learnable_weighted_messahe_op.cpython-37.pyc b/sgl/operators/message_op/__pycache__/learnable_weighted_messahe_op.cpython-37.pyc deleted file mode 100644 index 829e8bcc8380f1cede06d0277d3fdd3088bc4a90..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2908 zcmcImTW=gS6t+FHGqcxjlhQT_&{ka1jc6iO6%qoWs#gjnATkqw)Sk& zRuc&|L45-CC+s3#;E^AI_#GB+Af!q>^S}$7A^s)6WLtGBnkOjuTEJ;*UofGe#(hE#IcdVG)U$&kSuvh1l_(*!5jAw<9ku z`b9$*qEcM;%W=i8P;!%Shr0)ayP~*n`BhHukf!$pe3A7g)htPxOcy^3Wf+T0NbTJ2 zCL)xuRQyCJ6?Vj@eULbnCL-Y7IM@~4&Q>OP5OS_hWxMG!+!orNe%6FgZvY{PPdV|~ zBhsdva(3YP7Pq*4Kz#cVv58-R*b4gP+Ymq!ZbssZiL+TGEzA~v9<%~97n-_83CT!C zbF$BdET=geVU}PXSh>Yn&z0vhn_K(zAzw}i`+d$=C(M2+@D<63(++BvrFuB$U_e}PJApkQrTGE?k%rMyj8Lc-4i$gOryVsAmi!mN&wTiYrJl@u1pU0W2VUP@MUOnFC zcJCN-&AKkAL0?FTS4(?A&`r8o5a?2XnF+9Ly7YM%4a9XRQ>iZm&%dTYaH!&A)fbD; z>M=QV+KNP?3s8*KmM-6j(lFZ)iAtr!4Qm@?)3tlyPFKhvNo5>Hfkdp)Hn4Uj!bH>4 zwh~br_LYS))0RlMu0Io)9m23!9wW6Dj{!7D`)f~c5CGJyZV6aj?cbM+ z5b;T9>Kb%}E>MSglu-vpht_Gu&>ll-0ahk+T02B9uoA5qxdrGSy#D<9tWK-j!gS*2 zK*Pdcg{J<5^Uuj1Ak#LWD4^1%jBc~c8rr}qKy#qEU$Z4L0u<-gH*60uDA_RbfT$9w zBXVZ_l37Qj{SIt-fYV1T?OAg#!$NM2=&%S9E=rvR%p0-o68Lk+{*ph8HP4{mLJn~( zkzx6eybCC8gP%%{{WIR$Y()RdBVhc@=LlqG{&tMNqVacN>{5vE|MX~Q)rat2?qMcP z9UfEly`dq;fbu_QwWB%A881OC3dxF*D)$_a`W@bxBblB_8T&cDxs~-Xlp^lIDYxLT z8(1~)C+onKC$K{VlZZPYR34*y5Or17w8oP2^}0L_Vw&F2u1eck9Nv>l$Z%D+6Q^A+ z&*02~8;4iZ4)AU|J1}eY{slB0 z`5Gz~&2`17WE1B8$&jag{9&2#{TjSfLW-m%3{zL950#n9x6q{NmeHU$0WseBv&?2MzMNx>5={qTIDsEC z1H(a_@t)C#iMJ3 jtQ)C$`Aih^kPqko{qe>>z|*PlLG}P1E3^(-u3P^Cb><-B diff --git a/sgl/operators/message_op/__pycache__/learnable_weighted_messahe_op.cpython-39.pyc b/sgl/operators/message_op/__pycache__/learnable_weighted_messahe_op.cpython-39.pyc deleted file mode 100644 index c79d7e97657ac6692e2c64094769bef57ff45bf7..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2956 zcmcImTW=gS6t+FHGqcxjlaeM#pe+}rTWBIxl@J1IDosSdyK%59x}D8T@F3(|pUQUA7q~67J^NW>LcIe6L43-I z&mNOD<&?8Wo^NrB+mDEEKPEQu3lLvHzkU;fNWzUsd_8tJji!Y=Hu*dl6rj1#)D=od zMlza{eKue@&Djul3+TYgEzWkRJe%3v+NTd`ZUMKz3G@!k**&JNz`xLA^83uq?S0Cf zo+a;qyldob?k$nMLgo#M+{p{vIZcLa&&f(#%%xHWA8Wqs^dq(3o(5ODIk-MiMvtzB#K&#ecPpG^!);rO& zfLZ{}`k`iSb7zS#SjC?4I$A~Uo+5{~ID!2?xu)l$Ndl%RhqO4V8}}oj>!W`VMp3#O zB%-73PM8Vprm`DI(bjCM_h9ozvJ*yK-k6-3Mk|dsx=EOI)1;Bz?~5~yb?6)Mw#phC zqOm-(y8QlRbZ8%W=mGy@AIptYHkP+~%PSHqOqQX01}A_#xIS^*pB(E(vK?;-sTyt2 z$u>n}hX3l#^bD3(blKO+ZwO)ICC`&lA+DZ z$J^ZM9b>Lp*9EoR7gAzrX)g%6NjD1uT?%k#0!*7OeHBLA;;NLX)aQejmue8qWPGIh zYVm14#)nQ@kw|m_lF{1K)~-aT zXnNXKB5K3BvQTE)5((G!7Xq_G*-^UGmuWvxndmFOn#KL#0uM6dQG0C>W{7rA0UKdu zQK(c9JSG2pzPg#lVpXZTtGBP;JRjY^x2ihPY6_c($4ITlBOndZ{_67!1RyoDn*wH6 z`uF8LM0^sOx&?&)cFLi(Bh(x#(b`yg4*vCn2XDVT zS1XfDw{UAGp92Ft2$!L$zn~sC*#o570#pT*x|q=|mRSQEI0a}9H1}J!M23Lu-1?5~ z0VXBuMjjAW;^>H+nZIb}5qWyWJ zPh-t9XgHTc97|+So{@I}wQcZI$+3UITbm8(e|ZGVpZXkv%+%kG@mDnd9vQn7BK$u+ z+FA7>+@X88lO_*OsQS^2VUGdzf5~b`_b_L?1i2_A=ZsW&hk(={aPJ)1>8X^lpW{7T zSsOts;vl>TEI9lIW(^$5I&kF)>=4l;;t&XxM`#~JU6nPhk>ufgT`qu_rq{Hq(smYy z_v8{XT-EKwX_w2>NI6h}SgLk_wd(8u)#}MNjOMy_r07E_YG>2%_550z-N0)p774se z3@k}JxDsClO-H_oibYet7?nH+w0|<>X&)alQ@&q=n@dQMw1i>m>g3`wq+CXmraOxU zy)n3nkAPBlpd-|#CF((;VZoNH5?z22TmbG{2cA`-i!fS*Lbhqc&?QbGvqI_I&p4S+bG41f}9X zK*?XamWsbX#mt?JBeBx%?A*@1dGF2p!QdEZy;**A7XsiHC;O%8@{V?VMTZ3|JY*cD z7lq8EfOD|Q>N~L7$y=2vE3P2v{HAIcC89pMbTOKV`?SdU(|)PlCBC4GgG^Y+#;%ebS;NR*V;&eW(5bI+!lS#m8vv7&3b##7Am6Zuqt=!zQAjI687!iNMpR@F) zELtWz}jFXs$r)m7Nx3WrxV+H;uCV zbKS*Op+=GzxIrRAk3xMkcBVo0w?L*VDOsOj&FQ{zrdTk#yPMxhY&@ zFczg{Jaz&x9}1L#9dXDzk2Y*vda`|2+X9=q}Y diff --git a/sgl/operators/message_op/__pycache__/max_message_op.cpython-39.pyc b/sgl/operators/message_op/__pycache__/max_message_op.cpython-39.pyc deleted file mode 100644 index fed534a68dd2b3a628af68dd85c9986f808def79..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 813 zcmZWmJ#Q015S{(7ae|Eyl8Z=`2yq2Ii8?|^NTf+jk?xw&=63Da-1+R?vt%Q82};F( zfHHsSS}Ohm6*GGdvBXNdvvWK5=Djy&FgONU?^d7Oxd8aZ$>GwJe4riP&|$%f0y2)W z5T(qdfHSbl>RYhd$s3g^D=s1F{HAIcC!#sJa3Pqf`@GEg)8SJ4B)p=FgG^Y++n$K z=z&!(`F+3m&h;N@hu3t}Fc+j0=90GhCQnGOmEYxD$4IW8uyQr_G-sgEC_5|3(hd>K zucNyBOVh= zl>DV_srU<2%=heJHY){b4W*(dM!X>KJ!4*^LY>aPOfcPDLy-7dSrAtFt zrq>rAUKZ;wX_yshO~|IHeMrk*Kh_H>O@+txWU=-`x_rRISpkSsar@{Wm6HRnTyg^i zXKsE^qkD@6NG5#dA-~EeuoMYTWs3f!)9X(eJl?yQyP`lm zUY!X4fN3y;?rm6XXHO)?y4ztr56R`pAD1ULCOER#Gs&*!2>7?4*FCMb^mF%u#_Bis C_p<{4 diff --git a/sgl/operators/message_op/__pycache__/mean_message_op.cpython-39.pyc b/sgl/operators/message_op/__pycache__/mean_message_op.cpython-39.pyc deleted file mode 100644 index 30e29816b1348f88bd0ac9eee6c3eae2ba66a3ab..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 771 zcmZWnJ&zMH5FOi_XmeyKKp+ql6f|3qm8b!ogml+%M{(UYBj4HyLHm)|8|i}5C8zrv zP~KnKmWsbX#f&${1u>GnGxqrTz30_Qas(7_mnZgE0Q}_QFd{0BY4}$(88D)RO5{So zDHv(g4vey5t8!_g3rGjQs0&ijjn1t1dS*W_3SK!3wJ-c@N-X5UKrU}!DvU647w5_- z6YU_6ZXkkuz^GkvW_5iJc?ZEroHI%O`j<4d(25Og<#)Le8~H;DBE11r1_(-?cTq)H zY19q;YGIw@b6t!$ujegdcZ7O2bJ(sIHdS5ZZ8@dRR=c)Swl=pU4Y4S-_lURf`+fGM zsch!`GP^ka^rl>2Wqwv>4I%5caXzbhAFx?as%yNjCyTWk(&Ymt&I&+065pQwqjYk| ztB~BV56=0(M;hT>S|F+Lm502R&tV%fUU>#DShhGPG0Ly1Ziu~@2uV#UT}&9=Xliwf z<=nR}p}LSnJ9LAF*o8s^ol(xLoP7-wB%CV+K$5 zP8PN-5f83@z$Ezo@H#9t7!-?f6joTx6Y_cT*XPNl_LfX`OfmEvzy(dcoOeB`5A>_| Ha>nWpMpd(e diff --git a/sgl/operators/message_op/__pycache__/min_message_op.cpython-37.pyc b/sgl/operators/message_op/__pycache__/min_message_op.cpython-37.pyc deleted file mode 100644 index ff9d5b1161f234f2e501cf65834131b285a13c5f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 799 zcmZWmO>fgc5S{%VQnY7j?|#s+Nvl1 z104BFzH;I(aAL+zNF`R)rCyo(q6qoa~mS%Lm%w4IKuIaFB46 zP82ec0?xoFqwm0ID{obzjJSlT_nWF=5{c&E!iJE}?B_+spLR=aFX0tkEF{7}A|GHX zj4<+UKrk(6q8~s*(qqVGc#)TTv>gIRCdo0pq@#vfY~WUYmm9H>KZv5C9F+k=k&{Lh zc^T=Z7gh`FJ^$7W5c4vx5t~CyXETrWdSN5g=wRKHD(lqOjk2X-L5P`4LqPlqe@^1- zs<3eg%lPu_{HN_dwEP+jr=C(x{c)KF{@S+PjdxguS{N$l|3StUAM}z zFHIjig&Iks=LV7V0}9Q-)TTAMJk)+jd3TKXWq?^#T;-*GN?B+L{zrdX5%t;NGA(Rt zFcy_rIld1Om#fruFRdH?1!Ir47H8Hu#C_*Tc<$zzKsz`#ur+TWhPwR{Y5$BNj%Wpb=ZRkw4^CY~@d)YA8o#fRN<0 zRb^g9y6yS((mKb#wFAVw${WP?5YzeGVY6AY7&B$~=%T#H>iYM|^-k zC-GHX+SvP5d~tgAuGoBw{k({4LQb37`MB&{zRQ6?|R(jF1a zu7bMkYum?8rACtJxIrY{fI@pPvuT4x?i)9vygOd-%K)>wyv!^6oH8Fs{wHBq5%t;R zDlKg=8Oz$N3%(B#SL?LsUV7W{$m$fGbi_WeRUNJzs=r-mylDUc diff --git a/sgl/operators/message_op/__pycache__/over_smooth_distance_op.cpython-37.pyc b/sgl/operators/message_op/__pycache__/over_smooth_distance_op.cpython-37.pyc deleted file mode 100644 index b6c9b25f7fe586515e504d5fdfde368f3b67a006..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1360 zcmah|&59&N5YC^f>gvA{bX;b^qYl);ZrH;LqO$1jEQrp+&S9WHlG@7b?i%W^GqZMk zmaNk(?8$fNJ}&qMqOT#Q2M?a+6})I2!o(-CLiC`yN8S zIp=19j~;-ik3n(7agO5tiwi;!Mg@&M>>?!hiy#g#dW<;b-YdjC5nR!D!0{Q1{A)Om zP9toGUkjy@rFc@OE_Ac#-cav@AyAAtipguVz#Mb(DvW6l8hfvihhpDB+V`G37xJvA zD!qD`DV>z5I2YM+r3D9cyhn6jhweljfzn80gRaPy?4SS+HlrK?A~}tZDu86F!(=>D z?8P21@~p3#T1e?|Y{*zvW|}cOV##tTnZB$AROcy?FI<(3rFo(hV{g#!kLIhY5Oby0 z^RvfaeVAWv=4zSGD?m$hCDpv>uY^_g{9pOIq98u2FQpHO2d)VBCi{GG`2s|Ra4$>L5(77U7qwKQEjE-BPIq)mxVOWtw|Q#tL-fs;t>5kX|JV&6 zMh!LoA=-Ly3Y%a@+Cdi`7|+>;Y<)A>2HfA_9mZ&Ve2(>|hQ)c7Vele@@FL2#$R14aRdj31%a zU;lWH%oun82Vn3WFmTx+hs)lM%@|%jJe96ZuH8*tJ8Sv<4?q3A>px(@gHoOq zOW5dEG_^iFT)MJeS;@lEgmd@Db4yiI*dVEEQF1$MO0{l;cquGCu>(~tbdhW#HwZnx-G|Qo(?YQs0-U(g2kwV$fC1v1R>QwKhjkHu3wGH}1 z)>~Dzvcs||SXps67F3O{8*OoB@r4{hu3ow6hlWn8;#pRTF?i}26v7@3@gW|8+`$ua zNRCJdo(G-hn6$jovy`YmJ!nH!cU5=)e$`!7$!s=3Fy7zX7aw>C z{c4TF0Y3Bzi2Mi?M;sR@9%ft;f-ov+>|yI6gX* z`{;Cz&G?~^GRefFCbh1^LHC9{frUUZ<|rmF(F$|S$%``KL3v1qf^hC60 zWnHWFyX|lGU;#` z&n0`h2c3Jymt7;;*5a6uvAoI^V`jpVENhv%Xau~{lcHF?g4Z&Z7KxOMy+W@)U#{y? zET!Bm&mMgFUUBhkDYIf(16rc$RxZooYgpYZ|Hbbf2L45J(fU9kHRM8i+tLOoSzi#6qh(;eLs?(J~zX?zcT{YCG)J^vqj z0r;q<+P{gm9^An?*pYtZyd&+|wI*9%kG28#cX)>}+MKA-d8p9&SOW_WzS-+OR1=La zaX;4deCnc^Kx9)D>B*&y;L|?hBZwl>V`ty)XWG}Hj&3161Hn#%1+fGixIRR^zxhs2 zfgiYm`CG>d*p{`Q>M3MAJZ)`xT-hJGva9U-AAbDl7ya|ytKVTjLMhM73^v@D&x{WV zOV`G$>b5jA;oP2hVyNs&6C_O|DsIMIB{!WA&xOJF%}CZORVL32O%r8kAu6b(bgArPP*9vJuOUl4B)T+TuyVg)uHYOM@8E;)T(u}LF zWL3@KT2ObY>6F2_!RM{*`$0`DYsk{Ne4JOpTJkn1ggqSMn@--u5xGu|Nr)%ljhvjo z{QtGLN1X?@RjQ;Eu2ooB^RBS#h_OwV6vK`y?e<+8&urf~sVEA@4x$ZVx3xjK2DByY z`XI!S=REVUy=DP!xCAY#YO(67RONM*6oyolnfz_HZ`qS?Koc(hYQn|iL<;D%*4BbU ZJRo=w diff --git a/sgl/operators/message_op/__pycache__/pre_normalize_message_op.cpython-37.pyc b/sgl/operators/message_op/__pycache__/pre_normalize_message_op.cpython-37.pyc deleted file mode 100644 index fb57b9a48f50bbb2b434acd694a0c0b9ec031b26..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 845 zcmZWnJ#Q015S`h(bL?P8QGfy@TBJkbPNGBzA&`;=9HeyDtd84t;%M(9yJyL<+{IEV z{sWZ!rERHF(o!*V#&!fF&D_rJ-kbMk_WfwI4`jbyoa)yI@RO5`g|v7_?p{$)pwK{x z86vNR5oI6%p$xMK;RIBm!WF1c<4utDlsJPpx}*OfMnC*)^(XCIey&fOalyMA;~H%4 zIRykcGDpLKYG3oj=pAu~B48vk{%+horKE)x zmvAG#iKX~X0+y7?Kmmc_I1YW}W{tL%|M7h(r`5ET(hJ$dfe)QF6OWA#)U+fq11SrW zJ16B2_9^VFT2v?jE?t*uMTwQN~8X=Amlh;(kI3oUmCNpo#` zRGHm%{lG(9Jo#5*vf)(-y2!@gi2MV&dqkmyOJr%6f;>?8H6Vdv91A<3Q_GVkv3TqU zTOM|gHhm~ro3MROaBhcuy$d7W8Oqk?m75TGIcF5tQNYEs|4H7vNq{V;3vrJdTjiyu z#_mg5)@o*WJ(TilmYelS*9UvVHAjEjTxesY+%byLK44HL)(wn1(6z0P2Vu8E59BdT zA=yqL`I0+LBd|={KcvAF7fDqmlUY@?)4Ixy7ggo=cey>-!0#&GYe@gin*UpVc3ttI HI1+yVXBo+_ diff --git a/sgl/operators/message_op/__pycache__/pre_normalize_message_op.cpython-39.pyc b/sgl/operators/message_op/__pycache__/pre_normalize_message_op.cpython-39.pyc deleted file mode 100644 index aa5c6f9876984ee97997949439fbb52f271dd897..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 859 zcmZWnJ8u&~5T4n)b8KUWh(G}nEz%+3lPD2FNJvQo4pO>nR>$o+akTfy?paDKcd?X; z{{UtF(zaA7X{new#}C0sGuqk5H{Un2*>Kngvfr;y^g9Ij#mTOPlzb$2Zz(8HXrRCp zk=MeAA`pO3glU9u3@T9J5>%-1CP+IsXK558LcMLU63b#^X{&39X9uh zf`$|oq(XIWVS-AiVA)FvTSZGq!&?X;jRVo0Xdj*RuUjqBom*zV9bc}8;> zZRRlgmN`vBuuM8Vq|xN(qp}=LW@X+?t1>fQl%?-)bNhG=zf*VHj_&&v`w#s2+TsK8 GMEnKkQpvCY diff --git a/sgl/operators/message_op/__pycache__/projected_concat_message_op.cpython-37.pyc b/sgl/operators/message_op/__pycache__/projected_concat_message_op.cpython-37.pyc deleted file mode 100644 index 41ea1b2b22a3fb4ed28fd28845b28002575173ff..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1341 zcmaJ>OK%%D5GJ_~tz@}K94BysUIO&8DJlUu|f%wr~HMUIPIKFTxoS&SO?PK$x6H8SX?P{rcACHXX*;{!xLq#n5&A;6p=^4ax!Z;s*!6!2%9G6AuM~gyARrP$e zD!E1^tfw3?;5^sDTF(C>|Gb|r>QYUuT~0r|JN`jmT~6&>Pish2xZ2oh*#pMwW{Q<| z4d0f_ceY7$Ww6yaiFOQx&;uH?1ZYD4c>1+sCwCpnu`9R`JKS$z+XWCnid+CX7xYi| zn5-FaOwz}s?dXQCgUi=0kf`rG0XqiRJzIw^you1e4bdCvhThV3tF;2F&rNWt>O zN74p>WpHXJ#QTgSO?=Gz&W0R&j&+$mAfCSUp;3C}BWO>)@bLot^54?f?4x7TG-q(ArsT%B`Duh=Drt2(L?^)#Z6nsRtn0Z-5Ybj~&n> z8q!00!2Ykb+b`1?2eA^R>IRXQ-8%(-%=vO9bZ_aDYV=vVYL^W>Qs?*{aH2QCC8HR#kqqlg$(KOzu4UcX0&6H-QY5 zwi8>F4LpKg>Zaa5MDv%~oSX}*U@#5l(8b>+``;eE>`Q)y$XY}a$8yZZbR7H(#&>2E diff --git a/sgl/operators/message_op/__pycache__/projected_concat_message_op.cpython-39.pyc b/sgl/operators/message_op/__pycache__/projected_concat_message_op.cpython-39.pyc deleted file mode 100644 index a804b334198e63d09f891fededa8e9d09f13843f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1351 zcmaJ>OK%%D5aw{XtCcJl%Ck<399s0Ewvm9K$D#<3B0&M$4Vv2m0YNV*XLeHW8}>1OiPxU;7kWvDD>>FBCBWglmh{xVIzH09GR@MCXS!bwvH_5nM8-0?Br;KhO_Uu-cumH= z_joTk9fKd7)N-ZOabX=w_fJ;Z702R2nODl>s&S^SP~SgM){43MwaLSW`zzWB_8Vjr z$v~2fz9ln|An9eAMKY4?l4R^HVI+$&cy$7YE^O_(6Ao( zJH&zWTnlSCe@FiPWxA+KHMMp*ef{$IiN5%AYUg@dW1_;<#!kyEXuNKwq0+A5yK?!h zZPHwr(B?3a_A?{|?m|LS7(oi@hrho1RJoIz9_67gH!C0*_H*1;B0;RkIpT8;f73@~ zO_8G#9+9@A8(2p!l0n<~A$k51pjc>u<5wdP@ zK>0^icCek4PUE4-Mh?NXv)Yt9H}fzA>d5Eex)hsTo)wk4hmsv3AuytM;UO>>ePDD% z|F5~bZ^ua(#!8f`8%ADs?-ltW=gXDQTT6$a`8q^B+|5X?wB~#-r~7C%A+c`6`^ebv zpGjfZ?=<6~d5p5Z*Y##nRg>AO%3V=cLVH?OesC+BFN0@#u{H%ui`Mu|?V7 vGuTSq-1`T?{Bb%bXTmBRO%rnH;_s!p*TWyT1^*O~wE@W#$8$)BFpT~K=5S_* diff --git a/sgl/operators/message_op/__pycache__/simple_weighted_message_op.cpython-37.pyc b/sgl/operators/message_op/__pycache__/simple_weighted_message_op.cpython-37.pyc deleted file mode 100644 index 61baf71f4c30980f1c22871734f69c0f61170d08..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1932 zcmaJ?&2Jk;6rY*>@Y-275l!19DzFhBsS(%|aRh`4q!dAVNJJ$=HJW(F_9pw`%#2GD zxtBVq5)!xm0m+engFj;rTsY;-jSIZDYbOB|#+unT^Jd=X?>*mZHX8)S2M3$#kARTh zaWfw-5TC*{xJV*tLb~{62}x;}Qly=PrB2r=X*Y3Guj^5ApGZ&o$3*(dnL1rx(rwbJ zoq-gx)uM&JqcW3gpz7Z!W5hr`9KmKe&lHz&%70MtU}%-(LdqUKnh$42m^Xn?q)R2~ zvKORJC6(+r=sMDo?lI}QFNjOJ9w_ySTibCuO4Os7%yWf3^bUsy6X_aE^BXLdSUMq7 zcETof!cH0Z27B&I9Le@;`i^y_Go{bz#DP=K`V)6brMvIwd%*FKBmM6g@M8dpd7xXY@I*2iiXt3}O!A%< zeTe_jR}Zo$B8g=*&q&nE)7>}|HqNui9*)#{v;%9Dj*X3WRkTq`*@$u-ZCsYL(LSRP z*`tr=0!KoNR9U5?Y@F^YZK6KBu|pMArd?7e_ z@knWn&?y4W<1DtE7Y&Zaa*RvS_*Nui^@Y~CF5cl+V36}m@vPoCiOP!CuYORa>8-0T zRH@?VMA3*%oEa;!o+`XDGDRKpRhi|McRYw}Z&)mBV&JO6)C~NB(s7MS`rxY-+7^z= zWU=%gdtIRi1}hnSNrW++|3d!!vNOz6)iLIvvwi>T_mab>9WzKeIm8STX*y|D_BR;Zi6*NeC;!BTiyWCNwIA?=IN2=b_JAopf@!mISU z9P@j9E4L4@U8X9722kQeV{}@szJ=H8eI;xKF3OAMyf1J()~}+}a(NB!1~YW4JFqir zunbL8aChHFo*lc5+LR`!qA#K-bpoP5`ko>RvYiye#!fAcGrPK z?xhZ@goMO_ACMgTH~1NQ;KC_qZd}65dvQ$xfn^#4cL0`HIwli# z%Eole&KURzH221yWQSDWx4!fy^cfv{a2nWf>`$ol_dWd)a02AW@H+;)1=tb1TG(g$ z*Re07Rf4CsHYR%uwmxCz&@RdvzBq7r22y#etSaX%~U8G`Bj}O&OQVfVY?nL^>P@*BFRg@~##= z$iwmHkMl#3CNiF7Chiv5ZjuX|6nSiqhH5R|0T^c^W8+;Fua{ER<3h*l^ODv#&M8Fp z`2Cr{q0k~zR_Qn&WxGn7xCd|SK*g15FExsOU#q^bg{!LgY7U@WRs&%2=1Aavvay}Q&>YbOW+}*wUL6xRAuD(#E zilY;*o|q&zR^(mff-*9$hWV<@^2<8`M7BF{i<=m@szEgazo315qtYJuYK6AWQ@M0Y z|FPE>x^G@UM}se^FoyFNm|-GKJFD8C7sK|y zN|C>0Z48ez4yp!h^DY2FZ_toMlu>r!vye5WSc68iP8+b7X&v~8FToG*9miM1eq)xl zE_yp#I0&%9K89_60RU|U?GLTD3mENVgE8w_e@vn2*$SD^Pso$_DYW|@I}Id$=?cDL z(px3VQK^9xvI8N(P|oXdQxp_{8u4yDoXDC+ES7E$rl-+={Kdmx8cv0hI>oBCmpS>pG3X zz^z+jEq@kb#Km&q5yKf4^SMR8X-%p^NQW%ivZzfb8rF?oR~#pZt94QJHyuO!n0t>8vX!Q)5Kog+VIMIH w2GBhQ1MdWhMG|!X0zssj^5TsLFC(3bK7-LUALgJrD7xUp{Y4Qf`+Ba@Sb-z94 zn;46#yjgF|geu<+{*TtFBN?-Co4hogaadNmwVY28w{33wn@*)WXYAR@y24nCIGFs1 zIq>-Ib1?1Bl2}Zm?u5jI3o>^8&)E4*?hU!>m|@p5fa@gnYA$ Date: Sun, 3 Dec 2023 14:55:30 +0800 Subject: [PATCH 11/28] Fix the bug of ClusterGCN make examples clustergcn_nodeclass.py runable. --- sgl/models/homo/clustergcn.py | 6 +++++- sgl/sampler/sampler.py | 5 ++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/sgl/models/homo/clustergcn.py b/sgl/models/homo/clustergcn.py index e89a8ef..7c9d27b 100644 --- a/sgl/models/homo/clustergcn.py +++ b/sgl/models/homo/clustergcn.py @@ -10,7 +10,7 @@ def __init__(self, training_sampler, eval_sampler, nfeat, hidden_dim, nclass, dr self._eval_sampling_op = eval_sampler self._base_model = GCN(nfeat=nfeat, nhid=hidden_dim, nclass=nclass, nlayers=num_layers, dropout=dropout).to(device) - def mini_batch_prepare_forward(self, batch, device): + def mini_batch_prepare_forward(self, batch, device, inductive = False): batch_in, batch_out, block = batch local_inds, global_inds = batch_out in_x = self._processed_feature[batch_in].to(device) @@ -18,3 +18,7 @@ def mini_batch_prepare_forward(self, batch, device): block.to_device(device) y_pred = self._base_model(in_x, block)[local_inds] return y_pred, y_truth + + @property + def collate_fn(self): + return self._training_sampling_op.collate_fn diff --git a/sgl/sampler/sampler.py b/sgl/sampler/sampler.py index c4affc5..56ab143 100644 --- a/sgl/sampler/sampler.py +++ b/sgl/sampler/sampler.py @@ -263,4 +263,7 @@ def _metis_clustering(self): if self._save_dir is not None: torch.save((self.perm_adjs, self.partptr, self.perm_node_idx), self._save_path_pt) pkl.dump(self.splitted_perm_adjs, open(self._save_path_pkl, "wb")) - print(f"\nSave Metis graph clustering results under the {self._save_dir} directory.\n") \ No newline at end of file + print(f"\nSave Metis graph clustering results under the {self._save_dir} directory.\n") + +class GraphSaintSampler(BaseSampler): + pass \ No newline at end of file From b253bb62fd587f7b651baffb9a5ff477c3ada2f2 Mon Sep 17 00:00:00 2001 From: TheRoadQaQ Date: Mon, 4 Dec 2023 13:19:35 +0800 Subject: [PATCH 12/28] GraphSAINT only node sampler --- .gitignore | 5 ++ examples/configs/graphsaint.yml | 21 ++++++++ examples/graphsaint_nodeclass.py | 55 +++++++++++++++++++ sgl/models/homo/__init__.py | 4 +- sgl/models/homo/graphsaint.py | 17 ++++++ sgl/sampler/__init__.py | 5 +- sgl/sampler/sampler.py | 93 +++++++++++++++++++++++++++++++- 7 files changed, 195 insertions(+), 5 deletions(-) create mode 100644 .gitignore create mode 100644 examples/configs/graphsaint.yml create mode 100644 examples/graphsaint_nodeclass.py create mode 100644 sgl/models/homo/graphsaint.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ef07548 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ + +*.pyc +*.pyc +*.pyc +*.pyc diff --git a/examples/configs/graphsaint.yml b/examples/configs/graphsaint.yml new file mode 100644 index 0000000..d016b2b --- /dev/null +++ b/examples/configs/graphsaint.yml @@ -0,0 +1,21 @@ +dataset: + classname: "Planetoid" + name: "cora" + root: "/home/ssq/test_data/" +sampler: + train: + pre_sampling_graphs: 10 + samplertype: "Node" + nodebudget: 2048 + pre_sampling_op: "RwGraphOp" +model: + hidden_dim: 128 + dropout: 0.5 + num_layers: 2 +task: + train_batch_size: 2048 + epochs: 100 + lr: 0.01 + weight_decay: 0.00005 + seed: 42 + diff --git a/examples/graphsaint_nodeclass.py b/examples/graphsaint_nodeclass.py new file mode 100644 index 0000000..66295f3 --- /dev/null +++ b/examples/graphsaint_nodeclass.py @@ -0,0 +1,55 @@ +import yaml +import argparse +from torch.nn.functional import nll_loss +import sgl.dataset as Dataset +from sgl.models.homo import GraphSAINT +import sgl.sampler as Sampler +from sgl.sampler import GraphSAINTSampler +from sgl.tasks import NodeClassification_Sampling + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="GraphSaint-Models.") + parser.add_argument( + "--device", type=int, default=0, help="gpu device id or cpu (-1)" + ) + parser.add_argument( + "--config_path", type=str, default="./configs/graphsaint.yml", help="save path of the configuration file" + ) + args = parser.parse_args() + config = yaml.safe_load(open(args.config_path, "rb")) + device = f"cuda:{args.device}" if args.device >= 0 else "cpu" + dataset_kwargs = config["dataset"] + sampler_kwargs = config["sampler"] + model_kwargs = config["model"] + task_kwargs = config["task"] + + classname = dataset_kwargs.pop("classname") + dataset = getattr(Dataset, classname)(**dataset_kwargs) + train_sampler_kwargs = sampler_kwargs["train"] + train_sampler_kwargs.update({"save_dir": dataset.processed_dir}) + + train_sampler = GraphSAINTSampler(dataset.adj, **train_sampler_kwargs) + if "eval" in sampler_kwargs: + eval_sampler_kwargs = sampler_kwargs["eval"] + eval_sampler_name = eval_sampler_kwargs["name"] + if eval_sampler_name == "ClusterGCNSampler": + eval_sampler_kwargs.update({"save_dir": dataset.processed_dir}) + eval_cluster_number = eval_sampler_kwargs["cluster_number"] + task_kwargs.update({"eval_graph_number": eval_cluster_number}) + eval_sampler = GraphSAINTSampler(dataset, **eval_sampler_kwargs) + else: + eval_sampler = getattr(Sampler, eval_sampler_name)(dataset.adj, **eval_sampler_kwargs) + else: + eval_sampler = None + + model_kwargs.update({"device": device}) + model = GraphSAINT(dataset, train_sampler, eval_sampler, **model_kwargs) + task_kwargs.update({"device": device}) + + def myloss(pred,labels): + loss = nll_loss(pred, labels, reduction="none") + loss = (loss/model.cur_loss_norm).sum() + return loss + + task_kwargs.update({"loss_fn":myloss}) + test_acc = NodeClassification_Sampling(dataset, model, **task_kwargs).test_acc diff --git a/sgl/models/homo/__init__.py b/sgl/models/homo/__init__.py index d6e6b03..06c65a4 100644 --- a/sgl/models/homo/__init__.py +++ b/sgl/models/homo/__init__.py @@ -11,6 +11,7 @@ from .graphsage import GraphSAGE from .vanillagnn import VanillaGNN from .lazygnn import LazyGNN +from .graphsaint import GraphSAINT __all__ = [ "SGC", @@ -25,5 +26,6 @@ "ClusterGCN", "GraphSAGE", "VanillaGNN", - "LazyGNN" + "LazyGNN", + "GraphSAINT" ] diff --git a/sgl/models/homo/graphsaint.py b/sgl/models/homo/graphsaint.py new file mode 100644 index 0000000..76c3eea --- /dev/null +++ b/sgl/models/homo/graphsaint.py @@ -0,0 +1,17 @@ +from sgl.models.simple_models import GCN +from sgl.models.base_model import BaseSAMPLEModel +from sgl.operators.graph_op import RwGraphOp + +class GraphSAINT(BaseSAMPLEModel): + def __init__(self, dataset, training_sampler, eval_sampler, hidden_dim, dropout=0.5, num_layers=2, device="cpu"): + super(GraphSAINT, self).__init__() + self._pre_graph_op = RwGraphOp() + self._training_sampling_op = training_sampler + self._eval_sampling_op = eval_sampler + self._base_model = GCN( + nfeat=dataset.num_features, nhid=hidden_dim, nclass=dataset.num_classes, nlayers=num_layers, dropout=dropout + ).to(device) + + @property + def cur_loss_norm(self): + return self._training_sampling_op.loss_norm[self._training_sampling_op.index] \ No newline at end of file diff --git a/sgl/sampler/__init__.py b/sgl/sampler/__init__.py index d1d9382..aa78af1 100644 --- a/sgl/sampler/__init__.py +++ b/sgl/sampler/__init__.py @@ -1,8 +1,9 @@ -from .sampler import FastGCNSampler, ClusterGCNSampler, NeighborSampler, FullSampler +from .sampler import FastGCNSampler, ClusterGCNSampler, NeighborSampler, FullSampler , GraphSAINTSampler __all__ = [ "FastGCNSampler", "ClusterGCNSampler", "NeighborSampler", - "FullSampler" + "FullSampler", + "GraphSAINTSampler" ] diff --git a/sgl/sampler/sampler.py b/sgl/sampler/sampler.py index 56ab143..a8ae036 100644 --- a/sgl/sampler/sampler.py +++ b/sgl/sampler/sampler.py @@ -4,6 +4,7 @@ import pickle as pkl import networkx as nx import scipy.sparse as sp + from torch_sparse import SparseTensor from torch_geometric.utils import from_networkx, mask_to_index @@ -265,5 +266,93 @@ def _metis_clustering(self): pkl.dump(self.splitted_perm_adjs, open(self._save_path_pkl, "wb")) print(f"\nSave Metis graph clustering results under the {self._save_dir} directory.\n") -class GraphSaintSampler(BaseSampler): - pass \ No newline at end of file +class GraphSAINTSampler(BaseSampler): + ''' + sample the wholo graph, feature and label as GraphSAINT method + ''' + def __init__(self, adj, **kwargs): + """ + Inputs: + adj: adj of dgl Graph:sp.matrix + kwargs: some params + """ + self.replace = True + self.node_budget = kwargs['nodebudget'] + + super(GraphSAINTSampler, self).__init__(adj, **kwargs) + + self.sampler_name = "GraphSaintSampler" + self.sample_level = "graph" + self.pre_sampling = False + + def _pre_process(self, **kwargs): + if kwargs['samplertype'] == "Node": + self._calc_probs(**kwargs) + self.sample = self.node_sample + else: + raise NotImplementedError + + self._calc_norm(**kwargs) + + def node_sample(self): + """ + Inputs: + batch_ids: is not used in this method + + method: sample fixed size of nodes as a subgraph + + Outputs: + batch_in: global node index + batch_out: global node index + block: sampled adjs in the form of sparse tensors wrapped in Block class + """ + + p = self.probs + sampled = np.random.choice(np.arange(np.size(p)), self.node_budget, self.replace, p) + sampled = np.unique(sampled) + + adj = self._adj[sampled, :].tocsc() + adj = adj[:, sampled].tocsr() + return sampled, adj + + def _calc_norm(self, **kwargs): + """ + methods: calculate the norm to estimate embedding and loss + """ + times = kwargs['pre_sampling_graphs'] + + node_value = np.zeros(np.size(self.probs)) + edge_value = sp.lil_matrix((np.size(self.probs),np.size(self.probs))) + + for _ in range(times): + sampled, adj = self.sample() + adj = adj.tocoo() + for row, col in zip(adj.row, adj.col): + edge_value[sampled[row],sampled[col]] += 1 + node_value[sampled] += 1 + + edge_value = edge_value.tocsr().dot(sp.diags(1.0 / np.maximum(node_value, 1))) + + self.aggr_norm = edge_value + self.loss_norm = torch.FloatTensor(np.maximum(node_value, 1)) + return + + def collate_fn(self, batch_ids): + """ + Inputs: + batch_ids: is not used in this method + + method: sample fixed size of nodes as a subgraph + + Outputs: batch_in: global node index + batch_out: global node index + block: sampled adjs in the form of sparse tensors wrapped in Block class + """ + sampled, adj = self.sample() + sampled_aggr_norm = self.aggr_norm[sampled, :].tocsc() + sampled_aggr_norm = sampled_aggr_norm[:, sampled] + adj = adj.multiply(sampled_aggr_norm.transpose()) + + self.index = sampled + + return sampled,sampled,self._to_Block(adj) \ No newline at end of file From ef019b9484cb193fac347b75e6593aeae2a9f2fe Mon Sep 17 00:00:00 2001 From: infinity Date: Mon, 4 Dec 2023 06:23:42 +0000 Subject: [PATCH 13/28] implement C++ version of node-wise sampling for one layer. --- .gitignore | 6 +- examples/configs/graphsage.yml | 2 +- sgl/sampler/sampler.py | 40 +--- sgl/sampler/sampling_ops.cpp | 227 ++++++++++++++++++++++ sgl/sampler/sampling_ops.hpp | 19 ++ sgl/sampler/setup.py | 53 +++++ sgl/tasks/node_classification_sampling.py | 2 +- 7 files changed, 311 insertions(+), 38 deletions(-) create mode 100644 sgl/sampler/sampling_ops.cpp create mode 100644 sgl/sampler/sampling_ops.hpp create mode 100644 sgl/sampler/setup.py diff --git a/.gitignore b/.gitignore index a937f30..5e704a5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ __pycache__/ -sgl_dair.egg-info \ No newline at end of file + +*.egg-info/ + +build/ +dist/ \ No newline at end of file diff --git a/examples/configs/graphsage.yml b/examples/configs/graphsage.yml index 8bfbfb3..d12c9a1 100644 --- a/examples/configs/graphsage.yml +++ b/examples/configs/graphsage.yml @@ -25,7 +25,7 @@ model: task: name: "NodeClassification_Sampling" train_batch_size: 1024 - train_num_workers: 5 + train_num_workers: 0 # eval_batch_size: 1024 # eval_num_workers: 5 # eval_together: True diff --git a/sgl/sampler/sampler.py b/sgl/sampler/sampler.py index c4affc5..f130ca8 100644 --- a/sgl/sampler/sampler.py +++ b/sgl/sampler/sampler.py @@ -8,6 +8,7 @@ from torch_geometric.utils import from_networkx, mask_to_index from sgl.sampler.base_sampler import BaseSampler +from sampling_ops import NodeWiseOneLayer class FullSampler(BaseSampler): def __init__(self, adj, **kwargs): @@ -61,9 +62,12 @@ def collate_fn(self, batch_inds): batch_inds = np.asarray(batch_inds) all_adjs = [] + indptr, indices, values = self._adj.indptr, self._adj.indices, self._adj.data + cur_tgt_nodes = batch_inds for layer_index in range(self.num_layers): - cur_src_nodes, adj_sampled = self._one_layer_sampling(cur_tgt_nodes, self.layer_sizes[layer_index]) + cur_src_nodes, (s_indptr, s_indices, s_data) = NodeWiseOneLayer(cur_tgt_nodes, indptr, indices, values, self.layer_sizes[layer_index], self.probs, True, self.replace) + adj_sampled = sp.csr_matrix((s_data, s_indices, s_indptr), shape=(len(cur_tgt_nodes), len(cur_src_nodes))) all_adjs.insert(0, adj_sampled) cur_tgt_nodes = cur_src_nodes @@ -71,40 +75,6 @@ def collate_fn(self, batch_inds): return cur_tgt_nodes, batch_inds, self._to_Block(all_adjs) - def _one_layer_sampling(self, prev_nodes, layer_size=-1): - """ - Inputs: - v_indices: array of target node inds of the current layer - layer_size: size of sampled neighbors as the source nodes - """ - - current_layer_adj = self._adj[prev_nodes, :] - - if layer_size < 0: - # in case layer_size < 0, we simply keep all the neighbors - next_nodes = np.unique(current_layer_adj.indices) - - else: - next_nodes = [] - - row_start_stop = np.lib.stride_tricks.as_strided(current_layer_adj.indptr, shape=(current_layer_adj.shape[0], 2), strides=2*current_layer_adj.indptr.strides) - - for start, stop in row_start_stop: - neigh_index = current_layer_adj.indices[start:stop] - if neigh_index.size == 0: - continue - probs = self.probs[neigh_index] / np.sum(self.probs[neigh_index]) - num_samples = np.min([neigh_index.size, layer_size]) if self.replace is False else layer_size - sampled_nodes = np.random.choice(neigh_index, num_samples, replace=self.replace, p=probs) - next_nodes.append(sampled_nodes) - - next_nodes = np.unique(np.concatenate(next_nodes)) - - next_nodes = np.setdiff1d(next_nodes, prev_nodes) - next_nodes = np.concatenate((prev_nodes, next_nodes)) - - return next_nodes, current_layer_adj[:, next_nodes] - class FastGCNSampler(BaseSampler): def __init__(self, adj, **kwargs): super(FastGCNSampler, self).__init__(adj, **kwargs) diff --git a/sgl/sampler/sampling_ops.cpp b/sgl/sampler/sampling_ops.cpp new file mode 100644 index 0000000..f50995d --- /dev/null +++ b/sgl/sampler/sampling_ops.cpp @@ -0,0 +1,227 @@ +#include + +#include "sampling_ops.hpp" + +std::mt19937 gen; + +// BatchSamples NodeWiseMultiLayers(PyArrInt batch_inds, PyArrInt indptr, PyArrInt indices, PyArrFloat values, PyArrInt layer_sizes, PyArrFloat probability, bool biased, bool replace) { +// py::buffer_info buf_batch_inds = batch_inds.request(); + +// } + +SingleSample NodeWiseOneLayer(PyArrInt prev_nodes, PyArrInt indptr, PyArrInt indices, PyArrFloat values, int32_t layer_size, PyArrFloat probability, bool biased, bool replace) { + py::buffer_info buf_prev_nodes = prev_nodes.request(); + py::buffer_info buf_indptr = indptr.request(); + py::buffer_info buf_indices = indices.request(); + py::buffer_info buf_values = values.request(); + py::buffer_info buf_probability = probability.request(); + + int32_t* ptr_prev_nodes = static_cast (buf_prev_nodes.ptr); + int32_t* ptr_indptr = static_cast (buf_indptr.ptr); + int32_t* ptr_indices = static_cast (buf_indices.ptr); + float* ptr_values = static_cast (buf_values.ptr); + float* ptr_probability = static_cast (buf_probability.ptr); + + std::vector>> cols; // col, v + std::vector n_ids; + std::unordered_map n_id_map; + + auto out_indptr = PyArrInt(prev_nodes.size() + 1); + py::buffer_info buf_out_indptr = out_indptr.request(); + int32_t* ptr_out_indptr = static_cast (buf_out_indptr.ptr); + ptr_out_indptr[0] = 0; + + int32_t n, c, e, start_, end_, num_neighbors; + float v; + + for (int32_t i = 0; i < prev_nodes.size(); i++) { + n = ptr_prev_nodes[i]; + cols.push_back(std::vector>()); + n_id_map[n] = i; + n_ids.push_back(n); + } + + if (layer_size < 0) { + // No sampling + for (int32_t i = 0; i < prev_nodes.size(); i++) { + n = ptr_prev_nodes[i]; + start_ = ptr_indptr[n], end_ = ptr_indptr[n + 1]; + num_neighbors = end_ - start_; + + for (int32_t j = 0; j < num_neighbors; j++) { + e = start_ + j; + c = ptr_indices[e]; + v = ptr_values[e]; + + if (n_id_map.count(c) == 0) { + n_id_map[c] = n_ids.size(); + n_ids.push_back(c); + } + cols[i].push_back(std::make_tuple(n_id_map[c], v)); + } + ptr_out_indptr[i + 1] = ptr_out_indptr[i] + cols[i].size(); + } + } + else if (replace) { + // Sample with replacement + if (biased) { + for (int32_t i = 0; i < prev_nodes.size(); i++) { + n = ptr_prev_nodes[i]; + start_ = ptr_indptr[n], end_ = ptr_indptr[n + 1]; + num_neighbors = end_ - start_; + + if (num_neighbors > 0) { + std::vector temp_probability(ptr_probability + start_, ptr_probability + end_); + for (int32_t j = 0; j < layer_size; j++) { + std::discrete_distribution<> d(temp_probability.begin(), temp_probability.end()); + e = start_ + d(gen); + c = ptr_indices[e]; + v = ptr_values[e]; + + if (n_id_map.count(c) == 0) { + n_id_map[c] = n_ids.size(); + n_ids.push_back(c); + } + cols[i].push_back(std::make_tuple(n_id_map[c], v)); + } + } + ptr_out_indptr[i + 1] = ptr_out_indptr[i] + cols[i].size(); + } + } + else { + for (int32_t i = 0; i < prev_nodes.size(); i++) { + n = ptr_prev_nodes[i]; + start_ = ptr_indptr[n], end_ = ptr_indptr[n + 1]; + num_neighbors = end_ - start_; + + if (num_neighbors > 0) { + for (int32_t j = 0; j < layer_size; j++) { + e = start_ + rand() % num_neighbors; + c = ptr_indices[e]; + v = ptr_values[e]; + + if (n_id_map.count(c) == 0) { + n_id_map[c] = n_ids.size(); + n_ids.push_back(c); + } + cols[i].push_back(std::make_tuple(n_id_map[c], v)); + } + } + ptr_out_indptr[i + 1] = ptr_out_indptr[i] + cols[i].size(); + } + } + } + else { + // Sample without replacement + if (biased) { + for (int32_t i = 0; i < prev_nodes.size(); i++) { + n = ptr_prev_nodes[i]; + start_ = ptr_indptr[n], end_ = ptr_indptr[n + 1]; + num_neighbors = end_ - start_; + + if (num_neighbors <= layer_size) { + for(int32_t j = 0; j < num_neighbors; j++) { + e = start_ + j; + c = ptr_indices[e]; + v = ptr_values[e]; + + if (n_id_map.count(c) == 0) { + n_id_map[c] = n_ids.size(); + n_ids.push_back(c); + } + cols[i].push_back(std::make_tuple(n_id_map[c], v)); + } + } + else { + std::vector temp_probability(ptr_probability + start_, ptr_probability + end_); + std::discrete_distribution<> d(temp_probability.begin(), temp_probability.end()); + std::uniform_real_distribution dist(0.0, 1.0); + std::vector vals; + std::generate_n(std::back_inserter(vals), num_neighbors, [&dist]() { return dist(gen); }); + std::transform(vals.begin(), vals.end(), temp_probability.begin(), vals.begin(), [&](auto r, auto prob) { return std::pow(r, 1. / prob); }); + std::vector> valIndices; + int32_t index = 0; + std::transform(vals.begin(), vals.end(), std::back_inserter(valIndices), [&index](auto v) { return std::pair(v, index++); }); + std::sort(valIndices.begin(), valIndices.end(), [](auto x, auto y) { return x.first > y.first; }); + std::vector candidates; + std::transform(valIndices.begin(), valIndices.end(), std::back_inserter(candidates), [](auto v) { return v.second; }); + for(int32_t j = 0; j < layer_size; j++) { + e = start_ + candidates[j]; + c = ptr_indices[e]; + v = ptr_values[e]; + + if (n_id_map.count(c) == 0) { + n_id_map[c] = n_ids.size(); + n_ids.push_back(c); + } + cols[i].push_back(std::make_tuple(n_id_map[c], v)); + } + } + ptr_out_indptr[i + 1] = ptr_out_indptr[i] + cols[i].size(); + } + } + else { + // via Robert Floyd algorithm + for (int32_t i = 0; i < prev_nodes.size(); i++) { + n = ptr_prev_nodes[i]; + start_ = ptr_indptr[n], end_ = ptr_indptr[n + 1]; + num_neighbors = end_ - start_; + + std::unordered_set perm; + if (num_neighbors <= layer_size) { + for (int32_t j = 0; j < num_neighbors; j++) perm.insert(j); + } else { + for (int32_t j = num_neighbors - layer_size; j < num_neighbors; j++) { + if (!perm.insert(rand() % j).second) perm.insert(j); + } + } + + for(const int32_t &p: perm) { + e = start_ + p; + c = ptr_indices[e]; + v = ptr_values[e]; + + if (n_id_map.count(c) == 0) { + n_id_map[c] = n_ids.size(); + n_ids.push_back(c); + } + cols[i].push_back(std::make_tuple(n_id_map[c], v)); + } + ptr_out_indptr[i + 1] = ptr_out_indptr[i] + cols[i].size(); + } + } + } + + int32_t E = ptr_out_indptr[prev_nodes.size()]; + auto out_indices = PyArrInt(E); + py::buffer_info buf_out_indices = out_indices.request(); + int32_t* ptr_out_indices = static_cast (buf_out_indices.ptr); + auto out_values = PyArrFloat(E); + py::buffer_info buf_out_values = out_values.request(); + float* ptr_out_values = static_cast (buf_out_values.ptr); + + int32_t i = 0; + for (std::vector> &col_vec : cols) { + std::sort(col_vec.begin(), col_vec.end(), + [](const std::tuple &a, + const std::tuple &b) -> bool { + return std::get<0>(a) < std::get<0>(b); + }); + for (const std::tuple &value : col_vec) { + ptr_out_indices[i] = std::get<0>(value); + ptr_out_values[i] = std::get<1>(value); + i += 1; + } + } + + PyArrInt out_n_ids(n_ids.size()); + py::buffer_info buf_out_n_ids = out_n_ids.request(); + int32_t *ptr_out_n_ids = static_cast(buf_out_n_ids.ptr); + std::copy(n_ids.begin(), n_ids.end(), ptr_out_n_ids); + Adj out_adj = std::make_tuple(out_indptr, out_indices, out_values); + return std::make_pair(out_n_ids, out_adj); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("NodeWiseOneLayer", &NodeWiseOneLayer); +} \ No newline at end of file diff --git a/sgl/sampler/sampling_ops.hpp b/sgl/sampler/sampling_ops.hpp new file mode 100644 index 0000000..5c87621 --- /dev/null +++ b/sgl/sampler/sampling_ops.hpp @@ -0,0 +1,19 @@ +#include +#include +#include +#include + +namespace py = pybind11; +typedef py::array_t PyArrInt; +typedef py::array_t PyArrFloat; + +using Adj = std::tuple; +using Adjs = std::vector; +using SingleSample = std::tuple; +using BatchSamples = std::tuple; + +SingleSample NodeWiseOneLayer(PyArrInt prev_nodes, PyArrInt indptr, PyArrInt indices, + PyArrFloat values, int32_t layer_size, PyArrFloat probability, bool biased, bool replace); +// BatchSamples NodeWiseMultiLayers(PyArrInt batch_inds, PyArrInt indptr, PyArrInt indices, +// PyArrFloat values, PyArrInt layer_sizes, PyArrFloat probability, +// bool biased, bool replace); \ No newline at end of file diff --git a/sgl/sampler/setup.py b/sgl/sampler/setup.py new file mode 100644 index 0000000..91c77da --- /dev/null +++ b/sgl/sampler/setup.py @@ -0,0 +1,53 @@ +import os +import sys +from pathlib import Path +from setuptools import setup + +from torch.__config__ import parallel_info +from torch.utils.cpp_extension import BuildExtension, CppExtension + + +def flags_to_list(flagstring): + return list(filter(bool, flagstring.split(' '))) + + +WITH_SYMBOLS = True if os.getenv('WITH_SYMBOLS', '0') == '1' else False +CXX_FLAGS = flags_to_list(os.getenv('CXX_FLAGS', '')) +ROOT_PATH = Path(__file__).resolve().parent + + +def get_extensions(): + define_macros = [] + libraries = [] + extra_compile_args = { + 'cxx': ['-O3', '-march=native', '-std=c++17', '-g'] + CXX_FLAGS} + extra_link_args = [] if WITH_SYMBOLS else ['-s'] + + info = parallel_info() + if 'backend: OpenMP' in info and 'OpenMP not found' not in info: + extra_compile_args['cxx'] += ['-DAT_PARALLEL_OPENMP'] + if sys.platform == 'win32': + extra_compile_args['cxx'] += ['/openmp'] + else: + extra_compile_args['cxx'] += ['-fopenmp'] + else: + print('Compiling without OpenMP...') + + return [ + CppExtension( + 'sampling_ops', + ['sampling_ops.cpp'], + define_macros=define_macros, + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + libraries=libraries, + ), + ] + + +setup( + name='sampling_ops', + ext_modules=get_extensions(), + cmdclass={ + 'build_ext': BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False) + }) diff --git a/sgl/tasks/node_classification_sampling.py b/sgl/tasks/node_classification_sampling.py index ae5ddff..b5d4079 100644 --- a/sgl/tasks/node_classification_sampling.py +++ b/sgl/tasks/node_classification_sampling.py @@ -57,7 +57,7 @@ def _execute(self): self.__model.preprocess(adj=self.__dataset.adj, x=self.__dataset.x, y=self.__dataset.y, device=self.__device, **kwargs) pre_time_ed = time.time() print(f"Preprocessing done in {(pre_time_ed - pre_time_st):.4f}s") - + if self.__mini_batch_train: if self.__train_determined_sample: self.__train_loader = DataLoader( From 6519321c5f19e2373b06b6136ba88fa804458e47 Mon Sep 17 00:00:00 2001 From: infinity Date: Mon, 4 Dec 2023 10:53:45 +0000 Subject: [PATCH 14/28] fix tiny bugs for clustergcn, add nodewisesampler and layerwisesampler. --- examples/configs/fastgcn.yml | 2 +- sgl/models/homo/clustergcn.py | 8 +- sgl/sampler/base_sampler.py | 42 +++++++++++ sgl/sampler/sampler.py | 44 ++--------- sgl/sampler/sampling_ops.cpp | 92 +++++++++++++++++++++-- sgl/sampler/sampling_ops.hpp | 6 +- sgl/tasks/node_classification_sampling.py | 7 -- 7 files changed, 145 insertions(+), 56 deletions(-) diff --git a/examples/configs/fastgcn.yml b/examples/configs/fastgcn.yml index 9feaa22..a5d4f09 100644 --- a/examples/configs/fastgcn.yml +++ b/examples/configs/fastgcn.yml @@ -10,7 +10,7 @@ sampler: pre_sampling_op: "LaplacianGraphOp" layer_sizes: "2048,2048" prob_type: "normalize" - replace: True + replace: False model: name: "FastGCN" hidden_dim: 128 diff --git a/sgl/models/homo/clustergcn.py b/sgl/models/homo/clustergcn.py index e89a8ef..7d77222 100644 --- a/sgl/models/homo/clustergcn.py +++ b/sgl/models/homo/clustergcn.py @@ -10,7 +10,7 @@ def __init__(self, training_sampler, eval_sampler, nfeat, hidden_dim, nclass, dr self._eval_sampling_op = eval_sampler self._base_model = GCN(nfeat=nfeat, nhid=hidden_dim, nclass=nclass, nlayers=num_layers, dropout=dropout).to(device) - def mini_batch_prepare_forward(self, batch, device): + def mini_batch_prepare_forward(self, batch, device, **kwargs): batch_in, batch_out, block = batch local_inds, global_inds = batch_out in_x = self._processed_feature[batch_in].to(device) @@ -18,3 +18,9 @@ def mini_batch_prepare_forward(self, batch, device): block.to_device(device) y_pred = self._base_model(in_x, block)[local_inds] return y_pred, y_truth + + def collate_fn(self, batch_inds, mode): + if self.training: + return self._training_sampling_op.collate_fn(batch_inds, mode) + else: + return self._eval_sampling_op.collate_fn(batch_inds, mode) diff --git a/sgl/sampler/base_sampler.py b/sgl/sampler/base_sampler.py index 05ea9c7..2dece79 100644 --- a/sgl/sampler/base_sampler.py +++ b/sgl/sampler/base_sampler.py @@ -1,5 +1,6 @@ import os import numpy as np +import scipy.sparse as sp from scipy.sparse.linalg import norm as sparse_norm from sgl.data.base_data import Block @@ -7,6 +8,8 @@ from sgl.sampler.utils import adj_train_analysis from sgl.utils import sparse_mx_to_torch_sparse_tensor +from sampling_ops import NodeWiseOneLayer + class BaseSampler: def __init__(self, adj, **kwargs): self._adj = adj @@ -93,3 +96,42 @@ def _to_Block(self, adjs): def collate_fn(self, *args): raise NotImplementedError + +class NodeWiseSampler(BaseSampler): + def __init__(self, adj, **kwargs): + super(NodeWiseSampler, self).__init__(adj, **kwargs) + self.__indptr = self._adj.indptr + self.__indices = self._adj.indices + self.__values = self._adj.data + + def one_layer_sampling(self, target_nodes, layer_size, biased): + source_nodes, (s_indptr, s_indices, s_data) = NodeWiseOneLayer(target_nodes, self.__indptr, self.__indices, self.__values, layer_size, self.probs, biased, self.replace) + adj_sampled = sp.csr_matrix((s_data, s_indices, s_indptr), shape=(len(target_nodes), len(source_nodes))) + return source_nodes, adj_sampled + +class LayerWiseSampler(BaseSampler): + def __init__(self, adj, **kwargs): + super(LayerWiseSampler, self).__init__(adj, **kwargs) + + def one_layer_sampling(self, target_nodes, layer_size, probability): + subgraph_adj = self._adj[target_nodes, :] + neis = np.nonzero(np.sum(subgraph_adj, axis=0))[1] + p1 = probability[neis] + p1 = p1 / np.sum(p1) + + if self.replace is False: + layer_size = min(len(neis), layer_size) + + local_nids = np.random.choice(np.arange(np.size(neis)), + layer_size, self.replace, p1) + + source_nodes = neis[local_nids] + subgraph_adj = subgraph_adj[:, source_nodes] + sampled_p1 = p1[local_nids] + + subgraph_adj = subgraph_adj.dot(sp.diags(1.0 / (sampled_p1 * layer_size))) + return source_nodes, subgraph_adj + +class GraphWiseSampler(BaseSampler): + def __init__(self, adj, **kwargs): + super(GraphWiseSampler, self).__init__(adj, **kwargs) \ No newline at end of file diff --git a/sgl/sampler/sampler.py b/sgl/sampler/sampler.py index f130ca8..19fe91c 100644 --- a/sgl/sampler/sampler.py +++ b/sgl/sampler/sampler.py @@ -7,8 +7,7 @@ from torch_sparse import SparseTensor from torch_geometric.utils import from_networkx, mask_to_index -from sgl.sampler.base_sampler import BaseSampler -from sampling_ops import NodeWiseOneLayer +from sgl.sampler.base_sampler import BaseSampler, NodeWiseSampler, LayerWiseSampler class FullSampler(BaseSampler): def __init__(self, adj, **kwargs): @@ -25,10 +24,10 @@ def __init__(self, adj, **kwargs): def sampling(self): return self.full_batch, self.full_batch, self.full_block -class NeighborSampler(BaseSampler): +class NeighborSampler(NodeWiseSampler): def __init__(self, adj, **kwargs): """ - Node-wise neighbor sampler + Neighborhood sampler """ super(NeighborSampler, self).__init__(adj, **kwargs) self.sampler_name = "NeighborSampler" @@ -48,7 +47,7 @@ def collate_fn(self, batch_inds): Input: batch_inds: array of batch node inds Method: - Neighbor sampling + Neighborhood sampling Outputs: batch_in: global node index of each source node in the first aggregation layer batch_out: global node index of each target node in the last aggregation layer @@ -62,12 +61,10 @@ def collate_fn(self, batch_inds): batch_inds = np.asarray(batch_inds) all_adjs = [] - indptr, indices, values = self._adj.indptr, self._adj.indices, self._adj.data cur_tgt_nodes = batch_inds for layer_index in range(self.num_layers): - cur_src_nodes, (s_indptr, s_indices, s_data) = NodeWiseOneLayer(cur_tgt_nodes, indptr, indices, values, self.layer_sizes[layer_index], self.probs, True, self.replace) - adj_sampled = sp.csr_matrix((s_data, s_indices, s_indptr), shape=(len(cur_tgt_nodes), len(cur_src_nodes))) + cur_src_nodes, adj_sampled = self.one_layer_sampling(cur_tgt_nodes, self.layer_sizes[layer_index], True) all_adjs.insert(0, adj_sampled) cur_tgt_nodes = cur_src_nodes @@ -75,7 +72,7 @@ def collate_fn(self, batch_inds): return cur_tgt_nodes, batch_inds, self._to_Block(all_adjs) -class FastGCNSampler(BaseSampler): +class FastGCNSampler(LayerWiseSampler): def __init__(self, adj, **kwargs): super(FastGCNSampler, self).__init__(adj, **kwargs) self.sampler_name = "FastGCNSampler" @@ -109,39 +106,14 @@ def collate_fn(self, batch_inds): cur_out_nodes = batch_inds for layer_index in range(self.num_layers): - cur_in_nodes, cur_adj = self._one_layer_sampling( - cur_out_nodes, self.layer_sizes[layer_index]) + cur_in_nodes, cur_adj = self.one_layer_sampling( + cur_out_nodes, self.layer_sizes[layer_index], self.probs) all_adjs.insert(0, cur_adj) cur_out_nodes = cur_in_nodes all_adjs = self._post_process(all_adjs, to_sparse_tensor=False) return cur_out_nodes, batch_inds, self._to_Block(all_adjs) - - def _one_layer_sampling(self, v_indices, output_size): - """ - Inputs: - v_indices: array of target node inds of the current layer - output_size: size of the source nodes to be sampled - Outputs: - u_samples: array of source node inds of the current layer - support: normalized sparse adjacency matrix of the current layer - """ - support = self._adj[v_indices, :] - neis = np.nonzero(np.sum(support, axis=0))[1] - p1 = self.probs[neis] - p1 = p1 / np.sum(p1) - if self.replace is False: - output_size = min(len(neis), output_size) - sampled = np.random.choice(np.arange(np.size(neis)), - output_size, self.replace, p1) - - u_sampled = neis[sampled] - support = support[:, u_sampled] - sampled_p1 = p1[sampled] - - support = support.dot(sp.diags(1.0 / (sampled_p1 * output_size))) - return u_sampled, support class ClusterGCNSampler(BaseSampler): """ diff --git a/sgl/sampler/sampling_ops.cpp b/sgl/sampler/sampling_ops.cpp index f50995d..ca1ab54 100644 --- a/sgl/sampler/sampling_ops.cpp +++ b/sgl/sampler/sampling_ops.cpp @@ -4,11 +4,6 @@ std::mt19937 gen; -// BatchSamples NodeWiseMultiLayers(PyArrInt batch_inds, PyArrInt indptr, PyArrInt indices, PyArrFloat values, PyArrInt layer_sizes, PyArrFloat probability, bool biased, bool replace) { -// py::buffer_info buf_batch_inds = batch_inds.request(); - -// } - SingleSample NodeWiseOneLayer(PyArrInt prev_nodes, PyArrInt indptr, PyArrInt indices, PyArrFloat values, int32_t layer_size, PyArrFloat probability, bool biased, bool replace) { py::buffer_info buf_prev_nodes = prev_nodes.request(); py::buffer_info buf_indptr = indptr.request(); @@ -143,7 +138,7 @@ SingleSample NodeWiseOneLayer(PyArrInt prev_nodes, PyArrInt indptr, PyArrInt ind int32_t index = 0; std::transform(vals.begin(), vals.end(), std::back_inserter(valIndices), [&index](auto v) { return std::pair(v, index++); }); std::sort(valIndices.begin(), valIndices.end(), [](auto x, auto y) { return x.first > y.first; }); - std::vector candidates; + std::vector candidates; std::transform(valIndices.begin(), valIndices.end(), std::back_inserter(candidates), [](auto v) { return v.second; }); for(int32_t j = 0; j < layer_size; j++) { e = start_ + candidates[j]; @@ -222,6 +217,91 @@ SingleSample NodeWiseOneLayer(PyArrInt prev_nodes, PyArrInt indptr, PyArrInt ind return std::make_pair(out_n_ids, out_adj); } +PyArrInt LayerWiseOneLayer(PyArrInt indices, int32_t layer_size, PyArrFloat probability, bool biased, bool replace) { + py::buffer_info buf_indices = indices.request(); + py::buffer_info buf_probability = probability.request(); + + int32_t* ptr_indices = static_cast (buf_indices.ptr); + float* ptr_probability = static_cast (buf_probability.ptr); + + std::vector neighbors(ptr_indices, ptr_indices + indices.size()); + std::sort(neighbors.begin(), neighbors.end()); + neighbors.erase(std::unique(neighbors.begin(), neighbors.end()), neighbors.end()); + std::vector n_ids; + int32_t e, c, num_neighbors = neighbors.size(); + + if (layer_size < 0) { + // No sampling + n_ids.insert(n_ids.end(), neighbors.begin(), neighbors.end()); + } else if (replace) { + // Sample with replacement + n_ids.resize(layer_size); + if (biased) { + std::vector selectedProbability(num_neighbors); + std::transform(neighbors.begin(), neighbors.end(), selectedProbability.begin(), + [&ptr_probability](int index) { return ptr_probability[index]; }); + + #pragma omp parallel for schedule(static) + for (int32_t j = 0; j < layer_size; j++) { + std::discrete_distribution<> d(selectedProbability.begin(), selectedProbability.end()); + e = d(gen); + c = neighbors[e]; + n_ids[j] = c; + } + } else { + #pragma omp parallel for schedule(static) + for (int32_t j = 0; j < layer_size; j++) { + e = rand() % num_neighbors; + c = neighbors[e]; + n_ids[j] = c; + } + } + } else { + // Sample without replacement + if (num_neighbors <= layer_size) { + n_ids.insert(n_ids.end(), neighbors.begin(), neighbors.end()); + } else if (biased) { + std::vector selectedProbability(num_neighbors); + std::transform(neighbors.begin(), neighbors.end(), selectedProbability.begin(), + [&ptr_probability](int index) { return ptr_probability[index]; }); + std::discrete_distribution<> d(selectedProbability.begin(), selectedProbability.end()); + std::uniform_real_distribution dist(0.0, 1.0); + std::vector vals; + std::generate_n(std::back_inserter(vals), num_neighbors, [&dist]() { return dist(gen); }); + std::transform(vals.begin(), vals.end(), selectedProbability.begin(), vals.begin(), [&](auto r, auto prob) { return std::pow(r, 1. / prob); }); + std::vector> valIndices; + int32_t index = 0; + std::transform(vals.begin(), vals.end(), std::back_inserter(valIndices), [&index](auto v) { return std::pair(v, index++); }); + std::sort(valIndices.begin(), valIndices.end(), [](auto x, auto y) { return x.first > y.first; }); + std::vector candidates; + std::transform(valIndices.begin(), valIndices.end(), std::back_inserter(candidates), [](auto v) { return v.second; }); + + n_ids.resize(layer_size); + #pragma omp parallel for schedule(static) + for (int32_t j = 0; j < layer_size; j++) { + c = candidates[j]; + n_ids[j] = c; + } + } else { + std::unordered_set perm; + for (int32_t j = num_neighbors - layer_size; j < num_neighbors; j++) { + if (!perm.insert(rand() % j).second) perm.insert(j); + } + for (const int32_t &p: perm) { + c = neighbors[p]; + n_ids.push_back(c); + } + } + } + + PyArrInt out_n_ids(n_ids.size()); + py::buffer_info buf_out_n_ids = out_n_ids.request(); + int32_t *ptr_out_n_ids = static_cast(buf_out_n_ids.ptr); + std::copy(n_ids.begin(), n_ids.end(), ptr_out_n_ids); + return out_n_ids; +} + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("NodeWiseOneLayer", &NodeWiseOneLayer); + m.def("LayerWiseOneLayer", &LayerWiseOneLayer); } \ No newline at end of file diff --git a/sgl/sampler/sampling_ops.hpp b/sgl/sampler/sampling_ops.hpp index 5c87621..11a49d6 100644 --- a/sgl/sampler/sampling_ops.hpp +++ b/sgl/sampler/sampling_ops.hpp @@ -8,12 +8,8 @@ typedef py::array_t PyArrInt; typedef py::array_t PyArrFloat; using Adj = std::tuple; -using Adjs = std::vector; using SingleSample = std::tuple; -using BatchSamples = std::tuple; SingleSample NodeWiseOneLayer(PyArrInt prev_nodes, PyArrInt indptr, PyArrInt indices, PyArrFloat values, int32_t layer_size, PyArrFloat probability, bool biased, bool replace); -// BatchSamples NodeWiseMultiLayers(PyArrInt batch_inds, PyArrInt indptr, PyArrInt indices, -// PyArrFloat values, PyArrInt layer_sizes, PyArrFloat probability, -// bool biased, bool replace); \ No newline at end of file +PyArrInt LayerWiseOneLayer(PyArrInt indices, int32_t layer_size, PyArrFloat probability, bool biased, bool replace); \ No newline at end of file diff --git a/sgl/tasks/node_classification_sampling.py b/sgl/tasks/node_classification_sampling.py index b5d4079..ea76e5c 100644 --- a/sgl/tasks/node_classification_sampling.py +++ b/sgl/tasks/node_classification_sampling.py @@ -143,13 +143,6 @@ def _postprocess(self): else: outputs = self.__model.inference(self.__all_eval_loader, self.__device) labels = self.__dataset.y - # outputs, labels = [], [] - # for batch in self.__all_eval_loader: - # output, label = self.__model.mini_batch_prepare_forward(batch, self.__device, transfer_y_to_device=False) - # outputs.append(output.cpu()) - # labels.append(label) - # outputs = torch.vstack(outputs) - # labels = torch.cat(labels) # TODO: self.__model.postprocess now directly returns the raw outputs final_output = self.__model.postprocess(self.__dataset.adj, outputs) From c467498e10212615e98d74b881f84e413a2da219 Mon Sep 17 00:00:00 2001 From: infinity Date: Mon, 4 Dec 2023 12:09:29 +0000 Subject: [PATCH 15/28] add nodewise/layerwise/graphwise-sampler class --- sgl/models/homo/clustergcn.py | 6 +++ sgl/sampler/__init__.py | 8 +++- sgl/sampler/base_sampler.py | 42 ++++++++++++++++-- sgl/sampler/sampler.py | 52 +++++++---------------- sgl/tasks/node_classification_sampling.py | 4 +- 5 files changed, 69 insertions(+), 43 deletions(-) diff --git a/sgl/models/homo/clustergcn.py b/sgl/models/homo/clustergcn.py index 7d77222..16cd01d 100644 --- a/sgl/models/homo/clustergcn.py +++ b/sgl/models/homo/clustergcn.py @@ -10,6 +10,12 @@ def __init__(self, training_sampler, eval_sampler, nfeat, hidden_dim, nclass, dr self._eval_sampling_op = eval_sampler self._base_model = GCN(nfeat=nfeat, nhid=hidden_dim, nclass=nclass, nlayers=num_layers, dropout=dropout).to(device) + def pre_sample(self, mode="train"): + if mode == "train": + self._training_sampling_op.multiple_graphs_sampling() + else: + self._eval_sampling_op.multiple_graphs_sampling() + def mini_batch_prepare_forward(self, batch, device, **kwargs): batch_in, batch_out, block = batch local_inds, global_inds = batch_out diff --git a/sgl/sampler/__init__.py b/sgl/sampler/__init__.py index d1d9382..eac7dc6 100644 --- a/sgl/sampler/__init__.py +++ b/sgl/sampler/__init__.py @@ -1,8 +1,12 @@ -from .sampler import FastGCNSampler, ClusterGCNSampler, NeighborSampler, FullSampler +from .sampler import FastGCNSampler, ClusterGCNSampler, NeighborSampler +from .base_sampler import FullSampler, NodeWiseSampler, LayerWiseSampler, GraphWiseSampler __all__ = [ "FastGCNSampler", "ClusterGCNSampler", "NeighborSampler", - "FullSampler" + "FullSampler", + "NodeWiseSampler", + "LayerWiseSampler", + "GraphWiseSampler" ] diff --git a/sgl/sampler/base_sampler.py b/sgl/sampler/base_sampler.py index 2dece79..1cb767e 100644 --- a/sgl/sampler/base_sampler.py +++ b/sgl/sampler/base_sampler.py @@ -1,5 +1,7 @@ import os +import torch import numpy as np +import pickle as pkl import scipy.sparse as sp from scipy.sparse.linalg import norm as sparse_norm @@ -96,7 +98,22 @@ def _to_Block(self, adjs): def collate_fn(self, *args): raise NotImplementedError - + +class FullSampler(BaseSampler): + def __init__(self, adj, **kwargs): + """ + In fact, this sampler simply returns the full graph. + """ + super(FullSampler, self).__init__(adj, **kwargs) + self.sampler_name = "FullSampler" + self.sample_level = "graph" + self.pre_sampling = False + self.full_batch = kwargs.get("node_ids", range(self._adj.shape[0])) + self.full_block = self._to_Block(self._adj) + + def sampling(self): + return self.full_batch, self.full_batch, self.full_block + class NodeWiseSampler(BaseSampler): def __init__(self, adj, **kwargs): super(NodeWiseSampler, self).__init__(adj, **kwargs) @@ -106,8 +123,8 @@ def __init__(self, adj, **kwargs): def one_layer_sampling(self, target_nodes, layer_size, biased): source_nodes, (s_indptr, s_indices, s_data) = NodeWiseOneLayer(target_nodes, self.__indptr, self.__indices, self.__values, layer_size, self.probs, biased, self.replace) - adj_sampled = sp.csr_matrix((s_data, s_indices, s_indptr), shape=(len(target_nodes), len(source_nodes))) - return source_nodes, adj_sampled + subgraph_adj = sp.csr_matrix((s_data, s_indices, s_indptr), shape=(len(target_nodes), len(source_nodes))) + return source_nodes, subgraph_adj class LayerWiseSampler(BaseSampler): def __init__(self, adj, **kwargs): @@ -134,4 +151,21 @@ def one_layer_sampling(self, target_nodes, layer_size, probability): class GraphWiseSampler(BaseSampler): def __init__(self, adj, **kwargs): - super(GraphWiseSampler, self).__init__(adj, **kwargs) \ No newline at end of file + super(GraphWiseSampler, self).__init__(adj, **kwargs) + + @property + def sample_graph_ops(self): + # Each subclass must implement its own sample operations + raise NotImplementedError + + def multiple_graphs_sampling(self): + if self.pre_sampling is False or self.sampling_done is False: + if self._save_dir is not None and os.path.exists(self._save_path_pt) and os.path.exists(self._save_path_pkl): + print("\nLoad from existing subgraphs.\n") + (self.perm_adjs, self.partptr, self.perm_node_idx) = torch.load(self._save_path_pt) + self.splitted_perm_adjs = pkl.load(open(self._save_path_pkl, "rb")) + else: + self.sample_graph_ops() + self.sampling_done = True + else: + print("\nSubgraphs already existed.\n") \ No newline at end of file diff --git a/sgl/sampler/sampler.py b/sgl/sampler/sampler.py index 19fe91c..f7cd0b1 100644 --- a/sgl/sampler/sampler.py +++ b/sgl/sampler/sampler.py @@ -7,22 +7,8 @@ from torch_sparse import SparseTensor from torch_geometric.utils import from_networkx, mask_to_index -from sgl.sampler.base_sampler import BaseSampler, NodeWiseSampler, LayerWiseSampler +from sgl.sampler.base_sampler import NodeWiseSampler, LayerWiseSampler, GraphWiseSampler -class FullSampler(BaseSampler): - def __init__(self, adj, **kwargs): - """ - In fact, this sampler simply returns the full graph. - """ - super(FullSampler, self).__init__(adj, **kwargs) - self.sampler_name = "FullSampler" - self.sample_level = "graph" - self.pre_sampling = False - self.full_batch = kwargs.get("node_ids", range(self._adj.shape[0])) - self.full_block = self._to_Block(self._adj) - - def sampling(self): - return self.full_batch, self.full_batch, self.full_block class NeighborSampler(NodeWiseSampler): def __init__(self, adj, **kwargs): @@ -115,7 +101,7 @@ def collate_fn(self, batch_inds): return cur_out_nodes, batch_inds, self._to_Block(all_adjs) -class ClusterGCNSampler(BaseSampler): +class ClusterGCNSampler(GraphWiseSampler): """ Clustering the graph, feature set and target. """ @@ -127,36 +113,30 @@ def __init__(self, dataset, **kwargs): super(ClusterGCNSampler, self).__init__(nx.from_scipy_sparse_matrix(dataset.adj), **kwargs) self.sampler_name = "ClusterGCNSampler" self.sample_level = "graph" - self.pre_sampling = True + self.pre_sampling = True # conduct sampling only once before training + self.sampling_done = False self._masks = {"train": dataset.train_mask, "val": dataset.val_mask, "test": dataset.test_mask} - self._sampling_done = False - def _pre_process(self, **kwargs): + @property + def sample_graph_ops(self): + if self.cluster_method == "metis": + return self._metis_clustering + else: + raise NotImplementedError - self.cluster_method = kwargs.get("cluster_method", "random") + def _pre_process(self, **kwargs): + + self.cluster_method = kwargs.get("cluster_method", "metis") self.cluster_number = kwargs.get("cluster_number", 32) + self._save_dir = kwargs.get("save_dir", None) if self._save_dir is not None: - self._save_path_pt = os.path.join(self._save_dir, f"partition_{self.cluster_method}_{self.cluster_number}.pt") - self._save_path_pkl = os.path.join(self._save_dir, f"partition_{self.cluster_method}_{self.cluster_number}.pkl") + self._save_path_pt = os.path.join(self._save_dir, f"cluster_partition_{self.cluster_method}_{self.cluster_number}.pt") + self._save_path_pkl = os.path.join(self._save_dir, f"cluster_partition_{self.cluster_method}_{self.cluster_number}.pkl") else: self._save_path_pt = self._save_path_pkl = None def collate_fn(self, batch_inds, mode): - if self._sampling_done is False: - if self._save_dir is not None and os.path.exists(self._save_path_pt) and os.path.exists(self._save_path_pkl): - print("\nLoad from existing clusters.\n") - (self.perm_adjs, self.partptr, self.perm_node_idx) = torch.load(self._save_path_pt) - self.splitted_perm_adjs = pkl.load(open(self._save_path_pkl, "rb")) - else: - if self.cluster_method == "metis": - print("\nMetis graph clustering started.\n") - self._metis_clustering() - else: - raise NotImplementedError - - self._sampling_done = True - if not isinstance(batch_inds, torch.Tensor): batch_inds = torch.tensor(batch_inds) diff --git a/sgl/tasks/node_classification_sampling.py b/sgl/tasks/node_classification_sampling.py index ea76e5c..b132421 100644 --- a/sgl/tasks/node_classification_sampling.py +++ b/sgl/tasks/node_classification_sampling.py @@ -60,6 +60,7 @@ def _execute(self): if self.__mini_batch_train: if self.__train_determined_sample: + self.__model.pre_sample("train") self.__train_loader = DataLoader( range(self.__train_graph_number), batch_size=self.__train_batch_size, num_workers=self.__train_num_workers, collate_fn=lambda x: self.__model.collate_fn(x, "train"), shuffle=True, drop_last=False) else: @@ -72,6 +73,7 @@ def _execute(self): if self.__mini_batch_eval: if self.__eval_determined_sample: + self.__model.pre_sample("eval") self.__val_loader = DataLoader( range(self.__eval_graph_number), batch_size=self.__eval_batch_size, num_workers=self.__eval_num_workers, collate_fn=lambda x: self.__model.collate_fn(x, "val"), shuffle=False, drop_last=False) self.__test_loader = DataLoader( @@ -86,7 +88,7 @@ def _execute(self): self.__dataset.test_idx, batch_size=self.__eval_batch_size, num_workers=self.__eval_num_workers, collate_fn=self.__model.eval_collate_fn, shuffle=False, drop_last=False) self.__all_eval_loader = DataLoader( self.__dataset.node_ids, batch_size=self.__eval_batch_size, num_workers=self.__eval_num_workers, collate_fn=self.__model.eval_collate_fn, shuffle=False, drop_last=False) - + self.__model = self.__model.to(self.__device) t_total = time.time() From ad033537b81263e08de0c9bbca0681a309af28cf Mon Sep 17 00:00:00 2001 From: infinity Date: Mon, 4 Dec 2023 12:17:59 +0000 Subject: [PATCH 16/28] add nodewise/layerwise/graphwise-sampler class --- sgl/sampler/base_sampler.py | 13 ++++++++++++- sgl/sampler/sampler.py | 16 ---------------- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/sgl/sampler/base_sampler.py b/sgl/sampler/base_sampler.py index 1cb767e..b0249a7 100644 --- a/sgl/sampler/base_sampler.py +++ b/sgl/sampler/base_sampler.py @@ -121,15 +121,26 @@ def __init__(self, adj, **kwargs): self.__indices = self._adj.indices self.__values = self._adj.data - def one_layer_sampling(self, target_nodes, layer_size, biased): + def _pre_process(self, **kwargs): + self._get_sample_sizes(**kwargs) + self._calc_probs(**kwargs) + self.replace = kwargs.get("replace", True) + + def one_layer_sampling(self, target_nodes, layer_size, biased): source_nodes, (s_indptr, s_indices, s_data) = NodeWiseOneLayer(target_nodes, self.__indptr, self.__indices, self.__values, layer_size, self.probs, biased, self.replace) subgraph_adj = sp.csr_matrix((s_data, s_indices, s_indptr), shape=(len(target_nodes), len(source_nodes))) + return source_nodes, subgraph_adj class LayerWiseSampler(BaseSampler): def __init__(self, adj, **kwargs): super(LayerWiseSampler, self).__init__(adj, **kwargs) + def _pre_process(self, **kwargs): + self._get_sample_sizes(**kwargs) + self._calc_probs(**kwargs) + self.replace = kwargs.get("replace", False) + def one_layer_sampling(self, target_nodes, layer_size, probability): subgraph_adj = self._adj[target_nodes, :] neis = np.nonzero(np.sum(subgraph_adj, axis=0))[1] diff --git a/sgl/sampler/sampler.py b/sgl/sampler/sampler.py index f7cd0b1..c22b3a9 100644 --- a/sgl/sampler/sampler.py +++ b/sgl/sampler/sampler.py @@ -19,14 +19,6 @@ def __init__(self, adj, **kwargs): self.sampler_name = "NeighborSampler" self.sample_level = "node" self.pre_sampling = False - - def _pre_process(self, **kwargs): - - self._get_sample_sizes(**kwargs) - - self._calc_probs(**kwargs) - - self.replace = kwargs.get("replace", True) def collate_fn(self, batch_inds): """ @@ -65,14 +57,6 @@ def __init__(self, adj, **kwargs): self.sample_level = "layer" self.pre_sampling = False - def _pre_process(self, **kwargs): - - self._get_sample_sizes(**kwargs) - - self._calc_probs(**kwargs) - - self.replace = kwargs.get("replace", False) - def collate_fn(self, batch_inds): """ Input: From 3f421088b9dac2bd0ac8a2f6dc8224683a4a52ce Mon Sep 17 00:00:00 2001 From: TheRoadQaQ Date: Thu, 7 Dec 2023 21:27:37 +0800 Subject: [PATCH 17/28] Add graphsaint. --- examples/configs/graphsaint.yml | 12 +- examples/graphsaint_nodeclass.py | 10 +- sgl/models/homo/graphsaint.py | 31 ++++- sgl/sampler/__init__.py | 3 +- sgl/sampler/sampler.py | 192 ++++++++++++++++++++++++------- 5 files changed, 193 insertions(+), 55 deletions(-) diff --git a/examples/configs/graphsaint.yml b/examples/configs/graphsaint.yml index d016b2b..3b44515 100644 --- a/examples/configs/graphsaint.yml +++ b/examples/configs/graphsaint.yml @@ -4,16 +4,20 @@ dataset: root: "/home/ssq/test_data/" sampler: train: - pre_sampling_graphs: 10 - samplertype: "Node" - nodebudget: 2048 + pre_sampling_graphs: 20 + sampler_type: "Node" + nodebudget: 1000 + edgebudget: 3000 + r: 500 + h: 4 pre_sampling_op: "RwGraphOp" model: hidden_dim: 128 dropout: 0.5 num_layers: 2 task: - train_batch_size: 2048 + train_graph_number: 10 + train_batch_size: 5 epochs: 100 lr: 0.01 weight_decay: 0.00005 diff --git a/examples/graphsaint_nodeclass.py b/examples/graphsaint_nodeclass.py index 66295f3..b4a57a4 100644 --- a/examples/graphsaint_nodeclass.py +++ b/examples/graphsaint_nodeclass.py @@ -28,7 +28,7 @@ train_sampler_kwargs = sampler_kwargs["train"] train_sampler_kwargs.update({"save_dir": dataset.processed_dir}) - train_sampler = GraphSAINTSampler(dataset.adj, **train_sampler_kwargs) + train_sampler = GraphSAINTSampler(dataset, **train_sampler_kwargs) if "eval" in sampler_kwargs: eval_sampler_kwargs = sampler_kwargs["eval"] eval_sampler_name = eval_sampler_kwargs["name"] @@ -45,11 +45,5 @@ model_kwargs.update({"device": device}) model = GraphSAINT(dataset, train_sampler, eval_sampler, **model_kwargs) task_kwargs.update({"device": device}) - - def myloss(pred,labels): - loss = nll_loss(pred, labels, reduction="none") - loss = (loss/model.cur_loss_norm).sum() - return loss - - task_kwargs.update({"loss_fn":myloss}) + task_kwargs.update({"loss_fn":model.loss}) test_acc = NodeClassification_Sampling(dataset, model, **task_kwargs).test_acc diff --git a/sgl/models/homo/graphsaint.py b/sgl/models/homo/graphsaint.py index 76c3eea..6232fdc 100644 --- a/sgl/models/homo/graphsaint.py +++ b/sgl/models/homo/graphsaint.py @@ -2,16 +2,45 @@ from sgl.models.base_model import BaseSAMPLEModel from sgl.operators.graph_op import RwGraphOp +from torch.nn.functional import nll_loss + class GraphSAINT(BaseSAMPLEModel): def __init__(self, dataset, training_sampler, eval_sampler, hidden_dim, dropout=0.5, num_layers=2, device="cpu"): super(GraphSAINT, self).__init__() self._pre_graph_op = RwGraphOp() self._training_sampling_op = training_sampler self._eval_sampling_op = eval_sampler + self.device = device self._base_model = GCN( nfeat=dataset.num_features, nhid=hidden_dim, nclass=dataset.num_classes, nlayers=num_layers, dropout=dropout ).to(device) + def pre_sample(self, mode="train"): + self._training_sampling_op._calc_norm() + self._training_sampling_op.loss_norm.to(device=self.device) + return + + def mini_batch_prepare_forward(self, batch, device, **kwargs): + batch_in, batch_out, block = batch + local_inds, global_inds = batch_out + + in_x = self._processed_feature[batch_in].to(device) + y_truth = self._vanilla_y[global_inds].to(device) + block.to_device(device) + y_pred = self._base_model(in_x, block)[local_inds] + return y_pred, y_truth + + def collate_fn(self, batch_ids, mode): + if mode == "train": + return self._training_sampling_op.collate_fn(batch_ids, mode) + else: + return self._eval_sampling_op.collate_fn(batch_ids, mode) + + def loss(self, pred, labels): + loss = nll_loss(pred, labels, reduction="none") + loss = (loss / self.cur_loss_norm).sum() + return loss + @property def cur_loss_norm(self): - return self._training_sampling_op.loss_norm[self._training_sampling_op.index] \ No newline at end of file + return self._training_sampling_op.loss_norm[self._training_sampling_op.cur_index] \ No newline at end of file diff --git a/sgl/sampler/__init__.py b/sgl/sampler/__init__.py index eac7dc6..1df6716 100644 --- a/sgl/sampler/__init__.py +++ b/sgl/sampler/__init__.py @@ -1,9 +1,10 @@ -from .sampler import FastGCNSampler, ClusterGCNSampler, NeighborSampler +from .sampler import FastGCNSampler, ClusterGCNSampler, GraphSAINTSampler,NeighborSampler from .base_sampler import FullSampler, NodeWiseSampler, LayerWiseSampler, GraphWiseSampler __all__ = [ "FastGCNSampler", "ClusterGCNSampler", + "GraphSAINTSampler", "NeighborSampler", "FullSampler", "NodeWiseSampler", diff --git a/sgl/sampler/sampler.py b/sgl/sampler/sampler.py index f9873fb..9cee82b 100644 --- a/sgl/sampler/sampler.py +++ b/sgl/sampler/sampler.py @@ -172,66 +172,148 @@ def _metis_clustering(self): pkl.dump(self.splitted_perm_adjs, open(self._save_path_pkl, "wb")) print(f"\nSave Metis graph clustering results under the {self._save_dir} directory.\n") -class GraphSAINTSampler(BaseSampler): +class GraphSAINTSampler(GraphWiseSampler): ''' - sample the wholo graph, feature and label as GraphSAINT method + sample the wholo graph, feature set and label as GraphSAINT method ''' - def __init__(self, adj, **kwargs): + def __init__(self, dataset, **kwargs): """ Inputs: - adj: adj of dgl Graph:sp.matrix - kwargs: some params + adj: Adjacency matrix: scipy.sparse.csr_matrix """ - self.replace = True - self.node_budget = kwargs['nodebudget'] - - super(GraphSAINTSampler, self).__init__(adj, **kwargs) + super(GraphSAINTSampler, self).__init__(dataset.adj, **kwargs) + self.replace = True self.sampler_name = "GraphSaintSampler" self.sample_level = "graph" self.pre_sampling = False + self._masks = {"train": dataset.train_mask, "val": dataset.val_mask, "test": dataset.test_mask} - def _pre_process(self, **kwargs): - if kwargs['samplertype'] == "Node": + self.n = dataset.adj.shape[0] + self.e = dataset.adj.nnz + self.pre_sampling_times = kwargs.get("pre_sampling_graphs", 1) + self.used_sample_graphs = 0 + + if kwargs['sampler_type'] == "Node": + kwargs.update({"prob_type": "normalize"}) self._calc_probs(**kwargs) - self.sample = self.node_sample + self.node_probs = self.probs + self.node_budget = kwargs['nodebudget'] + self.sample_graph_type = "Node" + elif kwargs['sampler_type'] == 'Edge': + self._calc_edge_probs() + self.edge_budget = kwargs['edgebudget'] + self.sample_graph_type = "Edge" + elif kwargs['sampler_type'] == 'RandomWalk': + self.r = kwargs['r'] + self.h = kwargs['h'] + self.sample_graph_type = "RandomWalk" + else: + raise NotImplementedError + + @property + def sample_graph_ops(self): + if self.sample_graph_type == "Node": + return self.node_sampler() + elif self.sample_graph_type == "Edge": + return self.edge_sampler() + elif self.sample_graph_type == "RandomWalk": + return self.random_walk_sampler() else: raise NotImplementedError - self._calc_norm(**kwargs) + def node_sampler(self): + """ + method: sample fixed size of nodes as a subgraph with node_probs - def node_sample(self): + Outputs: + sampled_node: global node index + block: sampled adjs, csr sparse matrix """ - Inputs: - batch_ids: is not used in this method - method: sample fixed size of nodes as a subgraph + p = self.node_probs + + sampled_node = np.random.choice(a=self.n, size=self.node_budget, replace=self.replace, p=p) + sampled_node = np.unique(sampled_node) + + subadj = self._adj[sampled_node, :] + subadj = subadj[:, sampled_node] + return sampled_node, subadj + + def _calc_edge_probs(self): + """ + method: calculate edge probablity as 1/d(u)+1/d(v) + """ + degrees = self._adj.sum(axis=1).A1 + edges = self._adj.nonzero() + start_degrees = degrees[edges[0]] + end_degrees = degrees[edges[1]] + + self.edge_probs = 1 / start_degrees + 1 / end_degrees + self.edge_probs = self.edge_probs / np.sum(self.edge_probs) + return + + def edge_sampler(self): + """ + method: sample fixed size of edges as a subgraph with edge_probs Outputs: - batch_in: global node index - batch_out: global node index - block: sampled adjs in the form of sparse tensors wrapped in Block class + sampled_node: global node index + block: sampled adjs, csr sparse matrix """ - p = self.probs - sampled = np.random.choice(np.arange(np.size(p)), self.node_budget, self.replace, p) - sampled = np.unique(sampled) + p = self.edge_probs + sampled_edges = np.random.choice(a=self.e, size=self.edge_budget, replace=self.replace, p=p) + sampled_edges = np.unique(sampled_edges) + + edges = self._adj.nonzero() + sampled_start = edges[0][sampled_edges] + sampled_end = edges[1][sampled_edges] + + sampled_node = np.unique(np.concatenate([sampled_start,sampled_end])) + + subadj = self._adj[sampled_node, :] + subadj = subadj[:, sampled_node] + + return sampled_node, subadj + + def random_walk_sampler(self): + """ + method: sample like random walk + + Outputs: + sampled_node: global node index + block: sampled adjs, csr sparse matrix + """ + root_nodes = np.random.choice(a=self.n, size=self.r, replace = self.replace) + sampled_node = [] + for v in root_nodes: + sampled_node.append(v) + + neighbors = self._adj.indices[self._adj.indptr[v]:self._adj.indptr[v+1]] + sampled_nei = np.random.choice(a=neighbors, size=self.h, replace=self.replace) + + sampled_node.extend(sampled_nei.tolist()) + + sampled_node = np.unique(np.array(sampled_node)) + + subadj = self._adj[sampled_node, :] + subadj = subadj[:, sampled_node] - adj = self._adj[sampled, :].tocsc() - adj = adj[:, sampled].tocsr() - return sampled, adj + return sampled_node, subadj - def _calc_norm(self, **kwargs): + def _calc_norm(self): """ methods: calculate the norm to estimate embedding and loss """ - times = kwargs['pre_sampling_graphs'] + self.sampled_graphs = [] - node_value = np.zeros(np.size(self.probs)) - edge_value = sp.lil_matrix((np.size(self.probs),np.size(self.probs))) + node_value = np.zeros(self.n) + edge_value = sp.lil_matrix((self.n,self.n)) - for _ in range(times): - sampled, adj = self.sample() + for _ in range(self.pre_sampling_times): + sampled, adj = self.sample_graph_ops + self.sampled_graphs.append((sampled,adj)) adj = adj.tocoo() for row, col in zip(adj.row, adj.col): edge_value[sampled[row],sampled[col]] += 1 @@ -243,22 +325,50 @@ def _calc_norm(self, **kwargs): self.loss_norm = torch.FloatTensor(np.maximum(node_value, 1)) return - def collate_fn(self, batch_ids): + def collate_fn(self, batch_ids, mode): """ Inputs: - batch_ids: is not used in this method + batch_ids: only the len of it is used, means how many subgraphs are sampled to construct computation graph - method: sample fixed size of nodes as a subgraph + method: sample len(batch_ids) subgraphs as mini-batch Outputs: batch_in: global node index batch_out: global node index block: sampled adjs in the form of sparse tensors wrapped in Block class """ - sampled, adj = self.sample() - sampled_aggr_norm = self.aggr_norm[sampled, :].tocsc() - sampled_aggr_norm = sampled_aggr_norm[:, sampled] - adj = adj.multiply(sampled_aggr_norm.transpose()) - self.index = sampled + adjs = [] + batch_in = [] + for _ in range(len(batch_ids)): + if self.used_sample_graphs < self.pre_sampling_times: + sampled, adj = self.sampled_graphs[self.used_sample_graphs] + self.used_sample_graphs += 1 + else: + sampled, adj = self.sample_graph_ops + + sampled_aggr_norm = self.aggr_norm[sampled, :] + sampled_aggr_norm = sampled_aggr_norm[:, sampled] + adj = adj.multiply(sampled_aggr_norm.transpose()) + adjs.append(adj) + batch_in.extend(sampled) + + batched_adj = sp.block_diag(adjs, format='csr') + batch_in = torch.LongTensor(batch_in) + + if mode in ["train", "val", "test"]: + mask = self._masks[mode][batch_in] + global_inds = batch_in[mask] + local_inds = mask_to_index(mask) + batch_out = torch.vstack([local_inds, global_inds]) + else: + mode = mode.split("_") + batch_out = {} + for one_mode in mode: + mask = self._masks[one_mode][batch_in] + global_inds = batch_in[mask] + local_inds = mask_to_index(mask) + batch_out.update({one_mode: torch.vstack([local_inds, global_inds])}) + + self.cur_index = global_inds - return sampled,sampled,self._to_Block(adj) \ No newline at end of file + return batch_in,batch_out,self._to_Block(batched_adj) \ No newline at end of file From 0be039a40133d8e3147a2b1aaa77e5b9b6cb22b3 Mon Sep 17 00:00:00 2001 From: infinity Date: Fri, 8 Dec 2023 08:41:52 +0000 Subject: [PATCH 18/28] add GAugO --- sgl/models/simple_models.py | 128 +++++++++++++++++++++++++++++++----- sgl/tasks/__init__.py | 4 +- sgl/tasks/utils.py | 20 ++++++ 3 files changed, 135 insertions(+), 17 deletions(-) diff --git a/sgl/models/simple_models.py b/sgl/models/simple_models.py index 665e7b0..e5a79dc 100644 --- a/sgl/models/simple_models.py +++ b/sgl/models/simple_models.py @@ -190,7 +190,7 @@ class GCNConv(nn.Module): Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 """ - def __init__(self, in_features, out_features, bias=False): + def __init__(self, in_features, out_features, bias=True): super(GCNConv, self).__init__() self.in_features = in_features self.out_features = out_features @@ -202,10 +202,15 @@ def __init__(self, in_features, out_features, bias=False): self.reset_parameters() def reset_parameters(self): - stdv = 1.0 / math.sqrt(self.weight.size(1)) - self.weight.data.uniform_(-stdv, stdv) - if self.bias is not None: - self.bias.data.uniform_(-stdv, stdv) + # stdv = 1.0 / math.sqrt(self.weight.size(1)) + # self.weight.data.uniform_(-stdv, stdv) + # if self.bias is not None: + # self.bias.data.uniform_(-stdv, stdv) + for param in self.parameters(): + if len(param.size()) == 2: + nn.init.xavier_uniform_(param) + else: + nn.init.constant_(param, 0.0) def forward(self, input, adj): support = torch.mm(input, self.weight) @@ -217,7 +222,7 @@ def forward(self, input, adj): class SAGEConv(nn.Module): """ - Simple GraphSAGE layer, use mean as aggregation way + Simple GraphSAGE layer """ def __init__(self, in_features, out_features, normalize=True): @@ -252,8 +257,60 @@ def forward(self, x, adj): return output +class GATConv(nn.Module): + """ + Simple GAT layer + """ + def __init__(self, in_features, out_features, n_heads, bias=True): + super(GATConv, self).__init__() + self.W = nn.Parameter(torch.FloatTensor(in_features, out_features)) + self.n_heads = n_heads + self.attn_l = nn.Linear(out_features, self.n_heads, bias=False) + self.attn_r = nn.Linear(out_features, self.n_heads, bias=False) + self.attn_drop = nn.Dropout(p=0.6) + if bias: + self.b = nn.Parameter(torch.FloatTensor(out_features)) + else: + self.b = None + self.reset_parameters() + + def reset_parameters(self): + """ Initialize weights with xavier uniform and biases with all zeros """ + for param in self.parameters(): + if len(param.size()) == 2: + nn.init.xavier_uniform_(param) + else: + nn.init.constant_(param, 0.0) + + def forward(self, x, adj): + repr = x @ self.W + el = self.attn_l(repr) + er = self.attn_r(repr) + if isinstance(adj, torch.sparse.FloatTensor): + nz_indices = adj._indices() + else: + nz_indices = adj.nonzero().T + attn = el[nz_indices[0]] + er[nz_indices[1]] + attn = F.leaky_relu(attn, negative_slope=0.2).squeeze() + attn = torch.exp(attn) + if self.n_heads == 1: + adj_attn = torch.zeros(size=(adj.size(0), adj.size(1)), device=adj.device) + adj_attn.index_put_((nz_indices[0], nz_indices[1]), attn) + else: + adj_attn = torch.zeros(size=(adj.size(0), adj.size(1), self.n_heads), device=adj.device) + adj_attn.index_put_((nz_indices[0], nz_indices[1]), attn) + adj_attn.transpose_(1, 2) + adj_attn = F.normalize(adj_attn, p=1, dim=-1) + adj_attn = self.attn_drop(adj_attn) + repr = adj_attn @ repr + if self.b is not None: + repr = repr + self.b + if self.n_heads > 1: + repr = repr.flatten(start_dim=1) + return repr + class SAGE(nn.Module): - def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5): + def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, activation=F.relu): super(SAGE, self).__init__() self.gcs = nn.ModuleList() self.gcs.append(SAGEConv(nfeat, nhid)) @@ -262,6 +319,7 @@ def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5): self.gcs.append(SAGEConv(nhid, nhid)) self.gcs.append(SAGEConv(nhid, nclass, normalize=False)) self.dropout = dropout + self.activation = activation def reset_parameter(self): for conv in self.gcs: @@ -269,16 +327,18 @@ def reset_parameter(self): def forward(self, x, block): repr = x + if isinstance(block, torch.Tensor): + block = [block] if len(block) == self.nlayers: for i in range(self.nlayers-1): repr = self.gcs[i](repr, block[i]) - repr = F.relu(repr) + repr = self.activation(repr) repr = F.dropout(repr, self.dropout, training=self.training) repr = self.gcs[-1](repr, block[-1]) elif len(block) == 1: for gc in self.gcs[:-1]: repr = gc(repr, block[0]) - repr = F.relu(repr) + repr = self.activation(repr) repr = F.dropout(repr, self.dropout, training=self.training) repr = self.gcs[-1](repr, block[0]) else: @@ -306,15 +366,16 @@ def inference(self, x_all, subgraph_loader, device): return x_all class GCN(nn.Module): - def __init__(self, nfeat, nhid, nclass, layer=GCNConv, nlayers=2, dropout=0.5): + def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, activation=F.relu): super(GCN, self).__init__() self.gcs = nn.ModuleList() - self.gcs.append(layer(nfeat, nhid)) + self.gcs.append(GCNConv(nfeat, nhid)) self.nlayers = nlayers for _ in range(nlayers-2): - self.gcs.append(layer(nhid, nhid)) - self.gcs.append(layer(nhid, nclass)) + self.gcs.append(GCNConv(nhid, nhid)) + self.gcs.append(GCNConv(nhid, nclass)) self.dropout = dropout + self.activation = activation def reset_parameter(self): for conv in self.gcs: @@ -322,16 +383,18 @@ def reset_parameter(self): def forward(self, x, block): repr = x + if isinstance(block, torch.Tensor): + block = [block] if len(block) == self.nlayers: for i in range(self.nlayers-1): repr = self.gcs[i](repr, block[i]) - repr = F.relu(repr) + repr = self.activation(repr) repr = F.dropout(repr, self.dropout, training=self.training) repr = self.gcs[-1](repr, block[-1]) elif len(block) == 1: for gc in self.gcs[:-1]: repr = gc(repr, block[0]) - repr = F.relu(repr) + repr = self.activation(repr) repr = F.dropout(repr, self.dropout, training=self.training) repr = self.gcs[-1](repr, block[0]) else: @@ -356,4 +419,37 @@ def inference(self, x_all, subgraph_loader, device): x_all = torch.cat(xs, dim=0) - return x_all \ No newline at end of file + return x_all + +class GAT(nn.Module): + def __init__(self, nfeat, nhid, nclass, n_heads, nlayers=2, dropout=0.6, activation=F.elu): + super(GAT, self).__init__() + self.gcs = nn.ModuleList() + self.gcs.append(GATConv(nfeat, nhid // n_heads[0], n_heads[0])) + self.nlayers = nlayers + for i in range(nlayers-2): + self.gcs.append(GATConv(nhid, nhid // n_heads[i+1], n_heads[i+1])) + self.gcs.append(GATConv(nhid, nclass, n_heads[-1])) + self.dropout = dropout + self.activation = activation + + def forward(self, x, block): + repr = x + if isinstance(block, torch.Tensor): + block = [block] + if len(block) == self.nlayers: + for i in range(self.nlayers-1): + repr = self.gcs[i](repr, block[i]) + repr = self.activation(repr) + repr = F.dropout(repr, self.dropout, training=self.training) + repr = self.gcs[-1](repr, block[-1]) + elif len(block) == 1: + for gc in self.gcs[:-1]: + repr = gc(repr, block[0]) + repr = self.activation(repr) + repr = F.dropout(repr, self.dropout, training=self.training) + repr = self.gcs[-1](repr, block[0]) + else: + raise ValueError('The sampling layer must be equal to GNN layer.') + + return F.log_softmax(repr, dim=1) \ No newline at end of file diff --git a/sgl/tasks/__init__.py b/sgl/tasks/__init__.py index 7086af0..f68c17e 100644 --- a/sgl/tasks/__init__.py +++ b/sgl/tasks/__init__.py @@ -8,6 +8,7 @@ from .correct_and_smooth import NodeClassification_With_CorrectAndSmooth from .node_classification_with_label_use import NodeClassificationWithLabelUse from .node_classification_dist import NodeClassificationDist +from .node_classification_GAug import NodeClassification_GAug __all__ = [ "NodeClassification", @@ -20,5 +21,6 @@ "NodeClassificationWithLabelUse", "NodeClassificationDist", "NodeClassification_Sampling", - "NodeClassification_RecycleSampling" + "NodeClassification_RecycleSampling", + "NodeClassification_GAug" ] diff --git a/sgl/tasks/utils.py b/sgl/tasks/utils.py index 393de80..9f61a90 100644 --- a/sgl/tasks/utils.py +++ b/sgl/tasks/utils.py @@ -438,3 +438,23 @@ def sparse_mx_to_torch_sparse_tensor(sparse_mx): values = torch.from_numpy(sparse_mx.data) shape = torch.Size(sparse_mx.shape) return torch.sparse.FloatTensor(indices, values, shape) + +class MultipleOptimizer(): + """ a class that wraps multiple optimizers """ + def __init__(self, *op): + self.optimizers = op + + def zero_grad(self): + for op in self.optimizers: + op.zero_grad() + + def step(self): + for op in self.optimizers: + op.step() + + def update_lr(self, op_index, new_lr): + """ update the learning rate of one optimizer + Parameters: op_index: the index of the optimizer to update + new_lr: new learning rate for that optimizer """ + for param_group in self.optimizers[op_index].param_groups: + param_group['lr'] = new_lr \ No newline at end of file From 442d3e522a351adbf6c27b867a574a48873cd474 Mon Sep 17 00:00:00 2001 From: infinity Date: Fri, 8 Dec 2023 08:44:44 +0000 Subject: [PATCH 19/28] add GAugO --- examples/GDA/configs/GAugO.yml | 64 ++++++++ examples/GDA/test_GAug.py | 30 ++++ sgl/models/homo/gda/GAug.py | 221 ++++++++++++++++++++++++++ sgl/models/homo/gda/__init__.py | 5 + sgl/models/homo/gda/utils.py | 20 +++ sgl/tasks/node_classification_GAug.py | 172 ++++++++++++++++++++ 6 files changed, 512 insertions(+) create mode 100644 examples/GDA/configs/GAugO.yml create mode 100644 examples/GDA/test_GAug.py create mode 100644 sgl/models/homo/gda/GAug.py create mode 100644 sgl/models/homo/gda/__init__.py create mode 100644 sgl/models/homo/gda/utils.py create mode 100644 sgl/tasks/node_classification_GAug.py diff --git a/examples/GDA/configs/GAugO.yml b/examples/GDA/configs/GAugO.yml new file mode 100644 index 0000000..70b8c6b --- /dev/null +++ b/examples/GDA/configs/GAugO.yml @@ -0,0 +1,64 @@ +dataset: + classname: "Planetoid" + name: "cora" + root: "/home/ssq/test_data/" +model: + gnnlayer_type: 'gcn' + alpha: 1.0 + temperature: 1.2 + hidden_dim: 128 + emb_size: 32 + dropout: 0.5 + n_layers: 2 + gae: true + feat_norm: 'row' + sample_type: 'add_sample' +task: + lr: 0.01 + seed: 42 + warmup: 0 + beta: 0.8 + epochs: 200 + weight_decay: 0.0005 + pretrain_ep: 160 + pretrain_nc: 30 +# model: +# gnnlayer_type: 'gsage' +# alpha: 0.13 +# temperature: 1.0 +# hidden_dim: 128 +# emb_size: 32 +# dropout: 0.5 +# n_layers: 2 +# gae: true +# feat_norm: 'row' +# sample_type: 'add_sample' +# task: +# lr: 0.01 +# seed: 42 +# warmup: 2 +# beta: 3.2 +# epochs: 200 +# weight_decay: 0.0005 +# pretrain_ep: 195 +# pretrain_nc: 35 +# model: +# gnnlayer_type: 'gat' +# alpha: 0.02 +# temperature: 1.7 +# hidden_dim: 128 +# emb_size: 32 +# dropout: 0.6 +# n_layers: 2 +# gae: true +# feat_norm: 'row' +# sample_type: 'add_sample' +# task: +# lr: 0.01 +# seed: 42 +# warmup: 1 +# beta: 3.2 +# epochs: 200 +# weight_decay: 0.0005 +# pretrain_ep: 175 +# pretrain_nc: 45 diff --git a/examples/GDA/test_GAug.py b/examples/GDA/test_GAug.py new file mode 100644 index 0000000..e53e01e --- /dev/null +++ b/examples/GDA/test_GAug.py @@ -0,0 +1,30 @@ +import yaml +import argparse + +import sgl.dataset as Dataset +from sgl.models.homo.gda import GAug +from sgl.tasks import NodeClassification_GAug + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description = "GAug-Model.") + parser.add_argument( + "--device", type=int, default=0, help="gpu device id or cpu (-1)" + ) + parser.add_argument( + "--config_path", type=str, default="./configs/GAugO.yml", help="save path of the configuration file" + ) + args = parser.parse_args() + config = yaml.safe_load(open(args.config_path, "rb")) + device = f"cuda:{args.device}" if args.device >= 0 else "cpu" + dataset_kwargs = config["dataset"] + model_kwargs = config["model"] + task_kwargs = config["task"] + + dataset_classname = dataset_kwargs.pop("classname") + dataset = getattr(Dataset, dataset_classname)(**dataset_kwargs) + for seed in range(10): + model = GAug(in_dim=dataset.num_features, n_classes=dataset.num_classes, **model_kwargs) + task_kwargs.update({"device": device}) + task_kwargs.update({"seed": seed}) + test_acc = NodeClassification_GAug(dataset, model, **task_kwargs).test_acc + print(f"test acc: {test_acc:.4f}") \ No newline at end of file diff --git a/sgl/models/homo/gda/GAug.py b/sgl/models/homo/gda/GAug.py new file mode 100644 index 0000000..08f05e8 --- /dev/null +++ b/sgl/models/homo/gda/GAug.py @@ -0,0 +1,221 @@ +import pyro +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import scipy.sparse as sp + +from sgl.models.simple_models import GCNConv, GCN, SAGE, GAT +from sgl.models.homo.gda.utils import RoundNoGradient, CeilNoGradient +from sgl.utils import sparse_mx_to_torch_sparse_tensor + +class GAug(nn.Module): + def __init__(self, + in_dim, + hidden_dim, + emb_size, + n_classes, + n_layers, + dropout, + gnnlayer_type, + activation=F.relu, + temperature=1, + gae=False, + alpha=1, + feat_norm="row", + sample_type="add_sample", + **kwargs): + super(GAug, self).__init__() + self.__temperature = temperature + self.__gnnlayer_type = gnnlayer_type + self.__alpha = alpha + self.__sample_type = sample_type + # edge prediction network + self.__gae = gae + self.__feat_norm = feat_norm + self.ep_net = VGAE(in_dim, hidden_dim, emb_size, activation, gae=gae) + # node classification network + select_model = {"gcn": GCN, "gsage": SAGE, "gat": GAT} + if gnnlayer_type == 'gat': + if kwargs.get("n_heads"): + n_heads = list(map(lambda x: int(x), kwargs["n_heads"].split(","))) + else: + n_heads = [8] * (n_layers - 1) + [1] + kwargs.update({"n_heads": n_heads}) + activation = F.elu + + self.nc_net = select_model.get(gnnlayer_type)(in_dim, hidden_dim, n_classes, nlayers=n_layers, dropout=dropout, activation=activation, **kwargs) + + @property + def gae(self): + return self.__gae + + @staticmethod + def col_normalization(features): + """ column normalization for feature matrix """ + features = features.numpy() + m = features.mean(axis=0) + s = features.std(axis=0, ddof=0, keepdims=True) + 1e-12 + features -= m + features /= s + return torch.FloatTensor(features) + + def preprocess(self, features, adj_matrix, device): + if self.__feat_norm == 'row': + features = F.normalize(features, p=1, dim=1) + elif self.__feat_norm == 'col': + features = self.col_normalization(features) + features = features.to(device) + + assert sp.issparse(adj_matrix) + if not isinstance(adj_matrix, sp.coo_matrix): + adj_matrix = sp.coo_matrix(adj_matrix) + adj_matrix.setdiag(1) + adj_orig = sparse_mx_to_torch_sparse_tensor(adj_matrix).to_dense().to(device) + # normalized adj_matrix used as input for ep_net (torch.sparse.FloatTensor) + degrees = np.array(adj_matrix.sum(1)) + degree_mat_inv_sqrt = sp.diags(np.power(degrees, -0.5).flatten()) + adj_norm_matrix = degree_mat_inv_sqrt @ adj_matrix @ degree_mat_inv_sqrt + adj_norm = sparse_mx_to_torch_sparse_tensor(adj_norm_matrix) + # adj_matrix used as input for nc_net (torch.sparse.FloatTensor) + if self.__gnnlayer_type == 'gcn': + adj = sparse_mx_to_torch_sparse_tensor(adj_norm_matrix) + elif self.__gnnlayer_type == 'gsage': + adj_matrix_noselfloop = sp.coo_matrix(adj_matrix) + adj_matrix_noselfloop = sp.coo_matrix(adj_matrix_noselfloop / adj_matrix_noselfloop.sum(1)) + adj = sparse_mx_to_torch_sparse_tensor(adj_matrix_noselfloop) + elif self.__gnnlayer_type == 'gat': + adj = torch.FloatTensor(adj_matrix.todense()) + + adj_norm = adj_norm.to(device) + adj = adj.to(device) + + return features, adj_orig, adj_norm, adj + + def sample_adj(self, adj_logits): + """ sample an adj from the predicted edge probabilities of ep_net """ + edge_probs = adj_logits / torch.max(adj_logits) + # sampling + adj_sampled = pyro.distributions.RelaxedBernoulliStraightThrough(temperature=self.__temperature, probs=edge_probs).rsample() + # making adj_sampled symmetric + adj_sampled = adj_sampled.triu(1) + adj_sampled = adj_sampled + adj_sampled.T + return adj_sampled + + def sample_adj_add_bernoulli(self, adj_logits, adj_orig, alpha): + edge_probs = adj_logits / torch.max(adj_logits) + edge_probs = alpha * edge_probs + (1-alpha) * adj_orig + # sampling + adj_sampled = pyro.distributions.RelaxedBernoulliStraightThrough(temperature=self.__temperature, probs=edge_probs).rsample() + # making adj_sampled symmetric + adj_sampled = adj_sampled.triu(1) + adj_sampled = adj_sampled + adj_sampled.T + return adj_sampled + + def sample_adj_add_round(self, adj_logits, adj_orig, alpha): + edge_probs = adj_logits / torch.max(adj_logits) + edge_probs = alpha * edge_probs + (1-alpha) * adj_orig + # sampling + adj_sampled = RoundNoGradient.apply(edge_probs) + # making adj_sampled symmetric + adj_sampled = adj_sampled.triu(1) + adj_sampled = adj_sampled + adj_sampled.T + return adj_sampled + + def sample_adj_random(self, adj_logits): + adj_rand = torch.rand(adj_logits.size()) + adj_rand = adj_rand.triu(1) + adj_rand = torch.round(adj_rand) + adj_rand = adj_rand + adj_rand.T + return adj_rand + + def sample_adj_edge(self, adj_logits, adj_orig, change_frac): + adj = adj_orig.to_dense() if adj_orig.is_sparse else adj_orig + n_edges = adj.nonzero().size(0) + n_change = int(n_edges * change_frac / 2) + # take only the upper triangle + edge_probs = adj_logits.triu(1) + edge_probs = edge_probs - torch.min(edge_probs) + edge_probs = edge_probs / torch.max(edge_probs) + adj_inverse = 1 - adj + # get edges to be removed + mask_rm = edge_probs * adj + nz_mask_rm = mask_rm[mask_rm>0] + if len(nz_mask_rm) > 0: + n_rm = len(nz_mask_rm) if len(nz_mask_rm) < n_change else n_change + thresh_rm = torch.topk(mask_rm[mask_rm>0], n_rm, largest=False)[0][-1] + mask_rm[mask_rm > thresh_rm] = 0 + mask_rm = CeilNoGradient.apply(mask_rm) + mask_rm = mask_rm + mask_rm.T + # remove edges + adj_new = adj - mask_rm + # get edges to be added + mask_add = edge_probs * adj_inverse + nz_mask_add = mask_add[mask_add>0] + if len(nz_mask_add) > 0: + n_add = len(nz_mask_add) if len(nz_mask_add) < n_change else n_change + thresh_add = torch.topk(mask_add[mask_add>0], n_add, largest=True)[0][-1] + mask_add[mask_add < thresh_add] = 0 + mask_add = CeilNoGradient.apply(mask_add) + mask_add = mask_add + mask_add.T + # add edges + adj_new = adj_new + mask_add + return adj_new + + def normalize_adj(self, adj): + if self.__gnnlayer_type == 'gcn': + adj.fill_diagonal_(1) + # normalize adj with A = D^{-1/2} @ A @ D^{-1/2} + D_norm = torch.diag(torch.pow(adj.sum(1), -0.5)).to(adj.device) + adj = D_norm @ adj @ D_norm + elif self.__gnnlayer_type == 'gat': + adj.fill_diagonal_(1) + elif self.__gnnlayer_type == 'gsage': + adj.fill_diagonal_(1) + adj = F.normalize(adj, p=1, dim=1) + return adj + + def forward(self, adj, adj_orig, features): + adj_logits = self.ep_net(adj, features) + if self.__sample_type == 'edge': + adj_new = self.sample_adj_edge(adj_logits, adj_orig, self.__alpha) + elif self.__sample_type == 'add_round': + adj_new = self.sample_adj_add_round(adj_logits, adj_orig, self.__alpha) + elif self.__sample_type == 'rand': + adj_new = self.sample_adj_random(adj_logits) + elif self.__sample_type == 'add_sample': + if self.__alpha == 1: + adj_new = self.sample_adj(adj_logits) + else: + adj_new = self.sample_adj_add_bernoulli(adj_logits, adj_orig, self.__alpha) + adj_new_normed = self.normalize_adj(adj_new) + nc_logits = self.nc_net(features, adj_new_normed) + return nc_logits, adj_logits + + +class VGAE(nn.Module): + """ GAE/VGAE as edge prediction model """ + def __init__(self, in_dim, hidden_dim, emb_size, activation, gae=False): + super(VGAE, self).__init__() + self.gae = gae + self.activation = activation + self.gcn_base = GCNConv(in_dim, hidden_dim, bias=False) + self.gcn_mean = GCNConv(hidden_dim, emb_size, bias=False) + self.gcn_logstd = GCNConv(hidden_dim, emb_size, bias=False) + + def forward(self, adj, features): + # GCN encoder + hidden = self.gcn_base(features, adj, ) + self.mean = self.activation(self.gcn_mean(hidden, adj)) + if self.gae: + # GAE (no sampling at bottleneck) + Z = self.mean + else: + # VGAE + self.logstd = self.activation(self.gcn_logstd(hidden, adj)) + gaussian_noise = torch.randn_like(self.mean) + sampled_Z = gaussian_noise * torch.exp(self.logstd) + self.mean + Z = sampled_Z + # inner product decoder + adj_logits = Z @ Z.T + return adj_logits \ No newline at end of file diff --git a/sgl/models/homo/gda/__init__.py b/sgl/models/homo/gda/__init__.py new file mode 100644 index 0000000..bac6997 --- /dev/null +++ b/sgl/models/homo/gda/__init__.py @@ -0,0 +1,5 @@ +from .GAug import GAug + +__all__ = [ + "GAug" +] \ No newline at end of file diff --git a/sgl/models/homo/gda/utils.py b/sgl/models/homo/gda/utils.py new file mode 100644 index 0000000..e1d0b8f --- /dev/null +++ b/sgl/models/homo/gda/utils.py @@ -0,0 +1,20 @@ +import torch + +class RoundNoGradient(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x.round() + + @staticmethod + def backward(ctx, g): + return g + + +class CeilNoGradient(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x.ceil() + + @staticmethod + def backward(ctx, g): + return g \ No newline at end of file diff --git a/sgl/tasks/node_classification_GAug.py b/sgl/tasks/node_classification_GAug.py new file mode 100644 index 0000000..54fdf1d --- /dev/null +++ b/sgl/tasks/node_classification_GAug.py @@ -0,0 +1,172 @@ +import gc +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from sgl.tasks.base_task import BaseTask +from sgl.tasks.utils import set_seed, accuracy, MultipleOptimizer + +class NodeClassification_GAug(BaseTask): + def __init__(self, dataset, model, lr, weight_decay, epochs, device, seed, beta, warmup, pretrain_ep, pretrain_nc): + super(NodeClassification_GAug, self).__init__() + + self.__dataset = dataset + self.__model = model + + self.__lr = lr + self.__weight_decay = weight_decay + + self.__epochs = epochs + self.__device = device + self.__seed = seed + + self.__warmup = warmup + self.__beta = beta + + self.__pretrain_ep = pretrain_ep + self.__pretrain_nc = pretrain_nc + + self.__test_acc = self._execute() + + @property + def test_acc(self): + return self.__test_acc + + @staticmethod + def get_lr_schedule_by_sigmoid(n_epochs, lr, warmup): + """ schedule the learning rate with the sigmoid function. + The learning rate will start with near zero and end with near lr """ + factors = torch.FloatTensor(np.arange(n_epochs)) + factors = ((factors / factors[-1]) * (warmup * 2)) - warmup + factors = torch.sigmoid(factors) + # range the factors to [0, 1] + factors = (factors - factors[0]) / (factors[-1] - factors[0]) + lr_schedule = factors * lr + return lr_schedule + + @staticmethod + def col_normalization(features): + """ column normalization for feature matrix """ + features = features.numpy() + m = features.mean(axis=0) + s = features.std(axis=0, ddof=0, keepdims=True) + 1e-12 + features -= m + features /= s + return torch.FloatTensor(features) + + def pretrain_ep_net(self, model, adj, features, adj_orig, norm_w, pos_weight): + """ pretrain the edge prediction network """ + optimizer = torch.optim.Adam(model.ep_net.parameters(), + lr=self.__lr) + model.train() + for _ in range(self.__pretrain_ep): + adj_logits = model.ep_net(adj, features) + loss = norm_w * F.binary_cross_entropy_with_logits(adj_logits, adj_orig, pos_weight=pos_weight) + if not model.gae: + mu = model.ep_net.mean + lgstd = model.ep_net.logstd + kl_divergence = 0.5 / adj_logits.size(0) * (1 + 2*lgstd - mu**2 - torch.exp(2*lgstd)).sum(1).mean() + loss -= kl_divergence + optimizer.zero_grad() + loss.backward() + optimizer.step() + + def pretrain_nc_net(self, model, adj, features, labels): + """ pretrain the node classification network """ + optimizer = torch.optim.Adam(model.nc_net.parameters(), + lr=self.__lr, + weight_decay=self.__weight_decay) + # loss function for node classification + if labels.dim() == 2: + nc_criterion = nn.BCEWithLogitsLoss() + else: + nc_criterion = nn.CrossEntropyLoss() + + best_val_acc = 0. + for _ in range(self.__pretrain_nc): + model.train() + nc_logits = model.nc_net(features, adj) + # losses + loss = nc_criterion(nc_logits[self.__dataset.train_idx], labels[self.__dataset.train_idx]) + optimizer.zero_grad() + loss.backward() + optimizer.step() + model.eval() + with torch.no_grad(): + nc_logits_eval = model.nc_net(features, adj) + val_acc = accuracy(nc_logits_eval[self.__dataset.val_idx], labels[self.__dataset.val_idx]) + if val_acc > best_val_acc: + best_val_acc = val_acc + + def _execute(self): + set_seed(self.__seed) + + features, adj_orig, adj_norm, adj = self.__model.preprocess(self.__dataset.x, self.__dataset.adj, self.__device) + + model = self.__model.to(self.__device) + labels = self.__dataset.y.to(self.__device) + + # weights for log_lik loss when training EP net + adj_t = adj_orig + norm_w = adj_t.shape[0]**2 / float((adj_t.shape[0]**2 - adj_t.sum()) * 2) + pos_weight = torch.FloatTensor([float(adj_t.shape[0]**2 - adj_t.sum()) / adj_t.sum()]).to(self.__device) + # pretrain VGAE if needed + if self.__pretrain_ep: + self.pretrain_ep_net(model, adj_norm, features, adj_orig, norm_w, pos_weight) + # pretrain GCN if needed + if self.__pretrain_nc: + self.pretrain_nc_net(model, adj, features, labels) + # optimizers + optims = MultipleOptimizer(torch.optim.Adam(model.ep_net.parameters(), + lr=self.__lr), + torch.optim.Adam(model.nc_net.parameters(), + lr=self.__lr, + weight_decay=self.__weight_decay)) + # get the learning rate schedule for the optimizer of ep_net if needed + if self.__warmup: + ep_lr_schedule = self.get_lr_schedule_by_sigmoid(self.__epochs, self.__lr, self.__warmup) + # loss function for node classification + if labels.dim() == 2: + nc_criterion = nn.BCEWithLogitsLoss() + else: + nc_criterion = nn.CrossEntropyLoss() + + # keep record of the best validation accuracy for early stopping + best_val_acc = 0. + patience_step = 0 + # train model + for epoch in range(self.__epochs): + # update the learning rate for ep_net if needed + if self.__warmup: + optims.update_lr(0, ep_lr_schedule[epoch]) + + model.train() + nc_logits, adj_logits = model(adj_norm, adj_orig, features) + + # losses + loss = nc_criterion(nc_logits[self.__dataset.train_idx], labels[self.__dataset.train_idx]) + ep_loss = norm_w * F.binary_cross_entropy_with_logits(adj_logits, adj_orig, pos_weight=pos_weight) + loss += self.__beta * ep_loss + optims.zero_grad() + loss.backward() + optims.step() + # validate (without dropout) + model.eval() + with torch.no_grad(): + nc_logits_eval = model.nc_net(features, adj) + val_acc = accuracy(nc_logits_eval[self.__dataset.val_idx], labels[self.__dataset.val_idx]) + if val_acc > best_val_acc: + best_val_acc = val_acc + test_acc = accuracy(nc_logits_eval[self.__dataset.test_idx], labels[self.__dataset.test_idx]) + patience_step = 0 + else: + patience_step += 1 + if patience_step == 100: + break + # release RAM and GPU memory + del adj, features, labels, adj_orig + torch.cuda.empty_cache() + gc.collect() + + return test_acc \ No newline at end of file From 3999fbc6406710d453a3c70de4165e843ff9b0d7 Mon Sep 17 00:00:00 2001 From: infinity Date: Sat, 9 Dec 2023 10:14:19 +0000 Subject: [PATCH 20/28] add GAugM --- examples/GDA/configs/GAugM.yml | 54 ++++++ examples/GDA/configs/GAugO.yml | 63 ++++--- examples/GDA/test_GAug.py | 11 +- sgl/models/homo/clustergcn.py | 2 +- sgl/models/homo/fastgcn.py | 2 +- sgl/models/homo/gda/GAug.py | 167 +++++++++++++---- sgl/models/homo/gda/__init__.py | 5 +- sgl/models/homo/graphsage.py | 2 +- sgl/models/homo/graphsaint.py | 2 +- sgl/models/homo/lazygnn.py | 2 +- sgl/models/homo/vanillagnn.py | 2 +- sgl/models/simple_models.py | 56 +++--- sgl/tasks/__init__.py | 5 +- sgl/tasks/node_classification_GAug.py | 252 ++++++++++++++++++-------- 14 files changed, 442 insertions(+), 183 deletions(-) create mode 100644 examples/GDA/configs/GAugM.yml diff --git a/examples/GDA/configs/GAugM.yml b/examples/GDA/configs/GAugM.yml new file mode 100644 index 0000000..7559a07 --- /dev/null +++ b/examples/GDA/configs/GAugM.yml @@ -0,0 +1,54 @@ +#dataset: +# classname: "Planetoid" +# name: "cora" +# root: "/home/ssq/test_data/" +#model: +# model_name: 'GAugM' +# gnn_type: 'gcn' +# hidden_dim: 128 +# dropout: 0.5 +# n_layers: 2 +# choose_idx: 5 +# rm_pct: 2 +# add_pct: 57 +#task: +# lr: 0.01 +# seed: 42 +# epochs: 200 +# weight_decay: 0.0005 +#dataset: +# classname: "Planetoid" +# name: "cora" +# root: "/home/ssq/test_data/" +#model: +# model_name: 'GAugM' +# gnn_type: 'gsage' +# hidden_dim: 128 +# dropout: 0.5 +# n_layers: 2 +# choose_idx: 2 +# rm_pct: 1 +# add_pct: 80 +#task: +# lr: 0.01 +# seed: 42 +# epochs: 200 +# weight_decay: 0.0005 +dataset: + classname: "Planetoid" + name: "cora" + root: "/home/ssq/test_data/" +model: + model_name: 'GAugM' + gnn_type: 'gat' + hidden_dim: 128 + dropout: 0.5 + n_layers: 2 + choose_idx: 2 + rm_pct: 1 + add_pct: 68 +task: + lr: 0.01 + seed: 42 + epochs: 200 + weight_decay: 0.0005 \ No newline at end of file diff --git a/examples/GDA/configs/GAugO.yml b/examples/GDA/configs/GAugO.yml index 70b8c6b..971e62a 100644 --- a/examples/GDA/configs/GAugO.yml +++ b/examples/GDA/configs/GAugO.yml @@ -1,11 +1,37 @@ +#dataset: +# classname: "Planetoid" +# name: "cora" +# root: "/home/ssq/test_data/" +#model: +# model_name: 'GAugO' +# gnn_type: 'gcn' +# alpha: 1.0 +# temperature: 1.2 +# hidden_dim: 128 +# emb_size: 32 +# dropout: 0.5 +# n_layers: 2 +# gae: true +# feat_norm: 'row' +# sample_type: 'add_sample' +#task: +# lr: 0.01 +# seed: 42 +# warmup: 0 +# beta: 0.8 +# epochs: 200 +# weight_decay: 0.0005 +# pretrain_ep: 160 +# pretrain_nc: 30 dataset: classname: "Planetoid" name: "cora" root: "/home/ssq/test_data/" model: - gnnlayer_type: 'gcn' - alpha: 1.0 - temperature: 1.2 + model_name: 'GAugO' + gnn_type: 'gsage' + alpha: 0.13 + temperature: 1.0 hidden_dim: 128 emb_size: 32 dropout: 0.5 @@ -16,34 +42,15 @@ model: task: lr: 0.01 seed: 42 - warmup: 0 - beta: 0.8 + warmup: 2 + beta: 3.2 epochs: 200 weight_decay: 0.0005 - pretrain_ep: 160 - pretrain_nc: 30 + pretrain_ep: 195 + pretrain_nc: 35 # model: -# gnnlayer_type: 'gsage' -# alpha: 0.13 -# temperature: 1.0 -# hidden_dim: 128 -# emb_size: 32 -# dropout: 0.5 -# n_layers: 2 -# gae: true -# feat_norm: 'row' -# sample_type: 'add_sample' -# task: -# lr: 0.01 -# seed: 42 -# warmup: 2 -# beta: 3.2 -# epochs: 200 -# weight_decay: 0.0005 -# pretrain_ep: 195 -# pretrain_nc: 35 -# model: -# gnnlayer_type: 'gat' +# model_name: 'GAugO' +# gnn_type: 'gat' # alpha: 0.02 # temperature: 1.7 # hidden_dim: 128 diff --git a/examples/GDA/test_GAug.py b/examples/GDA/test_GAug.py index e53e01e..74ccd3b 100644 --- a/examples/GDA/test_GAug.py +++ b/examples/GDA/test_GAug.py @@ -2,8 +2,8 @@ import argparse import sgl.dataset as Dataset -from sgl.models.homo.gda import GAug -from sgl.tasks import NodeClassification_GAug +from sgl.models.homo.gda import GAugO, GAugM +from sgl.tasks import NodeClassificationGAugO, NodeClassificationGAugM if __name__ == "__main__": parser = argparse.ArgumentParser(description = "GAug-Model.") @@ -22,9 +22,12 @@ dataset_classname = dataset_kwargs.pop("classname") dataset = getattr(Dataset, dataset_classname)(**dataset_kwargs) + Model = {"GAugO": GAugO, "GAugM": GAugM} + Task = {"GAugO": NodeClassificationGAugO, "GAugM": NodeClassificationGAugM} + model_name = model_kwargs.pop("model_name") for seed in range(10): - model = GAug(in_dim=dataset.num_features, n_classes=dataset.num_classes, **model_kwargs) + model = Model.get(model_name)(in_dim=dataset.num_features, n_classes=dataset.num_classes, **model_kwargs) task_kwargs.update({"device": device}) task_kwargs.update({"seed": seed}) - test_acc = NodeClassification_GAug(dataset, model, **task_kwargs).test_acc + test_acc = Task.get(model_name)(dataset, model, **task_kwargs).test_acc print(f"test acc: {test_acc:.4f}") \ No newline at end of file diff --git a/sgl/models/homo/clustergcn.py b/sgl/models/homo/clustergcn.py index 16cd01d..3b5308c 100644 --- a/sgl/models/homo/clustergcn.py +++ b/sgl/models/homo/clustergcn.py @@ -8,7 +8,7 @@ def __init__(self, training_sampler, eval_sampler, nfeat, hidden_dim, nclass, dr self._pre_graph_op = LaplacianGraphOp(r=0.5) self._training_sampling_op = training_sampler self._eval_sampling_op = eval_sampler - self._base_model = GCN(nfeat=nfeat, nhid=hidden_dim, nclass=nclass, nlayers=num_layers, dropout=dropout).to(device) + self._base_model = GCN(n_feat=nfeat, n_hid=hidden_dim, n_class=nclass, n_layers=num_layers, dropout=dropout).to(device) def pre_sample(self, mode="train"): if mode == "train": diff --git a/sgl/models/homo/fastgcn.py b/sgl/models/homo/fastgcn.py index 7483316..a11f031 100644 --- a/sgl/models/homo/fastgcn.py +++ b/sgl/models/homo/fastgcn.py @@ -9,5 +9,5 @@ def __init__(self, dataset, training_sampler, eval_sampler, hidden_dim, dropout= self._training_sampling_op = training_sampler self._eval_sampling_op = eval_sampler self._base_model = GCN( - nfeat=dataset.num_features, nhid=hidden_dim, nclass=dataset.num_classes, nlayers=num_layers, dropout=dropout + n_feat=dataset.num_features, n_hid=hidden_dim, n_class=dataset.num_classes, n_layers=num_layers, dropout=dropout ).to(device) diff --git a/sgl/models/homo/gda/GAug.py b/sgl/models/homo/gda/GAug.py index 08f05e8..acc5b5a 100644 --- a/sgl/models/homo/gda/GAug.py +++ b/sgl/models/homo/gda/GAug.py @@ -1,15 +1,20 @@ +import os import pyro +import copy import torch import torch.nn as nn import torch.nn.functional as F import numpy as np +import pickle as pkl import scipy.sparse as sp from sgl.models.simple_models import GCNConv, GCN, SAGE, GAT from sgl.models.homo.gda.utils import RoundNoGradient, CeilNoGradient from sgl.utils import sparse_mx_to_torch_sparse_tensor +from sgl.operators.graph_op import LaplacianGraphOp -class GAug(nn.Module): + +class GAugO(nn.Module): def __init__(self, in_dim, hidden_dim, @@ -17,7 +22,7 @@ def __init__(self, n_classes, n_layers, dropout, - gnnlayer_type, + gnn_type, activation=F.relu, temperature=1, gae=False, @@ -25,18 +30,19 @@ def __init__(self, feat_norm="row", sample_type="add_sample", **kwargs): - super(GAug, self).__init__() + super(GAugO, self).__init__() + self.__pre_graph_op = LaplacianGraphOp() self.__temperature = temperature - self.__gnnlayer_type = gnnlayer_type self.__alpha = alpha self.__sample_type = sample_type # edge prediction network self.__gae = gae self.__feat_norm = feat_norm self.ep_net = VGAE(in_dim, hidden_dim, emb_size, activation, gae=gae) - # node classification network - select_model = {"gcn": GCN, "gsage": SAGE, "gat": GAT} - if gnnlayer_type == 'gat': + # node classification network + self.__gnn_type = gnn_type + gnn_backbone = {"gcn": GCN, "gsage": SAGE, "gat": GAT} + if gnn_type == 'gat': if kwargs.get("n_heads"): n_heads = list(map(lambda x: int(x), kwargs["n_heads"].split(","))) else: @@ -44,7 +50,7 @@ def __init__(self, kwargs.update({"n_heads": n_heads}) activation = F.elu - self.nc_net = select_model.get(gnnlayer_type)(in_dim, hidden_dim, n_classes, nlayers=n_layers, dropout=dropout, activation=activation, **kwargs) + self.nc_net = gnn_backbone.get(gnn_type)(in_dim, hidden_dim, n_classes, n_layers=n_layers, dropout=dropout, activation=activation, **kwargs) @property def gae(self): @@ -70,49 +76,48 @@ def preprocess(self, features, adj_matrix, device): assert sp.issparse(adj_matrix) if not isinstance(adj_matrix, sp.coo_matrix): adj_matrix = sp.coo_matrix(adj_matrix) - adj_matrix.setdiag(1) - adj_orig = sparse_mx_to_torch_sparse_tensor(adj_matrix).to_dense().to(device) - # normalized adj_matrix used as input for ep_net (torch.sparse.FloatTensor) - degrees = np.array(adj_matrix.sum(1)) - degree_mat_inv_sqrt = sp.diags(np.power(degrees, -0.5).flatten()) - adj_norm_matrix = degree_mat_inv_sqrt @ adj_matrix @ degree_mat_inv_sqrt + adj_matrix_sl = adj_matrix + sp.eye(*adj_matrix.shape) + adj_orig = sparse_mx_to_torch_sparse_tensor(adj_matrix_sl).to_dense().to(device) + adj_norm_matrix = self.__pre_graph_op._construct_adj(adj_matrix) adj_norm = sparse_mx_to_torch_sparse_tensor(adj_norm_matrix) # adj_matrix used as input for nc_net (torch.sparse.FloatTensor) - if self.__gnnlayer_type == 'gcn': + if self.__gnn_type == 'gcn': adj = sparse_mx_to_torch_sparse_tensor(adj_norm_matrix) - elif self.__gnnlayer_type == 'gsage': - adj_matrix_noselfloop = sp.coo_matrix(adj_matrix) - adj_matrix_noselfloop = sp.coo_matrix(adj_matrix_noselfloop / adj_matrix_noselfloop.sum(1)) - adj = sparse_mx_to_torch_sparse_tensor(adj_matrix_noselfloop) - elif self.__gnnlayer_type == 'gat': - adj = torch.FloatTensor(adj_matrix.todense()) + elif self.__gnn_type == 'gsage': + adj = adj_matrix_sl / adj_matrix_sl.sum(1) + adj = sparse_mx_to_torch_sparse_tensor(adj) + elif self.__gnn_type == 'gat': + adj = torch.FloatTensor(adj_matrix_sl.todense()) adj_norm = adj_norm.to(device) adj = adj.to(device) return features, adj_orig, adj_norm, adj - - def sample_adj(self, adj_logits): + + @staticmethod + def sample_adj(adj_logits, temp): """ sample an adj from the predicted edge probabilities of ep_net """ edge_probs = adj_logits / torch.max(adj_logits) # sampling - adj_sampled = pyro.distributions.RelaxedBernoulliStraightThrough(temperature=self.__temperature, probs=edge_probs).rsample() + adj_sampled = pyro.distributions.RelaxedBernoulliStraightThrough(temperature=temp, probs=edge_probs).rsample() # making adj_sampled symmetric adj_sampled = adj_sampled.triu(1) adj_sampled = adj_sampled + adj_sampled.T return adj_sampled - def sample_adj_add_bernoulli(self, adj_logits, adj_orig, alpha): + @staticmethod + def sample_adj_add_bernoulli(adj_logits, adj_orig, alpha, temp): edge_probs = adj_logits / torch.max(adj_logits) edge_probs = alpha * edge_probs + (1-alpha) * adj_orig # sampling - adj_sampled = pyro.distributions.RelaxedBernoulliStraightThrough(temperature=self.__temperature, probs=edge_probs).rsample() + adj_sampled = pyro.distributions.RelaxedBernoulliStraightThrough(temperature=temp, probs=edge_probs).rsample() # making adj_sampled symmetric adj_sampled = adj_sampled.triu(1) adj_sampled = adj_sampled + adj_sampled.T return adj_sampled - def sample_adj_add_round(self, adj_logits, adj_orig, alpha): + @staticmethod + def sample_adj_add_round(adj_logits, adj_orig, alpha): edge_probs = adj_logits / torch.max(adj_logits) edge_probs = alpha * edge_probs + (1-alpha) * adj_orig # sampling @@ -122,14 +127,16 @@ def sample_adj_add_round(self, adj_logits, adj_orig, alpha): adj_sampled = adj_sampled + adj_sampled.T return adj_sampled - def sample_adj_random(self, adj_logits): + @staticmethod + def sample_adj_random(adj_logits): adj_rand = torch.rand(adj_logits.size()) adj_rand = adj_rand.triu(1) adj_rand = torch.round(adj_rand) adj_rand = adj_rand + adj_rand.T return adj_rand - def sample_adj_edge(self, adj_logits, adj_orig, change_frac): + @staticmethod + def sample_adj_edge(adj_logits, adj_orig, change_frac): adj = adj_orig.to_dense() if adj_orig.is_sparse else adj_orig n_edges = adj.nonzero().size(0) n_change = int(n_edges * change_frac / 2) @@ -163,14 +170,14 @@ def sample_adj_edge(self, adj_logits, adj_orig, change_frac): return adj_new def normalize_adj(self, adj): - if self.__gnnlayer_type == 'gcn': + if self.__gnn_type == 'gcn': adj.fill_diagonal_(1) # normalize adj with A = D^{-1/2} @ A @ D^{-1/2} D_norm = torch.diag(torch.pow(adj.sum(1), -0.5)).to(adj.device) adj = D_norm @ adj @ D_norm - elif self.__gnnlayer_type == 'gat': + elif self.__gnn_type == 'gat': adj.fill_diagonal_(1) - elif self.__gnnlayer_type == 'gsage': + elif self.__gnn_type == 'gsage': adj.fill_diagonal_(1) adj = F.normalize(adj, p=1, dim=1) return adj @@ -185,11 +192,12 @@ def forward(self, adj, adj_orig, features): adj_new = self.sample_adj_random(adj_logits) elif self.__sample_type == 'add_sample': if self.__alpha == 1: - adj_new = self.sample_adj(adj_logits) + adj_new = self.sample_adj(adj_logits, self.__temperature) else: - adj_new = self.sample_adj_add_bernoulli(adj_logits, adj_orig, self.__alpha) + adj_new = self.sample_adj_add_bernoulli(adj_logits, adj_orig, self.__alpha, self.__temperature) adj_new_normed = self.normalize_adj(adj_new) nc_logits = self.nc_net(features, adj_new_normed) + return nc_logits, adj_logits @@ -218,4 +226,93 @@ def forward(self, adj, features): Z = sampled_Z # inner product decoder adj_logits = Z @ Z.T - return adj_logits \ No newline at end of file + return adj_logits + + +class GAugM(nn.Module): + def __init__(self, in_dim, hidden_dim, n_classes, n_layers, gnn_type, rm_pct, add_pct, choose_idx, dropout=0.5, activation=F.relu, **kwargs): + super(GAugM, self).__init__() + + self.__rm_pct = rm_pct + self.__add_pct = add_pct + self.__choose_idx = choose_idx + self.__pre_graph_op = None + gnn_backbone = {'gcn': GCN, 'gsage': SAGE, 'gat': GAT} + self.__gnn_type = gnn_type + if gnn_type == 'gcn': + self.__pre_graph_op = LaplacianGraphOp() + if gnn_type == 'gat': + if kwargs.get("n_heads"): + n_heads = list(map(lambda x: int(x), kwargs["n_heads"].split(","))) + else: + n_heads = [8] * (n_layers - 1) + [1] + kwargs.update({"n_heads": n_heads}) + activation = F.elu + + self.nc_net = gnn_backbone.get(gnn_type)(in_dim, hidden_dim, n_classes, n_layers=n_layers, dropout=dropout, activation=activation, **kwargs) + + @staticmethod + def sample_graph_det(adj_orig, A_pred, remove_pct, add_pct): + if remove_pct == 0 and add_pct == 0: + return copy.deepcopy(adj_orig) + + orig_upper = sp.triu(adj_orig, 1) + n_edges = orig_upper.nnz + edges = np.asarray(orig_upper.nonzero()).T + + if remove_pct: + n_remove = int(n_edges * remove_pct / 100) + pos_probs = A_pred[edges.T[0], edges.T[1]] + e_index_2b_remove = np.argpartition(pos_probs, n_remove)[:n_remove] + mask = np.ones(len(edges), dtype=bool) + mask[e_index_2b_remove] = False + edges_pred = edges[mask] + else: + edges_pred = edges + + if add_pct: + n_add = int(n_edges * add_pct / 100) + # deep copy to avoid modifying A_pred + A_probs = np.array(A_pred) + # make the probabilities of the lower half to be zero (including diagonal) + A_probs[np.tril_indices(A_probs.shape[0])] = 0 + # make the probabilities of existing edges to be zero + A_probs[edges.T[0], edges.T[1]] = 0 + all_probs = A_probs.reshape(-1) + e_index_2b_add = np.argpartition(all_probs, -n_add)[-n_add:] + new_edges = [] + for index in e_index_2b_add: + i = int(index / A_probs.shape[0]) + j = index % A_probs.shape[0] + new_edges.append([i, j]) + edges_pred = np.concatenate((edges_pred, new_edges), axis=0) + adj_pred = sp.csr_matrix((np.ones(len(edges_pred)), edges_pred.T), shape=adj_orig.shape) + adj_pred = adj_pred + adj_pred.T + + return adj_pred + + def preprocess(self, adj_orig, features, A_pred_dir, device): + if features.size(1) in (1433, 3703): + features = F.normalize(features, p=1, dim=1) + features = features.to(device) + + A_pred = pkl.load(open(os.path.join(A_pred_dir, f'{self.__choose_idx}_logits.pkl'), 'rb')) + adj_pred = self.sample_graph_det(adj_orig, A_pred, self.__rm_pct, self.__add_pct) + + if self.__pre_graph_op is not None: + adj_norm_matrix = self.__pre_graph_op._construct_adj(adj_pred) + adj_processed = sparse_mx_to_torch_sparse_tensor(adj_norm_matrix).to(device) + else: + if not isinstance(adj_pred, sp.coo_matrix): + adj_pred = sp.coo_matrix(adj_pred) + adj_pred.setdiag(1) + if self.__gnn_type == 'gsage': + adj_processed = sparse_mx_to_torch_sparse_tensor(adj_pred).to(device) + elif self.__gnn_type == 'gat': + adj_processed = torch.FloatTensor(adj_pred.todense()).to(device) + + return adj_processed, features + + def forward(self, adj, features): + return self.nc_net(features, adj) + diff --git a/sgl/models/homo/gda/__init__.py b/sgl/models/homo/gda/__init__.py index bac6997..316fdd9 100644 --- a/sgl/models/homo/gda/__init__.py +++ b/sgl/models/homo/gda/__init__.py @@ -1,5 +1,6 @@ -from .GAug import GAug +from .GAug import GAugO, GAugM __all__ = [ - "GAug" + "GAugO", + "GAugM" ] \ No newline at end of file diff --git a/sgl/models/homo/graphsage.py b/sgl/models/homo/graphsage.py index a8ac314..ac1a6dc 100644 --- a/sgl/models/homo/graphsage.py +++ b/sgl/models/homo/graphsage.py @@ -11,5 +11,5 @@ def __init__(self, dataset, training_sampler, eval_sampler, hidden_dim, dropout= self._training_sampling_op = training_sampler self._eval_sampling_op = eval_sampler self._base_model = SAGE( - nfeat=dataset.num_features, nhid=hidden_dim, nclass=dataset.num_classes, nlayers=num_layers, dropout=dropout + n_feat=dataset.num_features, n_hid=hidden_dim, n_class=dataset.num_classes, n_layers=num_layers, dropout=dropout ).to(device) diff --git a/sgl/models/homo/graphsaint.py b/sgl/models/homo/graphsaint.py index 6232fdc..44f0c3b 100644 --- a/sgl/models/homo/graphsaint.py +++ b/sgl/models/homo/graphsaint.py @@ -12,7 +12,7 @@ def __init__(self, dataset, training_sampler, eval_sampler, hidden_dim, dropout= self._eval_sampling_op = eval_sampler self.device = device self._base_model = GCN( - nfeat=dataset.num_features, nhid=hidden_dim, nclass=dataset.num_classes, nlayers=num_layers, dropout=dropout + n_feat=dataset.num_features, n_hid=hidden_dim, n_class=dataset.num_classes, n_layers=num_layers, dropout=dropout ).to(device) def pre_sample(self, mode="train"): diff --git a/sgl/models/homo/lazygnn.py b/sgl/models/homo/lazygnn.py index 61ecdaf..37dad4a 100644 --- a/sgl/models/homo/lazygnn.py +++ b/sgl/models/homo/lazygnn.py @@ -26,7 +26,7 @@ def __init__(self, dataset, training_sampler, eval_sampler=None, hidden_dim=128, self._tau = tau # define the base model self._base_model = getattr(SimpleModels, basemodel)( - nfeat=dataset.num_features, nhid=hidden_dim, nclass=dataset.num_classes, nlayers=num_layers, dropout=dropout + n_feat=dataset.num_features, n_hid=hidden_dim, n_class=dataset.num_classes, n_layers=num_layers, dropout=dropout ).to(device) def preprocess(self, adj, x, val_dataloader=None, test_dataloader=None): diff --git a/sgl/models/homo/vanillagnn.py b/sgl/models/homo/vanillagnn.py index b421fda..bb1c565 100644 --- a/sgl/models/homo/vanillagnn.py +++ b/sgl/models/homo/vanillagnn.py @@ -16,5 +16,5 @@ def __init__(self, dataset, training_sampler, eval_sampler, hidden_dim, basemode self._training_sampling_op = training_sampler self._eval_sampling_op = eval_sampler self._base_model = getattr(SimpleModels, basemodel)( - nfeat=dataset.num_features, nhid=hidden_dim, nclass=dataset.num_classes, nlayers=num_layers, dropout=dropout + n_feat=dataset.num_features, n_hid=hidden_dim, n_class=dataset.num_classes, n_layers=num_layers, dropout=dropout ).to(device) diff --git a/sgl/models/simple_models.py b/sgl/models/simple_models.py index e5a79dc..dc111d9 100644 --- a/sgl/models/simple_models.py +++ b/sgl/models/simple_models.py @@ -256,7 +256,8 @@ def forward(self, x, adj): output = self.norm(output) return output - + + class GATConv(nn.Module): """ Simple GAT layer @@ -308,16 +309,17 @@ def forward(self, x, adj): if self.n_heads > 1: repr = repr.flatten(start_dim=1) return repr - + + class SAGE(nn.Module): - def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, activation=F.relu): + def __init__(self, n_feat, n_hid, n_class, n_layers=2, dropout=0.5, activation=F.relu): super(SAGE, self).__init__() self.gcs = nn.ModuleList() - self.gcs.append(SAGEConv(nfeat, nhid)) - self.nlayers = nlayers - for _ in range(nlayers-2): - self.gcs.append(SAGEConv(nhid, nhid)) - self.gcs.append(SAGEConv(nhid, nclass, normalize=False)) + self.gcs.append(SAGEConv(n_feat, n_hid)) + self.n_layers = n_layers + for _ in range(n_layers-2): + self.gcs.append(SAGEConv(n_hid, n_hid)) + self.gcs.append(SAGEConv(n_hid, n_class, normalize=False)) self.dropout = dropout self.activation = activation @@ -329,8 +331,8 @@ def forward(self, x, block): repr = x if isinstance(block, torch.Tensor): block = [block] - if len(block) == self.nlayers: - for i in range(self.nlayers-1): + if len(block) == self.n_layers: + for i in range(self.n_layers-1): repr = self.gcs[i](repr, block[i]) repr = self.activation(repr) repr = F.dropout(repr, self.dropout, training=self.training) @@ -365,15 +367,16 @@ def inference(self, x_all, subgraph_loader, device): return x_all + class GCN(nn.Module): - def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, activation=F.relu): + def __init__(self, n_feat, n_hid, n_class, n_layers=2, dropout=0.5, activation=F.relu): super(GCN, self).__init__() self.gcs = nn.ModuleList() - self.gcs.append(GCNConv(nfeat, nhid)) - self.nlayers = nlayers - for _ in range(nlayers-2): - self.gcs.append(GCNConv(nhid, nhid)) - self.gcs.append(GCNConv(nhid, nclass)) + self.gcs.append(GCNConv(n_feat, n_hid)) + self.n_layers = n_layers + for _ in range(n_layers-2): + self.gcs.append(GCNConv(n_hid, n_hid)) + self.gcs.append(GCNConv(n_hid, n_class)) self.dropout = dropout self.activation = activation @@ -385,7 +388,7 @@ def forward(self, x, block): repr = x if isinstance(block, torch.Tensor): block = [block] - if len(block) == self.nlayers: + if len(block) == self.n_layers: for i in range(self.nlayers-1): repr = self.gcs[i](repr, block[i]) repr = self.activation(repr) @@ -420,16 +423,17 @@ def inference(self, x_all, subgraph_loader, device): x_all = torch.cat(xs, dim=0) return x_all - + + class GAT(nn.Module): - def __init__(self, nfeat, nhid, nclass, n_heads, nlayers=2, dropout=0.6, activation=F.elu): + def __init__(self, n_feat, n_hid, n_class, n_heads, n_layers=2, dropout=0.6, activation=F.elu): super(GAT, self).__init__() self.gcs = nn.ModuleList() - self.gcs.append(GATConv(nfeat, nhid // n_heads[0], n_heads[0])) - self.nlayers = nlayers - for i in range(nlayers-2): - self.gcs.append(GATConv(nhid, nhid // n_heads[i+1], n_heads[i+1])) - self.gcs.append(GATConv(nhid, nclass, n_heads[-1])) + self.gcs.append(GATConv(n_feat, n_hid // n_heads[0], n_heads[0])) + self.n_layers = n_layers + for i in range(n_layers-2): + self.gcs.append(GATConv(n_hid, n_hid // n_heads[i + 1], n_heads[i + 1])) + self.gcs.append(GATConv(n_hid, n_class, n_heads[-1])) self.dropout = dropout self.activation = activation @@ -437,8 +441,8 @@ def forward(self, x, block): repr = x if isinstance(block, torch.Tensor): block = [block] - if len(block) == self.nlayers: - for i in range(self.nlayers-1): + if len(block) == self.n_layers: + for i in range(self.n_layers-1): repr = self.gcs[i](repr, block[i]) repr = self.activation(repr) repr = F.dropout(repr, self.dropout, training=self.training) diff --git a/sgl/tasks/__init__.py b/sgl/tasks/__init__.py index f68c17e..4f6683a 100644 --- a/sgl/tasks/__init__.py +++ b/sgl/tasks/__init__.py @@ -8,7 +8,7 @@ from .correct_and_smooth import NodeClassification_With_CorrectAndSmooth from .node_classification_with_label_use import NodeClassificationWithLabelUse from .node_classification_dist import NodeClassificationDist -from .node_classification_GAug import NodeClassification_GAug +from .node_classification_GAug import NodeClassificationGAugO, NodeClassificationGAugM __all__ = [ "NodeClassification", @@ -22,5 +22,6 @@ "NodeClassificationDist", "NodeClassification_Sampling", "NodeClassification_RecycleSampling", - "NodeClassification_GAug" + "NodeClassificationGAugO", + "NodeClassificationGAugM" ] diff --git a/sgl/tasks/node_classification_GAug.py b/sgl/tasks/node_classification_GAug.py index 54fdf1d..1d54c7a 100644 --- a/sgl/tasks/node_classification_GAug.py +++ b/sgl/tasks/node_classification_GAug.py @@ -1,18 +1,25 @@ import gc +import os +import time import torch import torch.nn as nn +from torch.optim import Adam import torch.nn.functional as F import numpy as np from sgl.tasks.base_task import BaseTask from sgl.tasks.utils import set_seed, accuracy, MultipleOptimizer -class NodeClassification_GAug(BaseTask): +class NodeClassificationGAugO(BaseTask): def __init__(self, dataset, model, lr, weight_decay, epochs, device, seed, beta, warmup, pretrain_ep, pretrain_nc): - super(NodeClassification_GAug, self).__init__() + super(NodeClassificationGAugO, self).__init__() self.__dataset = dataset - self.__model = model + self.__labels = self.__dataset.y + + self.__model = model + self.__optimizer = MultipleOptimizer(Adam(model.ep_net.parameters(), lr=lr), + Adam(model.nc_net.parameters(), lr=lr, weight_decay=weight_decay)) self.__lr = lr self.__weight_decay = weight_decay @@ -44,129 +51,214 @@ def get_lr_schedule_by_sigmoid(n_epochs, lr, warmup): factors = (factors - factors[0]) / (factors[-1] - factors[0]) lr_schedule = factors * lr return lr_schedule - + @staticmethod - def col_normalization(features): - """ column normalization for feature matrix """ - features = features.numpy() - m = features.mean(axis=0) - s = features.std(axis=0, ddof=0, keepdims=True) + 1e-12 - features -= m - features /= s - return torch.FloatTensor(features) + def loss_fn(nc_logits, norm_w, adj_logits, adj_orig, pos_weight, labels, idx, beta): + if labels.dim() == 2: + nc_criterion = nn.BCEWithLogitsLoss() + else: + nc_criterion = nn.CrossEntropyLoss() + loss = nc_criterion(nc_logits[idx], labels[idx]) + ep_loss = norm_w * F.binary_cross_entropy_with_logits(adj_logits, adj_orig, pos_weight=pos_weight) + loss += beta * ep_loss + + return loss - def pretrain_ep_net(self, model, adj, features, adj_orig, norm_w, pos_weight): + def pretrain_ep_net(self, adj, features, adj_orig, norm_w, pos_weight): """ pretrain the edge prediction network """ - optimizer = torch.optim.Adam(model.ep_net.parameters(), - lr=self.__lr) - model.train() + optimizer = Adam(self.__model.ep_net.parameters(), lr=self.__lr) + + self.__model.train() for _ in range(self.__pretrain_ep): - adj_logits = model.ep_net(adj, features) + adj_logits = self.__model.ep_net(adj, features) loss = norm_w * F.binary_cross_entropy_with_logits(adj_logits, adj_orig, pos_weight=pos_weight) - if not model.gae: - mu = model.ep_net.mean - lgstd = model.ep_net.logstd + if not self.__model.gae: + mu = self.__model.ep_net.mean + lgstd = self.__model.ep_net.logstd kl_divergence = 0.5 / adj_logits.size(0) * (1 + 2*lgstd - mu**2 - torch.exp(2*lgstd)).sum(1).mean() loss -= kl_divergence optimizer.zero_grad() loss.backward() optimizer.step() - def pretrain_nc_net(self, model, adj, features, labels): + def pretrain_nc_net(self, adj, features): """ pretrain the node classification network """ - optimizer = torch.optim.Adam(model.nc_net.parameters(), - lr=self.__lr, - weight_decay=self.__weight_decay) + optimizer = Adam(self.__model.nc_net.parameters(), lr=self.__lr, weight_decay=self.__weight_decay) # loss function for node classification - if labels.dim() == 2: + if self.__labels.dim() == 2: nc_criterion = nn.BCEWithLogitsLoss() else: nc_criterion = nn.CrossEntropyLoss() - - best_val_acc = 0. + for _ in range(self.__pretrain_nc): - model.train() - nc_logits = model.nc_net(features, adj) + self.__model.train() + nc_logits = self.__model.nc_net(features, adj) # losses - loss = nc_criterion(nc_logits[self.__dataset.train_idx], labels[self.__dataset.train_idx]) + loss = nc_criterion(nc_logits[self.__dataset.train_idx], self.__labels[self.__dataset.train_idx]) optimizer.zero_grad() loss.backward() optimizer.step() - model.eval() - with torch.no_grad(): - nc_logits_eval = model.nc_net(features, adj) - val_acc = accuracy(nc_logits_eval[self.__dataset.val_idx], labels[self.__dataset.val_idx]) - if val_acc > best_val_acc: - best_val_acc = val_acc + + def train(self, adj_norm, adj_orig, features, norm_w, pos_weight, epoch, ep_lr_schedule): + # update the learning rate for ep_net if needed + if self.__warmup: + self.__optimizer.update_lr(0, ep_lr_schedule[epoch]) + + self.__model.train() + nc_logits, adj_logits = self.__model(adj_norm, adj_orig, features) + loss_train = self.loss_fn(nc_logits, norm_w, adj_logits, adj_orig, pos_weight, self.__labels, self.__dataset.train_idx, self.__beta) + acc_train = accuracy(nc_logits[self.__dataset.train_idx], self.__labels[self.__dataset.train_idx]) + self.__optimizer.zero_grad() + loss_train.backward() + self.__optimizer.step() + + return loss_train, acc_train + + def evaluate(self, features, adj): + self.__model.eval() + with torch.no_grad(): + nc_logits_eval = self.__model.nc_net(features, adj) + acc_val = accuracy(nc_logits_eval[self.__dataset.val_idx], self.__labels[self.__dataset.val_idx]) + acc_test = accuracy(nc_logits_eval[self.__dataset.test_idx], self.__labels[self.__dataset.test_idx]) + + return acc_val, acc_test def _execute(self): set_seed(self.__seed) features, adj_orig, adj_norm, adj = self.__model.preprocess(self.__dataset.x, self.__dataset.adj, self.__device) - - model = self.__model.to(self.__device) - labels = self.__dataset.y.to(self.__device) + + self.__model = self.__model.to(self.__device) + self.__labels = self.__labels.to(self.__device) # weights for log_lik loss when training EP net - adj_t = adj_orig - norm_w = adj_t.shape[0]**2 / float((adj_t.shape[0]**2 - adj_t.sum()) * 2) - pos_weight = torch.FloatTensor([float(adj_t.shape[0]**2 - adj_t.sum()) / adj_t.sum()]).to(self.__device) + norm_w = adj_orig.shape[0]**2 / float((adj_orig.shape[0]**2 - adj_orig.sum()) * 2) + pos_weight = torch.FloatTensor([float(adj_orig.shape[0]**2 - adj_orig.sum()) / adj_orig.sum()]).to(self.__device) # pretrain VGAE if needed if self.__pretrain_ep: - self.pretrain_ep_net(model, adj_norm, features, adj_orig, norm_w, pos_weight) + self.pretrain_ep_net(adj_norm, features, adj_orig, norm_w, pos_weight) # pretrain GCN if needed if self.__pretrain_nc: - self.pretrain_nc_net(model, adj, features, labels) - # optimizers - optims = MultipleOptimizer(torch.optim.Adam(model.ep_net.parameters(), - lr=self.__lr), - torch.optim.Adam(model.nc_net.parameters(), - lr=self.__lr, - weight_decay=self.__weight_decay)) + self.pretrain_nc_net(adj, features) # get the learning rate schedule for the optimizer of ep_net if needed if self.__warmup: ep_lr_schedule = self.get_lr_schedule_by_sigmoid(self.__epochs, self.__lr, self.__warmup) - # loss function for node classification - if labels.dim() == 2: - nc_criterion = nn.BCEWithLogitsLoss() else: - nc_criterion = nn.CrossEntropyLoss() - + ep_lr_schedule = None + # keep record of the best validation accuracy for early stopping - best_val_acc = 0. - patience_step = 0 + best_acc_val, best_acc_test, patience_step = 0., 0., 0 # train model for epoch in range(self.__epochs): - # update the learning rate for ep_net if needed - if self.__warmup: - optims.update_lr(0, ep_lr_schedule[epoch]) + t = time.time() + loss_train, acc_train = self.train(adj_norm, adj_orig, features, norm_w, pos_weight, epoch, ep_lr_schedule) + acc_val, acc_test = self.evaluate(features, adj) - model.train() - nc_logits, adj_logits = model(adj_norm, adj_orig, features) + print('Epoch: {:03d}'.format(epoch + 1), + 'loss_train: {:.4f}'.format(loss_train), + 'acc_train: {:.4f}'.format(acc_train), + 'acc_val: {:.4f}'.format(acc_val), + 'acc_test: {:.4f}'.format(acc_test), + 'time: {:.4f}s'.format(time.time() - t)) - # losses - loss = nc_criterion(nc_logits[self.__dataset.train_idx], labels[self.__dataset.train_idx]) - ep_loss = norm_w * F.binary_cross_entropy_with_logits(adj_logits, adj_orig, pos_weight=pos_weight) - loss += self.__beta * ep_loss - optims.zero_grad() - loss.backward() - optims.step() - # validate (without dropout) - model.eval() - with torch.no_grad(): - nc_logits_eval = model.nc_net(features, adj) - val_acc = accuracy(nc_logits_eval[self.__dataset.val_idx], labels[self.__dataset.val_idx]) - if val_acc > best_val_acc: - best_val_acc = val_acc - test_acc = accuracy(nc_logits_eval[self.__dataset.test_idx], labels[self.__dataset.test_idx]) + if acc_val > best_acc_val: + best_acc_val = acc_val + best_acc_test = acc_test patience_step = 0 else: patience_step += 1 - if patience_step == 100: + if patience_step == 50: break + # release RAM and GPU memory - del adj, features, labels, adj_orig + del adj, features, adj_orig, adj_norm torch.cuda.empty_cache() gc.collect() - return test_acc \ No newline at end of file + return best_acc_test + + +class NodeClassificationGAugM(BaseTask): + def __init__(self, dataset, model, lr, weight_decay, epochs, device, loss_fn=nn.CrossEntropyLoss(), seed=42): + super(NodeClassificationGAugM, self).__init__() + + self.__dataset = dataset + self.__labels = self.__dataset.y + + self.__model = model + self.__optimizer = Adam(model.parameters(), lr=lr, + weight_decay=weight_decay) + self.__epochs = epochs + self.__loss_fn = loss_fn + self.__device = device + self.__seed = seed + + self.__test_acc = self._execute() + + @property + def test_acc(self): + return self.__test_acc + + def train(self, adj_norm, features): + self.__model.train() + pred_y = self.__model(adj_norm, features)[self.__dataset.train_idx] + ground_truth_y = self.__labels[self.__dataset.train_idx] + loss_train = self.__loss_fn(pred_y, ground_truth_y) + acc_train = accuracy(pred_y, ground_truth_y) + + self.__optimizer.zero_grad() + loss_train.backward() + self.__optimizer.step() + + return loss_train, acc_train + + def evaluate(self, adj_norm, features): + self.__model.eval() + with torch.no_grad(): + pred_y = self.__model(adj_norm, features) + acc_val = accuracy(pred_y[self.__dataset.val_idx], self.__labels[self.__dataset.val_idx]) + acc_test = accuracy(pred_y[self.__dataset.test_idx], self.__labels[self.__dataset.test_idx]) + + return acc_val, acc_test + + def _execute(self): + set_seed(self.__seed) + + pre_time_st = time.time() + A_pred_dir = os.path.join(self.__dataset.processed_dir, "GAugM_edge_probabilities") + adj_norm, features = self.__model.preprocess(self.__dataset.adj, self.__dataset.x, A_pred_dir, self.__device) + pre_time_ed = time.time() + print(f"Preprocessing done in {(pre_time_ed - pre_time_st):.4f}s") + + self.__model = self.__model.to(self.__device) + self.__labels = self.__labels.to(self.__device) + + t_total = time.time() + best_val = 0. + best_test = 0. + for epoch in range(self.__epochs): + t = time.time() + loss_train, acc_train = self.train(adj_norm, features) + acc_val, acc_test = self.evaluate(adj_norm, features) + + print('Epoch: {:03d}'.format(epoch + 1), + 'loss_train: {:.4f}'.format(loss_train), + 'acc_train: {:.4f}'.format(acc_train), + 'acc_val: {:.4f}'.format(acc_val), + 'acc_test: {:.4f}'.format(acc_test), + 'time: {:.4f}s'.format(time.time() - t)) + + if acc_val > best_val: + best_val = acc_val + best_test = acc_test + + print("Optimization Finished!") + print("Total time elapsed: {:.4f}s".format(time.time() - t_total)) + print(f'Best val: {best_val:.4f}, best test: {best_test:.4f}') + + del adj_norm, features + torch.cuda.empty_cache() + gc.collect() + + return best_test + From 077d4530a19ec0de2ea7c89d6fd725a9cec1fc1a Mon Sep 17 00:00:00 2001 From: infinity Date: Wed, 13 Dec 2023 07:39:17 +0000 Subject: [PATCH 21/28] add FLAG, but currently it doesn't support sampling-based model, and doesn't test GAT(only GCN and SAGE for ogbn-arxiv) --- examples/GDA/configs/FLAG.yml | 37 +++++++++++ examples/GDA/test_FLAG.py | 30 +++++++++ sgl/models/homo/gda/FLAG.py | 101 +++++++++++++++++++++++++++++++ sgl/models/homo/gda/__init__.py | 4 +- sgl/models/simple_models.py | 56 +++++++++++++---- sgl/tasks/node_classification.py | 28 +++++++-- 6 files changed, 236 insertions(+), 20 deletions(-) create mode 100644 examples/GDA/configs/FLAG.yml create mode 100644 examples/GDA/test_FLAG.py create mode 100644 sgl/models/homo/gda/FLAG.py diff --git a/examples/GDA/configs/FLAG.yml b/examples/GDA/configs/FLAG.yml new file mode 100644 index 0000000..342eeef --- /dev/null +++ b/examples/GDA/configs/FLAG.yml @@ -0,0 +1,37 @@ +# dataset: +# classname: "Ogbn" +# name: "arxiv" +# root: "/home/ssq/test_data/" +# model: +# gnn_type: 'gcn' +# hidden_dim: 256 +# dropout: 0.5 +# n_layers: 3 +# step_size: 0.001 +# augM: 3 +# batch_norm: True +# task: +# lr: 0.01 +# seed: 12345 +# epochs: 500 +# patience: 80 +# weight_decay: 0 +dataset: + classname: "Ogbn" + name: "arxiv" + root: "/home/ssq/test_data/" +model: + gnn_type: 'sage' + hidden_dim: 256 + dropout: 0.5 + n_layers: 3 + step_size: 0.001 + augM: 3 + batch_norm: True + normalize: False +task: + lr: 0.01 + seed: 12345 + epochs: 500 + patience: 80 + weight_decay: 0 \ No newline at end of file diff --git a/examples/GDA/test_FLAG.py b/examples/GDA/test_FLAG.py new file mode 100644 index 0000000..f317f84 --- /dev/null +++ b/examples/GDA/test_FLAG.py @@ -0,0 +1,30 @@ +import yaml +import argparse + +import sgl.dataset as Dataset +from sgl.models.homo.gda import FLAG +from sgl.tasks import NodeClassification + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description = "FLAG-Model.") + parser.add_argument( + "--device", type=int, default=0, help="gpu device id or cpu (-1)" + ) + parser.add_argument( + "--config_path", type=str, default="./configs/FLAG.yml", help="save path of the configuration file" + ) + args = parser.parse_args() + config = yaml.safe_load(open(args.config_path, "rb")) + device = f"cuda:{args.device}" if args.device >= 0 else "cpu" + dataset_kwargs = config["dataset"] + model_kwargs = config["model"] + task_kwargs = config["task"] + + dataset_classname = dataset_kwargs.pop("classname") + dataset = getattr(Dataset, dataset_classname)(**dataset_kwargs) + for seed in range(10): + model = FLAG(in_dim=dataset.num_features, n_classes=dataset.num_classes, **model_kwargs) + task_kwargs.update({"device": device}) + task_kwargs.update({"seed": seed}) + test_acc = NodeClassification(dataset, model, **task_kwargs).test_acc + print(f"test acc: {test_acc:.4f}") \ No newline at end of file diff --git a/sgl/models/homo/gda/FLAG.py b/sgl/models/homo/gda/FLAG.py new file mode 100644 index 0000000..41d8923 --- /dev/null +++ b/sgl/models/homo/gda/FLAG.py @@ -0,0 +1,101 @@ +import scipy.sparse as sp +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sgl.models.simple_models import GCN, SAGE, GAT +from sgl.operators.graph_op import LaplacianGraphOp, RwGraphOp +from sgl.utils import sparse_mx_to_torch_sparse_tensor + +class FLAG(nn.Module): + def __init__(self, in_dim, hidden_dim, n_classes, n_layers, dropout, gnn_type, step_size, augM, activation=F.relu, **kwargs): + super(FLAG, self).__init__() + self.__step_size = step_size + self.__augM = augM + self.__gnn_type = gnn_type + gnn_backbone = {"gcn": GCN, "sage": SAGE, "gat": GAT} + if gnn_type == 'gat': + if kwargs.get("n_heads"): + n_heads = list(map(lambda x: int(x), kwargs["n_heads"].split(","))) + else: + n_heads = [8] * (n_layers - 1) + [1] + kwargs.update({"n_heads": n_heads}) + activation = F.elu + self.nc_net = gnn_backbone.get(gnn_type)(in_dim, hidden_dim, n_classes, n_layers=n_layers, dropout=dropout, activation=activation, **kwargs) + + @property + def processed_feature(self): + return self.__features + + @property + def processed_adj(self): + return self.__processed_adj + + def preprocess(self, adj, features, device): + self.__features = features.to(device) + if self.__gnn_type == "gcn": + adj_norm = LaplacianGraphOp()._construct_adj(adj) + self.__processed_adj = sparse_mx_to_torch_sparse_tensor(adj_norm).to(device) + elif self.__gnn_type == "sage": + adj_norm = RwGraphOp()._construct_adj(adj) + self.__processed_adj = sparse_mx_to_torch_sparse_tensor(adj_norm).to(device) + elif self.__gnn_type == "gat": + adj_sl = sp.coo_matrix(adj) + adj_sl = adj_sl + sp.eye(*adj_sl.shape) + self.__processed_adj = torch.FloatTensor(adj_sl.todense()).to(device) + + def flag(self, ground_truth_y, optimizer, device, train_idx, loss_fn): + x = self.__features + adj = self.__processed_adj + + self.nc_net.train() + optimizer.zero_grad() + + perturb = torch.FloatTensor(x.shape).uniform_(-self.__step_size, self.__step_size).to(device) + perturb.requires_grad_() + pred_y = self.nc_net(x+perturb, adj)[train_idx] + loss = loss_fn(pred_y, ground_truth_y) + loss /= self.__augM + + for _ in range(self.__augM-1): + loss.backward() + perturb_data = perturb.detach() + self.__step_size * torch.sign(perturb.grad.detach()) + perturb.data = perturb_data.data + perturb.grad[:] = 0 + + pred_y = self.nc_net(x+perturb, adj)[train_idx] + loss = loss_fn(pred_y, ground_truth_y) + loss /= self.__augM + + loss.backward() + optimizer.step() + + return loss + + def train_func(self, train_idx, labels, device, optimizer, loss_fn, metric): + loss_train = self.flag(labels[train_idx], optimizer, device, train_idx, loss_fn) + + self.nc_net.eval() + pred_y = self.nc_net(self.__features, self.__processed_adj) + acc_train = metric(pred_y[train_idx], labels[train_idx]) + + return loss_train.item(), acc_train + + @torch.no_grad() + def evaluate_func(self, val_idx, test_idx, labels, device, metric): + self.nc_net.eval() + pred_y = self.nc_net(self.__features, self.__processed_adj) + + acc_val = metric(pred_y[val_idx], labels[val_idx]) + acc_test = metric(pred_y[test_idx], labels[test_idx]) + return acc_val, acc_test + + def model_forward(self, idx, device): + pred_y = self.forward(self.__features, self.__processed_adj) + return pred_y[idx] + + def forward(self, x, adj): + return self.nc_net(x, adj) + + def postprocess(self, adj, outputs): + return outputs \ No newline at end of file diff --git a/sgl/models/homo/gda/__init__.py b/sgl/models/homo/gda/__init__.py index 316fdd9..123cae6 100644 --- a/sgl/models/homo/gda/__init__.py +++ b/sgl/models/homo/gda/__init__.py @@ -1,6 +1,8 @@ from .GAug import GAugO, GAugM +from .FLAG import FLAG __all__ = [ "GAugO", - "GAugM" + "GAugM", + "FLAG" ] \ No newline at end of file diff --git a/sgl/models/simple_models.py b/sgl/models/simple_models.py index dc111d9..3646495 100644 --- a/sgl/models/simple_models.py +++ b/sgl/models/simple_models.py @@ -312,13 +312,19 @@ def forward(self, x, adj): class SAGE(nn.Module): - def __init__(self, n_feat, n_hid, n_class, n_layers=2, dropout=0.5, activation=F.relu): + def __init__(self, n_feat, n_hid, n_class, n_layers=2, dropout=0.5, activation=F.relu, batch_norm=False, normalize=True): super(SAGE, self).__init__() self.gcs = nn.ModuleList() - self.gcs.append(SAGEConv(n_feat, n_hid)) + self.gcs.append(SAGEConv(n_feat, n_hid, normalize=normalize)) + self.batch_norm = batch_norm + if self.batch_norm: + self.bns = nn.ModuleList() + self.bns.append(nn.BatchNorm1d(n_hid)) self.n_layers = n_layers for _ in range(n_layers-2): - self.gcs.append(SAGEConv(n_hid, n_hid)) + self.gcs.append(SAGEConv(n_hid, n_hid, normalize=normalize)) + if self.batch_norm: + self.bns.append(nn.BatchNorm1d(n_hid)) self.gcs.append(SAGEConv(n_hid, n_class, normalize=False)) self.dropout = dropout self.activation = activation @@ -326,6 +332,9 @@ def __init__(self, n_feat, n_hid, n_class, n_layers=2, dropout=0.5, activation=F def reset_parameter(self): for conv in self.gcs: conv.reset_parameters() + if self.batch_norm: + for bn in self.bns: + bn.reset_parameters() def forward(self, x, block): repr = x @@ -334,12 +343,16 @@ def forward(self, x, block): if len(block) == self.n_layers: for i in range(self.n_layers-1): repr = self.gcs[i](repr, block[i]) + if self.batch_norm: + repr = self.bns[i](repr) repr = self.activation(repr) repr = F.dropout(repr, self.dropout, training=self.training) repr = self.gcs[-1](repr, block[-1]) elif len(block) == 1: - for gc in self.gcs[:-1]: - repr = gc(repr, block[0]) + for i in range(self.n_layers-1): + repr = self.gcs[i](repr, block[0]) + if self.batch_norm: + repr = self.bns[i](repr) repr = self.activation(repr) repr = F.dropout(repr, self.dropout, training=self.training) repr = self.gcs[-1](repr, block[0]) @@ -352,14 +365,16 @@ def inference(self, x_all, subgraph_loader, device): # Compute representations of nodes layer by layer, using *all* # available edges. This leads to faster computation in contrast to # immediately computing the final representations of each batch. - for i, conv in enumerate(self.gcs): + for i in range(self.n_layers): xs = [] for batch in subgraph_loader: batch_in, _, block = batch block.to_device(device) x = x_all[batch_in].to(device) - x = conv(x, block[0]) # one-layer sampling + x = self.gcs[i](x, block[0]) # one-layer sampling if i != self.nlayers - 1: + if self.batch_norm: + x = self.bns[i](x) x = F.relu(x) xs.append(x.cpu()) @@ -369,13 +384,19 @@ def inference(self, x_all, subgraph_loader, device): class GCN(nn.Module): - def __init__(self, n_feat, n_hid, n_class, n_layers=2, dropout=0.5, activation=F.relu): + def __init__(self, n_feat, n_hid, n_class, n_layers=2, dropout=0.5, activation=F.relu, batch_norm=False): super(GCN, self).__init__() self.gcs = nn.ModuleList() self.gcs.append(GCNConv(n_feat, n_hid)) + self.batch_norm = batch_norm + if self.batch_norm: + self.bns = nn.ModuleList() + self.bns.append(nn.BatchNorm1d(n_hid)) self.n_layers = n_layers for _ in range(n_layers-2): self.gcs.append(GCNConv(n_hid, n_hid)) + if self.batch_norm: + self.bns.append(nn.BatchNorm1d(n_hid)) self.gcs.append(GCNConv(n_hid, n_class)) self.dropout = dropout self.activation = activation @@ -383,6 +404,9 @@ def __init__(self, n_feat, n_hid, n_class, n_layers=2, dropout=0.5, activation=F def reset_parameter(self): for conv in self.gcs: conv.reset_parameters() + if self.batch_norm: + for bn in self.bns: + bn.reset_parameters() def forward(self, x, block): repr = x @@ -391,12 +415,16 @@ def forward(self, x, block): if len(block) == self.n_layers: for i in range(self.nlayers-1): repr = self.gcs[i](repr, block[i]) + if self.batch_norm: + repr = self.bns[i](repr) repr = self.activation(repr) repr = F.dropout(repr, self.dropout, training=self.training) repr = self.gcs[-1](repr, block[-1]) elif len(block) == 1: - for gc in self.gcs[:-1]: - repr = gc(repr, block[0]) + for i in range(self.n_layers-1): + repr = self.gcs[i](repr, block[0]) + if self.batch_norm: + repr = self.bns[i](repr) repr = self.activation(repr) repr = F.dropout(repr, self.dropout, training=self.training) repr = self.gcs[-1](repr, block[0]) @@ -409,15 +437,17 @@ def inference(self, x_all, subgraph_loader, device): # Compute representations of nodes layer by layer, using *all* # available edges. This leads to faster computation in contrast to # immediately computing the final representations of each batch. - for i, conv in enumerate(self.gcs): + for i in range(self.n_layers): xs = [] for batch in subgraph_loader: batch_in, _, block = batch block.to_device(device) x = x_all[batch_in].to(device) - x = conv(x, block[0]) # one-layer sampling + x = self.gcs[i](x, block[0]) # one-layer sampling if i != self.nlayers - 1: - x = F.relu(x) + if self.batch_norm: + x = self.bns[i](x) + x = self.activation(x) xs.append(x.cpu()) x_all = torch.cat(xs, dim=0) diff --git a/sgl/tasks/node_classification.py b/sgl/tasks/node_classification.py index e352d4c..ec83b10 100644 --- a/sgl/tasks/node_classification.py +++ b/sgl/tasks/node_classification.py @@ -3,6 +3,7 @@ import torch.nn as nn from torch.optim import Adam from torch.utils.data import DataLoader +from typing import Callable from sgl.tasks.base_task import BaseTask from sgl.tasks.utils import accuracy, set_seed, train, mini_batch_train, evaluate, mini_batch_evaluate @@ -10,7 +11,7 @@ class NodeClassification(BaseTask): def __init__(self, dataset, model, lr, weight_decay, epochs, device, loss_fn=nn.CrossEntropyLoss(), seed=42, - train_batch_size=None, eval_batch_size=None): + patience=100, train_batch_size=None, eval_batch_size=None): super(NodeClassification, self).__init__() self.__dataset = dataset @@ -23,6 +24,7 @@ def __init__(self, dataset, model, lr, weight_decay, epochs, device, loss_fn=nn. self.__loss_fn = loss_fn self.__device = device self.__seed = seed + self.__patience = patience self.__mini_batch = False if train_batch_size is not None: @@ -46,7 +48,7 @@ def _execute(self): set_seed(self.__seed) pre_time_st = time.time() - self.__model.preprocess(self.__dataset.adj, self.__dataset.x) + self.__model.preprocess(self.__dataset.adj, self.__dataset.x, self.__device) pre_time_ed = time.time() print(f"Preprocessing done in {(pre_time_ed - pre_time_st):.4f}s") @@ -56,13 +58,22 @@ def _execute(self): t_total = time.time() best_val = 0. best_test = 0. + patience = 0 for epoch in range(self.__epochs): t = time.time() if self.__mini_batch is False: - loss_train, acc_train = train(self.__model, self.__dataset.train_idx, self.__labels, self.__device, - self.__optimizer, self.__loss_fn) - acc_val, acc_test = evaluate(self.__model, self.__dataset.val_idx, self.__dataset.test_idx, - self.__labels, self.__device) + if hasattr(self.__model, "train_func") and isinstance(self.__model.train_func, Callable): + loss_train, acc_train = self.__model.train_func(self.__dataset.train_idx, self.__labels, self.__device, + self.__optimizer, self.__loss_fn, accuracy) + else: + loss_train, acc_train = train(self.__model, self.__dataset.train_idx, self.__labels, self.__device, + self.__optimizer, self.__loss_fn, accuracy) + if hasattr(self.__model, "evaluate_func") and isinstance(self.__model.evaluate_func, Callable): + acc_val, acc_test = self.__model.evaluate_func(self.__dataset.val_idx, self.__dataset.test_idx, + self.__labels, self.__device, accuracy) + else: + acc_val, acc_test = evaluate(self.__model, self.__dataset.val_idx, self.__dataset.test_idx, + self.__labels, self.__device, accuracy) else: loss_train, acc_train = mini_batch_train(self.__model, self.__dataset.train_idx, self.__train_loader, self.__labels, self.__device, self.__optimizer, self.__loss_fn) @@ -77,8 +88,13 @@ def _execute(self): 'acc_test: {:.4f}'.format(acc_test), 'time: {:.4f}s'.format(time.time() - t)) if acc_val > best_val: + patience = 0 best_val = acc_val best_test = acc_test + else: + patience += 1 + if patience == self.__patience: + break acc_val, acc_test = self._postprocess() if acc_val > best_val: From d0419b7495d44046efa530c9990b4ce4cc0f05aa Mon Sep 17 00:00:00 2001 From: infinity Date: Thu, 14 Dec 2023 13:44:57 +0000 Subject: [PATCH 22/28] add PyG-Style GNN models, add sampling-based FLAG (Test for Ogbn-Products) --- examples/GDA/configs/FLAG.yml | 60 ++++--- examples/GDA/configs/GAugM.yml | 83 ++++----- examples/GDA/configs/GAugO.yml | 17 +- examples/GDA/configs/SampleFLAG.yml | 38 ++++ examples/GDA/test_SampleFLAG.py | 46 +++++ sgl/data/base_data.py | 5 +- sgl/models/base_model.py | 5 +- sgl/models/homo/gda/FLAG.py | 184 ++++++++++++++++--- sgl/models/homo/gda/GAug.py | 148 ++++++---------- sgl/models/homo/gda/__init__.py | 5 +- sgl/models/pyg_simple_models.py | 204 ++++++++++++++++++++++ sgl/models/simple_models.py | 8 +- sgl/sampler/base_sampler.py | 8 +- sgl/tasks/node_classification_GAug.py | 26 ++- sgl/tasks/node_classification_sampling.py | 61 ++++--- sgl/utils/__init__.py | 3 +- sgl/utils/basic_operations.py | 12 +- 17 files changed, 675 insertions(+), 238 deletions(-) create mode 100644 examples/GDA/configs/SampleFLAG.yml create mode 100644 examples/GDA/test_SampleFLAG.py create mode 100644 sgl/models/pyg_simple_models.py diff --git a/examples/GDA/configs/FLAG.yml b/examples/GDA/configs/FLAG.yml index 342eeef..ac50617 100644 --- a/examples/GDA/configs/FLAG.yml +++ b/examples/GDA/configs/FLAG.yml @@ -1,9 +1,27 @@ +dataset: + classname: "Ogbn" + name: "arxiv" + root: "/home/ssq/test_data/" +model: + gnn_type: 'gcn' + hidden_dim: 256 + dropout: 0.5 + n_layers: 3 + step_size: 0.001 + augM: 3 + batch_norm: True +task: + lr: 0.01 + seed: 12345 + epochs: 500 + patience: 80 + weight_decay: 0 # dataset: # classname: "Ogbn" # name: "arxiv" # root: "/home/ssq/test_data/" # model: -# gnn_type: 'gcn' +# gnn_type: 'sage' # hidden_dim: 256 # dropout: 0.5 # n_layers: 3 @@ -16,22 +34,24 @@ # epochs: 500 # patience: 80 # weight_decay: 0 -dataset: - classname: "Ogbn" - name: "arxiv" - root: "/home/ssq/test_data/" -model: - gnn_type: 'sage' - hidden_dim: 256 - dropout: 0.5 - n_layers: 3 - step_size: 0.001 - augM: 3 - batch_norm: True - normalize: False -task: - lr: 0.01 - seed: 12345 - epochs: 500 - patience: 80 - weight_decay: 0 \ No newline at end of file +# dataset: +# classname: "Ogbn" +# name: "arxiv" +# root: "/home/ssq/test_data/" +# model: +# gnn_type: 'gat' +# hidden_dim: 256 +# n_heads: '8,8,1' +# dropout: 0.6 +# attn_dropout: 0.6 +# n_layers: 3 +# step_size: 0.001 +# augM: 3 +# amp: 2 +# batch_norm: True +# task: +# lr: 0.002 +# seed: 12345 +# epochs: 500 +# patience: 100 +# weight_decay: 0 \ No newline at end of file diff --git a/examples/GDA/configs/GAugM.yml b/examples/GDA/configs/GAugM.yml index 7559a07..e5a3b8e 100644 --- a/examples/GDA/configs/GAugM.yml +++ b/examples/GDA/configs/GAugM.yml @@ -1,54 +1,59 @@ -#dataset: -# classname: "Planetoid" -# name: "cora" -# root: "/home/ssq/test_data/" -#model: -# model_name: 'GAugM' -# gnn_type: 'gcn' -# hidden_dim: 128 -# dropout: 0.5 -# n_layers: 2 -# choose_idx: 5 -# rm_pct: 2 -# add_pct: 57 -#task: -# lr: 0.01 -# seed: 42 -# epochs: 200 -# weight_decay: 0.0005 -#dataset: +dataset: + classname: "Planetoid" + name: "cora" + root: "/home/ssq/test_data/" +model: + model_name: 'GAugM' + gnn_type: 'gcn' + feat_norm: 'row' + hidden_dim: 128 + dropout: 0.5 + n_layers: 2 + choose_idx: 5 + rm_pct: 2 + add_pct: 57 +task: + lr: 0.01 + seed: 42 + epochs: 200 + weight_decay: 0.0005 +# dataset: # classname: "Planetoid" # name: "cora" # root: "/home/ssq/test_data/" -#model: +# model: # model_name: 'GAugM' # gnn_type: 'gsage' +# feat_norm: 'row' +# normalize: True # hidden_dim: 128 # dropout: 0.5 # n_layers: 2 # choose_idx: 2 # rm_pct: 1 # add_pct: 80 -#task: +# task: # lr: 0.01 # seed: 42 # epochs: 200 # weight_decay: 0.0005 -dataset: - classname: "Planetoid" - name: "cora" - root: "/home/ssq/test_data/" -model: - model_name: 'GAugM' - gnn_type: 'gat' - hidden_dim: 128 - dropout: 0.5 - n_layers: 2 - choose_idx: 2 - rm_pct: 1 - add_pct: 68 -task: - lr: 0.01 - seed: 42 - epochs: 200 - weight_decay: 0.0005 \ No newline at end of file +# dataset: +# classname: "Planetoid" +# name: "cora" +# root: "/home/ssq/test_data/" +# model: +# model_name: 'GAugM' +# gnn_type: 'gat' +# feat_norm: 'row' +# activation: 'elu' +# hidden_dim: 128 +# dropout: 0.5 +# n_layers: 2 +# choose_idx: 2 +# rm_pct: 1 +# add_pct: 68 +# task: +# lr: 0.01 +# seed: 42 +# epochs: 200 +# weight_decay: 0.0005 \ No newline at end of file diff --git a/examples/GDA/configs/GAugO.yml b/examples/GDA/configs/GAugO.yml index 971e62a..9fce08b 100644 --- a/examples/GDA/configs/GAugO.yml +++ b/examples/GDA/configs/GAugO.yml @@ -1,20 +1,20 @@ -#dataset: +# dataset: # classname: "Planetoid" # name: "cora" # root: "/home/ssq/test_data/" -#model: +# model: # model_name: 'GAugO' # gnn_type: 'gcn' # alpha: 1.0 # temperature: 1.2 -# hidden_dim: 128 +# hidden_dim: 256 # emb_size: 32 # dropout: 0.5 # n_layers: 2 # gae: true # feat_norm: 'row' # sample_type: 'add_sample' -#task: +# task: # lr: 0.01 # seed: 42 # warmup: 0 @@ -23,6 +23,7 @@ # weight_decay: 0.0005 # pretrain_ep: 160 # pretrain_nc: 30 +# max_patience: 50 dataset: classname: "Planetoid" name: "cora" @@ -38,6 +39,7 @@ model: n_layers: 2 gae: true feat_norm: 'row' + normalize: True sample_type: 'add_sample' task: lr: 0.01 @@ -48,6 +50,11 @@ task: weight_decay: 0.0005 pretrain_ep: 195 pretrain_nc: 35 + max_patience: 50 +# dataset: +# classname: "Planetoid" +# name: "cora" +# root: "/home/ssq/test_data/" # model: # model_name: 'GAugO' # gnn_type: 'gat' @@ -57,6 +64,7 @@ task: # emb_size: 32 # dropout: 0.6 # n_layers: 2 +# activation: "elu" # gae: true # feat_norm: 'row' # sample_type: 'add_sample' @@ -69,3 +77,4 @@ task: # weight_decay: 0.0005 # pretrain_ep: 175 # pretrain_nc: 45 +# max_patience: 50 diff --git a/examples/GDA/configs/SampleFLAG.yml b/examples/GDA/configs/SampleFLAG.yml new file mode 100644 index 0000000..a8080ea --- /dev/null +++ b/examples/GDA/configs/SampleFLAG.yml @@ -0,0 +1,38 @@ +dataset: + classname: "Ogbn" + name: "products" + root: "/home/ssq/test_data/" +sampler: + training: + name: "NeighborSampler" + layer_sizes: "15,10,5" + prob_type: "normalize" + replace: False + eval: + name: "NeighborSampler" + layer_sizes: "-1" + replace: False +model: + gnn_type: 'sage' + hidden_dim: 256 + dropout: 0.5 + n_layers: 3 + step_size: 0.008 + augM: 3 + amp: 2 + batch_norm: False +task: + name: "NodeClassification_Sampling" + lr: 0.003 + seed: 12345 + epochs: 20 + patience: 10 + weight_decay: 0 + train_batch_size: 1024 + eval_batch_size: 4096 + train_num_workers: 12 + eval_num_workers: 12 + eval_together: True + eval_freq: 2 + eval_start: 10 + loss_fn: "nll_loss" \ No newline at end of file diff --git a/examples/GDA/test_SampleFLAG.py b/examples/GDA/test_SampleFLAG.py new file mode 100644 index 0000000..7803ee3 --- /dev/null +++ b/examples/GDA/test_SampleFLAG.py @@ -0,0 +1,46 @@ +import yaml +import argparse + +import sgl.dataset as Dataset +import sgl.sampler as Sampler +import sgl.tasks as Tasks +from sgl.models.homo.gda import SampleFLAG + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Sampler-Models") + parser.add_argument( + "--device", type=int, default=0, help="gpu device id or cpu (-1)" + ) + parser.add_argument( + "--config_path", type=str, default="./configs/SampleFLAG.yml", help="save path of the configuration file" + ) + args = parser.parse_args() + config = yaml.safe_load(open(args.config_path, "rb")) + device = f"cuda:{args.device}" if args.device >= 0 else "cpu" + dataset_kwargs = config["dataset"] + task_kwargs = config["task"] + classname = dataset_kwargs.pop("classname") + dataset = getattr(Dataset, classname)(**dataset_kwargs) + training_sampler_kwargs = config["sampler"]["training"] + if "inductive" in training_sampler_kwargs.keys(): + inductive = training_sampler_kwargs.pop("inductive") + else: + inductive = False + task_kwargs.update({"inductive": inductive}) + training_sampler_name = training_sampler_kwargs.pop("name") + training_sampler_kwargs.update({"save_dir": dataset.processed_dir}) + training_sampler = getattr(Sampler, training_sampler_name)(dataset.adj[dataset.train_idx, :][:, dataset.train_idx] if inductive else dataset.adj, **training_sampler_kwargs) + if "eval" in config["sampler"].keys(): + eval_sampler_kwargs = config["sampler"]["eval"] + eval_sampler_name = eval_sampler_kwargs.pop("name") + eval_sampler_kwargs.update({"save_dir": dataset.processed_dir}) + eval_sampler = getattr(Sampler, eval_sampler_name)(dataset.adj, **eval_sampler_kwargs) + else: + eval_sampler = None + model_kwargs = config["model"] + model = SampleFLAG(training_sampler, eval_sampler, in_dim=dataset.num_features, n_classes=dataset.num_classes, **model_kwargs) + task_kwargs.update({"device": device}) + task_name = task_kwargs.pop("name") + test_acc = getattr(Tasks, task_name)(dataset, model, **task_kwargs).test_acc + print(f"final test acc: {test_acc}") \ No newline at end of file diff --git a/sgl/data/base_data.py b/sgl/data/base_data.py index a5b15a7..2cb0c40 100644 --- a/sgl/data/base_data.py +++ b/sgl/data/base_data.py @@ -3,7 +3,7 @@ import numpy as np from scipy.sparse import csr_matrix -from sgl.utils import sparse_mx_to_torch_sparse_tensor +from sgl.utils import sparse_mx_to_torch_sparse_tensor, sparse_mx_to_pyg_sparse_tensor # A lighter wrapper class for sampled adjacency matrices, # as the Edge class seems contains useless information @@ -29,7 +29,8 @@ def to_device(self, device): if self.__device == device: return if not isinstance(self.__adjs[0], torch.sparse.FloatTensor): - self.__adjs = [sparse_mx_to_torch_sparse_tensor(adj) for adj in self.__adjs] + # self.__adjs = [sparse_mx_to_torch_sparse_tensor(adj) for adj in self.__adjs] + self.__adjs = [sparse_mx_to_pyg_sparse_tensor(adj) for adj in self.__adjs] self.__adjs = [adj.to(device) for adj in self.__adjs] self.__device = device diff --git a/sgl/models/base_model.py b/sgl/models/base_model.py index 19f528d..d9d6433 100644 --- a/sgl/models/base_model.py +++ b/sgl/models/base_model.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from sgl.data.base_data import Block from sgl.data.base_dataset import HeteroNodeDataset -from sgl.utils import sparse_mx_to_torch_sparse_tensor +from sgl.utils import sparse_mx_to_torch_sparse_tensor, sparse_mx_to_pyg_sparse_tensor class BaseSGAPModel(nn.Module): @@ -127,7 +127,8 @@ def preprocess(self, adj, x, y, device, **kwargs): norm_adj = self._pre_graph_op._construct_adj(adj) else: norm_adj = adj - norm_adj = sparse_mx_to_torch_sparse_tensor(norm_adj) + # norm_adj = sparse_mx_to_torch_sparse_tensor(norm_adj) + norm_adj = sparse_mx_to_pyg_sparse_tensor(norm_adj) self._processed_block = Block(norm_adj) if hasattr(self, "_pre_feature_op"): diff --git a/sgl/models/homo/gda/FLAG.py b/sgl/models/homo/gda/FLAG.py index 41d8923..6905fed 100644 --- a/sgl/models/homo/gda/FLAG.py +++ b/sgl/models/homo/gda/FLAG.py @@ -1,27 +1,28 @@ -import scipy.sparse as sp import torch import torch.nn as nn import torch.nn.functional as F -from sgl.models.simple_models import GCN, SAGE, GAT -from sgl.operators.graph_op import LaplacianGraphOp, RwGraphOp -from sgl.utils import sparse_mx_to_torch_sparse_tensor +from sgl.models.base_model import BaseSAMPLEModel +from sgl.utils import sparse_mx_to_pyg_sparse_tensor +from sgl.models.pyg_simple_models import GCN, SAGE, GAT + +GNN_BACKBONE = {"gcn": GCN, "sage": SAGE, "gat": GAT} class FLAG(nn.Module): def __init__(self, in_dim, hidden_dim, n_classes, n_layers, dropout, gnn_type, step_size, augM, activation=F.relu, **kwargs): super(FLAG, self).__init__() - self.__step_size = step_size + self.__step_size = float(step_size) self.__augM = augM - self.__gnn_type = gnn_type - gnn_backbone = {"gcn": GCN, "sage": SAGE, "gat": GAT} + self.__amp = kwargs.pop("amp", 1) + if isinstance(activation, str): + activation = getattr(F, activation) if gnn_type == 'gat': if kwargs.get("n_heads"): n_heads = list(map(lambda x: int(x), kwargs["n_heads"].split(","))) else: n_heads = [8] * (n_layers - 1) + [1] kwargs.update({"n_heads": n_heads}) - activation = F.elu - self.nc_net = gnn_backbone.get(gnn_type)(in_dim, hidden_dim, n_classes, n_layers=n_layers, dropout=dropout, activation=activation, **kwargs) + self._base_model = GNN_BACKBONE.get(gnn_type)(in_dim, hidden_dim, n_classes, n_layers=n_layers, dropout=dropout, activation=activation, **kwargs) @property def processed_feature(self): @@ -33,37 +34,33 @@ def processed_adj(self): def preprocess(self, adj, features, device): self.__features = features.to(device) - if self.__gnn_type == "gcn": - adj_norm = LaplacianGraphOp()._construct_adj(adj) - self.__processed_adj = sparse_mx_to_torch_sparse_tensor(adj_norm).to(device) - elif self.__gnn_type == "sage": - adj_norm = RwGraphOp()._construct_adj(adj) - self.__processed_adj = sparse_mx_to_torch_sparse_tensor(adj_norm).to(device) - elif self.__gnn_type == "gat": - adj_sl = sp.coo_matrix(adj) - adj_sl = adj_sl + sp.eye(*adj_sl.shape) - self.__processed_adj = torch.FloatTensor(adj_sl.todense()).to(device) + self.__processed_adj = sparse_mx_to_pyg_sparse_tensor(adj).to(device) def flag(self, ground_truth_y, optimizer, device, train_idx, loss_fn): x = self.__features adj = self.__processed_adj - self.nc_net.train() + self._base_model.train() optimizer.zero_grad() perturb = torch.FloatTensor(x.shape).uniform_(-self.__step_size, self.__step_size).to(device) + unlabel_idx = list(set(range(perturb.shape[0])) - set(train_idx)) + perturb.data[unlabel_idx] *= self.__amp + perturb.requires_grad_() - pred_y = self.nc_net(x+perturb, adj)[train_idx] + pred_y = self._base_model(x+perturb, adj)[train_idx] loss = loss_fn(pred_y, ground_truth_y) loss /= self.__augM for _ in range(self.__augM-1): loss.backward() - perturb_data = perturb.detach() + self.__step_size * torch.sign(perturb.grad.detach()) - perturb.data = perturb_data.data + perturb_data = perturb[train_idx].detach() + self.__step_size * torch.sign(perturb.grad[train_idx].detach()) + perturb.data[train_idx] = perturb_data.data + perturb_data = perturb[unlabel_idx].detach() + self.__amp * self.__step_size * torch.sign(perturb.grad[unlabel_idx].detach()) + perturb.data[unlabel_idx] = perturb_data.data perturb.grad[:] = 0 - pred_y = self.nc_net(x+perturb, adj)[train_idx] + pred_y = self._base_model(x+perturb, adj)[train_idx] loss = loss_fn(pred_y, ground_truth_y) loss /= self.__augM @@ -75,16 +72,16 @@ def flag(self, ground_truth_y, optimizer, device, train_idx, loss_fn): def train_func(self, train_idx, labels, device, optimizer, loss_fn, metric): loss_train = self.flag(labels[train_idx], optimizer, device, train_idx, loss_fn) - self.nc_net.eval() - pred_y = self.nc_net(self.__features, self.__processed_adj) + self._base_model.eval() + pred_y = self._base_model(self.__features, self.__processed_adj) acc_train = metric(pred_y[train_idx], labels[train_idx]) return loss_train.item(), acc_train @torch.no_grad() def evaluate_func(self, val_idx, test_idx, labels, device, metric): - self.nc_net.eval() - pred_y = self.nc_net(self.__features, self.__processed_adj) + self._base_model.eval() + pred_y = self._base_model(self.__features, self.__processed_adj) acc_val = metric(pred_y[val_idx], labels[val_idx]) acc_test = metric(pred_y[test_idx], labels[test_idx]) @@ -95,7 +92,136 @@ def model_forward(self, idx, device): return pred_y[idx] def forward(self, x, adj): - return self.nc_net(x, adj) + return self._base_model(x, adj) + + def postprocess(self, adj, outputs): + return outputs + + +class SampleFLAG(BaseSAMPLEModel): + def __init__(self, training_sampler, eval_sampler, in_dim, hidden_dim, n_classes, n_layers, dropout, gnn_type, step_size, augM, activation=F.relu, **kwargs): + super(SampleFLAG, self).__init__() + self.__step_size = float(step_size) + self.__augM = augM + self.__amp = kwargs.pop("amp", 1) + self._training_sampling_op = training_sampler + self._eval_sampling_op = eval_sampler + if isinstance(activation, str): + activation = getattr(F, activation) + if gnn_type == 'gat': + if kwargs.get("n_heads"): + n_heads = list(map(lambda x: int(x), kwargs["n_heads"].split(","))) + else: + n_heads = [8] * (n_layers - 1) + [1] + kwargs.update({"n_heads": n_heads}) + self._base_model = GNN_BACKBONE.get(gnn_type)(in_dim, hidden_dim, n_classes, n_layers=n_layers, dropout=dropout, activation=activation, **kwargs) + + def flag(self, clean, ground_truth_y, adjs, batch_out, optimizer, device, loss_fn): + self._base_model.train() + optimizer.zero_grad() + batch_size = len(batch_out) + perturb_t = torch.FloatTensor(clean[:batch_size].shape).uniform_(-self.__step_size, self.__step_size).to(device) + perturb_un = torch.FloatTensor(clean[batch_size:].shape).uniform_(-self.__amp * self.__step_size, self.__amp * self.__step_size).to(device) + perturb_t.requires_grad_() + perturb_un.requires_grad_() + + perturb = torch.cat([perturb_t, perturb_un], dim=0) + pred_y = self._base_model(clean+perturb, adjs) + loss = loss_fn(pred_y, ground_truth_y) + loss /= self.__augM + + for _ in range(self.__augM-1): + loss.backward() + + perturb_data_t = perturb_t.detach() + self.__step_size * torch.sign(perturb_t.grad.detach()) + perturb_t.data = perturb_data_t.data + perturb_t.grad[:] = 0 + + perturb_data_un = perturb_un.detach() + self.__amp * self.__step_size * torch.sign(perturb_un.grad.detach()) + perturb_un.data = perturb_data_un.data + perturb_un.grad[:] = 0 + + perturb = torch.cat((perturb_t, perturb_un), dim=0) + + pred_y = self._base_model(clean+perturb, adjs) + loss = loss_fn(pred_y, ground_truth_y) + loss /= self.__augM + + loss.backward() + optimizer.step() + + return loss, pred_y + + def mini_batch_prepare_forward(self, batch, device, loss_fn, optimizers, inductive=False, transfer_y_to_device=True): + batch_in, batch_out, block = batch + + if inductive is False: + in_x = self._processed_feature[batch_in].to(device) + y_truth = self._vanilla_y[batch_out] + else: + in_x = self._processed_train_feature[batch_in].to(device) + y_truth = self._vanilla_train_y[batch_out] + + if transfer_y_to_device is True: + y_truth = y_truth.to(device) + + block.to_device(device) + loss, pred_y = self.flag(in_x, y_truth, block, batch_out, optimizers, device, loss_fn) + + return loss, pred_y, y_truth + + def train_func(self, train_loader, inductive, device, optimizer, loss_fn): + correct_num = 0 + loss_train_sum = 0. + train_num = 0 + + for batch in train_loader: + loss_train, y_out, y_truth = self.mini_batch_prepare_forward(batch, device, loss_fn, optimizer, inductive=inductive) + pred = y_out.max(1)[1].type_as(y_truth) + correct_num += pred.eq(y_truth).double().sum() + loss_train_sum += loss_train.item() + train_num += len(y_truth) + + loss_train = loss_train_sum / len(train_loader) + acc_train = correct_num / train_num + + return loss_train, acc_train.item() + + @torch.no_grad() + def evaluate_func(self, val_loader, test_loader, device): + self._base_model.eval() + + correct_num_val, correct_num_test = 0, 0 + val_num = 0 + for batch in val_loader: + val_output, out_y = self.model_forward(batch, device) + pred = val_output.max(1)[1].type_as(out_y) + correct_num_val += pred.eq(out_y).double().sum() + val_num += len(out_y) + + acc_val = correct_num_val / val_num + + test_num = 0 + for batch in test_loader: + test_output, out_y = self.model_forward(batch, device) + pred = test_output.max(1)[1].type_as(out_y) + correct_num_test += pred.eq(out_y).double().sum() + test_num += len(out_y) + acc_test = correct_num_test / test_num + + return acc_val.item(), acc_test.item() + + def model_forward(self, batch, device): + batch_in, batch_out, block = batch + in_x = self._processed_feature[batch_in].to(device) + y_truth = self._vanilla_y[batch_out].to(device) + block.to_device(device) + + y_pred = self.forward(in_x, block) + return y_pred, y_truth + + def forward(self, x, adj): + return self._base_model(x, adj) def postprocess(self, adj, outputs): return outputs \ No newline at end of file diff --git a/sgl/models/homo/gda/GAug.py b/sgl/models/homo/gda/GAug.py index acc5b5a..0daf218 100644 --- a/sgl/models/homo/gda/GAug.py +++ b/sgl/models/homo/gda/GAug.py @@ -8,47 +8,33 @@ import pickle as pkl import scipy.sparse as sp -from sgl.models.simple_models import GCNConv, GCN, SAGE, GAT -from sgl.models.homo.gda.utils import RoundNoGradient, CeilNoGradient -from sgl.utils import sparse_mx_to_torch_sparse_tensor +from sgl.utils import sparse_mx_to_pyg_sparse_tensor from sgl.operators.graph_op import LaplacianGraphOp +from sgl.models.pyg_simple_models import GCNConv, GCN, SAGE, GAT +from sgl.models.homo.gda.utils import RoundNoGradient, CeilNoGradient class GAugO(nn.Module): - def __init__(self, - in_dim, - hidden_dim, - emb_size, - n_classes, - n_layers, - dropout, - gnn_type, - activation=F.relu, - temperature=1, - gae=False, - alpha=1, - feat_norm="row", - sample_type="add_sample", - **kwargs): + def __init__(self, in_dim, hidden_dim, emb_size, n_classes, n_layers, dropout, gnn_type, + activation=F.relu, temperature=1, gae=False, alpha=1, feat_norm="row", sample_type="add_sample", **kwargs): super(GAugO, self).__init__() - self.__pre_graph_op = LaplacianGraphOp() self.__temperature = temperature self.__alpha = alpha self.__sample_type = sample_type # edge prediction network self.__gae = gae self.__feat_norm = feat_norm - self.ep_net = VGAE(in_dim, hidden_dim, emb_size, activation, gae=gae) + self.ep_net = VGAE(in_dim, hidden_dim, emb_size, F.relu, gae=gae) # node classification network - self.__gnn_type = gnn_type gnn_backbone = {"gcn": GCN, "gsage": SAGE, "gat": GAT} - if gnn_type == 'gat': + if isinstance(activation, str): + activation = getattr(F, activation) + if gnn_type == "gat": if kwargs.get("n_heads"): n_heads = list(map(lambda x: int(x), kwargs["n_heads"].split(","))) else: n_heads = [8] * (n_layers - 1) + [1] kwargs.update({"n_heads": n_heads}) - activation = F.elu self.nc_net = gnn_backbone.get(gnn_type)(in_dim, hidden_dim, n_classes, n_layers=n_layers, dropout=dropout, activation=activation, **kwargs) @@ -67,9 +53,9 @@ def col_normalization(features): return torch.FloatTensor(features) def preprocess(self, features, adj_matrix, device): - if self.__feat_norm == 'row': + if self.__feat_norm == "row": features = F.normalize(features, p=1, dim=1) - elif self.__feat_norm == 'col': + elif self.__feat_norm == "col": features = self.col_normalization(features) features = features.to(device) @@ -77,22 +63,12 @@ def preprocess(self, features, adj_matrix, device): if not isinstance(adj_matrix, sp.coo_matrix): adj_matrix = sp.coo_matrix(adj_matrix) adj_matrix_sl = adj_matrix + sp.eye(*adj_matrix.shape) - adj_orig = sparse_mx_to_torch_sparse_tensor(adj_matrix_sl).to_dense().to(device) - adj_norm_matrix = self.__pre_graph_op._construct_adj(adj_matrix) - adj_norm = sparse_mx_to_torch_sparse_tensor(adj_norm_matrix) - # adj_matrix used as input for nc_net (torch.sparse.FloatTensor) - if self.__gnn_type == 'gcn': - adj = sparse_mx_to_torch_sparse_tensor(adj_norm_matrix) - elif self.__gnn_type == 'gsage': - adj = adj_matrix_sl / adj_matrix_sl.sum(1) - adj = sparse_mx_to_torch_sparse_tensor(adj) - elif self.__gnn_type == 'gat': - adj = torch.FloatTensor(adj_matrix_sl.todense()) - - adj_norm = adj_norm.to(device) - adj = adj.to(device) + adj_orig = sparse_mx_to_pyg_sparse_tensor(adj_matrix_sl).to_dense().to(device) + adj_norm_matrix = LaplacianGraphOp()._construct_adj(adj_matrix) + adj_norm = sparse_mx_to_pyg_sparse_tensor(adj_norm_matrix).to(device) + adj = sparse_mx_to_pyg_sparse_tensor(adj_matrix).to(device) - return features, adj_orig, adj_norm, adj + return features, adj_orig, adj, adj_norm @staticmethod def sample_adj(adj_logits, temp): @@ -169,34 +145,23 @@ def sample_adj_edge(adj_logits, adj_orig, change_frac): adj_new = adj_new + mask_add return adj_new - def normalize_adj(self, adj): - if self.__gnn_type == 'gcn': - adj.fill_diagonal_(1) - # normalize adj with A = D^{-1/2} @ A @ D^{-1/2} - D_norm = torch.diag(torch.pow(adj.sum(1), -0.5)).to(adj.device) - adj = D_norm @ adj @ D_norm - elif self.__gnn_type == 'gat': - adj.fill_diagonal_(1) - elif self.__gnn_type == 'gsage': - adj.fill_diagonal_(1) - adj = F.normalize(adj, p=1, dim=1) - return adj - - def forward(self, adj, adj_orig, features): - adj_logits = self.ep_net(adj, features) - if self.__sample_type == 'edge': + def forward(self, adj_norm, adj_orig, features): + adj_logits = self.ep_net(adj_norm, features) + if self.__sample_type == "edge": adj_new = self.sample_adj_edge(adj_logits, adj_orig, self.__alpha) - elif self.__sample_type == 'add_round': + elif self.__sample_type == "add_round": adj_new = self.sample_adj_add_round(adj_logits, adj_orig, self.__alpha) - elif self.__sample_type == 'rand': + elif self.__sample_type == "rand": adj_new = self.sample_adj_random(adj_logits) - elif self.__sample_type == 'add_sample': + elif self.__sample_type == "add_sample": if self.__alpha == 1: adj_new = self.sample_adj(adj_logits, self.__temperature) else: adj_new = self.sample_adj_add_bernoulli(adj_logits, adj_orig, self.__alpha, self.__temperature) - adj_new_normed = self.normalize_adj(adj_new) - nc_logits = self.nc_net(features, adj_new_normed) + + row, col = adj_new.nonzero(as_tuple=True) + edge_index = torch.vstack([row, col]) + nc_logits = self.nc_net(features, edge_index) return nc_logits, adj_logits @@ -207,13 +172,13 @@ def __init__(self, in_dim, hidden_dim, emb_size, activation, gae=False): super(VGAE, self).__init__() self.gae = gae self.activation = activation - self.gcn_base = GCNConv(in_dim, hidden_dim, bias=False) - self.gcn_mean = GCNConv(hidden_dim, emb_size, bias=False) - self.gcn_logstd = GCNConv(hidden_dim, emb_size, bias=False) + self.gcn_base = GCNConv(in_dim, hidden_dim, add_self_loops=False, normalize=False, bias=False) + self.gcn_mean = GCNConv(hidden_dim, emb_size, add_self_loops=False, normalize=False, bias=False) + self.gcn_logstd = GCNConv(hidden_dim, emb_size, add_self_loops=False, normalize=False, bias=False) def forward(self, adj, features): # GCN encoder - hidden = self.gcn_base(features, adj, ) + hidden = self.gcn_base(features, adj) self.mean = self.activation(self.gcn_mean(hidden, adj)) if self.gae: # GAE (no sampling at bottleneck) @@ -230,29 +195,27 @@ def forward(self, adj, features): class GAugM(nn.Module): - def __init__(self, in_dim, hidden_dim, n_classes, n_layers, gnn_type, rm_pct, add_pct, choose_idx, dropout=0.5, activation=F.relu, **kwargs): + def __init__(self, in_dim, hidden_dim, n_classes, n_layers, gnn_type, rm_pct, add_pct, choose_idx, dropout=0.5, activation=F.relu, feat_norm='none', **kwargs): super(GAugM, self).__init__() + self.__feat_norm = feat_norm self.__rm_pct = rm_pct self.__add_pct = add_pct self.__choose_idx = choose_idx - self.__pre_graph_op = None - gnn_backbone = {'gcn': GCN, 'gsage': SAGE, 'gat': GAT} - self.__gnn_type = gnn_type - if gnn_type == 'gcn': - self.__pre_graph_op = LaplacianGraphOp() - if gnn_type == 'gat': + if isinstance(activation, str): + activation = getattr(F, activation) + gnn_backbone = {"gcn": GCN, "gsage": SAGE, "gat": GAT} + if gnn_type == "gat": if kwargs.get("n_heads"): n_heads = list(map(lambda x: int(x), kwargs["n_heads"].split(","))) else: n_heads = [8] * (n_layers - 1) + [1] kwargs.update({"n_heads": n_heads}) - activation = F.elu self.nc_net = gnn_backbone.get(gnn_type)(in_dim, hidden_dim, n_classes, n_layers=n_layers, dropout=dropout, activation=activation, **kwargs) @staticmethod - def sample_graph_det(adj_orig, A_pred, remove_pct, add_pct): + def sample_graph_det(adj_orig, adj_pred, remove_pct, add_pct): if remove_pct == 0 and add_pct == 0: return copy.deepcopy(adj_orig) @@ -262,7 +225,7 @@ def sample_graph_det(adj_orig, A_pred, remove_pct, add_pct): if remove_pct: n_remove = int(n_edges * remove_pct / 100) - pos_probs = A_pred[edges.T[0], edges.T[1]] + pos_probs = adj_pred[edges.T[0], edges.T[1]] e_index_2b_remove = np.argpartition(pos_probs, n_remove)[:n_remove] mask = np.ones(len(edges), dtype=bool) mask[e_index_2b_remove] = False @@ -272,18 +235,18 @@ def sample_graph_det(adj_orig, A_pred, remove_pct, add_pct): if add_pct: n_add = int(n_edges * add_pct / 100) - # deep copy to avoid modifying A_pred - A_probs = np.array(A_pred) + # deep copy to avoid modifying adj_pred + adj_probs = np.array(adj_pred) # make the probabilities of the lower half to be zero (including diagonal) - A_probs[np.tril_indices(A_probs.shape[0])] = 0 + adj_probs[np.tril_indices(adj_probs.shape[0])] = 0 # make the probabilities of existing edges to be zero - A_probs[edges.T[0], edges.T[1]] = 0 - all_probs = A_probs.reshape(-1) + adj_probs[edges.T[0], edges.T[1]] = 0 + all_probs = adj_probs.reshape(-1) e_index_2b_add = np.argpartition(all_probs, -n_add)[-n_add:] new_edges = [] for index in e_index_2b_add: - i = int(index / A_probs.shape[0]) - j = index % A_probs.shape[0] + i = int(index / adj_probs.shape[0]) + j = index % adj_probs.shape[0] new_edges.append([i, j]) edges_pred = np.concatenate((edges_pred, new_edges), axis=0) adj_pred = sp.csr_matrix((np.ones(len(edges_pred)), edges_pred.T), shape=adj_orig.shape) @@ -291,25 +254,14 @@ def sample_graph_det(adj_orig, A_pred, remove_pct, add_pct): return adj_pred - def preprocess(self, adj_orig, features, A_pred_dir, device): - if features.size(1) in (1433, 3703): + def preprocess(self, adj_orig, features, adj_pred_dir, device): + if self.__feat_norm == "row": features = F.normalize(features, p=1, dim=1) features = features.to(device) - A_pred = pkl.load(open(os.path.join(A_pred_dir, f'{self.__choose_idx}_logits.pkl'), 'rb')) - adj_pred = self.sample_graph_det(adj_orig, A_pred, self.__rm_pct, self.__add_pct) - - if self.__pre_graph_op is not None: - adj_norm_matrix = self.__pre_graph_op._construct_adj(adj_pred) - adj_processed = sparse_mx_to_torch_sparse_tensor(adj_norm_matrix).to(device) - else: - if not isinstance(adj_pred, sp.coo_matrix): - adj_pred = sp.coo_matrix(adj_pred) - adj_pred.setdiag(1) - if self.__gnn_type == 'gsage': - adj_processed = sparse_mx_to_torch_sparse_tensor(adj_pred).to(device) - elif self.__gnn_type == 'gat': - adj_processed = torch.FloatTensor(adj_pred.todense()).to(device) + adj_pred = pkl.load(open(os.path.join(adj_pred_dir, f"{self.__choose_idx}_logits.pkl"), "rb")) + adj_pred = self.sample_graph_det(adj_orig, adj_pred, self.__rm_pct, self.__add_pct) + adj_processed = sparse_mx_to_pyg_sparse_tensor(adj_pred).to(device) return adj_processed, features diff --git a/sgl/models/homo/gda/__init__.py b/sgl/models/homo/gda/__init__.py index 123cae6..8d3bbf4 100644 --- a/sgl/models/homo/gda/__init__.py +++ b/sgl/models/homo/gda/__init__.py @@ -1,8 +1,9 @@ from .GAug import GAugO, GAugM -from .FLAG import FLAG +from .FLAG import FLAG, SampleFLAG __all__ = [ "GAugO", "GAugM", - "FLAG" + "FLAG", + "SampleFLAG" ] \ No newline at end of file diff --git a/sgl/models/pyg_simple_models.py b/sgl/models/pyg_simple_models.py new file mode 100644 index 0000000..70136b1 --- /dev/null +++ b/sgl/models/pyg_simple_models.py @@ -0,0 +1,204 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_sparse import SparseTensor +from torch_geometric.nn import GCNConv, SAGEConv, GATConv + +class GCN(nn.Module): + def __init__(self, n_feat, n_hid, n_class, n_layers=2, dropout=0.5, activation=F.relu, batch_norm=False, add_self_loops=True, normalize=True, cached=False): + super(GCN, self).__init__() + self.gcs = nn.ModuleList() + self.gcs.append(GCNConv(n_feat, n_hid, cached=cached, add_self_loops=add_self_loops, normalize=normalize)) + self.batch_norm = batch_norm + if self.batch_norm: + self.bns = nn.ModuleList() + self.bns.append(nn.BatchNorm1d(n_hid)) + self.n_layers = n_layers + for _ in range(n_layers-2): + self.gcs.append(GCNConv(n_hid, n_hid, cached=cached, add_self_loops=add_self_loops, normalize=normalize)) + if self.batch_norm: + self.bns.append(nn.BatchNorm1d(n_hid)) + self.gcs.append(GCNConv(n_hid, n_class, cached=cached, add_self_loops=add_self_loops, normalize=normalize)) + self.dropout = dropout + self.activation = activation + + def reset_parameter(self): + for conv in self.gcs: + conv.reset_parameters() + if self.batch_norm: + for bn in self.bns: + bn.reset_parameters() + + def forward(self, x, block): + repr = x + if isinstance(block, (SparseTensor, torch.Tensor)): + block = [block] + if len(block) == self.n_layers: + for i in range(self.n_layers-1): + repr = self.gcs[i](repr, block[i]) + if self.batch_norm: + repr = self.bns[i](repr) + repr = self.activation(repr) + repr = F.dropout(repr, self.dropout, training=self.training) + repr = self.gcs[-1](repr, block[-1]) + elif len(block) == 1: + for i in range(self.n_layers-1): + repr = self.gcs[i](repr, block[0]) + if self.batch_norm: + repr = self.bns[i](repr) + repr = self.activation(repr) + repr = F.dropout(repr, self.dropout, training=self.training) + repr = self.gcs[-1](repr, block[0]) + else: + raise ValueError('The sampling layer must be equal to GNN layer.') + + return F.log_softmax(repr, dim=1) + + def inference(self, x_all, subgraph_loader, device): + # Compute representations of nodes layer by layer, using *all* + # available edges. This leads to faster computation in contrast to + # immediately computing the final representations of each batch. + for i in range(self.n_layers): + xs = [] + for batch in subgraph_loader: + batch_in, _, block = batch + block.to_device(device) + x = x_all[batch_in].to(device) + x = self.gcs[i](x, block[0]) # one-layer sampling + if i != self.n_layers - 1: + if self.batch_norm: + x = self.bns[i](x) + x = self.activation(x) + xs.append(x.cpu()) + + x_all = torch.cat(xs, dim=0) + + return x_all + +class SAGE(nn.Module): + def __init__(self, n_feat, n_hid, n_class, n_layers=2, dropout=0.5, activation=F.relu, batch_norm=False, normalize=False): + super(SAGE, self).__init__() + self.gcs = nn.ModuleList() + self.gcs.append(SAGEConv(n_feat, n_hid)) + self.batch_norm = batch_norm + self.normalize = normalize + if normalize: + self.norm = lambda x: F.normalize(x, p=1, dim=1) + if self.batch_norm: + self.bns = nn.ModuleList() + self.bns.append(nn.BatchNorm1d(n_hid)) + self.n_layers = n_layers + for _ in range(n_layers-2): + self.gcs.append(SAGEConv(n_hid, n_hid)) + if self.batch_norm: + self.bns.append(nn.BatchNorm1d(n_hid)) + self.gcs.append(SAGEConv(n_hid, n_class)) + self.dropout = dropout + self.activation = activation + + def reset_parameter(self): + for conv in self.gcs: + conv.reset_parameters() + if self.batch_norm: + for bn in self.bns: + bn.reset_parameters() + + def forward(self, x, block): + repr = x + if isinstance(block, (SparseTensor, torch.Tensor)): + block = [block] + if len(block) == self.n_layers: + for i in range(self.n_layers-1): + root_size = block[i].sparse_size(0) + root_repr = repr[:root_size] + repr = self.gcs[i]((repr, root_repr), block[i]) + if self.normalize: + repr = self.norm(repr) + if self.batch_norm: + repr = self.bns[i](repr) + repr = self.activation(repr) + repr = F.dropout(repr, self.dropout, training=self.training) + root_size = block[-1].sparse_size(0) + root_repr = repr[:root_size] + repr = self.gcs[-1]((repr, root_repr), block[-1]) + elif len(block) == 1: + for i in range(self.n_layers-1): + repr = self.gcs[i](repr, block[0]) + if self.normalize: + repr = self.norm(repr) + if self.batch_norm: + repr = self.bns[i](repr) + repr = self.activation(repr) + repr = F.dropout(repr, self.dropout, training=self.training) + repr = self.gcs[-1](repr, block[0]) + else: + raise ValueError('The sampling layer must be equal to GNN layer.') + + return F.log_softmax(repr, dim=1) + + def inference(self, x_all, subgraph_loader, device): + # Compute representations of nodes layer by layer, using *all* + # available edges. This leads to faster computation in contrast to + # immediately computing the final representations of each batch. + for i in range(self.n_layers): + xs = [] + for batch in subgraph_loader: + batch_in, _, block = batch + block.to_device(device) + x = x_all[batch_in].to(device) + root_size = block[0].sparse_size(0) + root_x = x[:root_size] + x = self.gcs[i]((x, root_x), block[0]) + # one-layer sampling + if i != self.n_layers - 1: + if self.batch_norm: + x = self.bns[i](x) + x = F.relu(x) + xs.append(x.cpu()) + + x_all = torch.cat(xs, dim=0) + + return x_all + +class GAT(nn.Module): + def __init__(self, n_feat, n_hid, n_class, n_heads, n_layers=2, dropout=0.6, activation=F.elu, attn_dropout=0.6, batch_norm=False): + super(GAT, self).__init__() + self.gcs = nn.ModuleList() + self.gcs.append(GATConv(n_feat, n_hid // n_heads[0], n_heads[0], dropout=attn_dropout)) + self.n_layers = n_layers + self.batch_norm = batch_norm + if self.batch_norm: + self.bns = nn.ModuleList() + self.bns.append(nn.BatchNorm1d(n_hid)) + for i in range(n_layers-2): + self.gcs.append(GATConv(n_hid, n_hid // n_heads[i + 1], n_heads[i + 1], dropout=attn_dropout)) + if self.batch_norm: + self.bns.append(nn.BatchNorm1d(n_hid)) + self.gcs.append(GATConv(n_hid, n_class, n_heads[-1], concat=False, dropout=attn_dropout)) + self.dropout = dropout + self.activation = activation + + def forward(self, x, block): + repr = x + if isinstance(block, (SparseTensor, torch.Tensor)): + block = [block] + if len(block) == self.n_layers: + for i in range(self.n_layers-1): + repr = self.gcs[i](repr, block[i]) + if self.batch_norm: + repr = self.bns[i](repr) + repr = self.activation(repr) + repr = F.dropout(repr, self.dropout, training=self.training) + repr = self.gcs[-1](repr, block[-1]) + elif len(block) == 1: + for i in range(self.n_layers-1): + repr = self.gcs[i](repr, block[0]) + if self.batch_norm: + repr = self.bns[i](repr) + repr = self.activation(repr) + repr = F.dropout(repr, self.dropout, training=self.training) + repr = self.gcs[-1](repr, block[0]) + else: + raise ValueError('The sampling layer must be equal to GNN layer.') + + return F.log_softmax(repr, dim=-1) \ No newline at end of file diff --git a/sgl/models/simple_models.py b/sgl/models/simple_models.py index 3646495..7869493 100644 --- a/sgl/models/simple_models.py +++ b/sgl/models/simple_models.py @@ -1,4 +1,3 @@ -import math import torch import torch.nn as nn import torch.nn.functional as F @@ -372,7 +371,7 @@ def inference(self, x_all, subgraph_loader, device): block.to_device(device) x = x_all[batch_in].to(device) x = self.gcs[i](x, block[0]) # one-layer sampling - if i != self.nlayers - 1: + if i != self.n_layers - 1: if self.batch_norm: x = self.bns[i](x) x = F.relu(x) @@ -413,7 +412,7 @@ def forward(self, x, block): if isinstance(block, torch.Tensor): block = [block] if len(block) == self.n_layers: - for i in range(self.nlayers-1): + for i in range(self.n_layers-1): repr = self.gcs[i](repr, block[i]) if self.batch_norm: repr = self.bns[i](repr) @@ -444,7 +443,7 @@ def inference(self, x_all, subgraph_loader, device): block.to_device(device) x = x_all[batch_in].to(device) x = self.gcs[i](x, block[0]) # one-layer sampling - if i != self.nlayers - 1: + if i != self.n_layers - 1: if self.batch_norm: x = self.bns[i](x) x = self.activation(x) @@ -454,7 +453,6 @@ def inference(self, x_all, subgraph_loader, device): return x_all - class GAT(nn.Module): def __init__(self, n_feat, n_hid, n_class, n_heads, n_layers=2, dropout=0.6, activation=F.elu): super(GAT, self).__init__() diff --git a/sgl/sampler/base_sampler.py b/sgl/sampler/base_sampler.py index b0249a7..7382598 100644 --- a/sgl/sampler/base_sampler.py +++ b/sgl/sampler/base_sampler.py @@ -8,7 +8,7 @@ from sgl.data.base_data import Block import sgl.operators.graph_op as GraphOps from sgl.sampler.utils import adj_train_analysis -from sgl.utils import sparse_mx_to_torch_sparse_tensor +from sgl.utils import sparse_mx_to_torch_sparse_tensor, sparse_mx_to_pyg_sparse_tensor from sampling_ops import NodeWiseOneLayer @@ -85,12 +85,14 @@ def _post_process(self, adjs, to_sparse_tensor=True): if self._post_sampling_op is not None: adjs = [self._post_sampling_op._construct_adj(adj) for adj in adjs] if to_sparse_tensor: - adjs = [sparse_mx_to_torch_sparse_tensor(adj) for adj in adjs] + # adjs = [sparse_mx_to_torch_sparse_tensor(adj) for adj in adjs] + adjs = [sparse_mx_to_pyg_sparse_tensor(adj) for adj in adjs] else: if self._post_sampling_op is not None: adjs = self._post_sampling_op._construct_adj(adjs) if to_sparse_tensor: - adjs = sparse_mx_to_torch_sparse_tensor(adjs) + # adjs = sparse_mx_to_torch_sparse_tensor(adjs) + adjs = sparse_mx_to_pyg_sparse_tensor(adjs) return adjs def _to_Block(self, adjs): diff --git a/sgl/tasks/node_classification_GAug.py b/sgl/tasks/node_classification_GAug.py index 1d54c7a..0c0a071 100644 --- a/sgl/tasks/node_classification_GAug.py +++ b/sgl/tasks/node_classification_GAug.py @@ -11,7 +11,7 @@ from sgl.tasks.utils import set_seed, accuracy, MultipleOptimizer class NodeClassificationGAugO(BaseTask): - def __init__(self, dataset, model, lr, weight_decay, epochs, device, seed, beta, warmup, pretrain_ep, pretrain_nc): + def __init__(self, dataset, model, lr, weight_decay, epochs, device, seed, beta, warmup, max_patience, pretrain_ep, pretrain_nc): super(NodeClassificationGAugO, self).__init__() self.__dataset = dataset @@ -30,6 +30,7 @@ def __init__(self, dataset, model, lr, weight_decay, epochs, device, seed, beta, self.__warmup = warmup self.__beta = beta + self.__max_patience = max_patience self.__pretrain_ep = pretrain_ep self.__pretrain_nc = pretrain_nc @@ -126,7 +127,7 @@ def evaluate(self, features, adj): def _execute(self): set_seed(self.__seed) - features, adj_orig, adj_norm, adj = self.__model.preprocess(self.__dataset.x, self.__dataset.adj, self.__device) + features, adj_orig, adj, adj_norm = self.__model.preprocess(self.__dataset.x, self.__dataset.adj, self.__device) self.__model = self.__model.to(self.__device) self.__labels = self.__labels.to(self.__device) @@ -167,7 +168,7 @@ def _execute(self): patience_step = 0 else: patience_step += 1 - if patience_step == 50: + if patience_step == self.__max_patience: break # release RAM and GPU memory @@ -179,7 +180,7 @@ def _execute(self): class NodeClassificationGAugM(BaseTask): - def __init__(self, dataset, model, lr, weight_decay, epochs, device, loss_fn=nn.CrossEntropyLoss(), seed=42): + def __init__(self, dataset, model, lr, weight_decay, epochs, device, loss_fn=nn.CrossEntropyLoss(), seed=42, max_patience=100): super(NodeClassificationGAugM, self).__init__() self.__dataset = dataset @@ -192,6 +193,7 @@ def __init__(self, dataset, model, lr, weight_decay, epochs, device, loss_fn=nn. self.__loss_fn = loss_fn self.__device = device self.__seed = seed + self.__max_patience = max_patience self.__test_acc = self._execute() @@ -225,8 +227,8 @@ def _execute(self): set_seed(self.__seed) pre_time_st = time.time() - A_pred_dir = os.path.join(self.__dataset.processed_dir, "GAugM_edge_probabilities") - adj_norm, features = self.__model.preprocess(self.__dataset.adj, self.__dataset.x, A_pred_dir, self.__device) + adj_pred_dir = os.path.join(self.__dataset.processed_dir, "GAugM_edge_probabilities") + adj, features = self.__model.preprocess(self.__dataset.adj, self.__dataset.x, adj_pred_dir, self.__device) pre_time_ed = time.time() print(f"Preprocessing done in {(pre_time_ed - pre_time_st):.4f}s") @@ -236,10 +238,11 @@ def _execute(self): t_total = time.time() best_val = 0. best_test = 0. + patience = 0 for epoch in range(self.__epochs): t = time.time() - loss_train, acc_train = self.train(adj_norm, features) - acc_val, acc_test = self.evaluate(adj_norm, features) + loss_train, acc_train = self.train(adj, features) + acc_val, acc_test = self.evaluate(adj, features) print('Epoch: {:03d}'.format(epoch + 1), 'loss_train: {:.4f}'.format(loss_train), @@ -251,12 +254,17 @@ def _execute(self): if acc_val > best_val: best_val = acc_val best_test = acc_test + patience = 0 + else: + patience += 1 + if patience == self.__max_patience: + break print("Optimization Finished!") print("Total time elapsed: {:.4f}s".format(time.time() - t_total)) print(f'Best val: {best_val:.4f}, best test: {best_test:.4f}') - del adj_norm, features + del adj, features torch.cuda.empty_cache() gc.collect() diff --git a/sgl/tasks/node_classification_sampling.py b/sgl/tasks/node_classification_sampling.py index b132421..ded2d60 100644 --- a/sgl/tasks/node_classification_sampling.py +++ b/sgl/tasks/node_classification_sampling.py @@ -1,6 +1,7 @@ import time import torch import numpy as np +from typing import Callable from torch.optim import Adam import torch.nn.functional as F from torch.utils.data import DataLoader @@ -12,7 +13,7 @@ class NodeClassification_Sampling(BaseTask): def __init__(self, dataset, model, lr, weight_decay, epochs, device, loss_fn="nll_loss", seed=42, - inductive=False, train_batch_size=None, eval_batch_size=None, **kwargs): + inductive=False, train_batch_size=None, eval_batch_size=None, eval_freq=1, eval_start=1, **kwargs): super(NodeClassification_Sampling, self).__init__() self.__dataset = dataset @@ -21,6 +22,8 @@ def __init__(self, dataset, model, lr, weight_decay, epochs, device, loss_fn="nl self.__optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay) self.__epochs = epochs + self.__eval_freq = eval_freq + self.__eval_start = eval_start self.__loss_fn = getattr(F, loss_fn) if isinstance(loss_fn, str) else loss_fn self.__device = device self.__seed = seed @@ -98,33 +101,45 @@ def _execute(self): for epoch in range(self.__epochs): t = time.time() if self.__mini_batch_train: - loss_train, acc_train = mini_batch_train(self.__model, self.__train_loader, self.__inductive, self.__device, - self.__optimizer, self.__loss_fn) + if hasattr(self.__model, "train_func") and isinstance(self.__model.train_func, Callable): + loss_train, acc_train = self.__model.train_func(self.__train_loader, self.__inductive, self.__device, self.__optimizer, self.__loss_fn) + else: + loss_train, acc_train = mini_batch_train(self.__model, self.__train_loader, self.__inductive, self.__device, + self.__optimizer, self.__loss_fn) else: loss_train, acc_train = train(self.__model, self.__dataset.train_idx, self.__optimizer, self.__loss_fn) - if self.__mini_batch_eval: - if self.__eval_together is False: - acc_val, acc_test = mini_batch_evaluate(self.__model, self.__val_loader, self.__test_loader, self.__device) + if epoch + 1 >= self.__eval_start and (epoch + 1) % self.__eval_freq == 0: + if self.__mini_batch_eval: + if self.__eval_together is False: + if hasattr(self.__model, "evaluate_func") and isinstance(self.__model.evaluate_func, Callable): + acc_val, acc_test = self.evaluate_func(self.__val_loader, self.__test_loader, self.__device) + else: + acc_val, acc_test = mini_batch_evaluate(self.__model, self.__val_loader, self.__test_loader, self.__device) + else: + self.__model.eval() + outputs = self.__model.inference(self.__all_eval_loader, self.__device) + acc_train = accuracy(outputs[self.__dataset.train_idx], self.__dataset.y[self.__dataset.train_idx]) + acc_val = accuracy(outputs[self.__dataset.val_idx], self.__dataset.y[self.__dataset.val_idx]) + acc_test = accuracy(outputs[self.__dataset.test_idx], self.__dataset.y[self.__dataset.test_idx]) else: - self.__model.eval() - outputs = self.__model.inference(self.__all_eval_loader, self.__device) - acc_train = accuracy(outputs[self.__dataset.train_idx], self.__dataset.y[self.__dataset.train_idx]) - acc_val = accuracy(outputs[self.__dataset.val_idx], self.__dataset.y[self.__dataset.val_idx]) - acc_test = accuracy(outputs[self.__dataset.test_idx], self.__dataset.y[self.__dataset.test_idx]) + acc_val, acc_test = evaluate(self.__model, self.__dataset.val_idx, self.__dataset.test_idx) + + if acc_val > best_val: + best_val = acc_val + best_test = acc_test + + print('Epoch: {:03d}'.format(epoch + 1), + 'loss_train: {:.4f}'.format(loss_train), + 'acc_train: {:.4f}'.format(acc_train), + 'acc_val: {:.4f}'.format(acc_val), + 'acc_test: {:.4f}'.format(acc_test), + 'time: {:.4f}s'.format(time.time() - t)) else: - acc_val, acc_test = evaluate(self.__model, self.__dataset.val_idx, self.__dataset.test_idx) - - print('Epoch: {:03d}'.format(epoch + 1), - 'loss_train: {:.4f}'.format(loss_train), - 'acc_train: {:.4f}'.format(acc_train), - 'acc_val: {:.4f}'.format(acc_val), - 'acc_test: {:.4f}'.format(acc_test), - 'time: {:.4f}s'.format(time.time() - t)) - - if acc_val > best_val: - best_val = acc_val - best_test = acc_test + print('Epoch: {:03d}'.format(epoch + 1), + 'loss_train: {:.4f}'.format(loss_train), + 'acc_train: {:.4f}'.format(acc_train), + 'time: {:.4f}s'.format(time.time() - t)) acc_val, acc_test = self._postprocess() if acc_val > best_val: diff --git a/sgl/utils/__init__.py b/sgl/utils/__init__.py index 369b4d6..d6768a4 100644 --- a/sgl/utils/__init__.py +++ b/sgl/utils/__init__.py @@ -1,7 +1,8 @@ from .auto_choose_gpu import GpuWithMaxFreeMem -from .basic_operations import sparse_mx_to_torch_sparse_tensor +from .basic_operations import sparse_mx_to_torch_sparse_tensor, sparse_mx_to_pyg_sparse_tensor __all__ = [ "GpuWithMaxFreeMem", "sparse_mx_to_torch_sparse_tensor", + "sparse_mx_to_pyg_sparse_tensor" ] diff --git a/sgl/utils/basic_operations.py b/sgl/utils/basic_operations.py index 5766ba3..c5c397d 100644 --- a/sgl/utils/basic_operations.py +++ b/sgl/utils/basic_operations.py @@ -1,5 +1,6 @@ import torch import numpy as np +from torch_sparse import SparseTensor def sparse_mx_to_torch_sparse_tensor(sparse_mx): """Convert a scipy sparse matrix to a torch sparse tensor.""" @@ -8,4 +9,13 @@ def sparse_mx_to_torch_sparse_tensor(sparse_mx): np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) values = torch.from_numpy(sparse_mx.data) shape = torch.Size(sparse_mx.shape) - return torch.sparse.FloatTensor(indices, values, shape) \ No newline at end of file + return torch.sparse.FloatTensor(indices, values, shape) + +def sparse_mx_to_pyg_sparse_tensor(sparse_mx): + """Convert a scipy sparse matrix to a PyG SparseTensor""" + sparse_mx = sparse_mx.tocoo() + row = torch.from_numpy(sparse_mx.row).to(torch.long) + col = torch.from_numpy(sparse_mx.col).to(torch.long) + value = torch.from_numpy(sparse_mx.data) + sparse_sizes = torch.Size(sparse_mx.shape) + return SparseTensor(row=row, col=col, value=value, sparse_sizes=sparse_sizes, is_sorted=True, trust_data=True) \ No newline at end of file From ec4aac117ae822c32935076b6edc5cb3e66dc0ad Mon Sep 17 00:00:00 2001 From: infinity Date: Sat, 16 Dec 2023 02:41:13 +0000 Subject: [PATCH 23/28] fix OOM problem when FLAG evaluates on Ogbn-products; test mini-batch FLAG with GAT --- examples/GDA/configs/SampleFLAG.yml | 54 ++++++++++++++++++++--- examples/configs/clustergcn.yml | 1 + examples/configs/fastgcn.yml | 1 + examples/configs/graphsage.yml | 2 + examples/configs/graphsaint.yml | 1 + examples/configs/lazygnn.yml | 1 + examples/configs/vanillagnn.yml | 1 + sgl/data/base_data.py | 25 ++++++++--- sgl/models/base_model.py | 13 +++--- sgl/models/homo/clustergcn.py | 4 +- sgl/models/homo/fastgcn.py | 4 +- sgl/models/homo/gda/FLAG.py | 8 ++-- sgl/models/homo/graphsage.py | 4 +- sgl/models/homo/graphsaint.py | 4 +- sgl/models/homo/lazygnn.py | 4 +- sgl/models/homo/vanillagnn.py | 4 +- sgl/models/pyg_simple_models.py | 54 +++++++++++++++++++---- sgl/models/simple_models.py | 3 ++ sgl/sampler/base_sampler.py | 19 +++++--- sgl/sampler/sampler.py | 8 ++-- sgl/tasks/node_classification_sampling.py | 21 ++++----- 21 files changed, 174 insertions(+), 62 deletions(-) diff --git a/examples/GDA/configs/SampleFLAG.yml b/examples/GDA/configs/SampleFLAG.yml index a8080ea..fe7db14 100644 --- a/examples/GDA/configs/SampleFLAG.yml +++ b/examples/GDA/configs/SampleFLAG.yml @@ -1,6 +1,44 @@ +# dataset: +# classname: "Ogbn" +# name: "products" +# root: "/home/ssq/test_data/" +# sampler: +# training: +# name: "NeighborSampler" +# layer_sizes: "15,10,5" +# prob_type: "normalize" +# replace: False +# eval: +# name: "NeighborSampler" +# layer_sizes: "-1" +# replace: False +# model: +# gnn_type: 'sage' +# hidden_dim: 256 +# dropout: 0.5 +# n_layers: 3 +# step_size: 0.008 +# augM: 3 +# amp: 2 +# batch_norm: False +# task: +# name: "NodeClassification_Sampling" +# lr: 0.003 +# seed: 12345 +# epochs: 20 +# patience: 10 +# weight_decay: 0 +# train_batch_size: 1024 +# eval_batch_size: 4096 +# train_num_workers: 12 +# eval_num_workers: 12 +# eval_together: True +# eval_freq: 2 +# eval_start: 10 +# loss_fn: "nll_loss" dataset: classname: "Ogbn" - name: "products" + name: "arxiv" root: "/home/ssq/test_data/" sampler: training: @@ -13,8 +51,9 @@ sampler: layer_sizes: "-1" replace: False model: - gnn_type: 'sage' - hidden_dim: 256 + gnn_type: 'gat' + hidden_dim: 64 + n_heads: "8,8,1" dropout: 0.5 n_layers: 3 step_size: 0.008 @@ -23,16 +62,17 @@ model: batch_norm: False task: name: "NodeClassification_Sampling" - lr: 0.003 + lr: 0.002 seed: 12345 - epochs: 20 - patience: 10 + epochs: 500 + patience: 50 weight_decay: 0 train_batch_size: 1024 eval_batch_size: 4096 train_num_workers: 12 eval_num_workers: 12 eval_together: True + pin_memory: True eval_freq: 2 - eval_start: 10 + eval_start: 1 loss_fn: "nll_loss" \ No newline at end of file diff --git a/examples/configs/clustergcn.yml b/examples/configs/clustergcn.yml index 660e9ce..1423719 100644 --- a/examples/configs/clustergcn.yml +++ b/examples/configs/clustergcn.yml @@ -7,6 +7,7 @@ sampler: cluster_method: "metis" cluster_number: 10 post_sampling_op: "LaplacianGraphOp" + sparse_type: "torch" model: hidden_dim: 128 dropout: 0.5 diff --git a/examples/configs/fastgcn.yml b/examples/configs/fastgcn.yml index a5d4f09..e4c16f2 100644 --- a/examples/configs/fastgcn.yml +++ b/examples/configs/fastgcn.yml @@ -11,6 +11,7 @@ sampler: layer_sizes: "2048,2048" prob_type: "normalize" replace: False + sparse_type: "torch" model: name: "FastGCN" hidden_dim: 128 diff --git a/examples/configs/graphsage.yml b/examples/configs/graphsage.yml index d12c9a1..8d96760 100644 --- a/examples/configs/graphsage.yml +++ b/examples/configs/graphsage.yml @@ -11,12 +11,14 @@ sampler: prob_type: "normalize" replace: False post_sampling_op: "RwGraphOp" + sparse_type: "torch" # eval: # name: "NeighborSampler" # inductive: False # layer_sizes: "-1" # prob_type: "normalize" # post_sampling_op: "RwGraphOp" + # sparse_type: "torch" model: name: "GraphSAGE" hidden_dim: 256 diff --git a/examples/configs/graphsaint.yml b/examples/configs/graphsaint.yml index 3b44515..4d98f59 100644 --- a/examples/configs/graphsaint.yml +++ b/examples/configs/graphsaint.yml @@ -11,6 +11,7 @@ sampler: r: 500 h: 4 pre_sampling_op: "RwGraphOp" + sparse_type: "torch" model: hidden_dim: 128 dropout: 0.5 diff --git a/examples/configs/lazygnn.yml b/examples/configs/lazygnn.yml index 9f64cc7..ffe25a5 100644 --- a/examples/configs/lazygnn.yml +++ b/examples/configs/lazygnn.yml @@ -11,6 +11,7 @@ sampler: prob_type: "normalize" replace: False post_sampling_op: "LaplacianGraphOp" + sparse_type: "torch" model: name: "LazyGNN" basemodel: "GCN" diff --git a/examples/configs/vanillagnn.yml b/examples/configs/vanillagnn.yml index 9eae60b..2de9fb0 100644 --- a/examples/configs/vanillagnn.yml +++ b/examples/configs/vanillagnn.yml @@ -7,6 +7,7 @@ sampler: training: name: "FullSampler" inductive: False + sparse_type: "torch" model: name: "VanillaGNN" basemodel: "SAGE" diff --git a/sgl/data/base_data.py b/sgl/data/base_data.py index 2cb0c40..f2e2cad 100644 --- a/sgl/data/base_data.py +++ b/sgl/data/base_data.py @@ -1,18 +1,28 @@ import torch from torch import Tensor import numpy as np +import scipy.sparse as sp from scipy.sparse import csr_matrix - +from torch_sparse import SparseTensor from sgl.utils import sparse_mx_to_torch_sparse_tensor, sparse_mx_to_pyg_sparse_tensor # A lighter wrapper class for sampled adjacency matrices, # as the Edge class seems contains useless information class Block: - def __init__(self, adjs): + def __init__(self, adjs, sparse_type): + self.__sparse_type = sparse_type if not isinstance(adjs, list): self.__adjs = [adjs] + if isinstance(adjs, SparseTensor): + self.__root_sizes = [adjs.sparse_size(0)] + else: + self.__root_sizes = [adjs.shape[0]] else: self.__adjs = adjs + if isinstance(adjs[0], SparseTensor): + self.__root_sizes = [adj.sparse_size(0) for adj in adjs] + else: + self.__root_sizes = [adj.shape[0] for adj in adjs] self.__device = None def __len__(self): @@ -24,13 +34,18 @@ def __iter__(self): def __getitem__(self, id): return self.__adjs[id] + + def root_size(self, id): + return self.__root_sizes[id] def to_device(self, device): if self.__device == device: return - if not isinstance(self.__adjs[0], torch.sparse.FloatTensor): - # self.__adjs = [sparse_mx_to_torch_sparse_tensor(adj) for adj in self.__adjs] - self.__adjs = [sparse_mx_to_pyg_sparse_tensor(adj) for adj in self.__adjs] + if isinstance(self.__adjs[0], sp.spmatrix): + if self.__sparse_type == "pyg": + self.__adjs = [sparse_mx_to_pyg_sparse_tensor(adj) for adj in self.__adjs] + else: + self.__adjs = [sparse_mx_to_torch_sparse_tensor(adj) for adj in self.__adjs] self.__adjs = [adj.to(device) for adj in self.__adjs] self.__device = device diff --git a/sgl/models/base_model.py b/sgl/models/base_model.py index d9d6433..8458a44 100644 --- a/sgl/models/base_model.py +++ b/sgl/models/base_model.py @@ -67,9 +67,12 @@ def forward(self, idx, device): return output class BaseSAMPLEModel(nn.Module): - def __init__(self, evaluate_mode="full"): + def __init__(self, evaluate_mode="full", sparse_type="pyg"): super(BaseSAMPLEModel, self).__init__() self._evaluate_mode = evaluate_mode + if sparse_type not in ["pyg", "torch"]: + raise ValueError(f"sparse type {sparse_type} is not supported, please use either pyg or torch.") + self._sparse_type = sparse_type self._pre_graph_op, self._post_graph_op = None, None self._training_sampling_op, self._eval_sampling_op = None, None self._base_model = None @@ -117,7 +120,8 @@ def full_batch_prepare_forward(self, node_idx): y_pred = self._base_model(self._processed_feature, self._processed_block)[node_idx] y_truth = self._vanilla_y[node_idx] return y_pred, y_truth - + + @torch.no_grad() def inference(self, dataloader, device): preds = self._base_model.inference(self.processed_feature, dataloader, device) return preds @@ -127,9 +131,8 @@ def preprocess(self, adj, x, y, device, **kwargs): norm_adj = self._pre_graph_op._construct_adj(adj) else: norm_adj = adj - # norm_adj = sparse_mx_to_torch_sparse_tensor(norm_adj) - norm_adj = sparse_mx_to_pyg_sparse_tensor(norm_adj) - self._processed_block = Block(norm_adj) + + self._processed_block = Block(norm_adj, self._sparse_type) if hasattr(self, "_pre_feature_op"): self._processed_feature = self._pre_feature_op._transform_x(x) diff --git a/sgl/models/homo/clustergcn.py b/sgl/models/homo/clustergcn.py index 3b5308c..b6570ad 100644 --- a/sgl/models/homo/clustergcn.py +++ b/sgl/models/homo/clustergcn.py @@ -3,8 +3,8 @@ from sgl.operators.graph_op import LaplacianGraphOp class ClusterGCN(BaseSAMPLEModel): - def __init__(self, training_sampler, eval_sampler, nfeat, hidden_dim, nclass, dropout=0.5, num_layers=2, device="cpu"): - super(ClusterGCN, self).__init__(evaluate_mode="sampling") + def __init__(self, training_sampler, eval_sampler, nfeat, hidden_dim, nclass, sparse_type="torch", dropout=0.5, num_layers=2, device="cpu"): + super(ClusterGCN, self).__init__(evaluate_mode="sampling", sparse_type=sparse_type) self._pre_graph_op = LaplacianGraphOp(r=0.5) self._training_sampling_op = training_sampler self._eval_sampling_op = eval_sampler diff --git a/sgl/models/homo/fastgcn.py b/sgl/models/homo/fastgcn.py index a11f031..16fba2c 100644 --- a/sgl/models/homo/fastgcn.py +++ b/sgl/models/homo/fastgcn.py @@ -3,8 +3,8 @@ from sgl.operators.graph_op import LaplacianGraphOp class FastGCN(BaseSAMPLEModel): - def __init__(self, dataset, training_sampler, eval_sampler, hidden_dim, dropout=0.5, num_layers=2, device="cpu"): - super(FastGCN, self).__init__() + def __init__(self, dataset, training_sampler, eval_sampler, hidden_dim, sparse_type="torch", dropout=0.5, num_layers=2, device="cpu"): + super(FastGCN, self).__init__(sparse_type=sparse_type) self._pre_graph_op = LaplacianGraphOp(r=0.5) self._training_sampling_op = training_sampler self._eval_sampling_op = eval_sampler diff --git a/sgl/models/homo/gda/FLAG.py b/sgl/models/homo/gda/FLAG.py index 6905fed..e8b8c15 100644 --- a/sgl/models/homo/gda/FLAG.py +++ b/sgl/models/homo/gda/FLAG.py @@ -67,7 +67,7 @@ def flag(self, ground_truth_y, optimizer, device, train_idx, loss_fn): loss.backward() optimizer.step() - return loss + return loss.item() def train_func(self, train_idx, labels, device, optimizer, loss_fn, metric): loss_train = self.flag(labels[train_idx], optimizer, device, train_idx, loss_fn) @@ -76,7 +76,7 @@ def train_func(self, train_idx, labels, device, optimizer, loss_fn, metric): pred_y = self._base_model(self.__features, self.__processed_adj) acc_train = metric(pred_y[train_idx], labels[train_idx]) - return loss_train.item(), acc_train + return loss_train, acc_train @torch.no_grad() def evaluate_func(self, val_idx, test_idx, labels, device, metric): @@ -150,7 +150,7 @@ def flag(self, clean, ground_truth_y, adjs, batch_out, optimizer, device, loss_f loss.backward() optimizer.step() - return loss, pred_y + return loss.item(), pred_y def mini_batch_prepare_forward(self, batch, device, loss_fn, optimizers, inductive=False, transfer_y_to_device=True): batch_in, batch_out, block = batch @@ -179,7 +179,7 @@ def train_func(self, train_loader, inductive, device, optimizer, loss_fn): loss_train, y_out, y_truth = self.mini_batch_prepare_forward(batch, device, loss_fn, optimizer, inductive=inductive) pred = y_out.max(1)[1].type_as(y_truth) correct_num += pred.eq(y_truth).double().sum() - loss_train_sum += loss_train.item() + loss_train_sum += loss_train train_num += len(y_truth) loss_train = loss_train_sum / len(train_loader) diff --git a/sgl/models/homo/graphsage.py b/sgl/models/homo/graphsage.py index ac1a6dc..c25962d 100644 --- a/sgl/models/homo/graphsage.py +++ b/sgl/models/homo/graphsage.py @@ -4,8 +4,8 @@ from sgl.operators.graph_op import RwGraphOp class GraphSAGE(BaseSAMPLEModel): - def __init__(self, dataset, training_sampler, eval_sampler, hidden_dim, dropout=0.5, num_layers=2, device="cpu"): - super(GraphSAGE, self).__init__() + def __init__(self, dataset, training_sampler, eval_sampler, hidden_dim, sparse_type="torch", dropout=0.5, num_layers=2, device="cpu"): + super(GraphSAGE, self).__init__(sparse_type=sparse_type) self._pre_graph_op = RwGraphOp() self._pre_feature_op = PreNormMessageOp(p=1, dim=1) self._training_sampling_op = training_sampler diff --git a/sgl/models/homo/graphsaint.py b/sgl/models/homo/graphsaint.py index 44f0c3b..7a0d0a6 100644 --- a/sgl/models/homo/graphsaint.py +++ b/sgl/models/homo/graphsaint.py @@ -5,8 +5,8 @@ from torch.nn.functional import nll_loss class GraphSAINT(BaseSAMPLEModel): - def __init__(self, dataset, training_sampler, eval_sampler, hidden_dim, dropout=0.5, num_layers=2, device="cpu"): - super(GraphSAINT, self).__init__() + def __init__(self, dataset, training_sampler, eval_sampler, hidden_dim, sparse_type="torch", dropout=0.5, num_layers=2, device="cpu"): + super(GraphSAINT, self).__init__(sparse_type=sparse_type) self._pre_graph_op = RwGraphOp() self._training_sampling_op = training_sampler self._eval_sampling_op = eval_sampler diff --git a/sgl/models/homo/lazygnn.py b/sgl/models/homo/lazygnn.py index 37dad4a..09e4cbd 100644 --- a/sgl/models/homo/lazygnn.py +++ b/sgl/models/homo/lazygnn.py @@ -10,8 +10,8 @@ import concurrent.futures class LazyGNN(BaseSAMPLEModel): - def __init__(self, dataset, training_sampler, eval_sampler=None, hidden_dim=128, basemodel="GCN", dropout=0.5, num_layers=2, max_workers=5, max_threads=-1, rho=1.1, tau=2, device="cpu"): - super(LazyGNN, self).__init__() + def __init__(self, dataset, training_sampler, eval_sampler=None, hidden_dim=128, basemodel="GCN", sparse_type="torch", dropout=0.5, num_layers=2, max_workers=5, max_threads=-1, rho=1.1, tau=2, device="cpu"): + super(LazyGNN, self).__init__(sparse_type=sparse_type) if basemodel == "SAGE": self._pre_graph_op = RwGraphOp() elif basemodel == "GCN": diff --git a/sgl/models/homo/vanillagnn.py b/sgl/models/homo/vanillagnn.py index bb1c565..104e0fa 100644 --- a/sgl/models/homo/vanillagnn.py +++ b/sgl/models/homo/vanillagnn.py @@ -7,8 +7,8 @@ class VanillaGNN(BaseSAMPLEModel): """ It is a naive version of Graph Convolutional Network which works in full-batch training. """ - def __init__(self, dataset, training_sampler, eval_sampler, hidden_dim, basemodel="GCN", dropout=0.5, num_layers=2, device="cpu"): - super(VanillaGNN, self).__init__(evaluate_mode="full") + def __init__(self, dataset, training_sampler, eval_sampler, hidden_dim, basemodel="GCN", sparse_type="torch", dropout=0.5, num_layers=2, device="cpu"): + super(VanillaGNN, self).__init__(evaluate_mode="full", sparse_type=sparse_type) if basemodel == "SAGE": self._pre_graph_op = RwGraphOp() elif basemodel == "GCN": diff --git a/sgl/models/pyg_simple_models.py b/sgl/models/pyg_simple_models.py index 70136b1..f3763b3 100644 --- a/sgl/models/pyg_simple_models.py +++ b/sgl/models/pyg_simple_models.py @@ -54,6 +54,7 @@ def forward(self, x, block): return F.log_softmax(repr, dim=1) + @torch.no_grad() def inference(self, x_all, subgraph_loader, device): # Compute representations of nodes layer by layer, using *all* # available edges. This leads to faster computation in contrast to @@ -109,7 +110,7 @@ def forward(self, x, block): block = [block] if len(block) == self.n_layers: for i in range(self.n_layers-1): - root_size = block[i].sparse_size(0) + root_size = block.root_size(i) root_repr = repr[:root_size] repr = self.gcs[i]((repr, root_repr), block[i]) if self.normalize: @@ -118,7 +119,7 @@ def forward(self, x, block): repr = self.bns[i](repr) repr = self.activation(repr) repr = F.dropout(repr, self.dropout, training=self.training) - root_size = block[-1].sparse_size(0) + root_size = block.root_size(-1) root_repr = repr[:root_size] repr = self.gcs[-1]((repr, root_repr), block[-1]) elif len(block) == 1: @@ -136,6 +137,7 @@ def forward(self, x, block): return F.log_softmax(repr, dim=1) + @torch.no_grad() def inference(self, x_all, subgraph_loader, device): # Compute representations of nodes layer by layer, using *all* # available edges. This leads to faster computation in contrast to @@ -143,17 +145,17 @@ def inference(self, x_all, subgraph_loader, device): for i in range(self.n_layers): xs = [] for batch in subgraph_loader: - batch_in, _, block = batch + batch_in, batch_out, block = batch block.to_device(device) x = x_all[batch_in].to(device) - root_size = block[0].sparse_size(0) + root_size = len(batch_out) root_x = x[:root_size] x = self.gcs[i]((x, root_x), block[0]) # one-layer sampling if i != self.n_layers - 1: if self.batch_norm: x = self.bns[i](x) - x = F.relu(x) + x = self.activation(x) xs.append(x.cpu()) x_all = torch.cat(xs, dim=0) @@ -178,18 +180,29 @@ def __init__(self, n_feat, n_hid, n_class, n_heads, n_layers=2, dropout=0.6, act self.dropout = dropout self.activation = activation + def reset_parameter(self): + for conv in self.gcs: + conv.reset_parameters() + if self.batch_norm: + for bn in self.bns: + bn.reset_parameters() + def forward(self, x, block): repr = x if isinstance(block, (SparseTensor, torch.Tensor)): block = [block] if len(block) == self.n_layers: for i in range(self.n_layers-1): - repr = self.gcs[i](repr, block[i]) + root_size = block.root_size(i) + root_repr = repr[:root_size] + repr = self.gcs[i]((repr, root_repr), block[i]) if self.batch_norm: repr = self.bns[i](repr) repr = self.activation(repr) repr = F.dropout(repr, self.dropout, training=self.training) - repr = self.gcs[-1](repr, block[-1]) + root_size = block.root_size(-1) + root_repr = repr[:root_size] + repr = self.gcs[-1]((repr, root_repr), block[-1]) elif len(block) == 1: for i in range(self.n_layers-1): repr = self.gcs[i](repr, block[0]) @@ -201,4 +214,29 @@ def forward(self, x, block): else: raise ValueError('The sampling layer must be equal to GNN layer.') - return F.log_softmax(repr, dim=-1) \ No newline at end of file + return F.log_softmax(repr, dim=-1) + + @torch.no_grad() + def inference(self, x_all, subgraph_loader, device): + # Compute representations of nodes layer by layer, using *all* + # available edges. This leads to faster computation in contrast to + # immediately computing the final representations of each batch. + for i in range(self.n_layers): + xs = [] + for batch in subgraph_loader: + batch_in, batch_out, block = batch + block.to_device(device) + x = x_all[batch_in].to(device) + root_size = len(batch_out) + root_x = x[:root_size] + x = self.gcs[i]((x, root_x), block[0]) + # one-layer sampling + if i != self.n_layers - 1: + if self.batch_norm: + x = self.bns[i](x) + x = self.activation(x) + xs.append(x.cpu()) + + x_all = torch.cat(xs, dim=0) + + return x_all \ No newline at end of file diff --git a/sgl/models/simple_models.py b/sgl/models/simple_models.py index 7869493..56c36fb 100644 --- a/sgl/models/simple_models.py +++ b/sgl/models/simple_models.py @@ -454,6 +454,9 @@ def inference(self, x_all, subgraph_loader, device): return x_all class GAT(nn.Module): + """ + This GAT only accepts dense tensor as input (doesn't support torch.sparse.tensor) + """ def __init__(self, n_feat, n_hid, n_class, n_heads, n_layers=2, dropout=0.6, activation=F.elu): super(GAT, self).__init__() self.gcs = nn.ModuleList() diff --git a/sgl/sampler/base_sampler.py b/sgl/sampler/base_sampler.py index 7382598..e2df988 100644 --- a/sgl/sampler/base_sampler.py +++ b/sgl/sampler/base_sampler.py @@ -12,6 +12,8 @@ from sampling_ops import NodeWiseOneLayer +SPARSE_TRANSFORM = {"pyg": sparse_mx_to_pyg_sparse_tensor, "torch": sparse_mx_to_torch_sparse_tensor} + class BaseSampler: def __init__(self, adj, **kwargs): self._adj = adj @@ -35,6 +37,8 @@ def __init__(self, adj, **kwargs): elif graph_op == "RwGraphOp": self._post_sampling_op = getattr(GraphOps, "RwGraphOp")() + self._sparse_type = kwargs.get("sparse_type", "pyg") + self._pre_process(**kwargs) def _pre_process(self, **kwargs): @@ -85,18 +89,19 @@ def _post_process(self, adjs, to_sparse_tensor=True): if self._post_sampling_op is not None: adjs = [self._post_sampling_op._construct_adj(adj) for adj in adjs] if to_sparse_tensor: - # adjs = [sparse_mx_to_torch_sparse_tensor(adj) for adj in adjs] - adjs = [sparse_mx_to_pyg_sparse_tensor(adj) for adj in adjs] + sparse_transform_func = SPARSE_TRANSFORM.get(self._sparse_type) + adjs = [sparse_transform_func(adj) for adj in adjs] else: if self._post_sampling_op is not None: adjs = self._post_sampling_op._construct_adj(adjs) if to_sparse_tensor: - # adjs = sparse_mx_to_torch_sparse_tensor(adjs) - adjs = sparse_mx_to_pyg_sparse_tensor(adjs) + sparse_transform_func = SPARSE_TRANSFORM.get(self._sparse_type) + adjs = [sparse_transform_func(adj) for adj in adjs] return adjs - def _to_Block(self, adjs): - return Block(adjs) + @staticmethod + def to_Block(adjs, sparse_type): + return Block(adjs, sparse_type) def collate_fn(self, *args): raise NotImplementedError @@ -111,7 +116,7 @@ def __init__(self, adj, **kwargs): self.sample_level = "graph" self.pre_sampling = False self.full_batch = kwargs.get("node_ids", range(self._adj.shape[0])) - self.full_block = self._to_Block(self._adj) + self.full_block = self.to_Block(self._adj, self._sparse_type) def sampling(self): return self.full_batch, self.full_batch, self.full_block diff --git a/sgl/sampler/sampler.py b/sgl/sampler/sampler.py index 9cee82b..b38a5fd 100644 --- a/sgl/sampler/sampler.py +++ b/sgl/sampler/sampler.py @@ -49,7 +49,7 @@ def collate_fn(self, batch_inds): all_adjs = self._post_process(all_adjs, to_sparse_tensor=False) - return cur_tgt_nodes, batch_inds, self._to_Block(all_adjs) + return cur_tgt_nodes, batch_inds, self.to_Block(all_adjs, self._sparse_type) class FastGCNSampler(LayerWiseSampler): def __init__(self, adj, **kwargs): @@ -84,7 +84,7 @@ def collate_fn(self, batch_inds): all_adjs = self._post_process(all_adjs, to_sparse_tensor=False) - return cur_out_nodes, batch_inds, self._to_Block(all_adjs) + return cur_out_nodes, batch_inds, self.to_Block(all_adjs, self._sparse_type) class ClusterGCNSampler(GraphWiseSampler): """ @@ -131,7 +131,7 @@ def collate_fn(self, batch_inds, mode): node_idx = torch.cat([torch.arange(s, e) for s, e in zip(start, end)]) global_node_idx = self.perm_node_idx[node_idx] composed_sparse_mx = sp.block_diag([self.splitted_perm_adjs[batch_ind.item()] for batch_ind in batch_inds]) - block = self._to_Block(composed_sparse_mx) + block = self.to_Block(composed_sparse_mx, self._sparse_type) if mode in ["train", "val", "test"]: mask = self._masks[mode][global_node_idx] global_inds = global_node_idx[mask] @@ -371,4 +371,4 @@ def collate_fn(self, batch_ids, mode): self.cur_index = global_inds - return batch_in,batch_out,self._to_Block(batched_adj) \ No newline at end of file + return batch_in, batch_out, self.to_Block(batched_adj, self._sparse_type) \ No newline at end of file diff --git a/sgl/tasks/node_classification_sampling.py b/sgl/tasks/node_classification_sampling.py index ded2d60..1a9a255 100644 --- a/sgl/tasks/node_classification_sampling.py +++ b/sgl/tasks/node_classification_sampling.py @@ -43,6 +43,7 @@ def __init__(self, dataset, model, lr, weight_decay, epochs, device, loss_fn="nl self.__eval_determined_sample = True self.__train_num_workers = kwargs.get("train_num_workers", 0) self.__eval_num_workers = kwargs.get("eval_num_workers", 0) + self.__pin_memory = kwargs.get("pin_memory", False) self.__test_acc = self._execute() @property @@ -65,32 +66,32 @@ def _execute(self): if self.__train_determined_sample: self.__model.pre_sample("train") self.__train_loader = DataLoader( - range(self.__train_graph_number), batch_size=self.__train_batch_size, num_workers=self.__train_num_workers, collate_fn=lambda x: self.__model.collate_fn(x, "train"), shuffle=True, drop_last=False) + range(self.__train_graph_number), batch_size=self.__train_batch_size, num_workers=self.__train_num_workers, collate_fn=lambda x: self.__model.collate_fn(x, "train"), shuffle=True, drop_last=False, pin_memory=self.__pin_memory) else: if self.__inductive is False: self.__train_loader = DataLoader( - self.__dataset.train_idx, batch_size=self.__train_batch_size, num_workers=self.__train_num_workers, collate_fn=self.__model.train_collate_fn, shuffle=True, drop_last=False) + self.__dataset.train_idx, batch_size=self.__train_batch_size, num_workers=self.__train_num_workers, collate_fn=self.__model.train_collate_fn, shuffle=True, drop_last=False, pin_memory=self.__pin_memory) else: self.__train_loader = DataLoader( - range(len(self.__dataset.train_idx)), batch_size=self.__train_batch_size, num_workers=self.__train_num_workers, collate_fn=self.__model.train_collate_fn, shuffle=True, drop_last=False) + range(len(self.__dataset.train_idx)), batch_size=self.__train_batch_size, num_workers=self.__train_num_workers, collate_fn=self.__model.train_collate_fn, shuffle=True, drop_last=False, pin_memory=self.__pin_memory) if self.__mini_batch_eval: if self.__eval_determined_sample: self.__model.pre_sample("eval") self.__val_loader = DataLoader( - range(self.__eval_graph_number), batch_size=self.__eval_batch_size, num_workers=self.__eval_num_workers, collate_fn=lambda x: self.__model.collate_fn(x, "val"), shuffle=False, drop_last=False) + range(self.__eval_graph_number), batch_size=self.__eval_batch_size, num_workers=self.__eval_num_workers, collate_fn=lambda x: self.__model.collate_fn(x, "val"), shuffle=False, drop_last=False, pin_memory=self.__pin_memory) self.__test_loader = DataLoader( - range(self.__eval_graph_number), batch_size=self.__eval_batch_size, num_workers=self.__eval_num_workers, collate_fn=lambda x: self.__model.collate_fn(x, "test"), shuffle=False, drop_last=False) + range(self.__eval_graph_number), batch_size=self.__eval_batch_size, num_workers=self.__eval_num_workers, collate_fn=lambda x: self.__model.collate_fn(x, "test"), shuffle=False, drop_last=False, pin_memory=self.__pin_memory) self.__all_eval_loader = DataLoader( - range(self.__eval_graph_number), batch_size=self.__eval_batch_size, num_workers=self.__eval_num_workers, collate_fn=lambda x: self.__model.collate_fn(x, "val_test"), shuffle=False, drop_last=False) + range(self.__eval_graph_number), batch_size=self.__eval_batch_size, num_workers=self.__eval_num_workers, collate_fn=lambda x: self.__model.collate_fn(x, "val_test"), shuffle=False, drop_last=False, pin_memory=self.__pin_memory) else: if self.__eval_together is False: self.__val_loader = DataLoader( - self.__dataset.val_idx, batch_size=self.__eval_batch_size, num_workers=self.__eval_num_workers, collate_fn=self.__model.eval_collate_fn, shuffle=False, drop_last=False) + self.__dataset.val_idx, batch_size=self.__eval_batch_size, num_workers=self.__eval_num_workers, collate_fn=self.__model.eval_collate_fn, shuffle=False, drop_last=False, pin_memory=self.__pin_memory) self.__test_loader = DataLoader( - self.__dataset.test_idx, batch_size=self.__eval_batch_size, num_workers=self.__eval_num_workers, collate_fn=self.__model.eval_collate_fn, shuffle=False, drop_last=False) + self.__dataset.test_idx, batch_size=self.__eval_batch_size, num_workers=self.__eval_num_workers, collate_fn=self.__model.eval_collate_fn, shuffle=False, drop_last=False, pin_memory=self.__pin_memory) self.__all_eval_loader = DataLoader( - self.__dataset.node_ids, batch_size=self.__eval_batch_size, num_workers=self.__eval_num_workers, collate_fn=self.__model.eval_collate_fn, shuffle=False, drop_last=False) + self.__dataset.node_ids, batch_size=self.__eval_batch_size, num_workers=self.__eval_num_workers, collate_fn=self.__model.eval_collate_fn, shuffle=False, drop_last=False, pin_memory=self.__pin_memory) self.__model = self.__model.to(self.__device) @@ -113,7 +114,7 @@ def _execute(self): if self.__mini_batch_eval: if self.__eval_together is False: if hasattr(self.__model, "evaluate_func") and isinstance(self.__model.evaluate_func, Callable): - acc_val, acc_test = self.evaluate_func(self.__val_loader, self.__test_loader, self.__device) + acc_val, acc_test = self.__model.evaluate_func(self.__val_loader, self.__test_loader, self.__device) else: acc_val, acc_test = mini_batch_evaluate(self.__model, self.__val_loader, self.__test_loader, self.__device) else: From d6294cba29be6b49a10d3a3db22917015d583393 Mon Sep 17 00:00:00 2001 From: infinity Date: Tue, 19 Dec 2023 06:47:35 +0000 Subject: [PATCH 24/28] add deterministic graph generation part of GAugM --- sgl/models/homo/gda/__init__.py | 5 +- sgl/models/homo/gda/gen_graphs.py | 201 ++++++++++++++++++++++++++++++ sgl/models/homo/gda/utils.py | 43 ++++++- 3 files changed, 247 insertions(+), 2 deletions(-) create mode 100644 sgl/models/homo/gda/gen_graphs.py diff --git a/sgl/models/homo/gda/__init__.py b/sgl/models/homo/gda/__init__.py index 8d3bbf4..ed111c1 100644 --- a/sgl/models/homo/gda/__init__.py +++ b/sgl/models/homo/gda/__init__.py @@ -1,9 +1,12 @@ from .GAug import GAugO, GAugM from .FLAG import FLAG, SampleFLAG +from .gen_graphs import graph_generate, VGAE __all__ = [ "GAugO", "GAugM", "FLAG", - "SampleFLAG" + "SampleFLAG", + "graph_generate", + "VGAE" ] \ No newline at end of file diff --git a/sgl/models/homo/gda/gen_graphs.py b/sgl/models/homo/gda/gen_graphs.py new file mode 100644 index 0000000..4da6017 --- /dev/null +++ b/sgl/models/homo/gda/gen_graphs.py @@ -0,0 +1,201 @@ +import os +import copy +import torch +import torch.nn as nn +from torch.optim import Adam +import torch.nn.functional as F +import argparse +import numpy as np +import pickle as pkl +import scipy.sparse as sp +from sklearn.preprocessing import normalize +from torch_geometric.utils import negative_sampling, from_scipy_sparse_matrix + +import sgl.dataset as Dataset +from sgl.tasks.utils import set_seed +from sgl.operators.graph_op import LaplacianGraphOp +from sgl.utils import sparse_mx_to_torch_sparse_tensor +from utils import sparse_to_tuple, get_scores_gen_graphs + +class GraphConv(nn.Module): + def __init__(self, input_dim, output_dim, activation=True): + super(GraphConv, self).__init__() + self.weight = self.glorot_init(input_dim, output_dim) + self.activation = activation + + def glorot_init(self, input_dim, output_dim): + init_range = np.sqrt(6.0 / (input_dim + output_dim)) + initial = torch.rand(input_dim, output_dim) * 2 * init_range - init_range + return nn.Parameter(initial) + + def forward(self, adj, inputs): + x = inputs @ self.weight + x = adj @ x + if self.activation: + return F.elu(x) + else: + return x + +class VGAE(nn.Module): + def __init__(self, dim_in, dim_h, dim_z, gae): + super(VGAE,self).__init__() + self.dim_z = dim_z + self.gae = gae + self.base_gcn = GraphConv(dim_in, dim_h) + self.gcn_mean = GraphConv(dim_h, dim_z, activation=False) + self.gcn_logstd = GraphConv(dim_h, dim_z, activation=False) + + def encode(self, adj, X): + hidden = self.base_gcn(adj, X) + self.mean = self.gcn_mean(adj, hidden) + if self.gae: + return self.mean + else: + self.logstd = self.gcn_logstd(adj, hidden) + gaussian_noise = torch.randn_like(self.mean) + sampled_z = gaussian_noise * torch.exp(self.logstd) + self.mean + return sampled_z + + def decode(self, Z): + A_pred = Z @ Z.T + return A_pred + + def forward(self, adj, X): + Z = self.encode(adj, X) + A_pred = self.decode(Z) + return A_pred + +def prepare_data(dataset, val_frac, test_frac, no_mask, norm_feat=True): + adj_ori, features_orig = dataset.adj, dataset.x + if adj_ori.diagonal().sum() > 0: + adj_ori = sp.coo_matrix(adj_ori) + adj_ori.setdiag(0) + adj_ori.eliminate_zeros() + adj_ori = sp.csr_matrix(adj_ori) + if isinstance(features_orig, torch.Tensor): + features_orig = features_orig.numpy() + features_orig = sp.csr_matrix(features_orig) + if norm_feat: + features_orig = normalize(features_orig, norm="l1", axis=1) + adj_triu = sp.triu(adj_ori) + edges = sparse_to_tuple(adj_triu)[0] + num_val = int(np.floor(edges.shape[0] * val_frac)) + num_test = int(np.floor(edges.shape[0] * test_frac)) + + all_edge_idx = list(range(edges.shape[0])) + np.random.shuffle(all_edge_idx) + val_edge_idx = all_edge_idx[:num_val] + test_edge_idx = all_edge_idx[num_val:(num_val+num_test)] + val_edges = edges[val_edge_idx] + test_edges = edges[test_edge_idx] + if no_mask: + train_edges = edges + else: + train_edge_idx = all_edge_idx[num_val+num_test:] + train_edges = edges[train_edge_idx] + + num_nodes = adj_ori.shape[0] + test_edges_false = negative_sampling(from_scipy_sparse_matrix(adj_ori+sp.eye(adj_ori.shape[0]))[0], num_nodes, num_test) + test_edges_false = test_edges_false.numpy() + + val_edges_false = negative_sampling(from_scipy_sparse_matrix(adj_ori+sp.eye(adj_ori.shape[0]))[0], num_nodes, num_val) + val_edges_false = val_edges_false.numpy() + + adj_train = sp.csr_matrix((np.ones(train_edges.shape[0]), (train_edges[:, 0], train_edges[:, 1])), shape=adj_ori.shape) + adj_train = adj_train + adj_train.T + adj_norm = LaplacianGraphOp()._construct_adj(adj_train) + adj_norm = sparse_mx_to_torch_sparse_tensor(adj_norm) + adj_label = adj_train + sp.eye(adj_train.shape[0]) + adj_label = sparse_mx_to_torch_sparse_tensor(adj_label) + features = sparse_mx_to_torch_sparse_tensor(features_orig) + + return features, adj_ori, adj_train, adj_norm, adj_label, val_edges, val_edges_false, test_edges, test_edges_false + +def train_model(data, model, lr, epochs, gae, device, verbose=False, criterion="roc"): + features, _, adj_train, adj_norm, adj_label, val_edges, val_edges_false, test_edges, test_edges_false = data + optimizer = Adam(model.parameters(), lr=lr) + adj_t = adj_train + norm_w = adj_t.shape[0]**2 / float((adj_t.shape[0]**2 - adj_t.sum()) * 2) + pos_weight = torch.FloatTensor([float(adj_t.shape[0]**2 - adj_t.sum()) / adj_t.sum()]).to(device) + features = features.to(device) + adj_norm = adj_norm.to(device) + adj_label = adj_label.to_dense().to(device) + best_val = 0 + best_state_dict = None + model.train() + for epoch in range(epochs): + adj_pred = model(adj_norm, features) + optimizer.zero_grad() + loss = norm_w * F.binary_cross_entropy_with_logits(adj_pred, adj_label, pos_weight=pos_weight) + if gae is False: + kl_divergence = 0.5 / adj_pred.size(0) * (1 + 2 * model.logstd - model.mean**2 - torch.exp(2*model.logstd)).sum(1).mean() + loss -= kl_divergence + + adj_pred = torch.sigmoid(adj_pred).detach().cpu() + scores_val = get_scores_gen_graphs(val_edges, val_edges_false, adj_pred, adj_label) + if verbose: + print("Epoch{:3}: train_loss: {:.4f} recon_acc: {:.4f} val_roc: {:.4f} val_ap: {:.4f} val_f1: {:.4f}".format( + epoch+1, loss.item(), scores_val["acc"], scores_val["roc"], scores_val["ap"], scores_val["f1"])) + if scores_val[criterion] > best_val: + best_val = scores_val[criterion] + best_state_dict = copy.deepcopy(model.state_dict()) + if verbose: + scores_test = get_scores_gen_graphs(test_edges, test_edges_false, adj_pred, adj_label) + print("test_roc: {:.4f} test_ap: {:.4f} test_f1: {:.4f} test_recon_acc: {:.4f}".format( + scores_test["roc"], scores_test["ap"], scores_test["f1"], scores_test["acc"])) + loss.backward() + optimizer.step() + + model.load_state_dict(best_state_dict) + return model + +def graph_generate(dataset, model, lr, epochs, val_frac, test_frac, no_mask, num_gen_graphs, device, criterion, norm_feat=True, gae=True, verbose=False): + data = prepare_data(dataset, val_frac, test_frac, no_mask, norm_feat) + model = model.to(device) + model = train_model(data, model, lr, epochs, gae, device, verbose, criterion) + adj_ori = data[1] + save_dir = os.path.join(dataset.processed_dir, "GAugM_edge_probabilities") + if gae: + save_path = os.path.join(save_dir, "0_gae.pkl") + else: + save_path = os.path.join(save_dir, "0.pkl") + pkl.dump(adj_ori, open(save_path, "wb")) + features = data[0].to(device) + adj_norm = data[3].to(device) + for i in range(num_gen_graphs): + with torch.no_grad(): + adj_pred = model(adj_norm, features) + adj_pred = torch.sigmoid(adj_pred).detach().cpu() + adj_recon = adj_pred.numpy() + np.fill_diagonal(adj_recon, 0) + if gae: + save_path = os.path.join(save_dir, f"{i+1}_logits_gae.pkl") + else: + save_path = os.path.join(save_dir, f"{i+1}_logits.pkl") + pkl.dump(adj_recon, open(save_path, "wb")) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="generate graphs for GAugM") + parser.add_argument("--emb_size", type=int, default=16) + parser.add_argument("--hidden_size", type=int, default=32) + parser.add_argument("--epochs", type=int, default=200) + parser.add_argument("--seed", type=int, default=7) + parser.add_argument("--num_gen_graphs", type=int, default=0) + parser.add_argument("--lr", type=float, default=0.01, help="learning rate") + parser.add_argument("--val_frac", type=float, default=0.05) + parser.add_argument("--test_frac", type=float, default=0.1) + parser.add_argument("--dataset_classname", type=str, default="Planetoid") + parser.add_argument("--dataset_name", type=str, default="cora") + parser.add_argument("--criterion", type=str, default="roc") + parser.add_argument("--no_mask", action="store_true") + parser.add_argument("--gae", action="store_true") + parser.add_argument("--root", type=str, default="/home/ssq/test_data/") + parser.add_argument("--device", type=int, default=0) + args = parser.parse_args() + + device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu" + set_seed(args.seed) + + dataset = getattr(Dataset, args.dataset_classname)(root=args.root, name=args.dataset_name) + model = VGAE(dataset.num_features, args.hidden_size, args.emb_size, args.gae) + graph_generate(dataset, model, args.lr, args.epochs, args.val_frac, args.test_frac, args.no_mask, args.num_gen_graphs, device, args.criterion, True, args.gae, verbose=True) \ No newline at end of file diff --git a/sgl/models/homo/gda/utils.py b/sgl/models/homo/gda/utils.py index e1d0b8f..75f0b99 100644 --- a/sgl/models/homo/gda/utils.py +++ b/sgl/models/homo/gda/utils.py @@ -1,4 +1,9 @@ +import copy import torch +import warnings +import numpy as np +import scipy.sparse as sp +from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_curve, auc class RoundNoGradient(torch.autograd.Function): @staticmethod @@ -17,4 +22,40 @@ def forward(ctx, x): @staticmethod def backward(ctx, g): - return g \ No newline at end of file + return g + +def sparse_to_tuple(sparse_mx): + if not sp.isspmatrix_coo(sparse_mx): + sparse_mx = sparse_mx.tocoo() + coords = np.vstack((sparse_mx.row, sparse_mx.col)).transpose() + values = sparse_mx.data + shape = sparse_mx.shape + return coords, values, shape + +def get_scores_gen_graphs(edges_pos, edges_neg, adj_pred, adj_label): + # get logists and labels + preds = adj_pred[edges_pos.T] + preds_neg = adj_pred[edges_neg] + logists = np.hstack([preds, preds_neg]) + labels = np.hstack([np.ones(preds.size(0)), np.zeros(preds_neg.size(0))]) + roc_auc = roc_auc_score(labels, logists) + ap_score = average_precision_score(labels, logists) + precisions, recalls, thresholds = precision_recall_curve(labels, logists) + pr_auc = auc(recalls, precisions) + warnings.simplefilter("ignore", RuntimeWarning) + f1s = np.nan_to_num(2 * precisions * recalls / (precisions + recalls)) + best_comb = np.argmax(f1s) + f1 = f1s[best_comb] + pre = precisions[best_comb] + rec = recalls[best_comb] + thresh = thresholds[best_comb] + # calc reconstracted adj_mat and accuracy with the threshold for best f1 + adj_rec = copy.deepcopy(adj_pred) + adj_rec[adj_rec < thresh] = 0 + adj_rec[adj_rec >= thresh] = 1 + labels_all = adj_label.view(-1).long().cpu() + preds_all = adj_rec.view(-1).long() + recon_acc = (preds_all == labels_all).sum().float() / labels_all.size(0) + results = {"roc": roc_auc, "pr": pr_auc, "ap": ap_score, "pre": pre, "rec": rec, "f1": f1, "acc": recon_acc, "adj_recon": adj_rec} + + return results \ No newline at end of file From 94daf5654ddbf1f7da96127812e6fff5ed2a3751 Mon Sep 17 00:00:00 2001 From: infinity Date: Sun, 24 Dec 2023 02:23:01 +0000 Subject: [PATCH 25/28] add Mixup model, which supports minibatch training. add HPO codes for GAug --- .gitignore | 4 +- examples/GDA/configs/GAugM.yml | 2 +- examples/GDA/configs/GAugO.yml | 76 ++-- examples/GDA/configs/GAugOMini.yml | 30 ++ examples/GDA/configs/Mixup.yml | 20 + examples/GDA/configs/SampleMixup.yml | 36 ++ examples/GDA/test_GAug.py | 2 +- examples/GDA/test_Mixup.py | 31 ++ examples/GDA/test_SampleMixup.py | 53 +++ examples/GDA/test_search_GAug.py | 46 ++ sgl/data/base_data.py | 5 +- sgl/models/base_model.py | 6 +- sgl/models/homo/gda/FLAG.py | 5 +- sgl/models/homo/gda/GAug.py | 44 +- sgl/models/homo/gda/Mixup.py | 407 ++++++++++++++++++ sgl/models/homo/gda/__init__.py | 5 +- sgl/models/homo/gda/gen_graphs.py | 2 +- sgl/models/pyg_simple_models.py | 6 +- sgl/models/simple_models.py | 8 +- sgl/sampler/__init__.py | 2 +- sgl/search/gda_hpo/GAug_search_config.py | 137 ++++++ sgl/search/gda_hpo/search_config.py | 14 + sgl/search/search_config.py | 5 - sgl/search/search_config_dist.py | 1 - sgl/search/search_models.py | 2 +- sgl/tasks/__init__.py | 2 +- .../node_classification_GAug.py | 374 ++++++++++++++++ sgl/tasks/node_classification.py | 103 +++-- sgl/tasks/node_classification_GAug.py | 272 ------------ sgl/tasks/node_classification_sampling.py | 118 ++--- 30 files changed, 1375 insertions(+), 443 deletions(-) create mode 100644 examples/GDA/configs/GAugOMini.yml create mode 100644 examples/GDA/configs/Mixup.yml create mode 100644 examples/GDA/configs/SampleMixup.yml create mode 100644 examples/GDA/test_Mixup.py create mode 100644 examples/GDA/test_SampleMixup.py create mode 100644 examples/GDA/test_search_GAug.py create mode 100644 sgl/models/homo/gda/Mixup.py create mode 100644 sgl/search/gda_hpo/GAug_search_config.py create mode 100644 sgl/search/gda_hpo/search_config.py create mode 100644 sgl/tasks/gda_specific_tasks/node_classification_GAug.py delete mode 100644 sgl/tasks/node_classification_GAug.py diff --git a/.gitignore b/.gitignore index 5e704a5..fe6ba9d 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,6 @@ __pycache__/ *.egg-info/ build/ -dist/ \ No newline at end of file +dist/ + +logs/ \ No newline at end of file diff --git a/examples/GDA/configs/GAugM.yml b/examples/GDA/configs/GAugM.yml index e5a3b8e..6c55082 100644 --- a/examples/GDA/configs/GAugM.yml +++ b/examples/GDA/configs/GAugM.yml @@ -6,7 +6,7 @@ model: model_name: 'GAugM' gnn_type: 'gcn' feat_norm: 'row' - hidden_dim: 128 + hidden_dim: 256 dropout: 0.5 n_layers: 2 choose_idx: 5 diff --git a/examples/GDA/configs/GAugO.yml b/examples/GDA/configs/GAugO.yml index 9fce08b..d1e52e3 100644 --- a/examples/GDA/configs/GAugO.yml +++ b/examples/GDA/configs/GAugO.yml @@ -24,57 +24,57 @@ # pretrain_ep: 160 # pretrain_nc: 30 # max_patience: 50 +# dataset: +# classname: "Planetoid" +# name: "cora" +# root: "/home/ssq/test_data/" +# model: +# model_name: 'GAugO' +# gnn_type: 'gsage' +# alpha: 0.13 +# temperature: 1.0 +# hidden_dim: 256 +# emb_size: 32 +# dropout: 0.5 +# n_layers: 2 +# gae: true +# feat_norm: 'row' +# normalize: True +# sample_type: 'add_sample' +# task: +# lr: 0.01 +# seed: 42 +# warmup: 2 +# beta: 3.2 +# epochs: 200 +# weight_decay: 0.0005 +# pretrain_ep: 195 +# pretrain_nc: 35 +# max_patience: 50 dataset: - classname: "Planetoid" - name: "cora" - root: "/home/ssq/test_data/" + classname: "Planetoid" + name: "cora" + root: "/home/ssq/test_data/" model: model_name: 'GAugO' - gnn_type: 'gsage' - alpha: 0.13 - temperature: 1.0 + gnn_type: 'gat' + alpha: 0.02 + temperature: 1.7 hidden_dim: 128 emb_size: 32 - dropout: 0.5 + dropout: 0.6 n_layers: 2 + activation: "elu" gae: true feat_norm: 'row' - normalize: True sample_type: 'add_sample' task: lr: 0.01 seed: 42 - warmup: 2 + warmup: 1 beta: 3.2 epochs: 200 weight_decay: 0.0005 - pretrain_ep: 195 - pretrain_nc: 35 + pretrain_ep: 175 + pretrain_nc: 45 max_patience: 50 -# dataset: -# classname: "Planetoid" -# name: "cora" -# root: "/home/ssq/test_data/" -# model: -# model_name: 'GAugO' -# gnn_type: 'gat' -# alpha: 0.02 -# temperature: 1.7 -# hidden_dim: 128 -# emb_size: 32 -# dropout: 0.6 -# n_layers: 2 -# activation: "elu" -# gae: true -# feat_norm: 'row' -# sample_type: 'add_sample' -# task: -# lr: 0.01 -# seed: 42 -# warmup: 1 -# beta: 3.2 -# epochs: 200 -# weight_decay: 0.0005 -# pretrain_ep: 175 -# pretrain_nc: 45 -# max_patience: 50 diff --git a/examples/GDA/configs/GAugOMini.yml b/examples/GDA/configs/GAugOMini.yml new file mode 100644 index 0000000..a56eea9 --- /dev/null +++ b/examples/GDA/configs/GAugOMini.yml @@ -0,0 +1,30 @@ +dataset: + classname: "Planetoid" + name: "pubmed" + root: "/home/ssq/test_data/" +model: + model_name: 'GAugO' + gnn_type: 'gcn' + alpha: 1.0 + temperature: 0.2 + hidden_dim: 128 + emb_size: 64 + dropout: 0.5 + n_layers: 2 + gae: true + feat_norm: 'row' + sample_type: 'add_sample' + minibatch: True +task: + lr: 0.01 + ep_lr: 0.002 + seed: 42 + warmup: 0 + beta: 0.8 + epochs: 200 + weight_decay: 0.0005 + pretrain_ep: 160 + pretrain_nc: 30 + max_patience: 50 + train_batch_size: 250 + pretrain_batch_size: 4096 \ No newline at end of file diff --git a/examples/GDA/configs/Mixup.yml b/examples/GDA/configs/Mixup.yml new file mode 100644 index 0000000..ec38742 --- /dev/null +++ b/examples/GDA/configs/Mixup.yml @@ -0,0 +1,20 @@ +dataset: + classname: "Planetoid" + name: "pubmed" + root: "/home/ssq/test_data/" + split: "full" +model: + gnn_type: 'sage' + hidden_dim: 256 + dropout: 0.5 + alpha: 4 + beta: 4 + n_layers: 3 + feat_norm: "none" + batch_norm: True +task: + lr: 0.01 + seed: 12345 + epochs: 300 + patience: 30 + weight_decay: 0 \ No newline at end of file diff --git a/examples/GDA/configs/SampleMixup.yml b/examples/GDA/configs/SampleMixup.yml new file mode 100644 index 0000000..07ec815 --- /dev/null +++ b/examples/GDA/configs/SampleMixup.yml @@ -0,0 +1,36 @@ +dataset: + classname: "Ogbn" + name: "arxiv" + root: "/home/ssq/test_data/" +sampler: + training: + name: "NeighborSampler" + layer_sizes: "15, 10, 5" + prob_type: "normalize" + replace: False + eval: + name: "NeighborSampler" + layer_sizes: "-1" + replace: False +model: + gnn_type: 'sage' + hidden_dim: 256 + dropout: 0.5 + alpha: 4 + beta: 4 + n_layers: 3 + feat_norm: "none" + batch_norm: True +task: + name: "NodeClassification_Sampling" + lr: 0.003 + seed: 12345 + epochs: 300 + patience: 30 + weight_decay: 0 + train_batch_size: 1024 + eval_batch_size: 4096 + train_num_workers: 12 + eval_num_workers: 12 + eval_together: True + eval_freq: 2 \ No newline at end of file diff --git a/examples/GDA/test_GAug.py b/examples/GDA/test_GAug.py index 74ccd3b..3aa60c0 100644 --- a/examples/GDA/test_GAug.py +++ b/examples/GDA/test_GAug.py @@ -6,7 +6,7 @@ from sgl.tasks import NodeClassificationGAugO, NodeClassificationGAugM if __name__ == "__main__": - parser = argparse.ArgumentParser(description = "GAug-Model.") + parser = argparse.ArgumentParser(description="GAug-Model.") parser.add_argument( "--device", type=int, default=0, help="gpu device id or cpu (-1)" ) diff --git a/examples/GDA/test_Mixup.py b/examples/GDA/test_Mixup.py new file mode 100644 index 0000000..6e70d46 --- /dev/null +++ b/examples/GDA/test_Mixup.py @@ -0,0 +1,31 @@ +import yaml +import argparse + +import sgl.dataset as Dataset +from sgl.models.homo.gda import Mixup +from sgl.tasks import NodeClassification + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description = "Mixup-Model.") + parser.add_argument( + "--device", type=int, default=0, help="gpu device id or cpu (-1)" + ) + parser.add_argument( + "--config_path", type=str, default="./configs/Mixup.yml", help="save path of the configuration file" + ) + args = parser.parse_args() + config = yaml.safe_load(open(args.config_path, "rb")) + device = f"cuda:{args.device}" if args.device >= 0 else "cpu" + dataset_kwargs = config["dataset"] + model_kwargs = config["model"] + task_kwargs = config["task"] + + dataset_classname = dataset_kwargs.pop("classname") + dataset = getattr(Dataset, dataset_classname)(**dataset_kwargs) + for seed in range(10): + model = Mixup(in_dim=dataset.num_features, n_classes=dataset.num_classes, **model_kwargs) + task_kwargs.update({"loss_fn": model.loss_fn}) + task_kwargs.update({"device": device}) + task_kwargs.update({"seed": seed}) + test_acc = NodeClassification(dataset, model, **task_kwargs).test_acc + print(f"test acc: {test_acc:.4f}") \ No newline at end of file diff --git a/examples/GDA/test_SampleMixup.py b/examples/GDA/test_SampleMixup.py new file mode 100644 index 0000000..3c17562 --- /dev/null +++ b/examples/GDA/test_SampleMixup.py @@ -0,0 +1,53 @@ +import yaml +import argparse +import scipy.sparse as sp + +import sgl.tasks as Tasks +import sgl.dataset as Dataset +import sgl.sampler as Sampler +from sgl.models.homo.gda import SampleMixup + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Sampler-Models") + parser.add_argument( + "--device", type=int, default=0, help="gpu device id or cpu (-1)" + ) + parser.add_argument( + "--config_path", type=str, default="./configs/SampleMixup.yml", help="save path of the configuration file" + ) + args = parser.parse_args() + config = yaml.safe_load(open(args.config_path, "rb")) + device = f"cuda:{args.device}" if args.device >= 0 else "cpu" + dataset_kwargs = config["dataset"] + task_kwargs = config["task"] + classname = dataset_kwargs.pop("classname") + dataset = getattr(Dataset, classname)(**dataset_kwargs) + adj_matrix = dataset.adj + if isinstance(adj_matrix, sp.coo_matrix) is False: + adj_matrix = sp.coo_matrix(adj_matrix) + adj_matrix.setdiag(0) + adj_matrix = adj_matrix.tocsr() + training_sampler_kwargs = config["sampler"]["training"] + if "inductive" in training_sampler_kwargs.keys(): + inductive = training_sampler_kwargs.pop("inductive") + else: + inductive = False + task_kwargs.update({"inductive": inductive}) + training_sampler_name = training_sampler_kwargs.pop("name") + training_sampler_kwargs.update({"save_dir": dataset.processed_dir}) + training_sampler = getattr(Sampler, training_sampler_name)(adj_matrix[dataset.train_idx, :][:, dataset.train_idx] if inductive else adj_matrix, **training_sampler_kwargs) + if "eval" in config["sampler"].keys(): + eval_sampler_kwargs = config["sampler"]["eval"] + eval_sampler_name = eval_sampler_kwargs.pop("name") + eval_sampler_kwargs.update({"save_dir": dataset.processed_dir}) + eval_sampler = getattr(Sampler, eval_sampler_name)(adj_matrix, **eval_sampler_kwargs) + else: + eval_sampler = None + model_kwargs = config["model"] + model = SampleMixup(training_sampler, eval_sampler, in_dim=dataset.num_features, n_classes=dataset.num_classes, **model_kwargs) + task_kwargs.update({"device": device}) + task_kwargs.update({"loss_fn": model.loss_fn}) + task_name = task_kwargs.pop("name") + test_acc = getattr(Tasks, task_name)(dataset, model, **task_kwargs).test_acc + print(f"final test acc: {test_acc}") \ No newline at end of file diff --git a/examples/GDA/test_search_GAug.py b/examples/GDA/test_search_GAug.py new file mode 100644 index 0000000..6b18ac7 --- /dev/null +++ b/examples/GDA/test_search_GAug.py @@ -0,0 +1,46 @@ +import torch +import argparse +from openbox import Optimizer + +import sgl.dataset as Dataset +from sgl.search.gda_hpo.GAug_search_config import GAugOConfigManager, GAugMConfigManager + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="HPO-GAug-Model.") + parser.add_argument("--device", type=int, default=0, help="gpu device id or cpu(-1)") + parser.add_argument("--dataset_classname", type=str, default="Planetoid", help="class name of the dataset") + parser.add_argument("--name", type=str, default="cora", help="dataset name") + parser.add_argument("--root", type=str, default="/home/ssq/test_data/", help="root dir for dataset") + parser.add_argument("--gnn_type", type=str, default="gcn", choices=["gcn", "gsage", "gat"], help="gnn backbone") + parser.add_argument("--not_gae", action="store_true", default=False, help="whether not to use gae") + parser.add_argument("--minibatch", action="store_true", default=False, help="whether to use minibatch") + parser.add_argument("--pretrain_batch_size", type=int, default=-1, help="batch size when pretraining ep net") + parser.add_argument("--train_batch_size", type=int, default=-1, help="batch size when training") + parser.add_argument("--model", type=str, default="GAugO", choices=["GAugO", "GAugM"], help="choose the target mnodel") + parser.add_argument("--num_logits", type=int, default=10, help="number of candidate edge logits") + parser.add_argument("--runs_per_config", type=int, default=5, help="repeat runs for each configuration") + parser.add_argument("--max_patience", type=int, default=50, help="patience for early stop") + args = parser.parse_args() + device = f"cuda:{args.device}" if args.device >= 0 and torch.cuda.is_available() else "cpu" + dataset = getattr(Dataset, args.dataset_classname)(name=args.name, root=args.root) + pretrain_batch_size = args.pretrain_batch_size if args.pretrain_batch_size > 0 else None + train_batch_size = args.train_batch_size if args.train_batch_size > 0 else None + if args.model == "GAugO": + configer = GAugOConfigManager(dataset, args.gnn_type, not args.not_gae, device, minibatch=args.minibatch, pretrain_batch_size=pretrain_batch_size, train_batch_size=train_batch_size, runs=args.runs_per_config, max_patience=args.max_patience) + else: + configer = GAugMConfigManager(dataset, args.gnn_type, not args.not_gae, device, args.num_logits, runs=args.runs_per_config, max_patience=args.max_patience) + + opt = Optimizer(configer._configFunction, + configer._configSpace(), + num_objectives=1, + num_constraints=0, + max_runs=400, + surrogate_type="prf", + acq_type='ei', + acq_optimizer_type='local_random', + initial_runs=20, + task_id='quick_start', + random_state=1) + + history = opt.run() + print(history) \ No newline at end of file diff --git a/sgl/data/base_data.py b/sgl/data/base_data.py index f2e2cad..342513f 100644 --- a/sgl/data/base_data.py +++ b/sgl/data/base_data.py @@ -4,6 +4,7 @@ import scipy.sparse as sp from scipy.sparse import csr_matrix from torch_sparse import SparseTensor +from torch_geometric.utils import from_scipy_sparse_matrix from sgl.utils import sparse_mx_to_torch_sparse_tensor, sparse_mx_to_pyg_sparse_tensor # A lighter wrapper class for sampled adjacency matrices, @@ -44,8 +45,10 @@ def to_device(self, device): if isinstance(self.__adjs[0], sp.spmatrix): if self.__sparse_type == "pyg": self.__adjs = [sparse_mx_to_pyg_sparse_tensor(adj) for adj in self.__adjs] - else: + elif self.__sparse_type == "torch": self.__adjs = [sparse_mx_to_torch_sparse_tensor(adj) for adj in self.__adjs] + else: + self.__adjs = [from_scipy_sparse_matrix(adj)[0] for adj in self.__adjs] self.__adjs = [adj.to(device) for adj in self.__adjs] self.__device = device diff --git a/sgl/models/base_model.py b/sgl/models/base_model.py index 8458a44..01baf6a 100644 --- a/sgl/models/base_model.py +++ b/sgl/models/base_model.py @@ -3,7 +3,6 @@ import torch.nn.functional as F from sgl.data.base_data import Block from sgl.data.base_dataset import HeteroNodeDataset -from sgl.utils import sparse_mx_to_torch_sparse_tensor, sparse_mx_to_pyg_sparse_tensor class BaseSGAPModel(nn.Module): @@ -70,7 +69,7 @@ class BaseSAMPLEModel(nn.Module): def __init__(self, evaluate_mode="full", sparse_type="pyg"): super(BaseSAMPLEModel, self).__init__() self._evaluate_mode = evaluate_mode - if sparse_type not in ["pyg", "torch"]: + if sparse_type not in ["pyg", "torch", "2d-tensor"]: raise ValueError(f"sparse type {sparse_type} is not supported, please use either pyg or torch.") self._sparse_type = sparse_type self._pre_graph_op, self._post_graph_op = None, None @@ -97,6 +96,9 @@ def train_collate_fn(self): def eval_collate_fn(self): return self._eval_sampling_op.collate_fn + def reset_parameters(self): + self._base_model.reset_parameters() + def mini_batch_prepare_forward(self, batch, device, inductive=False, transfer_y_to_device=True): batch_in, batch_out, block = batch diff --git a/sgl/models/homo/gda/FLAG.py b/sgl/models/homo/gda/FLAG.py index e8b8c15..56f42bd 100644 --- a/sgl/models/homo/gda/FLAG.py +++ b/sgl/models/homo/gda/FLAG.py @@ -152,7 +152,7 @@ def flag(self, clean, ground_truth_y, adjs, batch_out, optimizer, device, loss_f return loss.item(), pred_y - def mini_batch_prepare_forward(self, batch, device, loss_fn, optimizers, inductive=False, transfer_y_to_device=True): + def mini_batch_prepare_forward(self, batch, device, loss_fn, optimizer, inductive=False, transfer_y_to_device=True): batch_in, batch_out, block = batch if inductive is False: @@ -166,7 +166,7 @@ def mini_batch_prepare_forward(self, batch, device, loss_fn, optimizers, inducti y_truth = y_truth.to(device) block.to_device(device) - loss, pred_y = self.flag(in_x, y_truth, block, batch_out, optimizers, device, loss_fn) + loss, pred_y = self.flag(in_x, y_truth, block, batch_out, optimizer, device, loss_fn) return loss, pred_y, y_truth @@ -207,6 +207,7 @@ def evaluate_func(self, val_loader, test_loader, device): pred = test_output.max(1)[1].type_as(out_y) correct_num_test += pred.eq(out_y).double().sum() test_num += len(out_y) + acc_test = correct_num_test / test_num return acc_val.item(), acc_test.item() diff --git a/sgl/models/homo/gda/GAug.py b/sgl/models/homo/gda/GAug.py index 0daf218..41ba487 100644 --- a/sgl/models/homo/gda/GAug.py +++ b/sgl/models/homo/gda/GAug.py @@ -21,6 +21,7 @@ def __init__(self, in_dim, hidden_dim, emb_size, n_classes, n_layers, dropout, g self.__temperature = temperature self.__alpha = alpha self.__sample_type = sample_type + self.__minibatch = kwargs.pop("minibatch", False) # edge prediction network self.__gae = gae self.__feat_norm = feat_norm @@ -52,6 +53,10 @@ def col_normalization(features): features /= s return torch.FloatTensor(features) + def reset_parameters(self): + self.ep_net.reset_parameters() + self.nc_net.reset_parameters() + def preprocess(self, features, adj_matrix, device): if self.__feat_norm == "row": features = F.normalize(features, p=1, dim=1) @@ -62,12 +67,16 @@ def preprocess(self, features, adj_matrix, device): assert sp.issparse(adj_matrix) if not isinstance(adj_matrix, sp.coo_matrix): adj_matrix = sp.coo_matrix(adj_matrix) + adj_matrix.setdiag(0) # remove incompelte self-loops before adding self-loops adj_matrix_sl = adj_matrix + sp.eye(*adj_matrix.shape) - adj_orig = sparse_mx_to_pyg_sparse_tensor(adj_matrix_sl).to_dense().to(device) + adj_orig = sparse_mx_to_pyg_sparse_tensor(adj_matrix_sl).to_dense() adj_norm_matrix = LaplacianGraphOp()._construct_adj(adj_matrix) adj_norm = sparse_mx_to_pyg_sparse_tensor(adj_norm_matrix).to(device) adj = sparse_mx_to_pyg_sparse_tensor(adj_matrix).to(device) + if self.__minibatch is False: + adj_orig = adj_orig.to(device) + return features, adj_orig, adj, adj_norm @staticmethod @@ -83,9 +92,8 @@ def sample_adj(adj_logits, temp): @staticmethod def sample_adj_add_bernoulli(adj_logits, adj_orig, alpha, temp): - edge_probs = adj_logits / torch.max(adj_logits) + edge_probs = adj_logits / (torch.max(adj_logits) + 1e-5) edge_probs = alpha * edge_probs + (1-alpha) * adj_orig - # sampling adj_sampled = pyro.distributions.RelaxedBernoulliStraightThrough(temperature=temp, probs=edge_probs).rsample() # making adj_sampled symmetric adj_sampled = adj_sampled.triu(1) @@ -145,8 +153,8 @@ def sample_adj_edge(adj_logits, adj_orig, change_frac): adj_new = adj_new + mask_add return adj_new - def forward(self, adj_norm, adj_orig, features): - adj_logits = self.ep_net(adj_norm, features) + def forward(self, adj_norm, adj_orig, features, nodes_batch=None): + adj_logits = self.ep_net(adj_norm, features, nodes_batch) if self.__sample_type == "edge": adj_new = self.sample_adj_edge(adj_logits, adj_orig, self.__alpha) elif self.__sample_type == "add_round": @@ -161,11 +169,13 @@ def forward(self, adj_norm, adj_orig, features): row, col = adj_new.nonzero(as_tuple=True) edge_index = torch.vstack([row, col]) - nc_logits = self.nc_net(features, edge_index) + if nodes_batch is not None: + nc_logits = self.nc_net(features[nodes_batch], edge_index) + else: + nc_logits = self.nc_net(features, edge_index) return nc_logits, adj_logits - class VGAE(nn.Module): """ GAE/VGAE as edge prediction model """ def __init__(self, in_dim, hidden_dim, emb_size, activation, gae=False): @@ -176,7 +186,12 @@ def __init__(self, in_dim, hidden_dim, emb_size, activation, gae=False): self.gcn_mean = GCNConv(hidden_dim, emb_size, add_self_loops=False, normalize=False, bias=False) self.gcn_logstd = GCNConv(hidden_dim, emb_size, add_self_loops=False, normalize=False, bias=False) - def forward(self, adj, features): + def reset_parameters(self): + self.gcn_base.reset_parameters() + self.gcn_mean.reset_parameters() + self.gcn_logstd.reset_parameters() + + def forward(self, adj, features, nodes_batch=None): # GCN encoder hidden = self.gcn_base(features, adj) self.mean = self.activation(self.gcn_mean(hidden, adj)) @@ -189,19 +204,22 @@ def forward(self, adj, features): gaussian_noise = torch.randn_like(self.mean) sampled_Z = gaussian_noise * torch.exp(self.logstd) + self.mean Z = sampled_Z + if nodes_batch is not None: + Z = Z[nodes_batch] # inner product decoder adj_logits = Z @ Z.T return adj_logits class GAugM(nn.Module): - def __init__(self, in_dim, hidden_dim, n_classes, n_layers, gnn_type, rm_pct, add_pct, choose_idx, dropout=0.5, activation=F.relu, feat_norm='none', **kwargs): + def __init__(self, in_dim, hidden_dim, n_classes, n_layers, gnn_type, rm_pct, add_pct, choose_idx, gae=False, dropout=0.5, activation=F.relu, feat_norm='none', **kwargs): super(GAugM, self).__init__() self.__feat_norm = feat_norm self.__rm_pct = rm_pct self.__add_pct = add_pct self.__choose_idx = choose_idx + self.__gae = gae if isinstance(activation, str): activation = getattr(F, activation) gnn_backbone = {"gcn": GCN, "gsage": SAGE, "gat": GAT} @@ -214,6 +232,9 @@ def __init__(self, in_dim, hidden_dim, n_classes, n_layers, gnn_type, rm_pct, ad self.nc_net = gnn_backbone.get(gnn_type)(in_dim, hidden_dim, n_classes, n_layers=n_layers, dropout=dropout, activation=activation, **kwargs) + def reset_parameters(self): + self.nc_net.reset_parameters() + @staticmethod def sample_graph_det(adj_orig, adj_pred, remove_pct, add_pct): if remove_pct == 0 and add_pct == 0: @@ -259,7 +280,10 @@ def preprocess(self, adj_orig, features, adj_pred_dir, device): features = F.normalize(features, p=1, dim=1) features = features.to(device) - adj_pred = pkl.load(open(os.path.join(adj_pred_dir, f"{self.__choose_idx}_logits.pkl"), "rb")) + if self.__gae is True: + adj_pred = pkl.load(open(os.path.join(adj_pred_dir, f"{self.__choose_idx}_logits_gae.pkl"), "rb")) + else: + adj_pred = pkl.load(open(os.path.join(adj_pred_dir, f"{self.__choose_idx}_logits.pkl"), "rb")) adj_pred = self.sample_graph_det(adj_orig, adj_pred, self.__rm_pct, self.__add_pct) adj_processed = sparse_mx_to_pyg_sparse_tensor(adj_pred).to(device) diff --git a/sgl/models/homo/gda/Mixup.py b/sgl/models/homo/gda/Mixup.py new file mode 100644 index 0000000..3007006 --- /dev/null +++ b/sgl/models/homo/gda/Mixup.py @@ -0,0 +1,407 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import scipy.sparse as sp +from torch_sparse import SparseTensor +from torch_geometric.nn import SAGEConv + +from sgl.data.base_data import Block +from sgl.models.base_model import BaseSAMPLEModel + +class Mixup(nn.Module): + def __init__(self, in_dim, hidden_dim, n_classes, n_layers, dropout, alpha, beta, gnn_type="sage", feat_norm="row", activation=F.relu, **kwargs): + super(Mixup, self).__init__() + self.__alpha = alpha + self.__beta = beta + self.__feat_norm = feat_norm + self.nc_net = TwoBranchGNN(in_dim, hidden_dim, n_classes, n_layers, dropout, gnn_type, activation, **kwargs) + + def preprocess(self, adj, features, device): + if self.__feat_norm == "row": + features = F.normalize(features, p=1, dim=1) + self.__num_nodes = features.size(0) + self.__features = features.to(device) + if isinstance(adj, sp.coo_matrix) is False: + adj = sp.coo_matrix(adj) + adj.setdiag(0) + self.__row = torch.from_numpy(adj.row).to(torch.long) + self.__col = torch.from_numpy(adj.col).to(torch.long) + self.__adj = torch.vstack([self.__row, self.__col]).to(device) + + @property + def processed_feature(self): + return self.__features + + @property + def processed_block(self): + return self.__adj + + @staticmethod + def loss_fn(mix_ratio, output, y_raw, y_b, train_idx): + loss = F.nll_loss(output[train_idx], y_raw[train_idx]) * mix_ratio + \ + F.nll_loss(output[train_idx], y_b[train_idx]) * (1 - mix_ratio) + return loss + + def reset_parameters(self): + self.nc_net.reset_parameters() + + def train_func(self, train_idx, y_raw, device, optimizer, loss_fn, metric): + self.nc_net.train() + mix_ratio = np.random.beta(self.__alpha, self.__beta) + id_old_value_new, adj_b, y_b = self._mixup(train_idx, y_raw, device) + output = self.nc_net(self.__features, self.__adj, adj_b, mix_ratio, id_old_value_new) + + loss = loss_fn(mix_ratio, output, y_raw, y_b, train_idx) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + self.nc_net.eval() + output = self.forward(self.__features, self.__adj) + acc = metric(output[train_idx], y_raw[train_idx]) + + return loss.item(), acc + + @torch.no_grad() + def evaluate_func(self, val_idx, test_idx, labels, device, metric): + self.nc_net.eval() + + pred_y = self.forward(self.__features, self.__adj) + + acc_val = metric(pred_y[val_idx], labels[val_idx]) + acc_test = metric(pred_y[test_idx], labels[test_idx]) + + return acc_val, acc_test + + def _mixup(self, train_idx, y_raw, device): + id_old_value_new = torch.arange(self.__num_nodes, dtype=torch.long) + train_idx_shuffle = np.asarray(train_idx) + np.random.shuffle(train_idx_shuffle) + # map raw node id to its pair node id + id_old_value_new[train_idx] = torch.from_numpy(train_idx_shuffle).to(torch.long) + id_new_value_old = torch.zeros_like(id_old_value_new) + # map the pair node id to the raw node id + id_new_value_old[id_old_value_new] = torch.arange(self.__num_nodes, dtype=torch.long) + row_b = id_old_value_new[self.__row] + col_b = id_old_value_new[self.__col] + adj_b = torch.vstack([row_b, col_b]).to(device) + y_b = y_raw[id_old_value_new] + + return id_old_value_new, adj_b, y_b + + def model_forward(self, idx, device): + output = self.forward(self.__features, self.__adj) + + return output[idx] + + def forward(self, x, adj): + output = self.nc_net(x, adj, adj, 1, np.arange(self.__num_nodes)) + + return output + + def postprocess(self, adj, output): + return output + + +class SampleMixup(BaseSAMPLEModel): + def __init__(self, training_sampler, eval_sampler, in_dim, hidden_dim, n_classes, n_layers, dropout, alpha, beta, gnn_type="sage", feat_norm="row", activation=F.relu, **kwargs): + super(SampleMixup, self).__init__(sparse_type="pyg") + self.__alpha = alpha + self.__beta = beta + self.__feat_norm = feat_norm + self._training_sampling_op = training_sampler + self._eval_sampling_op = eval_sampler + self._base_model = MinibatchTwoBranchGNN(in_dim, hidden_dim, n_classes, n_layers, dropout, gnn_type, activation, **kwargs) + + def preprocess(self, adj, x, y, device, **kwargs): + if self.__feat_norm == "row": + x = F.normalize(x, p=1, dim=1) + self.__num_nodes = x.size(0) + self.__features = x.to(device) + if isinstance(adj, sp.coo_matrix) is False: + adj = sp.coo_matrix(adj) + adj.setdiag(0) + self.__adj = Block(adj, sparse_type="pyg") + + self.__vanilla_y = y + + inductive = kwargs.get("inductive", False) + if inductive is True: + train_idx = kwargs.get("train_idx", None) + if train_idx is None: + raise ValueError(f"For inductive learning, " + "please pass train idx " + "as the parameters of preprocess function.") + self.__train_features = x[train_idx] + self.__vanilla_train_y = y[train_idx] + + @property + def processed_feature(self): + return self.__features + + @property + def processed_block(self): + return self.__adj + + @staticmethod + def loss_fn(mix_ratio, output, y_raw, y_b): + loss = F.nll_loss(output, y_raw) * mix_ratio + \ + F.nll_loss(output, y_b) * (1 - mix_ratio) + return loss + + def mini_batch_prepare_forward(self, batch, device, loss_fn, optimizer, inductive=False, transfer_y_to_device=True, mix_ratio=1): + batch_in, batch_out, block = batch + + if inductive is False: + in_x = self.__features[batch_in].to(device) + y_raw = self.__vanilla_y[batch_out] + else: + in_x = self.__train_features[batch_in].to(device) + y_raw = self.__vanilla_train_y[batch_out] + + if transfer_y_to_device is True: + y_raw = y_raw.to(device) + + id_old_value_new, block_b, y_b = self._mixup(batch_out.shape[0], batch_in.shape[0], block, y_raw) + block.to_device(device) + block_b.to_device(device) + output = self._base_model(in_x, block, block_b, mix_ratio, id_old_value_new) + + loss = loss_fn(mix_ratio, output, y_raw, y_b) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + return loss.item(), output, y_raw + + def train_func(self, train_loader, inductive, device, optimizer, loss_fn): + correct_num = 0 + loss_train_sum = 0. + train_num = 0 + + self._base_model.train() + mix_ratio = np.random.beta(self.__alpha, self.__beta) + + for batch in train_loader: + loss_train, y_out, y_truth = self.mini_batch_prepare_forward(batch, device, loss_fn, optimizer, inductive=inductive, mix_ratio=mix_ratio) + pred = y_out.max(1)[1].type_as(y_truth) + correct_num += pred.eq(y_truth).double().sum() + loss_train_sum += loss_train + train_num += len(y_truth) + + loss_train = loss_train_sum / len(train_loader) + acc_train = correct_num / train_num + + return loss_train, acc_train.item() + + @torch.no_grad() + def evaluate_func(self, val_loader, test_loader, device): + self._base_model.eval() + + correct_num_val, correct_num_test = 0, 0 + val_num = 0 + for batch in val_loader: + val_output, out_y = self.model_forward(batch, device) + pred = val_output.max(1)[1].type_as(out_y) + correct_num_val += pred.eq(out_y).double().sum() + val_num += len(out_y) + + acc_val = correct_num_val / val_num + + test_num = 0 + for batch in test_loader: + test_output, out_y = self.model_forward(batch, device) + pred = test_output.max(1)[1].type_as(out_y) + correct_num_test += pred.eq(out_y).double().sum() + test_num += len(out_y) + + acc_test = correct_num_test / test_num + + return acc_val.item(), acc_test.item() + + def _mixup(self, num_train_nodes, batch_size, block, y_raw): + id_old_value_new = torch.arange(batch_size, dtype=torch.long) + train_idx_shuffle = np.arange(num_train_nodes) + np.random.shuffle(train_idx_shuffle) + # map raw node id to its pair node id + id_old_value_new[:num_train_nodes] = torch.from_numpy(train_idx_shuffle).to(torch.long) + id_new_value_old = torch.zeros_like(id_old_value_new) + # map the pair node id to the raw node id + id_new_value_old[id_old_value_new] = torch.arange(batch_size, dtype=torch.long) + adjs_b = [] + for i in range(len(block)): + adj = block[i] + if isinstance(adj, sp.coo_matrix) is False: + adj = sp.coo_matrix(adj) + row, col = adj.row, adj.col + row_b = id_old_value_new[row] + col_b = id_old_value_new[col] + adj_b = SparseTensor(row=row_b, col=col_b, value=torch.ones_like(row_b)) + adjs_b.append(adj_b) + + block_b = Block(adjs_b, sparse_type="pyg") + + y_b = y_raw[train_idx_shuffle] + + return id_old_value_new, block_b, y_b + + def postprocess(self, adj, output): + return output + + def model_forward(self, batch_in, block, device): + x = self.__features[batch_in].to(device) + block.to_device(device) + output = self.forward(x, block) + + return output + + def forward(self, x, block): + output = self._base_model(x, block, block, 1, np.arange(self.__num_nodes)) + + return output + +class TwoBranchGNN(nn.Module): + def __init__(self, in_dim, hidden_dim, n_classes, n_layers, dropout, gnn_type, activation=F.relu, **kwargs): + super(TwoBranchGNN, self).__init__() + self.gcs = nn.ModuleList() + if gnn_type != "sage": + raise NotImplementedError + self.gcs.append(SAGEConv(in_dim, hidden_dim)) + self.batch_norm = kwargs.get("batch_norm", False) + if self.batch_norm: + self.bns = nn.ModuleList() + self.bns.append(nn.BatchNorm1d(hidden_dim)) + for _ in range(n_layers-1): + self.gcs.append(SAGEConv(hidden_dim, hidden_dim)) + if self.batch_norm: + self.bns.append(nn.BatchNorm1d(hidden_dim)) + self.lin = nn.Linear(hidden_dim, n_classes) + self.n_layers = n_layers + self.dropout = dropout + self.activation = activation + + def reset_parameters(self): + for gc in self.gcs: + gc.reset_parameters() + self.lin.reset_parameters() + + def forward(self, x0, adj, adj_b, mix_ratio, id_old_value_new): + aggr_xs = [x0] + for i in range(self.n_layers-1): + x = self.gcs[i](aggr_xs[-1], adj) + if self.batch_norm: + x = self.bns[i](x) + x = self.activation(x) + x = F.dropout(x, p=self.dropout, training=self.training) + aggr_xs.append(x) + + aggr_xs_b = [] + for x in aggr_xs: + aggr_xs_b.append(x[id_old_value_new]) + + x_mix = aggr_xs[0] * mix_ratio + aggr_xs_b[0] * (1 - mix_ratio) + for i in range(self.n_layers): + x_new = self.gcs[i]((aggr_xs[i], x_mix), adj) + if self.batch_norm: + x_new = self.bns[i](x_new) + x_new = self.activation(x_new) + + x_new_b = self.gcs[i]((aggr_xs_b[i], x_mix), adj_b) + if self.batch_norm: + x_new_b = self.bns[i](x_new_b) + x_new_b = self.activation(x_new_b) + + x_mix = x_new * mix_ratio + x_new_b * (1 - mix_ratio) + x_mix = F.dropout(x_mix, self.dropout, training=self.training) + + x = self.lin(x_mix) + return F.log_softmax(x, dim=-1) + +class MinibatchTwoBranchGNN(nn.Module): + def __init__(self, in_dim, hidden_dim, n_classes, n_layers, dropout, gnn_type, activation=F.relu, **kwargs): + super(MinibatchTwoBranchGNN, self).__init__() + self.gcs = nn.ModuleList() + if gnn_type != "sage": + raise NotImplementedError + self.gcs.append(SAGEConv(in_dim, hidden_dim)) + self.batch_norm = kwargs.get("batch_norm", False) + if self.batch_norm: + self.bns = nn.ModuleList() + self.bns.append(nn.BatchNorm1d(hidden_dim)) + for _ in range(n_layers-1): + self.gcs.append(SAGEConv(hidden_dim, hidden_dim)) + if self.batch_norm: + self.bns.append(nn.BatchNorm1d(hidden_dim)) + self.lin = nn.Linear(hidden_dim, n_classes) + self.n_layers = n_layers + self.dropout = dropout + self.activation = activation + + def reset_parameters(self): + for gc in self.gcs: + gc.reset_parameters() + self.lin.reset_parameters() + + def forward(self, x0, block, block_b, mix_ratio, id_old_value_new): + aggr_xs = [x0] + for i in range(self.n_layers): + root_size = block.root_size(i) + root_x = aggr_xs[-1][:root_size] + x = self.gcs[i]((aggr_xs[-1], root_x), block[i]) + if self.batch_norm: + x = self.bns[i](x) + x = self.activation(x) + x = F.dropout(x, p=self.dropout, training=self.training) + aggr_xs.append(x) + + aggr_xs_b = [] + for x in aggr_xs: + num_nodes = x.size(0) + aggr_xs_b.append(x[id_old_value_new[:num_nodes]]) + + x_mix = aggr_xs[0] * mix_ratio + aggr_xs_b[0] * (1 - mix_ratio) + for i in range(self.n_layers): + root_size = block.root_size(i) + root_x = x_mix[:root_size] + x_new = self.gcs[i]((aggr_xs[i], root_x), block[i]) + if self.batch_norm: + x_new = self.bns[i](x_new) + x_new = self.activation(x_new) + + root_size = block_b.root_size(i) + root_x_b = x_mix[:root_size] + x_new_b = self.gcs[i]((aggr_xs_b[i], root_x_b), block_b[i]) + if self.batch_norm: + x_new_b = self.bns[i](x_new_b) + x_new_b = self.activation(x_new_b) + x_mix = x_new * mix_ratio + x_new_b * (1 - mix_ratio) + x_mix = F.dropout(x_mix, self.dropout, training=self.training) + + x = self.lin(x_mix) + return F.log_softmax(x, dim=-1) + + @torch.no_grad() + def inference(self, x_all, subgraph_loader, device): + + for i in range(self.n_layers): + xs = [] + for batch in subgraph_loader: + batch_in, batch_out, block = batch + block.to_device(device) + x = x_all[batch_in].to(device) + root_size = len(batch_out) + root_x = x[:root_size] + x = self.gcs[i]((x, root_x), block[0]) # one-layer sampling + if self.batch_norm: + x = self.bns[i](x) + x = self.activation(x) + if i == self.n_layers-1: + x = self.lin(x) + xs.append(x.cpu()) + + x_all = torch.cat(xs, dim=0) + + return x_all + + \ No newline at end of file diff --git a/sgl/models/homo/gda/__init__.py b/sgl/models/homo/gda/__init__.py index ed111c1..b0537a7 100644 --- a/sgl/models/homo/gda/__init__.py +++ b/sgl/models/homo/gda/__init__.py @@ -1,5 +1,6 @@ from .GAug import GAugO, GAugM from .FLAG import FLAG, SampleFLAG +from .Mixup import Mixup, SampleMixup from .gen_graphs import graph_generate, VGAE __all__ = [ @@ -8,5 +9,7 @@ "FLAG", "SampleFLAG", "graph_generate", - "VGAE" + "VGAE", + "Mixup", + "SampleMixup" ] \ No newline at end of file diff --git a/sgl/models/homo/gda/gen_graphs.py b/sgl/models/homo/gda/gen_graphs.py index 4da6017..62de9cd 100644 --- a/sgl/models/homo/gda/gen_graphs.py +++ b/sgl/models/homo/gda/gen_graphs.py @@ -15,7 +15,7 @@ from sgl.tasks.utils import set_seed from sgl.operators.graph_op import LaplacianGraphOp from sgl.utils import sparse_mx_to_torch_sparse_tensor -from utils import sparse_to_tuple, get_scores_gen_graphs +from sgl.models.homo.gda.utils import sparse_to_tuple, get_scores_gen_graphs class GraphConv(nn.Module): def __init__(self, input_dim, output_dim, activation=True): diff --git a/sgl/models/pyg_simple_models.py b/sgl/models/pyg_simple_models.py index f3763b3..fe6f6f5 100644 --- a/sgl/models/pyg_simple_models.py +++ b/sgl/models/pyg_simple_models.py @@ -22,7 +22,7 @@ def __init__(self, n_feat, n_hid, n_class, n_layers=2, dropout=0.5, activation=F self.dropout = dropout self.activation = activation - def reset_parameter(self): + def reset_parameters(self): for conv in self.gcs: conv.reset_parameters() if self.batch_norm: @@ -97,7 +97,7 @@ def __init__(self, n_feat, n_hid, n_class, n_layers=2, dropout=0.5, activation=F self.dropout = dropout self.activation = activation - def reset_parameter(self): + def reset_parameters(self): for conv in self.gcs: conv.reset_parameters() if self.batch_norm: @@ -180,7 +180,7 @@ def __init__(self, n_feat, n_hid, n_class, n_heads, n_layers=2, dropout=0.6, act self.dropout = dropout self.activation = activation - def reset_parameter(self): + def reset_parameters(self): for conv in self.gcs: conv.reset_parameters() if self.batch_norm: diff --git a/sgl/models/simple_models.py b/sgl/models/simple_models.py index 56c36fb..c32fed7 100644 --- a/sgl/models/simple_models.py +++ b/sgl/models/simple_models.py @@ -328,7 +328,7 @@ def __init__(self, n_feat, n_hid, n_class, n_layers=2, dropout=0.5, activation=F self.dropout = dropout self.activation = activation - def reset_parameter(self): + def reset_parameters(self): for conv in self.gcs: conv.reset_parameters() if self.batch_norm: @@ -400,7 +400,7 @@ def __init__(self, n_feat, n_hid, n_class, n_layers=2, dropout=0.5, activation=F self.dropout = dropout self.activation = activation - def reset_parameter(self): + def reset_parameters(self): for conv in self.gcs: conv.reset_parameters() if self.batch_norm: @@ -468,6 +468,10 @@ def __init__(self, n_feat, n_hid, n_class, n_heads, n_layers=2, dropout=0.6, act self.dropout = dropout self.activation = activation + def reset_parameters(self): + for gc in self.gcs: + gc.reset_parameters() + def forward(self, x, block): repr = x if isinstance(block, torch.Tensor): diff --git a/sgl/sampler/__init__.py b/sgl/sampler/__init__.py index 1df6716..1bfb294 100644 --- a/sgl/sampler/__init__.py +++ b/sgl/sampler/__init__.py @@ -1,4 +1,4 @@ -from .sampler import FastGCNSampler, ClusterGCNSampler, GraphSAINTSampler,NeighborSampler +from .sampler import FastGCNSampler, ClusterGCNSampler, GraphSAINTSampler, NeighborSampler from .base_sampler import FullSampler, NodeWiseSampler, LayerWiseSampler, GraphWiseSampler __all__ = [ diff --git a/sgl/search/gda_hpo/GAug_search_config.py b/sgl/search/gda_hpo/GAug_search_config.py new file mode 100644 index 0000000..c3835e5 --- /dev/null +++ b/sgl/search/gda_hpo/GAug_search_config.py @@ -0,0 +1,137 @@ +import torch.nn.functional as F +from openbox import space as osp + +from sgl.models.homo.gda import GAugO, GAugM +from sgl.tasks import NodeClassificationGAugO, NodeClassificationGAugM +from sgl.search.gda_hpo.search_config import BaseGDAConfigManager + +class GAugOConfigManager(BaseGDAConfigManager): + def __init__(self, dataset, gnn_type, gae, device, runs=5, activation=F.relu, minibatch=False, epochs=200, max_patience=100, pretrain_batch_size=None, train_batch_size=None): + super(GAugOConfigManager, self).__init__() + # basic information + self.__dataset = dataset + self.__gnn_type = gnn_type + self.__gae = gae + self.__minibatch = minibatch + self.__activation = activation + self.__epochs = epochs + self.__device = device + self.__max_patience = max_patience + self.__pretrain_batch_size = pretrain_batch_size + self.__train_batch_size = train_batch_size + self.__runs = runs + self.__config_space = osp.Space() + # model hyperparameters + alpha = osp.Real("alpha", 0, 1, default_value=0.4, q=0.01) + temperature = osp.Real("temperature", 0.1, 2.1, default_value=1.5, q=0.1) + hidden_dim = osp.Categorical("hidden_dim", [32, 64, 128, 256], default_value=128) + emb_size = osp.Constant("emb_size", 32) + n_layers = osp.Constant("n_layers", 2) + dropout = osp.Constant("dropout", 0.5) + feat_norm = osp.Constant("feat_norm", "row") + # task hyperparameters + lr = osp.Constant("lr", 0.01) if self.__gnn_type != "gat" else osp.Constant("lr", 0.005) + if pretrain_batch_size is not None: + ep_lr = osp.Real("ep_lr", 0.001, 0.01, default_value=0.002, q=0.001) + else: + ep_lr = osp.Constant("ep_lr", 0.01) + weight_decay = osp.Constant("weight_decay", 0.0005) + warmup = osp.Int("warmup", 0, 10, default_value=2, q=1) + beta = osp.Real("beta", 0, 4, default_value=2, q=0.1) + pretrain_ep = osp.Int("pretrain_ep", 5, 300, default_value=100, q=5) + pretrain_nc = osp.Int("pretrain_nc", 5, 300, default_value=200, q=5) + self.__config_space.add_variables([alpha, temperature, hidden_dim, emb_size, n_layers, dropout, feat_norm, \ + lr, weight_decay, warmup, beta, pretrain_ep, pretrain_nc, ep_lr]) + + def _configSpace(self): + return self.__config_space + + def _configTarget(self, params): + model_kwargs = dict() + model_kwargs["in_dim"] = self.__dataset.num_features + model_kwargs["hidden_dim"] = params["hidden_dim"] + model_kwargs["emb_size"] = params["emb_size"] + model_kwargs["n_classes"] = self.__dataset.num_classes + model_kwargs["n_layers"] = params["n_layers"] + model_kwargs["dropout"] = params["dropout"] + model_kwargs["gnn_type"] = self.__gnn_type + model_kwargs["activation"] = self.__activation + model_kwargs["temperature"] = params["temperature"] + model_kwargs["gae"] = self.__gae + model_kwargs["alpha"] = params["alpha"] + model_kwargs["feat_norm"] = params["feat_norm"] + model_kwargs["minibatch"] = self.__minibatch + model = GAugO(**model_kwargs) + task_kwargs = dict() + task_kwargs["lr"] = params["lr"] + task_kwargs["weight_decay"] = params["weight_decay"] + task_kwargs["epochs"] = self.__epochs + task_kwargs["device"] = self.__device + task_kwargs["beta"] = params["beta"] + task_kwargs["warmup"] = params["warmup"] + task_kwargs["max_patience"] = self.__max_patience + task_kwargs["pretrain_ep"] = params["pretrain_ep"] + task_kwargs["pretrain_nc"] = params["pretrain_nc"] + task_kwargs["pretrain_batch_size"] = self.__pretrain_batch_size + task_kwargs["train_batch_size"] = self.__train_batch_size + task_kwargs["ep_lr"] = params["ep_lr"] + task = NodeClassificationGAugO(self.__dataset, model, runs=self.__runs, verbose=False, **task_kwargs) + acc_res = task._execute() + + return dict(objectives=[-acc_res]) + +class GAugMConfigManager(BaseGDAConfigManager): + def __init__(self, dataset, gnn_type, gae, device, num_logits, runs=5, activation=F.relu, epochs=200, max_patience=100): + super(GAugMConfigManager, self).__init__() + # basic information + self.__dataset = dataset + self.__gnn_type = gnn_type + self.__gae = gae + self.__activation = activation + self.__device = device + self.__epochs = epochs + self.__max_patience = max_patience + self.__runs = runs + self.__config_space = osp.Space() + # model hyperparameters + choose_idx = osp.Int("choose_idx", 1, num_logits, default_value=1, q=1) + rm_pct = osp.Int("rm_pct", 0, 80, default_value=20, q=1) + add_pct = osp.Int("add_pct", 0, 80, default_value=20, q=1) + hidden_dim = osp.Categorical("hidden_dim", [32, 64, 128, 256], default_value=128) + n_layers = osp.Constant("n_layers", 2) + dropout = osp.Constant("dropout", 0.5) + feat_norm = osp.Constant("feat_norm", "row") + # task hyperparameters + lr = osp.Constant("lr", 0.01) + weight_decay = osp.Constant("weight_decay", 0.0005) + self.__config_space.add_variables([choose_idx, rm_pct, add_pct, hidden_dim, n_layers, dropout, feat_norm, \ + lr, weight_decay]) + + def _configSpace(self): + return self.__config_space + + def _configTarget(self, params): + model_kwargs = dict() + model_kwargs["in_dim"] = self.__dataset.num_features + model_kwargs["hidden_dim"] = params["hidden_dim"] + model_kwargs["n_classes"] = self.__dataset.num_classes + model_kwargs["n_layers"] = params["n_layers"] + model_kwargs["dropout"] = params["dropout"] + model_kwargs["gnn_type"] = self.__gnn_type + model_kwargs["activation"] = self.__activation + model_kwargs["gae"] = self.__gae + model_kwargs["feat_norm"] = params["feat_norm"] + model_kwargs["choose_idx"] = params["choose_idx"] + model_kwargs["rm_pct"] = params["rm_pct"] + model_kwargs["add_pct"] = params["add_pct"] + model = GAugM(**model_kwargs) + task_kwargs = dict() + task_kwargs["lr"] = params["lr"] + task_kwargs["weight_decay"] = params["weight_decay"] + task_kwargs["epochs"] = self.__epochs + task_kwargs["device"] = self.__device + task_kwargs["max_patience"] = self.__max_patience + task = NodeClassificationGAugM(self.__dataset, model, runs=self.__runs, verbose=False, **task_kwargs) + acc_res = task._execute() + + return dict(objectives=[-acc_res]) \ No newline at end of file diff --git a/sgl/search/gda_hpo/search_config.py b/sgl/search/gda_hpo/search_config.py new file mode 100644 index 0000000..f98aa3a --- /dev/null +++ b/sgl/search/gda_hpo/search_config.py @@ -0,0 +1,14 @@ +from openbox import space as osp + +class BaseGDAConfigManager(): + def __init__(self): + super(BaseGDAConfigManager, self).__init__() + self.__config_space = None + + def _configTarget(self, params): + raise NotImplementedError + + def _configFunction(self, config_space: osp.Configuration): + params = config_space.get_dictionary().copy() + result = self._configTarget(params) + return result diff --git a/sgl/search/search_config.py b/sgl/search/search_config.py index e22b8fa..591b74f 100644 --- a/sgl/search/search_config.py +++ b/sgl/search/search_config.py @@ -1,9 +1,4 @@ import numpy as np -<<<<<<< Updated upstream -from sgl.search.search_models import SearchModel -from sgl.search.auto_search import SearchManager -======= ->>>>>>> Stashed changes from openbox.utils.config_space import ConfigurationSpace, UniformIntegerHyperparameter from sgl.search.auto_search import SearchManager diff --git a/sgl/search/search_config_dist.py b/sgl/search/search_config_dist.py index 97c2316..dbd1685 100644 --- a/sgl/search/search_config_dist.py +++ b/sgl/search/search_config_dist.py @@ -1,4 +1,3 @@ -import argparse import numpy as np from sgl.search.auto_search_dist import SearchManagerDist from sgl.search.search_models_dist import SearchModelDist diff --git a/sgl/search/search_models.py b/sgl/search/search_models.py index fe6a9e4..c4195df 100644 --- a/sgl/search/search_models.py +++ b/sgl/search/search_models.py @@ -2,7 +2,7 @@ from sgl.models.simple_models import LogisticRegression, ResMultiLayerPerceptron from sgl.operators.graph_op import LaplacianGraphOp, PprGraphOp from sgl.operators.message_op import LastMessageOp, ConcatMessageOp, MeanMessageOp, SimpleWeightedMessageOp, \ - LearnableWeightedMessageOp, IterateLearnableWeightedMessageOp, SumMessageOp, MaxMessageOp, MinMessageOp + LearnableWeightedMessageOp, SumMessageOp, MaxMessageOp, MinMessageOp class SearchModel(BaseSGAPModel): diff --git a/sgl/tasks/__init__.py b/sgl/tasks/__init__.py index 4f6683a..6a9f9d8 100644 --- a/sgl/tasks/__init__.py +++ b/sgl/tasks/__init__.py @@ -8,7 +8,7 @@ from .correct_and_smooth import NodeClassification_With_CorrectAndSmooth from .node_classification_with_label_use import NodeClassificationWithLabelUse from .node_classification_dist import NodeClassificationDist -from .node_classification_GAug import NodeClassificationGAugO, NodeClassificationGAugM +from .gda_specific_tasks.node_classification_GAug import NodeClassificationGAugO, NodeClassificationGAugM __all__ = [ "NodeClassification", diff --git a/sgl/tasks/gda_specific_tasks/node_classification_GAug.py b/sgl/tasks/gda_specific_tasks/node_classification_GAug.py new file mode 100644 index 0000000..03d835a --- /dev/null +++ b/sgl/tasks/gda_specific_tasks/node_classification_GAug.py @@ -0,0 +1,374 @@ +import gc +import os +import time +import torch +import torch.nn as nn +from torch.optim import Adam +from torch.utils.data import DataLoader +import torch.nn.functional as F +import numpy as np +import scipy.sparse as sp + +from sgl.tasks.base_task import BaseTask +from sgl.tasks.utils import set_seed, accuracy, MultipleOptimizer + +class NodeClassificationGAugO(BaseTask): + def __init__(self, dataset, model, lr, weight_decay, epochs, device, beta, warmup, max_patience, pretrain_ep, pretrain_nc, runs=1, verbose=True, seed=12345, pretrain_batch_size=None, train_batch_size=None, ep_lr=None): + super(NodeClassificationGAugO, self).__init__() + + self.__dataset = dataset + self.__labels = self.__dataset.y + + self.__model = model + self.__optimizer = MultipleOptimizer(Adam(model.ep_net.parameters(), lr=lr), + Adam(model.nc_net.parameters(), lr=lr, weight_decay=weight_decay)) + + self.__lr = lr + self.__ep_lr = ep_lr if ep_lr is not None else lr + self.__weight_decay = weight_decay + + self.__epochs = epochs + self.__device = device + self.__seed = seed + self.__runs = runs + self.__verbose = verbose + + self.__warmup = warmup + self.__beta = beta + self.__max_patience = max_patience + + self.__pretrain_ep = pretrain_ep + self.__pretrain_nc = pretrain_nc + self.__pretrain_batch_size = pretrain_batch_size + self.__train_batch_size = train_batch_size + + self.__test_acc = self._execute() + + @property + def test_acc(self): + return self.__test_acc + + @staticmethod + def get_lr_schedule_by_sigmoid(n_epochs, lr, warmup): + """ schedule the learning rate with the sigmoid function. + The learning rate will start with near zero and end with near lr """ + factors = torch.FloatTensor(np.arange(n_epochs)) + factors = ((factors / factors[-1]) * (warmup * 2)) - warmup + factors = torch.sigmoid(factors) + # range the factors to [0, 1] + factors = (factors - factors[0]) / (factors[-1] - factors[0]) + lr_schedule = factors * lr + return lr_schedule + + @staticmethod + def loss_fn(nc_logits, norm_w, adj_logits, adj_orig, pos_weight, labels, global_idx, beta, local_idx=None): + if labels.dim() == 2: + nc_criterion = nn.BCEWithLogitsLoss() + else: + nc_criterion = nn.CrossEntropyLoss() + if local_idx is None: + local_idx = global_idx + loss = nc_criterion(nc_logits[local_idx], labels[global_idx]) + ep_loss = norm_w * F.binary_cross_entropy_with_logits(adj_logits, adj_orig, pos_weight=pos_weight) + loss += beta * ep_loss + + return loss + + @staticmethod + def extend_batch(seed_batch, hops, adj_matrix): + nodes_batch = seed_batch + for _ in range(hops): + neigh_block = adj_matrix[nodes_batch] + nodes_batch = neigh_block.sum(0).nonzero()[1] + nodes_batch = np.setdiff1d(nodes_batch, seed_batch, assume_unique=True) + nodes_batch = np.concatenate((seed_batch, nodes_batch)) + return nodes_batch + + def _pretrain_ep_net(self, adj, features, adj_orig, norm_w, pos_weight): + """ pretrain the edge prediction network """ + optimizer = Adam(self.__model.ep_net.parameters(), lr=self.__ep_lr) + + self.__model.train() + for _ in range(self.__pretrain_ep): + adj_logits = self.__model.ep_net(adj, features) + loss = norm_w * F.binary_cross_entropy_with_logits(adj_logits, adj_orig, pos_weight=pos_weight) + if not self.__model.gae: + mu = self.__model.ep_net.mean + lgstd = self.__model.ep_net.logstd + kl_divergence = 0.5 / adj_logits.size(0) * (1 + 2*lgstd - mu**2 - torch.exp(2*lgstd)).sum(1).mean() + loss -= kl_divergence + optimizer.zero_grad() + loss.backward() + optimizer.step() + + def _minibatch_pretrain_ep_net(self, adj, features, adj_orig, norm_w, pos_weight): + """ pretrain the edge prediction network in mini-batches""" + optimizer = Adam(self.__model.ep_net.parameters(), lr=self.__ep_lr) + num_nodes = features.size(0) + train_loader = DataLoader(range(num_nodes), batch_size=self.__pretrain_batch_size, shuffle=True, drop_last=False) + + self.__model.train() + for _ in range(self.__pretrain_ep): + for node_batch in train_loader: + sub_adj_orig = adj_orig[node_batch][:, node_batch].to(self.__device) + sub_adj_logits = self.__model.ep_net(adj, features, node_batch) + loss = norm_w * F.binary_cross_entropy_with_logits(sub_adj_logits, sub_adj_orig, pos_weight=pos_weight) + if not self.__model.gae: + mu = self.__model.ep_net.mean + lgstd = self.__model.ep_net.logstd + kl_divergence = 0.5 / sub_adj_logits.size(0) * (1 + 2*lgstd - mu**2 - torch.exp(2*lgstd)).sum(1).mean() + loss -= kl_divergence + optimizer.zero_grad() + loss.backward() + optimizer.step() + + def _pretrain_nc_net(self, adj, features): + """ pretrain the node classification network """ + optimizer = Adam(self.__model.nc_net.parameters(), lr=self.__lr, weight_decay=self.__weight_decay) + # loss function for node classification + if self.__labels.dim() == 2: + nc_criterion = nn.BCEWithLogitsLoss() + else: + nc_criterion = nn.CrossEntropyLoss() + + for _ in range(self.__pretrain_nc): + self.__model.train() + nc_logits = self.__model.nc_net(features, adj) + # losses + loss = nc_criterion(nc_logits[self.__dataset.train_idx], self.__labels[self.__dataset.train_idx]) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + def _train(self, adj_norm, adj_orig, features, norm_w, pos_weight, epoch, ep_lr_schedule): + # update the learning rate for ep_net if needed + if self.__warmup: + self.__optimizer.update_lr(0, ep_lr_schedule[epoch]) + + self.__model.train() + nc_logits, adj_logits = self.__model(adj_norm, adj_orig, features) + loss_train = self.loss_fn(nc_logits, norm_w, adj_logits, adj_orig, pos_weight, self.__labels, self.__dataset.train_idx, self.__beta) + acc_train = accuracy(nc_logits[self.__dataset.train_idx], self.__labels[self.__dataset.train_idx]) + self.__optimizer.zero_grad() + loss_train.backward() + self.__optimizer.step() + + return loss_train, acc_train + + def _minibatch_train(self, adj_matrix, adj_norm, adj_orig, features, norm_w, pos_weight, epoch, ep_lr_schedule): + # update the learning rate for ep_net if needed + if self.__warmup: + self.__optimizer.update_lr(0, ep_lr_schedule[epoch]) + + seed_size = self.__train_batch_size // 20 + num_batches = int((len(self.__dataset.train_idx) + seed_size - 1) / seed_size) + node_idx_all = np.array(self.__dataset.train_idx) + np.random.shuffle(node_idx_all) + seed_batches = np.array_split(node_idx_all, num_batches) + + train_loss = 0. + train_num, num_correct = 0, 0 + self.__model.train() + for seed_batch in seed_batches: + nodes_batch = self.extend_batch(seed_batch, 2, adj_matrix) + if len(nodes_batch) >= self.__train_batch_size: + nodes_batch = nodes_batch[:self.__train_batch_size] + + sub_adj_orig = adj_orig[nodes_batch][:, nodes_batch].to(self.__device) + nc_logits, sub_adj_logits = self.__model(adj_norm, sub_adj_orig, features, nodes_batch) + num_correct += nc_logits[range(len(seed_batch))].argmax(dim=1).eq(self.__labels[seed_batch]).sum().long().item() + loss = self.loss_fn(nc_logits, norm_w, sub_adj_logits, sub_adj_orig, pos_weight, self.__labels, seed_batch, self.__beta, range(len(seed_batch))) + self.__optimizer.zero_grad() + loss.backward() + self.__optimizer.step() + train_loss += loss.item() * len(seed_batch) + train_num += len(seed_batch) + + return train_loss / train_num, num_correct / train_num + + def _evaluate(self, features, adj): + self.__model.eval() + with torch.no_grad(): + nc_logits_eval = self.__model.nc_net(features, adj) + acc_val = accuracy(nc_logits_eval[self.__dataset.val_idx], self.__labels[self.__dataset.val_idx]) + acc_test = accuracy(nc_logits_eval[self.__dataset.test_idx], self.__labels[self.__dataset.test_idx]) + + return acc_val, acc_test + + def _execute(self): + set_seed(self.__seed) + + features, adj_orig, adj, adj_norm = self.__model.preprocess(self.__dataset.x, self.__dataset.adj, self.__device) + if self.__train_batch_size is not None: + adj_matrix = sp.csr_matrix(adj_orig.numpy()) + + self.__model = self.__model.to(self.__device) + self.__labels = self.__labels.to(self.__device) + + # weights for log_lik loss when training EP net + norm_w = adj_orig.shape[0]**2 / float((adj_orig.shape[0]**2 - adj_orig.sum()) * 2) + pos_weight = torch.FloatTensor([float(adj_orig.shape[0]**2 - adj_orig.sum()) / adj_orig.sum()]).to(self.__device) + + acc_test_list = [] + for _ in range(self.__runs): + # reset model parameters at the beginning of each run + self.__model.reset_parameters() + # pretrain VGAE if needed + if self.__pretrain_ep: + if self.__pretrain_batch_size is None: + self._pretrain_ep_net(adj_norm, features, adj_orig, norm_w, pos_weight) + else: + self._minibatch_pretrain_ep_net(adj_norm, features, adj_orig, norm_w, pos_weight) + # pretrain GCN if needed + if self.__pretrain_nc: + self._pretrain_nc_net(adj, features) + # get the learning rate schedule for the optimizer of ep_net if needed + if self.__warmup: + ep_lr_schedule = self.get_lr_schedule_by_sigmoid(self.__epochs, self.__lr, self.__warmup) + else: + ep_lr_schedule = None + + # keep record of the best validation accuracy for early stopping + best_acc_val, best_acc_test, patience_step = 0., 0., 0 + # train model + for epoch in range(self.__epochs): + t = time.time() + + if self.__train_batch_size is None: + loss_train, acc_train = self._train(adj_norm, adj_orig, features, norm_w, pos_weight, epoch, ep_lr_schedule) + else: + loss_train, acc_train = self._minibatch_train(adj_matrix, adj_norm, adj_orig, features, norm_w, pos_weight, epoch, ep_lr_schedule) + acc_val, acc_test = self._evaluate(features, adj) + + if self.__verbose: + print('Epoch: {:03d}'.format(epoch + 1), + 'loss_train: {:.4f}'.format(loss_train), + 'acc_train: {:.4f}'.format(acc_train), + 'acc_val: {:.4f}'.format(acc_val), + 'acc_test: {:.4f}'.format(acc_test), + 'time: {:.4f}s'.format(time.time() - t)) + + if acc_val > best_acc_val: + best_acc_val = acc_val + best_acc_test = acc_test + patience_step = 0 + else: + patience_step += 1 + if patience_step == self.__max_patience: + break + + acc_test_list.append(best_acc_test) + + # release RAM and GPU memory + del adj, features, adj_orig, adj_norm + torch.cuda.empty_cache() + gc.collect() + + return np.mean(acc_test_list) + + +class NodeClassificationGAugM(BaseTask): + def __init__(self, dataset, model, lr, weight_decay, epochs, device, runs=1, verbose=True, loss_fn=nn.CrossEntropyLoss(), seed=42, max_patience=100): + super(NodeClassificationGAugM, self).__init__() + + self.__dataset = dataset + self.__labels = self.__dataset.y + + self.__model = model + self.__optimizer = Adam(model.parameters(), lr=lr, + weight_decay=weight_decay) + self.__epochs = epochs + self.__loss_fn = loss_fn + self.__device = device + self.__seed = seed + self.__max_patience = max_patience + self.__runs = runs + self.__verbose = verbose + + self.__test_acc = self._execute() + + @property + def test_acc(self): + return self.__test_acc + + def _train(self, adj_norm, features): + self.__model.train() + pred_y = self.__model(adj_norm, features)[self.__dataset.train_idx] + ground_truth_y = self.__labels[self.__dataset.train_idx] + loss_train = self.__loss_fn(pred_y, ground_truth_y) + acc_train = accuracy(pred_y, ground_truth_y) + + self.__optimizer.zero_grad() + loss_train.backward() + self.__optimizer.step() + + return loss_train, acc_train + + def _evaluate(self, adj_norm, features): + self.__model.eval() + with torch.no_grad(): + pred_y = self.__model(adj_norm, features) + acc_val = accuracy(pred_y[self.__dataset.val_idx], self.__labels[self.__dataset.val_idx]) + acc_test = accuracy(pred_y[self.__dataset.test_idx], self.__labels[self.__dataset.test_idx]) + + return acc_val, acc_test + + def _execute(self): + set_seed(self.__seed) + + pre_time_st = time.time() + adj_pred_dir = os.path.join(self.__dataset.processed_dir, "GAugM_edge_probabilities") + adj, features = self.__model.preprocess(self.__dataset.adj, self.__dataset.x, adj_pred_dir, self.__device) + pre_time_ed = time.time() + if self.__verbose: + print(f"Preprocessing done in {(pre_time_ed - pre_time_st):.4f}s") + + self.__model = self.__model.to(self.__device) + self.__labels = self.__labels.to(self.__device) + + acc_val_list = [] + acc_test_list = [] + + for _ in range(self.__runs): + self.__model.reset_parameters() + t_total = time.time() + best_val = 0. + best_test = 0. + patience = 0 + for epoch in range(self.__epochs): + t = time.time() + loss_train, acc_train = self._train(adj, features) + acc_val, acc_test = self._evaluate(adj, features) + + if self.__verbose: + print('Epoch: {:03d}'.format(epoch + 1), + 'loss_train: {:.4f}'.format(loss_train), + 'acc_train: {:.4f}'.format(acc_train), + 'acc_val: {:.4f}'.format(acc_val), + 'acc_test: {:.4f}'.format(acc_test), + 'time: {:.4f}s'.format(time.time() - t)) + + if acc_val > best_val: + best_val = acc_val + best_test = acc_test + patience = 0 + else: + patience += 1 + if patience == self.__max_patience: + break + + acc_val_list.append(best_val) + acc_test_list.append(best_test) + + mean_acc_test = np.mean(acc_test_list) + if self.__verbose: + print("Optimization Finished!") + print("Total time elapsed: {:.4f}s".format(time.time() - t_total)) + print(f'Best val: {np.mean(acc_val_list):.4f}, best test: {mean_acc_test:.4f}') + + del adj, features + torch.cuda.empty_cache() + gc.collect() + + return mean_acc_test + diff --git a/sgl/tasks/node_classification.py b/sgl/tasks/node_classification.py index ec83b10..21772b8 100644 --- a/sgl/tasks/node_classification.py +++ b/sgl/tasks/node_classification.py @@ -3,6 +3,7 @@ import torch.nn as nn from torch.optim import Adam from torch.utils.data import DataLoader +import numpy as np from typing import Callable from sgl.tasks.base_task import BaseTask @@ -11,7 +12,7 @@ class NodeClassification(BaseTask): def __init__(self, dataset, model, lr, weight_decay, epochs, device, loss_fn=nn.CrossEntropyLoss(), seed=42, - patience=100, train_batch_size=None, eval_batch_size=None): + patience=100, runs=1, verbose=True, train_batch_size=None, eval_batch_size=None): super(NodeClassification, self).__init__() self.__dataset = dataset @@ -25,6 +26,8 @@ def __init__(self, dataset, model, lr, weight_decay, epochs, device, loss_fn=nn. self.__device = device self.__seed = seed self.__patience = patience + self.__runs = runs + self.__verbose = verbose self.__mini_batch = False if train_batch_size is not None: @@ -50,61 +53,71 @@ def _execute(self): pre_time_st = time.time() self.__model.preprocess(self.__dataset.adj, self.__dataset.x, self.__device) pre_time_ed = time.time() - print(f"Preprocessing done in {(pre_time_ed - pre_time_st):.4f}s") + if self.__verbose: + print(f"Preprocessing done in {(pre_time_ed - pre_time_st):.4f}s") self.__model = self.__model.to(self.__device) self.__labels = self.__labels.to(self.__device) - t_total = time.time() - best_val = 0. - best_test = 0. - patience = 0 - for epoch in range(self.__epochs): - t = time.time() - if self.__mini_batch is False: - if hasattr(self.__model, "train_func") and isinstance(self.__model.train_func, Callable): - loss_train, acc_train = self.__model.train_func(self.__dataset.train_idx, self.__labels, self.__device, - self.__optimizer, self.__loss_fn, accuracy) + best_test_list = [] + for _ in range(self.__runs): + + self.__model.reset_parameters() + t_total = time.time() + best_val = 0. + best_test = 0. + patience = 0 + for epoch in range(self.__epochs): + t = time.time() + if self.__mini_batch is False: + if hasattr(self.__model, "train_func") and isinstance(self.__model.train_func, Callable): + loss_train, acc_train = self.__model.train_func(self.__dataset.train_idx, self.__labels, self.__device, + self.__optimizer, self.__loss_fn, accuracy) + else: + loss_train, acc_train = train(self.__model, self.__dataset.train_idx, self.__labels, self.__device, + self.__optimizer, self.__loss_fn, accuracy) + if hasattr(self.__model, "evaluate_func") and isinstance(self.__model.evaluate_func, Callable): + acc_val, acc_test = self.__model.evaluate_func(self.__dataset.val_idx, self.__dataset.test_idx, + self.__labels, self.__device, accuracy) + else: + acc_val, acc_test = evaluate(self.__model, self.__dataset.val_idx, self.__dataset.test_idx, + self.__labels, self.__device, accuracy) else: - loss_train, acc_train = train(self.__model, self.__dataset.train_idx, self.__labels, self.__device, - self.__optimizer, self.__loss_fn, accuracy) - if hasattr(self.__model, "evaluate_func") and isinstance(self.__model.evaluate_func, Callable): - acc_val, acc_test = self.__model.evaluate_func(self.__dataset.val_idx, self.__dataset.test_idx, - self.__labels, self.__device, accuracy) + loss_train, acc_train = mini_batch_train(self.__model, self.__dataset.train_idx, self.__train_loader, + self.__labels, self.__device, self.__optimizer, self.__loss_fn) + acc_val, acc_test = mini_batch_evaluate(self.__model, self.__dataset.val_idx, self.__val_loader, + self.__dataset.test_idx, self.__test_loader, self.__labels, + self.__device) + if self.__verbose: + print('Epoch: {:03d}'.format(epoch + 1), + 'loss_train: {:.4f}'.format(loss_train), + 'acc_train: {:.4f}'.format(acc_train), + 'acc_val: {:.4f}'.format(acc_val), + 'acc_test: {:.4f}'.format(acc_test), + 'time: {:.4f}s'.format(time.time() - t)) + + if acc_val > best_val: + patience = 0 + best_val = acc_val + best_test = acc_test else: - acc_val, acc_test = evaluate(self.__model, self.__dataset.val_idx, self.__dataset.test_idx, - self.__labels, self.__device, accuracy) - else: - loss_train, acc_train = mini_batch_train(self.__model, self.__dataset.train_idx, self.__train_loader, - self.__labels, self.__device, self.__optimizer, self.__loss_fn) - acc_val, acc_test = mini_batch_evaluate(self.__model, self.__dataset.val_idx, self.__val_loader, - self.__dataset.test_idx, self.__test_loader, self.__labels, - self.__device) + patience += 1 + if patience == self.__patience: + break - print('Epoch: {:03d}'.format(epoch + 1), - 'loss_train: {:.4f}'.format(loss_train), - 'acc_train: {:.4f}'.format(acc_train), - 'acc_val: {:.4f}'.format(acc_val), - 'acc_test: {:.4f}'.format(acc_test), - 'time: {:.4f}s'.format(time.time() - t)) + acc_val, acc_test = self._postprocess() if acc_val > best_val: - patience = 0 best_val = acc_val best_test = acc_test - else: - patience += 1 - if patience == self.__patience: - break - - acc_val, acc_test = self._postprocess() - if acc_val > best_val: - best_val = acc_val - best_test = acc_test - print("Optimization Finished!") - print("Total time elapsed: {:.4f}s".format(time.time() - t_total)) - print(f'Best val: {best_val:.4f}, best test: {best_test:.4f}') - return best_test + best_test_list.append(best_test) + if self.__verbose: + print("Optimization Finished!") + print("Total time elapsed: {:.4f}s".format(time.time() - t_total)) + print(f'Best val: {best_val:.4f}, best test: {best_test:.4f}') + + mean_best_test = np.mean(best_test_list) + return mean_best_test def _postprocess(self): self.__model.eval() diff --git a/sgl/tasks/node_classification_GAug.py b/sgl/tasks/node_classification_GAug.py deleted file mode 100644 index 0c0a071..0000000 --- a/sgl/tasks/node_classification_GAug.py +++ /dev/null @@ -1,272 +0,0 @@ -import gc -import os -import time -import torch -import torch.nn as nn -from torch.optim import Adam -import torch.nn.functional as F -import numpy as np - -from sgl.tasks.base_task import BaseTask -from sgl.tasks.utils import set_seed, accuracy, MultipleOptimizer - -class NodeClassificationGAugO(BaseTask): - def __init__(self, dataset, model, lr, weight_decay, epochs, device, seed, beta, warmup, max_patience, pretrain_ep, pretrain_nc): - super(NodeClassificationGAugO, self).__init__() - - self.__dataset = dataset - self.__labels = self.__dataset.y - - self.__model = model - self.__optimizer = MultipleOptimizer(Adam(model.ep_net.parameters(), lr=lr), - Adam(model.nc_net.parameters(), lr=lr, weight_decay=weight_decay)) - - self.__lr = lr - self.__weight_decay = weight_decay - - self.__epochs = epochs - self.__device = device - self.__seed = seed - - self.__warmup = warmup - self.__beta = beta - self.__max_patience = max_patience - - self.__pretrain_ep = pretrain_ep - self.__pretrain_nc = pretrain_nc - - self.__test_acc = self._execute() - - @property - def test_acc(self): - return self.__test_acc - - @staticmethod - def get_lr_schedule_by_sigmoid(n_epochs, lr, warmup): - """ schedule the learning rate with the sigmoid function. - The learning rate will start with near zero and end with near lr """ - factors = torch.FloatTensor(np.arange(n_epochs)) - factors = ((factors / factors[-1]) * (warmup * 2)) - warmup - factors = torch.sigmoid(factors) - # range the factors to [0, 1] - factors = (factors - factors[0]) / (factors[-1] - factors[0]) - lr_schedule = factors * lr - return lr_schedule - - @staticmethod - def loss_fn(nc_logits, norm_w, adj_logits, adj_orig, pos_weight, labels, idx, beta): - if labels.dim() == 2: - nc_criterion = nn.BCEWithLogitsLoss() - else: - nc_criterion = nn.CrossEntropyLoss() - loss = nc_criterion(nc_logits[idx], labels[idx]) - ep_loss = norm_w * F.binary_cross_entropy_with_logits(adj_logits, adj_orig, pos_weight=pos_weight) - loss += beta * ep_loss - - return loss - - def pretrain_ep_net(self, adj, features, adj_orig, norm_w, pos_weight): - """ pretrain the edge prediction network """ - optimizer = Adam(self.__model.ep_net.parameters(), lr=self.__lr) - - self.__model.train() - for _ in range(self.__pretrain_ep): - adj_logits = self.__model.ep_net(adj, features) - loss = norm_w * F.binary_cross_entropy_with_logits(adj_logits, adj_orig, pos_weight=pos_weight) - if not self.__model.gae: - mu = self.__model.ep_net.mean - lgstd = self.__model.ep_net.logstd - kl_divergence = 0.5 / adj_logits.size(0) * (1 + 2*lgstd - mu**2 - torch.exp(2*lgstd)).sum(1).mean() - loss -= kl_divergence - optimizer.zero_grad() - loss.backward() - optimizer.step() - - def pretrain_nc_net(self, adj, features): - """ pretrain the node classification network """ - optimizer = Adam(self.__model.nc_net.parameters(), lr=self.__lr, weight_decay=self.__weight_decay) - # loss function for node classification - if self.__labels.dim() == 2: - nc_criterion = nn.BCEWithLogitsLoss() - else: - nc_criterion = nn.CrossEntropyLoss() - - for _ in range(self.__pretrain_nc): - self.__model.train() - nc_logits = self.__model.nc_net(features, adj) - # losses - loss = nc_criterion(nc_logits[self.__dataset.train_idx], self.__labels[self.__dataset.train_idx]) - optimizer.zero_grad() - loss.backward() - optimizer.step() - - def train(self, adj_norm, adj_orig, features, norm_w, pos_weight, epoch, ep_lr_schedule): - # update the learning rate for ep_net if needed - if self.__warmup: - self.__optimizer.update_lr(0, ep_lr_schedule[epoch]) - - self.__model.train() - nc_logits, adj_logits = self.__model(adj_norm, adj_orig, features) - loss_train = self.loss_fn(nc_logits, norm_w, adj_logits, adj_orig, pos_weight, self.__labels, self.__dataset.train_idx, self.__beta) - acc_train = accuracy(nc_logits[self.__dataset.train_idx], self.__labels[self.__dataset.train_idx]) - self.__optimizer.zero_grad() - loss_train.backward() - self.__optimizer.step() - - return loss_train, acc_train - - def evaluate(self, features, adj): - self.__model.eval() - with torch.no_grad(): - nc_logits_eval = self.__model.nc_net(features, adj) - acc_val = accuracy(nc_logits_eval[self.__dataset.val_idx], self.__labels[self.__dataset.val_idx]) - acc_test = accuracy(nc_logits_eval[self.__dataset.test_idx], self.__labels[self.__dataset.test_idx]) - - return acc_val, acc_test - - def _execute(self): - set_seed(self.__seed) - - features, adj_orig, adj, adj_norm = self.__model.preprocess(self.__dataset.x, self.__dataset.adj, self.__device) - - self.__model = self.__model.to(self.__device) - self.__labels = self.__labels.to(self.__device) - - # weights for log_lik loss when training EP net - norm_w = adj_orig.shape[0]**2 / float((adj_orig.shape[0]**2 - adj_orig.sum()) * 2) - pos_weight = torch.FloatTensor([float(adj_orig.shape[0]**2 - adj_orig.sum()) / adj_orig.sum()]).to(self.__device) - # pretrain VGAE if needed - if self.__pretrain_ep: - self.pretrain_ep_net(adj_norm, features, adj_orig, norm_w, pos_weight) - # pretrain GCN if needed - if self.__pretrain_nc: - self.pretrain_nc_net(adj, features) - # get the learning rate schedule for the optimizer of ep_net if needed - if self.__warmup: - ep_lr_schedule = self.get_lr_schedule_by_sigmoid(self.__epochs, self.__lr, self.__warmup) - else: - ep_lr_schedule = None - - # keep record of the best validation accuracy for early stopping - best_acc_val, best_acc_test, patience_step = 0., 0., 0 - # train model - for epoch in range(self.__epochs): - t = time.time() - loss_train, acc_train = self.train(adj_norm, adj_orig, features, norm_w, pos_weight, epoch, ep_lr_schedule) - acc_val, acc_test = self.evaluate(features, adj) - - print('Epoch: {:03d}'.format(epoch + 1), - 'loss_train: {:.4f}'.format(loss_train), - 'acc_train: {:.4f}'.format(acc_train), - 'acc_val: {:.4f}'.format(acc_val), - 'acc_test: {:.4f}'.format(acc_test), - 'time: {:.4f}s'.format(time.time() - t)) - - if acc_val > best_acc_val: - best_acc_val = acc_val - best_acc_test = acc_test - patience_step = 0 - else: - patience_step += 1 - if patience_step == self.__max_patience: - break - - # release RAM and GPU memory - del adj, features, adj_orig, adj_norm - torch.cuda.empty_cache() - gc.collect() - - return best_acc_test - - -class NodeClassificationGAugM(BaseTask): - def __init__(self, dataset, model, lr, weight_decay, epochs, device, loss_fn=nn.CrossEntropyLoss(), seed=42, max_patience=100): - super(NodeClassificationGAugM, self).__init__() - - self.__dataset = dataset - self.__labels = self.__dataset.y - - self.__model = model - self.__optimizer = Adam(model.parameters(), lr=lr, - weight_decay=weight_decay) - self.__epochs = epochs - self.__loss_fn = loss_fn - self.__device = device - self.__seed = seed - self.__max_patience = max_patience - - self.__test_acc = self._execute() - - @property - def test_acc(self): - return self.__test_acc - - def train(self, adj_norm, features): - self.__model.train() - pred_y = self.__model(adj_norm, features)[self.__dataset.train_idx] - ground_truth_y = self.__labels[self.__dataset.train_idx] - loss_train = self.__loss_fn(pred_y, ground_truth_y) - acc_train = accuracy(pred_y, ground_truth_y) - - self.__optimizer.zero_grad() - loss_train.backward() - self.__optimizer.step() - - return loss_train, acc_train - - def evaluate(self, adj_norm, features): - self.__model.eval() - with torch.no_grad(): - pred_y = self.__model(adj_norm, features) - acc_val = accuracy(pred_y[self.__dataset.val_idx], self.__labels[self.__dataset.val_idx]) - acc_test = accuracy(pred_y[self.__dataset.test_idx], self.__labels[self.__dataset.test_idx]) - - return acc_val, acc_test - - def _execute(self): - set_seed(self.__seed) - - pre_time_st = time.time() - adj_pred_dir = os.path.join(self.__dataset.processed_dir, "GAugM_edge_probabilities") - adj, features = self.__model.preprocess(self.__dataset.adj, self.__dataset.x, adj_pred_dir, self.__device) - pre_time_ed = time.time() - print(f"Preprocessing done in {(pre_time_ed - pre_time_st):.4f}s") - - self.__model = self.__model.to(self.__device) - self.__labels = self.__labels.to(self.__device) - - t_total = time.time() - best_val = 0. - best_test = 0. - patience = 0 - for epoch in range(self.__epochs): - t = time.time() - loss_train, acc_train = self.train(adj, features) - acc_val, acc_test = self.evaluate(adj, features) - - print('Epoch: {:03d}'.format(epoch + 1), - 'loss_train: {:.4f}'.format(loss_train), - 'acc_train: {:.4f}'.format(acc_train), - 'acc_val: {:.4f}'.format(acc_val), - 'acc_test: {:.4f}'.format(acc_test), - 'time: {:.4f}s'.format(time.time() - t)) - - if acc_val > best_val: - best_val = acc_val - best_test = acc_test - patience = 0 - else: - patience += 1 - if patience == self.__max_patience: - break - - print("Optimization Finished!") - print("Total time elapsed: {:.4f}s".format(time.time() - t_total)) - print(f'Best val: {best_val:.4f}, best test: {best_test:.4f}') - - del adj, features - torch.cuda.empty_cache() - gc.collect() - - return best_test - diff --git a/sgl/tasks/node_classification_sampling.py b/sgl/tasks/node_classification_sampling.py index 1a9a255..88286ed 100644 --- a/sgl/tasks/node_classification_sampling.py +++ b/sgl/tasks/node_classification_sampling.py @@ -13,7 +13,7 @@ class NodeClassification_Sampling(BaseTask): def __init__(self, dataset, model, lr, weight_decay, epochs, device, loss_fn="nll_loss", seed=42, - inductive=False, train_batch_size=None, eval_batch_size=None, eval_freq=1, eval_start=1, **kwargs): + inductive=False, train_batch_size=None, eval_batch_size=None, eval_freq=1, eval_start=1, runs=1, verbose=True, **kwargs): super(NodeClassification_Sampling, self).__init__() self.__dataset = dataset @@ -27,6 +27,8 @@ def __init__(self, dataset, model, lr, weight_decay, epochs, device, loss_fn="nl self.__loss_fn = getattr(F, loss_fn) if isinstance(loss_fn, str) else loss_fn self.__device = device self.__seed = seed + self.__runs = runs + self.__verbose = verbose self.__inductive = inductive self.__train_batch_size= train_batch_size self.__eval_batch_size = eval_batch_size @@ -60,7 +62,8 @@ def _execute(self): kwargs.update({"inductive": self.__inductive, "train_idx": self.__dataset.train_idx}) self.__model.preprocess(adj=self.__dataset.adj, x=self.__dataset.x, y=self.__dataset.y, device=self.__device, **kwargs) pre_time_ed = time.time() - print(f"Preprocessing done in {(pre_time_ed - pre_time_st):.4f}s") + if self.__verbose: + print(f"Preprocessing done in {(pre_time_ed - pre_time_st):.4f}s") if self.__mini_batch_train: if self.__train_determined_sample: @@ -95,62 +98,69 @@ def _execute(self): self.__model = self.__model.to(self.__device) - t_total = time.time() - best_val = 0. - best_test = 0. - - for epoch in range(self.__epochs): - t = time.time() - if self.__mini_batch_train: - if hasattr(self.__model, "train_func") and isinstance(self.__model.train_func, Callable): - loss_train, acc_train = self.__model.train_func(self.__train_loader, self.__inductive, self.__device, self.__optimizer, self.__loss_fn) - else: - loss_train, acc_train = mini_batch_train(self.__model, self.__train_loader, self.__inductive, self.__device, - self.__optimizer, self.__loss_fn) - else: - loss_train, acc_train = train(self.__model, self.__dataset.train_idx, self.__optimizer, self.__loss_fn) + best_test_list = [] + for _ in range(self.__runs): + self.__model.reset_parameters() - if epoch + 1 >= self.__eval_start and (epoch + 1) % self.__eval_freq == 0: - if self.__mini_batch_eval: - if self.__eval_together is False: - if hasattr(self.__model, "evaluate_func") and isinstance(self.__model.evaluate_func, Callable): - acc_val, acc_test = self.__model.evaluate_func(self.__val_loader, self.__test_loader, self.__device) + t_total = time.time() + best_val = 0. + best_test = 0. + + for epoch in range(self.__epochs): + t = time.time() + if self.__mini_batch_train: + if hasattr(self.__model, "train_func") and isinstance(self.__model.train_func, Callable): + loss_train, acc_train = self.__model.train_func(self.__train_loader, self.__inductive, self.__device, self.__optimizer, self.__loss_fn) + else: + loss_train, acc_train = mini_batch_train(self.__model, self.__train_loader, self.__inductive, self.__device, + self.__optimizer, self.__loss_fn) + else: + loss_train, acc_train = train(self.__model, self.__dataset.train_idx, self.__optimizer, self.__loss_fn) + + if epoch + 1 >= self.__eval_start and (epoch + 1) % self.__eval_freq == 0: + if self.__mini_batch_eval: + if self.__eval_together is False: + if hasattr(self.__model, "evaluate_func") and isinstance(self.__model.evaluate_func, Callable): + acc_val, acc_test = self.__model.evaluate_func(self.__val_loader, self.__test_loader, self.__device) + else: + acc_val, acc_test = mini_batch_evaluate(self.__model, self.__val_loader, self.__test_loader, self.__device) else: - acc_val, acc_test = mini_batch_evaluate(self.__model, self.__val_loader, self.__test_loader, self.__device) + self.__model.eval() + outputs = self.__model.inference(self.__all_eval_loader, self.__device) + acc_train = accuracy(outputs[self.__dataset.train_idx], self.__dataset.y[self.__dataset.train_idx]) + acc_val = accuracy(outputs[self.__dataset.val_idx], self.__dataset.y[self.__dataset.val_idx]) + acc_test = accuracy(outputs[self.__dataset.test_idx], self.__dataset.y[self.__dataset.test_idx]) else: - self.__model.eval() - outputs = self.__model.inference(self.__all_eval_loader, self.__device) - acc_train = accuracy(outputs[self.__dataset.train_idx], self.__dataset.y[self.__dataset.train_idx]) - acc_val = accuracy(outputs[self.__dataset.val_idx], self.__dataset.y[self.__dataset.val_idx]) - acc_test = accuracy(outputs[self.__dataset.test_idx], self.__dataset.y[self.__dataset.test_idx]) + acc_val, acc_test = evaluate(self.__model, self.__dataset.val_idx, self.__dataset.test_idx) + + if acc_val > best_val: + best_val = acc_val + best_test = acc_test + + print('Epoch: {:03d}'.format(epoch + 1), + 'loss_train: {:.4f}'.format(loss_train), + 'acc_train: {:.4f}'.format(acc_train), + 'acc_val: {:.4f}'.format(acc_val), + 'acc_test: {:.4f}'.format(acc_test), + 'time: {:.4f}s'.format(time.time() - t)) else: - acc_val, acc_test = evaluate(self.__model, self.__dataset.val_idx, self.__dataset.test_idx) - - if acc_val > best_val: - best_val = acc_val - best_test = acc_test - - print('Epoch: {:03d}'.format(epoch + 1), - 'loss_train: {:.4f}'.format(loss_train), - 'acc_train: {:.4f}'.format(acc_train), - 'acc_val: {:.4f}'.format(acc_val), - 'acc_test: {:.4f}'.format(acc_test), - 'time: {:.4f}s'.format(time.time() - t)) - else: - print('Epoch: {:03d}'.format(epoch + 1), - 'loss_train: {:.4f}'.format(loss_train), - 'acc_train: {:.4f}'.format(acc_train), - 'time: {:.4f}s'.format(time.time() - t)) - - acc_val, acc_test = self._postprocess() - if acc_val > best_val: - best_val = acc_val - best_test = acc_test - - print("Optimization Finished!") - print("Total time elapsed: {:.4f}s".format(time.time() - t_total)) - print(f'Best val: {best_val:.4f}, best test: {best_test:.4f}') - return best_test + print('Epoch: {:03d}'.format(epoch + 1), + 'loss_train: {:.4f}'.format(loss_train), + 'acc_train: {:.4f}'.format(acc_train), + 'time: {:.4f}s'.format(time.time() - t)) + + acc_val, acc_test = self._postprocess() + if acc_val > best_val: + best_val = acc_val + best_test = acc_test + + best_test_list.append(best_test) + + print("Optimization Finished!") + print("Total time elapsed: {:.4f}s".format(time.time() - t_total)) + print(f'Best val: {best_val:.4f}, best test: {best_test:.4f}') + + return np.mean(best_test_list) def _postprocess(self): self.__model.eval() From 9b33e765c057cdf943a1facb77030ed27f159dfa Mon Sep 17 00:00:00 2001 From: infinity Date: Sun, 24 Dec 2023 14:01:46 +0000 Subject: [PATCH 26/28] change search_config.py --- examples/GDA/test_search_GAug.py | 26 ++++- sgl/search/gda_hpo/GAug_search_config.py | 137 ----------------------- sgl/search/gda_hpo/search_config.py | 46 +++++++- 3 files changed, 65 insertions(+), 144 deletions(-) delete mode 100644 sgl/search/gda_hpo/GAug_search_config.py diff --git a/examples/GDA/test_search_GAug.py b/examples/GDA/test_search_GAug.py index 6b18ac7..05d20e7 100644 --- a/examples/GDA/test_search_GAug.py +++ b/examples/GDA/test_search_GAug.py @@ -1,9 +1,10 @@ import torch import argparse +import torch.nn.functional as F from openbox import Optimizer import sgl.dataset as Dataset -from sgl.search.gda_hpo.GAug_search_config import GAugOConfigManager, GAugMConfigManager +from sgl.search.gda_hpo.search_config import BaseGDAConfigManager if __name__ == "__main__": parser = argparse.ArgumentParser(description="HPO-GAug-Model.") @@ -26,9 +27,28 @@ pretrain_batch_size = args.pretrain_batch_size if args.pretrain_batch_size > 0 else None train_batch_size = args.train_batch_size if args.train_batch_size > 0 else None if args.model == "GAugO": - configer = GAugOConfigManager(dataset, args.gnn_type, not args.not_gae, device, minibatch=args.minibatch, pretrain_batch_size=pretrain_batch_size, train_batch_size=train_batch_size, runs=args.runs_per_config, max_patience=args.max_patience) + model_keys = ["in_dim", "hidden_dim", "emb_size", "n_classes", "n_layers", "dropout", "gnn_type", "activation", "temperature", "gae", "alpha", "feat_norm", "sample_type", "minibatch", "n_heads"] + task_keys = ["dataset", "model", "lr", "weight_decay", "epochs", "device", "beta", "warmup", "max_patience", "pretrain_ep", "pretrain_nc", "runs", "verbose", "seed", "pretrain_batch_size", "train_batch_size", "ep_lr"] + const_model_kwargs = dict(in_dim=dataset.num_features, n_classes=dataset.num_classes, gnn_type=args.gnn_type, activation=F.relu, gae=not args.not_gae, minibatch=args.minibatch, emb_size=32, n_layers=2, dropout=0.5, feat_norm="row") + const_task_kwargs = dict(dataset=dataset, epochs=200, device=device, max_patience=args.max_patience, pretrain_batch_size=pretrain_batch_size, train_batch_size=train_batch_size, runs=args.runs_per_config, verbose=False, lr=0.01, weight_decay=0.0005) + Reals = dict(alpha=dict(lower=0, upper=1, default_value=0.4, q=0.01), temperature=dict(lower=0.1, upper=2.1, default_value=1.5, q=0.1), beta=dict(lower=0, upper=4, default_value=2, q=0.1)) + if pretrain_batch_size is not None: + Reals.update(ep_lr=dict(lower=0.001, upper=0.01, default_value=0.002, q=0.001)) + else: + const_task_kwargs.update(ep_lr=0.01) + Categoricals = dict(hidden_dim=dict(choices=[32, 64, 128, 256], default_value=128)) + Ints = dict(warmup=dict(lower=0, upper=10, default_value=2, q=1), pretrain_ep=dict(lower=5, upper=300, default_value=100, q=5), pretrain_nc=dict(lower=5, upper=300, default_value=100, q=5)) + hier_params = dict(Real=Reals, Categorical=Categoricals, Int=Ints) + configer = BaseGDAConfigManager(args.model, f"NodeClassification{args.model}", model_keys, task_keys, const_model_kwargs, const_task_kwargs, hier_params) else: - configer = GAugMConfigManager(dataset, args.gnn_type, not args.not_gae, device, args.num_logits, runs=args.runs_per_config, max_patience=args.max_patience) + model_keys = ["in_dim", "hidden_dim", "n_classes", "n_layers", "gnn_type", "rm_pct", "add_pct", "choose_idx", "gae", "dropout", "activation", "feat_norm", "n_heads"] + task_keys = ["dataset", "model", "lr", "weight_decay", "epochs", "device", "max_patience", "runs", "verbose", "seed"] + const_model_kwargs = dict(in_dim=dataset.num_features, n_classes=dataset.num_classes, gnn_type=args.gnn_type, activation=F.relu, gae=not args.not_gae, n_layers=2, dropout=0.5, feat_norm="row") + const_task_kwargs = dict(dataset=dataset, epochs=200, device=device, max_patience=args.max_patience, runs=args.runs_per_config, verbose=False, lr=0.01, weight_decay=0.0005) + Categoricals = dict(hidden_dim=dict(choices=[32, 64, 128, 256], default_value=128)) + Ints = dict(choose_idx=dict(lower=1, upper=args.num_logits, default_value=1, q=1), rm_pct=dict(lower=0, upper=80, default_value=20, q=1), add_pct=dict(lower=0, upper=80, default_value=20, q=1)) + hier_params = dict(Categorical=Categoricals, Int=Ints) + configer = BaseGDAConfigManager(args.model, f"NodeClassification{args.model}", model_keys, task_keys, const_model_kwargs, const_task_kwargs, hier_params) opt = Optimizer(configer._configFunction, configer._configSpace(), diff --git a/sgl/search/gda_hpo/GAug_search_config.py b/sgl/search/gda_hpo/GAug_search_config.py deleted file mode 100644 index c3835e5..0000000 --- a/sgl/search/gda_hpo/GAug_search_config.py +++ /dev/null @@ -1,137 +0,0 @@ -import torch.nn.functional as F -from openbox import space as osp - -from sgl.models.homo.gda import GAugO, GAugM -from sgl.tasks import NodeClassificationGAugO, NodeClassificationGAugM -from sgl.search.gda_hpo.search_config import BaseGDAConfigManager - -class GAugOConfigManager(BaseGDAConfigManager): - def __init__(self, dataset, gnn_type, gae, device, runs=5, activation=F.relu, minibatch=False, epochs=200, max_patience=100, pretrain_batch_size=None, train_batch_size=None): - super(GAugOConfigManager, self).__init__() - # basic information - self.__dataset = dataset - self.__gnn_type = gnn_type - self.__gae = gae - self.__minibatch = minibatch - self.__activation = activation - self.__epochs = epochs - self.__device = device - self.__max_patience = max_patience - self.__pretrain_batch_size = pretrain_batch_size - self.__train_batch_size = train_batch_size - self.__runs = runs - self.__config_space = osp.Space() - # model hyperparameters - alpha = osp.Real("alpha", 0, 1, default_value=0.4, q=0.01) - temperature = osp.Real("temperature", 0.1, 2.1, default_value=1.5, q=0.1) - hidden_dim = osp.Categorical("hidden_dim", [32, 64, 128, 256], default_value=128) - emb_size = osp.Constant("emb_size", 32) - n_layers = osp.Constant("n_layers", 2) - dropout = osp.Constant("dropout", 0.5) - feat_norm = osp.Constant("feat_norm", "row") - # task hyperparameters - lr = osp.Constant("lr", 0.01) if self.__gnn_type != "gat" else osp.Constant("lr", 0.005) - if pretrain_batch_size is not None: - ep_lr = osp.Real("ep_lr", 0.001, 0.01, default_value=0.002, q=0.001) - else: - ep_lr = osp.Constant("ep_lr", 0.01) - weight_decay = osp.Constant("weight_decay", 0.0005) - warmup = osp.Int("warmup", 0, 10, default_value=2, q=1) - beta = osp.Real("beta", 0, 4, default_value=2, q=0.1) - pretrain_ep = osp.Int("pretrain_ep", 5, 300, default_value=100, q=5) - pretrain_nc = osp.Int("pretrain_nc", 5, 300, default_value=200, q=5) - self.__config_space.add_variables([alpha, temperature, hidden_dim, emb_size, n_layers, dropout, feat_norm, \ - lr, weight_decay, warmup, beta, pretrain_ep, pretrain_nc, ep_lr]) - - def _configSpace(self): - return self.__config_space - - def _configTarget(self, params): - model_kwargs = dict() - model_kwargs["in_dim"] = self.__dataset.num_features - model_kwargs["hidden_dim"] = params["hidden_dim"] - model_kwargs["emb_size"] = params["emb_size"] - model_kwargs["n_classes"] = self.__dataset.num_classes - model_kwargs["n_layers"] = params["n_layers"] - model_kwargs["dropout"] = params["dropout"] - model_kwargs["gnn_type"] = self.__gnn_type - model_kwargs["activation"] = self.__activation - model_kwargs["temperature"] = params["temperature"] - model_kwargs["gae"] = self.__gae - model_kwargs["alpha"] = params["alpha"] - model_kwargs["feat_norm"] = params["feat_norm"] - model_kwargs["minibatch"] = self.__minibatch - model = GAugO(**model_kwargs) - task_kwargs = dict() - task_kwargs["lr"] = params["lr"] - task_kwargs["weight_decay"] = params["weight_decay"] - task_kwargs["epochs"] = self.__epochs - task_kwargs["device"] = self.__device - task_kwargs["beta"] = params["beta"] - task_kwargs["warmup"] = params["warmup"] - task_kwargs["max_patience"] = self.__max_patience - task_kwargs["pretrain_ep"] = params["pretrain_ep"] - task_kwargs["pretrain_nc"] = params["pretrain_nc"] - task_kwargs["pretrain_batch_size"] = self.__pretrain_batch_size - task_kwargs["train_batch_size"] = self.__train_batch_size - task_kwargs["ep_lr"] = params["ep_lr"] - task = NodeClassificationGAugO(self.__dataset, model, runs=self.__runs, verbose=False, **task_kwargs) - acc_res = task._execute() - - return dict(objectives=[-acc_res]) - -class GAugMConfigManager(BaseGDAConfigManager): - def __init__(self, dataset, gnn_type, gae, device, num_logits, runs=5, activation=F.relu, epochs=200, max_patience=100): - super(GAugMConfigManager, self).__init__() - # basic information - self.__dataset = dataset - self.__gnn_type = gnn_type - self.__gae = gae - self.__activation = activation - self.__device = device - self.__epochs = epochs - self.__max_patience = max_patience - self.__runs = runs - self.__config_space = osp.Space() - # model hyperparameters - choose_idx = osp.Int("choose_idx", 1, num_logits, default_value=1, q=1) - rm_pct = osp.Int("rm_pct", 0, 80, default_value=20, q=1) - add_pct = osp.Int("add_pct", 0, 80, default_value=20, q=1) - hidden_dim = osp.Categorical("hidden_dim", [32, 64, 128, 256], default_value=128) - n_layers = osp.Constant("n_layers", 2) - dropout = osp.Constant("dropout", 0.5) - feat_norm = osp.Constant("feat_norm", "row") - # task hyperparameters - lr = osp.Constant("lr", 0.01) - weight_decay = osp.Constant("weight_decay", 0.0005) - self.__config_space.add_variables([choose_idx, rm_pct, add_pct, hidden_dim, n_layers, dropout, feat_norm, \ - lr, weight_decay]) - - def _configSpace(self): - return self.__config_space - - def _configTarget(self, params): - model_kwargs = dict() - model_kwargs["in_dim"] = self.__dataset.num_features - model_kwargs["hidden_dim"] = params["hidden_dim"] - model_kwargs["n_classes"] = self.__dataset.num_classes - model_kwargs["n_layers"] = params["n_layers"] - model_kwargs["dropout"] = params["dropout"] - model_kwargs["gnn_type"] = self.__gnn_type - model_kwargs["activation"] = self.__activation - model_kwargs["gae"] = self.__gae - model_kwargs["feat_norm"] = params["feat_norm"] - model_kwargs["choose_idx"] = params["choose_idx"] - model_kwargs["rm_pct"] = params["rm_pct"] - model_kwargs["add_pct"] = params["add_pct"] - model = GAugM(**model_kwargs) - task_kwargs = dict() - task_kwargs["lr"] = params["lr"] - task_kwargs["weight_decay"] = params["weight_decay"] - task_kwargs["epochs"] = self.__epochs - task_kwargs["device"] = self.__device - task_kwargs["max_patience"] = self.__max_patience - task = NodeClassificationGAugM(self.__dataset, model, runs=self.__runs, verbose=False, **task_kwargs) - acc_res = task._execute() - - return dict(objectives=[-acc_res]) \ No newline at end of file diff --git a/sgl/search/gda_hpo/search_config.py b/sgl/search/gda_hpo/search_config.py index f98aa3a..47bb3f2 100644 --- a/sgl/search/gda_hpo/search_config.py +++ b/sgl/search/gda_hpo/search_config.py @@ -1,14 +1,52 @@ +from typing import List from openbox import space as osp +import sgl.models.homo.gda as GDAModel +import sgl.tasks as Task + class BaseGDAConfigManager(): - def __init__(self): + def __init__(self, gda_model_name: str, task_name: str, model_keys: List[str], task_keys: List[str], const_model_kwargs: dict, const_task_kwargs: dict, hier_params: dict): super(BaseGDAConfigManager, self).__init__() - self.__config_space = None + self._gda_model_name = gda_model_name + self._task_name = task_name + self._model_keys = model_keys + self._task_keys = task_keys + self._const_model_kwargs = const_model_kwargs + self._const_task_kwargs = const_task_kwargs + self._config_space = osp.Space() + self._setupSpace(hier_params) + + def _configTarget(self, params: dict): + model_kwargs, task_kwargs = self._const_model_kwargs.copy(), self._const_task_kwargs.copy() + for p_name, p_value in params.items(): + if p_name in self._model_keys: + model_kwargs.update({p_name: p_value}) + elif p_name in self._task_keys: + task_kwargs.update({p_name: p_value}) + else: + raise ValueError(f"Get unexpected parameter {p_name}") + model = getattr(GDAModel, self._gda_model_name)(**model_kwargs) + task = getattr(Task, self._task_name)(model=model, **task_kwargs) + acc_res = task._execute() + + return dict(objectives=[-acc_res]) - def _configTarget(self, params): - raise NotImplementedError + def _configSpace(self): + return self._config_space def _configFunction(self, config_space: osp.Configuration): params = config_space.get_dictionary().copy() result = self._configTarget(params) return result + + def _setupSpace(self, hier_params: dict): + for cls, variables in hier_params.items(): + """ + cls: str, variable class, Real, Int, Constant + variables: dict, key = variable name (e.g., alpha, temperature), + value = variable property (e.g., lower=0, upper=1, default_value=0.4, q=0.01) + """ + variable_list = [] + for var_name, var_kwargs in variables.items(): + variable_list.append(getattr(osp, cls)(var_name, **var_kwargs)) + self._config_space.add_variables(variable_list) \ No newline at end of file From c5fe9bf123c6c19c1033bddb4300126d7ef95e3e Mon Sep 17 00:00:00 2001 From: infinity Date: Mon, 25 Dec 2023 10:27:19 +0000 Subject: [PATCH 27/28] update hp for GAugOMini; change the adj stack method for ClusterGCN; fix GraphSAINT --- examples/GDA/configs/GAugOMini.yml | 20 ++-- examples/clustergcn_nodeclass.py | 7 +- examples/configs/clustergcn.yml | 31 +++++- examples/configs/graphsaint.yml | 4 +- examples/graphsaint_nodeclass.py | 5 +- sgl/models/homo/clustergcn.py | 16 +-- sgl/models/homo/graphsaint.py | 31 +++--- sgl/sampler/base_sampler.py | 6 +- sgl/sampler/sampler.py | 115 +++++++++++----------- sgl/search/gda_hpo/search_config.py | 3 +- sgl/tasks/node_classification_sampling.py | 33 ++++--- 11 files changed, 159 insertions(+), 112 deletions(-) diff --git a/examples/GDA/configs/GAugOMini.yml b/examples/GDA/configs/GAugOMini.yml index a56eea9..a70ba88 100644 --- a/examples/GDA/configs/GAugOMini.yml +++ b/examples/GDA/configs/GAugOMini.yml @@ -5,9 +5,9 @@ dataset: model: model_name: 'GAugO' gnn_type: 'gcn' - alpha: 1.0 - temperature: 0.2 - hidden_dim: 128 + alpha: 0.79 + temperature: 1.6 + hidden_dim: 256 emb_size: 64 dropout: 0.5 n_layers: 2 @@ -17,14 +17,14 @@ model: minibatch: True task: lr: 0.01 - ep_lr: 0.002 + ep_lr: 0.006 seed: 42 - warmup: 0 - beta: 0.8 + warmup: 7 + beta: 3.3 epochs: 200 weight_decay: 0.0005 - pretrain_ep: 160 - pretrain_nc: 30 + pretrain_ep: 10 + pretrain_nc: 40 max_patience: 50 - train_batch_size: 250 - pretrain_batch_size: 4096 \ No newline at end of file + train_batch_size: 8192 + pretrain_batch_size: 8192 \ No newline at end of file diff --git a/examples/clustergcn_nodeclass.py b/examples/clustergcn_nodeclass.py index 04ea21a..41e3025 100644 --- a/examples/clustergcn_nodeclass.py +++ b/examples/clustergcn_nodeclass.py @@ -29,7 +29,12 @@ train_sampler_kwargs.update({"save_dir": dataset.processed_dir}) train_cluster_number = train_sampler_kwargs["cluster_number"] task_kwargs.update({"train_graph_number": train_cluster_number}) - train_sampler = ClusterGCNSampler(dataset, **train_sampler_kwargs) + if "inductive" in train_sampler_kwargs.keys(): + inductive = train_sampler_kwargs.pop("inductive") + else: + inductive = False + task_kwargs.update({"inductive": inductive}) + train_sampler = ClusterGCNSampler(dataset, inductive=inductive, **train_sampler_kwargs) if "eval" in sampler_kwargs: eval_sampler_kwargs = sampler_kwargs["eval"] eval_sampler_name = eval_sampler_kwargs["name"] diff --git a/examples/configs/clustergcn.yml b/examples/configs/clustergcn.yml index 1423719..d723e7d 100644 --- a/examples/configs/clustergcn.yml +++ b/examples/configs/clustergcn.yml @@ -1,20 +1,43 @@ +# dataset: +# classname: "Planetoid" +# name: "cora" +# root: "/home/ssq/test_data/" +# sampler: +# train: +# cluster_method: "metis" +# cluster_number: 10 +# sparse_type: "pyg" +# model: +# hidden_dim: 128 +# dropout: 0.5 +# num_layers: 2 +# sparse_type: "pyg" +# task: +# train_batch_size: 5 +# epochs: 30 +# lr: 0.01 +# weight_decay: 0.00005 +# loss_fn: "nll_loss" +# seed: 42 dataset: classname: "Planetoid" - name: "cora" + name: "pubmed" root: "/home/ssq/test_data/" + split: "full" sampler: train: cluster_method: "metis" cluster_number: 10 - post_sampling_op: "LaplacianGraphOp" - sparse_type: "torch" + sparse_type: "pyg" + inductive: True model: hidden_dim: 128 dropout: 0.5 num_layers: 2 + sparse_type: "pyg" task: train_batch_size: 5 - epochs: 20 + epochs: 50 lr: 0.01 weight_decay: 0.00005 loss_fn: "nll_loss" diff --git a/examples/configs/graphsaint.yml b/examples/configs/graphsaint.yml index 4d98f59..3f47956 100644 --- a/examples/configs/graphsaint.yml +++ b/examples/configs/graphsaint.yml @@ -5,11 +5,11 @@ dataset: sampler: train: pre_sampling_graphs: 20 - sampler_type: "Node" + sampler_type: "random_walk" nodebudget: 1000 edgebudget: 3000 r: 500 - h: 4 + h: 3 pre_sampling_op: "RwGraphOp" sparse_type: "torch" model: diff --git a/examples/graphsaint_nodeclass.py b/examples/graphsaint_nodeclass.py index b4a57a4..3d7520f 100644 --- a/examples/graphsaint_nodeclass.py +++ b/examples/graphsaint_nodeclass.py @@ -1,8 +1,9 @@ import yaml import argparse -from torch.nn.functional import nll_loss + import sgl.dataset as Dataset from sgl.models.homo import GraphSAINT + import sgl.sampler as Sampler from sgl.sampler import GraphSAINTSampler from sgl.tasks import NodeClassification_Sampling @@ -45,5 +46,5 @@ model_kwargs.update({"device": device}) model = GraphSAINT(dataset, train_sampler, eval_sampler, **model_kwargs) task_kwargs.update({"device": device}) - task_kwargs.update({"loss_fn":model.loss}) + task_kwargs.update({"loss_fn": model.loss_fn}) test_acc = NodeClassification_Sampling(dataset, model, **task_kwargs).test_acc diff --git a/sgl/models/homo/clustergcn.py b/sgl/models/homo/clustergcn.py index b6570ad..75f8038 100644 --- a/sgl/models/homo/clustergcn.py +++ b/sgl/models/homo/clustergcn.py @@ -1,11 +1,9 @@ -from sgl.models.simple_models import GCN +from sgl.models.pyg_simple_models import GCN from sgl.models.base_model import BaseSAMPLEModel -from sgl.operators.graph_op import LaplacianGraphOp class ClusterGCN(BaseSAMPLEModel): def __init__(self, training_sampler, eval_sampler, nfeat, hidden_dim, nclass, sparse_type="torch", dropout=0.5, num_layers=2, device="cpu"): super(ClusterGCN, self).__init__(evaluate_mode="sampling", sparse_type=sparse_type) - self._pre_graph_op = LaplacianGraphOp(r=0.5) self._training_sampling_op = training_sampler self._eval_sampling_op = eval_sampler self._base_model = GCN(n_feat=nfeat, n_hid=hidden_dim, n_class=nclass, n_layers=num_layers, dropout=dropout).to(device) @@ -16,11 +14,17 @@ def pre_sample(self, mode="train"): else: self._eval_sampling_op.multiple_graphs_sampling() - def mini_batch_prepare_forward(self, batch, device, **kwargs): + def mini_batch_prepare_forward(self, batch, device, inductive=False): batch_in, batch_out, block = batch local_inds, global_inds = batch_out - in_x = self._processed_feature[batch_in].to(device) - y_truth = self._vanilla_y[global_inds].to(device) + + if inductive is False: + in_x = self._processed_feature[batch_in].to(device) + y_truth = self._vanilla_y[global_inds].to(device) + else: + in_x = self._processed_train_feature[batch_in].to(device) + y_truth = self._vanilla_train_y[global_inds].to(device) + block.to_device(device) y_pred = self._base_model(in_x, block)[local_inds] return y_pred, y_truth diff --git a/sgl/models/homo/graphsaint.py b/sgl/models/homo/graphsaint.py index 7a0d0a6..7e29b09 100644 --- a/sgl/models/homo/graphsaint.py +++ b/sgl/models/homo/graphsaint.py @@ -1,31 +1,38 @@ +import torch.nn.functional as F + from sgl.models.simple_models import GCN from sgl.models.base_model import BaseSAMPLEModel from sgl.operators.graph_op import RwGraphOp -from torch.nn.functional import nll_loss - class GraphSAINT(BaseSAMPLEModel): def __init__(self, dataset, training_sampler, eval_sampler, hidden_dim, sparse_type="torch", dropout=0.5, num_layers=2, device="cpu"): super(GraphSAINT, self).__init__(sparse_type=sparse_type) self._pre_graph_op = RwGraphOp() self._training_sampling_op = training_sampler self._eval_sampling_op = eval_sampler - self.device = device + self._device = device self._base_model = GCN( n_feat=dataset.num_features, n_hid=hidden_dim, n_class=dataset.num_classes, n_layers=num_layers, dropout=dropout ).to(device) def pre_sample(self, mode="train"): - self._training_sampling_op._calc_norm() - self._training_sampling_op.loss_norm.to(device=self.device) - return + if mode == "train": + self._training_sampling_op._calc_norm() + self._loss_norm = self._training_sampling_op.loss_norm.to(self._device) + else: + raise ValueError("GraphSAINT sampler now only support training mode.") - def mini_batch_prepare_forward(self, batch, device, **kwargs): + def mini_batch_prepare_forward(self, batch, device, inductive=False): batch_in, batch_out, block = batch local_inds, global_inds = batch_out - in_x = self._processed_feature[batch_in].to(device) - y_truth = self._vanilla_y[global_inds].to(device) + if inductive is False: + in_x = self._processed_feature[batch_in].to(device) + y_truth = self._vanilla_y[global_inds].to(device) + else: + in_x = self._processed_train_feature[batch_in].to(device) + y_truth = self._vanilla_train_y[global_inds].to(device) + block.to_device(device) y_pred = self._base_model(in_x, block)[local_inds] return y_pred, y_truth @@ -36,11 +43,11 @@ def collate_fn(self, batch_ids, mode): else: return self._eval_sampling_op.collate_fn(batch_ids, mode) - def loss(self, pred, labels): - loss = nll_loss(pred, labels, reduction="none") + def loss_fn(self, pred, labels): + loss = F.nll_loss(pred, labels, reduction="none") loss = (loss / self.cur_loss_norm).sum() return loss @property def cur_loss_norm(self): - return self._training_sampling_op.loss_norm[self._training_sampling_op.cur_index] \ No newline at end of file + return self._loss_norm[self._training_sampling_op.cur_index] \ No newline at end of file diff --git a/sgl/sampler/base_sampler.py b/sgl/sampler/base_sampler.py index e2df988..29f57da 100644 --- a/sgl/sampler/base_sampler.py +++ b/sgl/sampler/base_sampler.py @@ -99,10 +99,6 @@ def _post_process(self, adjs, to_sparse_tensor=True): adjs = [sparse_transform_func(adj) for adj in adjs] return adjs - @staticmethod - def to_Block(adjs, sparse_type): - return Block(adjs, sparse_type) - def collate_fn(self, *args): raise NotImplementedError @@ -116,7 +112,7 @@ def __init__(self, adj, **kwargs): self.sample_level = "graph" self.pre_sampling = False self.full_batch = kwargs.get("node_ids", range(self._adj.shape[0])) - self.full_block = self.to_Block(self._adj, self._sparse_type) + self.full_block = Block(self._adj, self._sparse_type) def sampling(self): return self.full_batch, self.full_batch, self.full_block diff --git a/sgl/sampler/sampler.py b/sgl/sampler/sampler.py index b38a5fd..51ee871 100644 --- a/sgl/sampler/sampler.py +++ b/sgl/sampler/sampler.py @@ -2,12 +2,13 @@ import torch import numpy as np import pickle as pkl -import networkx as nx import scipy.sparse as sp from torch_sparse import SparseTensor -from torch_geometric.utils import from_networkx, mask_to_index +from torch_geometric.utils import mask_to_index +from sgl.data.base_data import Block +from sgl.utils import sparse_mx_to_pyg_sparse_tensor from sgl.sampler.base_sampler import NodeWiseSampler, LayerWiseSampler, GraphWiseSampler @@ -49,7 +50,7 @@ def collate_fn(self, batch_inds): all_adjs = self._post_process(all_adjs, to_sparse_tensor=False) - return cur_tgt_nodes, batch_inds, self.to_Block(all_adjs, self._sparse_type) + return cur_tgt_nodes, batch_inds, Block(all_adjs, self._sparse_type) class FastGCNSampler(LayerWiseSampler): def __init__(self, adj, **kwargs): @@ -84,18 +85,18 @@ def collate_fn(self, batch_inds): all_adjs = self._post_process(all_adjs, to_sparse_tensor=False) - return cur_out_nodes, batch_inds, self.to_Block(all_adjs, self._sparse_type) + return cur_out_nodes, batch_inds, Block(all_adjs, self._sparse_type) class ClusterGCNSampler(GraphWiseSampler): """ Clustering the graph, feature set and target. """ - def __init__(self, dataset, **kwargs): + def __init__(self, dataset, inductive=False, **kwargs): """ Inputs: adj: Adjacency matrix (Networkx Graph). """ - super(ClusterGCNSampler, self).__init__(nx.from_scipy_sparse_matrix(dataset.adj), **kwargs) + super(ClusterGCNSampler, self).__init__(dataset.adj[dataset.train_idx, :][:, dataset.train_idx] if inductive else dataset.adj, **kwargs) self.sampler_name = "ClusterGCNSampler" self.sample_level = "graph" self.pre_sampling = True # conduct sampling only once before training @@ -129,9 +130,23 @@ def collate_fn(self, batch_inds, mode): start = self.partptr[batch_inds].tolist() end = self.partptr[batch_inds + 1].tolist() node_idx = torch.cat([torch.arange(s, e) for s, e in zip(start, end)]) + stack_row, stack_col, stack_value = [], [], [] + num_node = 0 + for i, batch_ind in enumerate(batch_inds): + batch_ind = batch_ind.item() + perm_adj = self.splitted_perm_adjs[batch_ind] + row, col, value = perm_adj.coo() + row = row + num_node + col = col + num_node + num_node += end[i] - start[i] + stack_row.append(row) + stack_col.append(col) + stack_value.append(value) + stack_row = torch.cat(stack_row) + stack_col = torch.cat(stack_col) + stack_value = torch.cat(stack_value) + block = Block(SparseTensor(row=stack_row, col=stack_col, value=stack_value, sparse_sizes=(num_node, num_node)), sparse_type=self._sparse_type) global_node_idx = self.perm_node_idx[node_idx] - composed_sparse_mx = sp.block_diag([self.splitted_perm_adjs[batch_ind.item()] for batch_ind in batch_inds]) - block = self.to_Block(composed_sparse_mx, self._sparse_type) if mode in ["train", "val", "test"]: mask = self._masks[mode][global_node_idx] global_inds = global_node_idx[mask] @@ -148,12 +163,12 @@ def collate_fn(self, batch_inds, mode): return global_node_idx, batch_out, block def _metis_clustering(self): - data = from_networkx(self._adj) - N, E = data.num_nodes, data.num_edges - adj = SparseTensor( - row=data.edge_index[0], col=data.edge_index[1], - value=torch.arange(E, device=data.edge_index.device), - sparse_sizes=(N, N)) + adj = sparse_mx_to_pyg_sparse_tensor(self._adj) + r""" + perm_adjs: SparseTensor + len(self.partptr) == self.cluster_number + 1 + len(self.perm_node_idx) = num_nodes + """ self.perm_adjs, self.partptr, self.perm_node_idx = adj.partition(self.cluster_number, False) self.splitted_perm_adjs = [] for i in range(len(self.partptr)-1): @@ -161,12 +176,7 @@ def _metis_clustering(self): node_idx = torch.arange(start, end) perm_adj = self.perm_adjs.narrow(0, start, end-start) perm_adj = perm_adj.index_select(1, node_idx) - row, col, _ = perm_adj.coo() - row, col = row.numpy(), col.numpy() - num_nodes = len(node_idx) - sparse_mx = sp.coo_matrix((np.ones_like(row), (row, col)), shape=(num_nodes, num_nodes)) - sparse_mx = self._post_process(sparse_mx, to_sparse_tensor=False) - self.splitted_perm_adjs.append(sparse_mx) + self.splitted_perm_adjs.append(perm_adj) if self._save_dir is not None: torch.save((self.perm_adjs, self.partptr, self.perm_node_idx), self._save_path_pt) pkl.dump(self.splitted_perm_adjs, open(self._save_path_pkl, "wb")) @@ -184,43 +194,36 @@ def __init__(self, dataset, **kwargs): super(GraphSAINTSampler, self).__init__(dataset.adj, **kwargs) self.replace = True - self.sampler_name = "GraphSaintSampler" + self.sampler_name = "GraphSAINTSampler" self.sample_level = "graph" self.pre_sampling = False self._masks = {"train": dataset.train_mask, "val": dataset.val_mask, "test": dataset.test_mask} - self.n = dataset.adj.shape[0] - self.e = dataset.adj.nnz + def _pre_process(self, **kwargs): + self.num_node = self._adj.shape[0] + self.num_edge = self._adj.nnz self.pre_sampling_times = kwargs.get("pre_sampling_graphs", 1) self.used_sample_graphs = 0 - if kwargs['sampler_type'] == "Node": + if kwargs["sampler_type"] == "node": kwargs.update({"prob_type": "normalize"}) self._calc_probs(**kwargs) self.node_probs = self.probs - self.node_budget = kwargs['nodebudget'] - self.sample_graph_type = "Node" - elif kwargs['sampler_type'] == 'Edge': + self.node_budget = kwargs["nodebudget"] + elif kwargs["sampler_type"] == "edge": self._calc_edge_probs() - self.edge_budget = kwargs['edgebudget'] - self.sample_graph_type = "Edge" - elif kwargs['sampler_type'] == 'RandomWalk': - self.r = kwargs['r'] - self.h = kwargs['h'] - self.sample_graph_type = "RandomWalk" + self.edge_budget = kwargs["edgebudget"] + elif kwargs["sampler_type"] == "random_walk": + self.r = kwargs["r"] + self.h = kwargs["h"] else: raise NotImplementedError + self.sample_graph_type = kwargs["sampler_type"] + @property def sample_graph_ops(self): - if self.sample_graph_type == "Node": - return self.node_sampler() - elif self.sample_graph_type == "Edge": - return self.edge_sampler() - elif self.sample_graph_type == "RandomWalk": - return self.random_walk_sampler() - else: - raise NotImplementedError + return getattr(self, f"{self.sample_graph_type}_sampler") def node_sampler(self): """ @@ -233,11 +236,12 @@ def node_sampler(self): p = self.node_probs - sampled_node = np.random.choice(a=self.n, size=self.node_budget, replace=self.replace, p=p) + sampled_node = np.random.choice(a=self.num_node, size=self.node_budget, replace=self.replace, p=p) sampled_node = np.unique(sampled_node) subadj = self._adj[sampled_node, :] subadj = subadj[:, sampled_node] + return sampled_node, subadj def _calc_edge_probs(self): @@ -251,7 +255,6 @@ def _calc_edge_probs(self): self.edge_probs = 1 / start_degrees + 1 / end_degrees self.edge_probs = self.edge_probs / np.sum(self.edge_probs) - return def edge_sampler(self): """ @@ -263,14 +266,14 @@ def edge_sampler(self): """ p = self.edge_probs - sampled_edges = np.random.choice(a=self.e, size=self.edge_budget, replace=self.replace, p=p) + sampled_edges = np.random.choice(a=self.num_edge, size=self.edge_budget, replace=self.replace, p=p) sampled_edges = np.unique(sampled_edges) edges = self._adj.nonzero() sampled_start = edges[0][sampled_edges] sampled_end = edges[1][sampled_edges] - sampled_node = np.unique(np.concatenate([sampled_start,sampled_end])) + sampled_node = np.unique(np.concatenate([sampled_start, sampled_end])) subadj = self._adj[sampled_node, :] subadj = subadj[:, sampled_node] @@ -285,7 +288,7 @@ def random_walk_sampler(self): sampled_node: global node index block: sampled adjs, csr sparse matrix """ - root_nodes = np.random.choice(a=self.n, size=self.r, replace = self.replace) + root_nodes = np.random.choice(a=self.num_node, size=self.r, replace=self.replace) sampled_node = [] for v in root_nodes: sampled_node.append(v) @@ -308,22 +311,22 @@ def _calc_norm(self): """ self.sampled_graphs = [] - node_value = np.zeros(self.n) - edge_value = sp.lil_matrix((self.n,self.n)) + node_value = np.zeros(self.num_node) + edge_value = sp.lil_matrix((self.num_node, self.num_node)) for _ in range(self.pre_sampling_times): - sampled, adj = self.sample_graph_ops - self.sampled_graphs.append((sampled,adj)) + sampled_node, adj = self.sample_graph_ops() + adj = self._post_process(adj, to_sparse_tensor=False) + self.sampled_graphs.append((sampled_node, adj)) adj = adj.tocoo() for row, col in zip(adj.row, adj.col): - edge_value[sampled[row],sampled[col]] += 1 - node_value[sampled] += 1 + edge_value[sampled_node[row], sampled_node[col]] += 1 + node_value[sampled_node] += 1 edge_value = edge_value.tocsr().dot(sp.diags(1.0 / np.maximum(node_value, 1))) self.aggr_norm = edge_value - self.loss_norm = torch.FloatTensor(np.maximum(node_value, 1)) - return + self.loss_norm = torch.FloatTensor(np.maximum(node_value, 1) / self.pre_sampling_times) def collate_fn(self, batch_ids, mode): """ @@ -344,7 +347,7 @@ def collate_fn(self, batch_ids, mode): sampled, adj = self.sampled_graphs[self.used_sample_graphs] self.used_sample_graphs += 1 else: - sampled, adj = self.sample_graph_ops + sampled, adj = self.sample_graph_ops() sampled_aggr_norm = self.aggr_norm[sampled, :] sampled_aggr_norm = sampled_aggr_norm[:, sampled] @@ -371,4 +374,4 @@ def collate_fn(self, batch_ids, mode): self.cur_index = global_inds - return batch_in, batch_out, self.to_Block(batched_adj, self._sparse_type) \ No newline at end of file + return batch_in, batch_out, Block(batched_adj, self._sparse_type) \ No newline at end of file diff --git a/sgl/search/gda_hpo/search_config.py b/sgl/search/gda_hpo/search_config.py index 47bb3f2..84dcb6a 100644 --- a/sgl/search/gda_hpo/search_config.py +++ b/sgl/search/gda_hpo/search_config.py @@ -4,9 +4,8 @@ import sgl.models.homo.gda as GDAModel import sgl.tasks as Task -class BaseGDAConfigManager(): +class BaseGDAConfigManager: def __init__(self, gda_model_name: str, task_name: str, model_keys: List[str], task_keys: List[str], const_model_kwargs: dict, const_task_kwargs: dict, hier_params: dict): - super(BaseGDAConfigManager, self).__init__() self._gda_model_name = gda_model_name self._task_name = task_name self._model_keys = model_keys diff --git a/sgl/tasks/node_classification_sampling.py b/sgl/tasks/node_classification_sampling.py index 88286ed..5614be6 100644 --- a/sgl/tasks/node_classification_sampling.py +++ b/sgl/tasks/node_classification_sampling.py @@ -13,7 +13,7 @@ class NodeClassification_Sampling(BaseTask): def __init__(self, dataset, model, lr, weight_decay, epochs, device, loss_fn="nll_loss", seed=42, - inductive=False, train_batch_size=None, eval_batch_size=None, eval_freq=1, eval_start=1, runs=1, verbose=True, **kwargs): + inductive=False, train_batch_size=None, eval_batch_size=None, eval_freq=1, eval_start=1, runs=1, verbose=True, max_patience=50, **kwargs): super(NodeClassification_Sampling, self).__init__() self.__dataset = dataset @@ -29,6 +29,7 @@ def __init__(self, dataset, model, lr, weight_decay, epochs, device, loss_fn="nl self.__seed = seed self.__runs = runs self.__verbose = verbose + self.__max_patience = max_patience self.__inductive = inductive self.__train_batch_size= train_batch_size self.__eval_batch_size = eval_batch_size @@ -103,6 +104,7 @@ def _execute(self): self.__model.reset_parameters() t_total = time.time() + patience = 0 best_val = 0. best_test = 0. @@ -136,18 +138,25 @@ def _execute(self): if acc_val > best_val: best_val = acc_val best_test = acc_test - - print('Epoch: {:03d}'.format(epoch + 1), - 'loss_train: {:.4f}'.format(loss_train), - 'acc_train: {:.4f}'.format(acc_train), - 'acc_val: {:.4f}'.format(acc_val), - 'acc_test: {:.4f}'.format(acc_test), - 'time: {:.4f}s'.format(time.time() - t)) + patience = 0 + else: + patience += 1 + if patience == self.__max_patience: + break + + if self.__verbose: + print('Epoch: {:03d}'.format(epoch + 1), + 'loss_train: {:.4f}'.format(loss_train), + 'acc_train: {:.4f}'.format(acc_train), + 'acc_val: {:.4f}'.format(acc_val), + 'acc_test: {:.4f}'.format(acc_test), + 'time: {:.4f}s'.format(time.time() - t)) else: - print('Epoch: {:03d}'.format(epoch + 1), - 'loss_train: {:.4f}'.format(loss_train), - 'acc_train: {:.4f}'.format(acc_train), - 'time: {:.4f}s'.format(time.time() - t)) + if self.__verbose: + print('Epoch: {:03d}'.format(epoch + 1), + 'loss_train: {:.4f}'.format(loss_train), + 'acc_train: {:.4f}'.format(acc_train), + 'time: {:.4f}s'.format(time.time() - t)) acc_val, acc_test = self._postprocess() if acc_val > best_val: From c81151390f9c4fdc270a5d87021045a3dead9d9d Mon Sep 17 00:00:00 2001 From: infinity Date: Sat, 6 Jan 2024 14:02:30 +0000 Subject: [PATCH 28/28] remove private information. --- .gitignore | 4 +- examples/GDA/test_search_GAug.py | 2 +- examples/gamlp_products.py | 8 ++- sgl/models/base_model.py | 71 ++++++++++++++++++++++- sgl/models/homo/gda/FLAG.py | 30 +++++----- sgl/models/homo/gda/Mixup.py | 52 +++++++++-------- sgl/models/homo/gda/gen_graphs.py | 4 +- sgl/operators/base_op.py | 5 +- sgl/tasks/node_classification.py | 33 ++++++----- sgl/tasks/node_classification_sampling.py | 18 +++--- 10 files changed, 156 insertions(+), 71 deletions(-) diff --git a/.gitignore b/.gitignore index fe6ba9d..952e624 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,6 @@ __pycache__/ build/ dist/ -logs/ \ No newline at end of file +logs/ + +configs/ \ No newline at end of file diff --git a/examples/GDA/test_search_GAug.py b/examples/GDA/test_search_GAug.py index 05d20e7..170185c 100644 --- a/examples/GDA/test_search_GAug.py +++ b/examples/GDA/test_search_GAug.py @@ -11,7 +11,7 @@ parser.add_argument("--device", type=int, default=0, help="gpu device id or cpu(-1)") parser.add_argument("--dataset_classname", type=str, default="Planetoid", help="class name of the dataset") parser.add_argument("--name", type=str, default="cora", help="dataset name") - parser.add_argument("--root", type=str, default="/home/ssq/test_data/", help="root dir for dataset") + parser.add_argument("--root", type=str, default="data/", help="root dir for dataset") parser.add_argument("--gnn_type", type=str, default="gcn", choices=["gcn", "gsage", "gat"], help="gnn backbone") parser.add_argument("--not_gae", action="store_true", default=False, help="whether not to use gae") parser.add_argument("--minibatch", action="store_true", default=False, help="whether to use minibatch") diff --git a/examples/gamlp_products.py b/examples/gamlp_products.py index a9ec6aa..89e41b3 100644 --- a/examples/gamlp_products.py +++ b/examples/gamlp_products.py @@ -5,13 +5,15 @@ if __name__ == "__main__": parser = argparse.ArgumentParser("GMLP") + parser.add_argument("--device", type=int, default=0, help="GPU ID or CPU (-1)") parser.add_argument("--hidden-dim", type=int, default=512, help="dimension of hidden layer") parser.add_argument("--num-layers", type=int, default=3, help="number of layers") + parser.add_argument("--dataset_root", type=str, default="data/", help="dataset path") args = parser.parse_args() - dataset = Ogbn("products", "./", "official") + dataset = Ogbn("products", args.dataset_root, "official") model = GAMLP(prop_steps=3, feat_dim=dataset.num_features, output_dim=dataset.num_classes, hidden_dim=args.hidden_dim, num_layers=args.num_layers) - device = "cuda:0" - test_acc = NodeClassification(dataset, model, lr=0.1, weight_decay=5e-5, epochs=200, device=device).test_acc + device = f"cuda:{args.device}" + test_acc = NodeClassification(dataset, model, lr=0.1, weight_decay=5e-5, epochs=200, device=device, train_batch_size=100000, eval_batch_size=200000).test_acc diff --git a/sgl/models/base_model.py b/sgl/models/base_model.py index 01baf6a..07ec5aa 100644 --- a/sgl/models/base_model.py +++ b/sgl/models/base_model.py @@ -20,7 +20,10 @@ def __init__(self, prop_steps, feat_dim, output_dim): self._processed_feature = None self._pre_msg_learnable = False - def preprocess(self, adj, feature): + def reset_parameters(self): + pass # TODO + + def preprocess(self, adj, feature, *args): if self._pre_graph_op is not None: self._processed_feat_list = self._pre_graph_op.propagate( adj, feature) @@ -34,7 +37,73 @@ def preprocess(self, adj, feature): else: self._pre_msg_learnable = False self._processed_feature = feature + + @staticmethod + def model_train(model, train_idx, labels, device, optimizer, loss_fn, accuracy): + model.train() + optimizer.zero_grad() + + train_output = model.model_forward(train_idx, device) + loss_train = loss_fn(train_output, labels[train_idx]) + acc_train = accuracy(train_output, labels[train_idx]) + loss_train.backward() + optimizer.step() + + return loss_train.item(), acc_train + + @staticmethod + @torch.no_grad() + def model_evaluate(model, val_idx, test_idx, labels, device, metric): + model.eval() + val_output = model.model_forward(val_idx, device) + test_output = model.model_forward(test_idx, device) + acc_val = metric(val_output, labels[val_idx]) + acc_test = metric(test_output, labels[test_idx]) + + return acc_val, acc_test + + @staticmethod + def model_mini_batch_train(model, train_idx, train_loader, labels, device, optimizer, loss_fn): + model.train() + correct_num = 0 + loss_train_sum = 0. + for batch in train_loader: + train_output = model.model_forward(batch, device) + loss_train = loss_fn(train_output, labels[batch]) + + pred = train_output.max(1)[1].type_as(labels) + correct_num += pred.eq(labels[batch]).double().sum() + loss_train_sum += loss_train.item() + + optimizer.zero_grad() + loss_train.backward() + optimizer.step() + + loss_train = loss_train_sum / len(train_loader) + acc_train = correct_num / len(train_idx) + + return loss_train, acc_train.item() + + @staticmethod + @torch.no_grad() + def model_mini_batch_evaluate(model, val_idx, val_loader, test_idx, test_loader, labels, device): + model.eval() + correct_num_val, correct_num_test = 0, 0 + for batch in val_loader: + val_output = model.model_forward(batch, device) + pred = val_output.max(1)[1].type_as(labels) + correct_num_val += pred.eq(labels[batch]).double().sum() + acc_val = correct_num_val / len(val_idx) + + for batch in test_loader: + test_output = model.model_forward(batch, device) + pred = test_output.max(1)[1].type_as(labels) + correct_num_test += pred.eq(labels[batch]).double().sum() + acc_test = correct_num_test / len(test_idx) + + return acc_val.item(), acc_test.item() + def postprocess(self, adj, output): if self._post_graph_op is not None: if self._post_msg_op.aggr_type in [ diff --git a/sgl/models/homo/gda/FLAG.py b/sgl/models/homo/gda/FLAG.py index 56f42bd..0a6a487 100644 --- a/sgl/models/homo/gda/FLAG.py +++ b/sgl/models/homo/gda/FLAG.py @@ -69,19 +69,21 @@ def flag(self, ground_truth_y, optimizer, device, train_idx, loss_fn): return loss.item() - def train_func(self, train_idx, labels, device, optimizer, loss_fn, metric): - loss_train = self.flag(labels[train_idx], optimizer, device, train_idx, loss_fn) + @staticmethod + def model_train(model, train_idx, labels, device, optimizer, loss_fn, metric): + loss_train = model.flag(labels[train_idx], optimizer, device, train_idx, loss_fn) - self._base_model.eval() - pred_y = self._base_model(self.__features, self.__processed_adj) + model.eval() + pred_y = model(model.processed_feature, model.processed_adj) acc_train = metric(pred_y[train_idx], labels[train_idx]) return loss_train, acc_train + @staticmethod @torch.no_grad() - def evaluate_func(self, val_idx, test_idx, labels, device, metric): - self._base_model.eval() - pred_y = self._base_model(self.__features, self.__processed_adj) + def model_evaluate(model, val_idx, test_idx, labels, device, metric): + model.eval() + pred_y = model(model.processed_feature, model.processed_adj) acc_val = metric(pred_y[val_idx], labels[val_idx]) acc_test = metric(pred_y[test_idx], labels[test_idx]) @@ -170,13 +172,14 @@ def mini_batch_prepare_forward(self, batch, device, loss_fn, optimizer, inductiv return loss, pred_y, y_truth - def train_func(self, train_loader, inductive, device, optimizer, loss_fn): + @staticmethod + def model_train(model, train_loader, inductive, device, optimizer, loss_fn): correct_num = 0 loss_train_sum = 0. train_num = 0 for batch in train_loader: - loss_train, y_out, y_truth = self.mini_batch_prepare_forward(batch, device, loss_fn, optimizer, inductive=inductive) + loss_train, y_out, y_truth = model.mini_batch_prepare_forward(batch, device, loss_fn, optimizer, inductive=inductive) pred = y_out.max(1)[1].type_as(y_truth) correct_num += pred.eq(y_truth).double().sum() loss_train_sum += loss_train @@ -187,14 +190,15 @@ def train_func(self, train_loader, inductive, device, optimizer, loss_fn): return loss_train, acc_train.item() + @staticmethod @torch.no_grad() - def evaluate_func(self, val_loader, test_loader, device): - self._base_model.eval() + def model_evaluate(model, val_loader, test_loader, device): + model.eval() correct_num_val, correct_num_test = 0, 0 val_num = 0 for batch in val_loader: - val_output, out_y = self.model_forward(batch, device) + val_output, out_y = model.model_forward(batch, device) pred = val_output.max(1)[1].type_as(out_y) correct_num_val += pred.eq(out_y).double().sum() val_num += len(out_y) @@ -203,7 +207,7 @@ def evaluate_func(self, val_loader, test_loader, device): test_num = 0 for batch in test_loader: - test_output, out_y = self.model_forward(batch, device) + test_output, out_y = model.model_forward(batch, device) pred = test_output.max(1)[1].type_as(out_y) correct_num_test += pred.eq(out_y).double().sum() test_num += len(out_y) diff --git a/sgl/models/homo/gda/Mixup.py b/sgl/models/homo/gda/Mixup.py index 3007006..a88c6de 100644 --- a/sgl/models/homo/gda/Mixup.py +++ b/sgl/models/homo/gda/Mixup.py @@ -12,8 +12,8 @@ class Mixup(nn.Module): def __init__(self, in_dim, hidden_dim, n_classes, n_layers, dropout, alpha, beta, gnn_type="sage", feat_norm="row", activation=F.relu, **kwargs): super(Mixup, self).__init__() - self.__alpha = alpha - self.__beta = beta + self.alpha = alpha + self.beta = beta self.__feat_norm = feat_norm self.nc_net = TwoBranchGNN(in_dim, hidden_dim, n_classes, n_layers, dropout, gnn_type, activation, **kwargs) @@ -46,35 +46,37 @@ def loss_fn(mix_ratio, output, y_raw, y_b, train_idx): def reset_parameters(self): self.nc_net.reset_parameters() - def train_func(self, train_idx, y_raw, device, optimizer, loss_fn, metric): - self.nc_net.train() - mix_ratio = np.random.beta(self.__alpha, self.__beta) - id_old_value_new, adj_b, y_b = self._mixup(train_idx, y_raw, device) - output = self.nc_net(self.__features, self.__adj, adj_b, mix_ratio, id_old_value_new) + @staticmethod + def model_train(model, train_idx, y_raw, device, optimizer, loss_fn, metric): + model.nc_net.train() + mix_ratio = np.random.beta(model.alpha, model.beta) + id_old_value_new, adj_b, y_b = model.mixup(train_idx, y_raw, device) + output = model.nc_net(model.processed_feature, model.processed_block, adj_b, mix_ratio, id_old_value_new) loss = loss_fn(mix_ratio, output, y_raw, y_b, train_idx) optimizer.zero_grad() loss.backward() optimizer.step() - self.nc_net.eval() - output = self.forward(self.__features, self.__adj) + model.nc_net.eval() + output = model(model.processed_feature, model.processed_block) acc = metric(output[train_idx], y_raw[train_idx]) return loss.item(), acc - + + @staticmethod @torch.no_grad() - def evaluate_func(self, val_idx, test_idx, labels, device, metric): - self.nc_net.eval() + def model_evaluate(model, val_idx, test_idx, labels, device, metric): + model.nc_net.eval() - pred_y = self.forward(self.__features, self.__adj) + pred_y = model(model.processed_feature, model.processed_block) acc_val = metric(pred_y[val_idx], labels[val_idx]) acc_test = metric(pred_y[test_idx], labels[test_idx]) return acc_val, acc_test - def _mixup(self, train_idx, y_raw, device): + def mixup(self, train_idx, y_raw, device): id_old_value_new = torch.arange(self.__num_nodes, dtype=torch.long) train_idx_shuffle = np.asarray(train_idx) np.random.shuffle(train_idx_shuffle) @@ -107,8 +109,8 @@ def postprocess(self, adj, output): class SampleMixup(BaseSAMPLEModel): def __init__(self, training_sampler, eval_sampler, in_dim, hidden_dim, n_classes, n_layers, dropout, alpha, beta, gnn_type="sage", feat_norm="row", activation=F.relu, **kwargs): super(SampleMixup, self).__init__(sparse_type="pyg") - self.__alpha = alpha - self.__beta = beta + self.alpha = alpha + self.beta = beta self.__feat_norm = feat_norm self._training_sampling_op = training_sampler self._eval_sampling_op = eval_sampler @@ -175,16 +177,17 @@ def mini_batch_prepare_forward(self, batch, device, loss_fn, optimizer, inductiv return loss.item(), output, y_raw - def train_func(self, train_loader, inductive, device, optimizer, loss_fn): + @staticmethod + def train_func(model, train_loader, inductive, device, optimizer, loss_fn): correct_num = 0 loss_train_sum = 0. train_num = 0 - self._base_model.train() - mix_ratio = np.random.beta(self.__alpha, self.__beta) + model.train() + mix_ratio = np.random.beta(model.alpha, model.beta) for batch in train_loader: - loss_train, y_out, y_truth = self.mini_batch_prepare_forward(batch, device, loss_fn, optimizer, inductive=inductive, mix_ratio=mix_ratio) + loss_train, y_out, y_truth = model.mini_batch_prepare_forward(batch, device, loss_fn, optimizer, inductive=inductive, mix_ratio=mix_ratio) pred = y_out.max(1)[1].type_as(y_truth) correct_num += pred.eq(y_truth).double().sum() loss_train_sum += loss_train @@ -195,14 +198,15 @@ def train_func(self, train_loader, inductive, device, optimizer, loss_fn): return loss_train, acc_train.item() + @staticmethod @torch.no_grad() - def evaluate_func(self, val_loader, test_loader, device): - self._base_model.eval() + def model_evaluate(model, val_loader, test_loader, device): + model.eval() correct_num_val, correct_num_test = 0, 0 val_num = 0 for batch in val_loader: - val_output, out_y = self.model_forward(batch, device) + val_output, out_y = model.model_forward(batch, device) pred = val_output.max(1)[1].type_as(out_y) correct_num_val += pred.eq(out_y).double().sum() val_num += len(out_y) @@ -211,7 +215,7 @@ def evaluate_func(self, val_loader, test_loader, device): test_num = 0 for batch in test_loader: - test_output, out_y = self.model_forward(batch, device) + test_output, out_y = model.model_forward(batch, device) pred = test_output.max(1)[1].type_as(out_y) correct_num_test += pred.eq(out_y).double().sum() test_num += len(out_y) diff --git a/sgl/models/homo/gda/gen_graphs.py b/sgl/models/homo/gda/gen_graphs.py index 62de9cd..ee75d8e 100644 --- a/sgl/models/homo/gda/gen_graphs.py +++ b/sgl/models/homo/gda/gen_graphs.py @@ -45,10 +45,10 @@ def __init__(self, dim_in, dim_h, dim_z, gae): self.gcn_mean = GraphConv(dim_h, dim_z, activation=False) self.gcn_logstd = GraphConv(dim_h, dim_z, activation=False) - def encode(self, adj, X): + def encode(self, adj, X, gen_Z=False): hidden = self.base_gcn(adj, X) self.mean = self.gcn_mean(adj, hidden) - if self.gae: + if self.gae or gen_Z: return self.mean else: self.logstd = self.gcn_logstd(adj, hidden) diff --git a/sgl/operators/base_op.py b/sgl/operators/base_op.py index a439c08..b74ad49 100644 --- a/sgl/operators/base_op.py +++ b/sgl/operators/base_op.py @@ -18,10 +18,13 @@ def _construct_adj(self, adj): def propagate(self, adj, feature): self._adj = self._construct_adj(adj) + + if isinstance(feature, Tensor): + feature = feature.numpy() if not isinstance(adj, sp.csr_matrix): raise TypeError("The adjacency matrix must be a scipy csr sparse matrix!") - elif not isinstance(feature, np.ndarray): ###代码Node类中已经转成了torch.FloatTensor + elif not isinstance(feature, np.ndarray): raise TypeError("The feature matrix must be a numpy.ndarray!") elif self._adj.shape[1] != feature.shape[0]: raise ValueError("Dimension mismatch detected for the adjacency and the feature matrix!") diff --git a/sgl/tasks/node_classification.py b/sgl/tasks/node_classification.py index 21772b8..34aa3ca 100644 --- a/sgl/tasks/node_classification.py +++ b/sgl/tasks/node_classification.py @@ -7,7 +7,9 @@ from typing import Callable from sgl.tasks.base_task import BaseTask -from sgl.tasks.utils import accuracy, set_seed, train, mini_batch_train, evaluate, mini_batch_evaluate +from sgl.tasks.utils import accuracy, set_seed +from sgl.tasks.utils import train as vanilla_train, evaluate as vanilla_evaluate +from sgl.tasks.utils import mini_batch_train as vanilla_mini_batch_train, mini_batch_evaluate as vanilla_mini_batch_evaluate class NodeClassification(BaseTask): @@ -70,21 +72,22 @@ def _execute(self): for epoch in range(self.__epochs): t = time.time() if self.__mini_batch is False: - if hasattr(self.__model, "train_func") and isinstance(self.__model.train_func, Callable): - loss_train, acc_train = self.__model.train_func(self.__dataset.train_idx, self.__labels, self.__device, + train = self.__model.model_train if hasattr(self.__model, "model_train") and isinstance(self.__model.model_train, Callable) \ + else vanilla_train + loss_train, acc_train = train(self.__model, self.__dataset.train_idx, self.__labels, self.__device, self.__optimizer, self.__loss_fn, accuracy) - else: - loss_train, acc_train = train(self.__model, self.__dataset.train_idx, self.__labels, self.__device, - self.__optimizer, self.__loss_fn, accuracy) - if hasattr(self.__model, "evaluate_func") and isinstance(self.__model.evaluate_func, Callable): - acc_val, acc_test = self.__model.evaluate_func(self.__dataset.val_idx, self.__dataset.test_idx, - self.__labels, self.__device, accuracy) - else: - acc_val, acc_test = evaluate(self.__model, self.__dataset.val_idx, self.__dataset.test_idx, + + evaluate = self.__model.model_evaluate if hasattr(self.__model, "model_evaluate") and isinstance(self.__model.model_evaluate, Callable) \ + else vanilla_evaluate + acc_val, acc_test = evaluate(self.__model, self.__dataset.val_idx, self.__dataset.test_idx, self.__labels, self.__device, accuracy) else: + mini_batch_train = self.__model.model_mini_batch_train if hasattr(self.__model, "model_mini_batch_train") and isinstance(self.__model.model_mini_batch_train, Callable) \ + else vanilla_mini_batch_train loss_train, acc_train = mini_batch_train(self.__model, self.__dataset.train_idx, self.__train_loader, self.__labels, self.__device, self.__optimizer, self.__loss_fn) + mini_batch_evaluate = self.__model.model_mini_batch_evaluate if hasattr(self.__model, "model_mini_batch_evaluate") and isinstance(self.__model.model_mini_batch_evaluate, Callable) \ + else vanilla_mini_batch_evaluate acc_val, acc_test = mini_batch_evaluate(self.__model, self.__dataset.val_idx, self.__val_loader, self.__dataset.test_idx, self.__test_loader, self.__labels, self.__device) @@ -213,14 +216,14 @@ def _execute(self, random_subgraph_num=-1, subgraph_edge_type_num=-1, for epoch in range(self.__epochs): t = time.time() if self.__mini_batch is False: - loss_train, acc_train = train(self.__model, self.__dataset.train_idx, self.__labels, self.__device, + loss_train, acc_train = vanilla_train(self.__model, self.__dataset.train_idx, self.__labels, self.__device, self.__optimizer, self.__loss_fn) - acc_val, acc_test = evaluate(self.__model, self.__dataset.val_idx, self.__dataset.test_idx, + acc_val, acc_test = vanilla_evaluate(self.__model, self.__dataset.val_idx, self.__dataset.test_idx, self.__labels, self.__device) else: - loss_train, acc_train = mini_batch_train(self.__model, self.__dataset.train_idx, self.__train_loader, + loss_train, acc_train = vanilla_mini_batch_train(self.__model, self.__dataset.train_idx, self.__train_loader, self.__labels, self.__device, self.__optimizer, self.__loss_fn) - acc_val, acc_test = mini_batch_evaluate(self.__model, self.__dataset.val_idx, self.__val_loader, + acc_val, acc_test = vanilla_mini_batch_evaluate(self.__model, self.__dataset.val_idx, self.__val_loader, self.__dataset.test_idx, self.__test_loader, self.__labels, self.__device) diff --git a/sgl/tasks/node_classification_sampling.py b/sgl/tasks/node_classification_sampling.py index 5614be6..eb90d9a 100644 --- a/sgl/tasks/node_classification_sampling.py +++ b/sgl/tasks/node_classification_sampling.py @@ -8,7 +8,8 @@ from sgl.data.utils import RandomLoader, SplitLoader from sgl.tasks.base_task import BaseTask -from sgl.tasks.utils import accuracy, set_seed, train, mini_batch_train, evaluate, mini_batch_evaluate +from sgl.tasks.utils import accuracy, set_seed, train, evaluate +from sgl.tasks.utils import mini_batch_train as vanilla_mini_batch_train, mini_batch_evaluate as vanilla_mini_batch_evaluate class NodeClassification_Sampling(BaseTask): @@ -111,21 +112,18 @@ def _execute(self): for epoch in range(self.__epochs): t = time.time() if self.__mini_batch_train: - if hasattr(self.__model, "train_func") and isinstance(self.__model.train_func, Callable): - loss_train, acc_train = self.__model.train_func(self.__train_loader, self.__inductive, self.__device, self.__optimizer, self.__loss_fn) - else: - loss_train, acc_train = mini_batch_train(self.__model, self.__train_loader, self.__inductive, self.__device, - self.__optimizer, self.__loss_fn) + mini_batch_train = self.__model.model_train if hasattr(self.__model, "model_train") and isinstance(self.__model.model_train, Callable) \ + else vanilla_mini_batch_train + loss_train, acc_train = mini_batch_train(self.__model, self.__train_loader, self.__inductive, self.__device, self.__optimizer, self.__loss_fn) else: loss_train, acc_train = train(self.__model, self.__dataset.train_idx, self.__optimizer, self.__loss_fn) if epoch + 1 >= self.__eval_start and (epoch + 1) % self.__eval_freq == 0: if self.__mini_batch_eval: if self.__eval_together is False: - if hasattr(self.__model, "evaluate_func") and isinstance(self.__model.evaluate_func, Callable): - acc_val, acc_test = self.__model.evaluate_func(self.__val_loader, self.__test_loader, self.__device) - else: - acc_val, acc_test = mini_batch_evaluate(self.__model, self.__val_loader, self.__test_loader, self.__device) + mini_batch_evaluate = self.__model.model_evaluate if hasattr(self.__model, "model_evaluate") and isinstance(self.__model.model_evaluate, Callable) \ + else vanilla_mini_batch_evaluate + acc_val, acc_test = mini_batch_evaluate(self.__model, self.__val_loader, self.__test_loader, self.__device) else: self.__model.eval() outputs = self.__model.inference(self.__all_eval_loader, self.__device)