-
Notifications
You must be signed in to change notification settings - Fork 200
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
246 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
package: | ||
name: torchdrug | ||
version: 0.1.1 | ||
version: 0.1.2 | ||
|
||
source: | ||
path: ../.. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
diff --git a/conda/torchdrug/meta.yaml b/conda/torchdrug/meta.yaml | ||
index b366902..55604f0 100644 | ||
--- a/conda/torchdrug/meta.yaml | ||
+++ b/conda/torchdrug/meta.yaml | ||
@@ -1,6 +1,6 @@ | ||
package: | ||
name: torchdrug | ||
- version: 0.1.1 | ||
+ version: 0.1.2 | ||
|
||
source: | ||
path: ../.. | ||
diff --git a/doc/source/paper.rst b/doc/source/paper.rst | ||
index ab8489c..c22a7ae 100644 | ||
--- a/doc/source/paper.rst | ||
+++ b/doc/source/paper.rst | ||
@@ -86,9 +86,9 @@ Readout Layers | ||
|
||
1. `Order Matters: Sequence to sequence for sets <Set2Set_>`_ | ||
|
||
- Oriol Vinyals, Samy Bengio, Manjunath Kudlur | ||
+ Oriol Vinyals, Samy Bengio, Manjunath Kudlur | ||
|
||
- :class:`Set2Set <torchdrug.layers.Set2Set>` | ||
+ :class:`Set2Set <torchdrug.layers.Set2Set>` | ||
|
||
Normalization Layers | ||
^^^^^^^^^^^^^^^^^^^^ | ||
diff --git a/setup.py b/setup.py | ||
index ddf3cdb..9da3d27 100644 | ||
--- a/setup.py | ||
+++ b/setup.py | ||
@@ -13,7 +13,7 @@ if __name__ == "__main__": | ||
long_description_content_type="text/markdown", | ||
url="https://torchdrug.ai/", | ||
author="TorchDrug Team", | ||
- version="0.1.1", | ||
+ version="0.1.2", | ||
license="Apache-2.0", | ||
keywords=["deep-learning", "pytorch", "drug-discovery"], | ||
packages=setuptools.find_packages(), | ||
diff --git a/torchdrug/__init__.py b/torchdrug/__init__.py | ||
index 7058780..7dca7a0 100644 | ||
--- a/torchdrug/__init__.py | ||
+++ b/torchdrug/__init__.py | ||
@@ -12,4 +12,4 @@ handler = logging.StreamHandler(sys.stdout) | ||
handler.setFormatter(format) | ||
logger.addHandler(handler) | ||
|
||
-__version__ = "0.1.1" | ||
\ No newline at end of file | ||
+__version__ = "0.1.2" | ||
\ No newline at end of file | ||
diff --git a/torchdrug/core/core.py b/torchdrug/core/core.py | ||
index 1de312b..4c6ea18 100644 | ||
--- a/torchdrug/core/core.py | ||
+++ b/torchdrug/core/core.py | ||
@@ -355,4 +355,4 @@ def make_configurable(cls, module=None, ignore_args=()): | ||
MetaClass = type(_Configurable.__name__, (Metaclass, _Configurable), {}) | ||
else: | ||
MetaClass = _Configurable | ||
- return MetaClass(cls.__name__, (cls,), {"_ignore_args": ignore_args, "__module__": module}) | ||
+ return MetaClass(cls.__name__, (cls,), {"_ignore_args": ignore_args, "__module__": module}) | ||
\ No newline at end of file | ||
diff --git a/torchdrug/data/__init__.py b/torchdrug/data/__init__.py | ||
index bfacd4c..66131f6 100644 | ||
--- a/torchdrug/data/__init__.py | ||
+++ b/torchdrug/data/__init__.py | ||
@@ -1,3 +1,4 @@ | ||
+from .dictionary import PerfectHash, Dictionary | ||
from .graph import Graph, PackedGraph, cat | ||
from .molecule import Molecule, PackedMolecule | ||
from .dataset import MoleculeDataset, ReactionDataset, NodeClassificationDataset, KnowledgeGraphDataset, \ | ||
@@ -7,7 +8,7 @@ from . import constant | ||
from . import feature | ||
|
||
__all__ = [ | ||
- "Graph", "PackedGraph", "Molecule", "PackedMolecule", | ||
+ "Graph", "PackedGraph", "Molecule", "PackedMolecule", "PerfectHash", "Dictionary", | ||
"MoleculeDataset", "ReactionDataset", "NodeClassificationDataset", "KnowledgeGraphDataset", "SemiSupervised", | ||
"semisupervised", "key_split", "scaffold_split", "ordered_scaffold_split", | ||
"DataLoader", "graph_collate", "feature", "constant", | ||
diff --git a/torchdrug/data/dataset.py b/torchdrug/data/dataset.py | ||
index 6285244..8db50df 100644 | ||
--- a/torchdrug/data/dataset.py | ||
+++ b/torchdrug/data/dataset.py | ||
@@ -171,16 +171,32 @@ class MoleculeDataset(torch_data.Dataset, core.Configurable): | ||
def atom_types(self): | ||
"""All atom types.""" | ||
atom_types = set() | ||
- for i in range(len(self.data)): | ||
- atom_types.update(self.get_item(i)["graph"].atom_type.tolist()) | ||
+ | ||
+ if getattr(self, "lazy", False): | ||
+ warnings.warn("Calling this function for dataset with lazy=True may take a large amount of time.") | ||
+ for smiles in self.smiles_list: | ||
+ graph = data.Molecule.from_smiles(smiles, **self.kwargs) | ||
+ atom_types.update(graph.atom_type.tolist()) | ||
+ else: | ||
+ for graph in self.data: | ||
+ atom_types.update(graph.atom_type.tolist()) | ||
+ | ||
return sorted(atom_types) | ||
|
||
@utils.cached_property | ||
def bond_types(self): | ||
"""All bond types.""" | ||
bond_types = set() | ||
- for i in range(len(self.data)): | ||
- bond_types.update(self.get_item(i)["graph"].edge_list[:, 2].tolist()) | ||
+ | ||
+ if getattr(self, "lazy", False): | ||
+ warnings.warn("Calling this function for dataset with lazy=True may take a large amount of time.") | ||
+ for smiles in self.smiles_list: | ||
+ graph = data.Molecule.from_smiles(smiles, **self.kwargs) | ||
+ bond_types.update(graph.edge_list[:, 2].tolist()) | ||
+ else: | ||
+ for graph in self.data: | ||
+ bond_types.update(graph.edge_list[:, 2].tolist()) | ||
+ | ||
return sorted(bond_types) | ||
|
||
def __len__(self): | ||
diff --git a/torchdrug/models/neurallp.py b/torchdrug/models/neurallp.py | ||
index db16f7d..ef78c67 100644 | ||
--- a/torchdrug/models/neurallp.py | ||
+++ b/torchdrug/models/neurallp.py | ||
@@ -104,7 +104,7 @@ class NeuralLogicProgramming(nn.Module, core.Configurable): | ||
|
||
h_index, t_index, r_index = self.negative_sample_to_tail(h_index, t_index, r_index) | ||
hr_index = h_index * graph.num_relation + r_index | ||
- hr_index_set, hr_inverse = torch.unique(hr_index, return_inverse=True) | ||
+ hr_index_set, hr_inverse = hr_index.unique(return_inverse=True) | ||
h_index_set = hr_index_set // graph.num_relation | ||
r_index_set = hr_index_set % graph.num_relation | ||
|
||
diff --git a/torchdrug/tasks/generation.py b/torchdrug/tasks/generation.py | ||
index bb7ddc0..942e8e3 100644 | ||
--- a/torchdrug/tasks/generation.py | ||
+++ b/torchdrug/tasks/generation.py | ||
@@ -803,7 +803,7 @@ class GCPNGeneration(tasks.Task, core.Configurable): | ||
self.batch_id += 1 | ||
|
||
# generation takes less time when early_stop=True | ||
- graph = self.generate(len(batch["graph"]), max_resample=5, off_policy=True, max_step=40 * 2, verbose=1) | ||
+ graph = self.generate(len(batch["graph"]), max_resample=20, off_policy=True, max_step=40 * 2, verbose=1) | ||
if graph.num_nodes.max() == 1: | ||
raise ValueError("Generation results collapse to singleton molecules") | ||
|
||
@@ -1338,7 +1338,7 @@ class GCPNGeneration(tasks.Task, core.Configurable): | ||
self.best_results[task] = best_results | ||
|
||
@torch.no_grad() | ||
- def generate(self, num_sample, max_resample=10, off_policy=False, max_step=30 * 2, initial_smiles="C", verbose=0): | ||
+ def generate(self, num_sample, max_resample=20, off_policy=False, max_step=30 * 2, initial_smiles="C", verbose=0): | ||
is_training = self.training | ||
self.eval() | ||
|
||
diff --git a/torchdrug/utils/comm.py b/torchdrug/utils/comm.py | ||
index 0980131..817c281 100644 | ||
--- a/torchdrug/utils/comm.py | ||
+++ b/torchdrug/utils/comm.py | ||
@@ -147,7 +147,7 @@ def reduce(obj, op="sum", dst=None): | ||
Available operators are ``sum``, ``mean``, ``min``, ``max``, ``product``. | ||
dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers. | ||
|
||
- Examples:: | ||
+ Example:: | ||
|
||
>>> # assume 4 workers | ||
>>> rank = comm.get_rank() | ||
@@ -190,7 +190,7 @@ def stack(obj, dst=None): | ||
obj (Object): any container object. Can be nested list, tuple or dict. | ||
dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers. | ||
|
||
- Examples:: | ||
+ Example:: | ||
|
||
>>> # assume 4 workers | ||
>>> rank = comm.get_rank() | ||
@@ -229,7 +229,7 @@ def cat(obj, dst=None): | ||
obj (Object): any container object. Can be nested list, tuple or dict. | ||
dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers. | ||
|
||
- Examples:: | ||
+ Example:: | ||
|
||
>>> # assume 4 workers | ||
>>> rank = comm.get_rank() | ||
diff --git a/torchdrug/utils/io.py b/torchdrug/utils/io.py | ||
index 29659cf..d573cde 100644 | ||
--- a/torchdrug/utils/io.py | ||
+++ b/torchdrug/utils/io.py | ||
@@ -77,7 +77,7 @@ def capture_rdkit_log(): | ||
""" | ||
Context manager to capture all rdkit loggings. | ||
|
||
- Examples:: | ||
+ Example:: | ||
|
||
>>> with utils.capture_rdkit_log() as log: | ||
>>> ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,4 +12,4 @@ | |
handler.setFormatter(format) | ||
logger.addHandler(handler) | ||
|
||
__version__ = "0.1.1" | ||
__version__ = "0.1.2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters