-
Notifications
You must be signed in to change notification settings - Fork 0
/
km_color_more.py
111 lines (92 loc) · 2.47 KB
/
km_color_more.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import math
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from sklearn.cluster import KMeans
# from mpl_toolkits import mplot3d
def find_vector(subdirect, filename):
vec = list()
data = pd.read_csv(os.path.join(subdirect, filename))
# print(subdirect)
X = data['posX'].values[0:30000:10]
Y = data['posZ'].values[0:30000:10]
if len(X) != 3000 or len(Y) != 3000:
# false/terminate
return vec
# plt.scatter(X, Y)
# plt.show()
for x in X:
vec.append(x)
for y in Y:
# print(y)
vec.append(y)
# print(vec)
# quit()
return vec
rootdir = os.getcwd()
arr = list()
# dictionary
row_dict = list()
# hardcord
SUBDIRECT = None
count = 0
# actual processing files
for subdir, dirs, files in os.walk(rootdir):
for file in files:
if '.csv' in file:
#quit()
# ONE TIME THING
if not SUBDIRECT:
SUBDIRECT = subdir
tmp = find_vector(subdir, file)
if not tmp:
print('Invalid file (for this decimation strategy): ', file)
continue
arr.append(tmp)
row_dict.append(file)
# print('arr:', arr)
# print('=======================')
X = np.array(arr)
# print('X:', X)
# print(np.unique(list(map(len, X))))
# number of clusters*****************
k = 8
# Machine Learning
kmeans = KMeans(n_clusters = k)
# important code:
kmeans.fit(X)
y_kmeans = kmeans.predict(X)
print(y_kmeans)
# centers = kmeans.cluster_centers_
# plt.scatter(centers[:, 0], centers[:, 1], c='black', s=200, alpha=0.5);
# in row_dict we store actual meanings of rows, in my case it's russian words
clusters = {}
n = 0
for item in y_kmeans:
if item in clusters:
clusters[item].append(row_dict[n])
else:
clusters[item] = [row_dict[n]]
n +=1
# for item in clusters:
# print("Cluster ", item)
# for i in clusters[item]:
# print(i)
for item in sorted(clusters):
print("Cluster ", item)
print(clusters[item])
# Plot points by clusters
from matplotlib import cm
colorGrad = []
for i in range(0,3000):
colorGrad.append([i/3000,0, 1-i/3000])
# Reorganize the printings
for item in sorted(clusters):
print("Cluster ", item)
for i in clusters[item]:
data = pd.read_csv(os.path.join(SUBDIRECT, i))
X = data['posX'].values[0:30000:10]
Y = data['posZ'].values[0:30000:10]
plt.scatter(X, Y, c = colorGrad)
plt.show()