From 71191ea027c71b6e407af8f9579350163e543f8b Mon Sep 17 00:00:00 2001 From: Ben Sherman Date: Fri, 10 Jan 2025 13:20:00 -0800 Subject: [PATCH] chore(dev): table rows query returns absolute index of each row in parent table --- .../Browse3/datasets/DatasetEditorContext.tsx | 2 +- .../Browse3/datasets/EditableDatasetView.tsx | 24 +--- .../traceServerClientTypes.ts | 1 + .../clickhouse_trace_server_batched.py | 4 +- weave/trace_server/table_query_builder.py | 22 ++-- weave/trace_server/trace_server_interface.py | 123 +++++++++++++----- 6 files changed, 114 insertions(+), 62 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/DatasetEditorContext.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/DatasetEditorContext.tsx index 5a0bd561e75c..141c22895dfe 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/DatasetEditorContext.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/DatasetEditorContext.tsx @@ -71,7 +71,7 @@ export const DatasetEditProvider: React.FC = ({ const processRowUpdate = useCallback( (newRow: DatasetRow, oldRow: DatasetRow): DatasetRow => { const changedField = Object.keys(newRow).find( - key => newRow[key] !== oldRow[key] && key !== 'id' + key => newRow[key] !== oldRow[key] ); if (changedField) { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/EditableDatasetView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/EditableDatasetView.tsx index 94a811c3ca58..86eb5d9439b8 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/EditableDatasetView.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/EditableDatasetView.tsx @@ -149,15 +149,6 @@ export const EditableDatasetView: FC = ({ pageSize: 50, }); - // Reset sort model and pagination if we enter edit mode with sorting applied. - useEffect(() => { - if (isEditing && sortModel.length > 0) { - setPaginationModel({page: 0, pageSize: 50}); - setSortModel([]); - setSortBy([]); - } - }, [isEditing, sortModel]); - const sharedRef = useContext(WeaveCHTableSourceRefContext); const history = useHistory(); @@ -294,7 +285,6 @@ export const EditableDatasetView: FC = ({ setAddedRows(prev => { const updatedMap = new Map(prev); const newId = `${ADDED_ROW_ID_PREFIX}${uuidv4()}`; - console.log(initialFields); const newRow = { ___weave: { id: newId, @@ -309,16 +299,14 @@ export const EditableDatasetView: FC = ({ const rows = useMemo(() => { if (fetchQueryLoaded) { - return loadedRows.map((row, i) => { + return loadedRows.map(row => { const digest = row.digest; - const absoluteIndex = - i + paginationModel.pageSize * paginationModel.page; - const editedRow = editedCellsMap.get(absoluteIndex); + const editedRow = editedCellsMap.get(row.original_index); const value = flattenObjectPreservingWeaveTypes(row.val); return { ___weave: { - id: `${digest}_${absoluteIndex}`, - index: absoluteIndex, + id: `${digest}_${row.original_index}`, + index: row.original_index, isNew: false, }, ...(editedRow ? {...value, ...editedRow} : value), @@ -326,7 +314,7 @@ export const EditableDatasetView: FC = ({ }); } return []; - }, [loadedRows, fetchQueryLoaded, editedCellsMap, paginationModel]); + }, [loadedRows, fetchQueryLoaded, editedCellsMap]); const combinedRows = useMemo(() => { if ( @@ -412,7 +400,7 @@ export const EditableDatasetView: FC = ({ headerName: field as string, flex: 1, editable: isEditing, - sortable: !isEditing, + sortable: true, filterable: false, renderCell: (params: GridRenderCellParams) => { const editedRow = editedCellsMap.get(params.row.___weave?.index); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts index 7c89efd44196..3f8f156099b8 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts @@ -284,6 +284,7 @@ export type TraceTableQueryRes = { rows: Array<{ digest: string; val: any; + original_index?: number; }>; }; diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index dc7d10e92398..eb1839c38335 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -919,7 +919,9 @@ def _table_query_stream( res = self._query_stream(query, parameters=pb.get_params()) for row in res: - yield tsi.TableRowSchema(digest=row[0], val=json.loads(row[1])) + yield tsi.TableRowSchema( + digest=row[0], val=json.loads(row[1]), original_index=row[2] + ) def table_query_stats(self, req: tsi.TableQueryStatsReq) -> tsi.TableQueryStatsRes: parameters: dict[str, Any] = { diff --git a/weave/trace_server/table_query_builder.py b/weave/trace_server/table_query_builder.py index e5b44acb97d2..d77d94e25f66 100644 --- a/weave/trace_server/table_query_builder.py +++ b/weave/trace_server/table_query_builder.py @@ -36,21 +36,22 @@ def make_natural_sort_table_query( row_digests_selection = f"arraySlice({row_digests_selection}, 1 + {{{pb.add_param(offset)}: Int64}}, {{{pb.add_param(limit)}: Int64}})" query = f""" - SELECT DISTINCT tr.digest, tr.val_dump, t.row_order + SELECT DISTINCT tr.digest, tr.val_dump, t.original_index + {{{pb.add_param(offset or 0)}: Int64}} - 1 as original_index FROM table_rows tr INNER JOIN ( - SELECT row_digest, row_number() OVER () AS row_order + SELECT row_digest, original_index FROM ( - SELECT {row_digests_selection} as row_digests + SELECT {row_digests_selection} as row_digests, + arrayEnumerate(row_digests) as original_indices FROM tables WHERE project_id = {{{project_id_name}: String}} AND digest = {{{digest_name}: String}} LIMIT 1 ) - ARRAY JOIN row_digests AS row_digest + ARRAY JOIN row_digests AS row_digest, original_indices AS original_index ) AS t ON tr.digest = t.row_digest WHERE tr.project_id = {{{project_id_name}: String}} - ORDER BY row_order ASC + ORDER BY original_index ASC """ return query @@ -88,20 +89,21 @@ def make_standard_table_query( ) query = f""" - SELECT tr.digest, tr.val_dump, tr.row_order FROM + SELECT tr.digest, tr.val_dump, tr.original_index FROM ( - SELECT DISTINCT tr.digest, tr.val_dump, t.row_order + SELECT DISTINCT tr.digest, tr.val_dump, t.original_index FROM table_rows tr INNER JOIN ( - SELECT row_digest, row_number() OVER () AS row_order + SELECT row_digest, original_index - 1 as original_index FROM ( - SELECT row_digests + SELECT row_digests, + arrayEnumerate(row_digests) as original_indices FROM tables WHERE project_id = {{{project_id_name}: String}} AND digest = {{{digest_name}: String}} LIMIT 1 ) - ARRAY JOIN row_digests AS row_digest + ARRAY JOIN row_digests AS row_digest, original_indices AS original_index ) AS t ON tr.digest = t.row_digest WHERE tr.project_id = {{{project_id_name}: String}} {sql_safe_filter_clause} diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index 699e1e128c63..20d0fb109a38 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -601,6 +601,7 @@ class TableUpdateRes(BaseModel): class TableRowSchema(BaseModel): digest: str val: Any + original_index: Optional[int] = None class TableCreateRes(BaseModel): @@ -896,47 +897,105 @@ def ensure_project_exists( return EnsureProjectExistsRes(project_name=project) # Call API - def call_start(self, req: CallStartReq) -> CallStartRes: ... - def call_end(self, req: CallEndReq) -> CallEndRes: ... - def call_read(self, req: CallReadReq) -> CallReadRes: ... - def calls_query(self, req: CallsQueryReq) -> CallsQueryRes: ... - def calls_query_stream(self, req: CallsQueryReq) -> Iterator[CallSchema]: ... - def calls_delete(self, req: CallsDeleteReq) -> CallsDeleteRes: ... - def calls_query_stats(self, req: CallsQueryStatsReq) -> CallsQueryStatsRes: ... - def call_update(self, req: CallUpdateReq) -> CallUpdateRes: ... + def call_start(self, req: CallStartReq) -> CallStartRes: + ... + + def call_end(self, req: CallEndReq) -> CallEndRes: + ... + + def call_read(self, req: CallReadReq) -> CallReadRes: + ... + + def calls_query(self, req: CallsQueryReq) -> CallsQueryRes: + ... + + def calls_query_stream(self, req: CallsQueryReq) -> Iterator[CallSchema]: + ... + + def calls_delete(self, req: CallsDeleteReq) -> CallsDeleteRes: + ... + + def calls_query_stats(self, req: CallsQueryStatsReq) -> CallsQueryStatsRes: + ... + + def call_update(self, req: CallUpdateReq) -> CallUpdateRes: + ... # Op API - def op_create(self, req: OpCreateReq) -> OpCreateRes: ... - def op_read(self, req: OpReadReq) -> OpReadRes: ... - def ops_query(self, req: OpQueryReq) -> OpQueryRes: ... + def op_create(self, req: OpCreateReq) -> OpCreateRes: + ... + + def op_read(self, req: OpReadReq) -> OpReadRes: + ... + + def ops_query(self, req: OpQueryReq) -> OpQueryRes: + ... # Cost API - def cost_create(self, req: CostCreateReq) -> CostCreateRes: ... - def cost_query(self, req: CostQueryReq) -> CostQueryRes: ... - def cost_purge(self, req: CostPurgeReq) -> CostPurgeRes: ... + def cost_create(self, req: CostCreateReq) -> CostCreateRes: + ... + + def cost_query(self, req: CostQueryReq) -> CostQueryRes: + ... + + def cost_purge(self, req: CostPurgeReq) -> CostPurgeRes: + ... # Obj API - def obj_create(self, req: ObjCreateReq) -> ObjCreateRes: ... - def obj_read(self, req: ObjReadReq) -> ObjReadRes: ... - def objs_query(self, req: ObjQueryReq) -> ObjQueryRes: ... - def obj_delete(self, req: ObjDeleteReq) -> ObjDeleteRes: ... - def table_create(self, req: TableCreateReq) -> TableCreateRes: ... - def table_update(self, req: TableUpdateReq) -> TableUpdateRes: ... - def table_query(self, req: TableQueryReq) -> TableQueryRes: ... - def table_query_stream(self, req: TableQueryReq) -> Iterator[TableRowSchema]: ... - def table_query_stats(self, req: TableQueryStatsReq) -> TableQueryStatsRes: ... - def refs_read_batch(self, req: RefsReadBatchReq) -> RefsReadBatchRes: ... - def file_create(self, req: FileCreateReq) -> FileCreateRes: ... - def file_content_read(self, req: FileContentReadReq) -> FileContentReadRes: ... - def feedback_create(self, req: FeedbackCreateReq) -> FeedbackCreateRes: ... - def feedback_query(self, req: FeedbackQueryReq) -> FeedbackQueryRes: ... - def feedback_purge(self, req: FeedbackPurgeReq) -> FeedbackPurgeRes: ... - def feedback_replace(self, req: FeedbackReplaceReq) -> FeedbackReplaceRes: ... + def obj_create(self, req: ObjCreateReq) -> ObjCreateRes: + ... + + def obj_read(self, req: ObjReadReq) -> ObjReadRes: + ... + + def objs_query(self, req: ObjQueryReq) -> ObjQueryRes: + ... + + def obj_delete(self, req: ObjDeleteReq) -> ObjDeleteRes: + ... + + def table_create(self, req: TableCreateReq) -> TableCreateRes: + ... + + def table_update(self, req: TableUpdateReq) -> TableUpdateRes: + ... + + def table_query(self, req: TableQueryReq) -> TableQueryRes: + ... + + def table_query_stream(self, req: TableQueryReq) -> Iterator[TableRowSchema]: + ... + + def table_query_stats(self, req: TableQueryStatsReq) -> TableQueryStatsRes: + ... + + def refs_read_batch(self, req: RefsReadBatchReq) -> RefsReadBatchRes: + ... + + def file_create(self, req: FileCreateReq) -> FileCreateRes: + ... + + def file_content_read(self, req: FileContentReadReq) -> FileContentReadRes: + ... + + def feedback_create(self, req: FeedbackCreateReq) -> FeedbackCreateRes: + ... + + def feedback_query(self, req: FeedbackQueryReq) -> FeedbackQueryRes: + ... + + def feedback_purge(self, req: FeedbackPurgeReq) -> FeedbackPurgeRes: + ... + + def feedback_replace(self, req: FeedbackReplaceReq) -> FeedbackReplaceRes: + ... # Action API def actions_execute_batch( self, req: ActionsExecuteBatchReq - ) -> ActionsExecuteBatchRes: ... + ) -> ActionsExecuteBatchRes: + ... # Execute LLM API - def completions_create(self, req: CompletionsCreateReq) -> CompletionsCreateRes: ... + def completions_create(self, req: CompletionsCreateReq) -> CompletionsCreateRes: + ...