Skip to content

Commit d1703f2

Browse files
committed
accept engine_args
1 parent e560081 commit d1703f2

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

llama_iris/vectorstore.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
VectorStoreQueryResult,
1010
)
1111
from llama_index.vector_stores.utils import metadata_dict_to_node, node_to_metadata_dict
12+
from sqlalchemy.orm.session import close_all_sessions
1213

1314

1415
_logger = logging.getLogger(__name__)
@@ -60,6 +61,7 @@ class IRISVectorStore(BasePydanticVectorStore):
6061
embed_dim: int
6162
perform_setup: bool
6263
debug: bool
64+
engine_args: Optional[dict]
6365

6466
_base: Any = PrivateAttr()
6567
_table_class: Any = PrivateAttr()
@@ -76,6 +78,7 @@ def __init__(
7678
embed_dim: int = 1536,
7779
perform_setup: bool = True,
7880
debug: bool = False,
81+
engine_args: Optional[dict] = None,
7982
) -> None:
8083
table_name = table_name.lower()
8184
schema_name = schema_name.lower()
@@ -87,13 +90,14 @@ def __init__(
8790
embed_dim=embed_dim,
8891
perform_setup=perform_setup,
8992
debug=debug,
93+
engine_args=engine_args or {},
9094
)
9195

9296
async def close(self) -> None:
9397
if not self._is_initialized:
9498
return
9599

96-
self._session.close_all()
100+
close_all_sessions()
97101
self._engine.dispose()
98102

99103
@classmethod
@@ -114,6 +118,7 @@ def from_params(
114118
embed_dim: int = 1536,
115119
perform_setup: bool = True,
116120
debug: bool = False,
121+
engine_args: Optional[dict] = None,
117122
) -> "IRISVectorStore":
118123
"""Return connection string from database parameters."""
119124
conn_str = (
@@ -127,6 +132,7 @@ def from_params(
127132
embed_dim=embed_dim,
128133
perform_setup=perform_setup,
129134
debug=debug,
135+
engine_args=engine_args,
130136
)
131137

132138
@property
@@ -139,7 +145,7 @@ def _connect(self) -> Any:
139145
from sqlalchemy import create_engine
140146
from sqlalchemy.orm import sessionmaker
141147

142-
self._engine = create_engine(self.connection_string, echo=self.debug)
148+
self._engine = create_engine(self.connection_string, echo=self.debug, **self.engine_args)
143149
self._session = sessionmaker(self._engine)
144150
with self._engine.connect() as conn:
145151
self._native_vector = conn.dialect.supports_vectors

0 commit comments

Comments
 (0)