Skip to content

Commit 0597dfb

Browse files
committed
Add cpu_memory.py and track cpu memory by default, too.
1 parent 3088003 commit 0597dfb

File tree

9 files changed

+108
-26
lines changed

9 files changed

+108
-26
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
11
/.idea/
2+
/build/
3+
/dist/
4+
.benchmarks/
5+
*.egg-info

README.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,17 @@ toma.explicit.batch(..., toma_cache_type=toma.GlobalBatchsizeCache)
8080
### `StacktraceMemoryBatchsizeCache`: Stacktrace & Available Memory (*the default*)
8181

8282
This memorizes the successful batchsizes for a given call trace and available memory at that point.
83-
For most machine learning code this is sufficient to know the right batchsize without having to look at the actual arguments and understanding more of the semantics.
83+
For most machine learning code, this is sufficient to remember the right batchsize without having to look at the actual arguments and understanding more of the semantics.
8484

85-
The implicit assumption is that after a few iterations a stable state will be reached in regards to memory usage.
85+
The implicit assumption is that after a few iterations a stable state will be reached in regards to GPU and CPU memory usage.
86+
87+
To limit the CPU memory of the process, toma provides:
88+
```python
89+
import toma.cpu_memory
90+
91+
toma.cpu_memory.set_cpu_memory_limit(8)
92+
```
93+
This can also be useful to avoid accidental swap thrashing.
8694

8795
### `GlobalBatchsizeCache`: Global per Function
8896

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
# your project is installed. For an analysis of "install_requires" vs pip's
5454
# requirements files see:
5555
# https://packaging.python.org/en/latest/requirements.html
56-
install_requires=["torch"],
56+
install_requires=["torch", "psutil"],
5757
# List additional groups of dependencies here (e.g. development
5858
# dependencies). You can install these using the following syntax,
5959
# for example:

tests/test_benchmark.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import torch
22
import pytest_benchmark
33

4+
# Preload this import
5+
import resource
6+
47
from toma import simple, toma, explicit, NoBatchsizeCache
58

69

