Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
rememberYou committed Jun 9, 2021
2 parents c40c52b + d932a92 commit 3466899
Show file tree
Hide file tree
Showing 10 changed files with 19 additions and 15 deletions.
7 changes: 6 additions & 1 deletion examples/mutag.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@
# generated for the entities without hashing as MUTAG is a short KG.
walkers=[
HALKWalker(
2, None, n_jobs=2, random_state=RANDOM_STATE, md5_bytes=None
2,
None,
n_jobs=2,
sampler=WideSampler(),
random_state=RANDOM_STATE,
md5_bytes=None,
)
],
verbose=1,
Expand Down
4 changes: 1 addition & 3 deletions pyrdf2vec/embedders/word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Word2Vec(Embedder):
_model: The gensim.models.word2vec model.
Defaults to None.
kwargs: The keyword arguments dictionary.
Defaults to { min_count=0, negative=20, vector_size=500 }.
Defaults to { min_count=0 }.
"""

Expand All @@ -29,8 +29,6 @@ class Word2Vec(Embedder):
def __init__(self, **kwargs):
self.kwargs = {
"min_count": 0,
"negative": 20,
"vector_size": 500,
**kwargs,
}
self._model = W2V(**self.kwargs)
Expand Down
2 changes: 0 additions & 2 deletions pyrdf2vec/rdf2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,6 @@ def get_walks(self, kg: KG, entities: Entities) -> List[List[SWalk]]:
ValueError: If the provided entities aren't in the Knowledge Graph.
"""
# Avoids duplicate entities for unnecessary walk extractions.
entities = list(set(entities))
if kg.skip_verify is False and not kg.is_exist(entities):
if kg.mul_req:
asyncio.run(kg.connector.close())
Expand Down
3 changes: 2 additions & 1 deletion pyrdf2vec/walkers/anonymous.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class AnonymousWalker(RandomWalker):
Defaults to UniformSampler.
with_reverse: True to extracts parents and children hops from an
entity, creating (max_walks * max_walks) more walks of 2 * depth,
allowing also to centralize this entity in the walks. False otherwise.
allowing also to centralize this entity in the walks. False
otherwise.
Defaults to False.
"""
Expand Down
2 changes: 1 addition & 1 deletion pyrdf2vec/walkers/community.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def _dfs(
sub_walk += (pred_obj[0], pred_obj[1])
d = len(sub_walk) - 1
walks.append(sub_walk)
return list(set(walks))
return list(walks)

def extract(
self, kg: KG, entities: Entities, verbose: int = 0
Expand Down
1 change: 1 addition & 0 deletions pyrdf2vec/walkers/halk.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def _extract(self, kg: KG, entity: Vertex) -> EntityWalks:
"""
return super()._extract(kg, entity)

# flake8: noqa: C901
def _post_extract(self, res: List[EntityWalks]) -> List[List[SWalk]]:
"""Post processed walks.
Expand Down
2 changes: 1 addition & 1 deletion pyrdf2vec/walkers/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _dfs(
sub_walk += (pred_obj[0], pred_obj[1])
d = len(sub_walk) - 1
walks.append(sub_walk)
return list(set(walks))
return list(walks)

def extract_walks(self, kg: KG, entity: Vertex) -> List[Walk]:
"""Extracts random walks for an entity based on Knowledge Graph using
Expand Down
8 changes: 5 additions & 3 deletions pyrdf2vec/walkers/split.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import re
from typing import Set

Expand Down Expand Up @@ -50,6 +49,7 @@ def __attrs_post_init__(self):
if self.func_split is None:
self.func_split = self.basic_split

# flake8: noqa: C901
def basic_split(self, walks: List[Walk]) -> Set[SWalk]:
"""Splits vertices of random walks for an entity based. To achieve
this, each vertex (except the root node) is split according to symbols
Expand Down Expand Up @@ -77,8 +77,10 @@ def basic_split(self, walks: List[Walk]) -> Set[SWalk]:
"""
canonical_walks: Set[SWalk] = set()
for walk in walks:
tmp_vertices = []
canonical_walk = [] if self.with_reverse else [walk[0].name]
tmp_vertices = [] # type: ignore
canonical_walk = []
if self.with_reverse:
canonical_walk = [walk[0].name]
for i, _ in enumerate(walk[1::], 1):
vertices = []
if "http" in walk[i].name:
Expand Down
3 changes: 1 addition & 2 deletions tests/walkers/test_anonymous.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ def test_extract(
else:
assert len(walks) <= max_walks
for walk in walks:
assert not walk[0].isnumeric()
for obj in walk[2::2]:
for obj in walk[1::2]:
assert obj.isnumeric()
if not with_reverse:
assert walk[0] == root
Expand Down
2 changes: 1 addition & 1 deletion tests/walkers/test_halk.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

MAX_DEPTHS = range(15)
KGS = [KG_LOOP, KG_CHAIN]
MAX_WALKS = [None, 0, 1, 2, 3, 4, 5]
MAX_WALKS = [None, 1, 2, 3, 4, 5]
ROOTS_WITHOUT_URL = ["Alice", "Bob", "Dean"]
WITH_REVERSE = [False, True]

Expand Down

0 comments on commit 3466899

Please sign in to comment.