From 5c3df39f4addb29e2b795bf2ceaed33d745d3752 Mon Sep 17 00:00:00 2001 From: noobHappylife <64898326+noobHappylife@users.noreply.github.com> Date: Fri, 29 Nov 2024 12:05:32 +0800 Subject: [PATCH] Fix embedding size and correctly filter default models (#31) * correctly filter default model * minor typing fix * fix chat mode streaming with non-multiturn columns * minor fix to ollama model, update to qwen2.5 3B * update changelog --------- Co-authored-by: deafnv --- CHANGELOG.md | 19 + clients/python/src/jamaibase/protocol.py | 33 +- .../tests/oss/gen_table/test_table_ops.py | 45 +- docker/compose.cpu.ollama.yml | 4 +- docker/ollama.yml | 2 +- scripts/remove_cloud_modules.ps1 | 1 + scripts/remove_cloud_modules.sh | 1 + services/api/pyproject.toml | 2 +- services/api/src/owl/configs/models_aipc.json | 2 +- .../api/src/owl/configs/models_ollama.json | 12 +- services/api/src/owl/db/gen_table.py | 150 ++--- services/api/src/owl/llm.py | 48 +- services/api/src/owl/protocol.py | 11 +- services/api/src/owl/utils/auth.py | 94 ++- .../lib/components/preset/SearchBar.svelte | 12 +- .../tables/(sub)/ColumnDropdown.svelte | 9 +- .../tables/(sub)/ColumnSettings.svelte | 7 + .../tables/(sub)/TablePagination.svelte | 8 +- .../lib/components/tables/ActionTable.svelte | 11 +- .../lib/components/tables/ChatTable.svelte | 11 +- .../components/tables/KnowledgeTable.svelte | 11 +- .../app/src/lib/components/ui/button/index.ts | 2 +- services/app/src/lib/constants.ts | 2 +- .../src/lib/icons/MultiturnChatIcon.svelte | 40 +- .../app/src/lib/icons/RegenerateIcon.svelte | 30 +- services/app/src/lib/types.ts | 4 + .../(components)/ActionsDropdown.svelte | 10 +- .../(components)/GenerateButton.svelte | 15 +- .../[table_id]/+page@project.svelte | 10 +- .../chat-table/[table_id]/+page.ts | 31 +- .../[table_id]/+page@project.svelte | 301 +++++---- .../chat-table/[table_id]/ChatMode.svelte | 603 +++++++++--------- .../chat-table/[table_id]/ModeToggle.svelte | 23 +- .../[table_id]/+page@project.svelte | 11 +- services/app/tests/pages/table.page.ts | 10 +- services/app/tests/tables/chatTable.spec.ts | 10 - 36 files changed, 827 insertions(+), 768 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0657313..1ccb399 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,25 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). - The version number mentioned here refers to the cloud version. For each release, all SDKs will have the same major and minor version, but their patch version may differ. For example, latest Python SDK might be `v0.2.0` whereas TS SDK might be `v0.2.1`, but both will be compatible with release `v0.2`. +## [Unreleased] + +Backend - owl (API server) + +- Fix bge-small embedding size (1024 -> 384) +- Correctly filter models at auth level +- Fix ollama model deployment config + +Frontend + +- Added support for multiple multiturn columns in Chat table chat view. +- Added multiturn chat toggle to column settings. + +Docker + +- Added Mac Apple Silicon `compose.mac.yml` +- Update `ollama.yml` to use Qwen2.5 3B +- Fix ollama default config + ## [v0.3.1] (2024-11-26) This is a bug fix release for frontend code. SDKs are not affected. diff --git a/clients/python/src/jamaibase/protocol.py b/clients/python/src/jamaibase/protocol.py index 125f548..31b8e70 100644 --- a/clients/python/src/jamaibase/protocol.py +++ b/clients/python/src/jamaibase/protocol.py @@ -1719,46 +1719,27 @@ class ActionTableSchemaCreate(TableSchemaCreate): class AddActionColumnSchema(ActionTableSchemaCreate): + # TODO: Deprecate this pass class KnowledgeTableSchemaCreate(TableSchemaCreate): + # TODO: Maybe deprecate this and use EmbedGenConfig instead ? embedding_model: str - @model_validator(mode="after") - def check_cols(self) -> Self: - super().check_cols() - num_text_cols = sum(c.id.lower() in ("text", "title", "file id") for c in self.cols) - if num_text_cols != 0: - raise ValueError("Schema cannot contain column names: 'Text', 'Title', 'File ID'.") - return self - class AddKnowledgeColumnSchema(TableSchemaCreate): - @model_validator(mode="after") - def check_cols(self) -> Self: - super().check_cols() - num_text_cols = sum(c.id.lower() in ("text", "title", "file id") for c in self.cols) - if num_text_cols != 0: - raise ValueError("Schema cannot contain column names: 'Text', 'Title', 'File ID'.") - return self + # TODO: Deprecate this + pass class ChatTableSchemaCreate(TableSchemaCreate): - @model_validator(mode="after") - def check_cols(self) -> Self: - super().check_cols() - num_text_cols = sum(c.id.lower() in ("user", "ai") for c in self.cols) - if num_text_cols != 2: - raise ValueError("Schema must contain column names: 'User' and 'AI'.") - return self + pass class AddChatColumnSchema(TableSchemaCreate): - @model_validator(mode="after") - def check_cols(self) -> Self: - super().check_cols() - return self + # TODO: Deprecate this + pass class TableMeta(TableBase): diff --git a/clients/python/tests/oss/gen_table/test_table_ops.py b/clients/python/tests/oss/gen_table/test_table_ops.py index 6b3116e..6e5fea8 100644 --- a/clients/python/tests/oss/gen_table/test_table_ops.py +++ b/clients/python/tests/oss/gen_table/test_table_ops.py @@ -21,6 +21,7 @@ "str": '"Arrival" is a 2016 science fiction film. "Arrival" è un film di fantascienza del 2016. 「Arrival」は2016年のSF映画です。', } KT_FIXED_COLUMN_IDS = ["Title", "Title Embed", "Text", "Text Embed", "File ID"] +CT_FIXED_COLUMN_IDS = ["User"] TABLE_ID_A = "table_a" TABLE_ID_B = "table_b" @@ -1430,6 +1431,21 @@ def test_kt_drop_invalid_columns(client_cls: Type[JamAI]): ) +@flaky(max_runs=5, min_passes=1, rerun_filter=_rerun_on_fs_error_with_delay) +@pytest.mark.parametrize("client_cls", CLIENT_CLS) +def test_ct_drop_invalid_columns(client_cls: Type[JamAI]): + table_type = "chat" + jamai = client_cls() + with _create_table(jamai, table_type) as table: + assert isinstance(table, p.TableMetaResponse) + for col in CT_FIXED_COLUMN_IDS: + with pytest.raises(RuntimeError): + jamai.table.drop_columns( + table_type, + p.ColumnDropRequest(table_id=table.id, column_names=[col]), + ) + + @flaky(max_runs=5, min_passes=1, rerun_filter=_rerun_on_fs_error_with_delay) @pytest.mark.parametrize("client_cls", CLIENT_CLS) @pytest.mark.parametrize("table_type", TABLE_TYPES) @@ -1450,7 +1466,7 @@ def test_rename_columns( assert isinstance(table, p.TableMetaResponse) assert all(isinstance(c, p.ColumnSchema) for c in table.cols) # Test rename on empty table - table = jamai.rename_columns( + table = jamai.table.rename_columns( table_type, p.ColumnRenameRequest(table_id=table.id, column_map=dict(y="z")), ) @@ -1475,7 +1491,7 @@ def test_rename_columns( _add_row(jamai, table_type, False, data=dict(x="True", z="")) # Test rename table with data # Test also auto gen config reference update - table = jamai.rename_columns( + table = jamai.table.rename_columns( table_type, p.ColumnRenameRequest(table_id=table.id, column_map=dict(x="a")), ) @@ -1503,14 +1519,14 @@ def test_rename_columns( # Repeated new column names with pytest.raises(RuntimeError): - jamai.rename_columns( + jamai.table.rename_columns( table_type, p.ColumnRenameRequest(table_id=table.id, column_map=dict(a="b", z="b")), ) # Overlapping new and old column names with pytest.raises(RuntimeError): - jamai.rename_columns( + jamai.table.rename_columns( table_type, p.ColumnRenameRequest(table_id=table.id, column_map=dict(a="b", z="a")), ) @@ -1525,7 +1541,22 @@ def test_kt_rename_invalid_columns(client_cls: Type[JamAI]): assert isinstance(table, p.TableMetaResponse) for col in KT_FIXED_COLUMN_IDS: with pytest.raises(RuntimeError): - jamai.rename_columns( + jamai.table.rename_columns( + table_type, + p.ColumnRenameRequest(table_id=table.id, column_map={col: col}), + ) + + +@flaky(max_runs=5, min_passes=1, rerun_filter=_rerun_on_fs_error_with_delay) +@pytest.mark.parametrize("client_cls", CLIENT_CLS) +def test_ct_rename_invalid_columns(client_cls: Type[JamAI]): + table_type = "chat" + jamai = client_cls() + with _create_table(jamai, table_type) as table: + assert isinstance(table, p.TableMetaResponse) + for col in CT_FIXED_COLUMN_IDS: + with pytest.raises(RuntimeError): + jamai.table.rename_columns( table_type, p.ColumnRenameRequest(table_id=table.id, column_map={col: col}), ) @@ -1582,7 +1613,7 @@ def test_reorder_columns( cols = [c.id for c in table.cols] assert cols == expected_order, cols # Test reorder empty table - table = jamai.reorder_columns( + table = jamai.table.reorder_columns( table_type, p.ColumnReorderRequest(table_id=TABLE_ID_A, column_names=column_names), ) @@ -1692,7 +1723,7 @@ def test_reorder_columns_invalid( else: raise ValueError(f"Invalid table type: {table_type}") with pytest.raises(RuntimeError, match="referenced an invalid source column"): - jamai.reorder_columns( + jamai.table.reorder_columns( table_type, p.ColumnReorderRequest(table_id=TABLE_ID_A, column_names=column_names), ) diff --git a/docker/compose.cpu.ollama.yml b/docker/compose.cpu.ollama.yml index 4fed391..1b9749e 100644 --- a/docker/compose.cpu.ollama.yml +++ b/docker/compose.cpu.ollama.yml @@ -29,12 +29,12 @@ services: echo 'ollama serve did not start in time'; \ exit 1; \ fi; \ - ollama pull phi3.5 && ollama cp phi3.5 microsoft/Phi3.5-mini-instruct; \ + ollama pull qwen2.5:3b && ollama cp qwen2.5:3b Qwen/Qwen2.5-3B-Instruct; \ tail -f /dev/null", ] restart: unless-stopped healthcheck: - test: ["CMD", "sh", "-c", "ollama show microsoft/Phi3.5-mini-instruct || exit 1"] + test: ["CMD", "sh", "-c", "ollama show Qwen/Qwen2.5-3B-Instruct || exit 1"] interval: 20s timeout: 2s retries: 20 diff --git a/docker/ollama.yml b/docker/ollama.yml index ce53422..8c0cc58 100644 --- a/docker/ollama.yml +++ b/docker/ollama.yml @@ -1,4 +1,4 @@ services: owl: environment: - - OWL_MODELS_CONFIG="models_ollama.json" + - OWL_MODELS_CONFIG=models_ollama.json diff --git a/scripts/remove_cloud_modules.ps1 b/scripts/remove_cloud_modules.ps1 index 8cbddcd..2a86c15 100644 --- a/scripts/remove_cloud_modules.ps1 +++ b/scripts/remove_cloud_modules.ps1 @@ -19,3 +19,4 @@ function quiet_rm($item) quiet_rm "services/app/ecosystem.config.cjs" quiet_rm "services/appecosystem.json" quiet_rm ".github/workflows/trigger-push-gh-image.yml" +quiet_rm ".github/workflows/ci.cloud.yml" \ No newline at end of file diff --git a/scripts/remove_cloud_modules.sh b/scripts/remove_cloud_modules.sh index a827108..0ef410f 100644 --- a/scripts/remove_cloud_modules.sh +++ b/scripts/remove_cloud_modules.sh @@ -9,3 +9,4 @@ find . -type d -name "(cloud)" -exec rm -rf {} + rm -f services/app/ecosystem.config.cjs rm -f services/app/ecosystem.json rm -f .github/workflows/trigger-push-gh-image.yml +rm -f .github/workflows/ci.cloud.yml diff --git a/services/api/pyproject.toml b/services/api/pyproject.toml index 45069f4..61ef746 100644 --- a/services/api/pyproject.toml +++ b/services/api/pyproject.toml @@ -132,7 +132,7 @@ dependencies = [ "sqlmodel~=0.0.21", "srsly~=2.4.8", # starlette 0.38.3 and 0.38.4 seem to have issues with background tasks - "starlette==0.38.2", + "starlette~=0.41.3", "stripe~=9.12.0", "tantivy~=0.22.0", "tenacity~=8.5.0", diff --git a/services/api/src/owl/configs/models_aipc.json b/services/api/src/owl/configs/models_aipc.json index c7e8b74..3ba623c 100644 --- a/services/api/src/owl/configs/models_aipc.json +++ b/services/api/src/owl/configs/models_aipc.json @@ -132,7 +132,7 @@ "id": "ellm/BAAI/bge-small-en-v1.5", "name": "ELLM BAAI BGE Small EN v1.5", "context_length": 512, - "embedding_size": 1024, + "embedding_size": 384, "languages": ["mul"], "capabilities": ["embed"], "deployments": [ diff --git a/services/api/src/owl/configs/models_ollama.json b/services/api/src/owl/configs/models_ollama.json index 6705fc2..00f1ed0 100644 --- a/services/api/src/owl/configs/models_ollama.json +++ b/services/api/src/owl/configs/models_ollama.json @@ -43,15 +43,15 @@ ] }, { - "id": "ellm/microsoft/Phi3.5-mini-instruct", - "name": "ELLM Phi3.5 mini instruct (3.8B)", - "context_length": 131072, + "id": "ellm/Qwen/Qwen2.5-3B-Instruct", + "name": "ELLM Qwen2.5 (3B)", + "context_length": 32000, "languages": ["en"], "capabilities": ["chat"], "deployments": [ { - "litellm_id": "ollama_chat/microsoft/Phi3.5-mini-instruct", - "api_base": "http://ollama:11434", + "litellm_id": "openai/Qwen/Qwen2.5-3B-Instruct", + "api_base": "http://ollama:11434/v1", "provider": "ellm" } ] @@ -62,7 +62,7 @@ "id": "ellm/BAAI/bge-small-en-v1.5", "name": "ELLM BAAI BGE Small EN v1.5", "context_length": 512, - "embedding_size": 1024, + "embedding_size": 384, "languages": ["mul"], "capabilities": ["embed"], "deployments": [ diff --git a/services/api/src/owl/db/gen_table.py b/services/api/src/owl/db/gen_table.py index 223540f..abf1c0d 100644 --- a/services/api/src/owl/db/gen_table.py +++ b/services/api/src/owl/db/gen_table.py @@ -205,6 +205,9 @@ def create_table( ) -> tuple[LanceTable, TableMeta]: if not isinstance(schema, TableSchema): raise TypeError("`schema` must be an instance of `TableSchema`.") + fixed_cols = set(c.lower() for c in self.FIXED_COLUMN_IDS) + if len(fixed_cols.intersection(set(c.id.lower() for c in schema.cols))) != len(fixed_cols): + raise BadInputError(f"Schema must contain fixed columns: {self.FIXED_COLUMN_IDS}") return self._create_table( session=session, schema=schema, @@ -573,15 +576,12 @@ def drop_columns( if not isinstance(column_names, list): raise TypeError("`column_names` must be a list.") if self.has_state_col_names(column_names): - raise make_validation_error( - ValueError("Cannot drop state columns."), - loc=("body", "column_names"), - ) + raise BadInputError("Cannot drop state columns.") if self.has_info_col_names(column_names): - raise make_validation_error( - ValueError('Cannot drop "ID" or "Updated at".'), - loc=("body", "column_names"), - ) + raise BadInputError('Cannot drop "ID" or "Updated at".') + fixed_cols = set(c.lower() for c in self.FIXED_COLUMN_IDS) + if len(fixed_cols.intersection(set(c.lower() for c in column_names))) > 0: + raise BadInputError(f"Cannot drop fixed columns: {self.FIXED_COLUMN_IDS}") with self.lock(table_id): # Get table metadata @@ -619,37 +619,25 @@ def rename_columns( ) -> TableMeta: new_col_names = set(column_map.values()) if self.has_state_col_names(column_map.keys()): - raise make_validation_error( - ValueError("Cannot rename state columns."), - loc=("body", "column_map"), - ) + raise BadInputError("Cannot rename state columns.") if self.has_info_col_names(column_map.keys()): - raise make_validation_error( - ValueError('Cannot rename "ID" or "Updated at".'), - loc=("body", "column_map"), - ) + raise BadInputError('Cannot rename "ID" or "Updated at".') + fixed_cols = set(c.lower() for c in self.FIXED_COLUMN_IDS) + if len(fixed_cols.intersection(set(c.lower() for c in column_map))) > 0: + raise BadInputError(f"Cannot rename fixed columns: {self.FIXED_COLUMN_IDS}") if len(new_col_names) != len(column_map): - raise make_validation_error( - ValueError("`column_map` contains repeated new column names."), - loc=("body", "column_map"), - ) + raise BadInputError("`column_map` contains repeated new column names.") if not all(re.match(COL_NAME_PATTERN, v) for v in column_map.values()): - raise make_validation_error( - ValueError("`column_map` contains invalid new column names."), - loc=("body", "column_map"), - ) + raise BadInputError("`column_map` contains invalid new column names.") meta = self.open_meta(session, table_id) col_names = set(c.id for c in meta.cols_schema) overlap_col_names = col_names.intersection(new_col_names) if len(overlap_col_names) > 0: - raise make_validation_error( - ValueError( - ( - "`column_map` contains new column names that " - f"overlap with existing column names: {overlap_col_names}" - ) - ), - loc=("body", "column_map"), + raise BadInputError( + ( + "`column_map` contains new column names that " + f"overlap with existing column names: {overlap_col_names}" + ) ) not_found = set(column_map.keys()) - col_names if len(not_found) > 0: @@ -1892,69 +1880,9 @@ def add_columns( # raise TableSchemaFixedError("Knowledge Table contains data, cannot add columns.") return super().add_columns(session, schema) - @override - def drop_columns( - self, - session: Session, - table_id: TableName, - col_names: list[ColName], - ) -> tuple[LanceTable, TableMeta]: - """ - Drops one or more input or output column. - - Args: - session (Session): SQLAlchemy session. - table_id (str): The ID of the table. - col_names (list[str]): List of column ID to drop. - - Raises: - TypeError: If `col_names` is not a list. - ResourceNotFoundError: If the table is not found. - ResourceNotFoundError: If any of the columns is not found. - - Returns: - table (LanceTable): Lance table. - meta (TableMeta): Table metadata. - """ - fixed_col_ids = [i.lower() for i in self.FIXED_COLUMN_IDS] - if sum(n.lower() in fixed_col_ids for n in col_names) > 0: - cols = ", ".join(f'"{c}"' for c in self.FIXED_COLUMN_IDS) - raise TableSchemaFixedError(f"Cannot drop {cols}.") - return super().drop_columns(session, table_id, col_names) - - @override - def rename_columns( - self, - session: Session, - table_id: TableName, - column_map: dict[ColName, ColName], - ) -> TableMeta: - fixed_col_ids = [i.lower() for i in self.FIXED_COLUMN_IDS] - if sum(n.lower() in fixed_col_ids for n in column_map) > 0: - cols = ", ".join(f'"{c}"' for c in self.FIXED_COLUMN_IDS) - raise TableSchemaFixedError(f"Cannot rename {cols}.") - return super().rename_columns(session, table_id, column_map) - - @override - def update_rows( - self, - session: Session, - table_id: TableName, - where: str | None = None, - *, - values: dict | None = None, - ) -> Self: - # Validate data - return super().update_rows( - session=session, - table_id=table_id, - where=where, - values=values, - ) - class ChatTable(GenerativeTable): - FIXED_COLUMN_IDS = ["User", "AI"] + FIXED_COLUMN_IDS = ["User"] @override def create_table( @@ -1966,11 +1894,16 @@ def create_table( ) -> tuple[LanceTable, TableMeta]: if not isinstance(schema, ChatTableSchemaCreate): raise TypeError("`schema` must be an instance of `ChatTableSchemaCreate`.") + num_chat_cols = len([c for c in schema.cols if c.gen_config and c.gen_config.multi_turn]) + if num_chat_cols == 0: + raise BadInputError("The table must have at least one multi-turn column.") return super().create_table(session, schema, remove_state_cols, add_info_state_cols) @override def add_columns( - self, session: Session, schema: AddChatColumnSchema + self, + session: Session, + schema: AddChatColumnSchema, ) -> tuple[LanceTable, TableMeta]: """ Adds one or more input or output column. @@ -2000,7 +1933,7 @@ def drop_columns( self, session: Session, table_id: TableName, - col_names: list[ColName], + column_names: list[ColName], ) -> tuple[LanceTable, TableMeta]: """ Drops one or more input or output column. @@ -2008,10 +1941,10 @@ def drop_columns( Args: session (Session): SQLAlchemy session. table_id (str): The ID of the table. - col_names (list[str]): List of column ID to drop. + column_names (list[str]): List of column ID to drop. Raises: - TypeError: If `col_names` is not a list. + TypeError: If `column_names` is not a list. ResourceNotFoundError: If the table is not found. ResourceNotFoundError: If any of the columns is not found. @@ -2019,23 +1952,28 @@ def drop_columns( table (LanceTable): Lance table. meta (TableMeta): Table metadata. """ - if sum(n.lower() in ("user", "ai") for n in col_names) > 0: - raise make_validation_error( - ValueError('Cannot drop "User" or "AI".'), - loc=("body", "column_names"), - ) with self.create_session() as session: meta = self.open_meta(session, table_id) if meta.parent_id is not None: raise TableSchemaFixedError("Unable to drop columns from a conversation table.") - return super().drop_columns(session, table_id, col_names) + num_chat_cols = len( + [ + c + for c in meta.cols_schema + if c.id not in column_names and c.gen_config and c.gen_config.multi_turn + ] + ) + if num_chat_cols == 0: + raise BadInputError("The table must have at least one multi-turn column.") + return super().drop_columns(session, table_id, column_names) @override def rename_columns( - self, session: Session, table_id: TableName, column_map: dict[ColName, ColName] + self, + session: Session, + table_id: TableName, + column_map: dict[ColName, ColName], ) -> TableMeta: - if sum(n.lower() in ("user", "ai") for n in column_map) > 0: - raise TableSchemaFixedError('Cannot rename "User" or "AI".') with self.create_session() as session: meta = self.open_meta(session, table_id) if meta.parent_id is not None: diff --git a/services/api/src/owl/llm.py b/services/api/src/owl/llm.py index 43881a4..c8458fd 100644 --- a/services/api/src/owl/llm.py +++ b/services/api/src/owl/llm.py @@ -32,15 +32,12 @@ ChatRole, Chunk, CompletionUsage, - EmbeddingModelConfig, ExternalKeys, - LLMModelConfig, ModelInfo, ModelInfoResponse, ModelListConfig, RAGParams, References, - RerankingModelConfig, ) from owl.utils import mask_content, mask_string, select_external_api_key @@ -207,54 +204,13 @@ def _map_and_log_exception( self._log_exception(model, messages, api_key, **hyperparams) return UnexpectedError(err_mssg) - def _get_valid_deployments( - self, - model: LLMModelConfig | EmbeddingModelConfig | RerankingModelConfig, - valid_providers: list[str], - ): - valid_deployments = [] - for deployment in model.deployments: - if deployment.provider in valid_providers: - valid_deployments.append(deployment) - return valid_deployments - def model_info( self, model: str = "", capabilities: list[str] | None = None, ) -> ModelInfoResponse: - all_models: ModelListConfig = self.request.state.all_models - # define all possible api providers - available_providers = [ - "openai", - "anthropic", - "together_ai", - "cohere", - "sambanova", - "cerebras", - "hyperbolic", - ] - # remove providers without credentials - available_providers = [ - provider - for provider in available_providers - if getattr(self.external_keys, provider) != "" - ] - - # add custom and ellm providers as allow no credentials - available_providers.extend( - [ - "custom", - "ellm", - ] - ) - models = [] - # Iterate over the llm, embed, rerank list - for m in all_models.models: - valid_deployments = self._get_valid_deployments(m, available_providers) - if len(valid_deployments) > 0: - m.deployments = valid_deployments - models.append(m) + model_list: ModelListConfig = self.request.state.all_models + models = model_list.models # Filter by name if model != "": models = [m for m in models if m.id == model] diff --git a/services/api/src/owl/protocol.py b/services/api/src/owl/protocol.py index 9d4e8c2..ea5b232 100644 --- a/services/api/src/owl/protocol.py +++ b/services/api/src/owl/protocol.py @@ -1761,16 +1761,7 @@ def check_gen_configs(self) -> Self: class ChatTableSchemaCreate(TableSchemaCreate): - @model_validator(mode="after") - def check_cols(self) -> Self: - super().check_cols() - num_text_cols = sum(c.id.lower() in ("user", "ai") for c in self.cols) - if num_text_cols != 2: - raise ValueError("Schema must contain column names: 'User' and 'AI'.") - for c in self.cols: - if c.id.lower() == "ai": - c.gen_config.multi_turn = True - return self + pass class AddChatColumnSchema(TableSchemaCreate): diff --git a/services/api/src/owl/utils/auth.py b/services/api/src/owl/utils/auth.py index 48047a9..dc3ab57 100644 --- a/services/api/src/owl/utils/auth.py +++ b/services/api/src/owl/utils/auth.py @@ -1,3 +1,4 @@ +from functools import lru_cache from secrets import compare_digest from typing import Annotated, AsyncGenerator @@ -15,7 +16,16 @@ UnexpectedError, UpgradeTierError, ) -from jamaibase.protocol import OrganizationRead, PATRead, ProjectRead, UserRead +from jamaibase.protocol import ( + EmbeddingModelConfig, + LLMModelConfig, + ModelDeploymentConfig, + OrganizationRead, + PATRead, + ProjectRead, + RerankingModelConfig, + UserRead, +) from owl.billing import BillingManager from owl.configs.manager import CONFIG, ENV_CONFIG from owl.protocol import ExternalKeys, ModelListConfig @@ -236,6 +246,77 @@ async def auth_user_cloud( auth_user = auth_user_oss if ENV_CONFIG.is_oss else auth_user_cloud +def _get_valid_deployments( + model: LLMModelConfig | EmbeddingModelConfig | RerankingModelConfig, + valid_providers: list[str], +) -> list[ModelDeploymentConfig]: + valid_deployments = [] + for deployment in model.deployments: + if deployment.provider in valid_providers: + valid_deployments.append(deployment) + return valid_deployments + + +@lru_cache(maxsize=64) +def _get_valid_modellistconfig(all_models: str, external_keys: str) -> ModelListConfig: + all_models = ModelListConfig.model_validate_json(all_models) + external_keys = ExternalKeys.model_validate_json(external_keys) + # define all possible api providers + available_providers = [ + "openai", + "anthropic", + "together_ai", + "cohere", + "sambanova", + "cerebras", + "hyperbolic", + ] + # remove providers without credentials + available_providers = [ + provider for provider in available_providers if getattr(external_keys, provider) != "" + ] + # add custom and ellm providers as allow no credentials + available_providers.extend( + [ + "custom", + "ellm", + ] + ) + + # Initialize lists to hold valid models + valid_llm_models = [] + valid_embed_models = [] + valid_rerank_models = [] + + # Iterate over the llm, embed, rerank list + for m in all_models.llm_models: + valid_deployments = _get_valid_deployments(m, available_providers) + if len(valid_deployments) > 0: + m.deployments = valid_deployments + valid_llm_models.append(m) + + for m in all_models.embed_models: + valid_deployments = _get_valid_deployments(m, available_providers) + if len(valid_deployments) > 0: + m.deployments = valid_deployments + valid_embed_models.append(m) + + for m in all_models.rerank_models: + valid_deployments = _get_valid_deployments(m, available_providers) + if len(valid_deployments) > 0: + m.deployments = valid_deployments + valid_rerank_models.append(m) + + # Create a new ModelListConfig with the valid models + valid_model_list_config = ModelListConfig( + llm_models=valid_llm_models, + embed_models=valid_embed_models, + rerank_models=valid_rerank_models, + ) + + return valid_model_list_config + + async def auth_user_project_oss( request: Request, project_id: Annotated[ @@ -255,7 +336,10 @@ async def auth_user_project_oss( request.state.project_id = project.id request.state.external_keys = _get_external_keys(organization) request.state.org_models = ModelListConfig.model_validate(organization.models) - request.state.all_models = request.state.org_models + CONFIG.get_model_config() + all_models = request.state.org_models + CONFIG.get_model_config() + request.state.all_models = _get_valid_modellistconfig( + all_models.model_dump_json(), request.state.external_keys.model_dump_json() + ) request.state.billing = BillingManager(request=request) yield project @@ -293,8 +377,10 @@ async def auth_user_project_cloud( request.state.project_id = project.id request.state.external_keys = _get_external_keys(organization) request.state.org_models = ModelListConfig.model_validate(organization.models) - request.state.all_models = request.state.org_models + CONFIG.get_model_config() - + all_models = request.state.org_models + CONFIG.get_model_config() + request.state.all_models = _get_valid_modellistconfig( + all_models.model_dump_json(), request.state.external_keys.model_dump_json() + ) # Check if token is provided bearer_token = bearer_token.split("Bearer ") if len(bearer_token) < 2 or bearer_token[1].strip() == "": diff --git a/services/app/src/lib/components/preset/SearchBar.svelte b/services/app/src/lib/components/preset/SearchBar.svelte index 8b8d317..9a024cf 100644 --- a/services/app/src/lib/components/preset/SearchBar.svelte +++ b/services/app/src/lib/components/preset/SearchBar.svelte @@ -15,11 +15,11 @@
@@ -28,7 +28,7 @@ bind:value={searchQuery} aria-label={label} {placeholder} - class="pl-8 sm:pl-9 pr-4 py-[6.5px] sm:py-2 w-full text-xs sm:text-sm placeholder:text-[#667085] bg-transparent focus-visible:outline-none {!searchQuery + class="pl-8 sm:pl-9 pr-4 py-[6.5px] sm:py-2 w-full text-xs sm:text-sm placeholder:text-[#98A2B3] bg-transparent focus-visible:outline-none {!searchQuery ? 'cursor-pointer focus-visible:cursor-text' : ''} peer" /> @@ -39,7 +39,7 @@
{:else} {/if} diff --git a/services/app/src/lib/components/tables/(sub)/ColumnDropdown.svelte b/services/app/src/lib/components/tables/(sub)/ColumnDropdown.svelte index c84e4cb..b15ce22 100644 --- a/services/app/src/lib/components/tables/(sub)/ColumnDropdown.svelte +++ b/services/app/src/lib/components/tables/(sub)/ColumnDropdown.svelte @@ -104,9 +104,6 @@ //? Revert back to original value genTableRows.revert(originalValues); } else { - //Delete all data except for inputs - genTableRows.clearOutputs(tableData, toRegenRowIds); - const reader = response.body!.pipeThrough(new TextDecoderStream()).getReader(); let isStreaming = true; @@ -210,7 +207,7 @@ variant="ghost" on:click={(e) => e.stopPropagation()} title="Column actions" - class="flex-[0_0_auto] ml-auto p-0 h-7 w-7 aspect-square" + class="flex-[0_0_auto] !z-0 ml-auto p-0 h-7 w-7 aspect-square" > @@ -235,7 +232,7 @@ {#if !readonly && (tableType !== 'chat' || !chatTableStaticCols.includes(column.id)) && (tableType !== 'knowledge' || !knowledgeTableStaticCols.includes(column.id))} - + {/if} { diff --git a/services/app/src/lib/components/tables/(sub)/ColumnSettings.svelte b/services/app/src/lib/components/tables/(sub)/ColumnSettings.svelte index 70b8ba9..49250ac 100644 --- a/services/app/src/lib/components/tables/(sub)/ColumnSettings.svelte +++ b/services/app/src/lib/components/tables/(sub)/ColumnSettings.svelte @@ -277,6 +277,13 @@ > {isColumnSettingsOpen.column?.dtype} + + {#if isColumnSettingsOpen.column.gen_config?.object === 'gen_config.llm' && isColumnSettingsOpen.column.gen_config.multi_turn} +
+
+ +
+ {/if} diff --git a/services/app/src/lib/components/tables/(sub)/TablePagination.svelte b/services/app/src/lib/components/tables/(sub)/TablePagination.svelte index a0a32fb..63e972b 100644 --- a/services/app/src/lib/components/tables/(sub)/TablePagination.svelte +++ b/services/app/src/lib/components/tables/(sub)/TablePagination.svelte @@ -2,7 +2,7 @@ import { page } from '$app/stores'; import { genTableRows } from '$lib/components/tables/tablesStore'; import * as constants from '$lib/constants'; - import type { GenTable, GenTableCol, GenTableRow } from '$lib/types'; + import type { GenTable, GenTableCol } from '$lib/types'; import { Skeleton } from '$lib/components/ui/skeleton'; import * as Pagination from '$lib/components/ui/pagination'; @@ -78,7 +78,7 @@ @@ -104,7 +104,7 @@ style={currentPage === page.value ? 'background: #E4E7EC; pointer-events: none;' : ''} - class="inline-flex items-center justify-center rounded-sm {pageFontSize} font-medium whitespace-nowrap ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50 border border-input bg-background hover:bg-accent hover:text-accent-foreground h-6 w-6" + class="inline-flex items-center justify-center {pageFontSize} text-[#475467] font-medium whitespace-nowrap ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 rounded-full disabled:pointer-events-none disabled:opacity-50 hover:bg-[#F2F4F7] h-6 w-6" > {page.value} @@ -117,7 +117,7 @@ diff --git a/services/app/src/lib/components/tables/ActionTable.svelte b/services/app/src/lib/components/tables/ActionTable.svelte index 2763a2f..c35230b 100644 --- a/services/app/src/lib/components/tables/ActionTable.svelte +++ b/services/app/src/lib/components/tables/ActionTable.svelte @@ -2,7 +2,6 @@ import { PUBLIC_JAMAI_URL } from '$env/static/public'; import { onDestroy } from 'svelte'; import { page } from '$app/stores'; - import { invalidate } from '$app/navigation'; import GripVertical from 'lucide-svelte/icons/grip-vertical'; import { genTableRows } from '$lib/components/tables/tablesStore'; import { isValidUri } from '$lib/utils'; @@ -48,18 +47,12 @@ export let isColumnSettingsOpen: { column: any; showMenu: boolean }; export let isDeletingColumn: string | null; export let readonly = false; + export let refetchTable: (hideColumnSettings?: boolean) => Promise; let rowThumbs: { [rowID: string]: { [colID: string]: { value: string; url: string } } } = {}; let isDeletingFile: { rowID: string; columnID: string; fileUri?: string } | null = null; let uploadController: AbortController | undefined = undefined; - async function refetchTable() { - //? Don't refetch while streaming - if (Object.keys(streamingRows).length === 0) { - await invalidate('action-table:slug'); - } - } - //? Expanding ID and Updated at columns let focusedCol: string | null = null; @@ -109,7 +102,7 @@ }); if (response.ok) { - invalidate('action-table:slug'); + refetchTable(); tableData = { ...tableData, cols: tableData.cols.map((col) => diff --git a/services/app/src/lib/components/tables/ChatTable.svelte b/services/app/src/lib/components/tables/ChatTable.svelte index c1f157e..f937898 100644 --- a/services/app/src/lib/components/tables/ChatTable.svelte +++ b/services/app/src/lib/components/tables/ChatTable.svelte @@ -2,7 +2,6 @@ import { PUBLIC_JAMAI_URL } from '$env/static/public'; import { onDestroy } from 'svelte'; import { page } from '$app/stores'; - import { invalidate } from '$app/navigation'; import GripVertical from 'lucide-svelte/icons/grip-vertical'; import { genTableRows } from '$lib/components/tables/tablesStore'; import { isValidUri } from '$lib/utils'; @@ -48,18 +47,12 @@ export let isColumnSettingsOpen: { column: any; showMenu: boolean }; export let isDeletingColumn: string | null; export let readonly = false; + export let refetchTable: (hideColumnSettings?: boolean) => Promise; let rowThumbs: { [rowID: string]: { [colID: string]: { value: string; url: string } } } = {}; let isDeletingFile: { rowID: string; columnID: string; fileUri?: string } | null = null; let uploadController: AbortController | undefined = undefined; - async function refetchTable() { - //? Don't refetch while streaming - if (Object.keys(streamingRows).length === 0) { - await invalidate('chat-table:slug'); - } - } - //? Expanding ID and Updated at columns let focusedCol: string | null = null; @@ -109,7 +102,7 @@ }); if (response.ok) { - invalidate('chat-table:slug'); + refetchTable(); tableData = { ...tableData, cols: tableData.cols.map((col) => diff --git a/services/app/src/lib/components/tables/KnowledgeTable.svelte b/services/app/src/lib/components/tables/KnowledgeTable.svelte index ea989b7..ebc345c 100644 --- a/services/app/src/lib/components/tables/KnowledgeTable.svelte +++ b/services/app/src/lib/components/tables/KnowledgeTable.svelte @@ -2,7 +2,6 @@ import { PUBLIC_JAMAI_URL } from '$env/static/public'; import { onDestroy } from 'svelte'; import { page } from '$app/stores'; - import { invalidate } from '$app/navigation'; import GripVertical from 'lucide-svelte/icons/grip-vertical'; import { genTableRows } from '$lib/components/tables/tablesStore'; import { isValidUri } from '$lib/utils'; @@ -48,18 +47,12 @@ export let isColumnSettingsOpen: { column: any; showMenu: boolean }; export let isDeletingColumn: string | null; export let readonly = false; + export let refetchTable: (hideColumnSettings?: boolean) => Promise; let rowThumbs: { [rowID: string]: { [colID: string]: { value: string; url: string } } } = {}; let isDeletingFile: { rowID: string; columnID: string; fileUri?: string } | null = null; let uploadController: AbortController | undefined = undefined; - async function refetchTable() { - //? Don't refetch while streaming - if (Object.keys(streamingRows).length === 0) { - await invalidate('knowledge-table:slug'); - } - } - //? Expanding ID and Updated at columns let focusedCol: string | null = null; @@ -108,7 +101,7 @@ }); if (response.ok) { - invalidate('knowledge-table:slug'); + refetchTable(); tableData = { ...tableData, cols: tableData.cols.map((col) => diff --git a/services/app/src/lib/components/ui/button/index.ts b/services/app/src/lib/components/ui/button/index.ts index ec5b4cb..ab87447 100644 --- a/services/app/src/lib/components/ui/button/index.ts +++ b/services/app/src/lib/components/ui/button/index.ts @@ -3,7 +3,7 @@ import { tv, type VariantProps } from 'tailwind-variants'; import type { Button as ButtonPrimitive } from 'bits-ui'; const buttonVariants = tv({ - base: 'inline-flex items-center justify-center rounded-md text-sm font-medium whitespace-nowrap ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 rounded-full disabled:pointer-events-none disabled:opacity-50', + base: 'inline-flex items-center justify-center text-sm font-medium whitespace-nowrap ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 rounded-full disabled:pointer-events-none disabled:opacity-50', variants: { variant: { default: diff --git a/services/app/src/lib/constants.ts b/services/app/src/lib/constants.ts index 22ae8a0..fe4308e 100644 --- a/services/app/src/lib/constants.ts +++ b/services/app/src/lib/constants.ts @@ -21,7 +21,7 @@ export const knowledgeTableStaticCols = [ 'Text Embed', 'File ID' ]; -export const chatTableStaticCols = ['ID', 'Updated at', 'User', 'AI']; +export const chatTableStaticCols = ['ID', 'Updated at', 'User']; export const knowledgeTableEmbedCols = ['Title Embed', 'Text Embed']; export const knowledgeTableFiletypes = [ '.csv', diff --git a/services/app/src/lib/icons/MultiturnChatIcon.svelte b/services/app/src/lib/icons/MultiturnChatIcon.svelte index d21606f..0b358d4 100644 --- a/services/app/src/lib/icons/MultiturnChatIcon.svelte +++ b/services/app/src/lib/icons/MultiturnChatIcon.svelte @@ -1,19 +1,31 @@ - - - - +{#if filled} + + + +{:else} + + + + +{/if} diff --git a/services/app/src/lib/icons/RegenerateIcon.svelte b/services/app/src/lib/icons/RegenerateIcon.svelte index b75e299..482c13b 100644 --- a/services/app/src/lib/icons/RegenerateIcon.svelte +++ b/services/app/src/lib/icons/RegenerateIcon.svelte @@ -1,15 +1,31 @@ - + + + diff --git a/services/app/src/lib/types.ts b/services/app/src/lib/types.ts index 9b6b8cd..457bdcd 100644 --- a/services/app/src/lib/types.ts +++ b/services/app/src/lib/types.ts @@ -99,6 +99,10 @@ export type ChatRequest = { user: string; }; +type ThreadObj = ChatRequest['messages'][number] & { column_id: string }; +type ThreadErr = { error?: number; message: any } & { column_id: string }; +export type Thread = ThreadObj | ThreadErr; + export interface GenTableCol { id: string; dtype: (typeof genTableDTypes)[number]; diff --git a/services/app/src/routes/(main)/project/[project_id]/(components)/ActionsDropdown.svelte b/services/app/src/routes/(main)/project/[project_id]/(components)/ActionsDropdown.svelte index 1419055..fcbb3d5 100644 --- a/services/app/src/routes/(main)/project/[project_id]/(components)/ActionsDropdown.svelte +++ b/services/app/src/routes/(main)/project/[project_id]/(components)/ActionsDropdown.svelte @@ -254,7 +254,7 @@ builders={[builder]} variant="ghost" title="Table actions" - class="p-0 h-8 sm:h-9 w-auto aspect-square" + class="p-0 h-8 sm:h-9 w-auto aspect-square bg-[#F2F4F7] hover:bg-[#E4E7EC]" > @@ -268,12 +268,12 @@ Order by - Last modified + Created
@@ -282,17 +279,17 @@
@@ -300,8 +297,8 @@ diff --git a/services/app/src/routes/(main)/project/[project_id]/chat-table/[table_id]/ModeToggle.svelte b/services/app/src/routes/(main)/project/[project_id]/chat-table/[table_id]/ModeToggle.svelte index 48ed68c..f564b49 100644 --- a/services/app/src/routes/(main)/project/[project_id]/chat-table/[table_id]/ModeToggle.svelte +++ b/services/app/src/routes/(main)/project/[project_id]/chat-table/[table_id]/ModeToggle.svelte @@ -45,13 +45,21 @@ className )} > - + + + + @@ -61,8 +69,9 @@ ? 'translate-x-1' : 'translate-x-[36px] sm:translate-x-[40px]'}" > -
+
diff --git a/services/app/src/routes/(main)/project/[project_id]/knowledge-table/[table_id]/+page@project.svelte b/services/app/src/routes/(main)/project/[project_id]/knowledge-table/[table_id]/+page@project.svelte index 385a29b..ccb902f 100644 --- a/services/app/src/routes/(main)/project/[project_id]/knowledge-table/[table_id]/+page@project.svelte +++ b/services/app/src/routes/(main)/project/[project_id]/knowledge-table/[table_id]/+page@project.svelte @@ -366,12 +366,12 @@ Chunk Editor --> {:else} - +
- - + +
- + {/if}
@@ -383,8 +383,9 @@ bind:selectedRows bind:isColumnSettingsOpen bind:isDeletingColumn - {streamingRows} + bind:streamingRows {table} + {refetchTable} /> {#if !tableError} diff --git a/services/app/tests/pages/table.page.ts b/services/app/tests/pages/table.page.ts index 5ed537a..0c1116c 100644 --- a/services/app/tests/pages/table.page.ts +++ b/services/app/tests/pages/table.page.ts @@ -205,10 +205,12 @@ export class TablePage extends LayoutPage { await newColDialog.getByLabel('Column ID').fill(`transient-${type}-column`); await newColDialog.getByTestId('datatype-select-btn').click(); - await newColDialog - .getByTestId('datatype-select-btn') - .locator('div[role="option"]', { hasText: datatype }) - .click(); + if (type === 'input') { + await newColDialog + .getByTestId('datatype-select-btn') + .locator('div[role="option"]', { hasText: datatype }) + .click(); + } if (type === 'output') { await newColDialog.getByLabel('Customize prompt').fill('Hello, what is your favorite food?'); } diff --git a/services/app/tests/tables/chatTable.spec.ts b/services/app/tests/tables/chatTable.spec.ts index f71495b..6a10a8d 100644 --- a/services/app/tests/tables/chatTable.spec.ts +++ b/services/app/tests/tables/chatTable.spec.ts @@ -97,16 +97,6 @@ test.describe('Chat Table Page Basic', () => { await expect(chatMessages.last()).toHaveAttribute('data-streaming', 'true'); await expect(chatMessages.last()).not.toHaveAttribute('data-streaming'); }); - - test('can regenerate message', async ({ page }) => { - const chatMessages = page.getByTestId('chat-message'); - // const originalText = await chatMessages.last().innerText(); - await page.getByTestId('stop-regen-btn').click(); - - await expect(chatMessages.last()).toHaveAttribute('data-streaming', 'true'); - await expect(chatMessages.last()).not.toHaveAttribute('data-streaming'); - // await expect(chatMessages.last()).not.toHaveText(originalText ?? ''); - }); }); test('can update column config, and persist', async ({ tablePage }) => {