diff --git a/pkg/domain/domain.go b/pkg/domain/domain.go index 2475723edd572..6e8b1e62eeb9e 100644 --- a/pkg/domain/domain.go +++ b/pkg/domain/domain.go @@ -3483,10 +3483,11 @@ func (do *Domain) planCacheEvictTrigger() { // SetupWorkloadBasedLearningWorker sets up all of the workload based learning workers. func (do *Domain) SetupWorkloadBasedLearningWorker() { wbLearningHandle := workloadlearning.NewWorkloadLearningHandle(do.sysSessionPool) + wbCacheWorker := workloadlearning.NewWLCacheWorker(do.sysSessionPool) // Start the workload based learning worker to analyze the read workload by statement_summary. do.wg.Run( func() { - do.readTableCostWorker(wbLearningHandle) + do.readTableCostWorker(wbLearningHandle, wbCacheWorker) }, "readTableCostWorker", ) @@ -3494,7 +3495,7 @@ func (do *Domain) SetupWorkloadBasedLearningWorker() { } // readTableCostWorker is a background worker that periodically analyze the read path table cost by statement_summary. -func (do *Domain) readTableCostWorker(wbLearningHandle *workloadlearning.Handle) { +func (do *Domain) readTableCostWorker(wbLearningHandle *workloadlearning.Handle, wbCacheWorker *workloadlearning.WLCacheWorker) { // Recover the panic and log the error when worker exit. defer util.Recover(metrics.LabelDomain, "readTableCostWorker", nil, false) readTableCostTicker := time.NewTicker(vardef.WorkloadBasedLearningInterval.Load()) @@ -3506,7 +3507,8 @@ func (do *Domain) readTableCostWorker(wbLearningHandle *workloadlearning.Handle) select { case <-readTableCostTicker.C: if vardef.EnableWorkloadBasedLearning.Load() && do.statsOwner.IsOwner() { - wbLearningHandle.HandleReadTableCost(do.InfoSchema()) + wbLearningHandle.HandleTableReadCost(do.InfoSchema()) + wbCacheWorker.UpdateTableReadCostCache() } case <-do.exit: return diff --git a/pkg/workloadlearning/BUILD.bazel b/pkg/workloadlearning/BUILD.bazel index 0413811722eba..88d35379208b8 100644 --- a/pkg/workloadlearning/BUILD.bazel +++ b/pkg/workloadlearning/BUILD.bazel @@ -3,6 +3,7 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "workloadlearning", srcs = [ + "cache.go", "handle.go", "metrics.go", ], @@ -24,8 +25,12 @@ go_library( go_test( name = "workloadlearning_test", timeout = "short", - srcs = ["handle_test.go"], + srcs = [ + "cache_test.go", + "handle_test.go", + ], flaky = True, + shard_count = 3, deps = [ ":workloadlearning", "//pkg/parser/ast", diff --git a/pkg/workloadlearning/cache.go b/pkg/workloadlearning/cache.go new file mode 100644 index 0000000000000..026233397fa87 --- /dev/null +++ b/pkg/workloadlearning/cache.go @@ -0,0 +1,154 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package workloadlearning + +import ( + "context" + "encoding/json" + "sync" + + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/logutil" + "go.uber.org/zap" +) + +// TableReadCostCache stores the cached workload learning metrics +type TableReadCostCache struct { + TableReadCostMetrics map[int64]*TableReadCostMetrics // key: tableID + Version uint64 +} + +// WLCacheWorker the worker to cache all workload-related metrics +// Now it is also used to save the cache data of table cost metrics. +type WLCacheWorker struct { + sysSessionPool util.DestroyableSessionPool + tableReadCostCache *TableReadCostCache + sync.RWMutex +} + +// NewWLCacheWorker Create a new workload learning cache worker to cache all workload-related metrics +// from storage mysql.tidb_workload_values to memory +func NewWLCacheWorker(pool util.DestroyableSessionPool) *WLCacheWorker { + cache := &TableReadCostCache{ + TableReadCostMetrics: make(map[int64]*TableReadCostMetrics), + Version: 0, + } + return &WLCacheWorker{ + pool, cache, sync.RWMutex{}} +} + +// UpdateTableReadCostCache refreshes the cached workload learning metrics +func (cw *WLCacheWorker) UpdateTableReadCostCache() { + // Get latest metrics from storage + se, err := cw.sysSessionPool.Get() + if err != nil { + logutil.BgLogger().Warn("Get system session failed when updating table cost cache", zap.Error(err)) + return + } + defer func() { + if err == nil { // only recycle when no error + cw.sysSessionPool.Put(se) + } else { + // Note: Otherwise, the session will be leaked. + cw.sysSessionPool.Destroy(se) + } + }() + + sctx := se.(sessionctx.Context) + exec := sctx.GetRestrictedSQLExecutor() + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnWorkloadLearning) + + // Whether to update table cost metrics + // Get the latest latestVersionInStorage in the storage + // TODO(elsa): Add the index of (category, type, version) to mysql.tidb_workload_values + sql := `SELECT version FROM mysql.tidb_workload_values + WHERE category = %? AND type = %? + ORDER BY version DESC LIMIT 1` + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql, feedbackCategory, tableReadCost) + if err != nil { + logutil.ErrVerboseLogger().Warn("Failed to get the latest table cost version", zap.Error(err)) + return + } + // Case: no metrics belongs to this feedback category and type + if len(rows) != 1 { + logutil.BgLogger().Warn("The result of latest table cost version query is not 1", + zap.Int("result_rows", len(rows))) + return + } + // If the latest latestVersionInStorage is the same as the cached latestVersionInStorage, no need to update + latestVersionInStorage := rows[0].GetUint64(0) + if latestVersionInStorage <= cw.tableReadCostCache.Version { + logutil.BgLogger().Info("The latest table cost version in storage is the same as the cached version, no need to update", + zap.Uint64("latest_version_in_storage", latestVersionInStorage), + zap.Uint64("cached_version", cw.tableReadCostCache.Version)) + return + } + + // Get the latest table cost of metrics + sql = `SELECT table_id, value FROM mysql.tidb_workload_values + WHERE category = %? AND type = %? AND version = %?` + rows, _, err = exec.ExecRestrictedSQL(ctx, nil, sql, feedbackCategory, tableReadCost, latestVersionInStorage) + if err != nil { + logutil.ErrVerboseLogger().Warn("Failed to get the latest table cost metrics", + zap.Error(err)) + return + } + newMetrics := make(map[int64]*TableReadCostMetrics) + for _, row := range rows { + tableID := row.GetInt64(0) + value := row.GetJSON(1).String() + + metric := &TableReadCostMetrics{} + if err := json.Unmarshal([]byte(value), metric); err != nil { + logutil.ErrVerboseLogger().Warn("Failed to unmarshal table cost metrics", + zap.Int64("table_id", tableID), + zap.String("value", value), + zap.Error(err)) + continue + } + newMetrics[tableID] = metric + } + + // Update cache atomically + cw.updateTableReadCostCacheWithMetrics(newMetrics, latestVersionInStorage) +} + +func (cw *WLCacheWorker) updateTableReadCostCacheWithMetrics(newMetrics map[int64]*TableReadCostMetrics, + latestVersionInStorage uint64) { + cw.RWMutex.Lock() + defer cw.RWMutex.Unlock() + cw.tableReadCostCache.TableReadCostMetrics = newMetrics + cw.tableReadCostCache.Version = latestVersionInStorage +} + +// GetTableReadCostMetrics returns the cached metrics for a given table ID +func (cw *WLCacheWorker) GetTableReadCostMetrics(tableID int64) *TableReadCostMetrics { + cw.RWMutex.RLock() + defer cw.RWMutex.RUnlock() + metric, exists := cw.tableReadCostCache.TableReadCostMetrics[tableID] + if !exists { + return nil + } + // deep copy for metrics to protect the cache + result := &TableReadCostMetrics{ + TableScanTime: metric.TableScanTime, + TableMemUsage: metric.TableMemUsage, + ReadFrequency: metric.ReadFrequency, + TableReadCost: metric.TableReadCost, + } + return result +} diff --git a/pkg/workloadlearning/cache_test.go b/pkg/workloadlearning/cache_test.go new file mode 100644 index 0000000000000..2a3e0e19cb950 --- /dev/null +++ b/pkg/workloadlearning/cache_test.go @@ -0,0 +1,79 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package workloadlearning_test + +import ( + "strconv" + "testing" + "time" + + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/testkit" + "github.com/pingcap/tidb/pkg/workloadlearning" + "github.com/stretchr/testify/require" +) + +func TestUpdateTableCostCache(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + + // Create test table and insert test metrics + tk.MustExec(`use test`) + tk.MustExec("create table test (a int, b int, index idx(a))") + + // Create a workload learning handle to save metrics + handle := workloadlearning.NewWorkloadLearningHandle(dom.SysSessionPool()) + + // Create test metrics + readTableCostMetrics := &workloadlearning.TableReadCostMetrics{ + DbName: ast.CIStr{O: "test", L: "test"}, + TableName: ast.CIStr{O: "test", L: "test"}, + TableScanTime: 10.0, + TableMemUsage: 10.0, + ReadFrequency: 10, + TableReadCost: 1.0, + } + tableCostMetrics := map[ast.CIStr]*workloadlearning.TableReadCostMetrics{ + {O: "test", L: "test"}: readTableCostMetrics, + } + + // Save metrics to storage + handle.SaveTableReadCostMetrics(tableCostMetrics, time.Now(), time.Now(), dom.InfoSchema()) + + // Create cache worker and test UpdateTableReadCostCache + worker := workloadlearning.NewWLCacheWorker(dom.SysSessionPool()) + worker.UpdateTableReadCostCache() + + // Get table ID for verification + rs := tk.MustQuery("select tidb_table_id from information_schema.tables where table_schema = 'test' and table_name = 'test'") + tableIDi, _ := strconv.Atoi(rs.Rows()[0][0].(string)) + tableID := int64(tableIDi) + + // Verify cached metrics + metrics := worker.GetTableReadCostMetrics(tableID) + require.NotNil(t, metrics) + require.Equal(t, 10.0, metrics.TableScanTime) + require.Equal(t, 10.0, metrics.TableMemUsage) + require.Equal(t, int64(10), metrics.ReadFrequency) + require.Equal(t, 1.0, metrics.TableReadCost) +} + +func TestGetTableReadCacheMetricsWithNoData(t *testing.T) { + _, dom := testkit.CreateMockStoreAndDomain(t) + // Create cache worker without saving metrics + worker := workloadlearning.NewWLCacheWorker(dom.SysSessionPool()) + result := worker.GetTableReadCostMetrics(1) + require.Nil(t, result) +} diff --git a/pkg/workloadlearning/handle.go b/pkg/workloadlearning/handle.go index c0afe41d6fcdd..cf5614709d801 100644 --- a/pkg/workloadlearning/handle.go +++ b/pkg/workloadlearning/handle.go @@ -45,21 +45,21 @@ const ( ) const ( // The type of workload-based learning - tableCostType = "TableCost" + tableReadCost = "TableReadCost" ) // Handle The entry point for all workload-based learning related tasks type Handle struct { - sysSessionPool util.SessionPool + sysSessionPool util.DestroyableSessionPool } // NewWorkloadLearningHandle Create a new WorkloadLearningHandle // WorkloadLearningHandle is Singleton pattern -func NewWorkloadLearningHandle(pool util.SessionPool) *Handle { +func NewWorkloadLearningHandle(pool util.DestroyableSessionPool) *Handle { return &Handle{pool} } -// HandleReadTableCost Start a new round of analysis of all historical read queries. +// HandleTableReadCost Start a new round of analysis of all historical table read queries. // According to abstracted table cost metrics, calculate the percentage of read scan time and memory usage for each table. // The result will be saved to the table "mysql.tidb_workload_values". // Dataflow @@ -72,7 +72,7 @@ func NewWorkloadLearningHandle(pool util.SessionPool) *Handle { // // 4. Calculate table cost for each table, table cost = table scan time / total scan time + table mem usage / total mem usage // 5. Save all table cost metrics[per table](scan time, table cost, etc) to table "mysql.tidb_workload_values" -func (handle *Handle) HandleReadTableCost(infoSchema infoschema.InfoSchema) { +func (handle *Handle) HandleTableReadCost(infoSchema infoschema.InfoSchema) { // step1: abstract middle table cost metrics from every record in statement_summary middleMetrics, startTime, endTime := handle.analyzeBasedOnStatementStats() if len(middleMetrics) == 0 { @@ -80,7 +80,7 @@ func (handle *Handle) HandleReadTableCost(infoSchema infoschema.InfoSchema) { } // step2: group by tablename, sum(table-scan-time), sum(table-mem-usage), sum(read-frequency) // step3: calculate the total scan time and total memory usage - tableNameToMetrics := make(map[ast.CIStr]*ReadTableCostMetrics) + tableNameToMetrics := make(map[ast.CIStr]*TableReadCostMetrics) totalScanTime := 0.0 totalMemUsage := 0.0 for _, middleMetric := range middleMetrics { @@ -100,28 +100,28 @@ func (handle *Handle) HandleReadTableCost(infoSchema infoschema.InfoSchema) { } // step4: calculate the percentage of scan time and memory usage for each table for _, metric := range tableNameToMetrics { - metric.TableCost = metric.TableScanTime/totalScanTime + metric.TableMemUsage/totalMemUsage + metric.TableReadCost = metric.TableScanTime/totalScanTime + metric.TableMemUsage/totalMemUsage } // step5: save the table cost metrics to table "mysql.tidb_workload_values" - handle.SaveReadTableCostMetrics(tableNameToMetrics, startTime, endTime, infoSchema) + handle.SaveTableReadCostMetrics(tableNameToMetrics, startTime, endTime, infoSchema) } -func (*Handle) analyzeBasedOnStatementSummary() []*ReadTableCostMetrics { +func (*Handle) analyzeBasedOnStatementSummary() []*TableReadCostMetrics { // step1: get all record from statement_summary // step2: abstract table cost metrics from each record return nil } // TODO -func (*Handle) analyzeBasedOnStatementStats() ([]*ReadTableCostMetrics, time.Time, time.Time) { +func (*Handle) analyzeBasedOnStatementStats() ([]*TableReadCostMetrics, time.Time, time.Time) { // step1: get all record from statement_stats // step2: abstract table cost metrics from each record // TODO change the mock value return nil, time.Now(), time.Now() } -// SaveReadTableCostMetrics table cost metrics, workload-based start and end time, version, -func (handle *Handle) SaveReadTableCostMetrics(metrics map[ast.CIStr]*ReadTableCostMetrics, +// SaveTableReadCostMetrics table cost metrics, workload-based start and end time, version, +func (handle *Handle) SaveTableReadCostMetrics(metrics map[ast.CIStr]*TableReadCostMetrics, _, _ time.Time, infoSchema infoschema.InfoSchema) { // TODO save the workload job info such as start end time into workload_jobs table // step1: create a new session, context, txn for saving table cost metrics @@ -130,8 +130,14 @@ func (handle *Handle) SaveReadTableCostMetrics(metrics map[ast.CIStr]*ReadTableC logutil.BgLogger().Warn("get system session failed when saving table cost metrics", zap.Error(err)) return } - // TODO to destroy the error session instead of put it back to the pool - defer handle.sysSessionPool.Put(se) + defer func() { + if err == nil { // only recycle when no error + handle.sysSessionPool.Put(se) + } else { + // Note: Otherwise, the session will be leaked. + handle.sysSessionPool.Destroy(se) + } + }() sctx := se.(sessionctx.Context) exec := sctx.GetRestrictedSQLExecutor() ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnWorkloadLearning) @@ -164,7 +170,7 @@ func (handle *Handle) SaveReadTableCostMetrics(metrics map[ast.CIStr]*ReadTableC zap.Float64("table_scan_time", metric.TableScanTime), zap.Float64("table_mem_usage", metric.TableMemUsage), zap.Int64("read_frequency", metric.ReadFrequency), - zap.Float64("table_cost", metric.TableCost), + zap.Float64("table_read_cost", metric.TableReadCost), zap.Error(err)) continue } @@ -176,12 +182,12 @@ func (handle *Handle) SaveReadTableCostMetrics(metrics map[ast.CIStr]*ReadTableC zap.Float64("table_scan_time", metric.TableScanTime), zap.Float64("table_mem_usage", metric.TableMemUsage), zap.Int64("read_frequency", metric.ReadFrequency), - zap.Float64("table_cost", metric.TableCost), + zap.Float64("table_read_cost", metric.TableReadCost), zap.Error(err)) continue } sqlescape.MustFormatSQL(sql, "(%?, %?, %?, %?, %?)", - version, feedbackCategory, tableCostType, tbl.Meta().ID, json.RawMessage(metricBytes)) + version, feedbackCategory, tableReadCost, tbl.Meta().ID, json.RawMessage(metricBytes)) // TODO check the txn record limit if i%batchInsertSize == batchInsertSize-1 { _, _, err := exec.ExecRestrictedSQL(ctx, nil, sql.String()) diff --git a/pkg/workloadlearning/handle_test.go b/pkg/workloadlearning/handle_test.go index 8aa93f3c1c029..90630d453abab 100644 --- a/pkg/workloadlearning/handle_test.go +++ b/pkg/workloadlearning/handle_test.go @@ -30,19 +30,19 @@ func TestSaveReadTableCostMetrics(t *testing.T) { tk.MustExec(`use test`) tk.MustExec("create table test (a int, b int, index idx(a))") // mock a table cost metrics - readTableCostMetrics := &workloadlearning.ReadTableCostMetrics{ + readTableCostMetrics := &workloadlearning.TableReadCostMetrics{ DbName: ast.CIStr{O: "test", L: "test"}, TableName: ast.CIStr{O: "test", L: "test"}, TableScanTime: 10.0, TableMemUsage: 10.0, ReadFrequency: 10, - TableCost: 1.0, + TableReadCost: 1.0, } - tableCostMetrics := map[ast.CIStr]*workloadlearning.ReadTableCostMetrics{ + tableCostMetrics := map[ast.CIStr]*workloadlearning.TableReadCostMetrics{ {O: "test", L: "test"}: readTableCostMetrics, } handle := workloadlearning.NewWorkloadLearningHandle(dom.SysSessionPool()) - handle.SaveReadTableCostMetrics(tableCostMetrics, time.Now(), time.Now(), dom.InfoSchema()) + handle.SaveTableReadCostMetrics(tableCostMetrics, time.Now(), time.Now(), dom.InfoSchema()) // check the result result := tk.MustQuery("select * from mysql.tidb_workload_values").Rows() diff --git a/pkg/workloadlearning/metrics.go b/pkg/workloadlearning/metrics.go index 3d769a59f7c09..396918ab86762 100644 --- a/pkg/workloadlearning/metrics.go +++ b/pkg/workloadlearning/metrics.go @@ -16,9 +16,10 @@ package workloadlearning import "github.com/pingcap/tidb/pkg/parser/ast" -// ReadTableCostMetrics is used to indicate the intermediate status and results analyzed through read workload -// for function "HandleReadTableCost". -type ReadTableCostMetrics struct { +// TableReadCostMetrics is used to indicate the intermediate status and results analyzed through table read workload +// for function "HandleTableReadCost". +type TableReadCostMetrics struct { + // TODO(Elsa): Add the json tag for the field DbName ast.CIStr TableName ast.CIStr // TableScanTime[t] = sum(scan-time * readFrequency) of all records in statement_summary where table-name = t @@ -27,7 +28,7 @@ type ReadTableCostMetrics struct { TableMemUsage float64 // ReadFrequency[t] = sum(read-frequency) of all records in statement_summary where table-name = t ReadFrequency int64 - // TableCost[t] = TableScanTime[t] / totalScanTime + TableMemUsage[t] / totalMemUsage + // TableReadCost[t] = TableScanTime[t] / totalScanTime + TableMemUsage[t] / totalMemUsage // range between 0 ~ 2 - TableCost float64 + TableReadCost float64 }