-
Notifications
You must be signed in to change notification settings - Fork 0
/
cluster.py
32 lines (21 loc) · 829 Bytes
/
cluster.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from sklearn.cluster import KMeans
from sklearn.cluster import MeanShift
import numpy as np
import pandas as pd
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
data = pd.read_pickle('./results/coords_and_embeddings.pkl')
print(data)
#KMeans
#k_means = KMeans(n_clusters=25)
#k_means.fit(data['coords'].tolist())
#data['category'] = k_means.labels_
#Mean Shift
clustering = MeanShift().fit(data['coords'].tolist())
print(clustering.labels_)
data['category'] = clustering.labels_
print(data.groupby('category', as_index=False)['word'].nunique().to_string())
data = data.groupby(['book_num', 'category'], as_index=False)['word'].nunique()
fig = plt.figure(figsize=(15,15))
plt.scatter(data['book_num'], data['category'], s=data['word']*5)
plt.savefig('./results/book-topics-scatter-category.png')