Skip to content

Commit

Permalink
Add milvus lite benchmark (1yefuwang1#28)
Browse files Browse the repository at this point in the history
* init milvus api

* test

* fix

* add milvus benchmark

* refactor

* remove unused line

* fix product name

* add var BENCHMARK_MILVUS_LITE

---------

Co-authored-by: xinyupang <[email protected]>
  • Loading branch information
Greedygre and xinyululala authored Sep 9, 2024
1 parent 4d981db commit a045798
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 9 deletions.
117 changes: 109 additions & 8 deletions benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def timeit(func):

query_data = {dim: np.float32(np.random.random((NUM_QUERIES, dim))) for dim in DIMS}
query_data_bytes = {dim: [query_data[dim][i].tobytes() for i in range(NUM_QUERIES)] for dim in DIMS}
query_data_for_milvus = {dim: [query_data[dim][i].tolist() for i in range(NUM_QUERIES)] for dim in DIMS}

# search for k nearest neighbors in this benchmark
k = 10
Expand Down Expand Up @@ -87,6 +88,7 @@ class BenchmarkResult:
insert_time_us: float # in micro seconds, per vector
search_time_us: float # in micro seconds, per query
recall_rate: float
product: Optional[str]


@dataclasses.dataclass
Expand All @@ -97,6 +99,7 @@ def __rich_console__(
self, console: Console, options: ConsoleOptions
) -> RenderResult:
table = Table()
table.add_column("product\nname")
table.add_column("distance\ntype")
table.add_column("vector\ndimension")
if self.results[0].ef_construction is not None:
Expand All @@ -109,6 +112,7 @@ def __rich_console__(
for result in self.results:
if self.results[0].ef_construction is not None:
table.add_row(
result.product,
result.distance_type,
str(result.dim),
str(result.ef_construction),
Expand All @@ -120,6 +124,7 @@ def __rich_console__(
)
else:
table.add_row(
result.product,
result.distance_type,
str(result.dim),
f"{result.insert_time_us:.2f} us",
Expand All @@ -133,7 +138,6 @@ class PlotData:
time_taken_us: float
column: str


benchmark_results = []
plot_data_for_insertion = defaultdict(list)
plot_data_for_query = defaultdict(list)
Expand All @@ -152,7 +156,7 @@ def wrapper():


def benchmark(distance_type, dim, ef_constructoin, M):
result = BenchmarkResult(distance_type, dim, ef_constructoin, M, 0, 0, 0, 0)
result = BenchmarkResult(distance_type, dim, ef_constructoin, M, 0, 0, 0, 0, "vectorLite")
table_name = f"table_{distance_type}_{dim}_{ef_constructoin}_{M}"
cursor.execute(
f"create virtual table {table_name} using vectorlite(embedding float32[{dim}] {distance_type}, hnsw(max_elements={NUM_ELEMENTS}, ef_construction={ef_constructoin}, M={M}))"
Expand Down Expand Up @@ -200,7 +204,6 @@ def search():
plot_data_for_query[dim].append(PlotData(result.search_time_us, f"vectorlite_{distance_type}_ef_{ef}"))
cursor.execute(f"drop table {table_name}")


for distance_type in distance_types:
for dim in DIMS:
for ef_construction, M in hnsw_params:
Expand All @@ -210,11 +213,10 @@ def search():
result_table = ResultTable(benchmark_results)
console.print(result_table)


hnswlib_benchmark_results = []
console.print("Bencharmk hnswlib as comparison.")
def benchmark_hnswlib(distance_type, dim, ef_construction, M):
result = BenchmarkResult(distance_type, dim, ef_construction, M, 0, 0, 0, 0)
result = BenchmarkResult(distance_type, dim, ef_construction, M, 0, 0, 0, 0, "hnswlib")
hnswlib_index = hnswlib.Index(space=distance_type, dim=dim)
hnswlib_index.init_index(max_elements=NUM_ELEMENTS, ef_construction=ef_construction, M=M)

Expand Down Expand Up @@ -265,7 +267,7 @@ def search():
console.print("Bencharmk vectorlite brute force(select rowid from my_table order by vector_distance(query_vector, embedding, 'l2')) as comparison.")

def benchmark_brute_force(dim: int):
benchmark_result = BenchmarkResult("l2", dim, None, None, None, 0, 0, 0)
benchmark_result = BenchmarkResult("l2", dim, None, None, None, 0, 0, 0, "vectorLite_brute_force")
table_name = f"table_vectorlite_bf_{dim}"
cursor.execute(
f"create table {table_name}(rowid integer primary key, embedding blob)"
Expand Down Expand Up @@ -328,7 +330,7 @@ def search():
vss_benchmark_results = []

def benchmark_sqlite_vss(dim: int):
benchmark_result = BenchmarkResult("l2", dim, None, None, None, 0, 0, 0)
benchmark_result = BenchmarkResult("l2", dim, None, None, None, 0, 0, 0, "vectorLite")
table_name = f"table_vss_{dim}"
cursor.execute(
f"create virtual table {table_name} using vss0(embedding({dim}))"
Expand Down Expand Up @@ -388,7 +390,7 @@ def search():
conn.load_extension(sqlite_vec.loadable_path())

def benchmark_sqlite_vec(dim: int):
benchmark_result = BenchmarkResult("l2", dim, None, None, None, 0, 0, 0)
benchmark_result = BenchmarkResult("l2", dim, None, None, None, 0, 0, 0, "vectorLite")
table_name = f"table_vec_{dim}"
cursor.execute(
f"create virtual table {table_name} using vec0(rowid integer primary key, embedding float[{dim}])"
Expand Down Expand Up @@ -437,6 +439,105 @@ def search():
vec_result_table = ResultTable(vec_benchmark_results)
console.print(vec_result_table)

benchmark_milvus_lite = os.environ.get("BENCHMARK_MILVUS_LITE", "0") != "0"
if benchmark_milvus_lite and (platform.system().lower() == "linux" or platform.system().lower() == "darwin"):
import os
from pymilvus import MilvusClient
import numpy as np

client = MilvusClient("./milvus_lite_demo.db")

def milvus_insert(client, collection_name, data):
client.insert(collection_name=collection_name, data=data)

def milvus_insert_many(client, collection_name, dim):
insert_data = [{"id": i, "embedding": data[dim][i].tolist()} for i in range(NUM_ELEMENTS)]
client.insert(collection_name=collection_name, data=insert_data)

def milvus_search(client, collection_name, distance_type, search_data):
res = client.search(
collection_name=collection_name,
data=[search_data],
search_params={"metric_type": distance_type.upper()}, # Search parameters
)
rowids = [result['id'] for result in res[0]]
return rowids

def milvus_create_table(client, collection_name, distance_type, dim):
from pymilvus import DataType

schema = client.create_schema(enable_dynamic_field=True)
schema.add_field("id", DataType.INT64, is_primary=True)
schema.add_field("embedding", DataType.FLOAT_VECTOR, dim=dim)
index_params = client.prepare_index_params()

index_params.add_index(
field_name="embedding",
metric_type=distance_type.upper(),
)
client.create_collection(
collection_name=collection_name,
schema=schema,
index_params=index_params,
dimension=dim # The vectors we will use in this demo has 384 dimensions
)

res = client.get_load_state(
collection_name=collection_name
)
# print(client.describe_index(collection_name,"embedding"))
# print(client.describe_index(collection_name,"id"))

console.print("Bencharmk milvuslite.")
benchmark_milvus_results = []
def benchmark_milvus(distance_type, dim):
result = BenchmarkResult(distance_type=distance_type, dim=dim, insert_time_us=0, search_time_us=0, recall_rate=0, product="milvusLite", ef_construction=None, M=None, ef_search=None)
collection_name = f"collection_{distance_type}_{dim}"

milvus_create_table(client=client, collection_name=collection_name, distance_type=distance_type, dim=dim)

# measure insert time
insert_time_us, _ = timeit(
transactional(lambda: milvus_insert_many(client=client, collection_name=collection_name, dim=dim))
)

result.insert_time_us = insert_time_us / NUM_ELEMENTS
plot_data_for_insertion[dim].append(PlotData(result.insert_time_us, f"milvuslite_{distance_type}"))


def search():
result = []
for i in range(NUM_QUERIES):
result.append(
milvus_search(client=client, collection_name=collection_name, distance_type=distance_type, search_data=query_data_for_milvus[dim][i])
)
return result

search_time_us, results = timeit(search)
# console.log(results)
recall_rate = np.mean(
[
np.intersect1d(results[i], correct_labels[distance_type][dim][i]).size
/ k
for i in range(NUM_QUERIES)
]
)
result = dataclasses.replace(
result,
search_time_us=search_time_us / NUM_QUERIES,
recall_rate=recall_rate,
)

benchmark_milvus_results.append(result)
plot_data_for_query[dim].append(PlotData(result.search_time_us, f"milvuslite_{distance_type}"))
client.drop_collection(collection_name=collection_name)

for distance_type in distance_types:
for dim in DIMS:
benchmark_milvus(distance_type, dim)

console.print(ResultTable(benchmark_milvus_results))

import plot
def plot_figures():
vector_insertion_columns = ["dim"] + [plot_data.column for plot_data in plot_data_for_insertion[DIMS[0]]]
Expand Down
3 changes: 2 additions & 1 deletion benchmark/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ apsw>=3.45
rich>=13.7
hnswlib>=0.8
matplotlib
pandas
pandas
pymilvus

0 comments on commit a045798

Please sign in to comment.