tests/test_cpu_mem_limit.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import torch
2+
3+
from toma import toma
4+
from toma import cpu_memory
5+
6+
7+
def test_cpu_mem_limit():
8+
cpu_memory.set_cpu_memory_limit(2)
9+
10+
batchsize = None
11+
12+
@toma.batch(initial_batchsize=2048)
13+
def allocate_gigabytes(bs):
14+
torch.empty((bs, 1024, 1024 // 4), dtype=torch.float32)
15+
16+
nonlocal batchsize
17+
batchsize = bs
18+
19+
allocate_gigabytes()
20+
21+
assert batchsize <= 512

toma/__init__.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88

99
import torch
1010

11-
import toma.torch_cuda_memory as tcm
1211
import toma.stacktrace as tst
1312
from toma.batchsize_cache import StacktraceMemoryBatchsizeCache, NoBatchsizeCache, GlobalBatchsizeCache
13+
from toma.cpu_memory import is_out_of_cpu_memory
14+
from toma.torch_cuda_memory import is_cuda_out_of_memory, is_cudnn_snafu, gc_cuda
1415

1516

1617
DEFAULT_CACHE_TYPE = StacktraceMemoryBatchsizeCache
@@ -23,22 +24,22 @@ class simple:
2324

2425
@staticmethod
2526
def batch(func, initial_batchsize: int, *args, **kwargs):
26-
tcm.gc_cuda()
27+
gc_cuda()
2728

2829
batchsize = initial_batchsize
2930
while True:
3031
try:
3132
return func(batchsize, *args, **kwargs)
3233
except RuntimeError as exception:
33-
if batchsize > 1 and tcm.should_reduce_batch_size(exception):
34+
if batchsize > 1 and should_reduce_batch_size(exception):
3435
batchsize //= 2
35-
tcm.gc_cuda()
36+
gc_cuda()
3637
else:
3738
raise
3839

3940
@staticmethod
4041
def range(func, start: int, end: int, initial_step: int, *args, **kwargs):
41-
tcm.gc_cuda()
42+
gc_cuda()
4243

4344
stepsize = initial_step
4445
current = start
@@ -47,9 +48,9 @@ def range(func, start: int, end: int, initial_step: int, *args, **kwargs):
4748
func(current, min(current + stepsize, end), *args, **kwargs)
4849
current += stepsize
4950
except RuntimeError as exception:
50-
if stepsize > 1 and tcm.should_reduce_batch_size(exception):
51+
if stepsize > 1 and should_reduce_batch_size(exception):
5152
stepsize //= 2
52-
tcm.gc_cuda()
53+
gc_cuda()
5354
else:
5455
raise
5556

@@ -170,7 +171,7 @@ class explicit:
170171
def batch(
171172
func, initial_batchsize: int, *args, toma_context=None, toma_cache_type: Type = DEFAULT_CACHE_TYPE, **kwargs
172173
):
173-
tcm.gc_cuda()
174+
gc_cuda()
174175

175176
cache = get_cache_for_context(toma_cache_type, toma_context or func)
176177

@@ -181,9 +182,9 @@ def batch(
181182
value = batchsize.get()
182183
return func(value, *args, **kwargs)
183184
except RuntimeError as exception:
184-
if value > 1 and tcm.should_reduce_batch_size(exception):
185+
if value > 1 and should_reduce_batch_size(exception):
185186
batchsize.decrease_batchsize()
186-
tcm.gc_cuda()
187+
gc_cuda()
187188
else:
188189
raise
189190

@@ -198,22 +199,22 @@ def range(
198199
toma_cache_type: Type = DEFAULT_CACHE_TYPE,
199200
**kwargs,
200201
):
201-
tcm.gc_cuda()
202+
gc_cuda()
202203

203204
cache = get_cache_for_context(toma_cache_type, toma_context or func)
204205

205206
batchsize = cache.get_batchsize(initial_step)
206207

207-
tcm.gc_cuda()
208+
gc_cuda()
208209
current = start
209210
while current < end:
210211
try:
211212
func(current, min(current + batchsize.get(), end), *args, **kwargs)
212213
current += batchsize.get()
213214
except RuntimeError as exception:
214-
if batchsize.get() > 1 and tcm.should_reduce_batch_size(exception):
215+
if batchsize.get() > 1 and should_reduce_batch_size(exception):
215216
batchsize.decrease_batchsize()
216-
tcm.gc_cuda()
217+
gc_cuda()
217218
else:
218219
raise
219220

@@ -242,3 +243,7 @@ def body(start: int, end: int):
242243
toma_context=toma_context or func,
243244
toma_cache_type=toma_cache_type,
244245
)
246+
247+
248+
def should_reduce_batch_size(exception):
249+
return is_cuda_out_of_memory(exception) or is_cudnn_snafu(exception) or is_out_of_cpu_memory(exception)

toma/batchsize_cache.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from dataclasses import dataclass
33
from typing import Optional
44

5+
import toma.cpu_memory
56
from toma import stacktrace as tst, torch_cuda_memory as tcm
67
import weakref
78

@@ -50,7 +51,9 @@ def get_batchsize(self, initial_batchsize: int) -> Batchsize:
5051

5152

5253
class StacktraceMemoryBatchsizeCache(BatchsizeCache):
53-
LRU_CACHE_SIZE = 128
54+
LRU_CACHE_SIZE: int = 128
55+
TRACK_RAM: bool = True
56+
5457
initial_batchsize: Optional[int]
5558

5659
def __init__(self, lru_cache_size=None):
@@ -59,15 +62,20 @@ def __init__(self, lru_cache_size=None):
5962
self.initial_batchsize = None
6063

6164
@functools.lru_cache(lru_cache_size or StacktraceMemoryBatchsizeCache.LRU_CACHE_SIZE)
62-
def get_batchsize_from_cache(stacktrace, available_memory):
65+
def get_batchsize_from_cache(stacktrace, cpu_available_memory, gpu_available_memory):
6366
return Batchsize(self.initial_batchsize)
6467

6568
self.get_batchsize_from_cache = get_batchsize_from_cache
6669

6770
def get_batchsize(self, initial_batchsize: int):
6871
stacktrace = tst.get_simple_traceback(2)
69-
available_memory_256MB = int(tcm.get_cuda_assumed_available_memory() // 2 ** 28)
7072

71-
batchsize = self.get_batchsize_from_cache(stacktrace, available_memory_256MB)
73+
cpu_available_memory_256MB = int(tcm.get_cuda_assumed_available_memory() // 2 ** 28)
74+
if self.TRACK_RAM:
75+
gpu_available_memory_256MB = int(toma.cpu_memory.get_available_cpu_memory() // 2 ** 28)
76+
else:
77+
gpu_available_memory_256MB = -1
78+
79+
batchsize = self.get_batchsize_from_cache(stacktrace, cpu_available_memory_256MB, gpu_available_memory_256MB)
7280
batchsize.set_initial_batchsize(initial_batchsize)
7381
return batchsize

toma/cpu_memory.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import psutil
2+
3+
4+
def get_available_cpu_memory():
5+
this_process = psutil.Process()
6+
available_memory = psutil.virtual_memory().available
7+
8+
try:
9+
import resource
10+
11+
soft_mem_limit, hard_mem_limit = resource.getrlimit(resource.RLIMIT_AS)
12+
if hard_mem_limit != resource.RLIM_INFINITY:
13+
used_memory = this_process.memory_info().vms
14+
available_memory = min(hard_mem_limit - used_memory, available_memory)
15+
except ImportError:
16+
pass
17+
18+
return available_memory
19+
20+
21+
def set_cpu_memory_limit(num_gigabytes):
22+
try:
23+
import resource
24+
25+
num_bytes = int(num_gigabytes * 2 ** 30)
26+
resource.setrlimit(resource.RLIMIT_AS, (num_bytes, num_bytes))
27+
except ImportError:
28+
pass
29+
30+
31+
def is_out_of_cpu_memory(exception):
32+
return (
33+
isinstance(exception, RuntimeError)
34+
and len(exception.args) == 1
35+
and "DefaultCPUAllocator: can't allocate memory" in exception.args[0]
36+
)

toma/torch_cuda_memory.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
def gc_cuda():
11-
"""Gargage collect Torch cuda memory."""
11+
"""Gargage collect Torch (CUDA) memory."""
1212
gc.collect()
1313
if torch.cuda.is_available():
1414
torch.cuda.empty_cache()
@@ -32,6 +32,7 @@ def get_cuda_available_memory():
3232
return get_cuda_assumed_available_memory() - get_cuda_blocked_memory()
3333
return 0
3434

35+
3536
def get_cuda_blocked_memory():
3637
if not torch.cuda.is_available():
3738
return 0
@@ -70,10 +71,6 @@ def is_cudnn_snafu(exception):
7071
)
7172

7273

73-
def should_reduce_batch_size(exception):
74-
return is_cuda_out_of_memory(exception) or is_cudnn_snafu(exception)
75-
76-
7774
def cuda_meminfo():
7875
if not torch.cuda.is_available():
7976
return

0 commit comments

Comments
 (0)