Skip to content

Commit 623a6c2

Browse files
committed
Add code
1 parent 9ff2831 commit 623a6c2

19 files changed

+1880
-0
lines changed

memes/clustering/cluster_leiden.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""
2+
Run Leiden algorithm to generate clusters of hashes.
3+
"""
4+
5+
import argparse
6+
from array import array
7+
import sys
8+
from scipy.sparse import csr_matrix
9+
from tqdm import tqdm
10+
import math
11+
12+
import igraph as ig
13+
import leidenalg as la
14+
15+
# from sklearnex import patch_sklearn
16+
# patch_sklearn()
17+
from sklearn.cluster import DBSCAN, OPTICS
18+
from sklearn.neighbors._base import _check_precomputed, _is_sorted_by_data
19+
import numpy as np
20+
import pandas as pd
21+
22+
from memes.utils import read_year, DATA_DIR, construct_output_filename
23+
from memes.clustering.utils import to_binary_array, to_int
24+
from memes.clustering.clustering import read_distances, hash_to_ind, ind_to_hash
25+
26+
27+
np.random.seed(0xB1AB)
28+
29+
30+
def main(args):
31+
32+
matrix = read_distances(args.distances, args.sample, threshold=args.threshold, dist_func=lambda x: x.max() + 1 - x, upper=True)
33+
hash_index = ind_to_hash(args.hash_index)
34+
35+
print("clustering")
36+
graph = ig.Graph.Weighted_Adjacency(matrix, mode="upper")
37+
partitions = la.find_partition(
38+
graph,
39+
la.CPMVertexPartition,
40+
weights="weight",
41+
resolution_parameter=args.density,
42+
n_iterations=args.niters,
43+
seed=0xB1AB
44+
)
45+
46+
outpath = construct_output_filename(
47+
subdir=DATA_DIR / "clusters",
48+
prefix=args.prefix,
49+
suffix="leiden",
50+
ext="tsv",
51+
)
52+
with open(outpath, "w") as f:
53+
for ind, cluster in enumerate(partitions):
54+
for hash_ind in cluster:
55+
phash = hash_index[hash_ind]
56+
f.write(f"{phash}\t{ind}\n")
57+
print(len(partitions), "clusters")
58+
quality = partitions.quality() / (2 * sum(graph.es["weight"]))
59+
print("Quality score of", quality)
60+
# print(clusters.value_counts())
61+
print("done")
62+
63+
64+
if __name__ == "__main__":
65+
parser = argparse.ArgumentParser()
66+
parser.add_argument("distances")
67+
parser.add_argument("hash_index")
68+
parser.add_argument("--threshold", type=int, default=10)
69+
parser.add_argument("--eps", type=float, default=8)
70+
parser.add_argument("--density", type=float, default=1.0)
71+
parser.add_argument("--min_samples", type=float, default=3)
72+
parser.add_argument("--sample", type=int)
73+
parser.add_argument("--niters", type=int, default=2)
74+
parser.add_argument("--prefix", default=None)
75+
main(parser.parse_args(sys.argv[1:]))

memes/clustering/clustering.py

+177
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
"""
2+
Run DBSCAN to generate clusters of hashes.
3+
"""
4+
5+
import argparse
6+
from array import array
7+
import sys
8+
from scipy.sparse import csr_matrix
9+
from tqdm import tqdm
10+
import math
11+
12+
# from sklearnex import patch_sklearn
13+
# patch_sklearn()
14+
from sklearn.cluster import DBSCAN, OPTICS
15+
from sklearn.neighbors._base import _check_precomputed, _is_sorted_by_data
16+
import numpy as np
17+
import pandas as pd
18+
19+
from memes.utils import read_year, DATA_DIR, construct_output_filename
20+
from memes.clustering.utils import to_binary_array, to_int
21+
22+
23+
def read_distances(path, sample=None, return_csr=True, threshold=10, dist_func=lambda x: x, keeplist=None, upper=False):
24+
data = array("I")
25+
rows = array("I")
26+
cols = array("I")
27+
# try only taking pairs w distance <= 10
28+
# per zsavvas
29+
THRESHOLD = threshold
30+
31+
class np_buffer_wrapper:
32+
"""
33+
Create an array interface for numpy so we can directly refer to
34+
memory location.
35+
"""
36+
def __init__(self, ptr, shape, typestr):
37+
self.__array_interface__ = {
38+
"shape": shape,
39+
"typestr": typestr,
40+
"data": (ptr, True),
41+
}
42+
43+
@classmethod
44+
def from_array(cls, array):
45+
endianness = {"little": "<", "big": ">"}
46+
ptr, size = array.buffer_info()
47+
byteorder = endianness[sys.byteorder]
48+
# TODO: right now we assume unsigned int. best to infer from the
49+
# array
50+
basictype = "u"
51+
numbytes = array.itemsize
52+
typestr = byteorder + basictype + str(numbytes)
53+
return cls(ptr, (size,), typestr)
54+
55+
def add_row(ind1, ind2, dist):
56+
nonlocal data
57+
nonlocal rows
58+
nonlocal cols
59+
nonlocal keeplist
60+
if keeplist is not None:
61+
if not (int(ind1) in keeplist and int(ind2) in keeplist):
62+
return
63+
ind1 = keeplist[int(ind1)]
64+
ind2 = keeplist[int(ind2)]
65+
if int(ind1) > sample or int(ind2) > sample:
66+
return
67+
if int(dist) > THRESHOLD:
68+
return
69+
if return_csr:
70+
if upper:
71+
data.extend([int(dist)])
72+
rows.extend([int(ind1)])
73+
cols.extend([int(ind2)])
74+
else:
75+
data.extend([int(dist), int(dist)])
76+
rows.extend([int(ind1), int(ind2)])
77+
cols.extend([int(ind2), int(ind1)])
78+
79+
print("reading distances")
80+
if sample is None:
81+
sample = math.inf
82+
with open(path, "r") as f:
83+
for i, line in tqdm(enumerate(f)):
84+
if i > sample:
85+
break
86+
try:
87+
ind1, ind2, dist = line.strip().split("\t")
88+
add_row(ind1, ind2, dist)
89+
except Exception as e:
90+
# NOTE: this exception handling addresses a bug in distance
91+
# calculations that has since been fixed.
92+
pass
93+
# ind1, ind2, dist_ind3, ind4, dist2 = line.strip().split("\t")
94+
# dist = dist_ind3[: -len(ind1)]
95+
# add_row(ind1, ind2, dist)
96+
97+
# ind3 = dist_ind3[len(dist) :]
98+
# add_row(ind3, ind4, dist2)
99+
d = np.array(np_buffer_wrapper.from_array(data), copy=False)
100+
r = np.array(np_buffer_wrapper.from_array(rows), copy=False)
101+
c = np.array(np_buffer_wrapper.from_array(cols), copy=False)
102+
d = dist_func(d)
103+
if return_csr:
104+
print("constructing matrix")
105+
return csr_matrix((d, (r, c)))
106+
return np.stack([d, r, c])
107+
108+
109+
def hash_to_ind(path):
110+
"""path to file of unique hashes.
111+
112+
same as path being passed into the distance calculation.
113+
"""
114+
hashes = {}
115+
print("reading index")
116+
with open(path, "r") as f:
117+
for i, line in tqdm(enumerate(f)):
118+
hashes[line.strip()] = i
119+
return hashes
120+
121+
122+
def ind_to_hash(path):
123+
"""path to file of unique hashes.
124+
125+
same as path being passed into the distance calculation.
126+
"""
127+
hashes = list()
128+
print("reading index")
129+
with open(path, "r") as f:
130+
for i, line in tqdm(enumerate(f)):
131+
hashes.append(line.strip())
132+
return hashes
133+
134+
135+
np.random.seed(0xB1AB)
136+
137+
138+
def main(args):
139+
140+
distances = read_distances(args.distances, args.sample)
141+
# distances = _check_precomputed(distances)
142+
hash_index = ind_to_hash(args.hash_index)
143+
144+
print("clustering")
145+
146+
dbscan = DBSCAN(metric="precomputed", eps=args.eps, min_samples=args.min_samples, n_jobs=16)
147+
# try using OPTICS instead for lower memory
148+
# dbscan = OPTICS(metric="precomputed", min_samples=args.min_samples)
149+
150+
dbscan.fit(distances)
151+
clusters = pd.Series(dbscan.labels_, index=list(range(distances.shape[0])))
152+
cluster_dict = clusters.to_dict()
153+
print("writing output file")
154+
155+
outpath = construct_output_filename(
156+
subdir=DATA_DIR / "clusters",
157+
prefix=args.prefix,
158+
suffix="clusters",
159+
ext="tsv",
160+
)
161+
with open(outpath, "w") as f:
162+
for ind, cluster in cluster_dict.items():
163+
phash = hash_index[ind]
164+
f.write(f"{phash}\t{cluster}\n")
165+
print(clusters.value_counts())
166+
print("done")
167+
168+
169+
if __name__ == "__main__":
170+
parser = argparse.ArgumentParser()
171+
parser.add_argument("distances")
172+
parser.add_argument("hash_index")
173+
parser.add_argument("--eps", type=float, default=8)
174+
parser.add_argument("--min_samples", type=float, default=3)
175+
parser.add_argument("--sample", type=int)
176+
parser.add_argument("--prefix", default=None)
177+
main(parser.parse_args(sys.argv[1:]))

memes/clustering/file_to_hash.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import argparse
2+
from functools import partial
3+
import itertools
4+
from multiprocessing import Pool
5+
from memes.utils import DATA_DIR, construct_output_filename
6+
7+
import imagehash
8+
from PIL import Image
9+
10+
from tqdm import tqdm
11+
12+
13+
def get_image_hash(data, hash_size):
14+
_, item = data
15+
post_id, filepath = item
16+
try:
17+
phash = imagehash.phash(Image.open(filepath), hash_size=hash_size)
18+
return post_id, str(phash)
19+
except:
20+
return post_id, None
21+
22+
23+
def main(args):
24+
print(args.in_files)
25+
26+
def iterator(in_file):
27+
with open(in_file, "rt") as f:
28+
for ind, line in enumerate(f):
29+
try:
30+
post_id, filepath = line.strip().split("\t")
31+
except:
32+
print(line)
33+
yield ind, (post_id, filepath)
34+
35+
def data(skip=0):
36+
for i, d in enumerate(itertools.chain.from_iterable(
37+
[iterator(in_file) for in_file in args.in_files]
38+
)):
39+
if i >= skip:
40+
yield d
41+
42+
43+
pool = Pool(args.num_procs)
44+
45+
outfilename = construct_output_filename(subdir=DATA_DIR / "imagehashes", prefix=args.prefix, suffix="hashes", ext="tsv")
46+
outfile = open(outfilename, "w")
47+
48+
buf = []
49+
for ind, (post_id, phash) in tqdm(
50+
enumerate(pool.imap(partial(get_image_hash, hash_size=args.hash_size), data(args.skip), chunksize=500))
51+
):
52+
if phash is None:
53+
continue
54+
buf.append(post_id + "\t" + phash + "\n")
55+
if ind % 100_000 == 0:
56+
outfile.write("".join(buf))
57+
buf = []
58+
outfile.write("".join(buf))
59+
buf = []
60+
61+
62+
if __name__ == "__main__":
63+
parser = argparse.ArgumentParser()
64+
parser.add_argument("in_files", nargs="+", help="Path to input id-filepath map file(s)")
65+
parser.add_argument("--hash_size", default=8, type=int, help="Size of the hash")
66+
parser.add_argument(
67+
"--prefix", default=None, help="Prefix for the output filename"
68+
)
69+
parser.add_argument(
70+
"--num_procs", default=64, type=int, help="Number of processes in pool"
71+
)
72+
parser.add_argument(
73+
"--skip", default=0, type=int, help="Number of lines to skip"
74+
)
75+
main(parser.parse_args())

0 commit comments

Comments
 (0)