forked from Jwoo5/ai612-project2-2023
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path00000000_dataset.py
63 lines (52 loc) · 1.85 KB
/
00000000_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from . import BaseDataset, register_dataset
@register_dataset("00000000_dataset")
class MyDataset00000000(BaseDataset):
"""
TODO:
create your own dataset here.
Rename the class name and the file name with your student number
Example:
- 20218078_dataset.py
@register_dataset("20218078_dataset")
class MyDataset20218078(BaseDataset):
(...)
"""
def __init__(
self,
data_path: str, # data_path should be a path to the processed features
# ...,
**kwargs,
):
super().__init__()
...
def __getitem__(self, index):
"""
Note:
You must return a dictionary here or in collator so that the data loader iterator
yields samples in the form of python dictionary. For the model inputs, the key should
match with the argument of the model's forward() method.
Example:
class MyDataset(...):
...
def __getitem__(self, index):
(...)
return {"data_key": data, "label": label}
class MyModel(...):
...
def forward(self, data_key, **kwargs):
(...)
"""
...
def __len__(self):
...
def collator(self, samples):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch suitable for forwarding with a Model
Note:
You can use it to make your batch on your own such as outputting padding mask together.
Otherwise, you don't need to implement this method.
"""
raise NotImplementedError