-
Notifications
You must be signed in to change notification settings - Fork 0
/
weight.py
31 lines (26 loc) · 897 Bytes
/
weight.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import numpy as np
import torch
# defualt weight function
def my_weight(dataset, device, num_classed=18):
"""
작성자: 김준홍_T2059
전부터 쓰던 weight
"""
labels = np.array([0] * num_classed)
for i in dataset.dataset.targets:
labels[i] += 1
weight = [1 - (i / sum(labels)) for i in labels]
weight = torch.FloatTensor(weight).to(device)
return weight
def ins_weight(dataset, device, num_classes=18):
"""
작성자: 김준홍_T2059
https://medium.com/gumgum-tech/handling-class-imbalance-by-introducing-sample-weighting-in-the-loss-function-3bdebd8203b4 참고
"""
labels = np.array([0] * 18)
for i in dataset.targets:
labels[i] += 1
weight = 1 / labels
weight = weight / np.sum(weight) * num_classes
weight = torch.FloatTensor(weight).to(device)
return weight