Skip to content

Commit

Permalink
added register_buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
wonchul committed Jan 24, 2024
1 parent da2ffbd commit 3208f02
Showing 1 changed file with 54 additions and 0 deletions.
54 changes: 54 additions & 0 deletions _posts/pytorch/2024-01-24-torch-register-buffer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
---
layout: post
title: nn.Module.register_buffer()
category: Pytorch
tag: torch.nn
---


### `nn.Module``gpu/cuda`로 로드할 때, `nn.Parameter()`로 설정하지 않은 일반 `tensor` 및 변수를 제외하고는 로드되지 않는다.
```python
class Model(nn.Module):
def __init__(self):
super().__init__()

self.param = nn.Parameter(torch.randn([2, 2]))

buff_1 = torch.randn([2, 2])
self.register_buffer('buff_1', buff_1)

buff_2 = torch.randn([2, 2], requires_grad=True)
self.register_buffer('buff_2', buff_2)

self.non_buff = torch.randn([2, 2])

def forward(self, x):
return x
```

```python

model = Model()
print(model.param.device) # cpu
print(model.buff_1.device) # cpu
print(model.buff_2.device) # cpu
print(model.non_buff.device) # cpu

model.cuda()
print(model.param.device) # cuda:0
print(model.buff_1.device) # cuda:0
print(model.buff_2.device) # cuda:0
print(model.non_buff.device) # cpu
```

`nn.Module` 내부의 변수끼리 연산을 할 경우, `device`가 서로 mismatch되는 것을 방지하기 위해서 사용이 가능할 거 같다.


### `requires_grad=True`해도 `nn.Parameter()`로 인식되지 않기 때문에 `optimizer`에 영향을 받지 않는다.

```python
for name, param in model.named_parameters():
print(name, param.data)

# param tensor([[...]])
```

0 comments on commit 3208f02

Please sign in to comment.