Skip to content

Commit

Permalink
accept engine_args for #2
Browse files Browse the repository at this point in the history
  • Loading branch information
daimor committed May 9, 2024
1 parent e560081 commit d228b28
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions llama_iris/vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
VectorStoreQueryResult,
)
from llama_index.vector_stores.utils import metadata_dict_to_node, node_to_metadata_dict
from sqlalchemy.orm.session import close_all_sessions


_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -60,6 +61,7 @@ class IRISVectorStore(BasePydanticVectorStore):
embed_dim: int
perform_setup: bool
debug: bool
engine_args: Optional[dict]

_base: Any = PrivateAttr()
_table_class: Any = PrivateAttr()
Expand All @@ -76,6 +78,7 @@ def __init__(
embed_dim: int = 1536,
perform_setup: bool = True,
debug: bool = False,
engine_args: Optional[dict] = None,
) -> None:
table_name = table_name.lower()
schema_name = schema_name.lower()
Expand All @@ -87,13 +90,14 @@ def __init__(
embed_dim=embed_dim,
perform_setup=perform_setup,
debug=debug,
engine_args=engine_args or {},
)

async def close(self) -> None:
if not self._is_initialized:
return

self._session.close_all()
close_all_sessions()
self._engine.dispose()

@classmethod
Expand All @@ -114,6 +118,7 @@ def from_params(
embed_dim: int = 1536,
perform_setup: bool = True,
debug: bool = False,
engine_args: Optional[dict] = None,
) -> "IRISVectorStore":
"""Return connection string from database parameters."""
conn_str = (
Expand All @@ -127,6 +132,7 @@ def from_params(
embed_dim=embed_dim,
perform_setup=perform_setup,
debug=debug,
engine_args=engine_args,
)

@property
Expand All @@ -139,7 +145,7 @@ def _connect(self) -> Any:
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

self._engine = create_engine(self.connection_string, echo=self.debug)
self._engine = create_engine(self.connection_string, echo=self.debug, **self.engine_args)
self._session = sessionmaker(self._engine)
with self._engine.connect() as conn:
self._native_vector = conn.dialect.supports_vectors
Expand Down

0 comments on commit d228b28

Please sign in to comment.