Skip to content

Commit

Permalink
Merge branch 'release/v0.2.1'
Browse files Browse the repository at this point in the history
  • Loading branch information
speedcell4 committed Nov 3, 2021
2 parents dbc2c36 + 2546970 commit b24422f
Show file tree
Hide file tree
Showing 32 changed files with 1,325 additions and 540 deletions.
52 changes: 52 additions & 0 deletions cliff.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# configuration file for git-cliff (0.1.0)

[changelog]
# changelog header
header = """
# Changelog
All notable changes to this project will be documented in this file.\n
"""
# template for the changelog body
# https://tera.netlify.app/docs/#introduction
body = """
{% if version %}\
## [{{ version | replace(from="v", to="") }}] - {{ timestamp | date(format="%Y-%m-%d") }}
{% else %}\
## [unreleased]
{% endif %}\
{% for group, commits in commits | group_by(attribute="group") %}
### {{ group | upper_first }}
{% for commit in commits %}
- {{ commit.message | upper_first }}\
{% endfor %}
{% endfor %}\n
"""
# remove the leading and trailing whitespaces from the template
trim = true
# changelog footer
footer = """
<!-- generated by git-cliff -->
"""

[git]
# allow only conventional commits
# https://www.conventionalcommits.org
conventional_commits = true
# regex for parsing and grouping commits
commit_parsers = [
{ message = "^Feat*", group = "Features" },
{ message = "^Fix*", group = "Bug Fixes" },
{ message = "^Doc*", group = "Documentation" },
{ message = "^Perf*", group = "Performance" },
{ message = "^Refactor*", group = "Refactor" },
{ message = "^Style*", group = "Styling" },
{ message = "^Test*", group = "Testing" },
{ message = "^Chore\\(release\\): prepare for*", skip = true },
{ message = "^Chore*", group = "Miscellaneous Tasks" },
]
# filter out the commits that are not matched by commit parsers
filter_commits = false
# glob pattern for matching git tags
tag_pattern = "v[0-9]*"
# regex for skipping tags
skip_tags = "v0.1.0-beta.1"
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ or you can simply manipulate existing `Pipe`s by calling `.with_` method.
```python
class PackedTokSeqPipe(PackedIdxSeqPipe):
def __init__(self, device, unk_token, special_tokens=(),
threshold=THRESHOLD, dtype=torch.long) -> None:
threshold=10, dtype=torch.long) -> None:
super(PackedTokSeqPipe, self).__init__(device=device, dtype=dtype)
self.with_(
pre=UpdateCounter(),
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

setup(
name=name,
version='0.2.0',
version='0.2.1',
packages=[package for package in find_packages() if package.startswith(name)],
url=f'https://speedcell4.github.io/torchglyph',
license='MIT',
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torch.nn.utils.rnn import PackedSequence

from torchglyph.types import PackedSequence
from torchglyph.datasets.conll2003 import CoNLL2003
from torchglyph.datasets.named_entity_recognition import CoNLL2003


def test_conll2003():
Expand Down
96 changes: 51 additions & 45 deletions torchglyph/dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import itertools
import uuid
from collections import namedtuple
from collections import namedtuple, OrderedDict
from pathlib import Path
from typing import Iterable, Any, Type
from typing import Union, List, Tuple, NamedTuple, Dict

from torch.distributions.utils import lazy_property
from torch.utils.data import DataLoader as TorchDataLoader
from torch.utils.data import Dataset as TorchDataset
from tqdm import tqdm
Expand All @@ -22,48 +22,41 @@ class Dataset(TorchDataset, DownloadMixin):
def __init__(self, pipes: List[Dict[str, Pipe]], **kwargs) -> None:
super(Dataset, self).__init__()

self.pipes: Dict[str, Pipe] = {
name: pipe
for ps in pipes
for name, pipe in ps.items()
}
self.Batch: Type = namedtuple(
typename=f'Batch_{str(uuid.uuid4())[:8]}',
field_names=list(self.pipes.keys()),
)
if self.Batch.__name__ not in globals():
globals()[self.Batch.__name__] = self.Batch
self.pipes = {}
self.names = []

self.data: Dict[str, List[Any]] = {}
for ps in pipes:
for name, pipe in ps.items():
self.pipes[name] = pipe
self.names.append(name)

self.data = {}
for datum, ps in zip(zip(*self.load(**kwargs)), pipes):
for name, pipe in ps.items():
self.data.setdefault(name, []).extend(datum)

def transpose(self) -> None:
names, data = zip(*self.data.items())
names, data = list(names), zip(*data)
self.data = [self.Batch(**dict(zip(names, datum))) for datum in data]

def __getitem__(self, index: int) -> NamedTuple:
return self.data[index]
def __getitem__(self, index: int) -> Dict[str, Any]:
return {name: self.data[name][index] for name in self.names}

def __len__(self) -> int:
return len(self.data)
return len(next(iter(self.data.values())))

@lazy_property
def named_tuple(self) -> Type:
return namedtuple(f'{self.__class__.__name__}Batch', field_names=self.names)

@property
def vocabs(self) -> NamedTuple:
return self.Batch(**{
return self.named_tuple(**{
name: pipe.vocab
for name, pipe in self.pipes.items()
})

def collate_fn(self, batch: List[NamedTuple]) -> NamedTuple:
data = self.Batch(*zip(*batch))
return self.Batch(*[
self.pipes[name].collate_fn(datum)
for name, datum in zip(self.Batch._fields, data)
])
def collate_fn(self, batch: List[Dict[str, Any]]) -> NamedTuple:
return self.named_tuple(**{
name: pipe.collate_fn([data[name] for data in batch])
for name, pipe in self.pipes.items()
})

@classmethod
def load(cls, **kwargs) -> Iterable[Any]:
Expand All @@ -72,10 +65,28 @@ def load(cls, **kwargs) -> Iterable[Any]:
def dump(self, fp, batch: NamedTuple, prediction: Any, *args, **kwargs) -> None:
raise NotImplementedError

def eval(self, path: Path, **kwargs):
raise NotImplementedError
def state_dict(self, destination: OrderedDict = None, prefix: str = '',
keep_vars: bool = False) -> OrderedDict:
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()

def viz(self, path: Path, **kwargs):
for name, datum in self.data.items():
destination[prefix + name] = datum

return destination

def load_state_dict(self, state_dict: OrderedDict, strict: bool = True) -> None:
names = set(self.names)
for name, datum in state_dict.items():
self.data[name] = datum
if strict:
names.remove(name)

if strict:
assert len(names) == 0

def eval(self, path: Path, **kwargs):
raise NotImplementedError

@classmethod
Expand All @@ -84,6 +95,8 @@ def new(cls, **kwargs) -> Tuple['DataLoader', ...]:


class DataLoader(TorchDataLoader):
dataset: Dataset

@property
def vocabs(self) -> NamedTuple:
return self.dataset.vocabs
Expand All @@ -98,19 +111,12 @@ def new(cls, datasets: Tuple[Dataset, ...],
if isinstance(batch_size, int):
batch_sizes = itertools.repeat(batch_size)

iteration = tqdm(
desc='processing datasets',
total=len(datasets) * (len(datasets[0].pipes) + 1),
)
for dataset in datasets:
for name, pipe in dataset.pipes.items():
pipe.postprocess_(dataset, name=name)
iteration.update(1)
iteration.set_postfix_str(f'{name}')
dataset.transpose()
iteration.update(1)
iteration.set_postfix_str('transpose')
iteration.close()
with tqdm(desc='post-processing', total=sum(len(dataset.pipes) for dataset in datasets)) as progress:
for index, dataset in enumerate(datasets):
for name, pipe in dataset.pipes.items():
progress.set_postfix_str(f'{index}.{name}')
pipe.postprocess_(dataset)
progress.update(1)

return tuple(
DataLoader(
Expand Down
63 changes: 0 additions & 63 deletions torchglyph/datasets/conll2003.py

This file was deleted.

97 changes: 97 additions & 0 deletions torchglyph/datasets/machine_translation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import logging
from pathlib import Path
from typing import Iterable, Any
from typing import Tuple

import torch
from torch.types import Device
from tqdm import tqdm

from torchglyph import data_dir
from torchglyph.dataset import Dataset, DataLoader
from torchglyph.pipe import PadListStrPipe

__all__ = [
'MachineTranslation',
'IWSLT14',
]


class WordPipe(PadListStrPipe):
def __init__(self, device: Device) -> None:
super(WordPipe, self).__init__(
batch_first=True, device=device,
unk_token='<unk>', pad_token='<pad>',
special_tokens=('<bos>', '<eos>'),
threshold=16, dtype=torch.long,
)


class MachineTranslation(Dataset):
@classmethod
def load(cls, path: Path, src_lang: str, tgt_lang: str, encoding: str = 'utf-8', **kwargs) -> Iterable[Any]:
src_path = path.with_name(f'{path.name}.{src_lang}')
tgt_path = path.with_name(f'{path.name}.{tgt_lang}')
with src_path.open(mode='r', encoding=encoding) as src_fp:
with tgt_path.open(mode='r', encoding=encoding) as tgt_fp:
for src, tgt in tqdm(zip(src_fp, tgt_fp), desc=f'{path.resolve()}'):
yield [src.strip().split(' '), tgt.strip().split(' ')]

@classmethod
def new(cls, batch_size: int, share_vocab: bool, src_lang: str, tgt_lang: str, *,
device: Device, root: Path = data_dir, **kwargs) -> Tuple['DataLoader', ...]:
if share_vocab:
src = tgt = WordPipe(device=device)
else:
src = WordPipe(device=device)
tgt = WordPipe(device=device)

pipes = [
dict(src=src),
dict(tgt=tgt),
]

for ps in pipes:
for name, pipe in ps.items():
logging.info(f'{name} => {pipe}')

train, dev, test = cls.paths(root=root)

train = cls(pipes=pipes, path=train, src_lang=src_lang, tgt_lang=tgt_lang)
dev = cls(pipes=pipes, path=dev, src_lang=src_lang, tgt_lang=tgt_lang)
test = cls(pipes=pipes, path=test, src_lang=src_lang, tgt_lang=tgt_lang)

src.build_vocab_(train)
if not share_vocab:
tgt.build_vocab_(train)

return DataLoader.new(
(train, dev, test),
batch_size=batch_size,
shuffle=True, drop_last=False,
)


class IWSLT14(MachineTranslation):
urls = [(
'https://raw.githubusercontent.com/pytorch/fairseq/master/examples/translation/prepare-iwslt14.sh',
'prepare-iwslt14.sh',
)]

@classmethod
def paths(cls, root: Path = data_dir, **kwargs) -> Tuple[Path, ...]:
path, = super(IWSLT14, cls).paths(root=root, **kwargs)
train = path.parent / 'iwslt14.tokenized.de-en' / 'train'
dev = path.parent / 'iwslt14.tokenized.de-en' / 'valid'
test = path.parent / 'iwslt14.tokenized.de-en' / 'test'
return train, dev, test


if __name__ == '__main__':
train, dev, test = IWSLT14.new(
batch_size=32, share_vocab=False,
src_lang='en', tgt_lang='de',
device=torch.device('cpu'),
)
for item in train:
print(item)
Loading

0 comments on commit b24422f

Please sign in to comment.