-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathutils.py
38 lines (31 loc) · 926 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import json
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def num_corrects(outputs, label_batch):
"""
How many number of outputs of a model is
equal to true labels
:param outputs: Outputs of a model
:type outputs: :py:class:`torch.Tensor`
:param label_batch: True labels of a model
:type label_batch: :py:class:`torch.Tensor`
"""
out = outputs.argmax(1)
corrects = out == label_batch
return torch.sum(corrects).item()
def init_hyper_params():
"""
Initialize hyper parameters.
Stores number of neurons in each layer
and number of layers before max operation for mlp
Stores number of channels and kernel for
each maxout layer
"""
with open('maxout.json', 'r') as f:
return json.load(f)
def gcd(a, b):
while b:
a, b = b, a % b
return a
def lcm(a, b):
return (a * b) // gcd(a, b)