Skip to content

Commit

Permalink
feat: Implementing online_read for MilvusOnlineStore (feast-dev#4996)
Browse files Browse the repository at this point in the history
  • Loading branch information
franciscojavierarceo authored Feb 1, 2025
1 parent 0145e55 commit 92dde13
Show file tree
Hide file tree
Showing 9 changed files with 431 additions and 60 deletions.
1 change: 1 addition & 0 deletions docs/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
* [Hazelcast](reference/online-stores/hazelcast.md)
* [ScyllaDB](reference/online-stores/scylladb.md)
* [SingleStore](reference/online-stores/singlestore.md)
* [Milvus](reference/online-stores/milvus.md)
* [Registries](reference/registries/README.md)
* [Local](reference/registries/local.md)
* [S3](reference/registries/s3.md)
Expand Down
64 changes: 64 additions & 0 deletions docs/reference/online-stores/milvus.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Redis online store

## Description

The [Milvus](https://milvus.io/) online store provides support for materializing feature values into Milvus.

* The data model used to store feature values in Milvus is described in more detail [here](../../specs/online\_store\_format.md).

## Getting started
In order to use this online store, you'll need to install the Milvus extra (along with the dependency needed for the offline store of choice). E.g.

`pip install 'feast[milvus]'`

You can get started by using any of the other templates (e.g. `feast init -t gcp` or `feast init -t snowflake` or `feast init -t aws`), and then swapping in Redis as the online store as seen below in the examples.

## Examples

Connecting to a local MilvusDB instance:

{% code title="feature_store.yaml" %}
```yaml
project: my_feature_repo
registry: data/registry.db
provider: local
online_store:
type: milvus
path: "data/online_store.db"
connection_string: "localhost:6379"
embedding_dim: 128
index_type: "FLAT"
metric_type: "COSINE"
username: "username"
password: "password"
```
{% endcode %}
The full set of configuration options is available in [MilvusOnlineStoreConfig](https://rtd.feast.dev/en/latest/#feast.infra.online_stores.milvus.MilvusOnlineStoreConfig).
## Functionality Matrix
The set of functionality supported by online stores is described in detail [here](overview.md#functionality).
Below is a matrix indicating which functionality is supported by the Milvus online store.
| | Milvus |
| :-------------------------------------------------------- |:-------|
| write feature values to the online store | yes |
| read feature values from the online store | yes |
| update infrastructure (e.g. tables) in the online store | yes |
| teardown infrastructure (e.g. tables) in the online store | yes |
| generate a plan of infrastructure changes | no |
| support for on-demand transforms | yes |
| readable by Python SDK | yes |
| readable by Java | no |
| readable by Go | no |
| support for entityless feature views | yes |
| support for concurrent writing to the same key | yes |
| support for ttl (time to live) at retrieval | yes |
| support for deleting expired data | yes |
| collocated by feature view | no |
| collocated by feature service | no |
| collocated by entity key | yes |
To compare this set of functionality against other online stores, please see the full [functionality matrix](overview.md#functionality-matrix).
7 changes: 3 additions & 4 deletions examples/rag/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ The RAG architecture combines retrieval of documents (using vector search) with

3. Materialize features into the online store:

```bash
python -c "from datetime import datetime; from feast import FeatureStore; store = FeatureStore(repo_path='.')"
python -c "store.materialize_incremental(datetime.utcnow())"
```python
store.write_to_online_store(feature_view_name='city_embeddings', df=df)
```
4. Run a query:

Expand All @@ -61,7 +60,7 @@ feast apply
store.write_to_online_store(feature_view_name='city_embeddings', df=df)
```

-Inspect retrieved features using Python:
- Inspect retrieved features using Python:
```python
context_data = store.retrieve_online_documents_v2(
features=[
Expand Down
2 changes: 1 addition & 1 deletion examples/rag/milvus-quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@
}
],
"source": [
"! feast apply "
"! feast apply"
]
},
{
Expand Down
205 changes: 153 additions & 52 deletions sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from pydantic import StrictStr
from pymilvus import (
Collection,
CollectionSchema,
DataType,
FieldSchema,
Expand All @@ -20,13 +19,13 @@
)
from feast.infra.online_stores.online_store import OnlineStore
from feast.infra.online_stores.vector_store import VectorStoreConfig
from feast.protos.feast.core.InfraObject_pb2 import InfraObject as InfraObjectProto
from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.type_map import (
PROTO_VALUE_TO_VALUE_TYPE_MAP,
VALUE_TYPE_TO_PROTO_VALUE_MAP,
feast_value_type_to_python_type,
)
from feast.types import (
Expand All @@ -35,6 +34,7 @@
ComplexFeastType,
PrimitiveFeastType,
ValueType,
from_feast_type,
)
from feast.utils import (
_serialize_vector_to_float_list,
Expand Down Expand Up @@ -146,9 +146,7 @@ def _get_or_create_collection(
collection_name = _table_id(config.project, table)
if collection_name not in self._collections:
# Create a composite key by combining entity fields
composite_key_name = (
"_".join([field.name for field in table.entity_columns]) + "_pk"
)
composite_key_name = _get_composite_key_name(table)

fields = [
FieldSchema(
Expand Down Expand Up @@ -251,9 +249,8 @@ def online_write_batch(
).hex()
# to recover the entity key just run:
# deserialize_entity_key(bytes.fromhex(entity_key_str), entity_key_serialization_version=3)
composite_key_name = (
"_".join([str(value) for value in entity_key.join_keys]) + "_pk"
)
composite_key_name = _get_composite_key_name(table)

timestamp_int = int(to_naive_utc(timestamp).timestamp() * 1e6)
created_ts_int = (
int(to_naive_utc(created_ts).timestamp() * 1e6) if created_ts else 0
Expand Down Expand Up @@ -293,8 +290,133 @@ def online_read(
table: FeatureView,
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None,
full_feature_names: bool = False,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
raise NotImplementedError
self.client = self._connect(config)
collection_name = _table_id(config.project, table)
collection = self._get_or_create_collection(config, table)

composite_key_name = _get_composite_key_name(table)

output_fields = (
[composite_key_name]
+ (requested_features if requested_features else [])
+ ["created_ts", "event_ts"]
)
assert all(
field in [f["name"] for f in collection["fields"]]
for field in output_fields
), (
f"field(s) [{[field for field in output_fields if field not in [f['name'] for f in collection['fields']]]}] not found in collection schema"
)
composite_entities = []
for entity_key in entity_keys:
entity_key_str = serialize_entity_key(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
).hex()
composite_entities.append(entity_key_str)

query_filter_for_entities = (
f"{composite_key_name} in ["
+ ", ".join([f"'{e}'" for e in composite_entities])
+ "]"
)
self.client.load_collection(collection_name)
results = self.client.query(
collection_name=collection_name,
filter=query_filter_for_entities,
output_fields=output_fields,
)
# Group hits by composite key.
grouped_hits: Dict[str, Any] = {}
for hit in results:
key = hit.get(composite_key_name)
grouped_hits.setdefault(key, []).append(hit)

# Map the features to their Feast types.
feature_name_feast_primitive_type_map = {
f.name: f.dtype for f in table.features
}
# Build a dictionary mapping composite key -> (res_ts, res)
results_dict: Dict[
str, Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]
] = {}

# here we need to map the data stored as characters back into the protobuf value
for hit in results:
key = hit.get(composite_key_name)
# Only take one hit per composite key (adjust if you need aggregation)
if key not in results_dict:
res = {}
res_ts = None
for field in output_fields:
val = ValueProto()
field_value = hit.get(field, None)
if field_value is None and ":" in field:
_, field_short = field.split(":", 1)
field_value = hit.get(field_short)

if field in ["created_ts", "event_ts"]:
res_ts = datetime.fromtimestamp(field_value / 1e6)
elif field == composite_key_name:
# We do not return the composite key value
pass
else:
feature_feast_primitive_type = (
feature_name_feast_primitive_type_map.get(
field, PrimitiveFeastType.INVALID
)
)
feature_fv_dtype = from_feast_type(feature_feast_primitive_type)
proto_attr = VALUE_TYPE_TO_PROTO_VALUE_MAP.get(feature_fv_dtype)
if proto_attr:
if proto_attr == "bytes_val":
setattr(val, proto_attr, field_value.encode())
elif proto_attr in [
"int32_val",
"int64_val",
"float_val",
"double_val",
]:
setattr(
val,
proto_attr,
type(getattr(val, proto_attr))(field_value),
)
elif proto_attr in [
"int32_list_val",
"int64_list_val",
"float_list_val",
"double_list_val",
]:
setattr(
val,
proto_attr,
list(
map(
type(getattr(val, proto_attr)).__args__[0],
field_value,
)
),
)
else:
setattr(val, proto_attr, field_value)
else:
raise ValueError(
f"Unsupported ValueType: {feature_feast_primitive_type} with feature view value {field_value} for feature {field} with value {field_value}"
)
# res[field] = val
key_to_use = field.split(":", 1)[-1] if ":" in field else field
res[key_to_use] = val
results_dict[key] = (res_ts, res if res else None)

# Map the results back into a list matching the original order of composite_keys.
result_list = [
results_dict.get(key, (None, None)) for key in composite_entities
]

return result_list

def update(
self,
Expand Down Expand Up @@ -362,11 +484,7 @@ def retrieve_online_documents_v2(
"params": {"nprobe": 10},
}

composite_key_name = (
"_".join([str(field.name) for field in table.entity_columns]) + "_pk"
)
# features_str = ", ".join([f"'{f}'" for f in requested_features])
# expr = f" && feature_name in [{features_str}]"
composite_key_name = _get_composite_key_name(table)

output_fields = (
[composite_key_name]
Expand Down Expand Up @@ -452,6 +570,10 @@ def _table_id(project: str, table: FeatureView) -> str:
return f"{project}_{table.name}"


def _get_composite_key_name(table: FeatureView) -> str:
return "_".join([field.name for field in table.entity_columns]) + "_pk"


def _extract_proto_values_to_dict(
input_dict: Dict[str, Any],
vector_cols: List[str],
Expand All @@ -462,6 +584,13 @@ def _extract_proto_values_to_dict(
for k in PROTO_VALUE_TO_VALUE_TYPE_MAP.keys()
if k is not None and "list" in k and "string" not in k
]
numeric_types = [
"double_val",
"float_val",
"int32_val",
"int64_val",
"bool_val",
]
output_dict = {}
for feature_name, feature_values in input_dict.items():
for proto_val_type in PROTO_VALUE_TO_VALUE_TYPE_MAP:
Expand All @@ -475,10 +604,18 @@ def _extract_proto_values_to_dict(
else:
vector_values = getattr(feature_values, proto_val_type).val
else:
if serialize_to_string and proto_val_type != "string_val":
if (
serialize_to_string
and proto_val_type not in ["string_val"] + numeric_types
):
vector_values = feature_values.SerializeToString().decode()
else:
vector_values = getattr(feature_values, proto_val_type)
if not isinstance(feature_values, str):
vector_values = str(
getattr(feature_values, proto_val_type)
)
else:
vector_values = getattr(feature_values, proto_val_type)
output_dict[feature_name] = vector_values
else:
if serialize_to_string:
Expand All @@ -487,39 +624,3 @@ def _extract_proto_values_to_dict(
output_dict[feature_name] = feature_values

return output_dict


class MilvusTable(InfraObject):
"""
A Milvus collection managed by Feast.
Attributes:
host: The host of the Milvus server.
port: The port of the Milvus server.
name: The name of the collection.
"""

host: str
port: int

def __init__(self, host: str, port: int, name: str):
super().__init__(name)
self.host = host
self.port = port
self._connect()

def _connect(self):
raise NotImplementedError

def to_infra_object_proto(self) -> InfraObjectProto:
# Implement serialization if needed
raise NotImplementedError

def update(self):
# Implement update logic if needed
raise NotImplementedError

def teardown(self):
collection = Collection(name=self.name)
if collection.exists():
collection.drop()
Loading

0 comments on commit 92dde13

Please sign in to comment.