-
Notifications
You must be signed in to change notification settings - Fork 0
/
check_for_problem_imgs.py
122 lines (90 loc) · 3.26 KB
/
check_for_problem_imgs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import tarfile
import torch
from utils.benthicnet.io import row2basename
from utils.utils import get_df
def extract_filepaths(tar_dir):
filepaths = set()
# Iterate over all the tar files in the directory
for filename in os.listdir(tar_dir):
filepath = os.path.join(tar_dir, filename)
# Check if the item in the directory is a tar file
if os.path.isfile(filepath) and filepath.endswith(".tar"):
# Open the tar file
with tarfile.open(filepath, "r") as tar:
# Iterate over each member (file or folder) in the tar file
for member in tar.getmembers():
# Extract the file path
filepaths.add(member.name)
return filepaths
class BenthicNetDatasetSkeleton(torch.utils.data.Dataset):
"""BenthicNet dataset."""
def __init__(
self,
tar_dir,
annotations=None,
):
"""
Dataset for BenthicNet data.
Parameters
----------
tar_dir : str
Directory with all the images.
annotations : str
Dataframe with annotations.
transform : callable, optional
Optional transform to be applied on a sample.
"""
self.dataframe = annotations
self.valid_filepaths = extract_filepaths(tar_dir)
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
row = self.dataframe.iloc[idx]
split_img_name = row2basename(row, use_url_extension=True).split(".")
if len(split_img_name) > 1:
img_name = ".".join(split_img_name[:-1]) + ".jpg"
else:
img_name = split_img_name[0] + ".jpg"
path = row["dataset"] + "/" + row["site"] + "/" + img_name
# Need to load the file from the tarball over the network
if path in self.valid_filepaths:
return path, True
return path, False
def save_list_to_txt(list_data, file_path):
with open(file_path, "w") as file:
for item in list_data:
file.write(str(item) + "\n")
def main():
problem_img_paths = []
tar_dir = "" # Need to add tar directory path
csv_path = "./data_csv/benthicnet_unlabelled_sub_eval.csv"
df = get_df(csv_path)
print("Loaded:", csv_path)
dataset = BenthicNetDatasetSkeleton(tar_dir, df)
dataset_len = len(dataset)
batch_size = 8192
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
num_workers=1,
drop_last=False,
shuffle=False,
pin_memory=True,
)
del dataset
for batch_idx, (batch_paths, batch_found_flags) in enumerate(dataloader):
start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, dataset_len)
completion_progress = end_idx / dataset_len * 100
print(f"Processing images: {completion_progress:.2f}%", end="\r")
problem_img_paths.extend(
[path for path, found in zip(batch_paths, batch_found_flags) if not found]
)
print("\nTotal number of encountered problem images:", len(problem_img_paths))
save_list_to_txt(
problem_img_paths,
"./data_csv/unlabelled_problem_imgs.txt",
)
if __name__ == "__main__":
main()