Skip to content

Commit b6fdec0

Browse files
committed
Code covarge and setup.py
1 parent fbee89c commit b6fdec0

File tree

11 files changed

+144
-28
lines changed

11 files changed

+144
-28
lines changed

.coveragerc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[run]
2+
source=toma/*
3+
omit=
4+
*/tests/*
5+
setup.py

README.md

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1-
# TOrch Memory-Adaptive Algorithms
1+
# Torch Memory-adaptive Algorithms (TOMA)
2+
3+
[![Build Status](https://www.travis-ci.com/BlackHC/implicit_lambda.svg?branch=master)](https://www.travis-ci.com/BlackHC/implicit_lambda) [![codecov](https://codecov.io/gh/BlackHC/implicit_lambda/branch/master/graph/badge.svg)](https://codecov.io/gh/BlackHC/implicit_lambda) [![PyPI](https://img.shields.io/badge/PyPI-implicit_lambda-blue.svg)](https://pypi.python.org/pypi/implicit_lambda/)
4+
5+
6+
A collection of helpers to make it easier to write code that adapts to the available CUDA memory.
7+
8+
## Installation
9+
210

311
TODOs:
412
* [] write readme
5-
* [] add a test for simple_chunked
6-
* [x] add tests for explicit_*
7-
* [] add tests for the decorators
13+
* [] add doc strings

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
test=pytest
33

44
[tool:pytest]
5-
addopts = --cov=.
5+
addopts = --cov toma
66

77
[pylama:pycodestyle]
88
max_line_length = 120

setup.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Always prefer setuptools over distutils
2+
from setuptools import setup
3+
4+
# To use a consistent encoding
5+
from codecs import open
6+
from os import path
7+
8+
here = path.abspath(path.dirname(__file__))
9+
10+
# Get the long description from the README file
11+
with open(path.join(here, "README.md"), encoding="utf-8") as f:
12+
long_description = f.read()
13+
14+
setup(
15+
name="toma",
16+
# Versions should comply with PEP440. For a discussion on single-sourcing
17+
# the version across setup.py and the project code, see
18+
# https://packaging.python.org/en/latest/single_source_version.html
19+
version="0.0.0",
20+
description="Make it easy to write algorithms in PyTorch that adapt to the available CUDA memory",
21+
# Fix windows newlines.
22+
long_description=long_description.replace("\r\n", "\n"),
23+
# The project's main homepage.
24+
url="https://github.com/blackhc/toma",
25+
# Author details
26+
author="Andreas @blackhc Kirsch",
27+
author_email="[email protected]",
28+
# Choose your license
29+
license="MIT",
30+
# See https://pypi.python.org/pypi?%3Aaction=list_classifiers
31+
classifiers=[
32+
# How mature is this project? Common values are
33+
# 3 - Alpha
34+
# 4 - Beta
35+
# 5 - Production/Stable
36+
"Development Status :: 3 - Alpha",
37+
# Indicate who your project is intended for
38+
"Intended Audience :: Developers",
39+
"Intended Audience :: Science/Research",
40+
"Topic :: Software Development :: Libraries :: Python Modules",
41+
# Pick your license as you wish (should match "license" above)
42+
"License :: OSI Approved :: MIT License",
43+
"Programming Language :: Python :: 3.7",
44+
],
45+
# What does your project relate to?
46+
keywords="tools pytorch",
47+
# You can just specify the packages manually here if your project is
48+
# simple. Or you can use find_packages().
49+
packages=["toma"],
50+
package_dir={"": "src"},
51+
# List run-time dependencies here. These will be installed by pip when
52+
# your project is installed. For an analysis of "install_requires" vs pip's
53+
# requirements files see:
54+
# https://packaging.python.org/en/latest/requirements.html
55+
install_requires=[],
56+
# List additional groups of dependencies here (e.g. development
57+
# dependencies). You can install these using the following syntax,
58+
# for example:
59+
# $ pip install -e .[dev,test]
60+
extras_require={
61+
"dev": ["check-manifest"],
62+
"test": ["coverage", "codecov", "pytest", "pytest-benchmark", "pytest-cov"],
63+
},
64+
setup_requires=["pytest-runner"],
65+
)

tests/test_explicit_toma.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import torch
2+
import pytest
3+
14
from toma import explicit
25
from toma.batchsize_cache import NoBatchsizeCache, GlobalBatchsizeCache, StacktraceMemoryBatchsizeCache
36

@@ -6,7 +9,23 @@ def raise_fake_oom():
69
raise RuntimeError("CUDA out of memory.")
710

811

9-
def test_fake_explicit_toma_none():
12+
def test_fake_explicit_batch_raise():
13+
def f(batchsize):
14+
raise_fake_oom()
15+
16+
with pytest.raises(RuntimeError):
17+
explicit.batch(f, 64)
18+
19+
20+
def test_fake_explicit_range_raise():
21+
def f(start, end):
22+
raise_fake_oom()
23+
24+
with pytest.raises(RuntimeError):
25+
explicit.range(f, 0, 64, 64)
26+
27+
28+
def test_fake_explicit_batch_none():
1029
batchsizes = []
1130

1231
def f(batchsize):
@@ -22,7 +41,7 @@ def f(batchsize):
2241
assert batchsizes == [64, 32, 16, 64, 32, 16]
2342

2443

25-
def test_fake_explicit_toma_global():
44+
def test_fake_explicit_batch_global():
2645
batchsizes = []
2746

2847
def f(batchsize):
@@ -40,7 +59,7 @@ def f(batchsize):
4059
assert batchsizes == [64, 32, 16, 16, 16, 16]
4160

4261

43-
def test_fake_explicit_toma_sm():
62+
def test_fake_explicit_batch_sm():
4463
batchsizes = []
4564

4665
def f(batchsize):
@@ -58,7 +77,7 @@ def f(batchsize):
5877
assert batchsizes == [64, 32, 16, 16, 16, 64, 32, 16]
5978

6079

61-
def test_fake_explicit_toma_mix():
80+
def test_fake_explicit_batch_mix():
6281
batchsizes = []
6382

6483
def f(batchsize):
@@ -79,7 +98,7 @@ def f(batchsize):
7998
assert batchsizes == [64, 32, 16, 16, 16, 16, 64, 32, 16, 16]
8099

81100

82-
def test_fake_explicit_toma_range_none():
101+
def test_fake_explicit_range_none():
83102
batchsizes = []
84103

85104
def f(start, end):
@@ -99,7 +118,7 @@ def f(start, end):
99118
assert batchsizes == [64, 64, 32, 32, 16, 16] * 2
100119

101120

102-
def test_fake_explicit_toma_range_global():
121+
def test_fake_explicit_range_global():
103122
batchsizes = []
104123

105124
def f(start, end):
@@ -121,7 +140,7 @@ def f(start, end):
121140
assert batchsizes == [64, 64, 32, 32, 16, 16] + [16] * 8 * 2
122141

123142

124-
def test_fake_explicit_toma_range_sm():
143+
def test_fake_explicit_range_sm():
125144
batchsizes = []
126145

127146
def f(start, end):
@@ -143,7 +162,7 @@ def f(start, end):
143162
assert batchsizes == [64, 64, 32, 32, 16, 16] + [16] * 8 + [64, 64, 32, 32, 16, 16]
144163

145164

146-
def test_fake_explicit_toma_range_sm():
165+
def test_fake_explicit_range_sm():
147166
batchsizes = []
148167

149168
def f(start, end):
@@ -166,3 +185,12 @@ def f(start, end):
166185
explicit.range(f, 0, 128, 64, toma_cache_type=StacktraceMemoryBatchsizeCache)
167186

168187
assert batchsizes == ([64, 64, 32, 32, 16, 16] + [16] * 8 * 2 + [64, 64, 32, 32, 16, 16] + [16] * 8)
188+
189+
190+
def test_explicit_chunked():
191+
def func(tensor, start, end):
192+
tensor[:] = 1.
193+
194+
tensor = torch.zeros((128, 4, 4))
195+
explicit.chunked(func, tensor, 32)
196+
assert torch.allclose(tensor, torch.tensor(1.))

tests/test_simple_toma.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ def raise_fake_oom():
77
raise RuntimeError("CUDA out of memory.")
88

99

10-
def test_fake_simple_toma():
10+
def test_fake_simple_batch():
1111
hit_16 = False
1212

1313
def f(batch_size):
@@ -24,7 +24,7 @@ def f(batch_size):
2424
assert hit_16
2525

2626

27-
def test_fake_simple_toma_range():
27+
def test_fake_simple_range():
2828
hit_16 = False
2929

3030
def f(start, end):
@@ -43,7 +43,7 @@ def f(start, end):
4343
assert hit_16
4444

4545

46-
def test_fake_simple_toma_chunked():
46+
def test_fake_simple_chunked():
4747
hit_16 = False
4848

4949
def f(tensor, start, end):
@@ -63,7 +63,7 @@ def f(tensor, start, end):
6363
assert hit_16
6464

6565

66-
def test_simple_toma():
66+
def test_simple_batch():
6767
import torch
6868

6969
if not torch.cuda.is_available():
@@ -89,7 +89,7 @@ def f(batch_size):
8989
assert succeeded
9090

9191

92-
def test_simple_toma_range():
92+
def test_simple_range():
9393
import torch
9494

9595
if not torch.cuda.is_available():

tests/test_toma.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
import torch
12
from toma import toma, explicit, batchsize_cache as tbc
23

34

45
def raise_fake_oom():
56
raise RuntimeError("CUDA out of memory.")
67

78

8-
def test_fake_toma_simple():
9+
def test_fake_batch_none():
910
batchsizes = []
1011

1112
@toma.batch(initial_batchsize=64, cache_type=tbc.NoBatchsizeCache)
@@ -22,7 +23,7 @@ def f(batchsize):
2223
assert batchsizes == [64, 32, 16, 64, 32, 16]
2324

2425

25-
def test_fake_toma_explicit():
26+
def test_fake_batch_global():
2627
batchsizes = []
2728

2829
@toma.batch(initial_batchsize=64, cache_type=tbc.GlobalBatchsizeCache)
@@ -39,7 +40,7 @@ def f(batchsize):
3940
assert batchsizes == [64, 32, 16, 16]
4041

4142

42-
def test_fake_toma_range_global():
43+
def test_fake_range_none():
4344
batchsizes = []
4445

4546
@toma.range(initial_step=64, cache_type=tbc.NoBatchsizeCache)
@@ -60,7 +61,7 @@ def f(start, end):
6061
assert batchsizes == [64, 64, 32, 32, 16, 16] * 2
6162

6263

63-
def test_fake_toma_range_explicit():
64+
def test_fake_range_global():
6465
batchsizes = []
6566

6667
@toma.range(initial_step=64, cache_type=tbc.GlobalBatchsizeCache)
@@ -79,3 +80,13 @@ def f(start, end):
7980
f(0, 128)
8081

8182
assert batchsizes == [64, 64, 32, 32, 16, 16] + [16] * 8
83+
84+
85+
def test_chunked():
86+
@toma.chunked(initial_step=32)
87+
def func(tensor, start, end):
88+
tensor[:] = 1.
89+
90+
tensor = torch.zeros((128, 4, 4))
91+
func(tensor)
92+
assert torch.allclose(tensor, torch.tensor(1.))
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def batch(
178178

179179
while True:
180180
try:
181-
value = batchsize.get_batchsize()
181+
value = batchsize.get()
182182
return func(value, *args, **kwargs)
183183
except RuntimeError as exception:
184184
if value > 1 and tcm.should_reduce_batch_size(exception):
@@ -206,10 +206,10 @@ def range(
206206
current = start
207207
while current < end:
208208
try:
209-
func(current, min(current + batchsize.get_batchsize(), end), *args, **kwargs)
210-
current += batchsize.get_batchsize()
209+
func(current, min(current + batchsize.get(), end), *args, **kwargs)
210+
current += batchsize.get()
211211
except RuntimeError as exception:
212-
if batchsize.get_batchsize() > 1 and tcm.should_reduce_batch_size(exception):
212+
if batchsize.get() > 1 and tcm.should_reduce_batch_size(exception):
213213
batchsize.decrease_batchsize()
214214
tcm.gc_cuda()
215215
else:
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@ def set_initial_batchsize(self, initial_batchsize: int):
1414
if not self.value:
1515
self.value = initial_batchsize
1616

17-
def get_batchsize(self) -> int:
17+
def get(self) -> int:
1818
return self.value
1919

2020
def decrease_batchsize(self):
2121
self.value //= 2
22+
assert self.value > 0
2223

2324

2425
class BatchsizeCache:
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def _constant_code_context(code_context):
1616
def get_simple_traceback(ignore_top=0):
1717
"""Get a simple trackback that can be hashed and won't create reference
1818
cyles."""
19-
stack = inspect.stack(context=ignore_top + 1)[ignore_top : -__watermark - 1]
19+
stack = inspect.stack(context=1)[ignore_top + 1: -__watermark - 1]
2020
simple_traceback = tuple(
2121
(fi.filename, fi.lineno, fi.function, _constant_code_context(fi.code_context), fi.index) for fi in stack
2222
)

0 commit comments

Comments
 (0)