Skip to content

Commit 6fefd6e

Browse files
committed
Add clustering binding
Running the example: ``` Clustering 10000 vectors of dimension 64 into 10 clusters Running simple k-means clustering... Simple k-means completed in 8.759083ms Average quantization error: 2698.03 Running clustering with custom parameters... Sampling a subset of 2560 / 10000 for training Clustering 2560 points in 64D to 10 clusters, redo 3 times, 25 iterations Preprocessing in 0.00 s Outer iteration 0 / 3 Iteration 24 (0.01 s, search 0.00 s): objective=2.69803e+07 imbalance=1.296 nsplit=0 Objective improved: keep new clusters Outer iteration 1 / 3 Iteration 24 (0.01 s, search 0.01 s): objective=3.06187e+07 imbalance=1.357 nsplit=0 Outer iteration 2 / 3 Iteration 24 (0.02 s, search 0.01 s): objective=2.0157e+07 imbalance=1.134 nsplit=0 Objective improved: keep new clusters Advanced clustering completed in 19.593958ms Comparing clustering quality... Average distance to nearest centroid: Simple k-means: 10705.22 Advanced method: 8305.35 Improvement: 22.42% better Time comparison: Simple: 8.759083ms Advanced: 19.593958ms (2.2x slower due to multiple runs) ```
1 parent 371fb38 commit 6fefd6e

File tree

2 files changed

+305
-0
lines changed

2 files changed

+305
-0
lines changed

_example/clustering/clustering.go

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"log"
6+
"math/rand"
7+
"time"
8+
9+
"github.com/blevesearch/go-faiss"
10+
)
11+
12+
func main() {
13+
rng := rand.New(rand.NewSource(123456))
14+
15+
const (
16+
d = 64 // vector dimension
17+
n = 10_000 // number of training vectors
18+
k = 10 // number of clusters
19+
)
20+
21+
fmt.Printf("Clustering %d vectors of dimension %d into %d clusters\n\n", n, d, k)
22+
23+
train := make([]float32, n*d)
24+
centers := make([][]float32, k)
25+
for i := 0; i < k; i++ {
26+
centers[i] = make([]float32, d)
27+
for j := 0; j < d; j++ {
28+
centers[i][j] = rng.Float32() * 100
29+
}
30+
}
31+
32+
// Generate points around these centers with some noise
33+
pointsPerCluster := n / k
34+
for i := 0; i < n; i++ {
35+
cluster := i / pointsPerCluster
36+
if cluster >= k {
37+
cluster = k - 1
38+
}
39+
40+
// Add Gaussian noise around the cluster center
41+
for j := 0; j < d; j++ {
42+
noise := float32(rng.NormFloat64() * 5)
43+
train[i*d+j] = centers[cluster][j] + noise
44+
}
45+
}
46+
47+
fmt.Println("Running simple k-means clustering...")
48+
start := time.Now()
49+
50+
centroids, qerr, err := faiss.KMeansClustering(d, n, k, train)
51+
if err != nil {
52+
log.Fatalf("k-means: %v", err)
53+
}
54+
55+
simpleTime := time.Since(start)
56+
fmt.Printf("Simple k-means completed in %v\n", simpleTime)
57+
fmt.Printf("Average quantization error: %.2f\n\n", qerr/float32(n))
58+
59+
fmt.Println("Running clustering with custom parameters...")
60+
61+
params := faiss.NewClusteringParameters()
62+
params.Niter = 25
63+
params.Nredo = 3
64+
params.Verbose = true
65+
params.Seed = 1234
66+
params.MinPointsPerCentroid = 39
67+
params.MaxPointsPerCentroid = 256
68+
69+
clustering, err := faiss.NewClusteringWithParams(d, k, params)
70+
if err != nil {
71+
log.Fatalf("new clustering: %v", err)
72+
}
73+
defer clustering.Close()
74+
75+
// Create an index to accelerate clustering
76+
// For larger datasets, consider using a faster index like IndexIVFFlat
77+
accelIdx, err := faiss.NewIndexFlatL2(d)
78+
if err != nil {
79+
log.Fatalf("index: %v", err)
80+
}
81+
defer accelIdx.Close()
82+
83+
start = time.Now()
84+
if err = clustering.Train(train, accelIdx); err != nil {
85+
log.Fatalf("train: %v", err)
86+
}
87+
advancedTime := time.Since(start)
88+
89+
advCentroids := clustering.Centroids()
90+
fmt.Printf("\nAdvanced clustering completed in %v\n\n", advancedTime)
91+
92+
fmt.Println("Comparing clustering quality...")
93+
94+
baseIdx, err := faiss.NewIndexFlatL2(d)
95+
if err != nil {
96+
log.Fatalf("index: %v", err)
97+
}
98+
defer baseIdx.Close()
99+
if err = baseIdx.Add(centroids); err != nil {
100+
log.Fatalf("add centroids: %v", err)
101+
}
102+
103+
advIdx, err := faiss.NewIndexFlatL2(d)
104+
if err != nil {
105+
log.Fatalf("index: %v", err)
106+
}
107+
defer advIdx.Close()
108+
if err = advIdx.Add(advCentroids); err != nil {
109+
log.Fatalf("add centroids: %v", err)
110+
}
111+
112+
// Find nearest centroid for each training point
113+
baseDist, _, _ := baseIdx.Search(train, 1)
114+
advDist, _, _ := advIdx.Search(train, 1)
115+
116+
avgBase := mean(baseDist)
117+
avgAdv := mean(advDist)
118+
119+
fmt.Printf("Average distance to nearest centroid:\n")
120+
fmt.Printf(" Simple k-means: %.2f\n", avgBase)
121+
fmt.Printf(" Advanced method: %.2f\n", avgAdv)
122+
fmt.Printf(" Improvement: %.2f%% better\n", 100*(avgBase-avgAdv)/avgBase)
123+
fmt.Printf("\nTime comparison:\n")
124+
fmt.Printf(" Simple: %v\n", simpleTime)
125+
fmt.Printf(" Advanced: %v (%.1fx slower due to multiple runs)\n",
126+
advancedTime, float64(advancedTime)/float64(simpleTime))
127+
}
128+
129+
func mean(x []float32) float64 {
130+
var s float64
131+
for _, v := range x {
132+
s += float64(v)
133+
}
134+
return s / float64(len(x))
135+
}

clustering.go

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
package faiss
2+
3+
/*
4+
#include <faiss/c_api/Clustering_c.h>
5+
#include <faiss/c_api/Index_c.h>
6+
*/
7+
import "C"
8+
import "unsafe"
9+
10+
type ClusteringParameters struct {
11+
Niter int // Number of clustering iterations
12+
Nredo int // Number of times to redo clustering and keep best
13+
Verbose bool // Verbose output
14+
Spherical bool // Do we want normalized centroids?
15+
IntCentroids bool // Round centroids coordinates to integer
16+
UpdateIndex bool // Update index after each iteration?
17+
FrozenCentroids bool // Use the centroids provided as input and do not change them during iterations
18+
MinPointsPerCentroid int // Otherwise you get a warning
19+
MaxPointsPerCentroid int // To limit size of dataset
20+
Seed int // Seed for the random number generator
21+
DecodeBlockSize uint64 // How many vectors at a time to decode
22+
}
23+
24+
// Create a new ClusteringParameters with default values.
25+
func NewClusteringParameters() *ClusteringParameters {
26+
var cparams C.FaissClusteringParameters
27+
C.faiss_ClusteringParameters_init(&cparams)
28+
29+
return &ClusteringParameters{
30+
Niter: int(cparams.niter),
31+
Nredo: int(cparams.nredo),
32+
Verbose: cparams.verbose != 0,
33+
Spherical: cparams.spherical != 0,
34+
IntCentroids: cparams.int_centroids != 0,
35+
UpdateIndex: cparams.update_index != 0,
36+
FrozenCentroids: cparams.frozen_centroids != 0,
37+
MinPointsPerCentroid: int(cparams.min_points_per_centroid),
38+
MaxPointsPerCentroid: int(cparams.max_points_per_centroid),
39+
Seed: int(cparams.seed),
40+
DecodeBlockSize: uint64(cparams.decode_block_size),
41+
}
42+
}
43+
44+
func (p *ClusteringParameters) toCStruct() C.FaissClusteringParameters {
45+
return C.FaissClusteringParameters{
46+
niter: C.int(p.Niter),
47+
nredo: C.int(p.Nredo),
48+
verbose: boolToInt(p.Verbose),
49+
spherical: boolToInt(p.Spherical),
50+
int_centroids: boolToInt(p.IntCentroids),
51+
update_index: boolToInt(p.UpdateIndex),
52+
frozen_centroids: boolToInt(p.FrozenCentroids),
53+
min_points_per_centroid: C.int(p.MinPointsPerCentroid),
54+
max_points_per_centroid: C.int(p.MaxPointsPerCentroid),
55+
seed: C.int(p.Seed),
56+
decode_block_size: C.size_t(p.DecodeBlockSize),
57+
}
58+
}
59+
60+
type Clustering struct {
61+
clustering *C.FaissClustering
62+
d int
63+
k int
64+
}
65+
66+
// Create a new clustering object with default parameters.
67+
func NewClustering(d, k int) (*Clustering, error) {
68+
var clustering *C.FaissClustering
69+
if c := C.faiss_Clustering_new(&clustering, C.int(d), C.int(k)); c != 0 {
70+
return nil, getLastError()
71+
}
72+
return &Clustering{
73+
clustering: clustering,
74+
d: d,
75+
k: k,
76+
}, nil
77+
}
78+
79+
func NewClusteringWithParams(d, k int, params *ClusteringParameters) (*Clustering, error) {
80+
var clustering *C.FaissClustering
81+
cparams := params.toCStruct()
82+
if c := C.faiss_Clustering_new_with_params(&clustering, C.int(d), C.int(k), &cparams); c != 0 {
83+
return nil, getLastError()
84+
}
85+
return &Clustering{
86+
clustering: clustering,
87+
d: d,
88+
k: k,
89+
}, nil
90+
}
91+
92+
// Return the dimension of the vectors.
93+
func (c *Clustering) D() int {
94+
return c.d
95+
}
96+
97+
// Return the number of clusters.
98+
func (c *Clustering) K() int {
99+
return c.k
100+
}
101+
102+
func (c *Clustering) cPtr() *C.FaissClustering {
103+
return c.clustering
104+
}
105+
106+
// Train performs the k-means clustering on the provided vectors.
107+
// The index parameter can be used to accelerate the clustering by providing
108+
// a fast way to perform nearest-neighbor queries. If nil, a default index
109+
// will be used internally.
110+
func (c *Clustering) Train(x []float32, index Index) error {
111+
n := len(x) / c.D()
112+
113+
var idx *C.FaissIndex
114+
if index != nil {
115+
idx = index.cPtr()
116+
}
117+
118+
if code := C.faiss_Clustering_train(
119+
c.clustering,
120+
C.idx_t(n),
121+
(*C.float)(&x[0]),
122+
idx,
123+
); code != 0 {
124+
return getLastError()
125+
}
126+
return nil
127+
}
128+
129+
// Return the cluster centroids after training.
130+
func (c *Clustering) Centroids() []float32 {
131+
var centroids *C.float
132+
var size C.size_t
133+
C.faiss_Clustering_centroids(c.clustering, &centroids, &size)
134+
return (*[1 << 30]float32)(unsafe.Pointer(centroids))[:size:size]
135+
}
136+
137+
// Free the memory used by the clustering object.
138+
func (c *Clustering) Close() {
139+
if c.clustering != nil {
140+
C.faiss_Clustering_free(c.clustering)
141+
c.clustering = nil
142+
}
143+
}
144+
145+
// KMeansClustering is a simplified interface for k-means clustering.
146+
// It performs clustering and returns the centroids and quantization error.
147+
func KMeansClustering(d, n, k int, x []float32) (centroids []float32, qerr float32, err error) {
148+
centroids = make([]float32, k*d)
149+
var cqerr C.float
150+
151+
if c := C.faiss_kmeans_clustering(
152+
C.size_t(d),
153+
C.size_t(n),
154+
C.size_t(k),
155+
(*C.float)(&x[0]),
156+
(*C.float)(&centroids[0]),
157+
&cqerr,
158+
); c != 0 {
159+
return nil, 0, getLastError()
160+
}
161+
162+
return centroids, float32(cqerr), nil
163+
}
164+
165+
func boolToInt(b bool) C.int {
166+
if b {
167+
return 1
168+
}
169+
return 0
170+
}

0 commit comments

Comments
 (0)