From 801c2a55a7de3e7b9a92c84faa0012ab8116b476 Mon Sep 17 00:00:00 2001 From: MyonKeminta <9948422+MyonKeminta@users.noreply.github.com> Date: Thu, 19 Dec 2024 19:08:12 +0800 Subject: [PATCH] This is an automated cherry-pick of #58054 Signed-off-by: ti-chi-bot --- DEPS.bzl | 11 + ddl/column_change_test.go | 72 +- ddl/column_test.go | 82 +- ddl/db_integration_test.go | 51 +- ddl/ddl_worker_test.go | 4 +- ddl/index_change_test.go | 7 +- executor/set.go | 6 + go.mod | 8 + go.sum | 13 + pkg/executor/test/executor/executor_test.go | 3090 +++++++++++++++++++ pkg/executor/test/writetest/BUILD.bazel | 34 + pkg/executor/test/writetest/write_test.go | 548 ++++ pkg/sessionctx/context.go | 252 ++ pkg/store/copr/batch_coprocessor.go | 1603 ++++++++++ pkg/store/copr/mpp.go | 346 +++ pkg/util/mock/BUILD.bazel | 71 + planner/core/planbuilder.go | 8 + sessiontxn/staleread/processor.go | 4 + sessiontxn/staleread/util.go | 17 + store/copr/BUILD.bazel | 1 + store/copr/batch_request_sender.go | 5 +- util/cmp/compare_test.go | 9 + util/mock/context.go | 86 +- 23 files changed, 6297 insertions(+), 31 deletions(-) create mode 100644 pkg/executor/test/executor/executor_test.go create mode 100644 pkg/executor/test/writetest/BUILD.bazel create mode 100644 pkg/executor/test/writetest/write_test.go create mode 100644 pkg/sessionctx/context.go create mode 100644 pkg/store/copr/batch_coprocessor.go create mode 100644 pkg/store/copr/mpp.go create mode 100644 pkg/util/mock/BUILD.bazel diff --git a/DEPS.bzl b/DEPS.bzl index 042f4025e5db0..fcb16d2735e39 100644 --- a/DEPS.bzl +++ b/DEPS.bzl @@ -3603,8 +3603,19 @@ def go_deps(): name = "com_github_tikv_client_go_v2", build_file_proto_mode = "disable_global", importpath = "github.com/tikv/client-go/v2", +<<<<<<< HEAD sum = "h1:0YcirnuxtXC9eQRb231im1M5w/n7JFuOo0IgE/K9ffM=", version = "v2.0.4-0.20241125064444-5f59e4e34c62", +======= + sha256 = "844684ee6ae7decc5cadcab3f95c526b66878f8401c71cf82af68ec0cc5257d5", + strip_prefix = "github.com/tikv/client-go/v2@v2.0.8-0.20241209094930-06d7f4b9233b", + urls = [ + "http://bazel-cache.pingcap.net:8080/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20241209094930-06d7f4b9233b.zip", + "http://ats.apps.svc/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20241209094930-06d7f4b9233b.zip", + "https://cache.hawkingrei.com/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20241209094930-06d7f4b9233b.zip", + "https://storage.googleapis.com/pingcapmirror/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20241209094930-06d7f4b9233b.zip", + ], +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)) ) go_repository( name = "com_github_tikv_pd_client", diff --git a/ddl/column_change_test.go b/ddl/column_change_test.go index 4528564d2f231..77cd09738691c 100644 --- a/ddl/column_change_test.go +++ b/ddl/column_change_test.go @@ -22,6 +22,7 @@ import ( "time" "github.com/pingcap/errors" +<<<<<<< HEAD:ddl/column_change_test.go "github.com/pingcap/tidb/ddl" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta" @@ -35,6 +36,21 @@ import ( "github.com/pingcap/tidb/testkit/external" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/mock" +======= + "github.com/pingcap/tidb/pkg/ddl" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/meta/model" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessiontxn" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/testkit" + "github.com/pingcap/tidb/pkg/testkit/external" + "github.com/pingcap/tidb/pkg/testkit/testfailpoint" + "github.com/pingcap/tidb/pkg/types" +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/ddl/column_change_test.go "github.com/stretchr/testify/require" ) @@ -47,10 +63,14 @@ func TestColumnAdd(t *testing.T) { tk.MustExec("create table t (c1 int, c2 int);") tk.MustExec("insert t values (1, 2);") +<<<<<<< HEAD:ddl/column_change_test.go d := dom.DDL() tc := &ddl.TestDDLCallback{Do: dom} ct := testNewContext(store) +======= + ct := testNewContext(t, store) +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/ddl/column_change_test.go // set up hook var ( deleteOnlyTable table.Table @@ -128,8 +148,14 @@ func TestColumnAdd(t *testing.T) { } else { return } +<<<<<<< HEAD:ddl/column_change_test.go sess := testNewContext(store) err := sessiontxn.NewTxn(context.Background(), sess) +======= + first = false + sess := testNewContext(t, store) + txn, err := newTxn(sess) +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/ddl/column_change_test.go require.NoError(t, err) _, err = writeOnlyTable.AddRecord(sess, types.MakeDatums(10, 10)) require.NoError(t, err) @@ -224,7 +250,15 @@ func checkAddWriteOnly(ctx sessionctx.Context, deleteOnlyTable, writeOnlyTable t if err != nil { return errors.Trace(err) } +<<<<<<< HEAD:ddl/column_change_test.go err = sessiontxn.NewTxn(context.Background(), ctx) +======= + err = txn.Commit(context.Background()) + if err != nil { + return errors.Trace(err) + } + txn, err = newTxn(ctx) +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/ddl/column_change_test.go if err != nil { return errors.Trace(err) } @@ -262,7 +296,15 @@ func checkAddWriteOnly(ctx sessionctx.Context, deleteOnlyTable, writeOnlyTable t if err != nil { return errors.Trace(err) } +<<<<<<< HEAD:ddl/column_change_test.go err = sessiontxn.NewTxn(context.Background(), ctx) +======= + err = txn.Commit(context.Background()) + if err != nil { + return errors.Trace(err) + } + txn, err = newTxn(ctx) +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/ddl/column_change_test.go if err != nil { return errors.Trace(err) } @@ -279,7 +321,15 @@ func checkAddWriteOnly(ctx sessionctx.Context, deleteOnlyTable, writeOnlyTable t if err != nil { return errors.Trace(err) } +<<<<<<< HEAD:ddl/column_change_test.go err = sessiontxn.NewTxn(context.Background(), ctx) +======= + err = txn.Commit(context.Background()) + if err != nil { + return errors.Trace(err) + } + _, err = newTxn(ctx) +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/ddl/column_change_test.go if err != nil { return errors.Trace(err) } @@ -309,7 +359,15 @@ func checkAddPublic(sctx sessionctx.Context, writeOnlyTable, publicTable table.T if err != nil { return errors.Trace(err) } +<<<<<<< HEAD:ddl/column_change_test.go err = sessiontxn.NewTxn(ctx, sctx) +======= + err = txn.Commit(context.Background()) + if err != nil { + return errors.Trace(err) + } + txn, err = newTxn(sctx) +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/ddl/column_change_test.go if err != nil { return errors.Trace(err) } @@ -326,7 +384,15 @@ func checkAddPublic(sctx sessionctx.Context, writeOnlyTable, publicTable table.T if err != nil { return errors.Trace(err) } +<<<<<<< HEAD:ddl/column_change_test.go err = sessiontxn.NewTxn(ctx, sctx) +======= + err = txn.Commit(context.Background()) + if err != nil { + return errors.Trace(err) + } + _, err = newTxn(sctx) +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/ddl/column_change_test.go if err != nil { return errors.Trace(err) } @@ -432,8 +498,6 @@ func testCheckJobDone(t *testing.T, store kv.Storage, jobID int64, isAdd bool) { } } -func testNewContext(store kv.Storage) sessionctx.Context { - ctx := mock.NewContext() - ctx.Store = store - return ctx +func testNewContext(t *testing.T, store kv.Storage) sessionctx.Context { + return testkit.NewSession(t, store) } diff --git a/ddl/column_test.go b/ddl/column_test.go index 53fb6a6fcd39f..49653a7377148 100644 --- a/ddl/column_test.go +++ b/ddl/column_test.go @@ -167,8 +167,13 @@ func TestColumnBasic(t *testing.T) { tk.MustExec(fmt.Sprintf("insert into t1 values(%d, %d, %d)", i, 10*i, 100*i)) } +<<<<<<< HEAD:ddl/column_test.go ctx := testNewContext(store) err := sessiontxn.NewTxn(context.Background(), ctx) +======= + ctx := testNewContext(t, store) + txn, err := newTxn(ctx) +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/ddl/column_test.go require.NoError(t, err) var tableID int64 @@ -214,7 +219,13 @@ func TestColumnBasic(t *testing.T) { h, err := tbl.AddRecord(ctx, types.MakeDatums(11, 12, 13, 14)) require.NoError(t, err) +<<<<<<< HEAD:ddl/column_test.go err = sessiontxn.NewTxn(context.Background(), ctx) +======= + err = txn.Commit(context.Background()) + require.NoError(t, err) + _, err = newTxn(ctx) +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/ddl/column_test.go require.NoError(t, err) values, err := tables.RowWithCols(tbl, ctx, h, tbl.Cols()) require.NoError(t, err) @@ -385,7 +396,13 @@ func checkDeleteOnlyColumn(t *testing.T, ctx sessionctx.Context, tableID int64, newRow := types.MakeDatums(int64(11), int64(22), int64(33)) newHandle, err := tbl.AddRecord(ctx, newRow) require.NoError(t, err) +<<<<<<< HEAD:ddl/column_test.go err = sessiontxn.NewTxn(context.Background(), ctx) +======= + err = txn.Commit(context.Background()) + require.NoError(t, err) + txn, err = newTxn(ctx) +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/ddl/column_test.go require.NoError(t, err) rows := [][]types.Datum{row, newRow} @@ -407,7 +424,13 @@ func checkDeleteOnlyColumn(t *testing.T, ctx sessionctx.Context, tableID int64, err = tbl.RemoveRecord(ctx, newHandle, newRow) require.NoError(t, err) +<<<<<<< HEAD:ddl/column_test.go err = sessiontxn.NewTxn(context.Background(), ctx) +======= + err = txn.Commit(context.Background()) + require.NoError(t, err) + txn, err = newTxn(ctx) +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/ddl/column_test.go require.NoError(t, err) i = 0 err = tables.IterRecords(tbl, ctx, tbl.Cols(), func(_ kv.Handle, data []types.Datum, cols []*table.Column) (bool, error) { @@ -447,7 +470,13 @@ func checkWriteOnlyColumn(t *testing.T, ctx sessionctx.Context, tableID int64, h newRow := types.MakeDatums(int64(11), int64(22), int64(33)) newHandle, err := tbl.AddRecord(ctx, newRow) require.NoError(t, err) +<<<<<<< HEAD:ddl/column_test.go err = sessiontxn.NewTxn(context.Background(), ctx) +======= + err = txn.Commit(context.Background()) + require.NoError(t, err) + txn, err = newTxn(ctx) +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/ddl/column_test.go require.NoError(t, err) rows := [][]types.Datum{row, newRow} @@ -469,7 +498,14 @@ func checkWriteOnlyColumn(t *testing.T, ctx sessionctx.Context, tableID int64, h err = tbl.RemoveRecord(ctx, newHandle, newRow) require.NoError(t, err) +<<<<<<< HEAD:ddl/column_test.go err = sessiontxn.NewTxn(context.Background(), ctx) +======= + + err = txn.Commit(context.Background()) + require.NoError(t, err) + txn, err = newTxn(ctx) +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/ddl/column_test.go require.NoError(t, err) i = 0 @@ -507,7 +543,13 @@ func checkReorganizationColumn(t *testing.T, ctx sessionctx.Context, tableID int newRow := types.MakeDatums(int64(11), int64(22), int64(33)) newHandle, err := tbl.AddRecord(ctx, newRow) require.NoError(t, err) +<<<<<<< HEAD:ddl/column_test.go err = sessiontxn.NewTxn(context.Background(), ctx) +======= + err = txn.Commit(context.Background()) + require.NoError(t, err) + txn, err = newTxn(ctx) +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/ddl/column_test.go require.NoError(t, err) rows := [][]types.Datum{row, newRow} @@ -530,7 +572,13 @@ func checkReorganizationColumn(t *testing.T, ctx sessionctx.Context, tableID int err = tbl.RemoveRecord(ctx, newHandle, newRow) require.NoError(t, err) +<<<<<<< HEAD:ddl/column_test.go err = sessiontxn.NewTxn(context.Background(), ctx) +======= + err = txn.Commit(context.Background()) + require.NoError(t, err) + txn, err = newTxn(ctx) +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/ddl/column_test.go require.NoError(t, err) i = 0 @@ -573,7 +621,13 @@ func checkPublicColumn(t *testing.T, ctx sessionctx.Context, tableID int64, newC } handle, err := tbl.AddRecord(ctx, newRow) require.NoError(t, err) +<<<<<<< HEAD:ddl/column_test.go err = sessiontxn.NewTxn(context.Background(), ctx) +======= + err = txn.Commit(context.Background()) + require.NoError(t, err) + txn, err = newTxn(ctx) +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/ddl/column_test.go require.NoError(t, err) rows := [][]types.Datum{updatedRow, newRow} @@ -593,8 +647,14 @@ func checkPublicColumn(t *testing.T, ctx sessionctx.Context, tableID int64, newC err = tbl.RemoveRecord(ctx, handle, newRow) require.NoError(t, err) + err = txn.Commit(context.Background()) + require.NoError(t, err) +<<<<<<< HEAD:ddl/column_test.go err = sessiontxn.NewTxn(context.Background(), ctx) +======= + txn, err = newTxn(ctx) +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/ddl/column_test.go require.NoError(t, err) i = 0 @@ -610,8 +670,13 @@ func checkPublicColumn(t *testing.T, ctx sessionctx.Context, tableID int64, newC require.NoError(t, err) } +<<<<<<< HEAD:ddl/column_test.go func checkAddColumn(t *testing.T, state model.SchemaState, tableID int64, handle kv.Handle, newCol *table.Column, oldRow []types.Datum, columnValue interface{}, dom *domain.Domain, store kv.Storage, columnCnt int) { ctx := testNewContext(store) +======= +func checkAddColumn(t *testing.T, state model.SchemaState, tableID int64, handle kv.Handle, newCol *table.Column, oldRow []types.Datum, columnValue any, dom *domain.Domain, store kv.Storage, columnCnt int) { + ctx := testNewContext(t, store) +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/ddl/column_test.go switch state { case model.StateNone: checkNoneColumn(t, ctx, tableID, handle, newCol, columnValue, dom) @@ -655,8 +720,13 @@ func TestAddColumn(t *testing.T) { tableID = int64(tableIDi) tbl := testGetTable(t, dom, tableID) +<<<<<<< HEAD:ddl/column_test.go ctx := testNewContext(store) err := sessiontxn.NewTxn(context.Background(), ctx) +======= + ctx := testNewContext(t, store) + txn, err := newTxn(ctx) +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/ddl/column_test.go require.NoError(t, err) oldRow := types.MakeDatums(int64(1), int64(2), int64(3)) handle, err := tbl.AddRecord(ctx, oldRow) @@ -728,8 +798,13 @@ func TestAddColumns(t *testing.T) { tableID = int64(tableIDi) tbl := testGetTable(t, dom, tableID) +<<<<<<< HEAD:ddl/column_test.go ctx := testNewContext(store) err := sessiontxn.NewTxn(context.Background(), ctx) +======= + ctx := testNewContext(t, store) + txn, err := newTxn(ctx) +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/ddl/column_test.go require.NoError(t, err) oldRow := types.MakeDatums(int64(1), int64(2), int64(3)) handle, err := tbl.AddRecord(ctx, oldRow) @@ -791,7 +866,7 @@ func TestDropColumnInColumnTest(t *testing.T) { tableID = int64(tableIDi) tbl := testGetTable(t, dom, tableID) - ctx := testNewContext(store) + ctx := testNewContext(t, store) colName := "c4" defaultColValue := int64(4) row := types.MakeDatums(int64(1), int64(2), int64(3)) @@ -852,8 +927,13 @@ func TestDropColumns(t *testing.T) { tableID = int64(tableIDi) tbl := testGetTable(t, dom, tableID) +<<<<<<< HEAD:ddl/column_test.go ctx := testNewContext(store) err := sessiontxn.NewTxn(context.Background(), ctx) +======= + ctx := testNewContext(t, store) + txn, err := newTxn(ctx) +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/ddl/column_test.go require.NoError(t, err) colNames := []string{"c3", "c4"} diff --git a/ddl/db_integration_test.go b/ddl/db_integration_test.go index 4c949d56e7f77..d81f5c99496a0 100644 --- a/ddl/db_integration_test.go +++ b/ddl/db_integration_test.go @@ -26,6 +26,7 @@ import ( "time" "github.com/pingcap/errors" +<<<<<<< HEAD:ddl/db_integration_test.go _ "github.com/pingcap/tidb/autoid_service" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/ddl" @@ -54,6 +55,35 @@ import ( "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/dbterror" "github.com/pingcap/tidb/util/mock" +======= + _ "github.com/pingcap/tidb/pkg/autoid_service" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/ddl/schematracker" + ddlutil "github.com/pingcap/tidb/pkg/ddl/util" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/errno" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/meta/model" + "github.com/pingcap/tidb/pkg/parser/auth" + "github.com/pingcap/tidb/pkg/parser/charset" + pmodel "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/session" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/sessiontxn" + "github.com/pingcap/tidb/pkg/store/mockstore" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/testkit" + "github.com/pingcap/tidb/pkg/testkit/external" + "github.com/pingcap/tidb/pkg/testkit/testfailpoint" + "github.com/pingcap/tidb/pkg/util/collate" + contextutil "github.com/pingcap/tidb/pkg/util/context" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/ddl/db_integration_test.go "github.com/stretchr/testify/require" ) @@ -830,11 +860,10 @@ func TestChangingTableCharset(t *testing.T) { tblInfo.Charset = "" tblInfo.Collate = "" updateTableInfo := func(tblInfo *model.TableInfo) { - mockCtx := mock.NewContext() - mockCtx.Store = store - err := sessiontxn.NewTxn(context.Background(), mockCtx) + ctx := testkit.NewSession(t, store) + err := sessiontxn.NewTxn(context.Background(), ctx) require.NoError(t, err) - txn, err := mockCtx.Txn(true) + txn, err := ctx.Txn(true) require.NoError(t, err) mt := meta.NewMeta(txn) @@ -1076,11 +1105,10 @@ func TestCaseInsensitiveCharsetAndCollate(t *testing.T) { tblInfo.Charset = "UTF8MB4" updateTableInfo := func(tblInfo *model.TableInfo) { - mockCtx := mock.NewContext() - mockCtx.Store = store - err := sessiontxn.NewTxn(context.Background(), mockCtx) + sctx := testkit.NewSession(t, store) + err := sessiontxn.NewTxn(context.Background(), sctx) require.NoError(t, err) - txn, err := mockCtx.Txn(true) + txn, err := sctx.Txn(true) require.NoError(t, err) mt := meta.NewMeta(txn) require.True(t, ok) @@ -1918,11 +1946,10 @@ func TestTreatOldVersionUTF8AsUTF8MB4(t *testing.T) { tblInfo.Version = model.TableInfoVersion0 tblInfo.Columns[0].Version = model.ColumnInfoVersion0 updateTableInfo := func(tblInfo *model.TableInfo) { - mockCtx := mock.NewContext() - mockCtx.Store = store - err := sessiontxn.NewTxn(context.Background(), mockCtx) + sctx := testkit.NewSession(t, store) + err := sessiontxn.NewTxn(context.Background(), sctx) require.NoError(t, err) - txn, err := mockCtx.Txn(true) + txn, err := sctx.Txn(true) require.NoError(t, err) mt := meta.NewMeta(txn) require.True(t, ok) diff --git a/ddl/ddl_worker_test.go b/ddl/ddl_worker_test.go index e07d1661f7d99..c9a009d45abea 100644 --- a/ddl/ddl_worker_test.go +++ b/ddl/ddl_worker_test.go @@ -51,7 +51,7 @@ func TestInvalidDDLJob(t *testing.T) { BinlogInfo: &model.HistoryInfo{}, Args: []interface{}{}, } - ctx := testNewContext(store) + ctx := testNewContext(t, store) ctx.SetValue(sessionctx.QueryString, "skip") err := dom.DDL().DoDDLJob(ctx, job) require.Equal(t, err.Error(), "[ddl:8204]invalid ddl job type: none") @@ -59,7 +59,7 @@ func TestInvalidDDLJob(t *testing.T) { func TestAddBatchJobError(t *testing.T) { store, dom := testkit.CreateMockStoreAndDomainWithSchemaLease(t, testLease) - ctx := testNewContext(store) + ctx := testNewContext(t, store) require.Nil(t, failpoint.Enable("github.com/pingcap/tidb/ddl/mockAddBatchDDLJobsErr", `return(true)`)) // Test the job runner should not hang forever. diff --git a/ddl/index_change_test.go b/ddl/index_change_test.go index f9dcc99154dc5..985eb627fcf97 100644 --- a/ddl/index_change_test.go +++ b/ddl/index_change_test.go @@ -55,8 +55,13 @@ func TestIndexChange(t *testing.T) { if job.SchemaState == prevState { return } +<<<<<<< HEAD:ddl/index_change_test.go jobID = job.ID ctx1 := testNewContext(store) +======= + jobID.Store(job.ID) + ctx1 := testNewContext(t, store) +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/ddl/index_change_test.go prevState = job.SchemaState require.NoError(t, dom.Reload()) tbl, exist := dom.InfoSchema().TableByID(job.TableID) @@ -105,7 +110,7 @@ func TestIndexChange(t *testing.T) { require.NoError(t, dom.Reload()) tbl, exist := dom.InfoSchema().TableByID(job.TableID) require.True(t, exist) - ctx1 := testNewContext(store) + ctx1 := testNewContext(t, store) switch job.SchemaState { case model.StateWriteOnly: writeOnlyTable = tbl diff --git a/executor/set.go b/executor/set.go index 75e4938d41725..5d70972c4306a 100644 --- a/executor/set.go +++ b/executor/set.go @@ -197,10 +197,16 @@ func (e *SetExecutor) setSysVariable(ctx context.Context, name string, v *expres newSnapshotTS := getSnapshotTSByName() newSnapshotIsSet := newSnapshotTS > 0 && newSnapshotTS != oldSnapshotTS if newSnapshotIsSet { +<<<<<<< HEAD:executor/set.go if name == variable.TiDBTxnReadTS { err = sessionctx.ValidateStaleReadTS(ctx, e.ctx, newSnapshotTS) } else { err = sessionctx.ValidateSnapshotReadTS(ctx, e.ctx, newSnapshotTS) +======= + isStaleRead := name == variable.TiDBTxnReadTS + err = sessionctx.ValidateSnapshotReadTS(ctx, e.Ctx().GetStore(), newSnapshotTS, isStaleRead) + if name != variable.TiDBTxnReadTS { +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/executor/set.go // Also check gc safe point for snapshot read. // We don't check snapshot with gc safe point for read_ts // Client-go will automatically check the snapshotTS with gc safe point. It's unnecessary to check gc safe point during set executor. diff --git a/go.mod b/go.mod index 6945c5367ea86..4b0a24c2da866 100644 --- a/go.mod +++ b/go.mod @@ -90,10 +90,18 @@ require ( github.com/stretchr/testify v1.8.4 github.com/tdakkota/asciicheck v0.1.1 github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2 +<<<<<<< HEAD github.com/tikv/client-go/v2 v2.0.4-0.20241125064444-5f59e4e34c62 github.com/tikv/pd/client v0.0.0-20230904040343-947701a32c05 github.com/timakin/bodyclose v0.0.0-20210704033933-f49887972144 github.com/twmb/murmur3 v1.1.3 +======= + github.com/tidwall/btree v1.7.0 + github.com/tikv/client-go/v2 v2.0.8-0.20241209094930-06d7f4b9233b + github.com/tikv/pd/client v0.0.0-20241111073742-238d4d79ea31 + github.com/timakin/bodyclose v0.0.0-20240125160201-f835fa56326a + github.com/twmb/murmur3 v1.1.6 +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)) github.com/uber/jaeger-client-go v2.22.1+incompatible github.com/vbauerster/mpb/v7 v7.5.3 github.com/wangjohn/quickselect v0.0.0-20161129230411-ed8402a42d5f diff --git a/go.sum b/go.sum index 3e24ce3c05608..4779e8ecd6c10 100644 --- a/go.sum +++ b/go.sum @@ -948,12 +948,25 @@ github.com/tenntenn/text/transform v0.0.0-20200319021203-7eef512accb3 h1:f+jULpR github.com/tenntenn/text/transform v0.0.0-20200319021203-7eef512accb3/go.mod h1:ON8b8w4BN/kE1EOhwT0o+d62W65a6aPw1nouo9LMgyY= github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2 h1:mbAskLJ0oJfDRtkanvQPiooDH8HvJ2FBh+iKT/OmiQQ= github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2/go.mod h1:2PfKggNGDuadAa0LElHrByyrz4JPZ9fFx6Gs7nx7ZZU= +<<<<<<< HEAD github.com/tikv/client-go/v2 v2.0.4-0.20241125064444-5f59e4e34c62 h1:0YcirnuxtXC9eQRb231im1M5w/n7JFuOo0IgE/K9ffM= github.com/tikv/client-go/v2 v2.0.4-0.20241125064444-5f59e4e34c62/go.mod h1:mmVCLP2OqWvQJPOIevQPZvGphzh/oq9vv8J5LDfpadQ= github.com/tikv/pd/client v0.0.0-20230904040343-947701a32c05 h1:e4hLUKfgfPeJPZwOfU+/I/03G0sn6IZqVcbX/5o+hvM= github.com/tikv/pd/client v0.0.0-20230904040343-947701a32c05/go.mod h1:MLIl+d2WbOF4A3U88WKtyXrQQW417wZDDvBcq2IW9bQ= github.com/timakin/bodyclose v0.0.0-20210704033933-f49887972144 h1:kl4KhGNsJIbDHS9/4U9yQo1UcPQM0kOMJHn29EoH/Ro= github.com/timakin/bodyclose v0.0.0-20210704033933-f49887972144/go.mod h1:Qimiffbc6q9tBWlVV6x0P9sat/ao1xEkREYPPj9hphk= +======= +github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a h1:J/YdBZ46WKpXsxsW93SG+q0F8KI+yFrcIDT4c/RNoc4= +github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a/go.mod h1:h4xBhSNtOeEosLJ4P7JyKXX7Cabg7AVkWCK5gV2vOrM= +github.com/tidwall/btree v1.7.0 h1:L1fkJH/AuEh5zBnnBbmTwQ5Lt+bRJ5A8EWecslvo9iI= +github.com/tidwall/btree v1.7.0/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY= +github.com/tikv/client-go/v2 v2.0.8-0.20241209094930-06d7f4b9233b h1:x8E2J8UuUa2ysUkgVfNGgiXxZ9nfqBpQ43PBLwmCitU= +github.com/tikv/client-go/v2 v2.0.8-0.20241209094930-06d7f4b9233b/go.mod h1:NI2GfVlB9n7DsIGCxrKcD4psrcuFNEV8m1BgyzK1Amc= +github.com/tikv/pd/client v0.0.0-20241111073742-238d4d79ea31 h1:oAYc4m5Eu1OY9ogJ103VO47AYPHvhtzbUPD8L8B67Qk= +github.com/tikv/pd/client v0.0.0-20241111073742-238d4d79ea31/go.mod h1:W5a0sDadwUpI9k8p7M77d3jo253ZHdmua+u4Ho4Xw8U= +github.com/timakin/bodyclose v0.0.0-20240125160201-f835fa56326a h1:A6uKudFIfAEpoPdaal3aSqGxBzLyU8TqyXImLwo6dIo= +github.com/timakin/bodyclose v0.0.0-20240125160201-f835fa56326a/go.mod h1:mkjARE7Yr8qU23YcGMSALbIxTQ9r9QBVahQOBRfU460= +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)) github.com/tklauser/go-sysconf v0.3.9/go.mod h1:11DU/5sG7UexIrp/O6g35hrWzu0JxlwQ3LSFUzyeuhs= github.com/tklauser/go-sysconf v0.3.10 h1:IJ1AZGZRWbY8T5Vfk04D9WOA5WSejdflXxP03OUqALw= github.com/tklauser/go-sysconf v0.3.10/go.mod h1:C8XykCvCb+Gn0oNCWPIlcb0RuglQTYaQ2hGm7jmxEFk= diff --git a/pkg/executor/test/executor/executor_test.go b/pkg/executor/test/executor/executor_test.go new file mode 100644 index 0000000000000..d0afcc68c03e5 --- /dev/null +++ b/pkg/executor/test/executor/executor_test.go @@ -0,0 +1,3090 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "archive/zip" + "context" + "fmt" + "math/rand" + "os" + "path/filepath" + "reflect" + "runtime" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/golang/protobuf/proto" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/ddl" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/executor" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/meta/autoid" + "github.com/pingcap/tidb/pkg/meta/model" + "github.com/pingcap/tidb/pkg/parser" + pmodel "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/planner" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/core/resolve" + "github.com/pingcap/tidb/pkg/session" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/sessiontxn" + "github.com/pingcap/tidb/pkg/store/mockstore" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/testkit" + "github.com/pingcap/tidb/pkg/testkit/testdata" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" + "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/mock" + "github.com/pingcap/tidb/pkg/util/replayer" + "github.com/pingcap/tidb/pkg/util/rowcodec" + "github.com/pingcap/tidb/pkg/util/sqlexec" + "github.com/pingcap/tidb/pkg/util/timeutil" + "github.com/pingcap/tipb/go-tipb" + "github.com/stretchr/testify/require" + "github.com/tikv/client-go/v2/oracle" + "github.com/tikv/client-go/v2/testutils" +) + +func checkFileName(s string) bool { + files := []string{ + "config.toml", + "debug_trace/debug_trace0.json", + "meta.txt", + "stats/test.t_dump_single.json", + "schema/test.t_dump_single.schema.txt", + "schema/schema_meta.txt", + "table_tiflash_replica.txt", + "variables.toml", + "session_bindings.sql", + "global_bindings.sql", + "sql/sql0.sql", + "explain.txt", + "statsMem/test.t_dump_single.txt", + "sql_meta.toml", + } + for _, f := range files { + if strings.Compare(f, s) == 0 { + return true + } + } + return false +} + +func TestPlanReplayer(t *testing.T) { + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/infoschema/mockTiFlashStoreCount", `return(true)`)) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/infoschema/mockTiFlashStoreCount")) + }() + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int, index idx_a(a))") + tk.MustExec("alter table t set tiflash replica 1") + tk.MustQuery("plan replayer dump explain select * from t where a=10") + defer os.RemoveAll(replayer.GetPlanReplayerDirName()) + tk.MustQuery("plan replayer dump explain select /*+ read_from_storage(tiflash[t]) */ * from t") + + tk.MustExec("create table t1 (a int)") + tk.MustExec("create table t2 (a int)") + tk.MustExec("create definer=`root`@`127.0.0.1` view v1 as select * from t1") + tk.MustExec("create definer=`root`@`127.0.0.1` view v2 as select * from v1") + tk.MustQuery("plan replayer dump explain with tmp as (select a from t1 group by t1.a) select * from tmp, t2 where t2.a=tmp.a;") + tk.MustQuery("plan replayer dump explain select * from t1 where t1.a > (with cte1 as (select 1) select count(1) from cte1);") + tk.MustQuery("plan replayer dump explain select * from v1") + tk.MustQuery("plan replayer dump explain select * from v2") + require.True(t, len(tk.Session().GetSessionVars().LastPlanReplayerToken) > 0) + + // clear the status table and assert + tk.MustExec("delete from mysql.plan_replayer_status") + tk.MustQuery("plan replayer dump explain select * from v2") + token := tk.Session().GetSessionVars().LastPlanReplayerToken + rows := tk.MustQuery(fmt.Sprintf("select * from mysql.plan_replayer_status where token = '%v'", token)).Rows() + require.Len(t, rows, 1) +} + +func TestPlanReplayerCaptureSEM(t *testing.T) { + originSEM := config.GetGlobalConfig().Security.EnableSEM + defer func() { + config.GetGlobalConfig().Security.EnableSEM = originSEM + }() + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("plan replayer capture '123' '123';") + tk.MustExec("create table t(id int)") + tk.MustQuery("plan replayer dump explain select * from t") + defer os.RemoveAll(replayer.GetPlanReplayerDirName()) + tk.MustQuery("select count(*) from mysql.plan_replayer_status").Check(testkit.Rows("1")) +} + +func TestPlanReplayerCapture(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("plan replayer capture '123' '123';") + tk.MustQuery("select sql_digest, plan_digest from mysql.plan_replayer_task;").Check(testkit.Rows("123 123")) + tk.MustGetErrMsg("plan replayer capture '123' '123';", "plan replayer capture task already exists") + tk.MustExec("plan replayer capture remove '123' '123'") + tk.MustQuery("select count(*) from mysql.plan_replayer_task;").Check(testkit.Rows("0")) + tk.MustExec("create table t(id int)") + tk.MustExec("prepare stmt from 'update t set id = ? where id = ? + 1';") + tk.MustExec("SET @number = 5;") + tk.MustExec("execute stmt using @number,@number") + _, sqlDigest := tk.Session().GetSessionVars().StmtCtx.SQLDigest() + _, planDigest := tk.Session().GetSessionVars().StmtCtx.GetPlanDigest() + tk.MustExec("SET @@tidb_enable_plan_replayer_capture = ON;") + tk.MustExec("SET @@global.tidb_enable_historical_stats_for_capture='ON'") + tk.MustExec(fmt.Sprintf("plan replayer capture '%v' '%v'", sqlDigest.String(), planDigest.String())) + err := dom.GetPlanReplayerHandle().CollectPlanReplayerTask() + require.NoError(t, err) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/domain/shouldDumpStats", "return(true)")) + defer require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/domain/shouldDumpStats")) + tk.MustExec("execute stmt using @number,@number") + task := dom.GetPlanReplayerHandle().DrainTask() + require.NotNil(t, task) +} + +func TestPlanReplayerContinuesCapture(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + + tk.MustExec("set @@global.tidb_enable_historical_stats='OFF'") + _, err := tk.Exec("set @@global.tidb_enable_plan_replayer_continuous_capture='ON'") + require.Error(t, err) + require.Equal(t, err.Error(), "tidb_enable_historical_stats should be enabled before enabling tidb_enable_plan_replayer_continuous_capture") + + tk.MustExec("set @@global.tidb_enable_historical_stats='ON'") + tk.MustExec("set @@global.tidb_enable_plan_replayer_continuous_capture='ON'") + + prHandle := dom.GetPlanReplayerHandle() + tk.MustExec("delete from mysql.plan_replayer_status;") + tk.MustExec("use test") + tk.MustExec("create table t(id int);") + tk.MustExec("set @@tidb_enable_plan_replayer_continuous_capture = 'ON'") + tk.MustQuery("select * from t;") + task := prHandle.DrainTask() + require.NotNil(t, task) + worker := prHandle.GetWorker() + success := worker.HandleTask(task) + defer os.RemoveAll(replayer.GetPlanReplayerDirName()) + require.True(t, success) + tk.MustQuery("select count(*) from mysql.plan_replayer_status").Check(testkit.Rows("1")) +} + +func TestPlanReplayerDumpSingle(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t_dump_single") + tk.MustExec("create table t_dump_single(a int)") + res := tk.MustQuery("plan replayer dump explain select * from t_dump_single") + defer os.RemoveAll(replayer.GetPlanReplayerDirName()) + path := testdata.ConvertRowsToStrings(res.Rows()) + + reader, err := zip.OpenReader(filepath.Join(replayer.GetPlanReplayerDirName(), path[0])) + require.NoError(t, err) + defer func() { require.NoError(t, reader.Close()) }() + for _, file := range reader.File { + require.True(t, checkFileName(file.Name), file.Name) + } +} + +func TestTimezonePushDown(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t (ts timestamp)") + defer tk.MustExec("drop table t") + tk.MustExec(`insert into t values ("2018-09-13 10:02:06")`) + + systemTZ := timeutil.SystemLocation() + require.NotEqual(t, "System", systemTZ.String()) + require.NotEqual(t, "Local", systemTZ.String()) + ctx := context.Background() + count := 0 + ctx1 := context.WithValue(ctx, "CheckSelectRequestHook", func(req *kv.Request) { + count++ + dagReq := new(tipb.DAGRequest) + require.NoError(t, proto.Unmarshal(req.Data, dagReq)) + require.Equal(t, systemTZ.String(), dagReq.GetTimeZoneName()) + }) + rs, err := tk.Session().Execute(ctx1, `select * from t where ts = "2018-09-13 10:02:06"`) + require.NoError(t, err) + rs[0].Close() + + tk.MustExec(`set time_zone="System"`) + rs, err = tk.Session().Execute(ctx1, `select * from t where ts = "2018-09-13 10:02:06"`) + require.NoError(t, err) + rs[0].Close() + + require.Equal(t, 2, count) // Make sure the hook function is called. +} + +func TestNotFillCacheFlag(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t (id int primary key)") + tk.MustExec("insert into t values (1)") + + tests := []struct { + sql string + expect bool + }{ + {"select SQL_NO_CACHE * from t", true}, + {"select SQL_CACHE * from t", false}, + {"select * from t", false}, + } + count := 0 + ctx := context.Background() + for _, test := range tests { + ctx1 := context.WithValue(ctx, "CheckSelectRequestHook", func(req *kv.Request) { + count++ + comment := fmt.Sprintf("sql=%s, expect=%v, get=%v", test.sql, test.expect, req.NotFillCache) + require.Equal(t, test.expect, req.NotFillCache, comment) + }) + rs, err := tk.Session().Execute(ctx1, test.sql) + require.NoError(t, err) + tk.ResultSetToResult(rs[0], fmt.Sprintf("sql: %v", test.sql)) + } + require.Equal(t, len(tests), count) // Make sure the hook function is called. +} + +func TestCheckIndex(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + + ctx := testkit.NewSession(t, store) + se, err := session.CreateSession4Test(store) + require.NoError(t, err) + defer se.Close() + + _, err = se.Execute(context.Background(), "create database test_admin") + require.NoError(t, err) + _, err = se.Execute(context.Background(), "use test_admin") + require.NoError(t, err) + _, err = se.Execute(context.Background(), "create table t (pk int primary key, c int default 1, c1 int default 1, unique key c(c))") + require.NoError(t, err) + + is := dom.InfoSchema() + db := pmodel.NewCIStr("test_admin") + dbInfo, ok := is.SchemaByName(db) + require.True(t, ok) + + tblName := pmodel.NewCIStr("t") + tbl, err := is.TableByName(context.Background(), db, tblName) + require.NoError(t, err) + tbInfo := tbl.Meta() + + alloc := autoid.NewAllocator(dom, dbInfo.ID, tbInfo.ID, false, autoid.RowIDAllocType) + tb, err := tables.TableFromMeta(autoid.NewAllocators(false, alloc), tbInfo) + require.NoError(t, err) + + _, err = se.Execute(context.Background(), "admin check index t c") + require.NoError(t, err) + + _, err = se.Execute(context.Background(), "admin check index t C") + require.NoError(t, err) + + // set data to: + // index data (handle, data): (1, 10), (2, 20) + // table data (handle, data): (1, 10), (2, 20) + recordVal1 := types.MakeDatums(int64(1), int64(10), int64(11)) + recordVal2 := types.MakeDatums(int64(2), int64(20), int64(21)) + require.NoError(t, sessiontxn.NewTxn(context.Background(), ctx)) + txn, err := ctx.Txn(true) + require.NoError(t, err) + _, err = tb.AddRecord(ctx.GetTableCtx(), txn, recordVal1) + require.NoError(t, err) + _, err = tb.AddRecord(ctx.GetTableCtx(), txn, recordVal2) + require.NoError(t, err) + require.NoError(t, txn.Commit(context.Background())) + + mockCtx := mock.NewContext() + idx := tb.Indices()[0] + + _, err = se.Execute(context.Background(), "admin check index t idx_inexistent") + require.Error(t, err) + require.Contains(t, err.Error(), "not exist") + + // set data to: + // index data (handle, data): (1, 10), (2, 20), (3, 30) + // table data (handle, data): (1, 10), (2, 20), (4, 40) + txn, err = store.Begin() + require.NoError(t, err) + _, err = idx.Create(mockCtx.GetTableCtx(), txn, types.MakeDatums(int64(30)), kv.IntHandle(3), nil) + require.NoError(t, err) + key := tablecodec.EncodeRowKey(tb.Meta().ID, kv.IntHandle(4).Encoded()) + setColValue(t, txn, key, types.NewDatum(int64(40))) + err = txn.Commit(context.Background()) + require.NoError(t, err) + _, err = se.Execute(context.Background(), "admin check index t c") + require.Error(t, err) + require.Equal(t, "[admin:8223]data inconsistency in table: t, index: c, handle: 3, index-values:\"handle: 3, values: [KindInt64 30]\" != record-values:\"\"", err.Error()) + + // set data to: + // index data (handle, data): (1, 10), (2, 20), (3, 30), (4, 40) + // table data (handle, data): (1, 10), (2, 20), (4, 40) + txn, err = store.Begin() + require.NoError(t, err) + _, err = idx.Create(mockCtx.GetTableCtx(), txn, types.MakeDatums(int64(40)), kv.IntHandle(4), nil) + require.NoError(t, err) + err = txn.Commit(context.Background()) + require.NoError(t, err) + _, err = se.Execute(context.Background(), "admin check index t c") + require.Error(t, err) + require.EqualError(t, err, "[admin:8223]data inconsistency in table: t, index: c, handle: 3, index-values:\"handle: 3, values: [KindInt64 30]\" != record-values:\"\"") + + // set data to: + // index data (handle, data): (1, 10), (4, 40) + // table data (handle, data): (1, 10), (2, 20), (4, 40) + txn, err = store.Begin() + require.NoError(t, err) + err = idx.Delete(mockCtx.GetTableCtx(), txn, types.MakeDatums(int64(30)), kv.IntHandle(3)) + require.NoError(t, err) + err = idx.Delete(mockCtx.GetTableCtx(), txn, types.MakeDatums(int64(20)), kv.IntHandle(2)) + require.NoError(t, err) + err = txn.Commit(context.Background()) + require.NoError(t, err) + _, err = se.Execute(context.Background(), "admin check index t c") + require.Error(t, err) + require.EqualError(t, err, "[admin:8223]data inconsistency in table: t, index: c, handle: 2, index-values:\"\" != record-values:\"handle: 2, values: [KindInt64 20]\"") + + // TODO: pass the case below: + // set data to: + // index data (handle, data): (1, 10), (4, 40), (2, 30) + // table data (handle, data): (1, 10), (2, 20), (4, 40) +} + +func setColValue(t *testing.T, txn kv.Transaction, key kv.Key, v types.Datum) { + row := []types.Datum{v, {}} + colIDs := []int64{2, 3} + sc := stmtctx.NewStmtCtxWithTimeZone(time.Local) + rd := rowcodec.Encoder{Enable: true} + value, err := tablecodec.EncodeRow(sc.TimeZone(), row, colIDs, nil, nil, nil, &rd) + require.NoError(t, err) + err = txn.Set(key, value) + require.NoError(t, err) +} + +func TestTimestampDefaultValueTimeZone(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("set time_zone = '+08:00'") + tk.MustExec(`create table t (a int, b timestamp default "2019-01-17 14:46:14")`) + tk.MustExec("insert into t set a=1") + r := tk.MustQuery(`show create table t`) + r.Check(testkit.Rows("t CREATE TABLE `t` (\n" + " `a` int(11) DEFAULT NULL,\n" + " `b` timestamp DEFAULT '2019-01-17 14:46:14'\n" + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + tk.MustExec("set time_zone = '+00:00'") + tk.MustExec("insert into t set a=2") + r = tk.MustQuery(`show create table t`) + r.Check(testkit.Rows("t CREATE TABLE `t` (\n" + " `a` int(11) DEFAULT NULL,\n" + " `b` timestamp DEFAULT '2019-01-17 06:46:14'\n" + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + r = tk.MustQuery(`select a,b from t order by a`) + r.Check(testkit.Rows("1 2019-01-17 06:46:14", "2 2019-01-17 06:46:14")) + // Test the column's version is greater than ColumnInfoVersion1. + is := domain.GetDomain(tk.Session()).InfoSchema() + require.NotNil(t, is) + tb, err := is.TableByName(context.Background(), pmodel.NewCIStr("test"), pmodel.NewCIStr("t")) + require.NoError(t, err) + tb.Cols()[1].Version = model.ColumnInfoVersion1 + 1 + tk.MustExec("insert into t set a=3") + r = tk.MustQuery(`select a,b from t order by a`) + r.Check(testkit.Rows("1 2019-01-17 06:46:14", "2 2019-01-17 06:46:14", "3 2019-01-17 06:46:14")) + tk.MustExec("delete from t where a=3") + // Change time zone back. + tk.MustExec("set time_zone = '+08:00'") + r = tk.MustQuery(`select a,b from t order by a`) + r.Check(testkit.Rows("1 2019-01-17 14:46:14", "2 2019-01-17 14:46:14")) + tk.MustExec("set time_zone = '-08:00'") + r = tk.MustQuery(`show create table t`) + r.Check(testkit.Rows("t CREATE TABLE `t` (\n" + " `a` int(11) DEFAULT NULL,\n" + " `b` timestamp DEFAULT '2019-01-16 22:46:14'\n" + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + + // test zero default value in multiple time zone. + defer tk.MustExec(fmt.Sprintf("set @@sql_mode='%s'", tk.MustQuery("select @@sql_mode").Rows()[0][0])) + tk.MustExec("set @@sql_mode='STRICT_TRANS_TABLES,NO_ENGINE_SUBSTITUTION';") + tk.MustExec("drop table if exists t") + tk.MustExec("set time_zone = '+08:00'") + tk.MustExec(`create table t (a int, b timestamp default "0000-00-00 00")`) + tk.MustExec("insert into t set a=1") + r = tk.MustQuery(`show create table t`) + r.Check(testkit.Rows("t CREATE TABLE `t` (\n" + " `a` int(11) DEFAULT NULL,\n" + " `b` timestamp DEFAULT '0000-00-00 00:00:00'\n" + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + tk.MustExec("set time_zone = '+00:00'") + tk.MustExec("insert into t set a=2") + r = tk.MustQuery(`show create table t`) + r.Check(testkit.Rows("t CREATE TABLE `t` (\n" + " `a` int(11) DEFAULT NULL,\n" + " `b` timestamp DEFAULT '0000-00-00 00:00:00'\n" + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + tk.MustExec("set time_zone = '-08:00'") + tk.MustExec("insert into t set a=3") + r = tk.MustQuery(`show create table t`) + r.Check(testkit.Rows("t CREATE TABLE `t` (\n" + " `a` int(11) DEFAULT NULL,\n" + " `b` timestamp DEFAULT '0000-00-00 00:00:00'\n" + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + r = tk.MustQuery(`select a,b from t order by a`) + r.Check(testkit.Rows("1 0000-00-00 00:00:00", "2 0000-00-00 00:00:00", "3 0000-00-00 00:00:00")) + + // test add timestamp column default current_timestamp. + tk.MustExec(`drop table if exists t`) + tk.MustExec(`set time_zone = 'Asia/Shanghai'`) + tk.MustExec(`create table t (a int)`) + tk.MustExec(`insert into t set a=1`) + tk.MustExec(`alter table t add column b timestamp not null default current_timestamp;`) + timeIn8 := tk.MustQuery("select b from t").Rows()[0][0] + tk.MustExec(`set time_zone = '+00:00'`) + timeIn0 := tk.MustQuery("select b from t").Rows()[0][0] + require.NotEqual(t, timeIn8, timeIn0) + datumTimeIn8, err := expression.GetTimeValue(tk.Session().GetExprCtx(), timeIn8, mysql.TypeTimestamp, 0, nil) + require.NoError(t, err) + tIn8To0 := datumTimeIn8.GetMysqlTime() + timeZoneIn8, err := time.LoadLocation("Asia/Shanghai") + require.NoError(t, err) + err = tIn8To0.ConvertTimeZone(timeZoneIn8, time.UTC) + require.NoError(t, err) + require.Equal(t, tIn8To0.String(), timeIn0) + + // test add index. + tk.MustExec(`alter table t add index(b);`) + tk.MustExec("admin check table t") + tk.MustExec(`set time_zone = '+05:00'`) + tk.MustExec("admin check table t") + + // 1. add a timestamp general column + // 2. add the index + tk.MustExec(`drop table if exists t`) + // change timezone + tk.MustExec(`set time_zone = 'Asia/Shanghai'`) + tk.MustExec(`create table t(a timestamp default current_timestamp)`) + tk.MustExec(`insert into t set a="20220413154712"`) + tk.MustExec(`alter table t add column b timestamp as (a+1) virtual;`) + // change timezone + tk.MustExec(`set time_zone = '+05:00'`) + tk.MustExec(`insert into t set a="20220413154840"`) + tk.MustExec(`alter table t add index(b);`) + tk.MustExec("admin check table t") + tk.MustExec(`set time_zone = '-03:00'`) + tk.MustExec("admin check table t") +} + +func TestTiDBCurrentTS(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustQuery("select @@tidb_current_ts").Check(testkit.Rows("0")) + tk.MustExec("begin") + rows := tk.MustQuery("select @@tidb_current_ts").Rows() + tsStr := rows[0][0].(string) + txn, err := tk.Session().Txn(true) + require.NoError(t, err) + require.Equal(t, fmt.Sprintf("%d", txn.StartTS()), tsStr) + tk.MustExec("begin") + rows = tk.MustQuery("select @@tidb_current_ts").Rows() + newTsStr := rows[0][0].(string) + txn, err = tk.Session().Txn(true) + require.NoError(t, err) + require.Equal(t, fmt.Sprintf("%d", txn.StartTS()), newTsStr) + require.NotEqual(t, tsStr, newTsStr) + tk.MustExec("commit") + tk.MustQuery("select @@tidb_current_ts").Check(testkit.Rows("0")) + + err = tk.ExecToErr("set @@tidb_current_ts = '1'") + require.True(t, terror.ErrorEqual(err, variable.ErrIncorrectScope), fmt.Sprintf("err: %v", err)) +} + +func TestTiDBLastTxnInfo(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a int primary key)") + tk.MustQuery("select @@tidb_last_txn_info").Check(testkit.Rows("")) + + tk.MustExec("insert into t values (1)") + rows1 := tk.MustQuery("select json_extract(@@tidb_last_txn_info, '$.start_ts'), json_extract(@@tidb_last_txn_info, '$.commit_ts')").Rows() + require.Greater(t, rows1[0][0].(string), "0") + require.Less(t, rows1[0][0].(string), rows1[0][1].(string)) + + tk.MustExec("begin") + tk.MustQuery("select a from t where a = 1").Check(testkit.Rows("1")) + rows2 := tk.MustQuery("select json_extract(@@tidb_last_txn_info, '$.start_ts'), json_extract(@@tidb_last_txn_info, '$.commit_ts'), @@tidb_current_ts").Rows() + tk.MustExec("commit") + rows3 := tk.MustQuery("select json_extract(@@tidb_last_txn_info, '$.start_ts'), json_extract(@@tidb_last_txn_info, '$.commit_ts')").Rows() + require.Equal(t, rows1[0][0], rows2[0][0]) + require.Equal(t, rows1[0][1], rows2[0][1]) + require.Equal(t, rows1[0][0], rows3[0][0]) + require.Equal(t, rows1[0][1], rows3[0][1]) + require.Less(t, rows2[0][1], rows2[0][2]) + + tk.MustExec("begin") + tk.MustExec("update t set a = a + 1 where a = 1") + rows4 := tk.MustQuery("select json_extract(@@tidb_last_txn_info, '$.start_ts'), json_extract(@@tidb_last_txn_info, '$.commit_ts'), @@tidb_current_ts").Rows() + tk.MustExec("commit") + rows5 := tk.MustQuery("select json_extract(@@tidb_last_txn_info, '$.start_ts'), json_extract(@@tidb_last_txn_info, '$.commit_ts')").Rows() + require.Equal(t, rows1[0][0], rows4[0][0]) + require.Equal(t, rows1[0][1], rows4[0][1]) + require.Equal(t, rows5[0][0], rows4[0][2]) + require.Less(t, rows4[0][1], rows4[0][2]) + require.Less(t, rows4[0][2], rows5[0][1]) + + tk.MustExec("begin") + tk.MustExec("update t set a = a + 1 where a = 2") + tk.MustExec("rollback") + rows6 := tk.MustQuery("select json_extract(@@tidb_last_txn_info, '$.start_ts'), json_extract(@@tidb_last_txn_info, '$.commit_ts')").Rows() + require.Equal(t, rows5[0][0], rows6[0][0]) + require.Equal(t, rows5[0][1], rows6[0][1]) + + tk.MustExec("begin optimistic") + tk.MustExec("insert into t values (2)") + err := tk.ExecToErr("commit") + require.Error(t, err) + rows7 := tk.MustQuery("select json_extract(@@tidb_last_txn_info, '$.start_ts'), json_extract(@@tidb_last_txn_info, '$.commit_ts'), json_extract(@@tidb_last_txn_info, '$.error')").Rows() + require.Greater(t, rows7[0][0], rows5[0][0]) + require.Equal(t, "0", rows7[0][1]) + require.Contains(t, err.Error(), rows7[0][1]) + + err = tk.ExecToErr("set @@tidb_last_txn_info = '{}'") + require.True(t, terror.ErrorEqual(err, variable.ErrIncorrectScope), fmt.Sprintf("err: %v", err)) +} + +func TestTiDBLastQueryInfo(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a int primary key, v int)") + tk.MustQuery("select json_extract(@@tidb_last_query_info, '$.start_ts'), json_extract(@@tidb_last_query_info, '$.start_ts')").Check(testkit.Rows("0 0")) + + toUint64 := func(str any) uint64 { + res, err := strconv.ParseUint(str.(string), 10, 64) + require.NoError(t, err) + return res + } + + tk.MustExec("select * from t") + rows := tk.MustQuery("select json_extract(@@tidb_last_query_info, '$.start_ts'), json_extract(@@tidb_last_query_info, '$.for_update_ts')").Rows() + require.Greater(t, toUint64(rows[0][0]), uint64(0)) + require.Equal(t, rows[0][1], rows[0][0]) + + tk.MustExec("insert into t values (1, 10)") + rows = tk.MustQuery("select json_extract(@@tidb_last_query_info, '$.start_ts'), json_extract(@@tidb_last_query_info, '$.for_update_ts')").Rows() + require.Greater(t, toUint64(rows[0][0]), uint64(0)) + require.Equal(t, rows[0][1], rows[0][0]) + // tidb_last_txn_info is still valid after checking query info. + rows = tk.MustQuery("select json_extract(@@tidb_last_txn_info, '$.start_ts'), json_extract(@@tidb_last_txn_info, '$.commit_ts')").Rows() + require.Greater(t, toUint64(rows[0][0]), uint64(0)) + require.Less(t, rows[0][0].(string), rows[0][1].(string)) + + tk.MustExec("begin pessimistic") + tk.MustExec("select * from t") + rows = tk.MustQuery("select json_extract(@@tidb_last_query_info, '$.start_ts'), json_extract(@@tidb_last_query_info, '$.for_update_ts')").Rows() + require.Greater(t, toUint64(rows[0][0]), uint64(0)) + require.Equal(t, rows[0][1], rows[0][0]) + + tk2 := testkit.NewTestKit(t, store) + tk2.MustExec("use test") + tk2.MustExec("update t set v = 11 where a = 1") + + tk.MustExec("select * from t") + rows = tk.MustQuery("select json_extract(@@tidb_last_query_info, '$.start_ts'), json_extract(@@tidb_last_query_info, '$.for_update_ts')").Rows() + require.Greater(t, toUint64(rows[0][0]), uint64(0)) + require.Equal(t, rows[0][1], rows[0][0]) + + tk.MustExec("update t set v = 12 where a = 1") + rows = tk.MustQuery("select json_extract(@@tidb_last_query_info, '$.start_ts'), json_extract(@@tidb_last_query_info, '$.for_update_ts')").Rows() + require.Greater(t, toUint64(rows[0][0]), uint64(0)) + require.Less(t, toUint64(rows[0][0]), toUint64(rows[0][1])) + + tk.MustExec("commit") + + tk.MustExec("set transaction isolation level read committed") + tk.MustExec("begin pessimistic") + tk.MustExec("select * from t") + rows = tk.MustQuery("select json_extract(@@tidb_last_query_info, '$.start_ts'), json_extract(@@tidb_last_query_info, '$.for_update_ts')").Rows() + require.Greater(t, toUint64(rows[0][0]), uint64(0)) + require.Less(t, toUint64(rows[0][0]), toUint64(rows[0][1])) + + tk.MustExec("rollback") +} + +func TestPartitionHashCode(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec(`create table t(c1 bigint, c2 bigint, c3 bigint, primary key(c1)) partition by hash (c1) partitions 4;`) + var wg util.WaitGroupWrapper + for i := 0; i < 5; i++ { + wg.Run(func() { + tk1 := testkit.NewTestKit(t, store) + tk1.MustExec("use test") + for i := 0; i < 5; i++ { + tk1.MustExec("select * from t") + } + }) + } + wg.Wait() +} + +func TestPrevStmtDesensitization(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test;") + tk.MustExec(fmt.Sprintf("set @@session.%v=1", variable.TiDBRedactLog)) + defer tk.MustExec(fmt.Sprintf("set @@session.%v=0", variable.TiDBRedactLog)) + tk.MustExec("create table t (a int, unique key (a))") + tk.MustExec("begin") + tk.MustExec("insert into t values (1),(2)") + require.Equal(t, "insert into `t` values ( ... )", tk.Session().GetSessionVars().PrevStmt.String()) + tk.MustGetErrMsg("insert into t values (1)", `[kv:1062]Duplicate entry '?' for key 't.a'`) +} + +func TestIssue19148(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a decimal(16, 2));") + tk.MustExec("select * from t where a > any_value(a);") + is := domain.GetDomain(tk.Session()).InfoSchema() + tblInfo, err := is.TableByName(context.Background(), pmodel.NewCIStr("test"), pmodel.NewCIStr("t")) + require.NoError(t, err) + require.Zero(t, tblInfo.Meta().Columns[0].GetFlag()) +} + +func TestOOMActionPriority(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t0") + tk.MustExec("drop table if exists t1") + tk.MustExec("drop table if exists t2") + tk.MustExec("drop table if exists t3") + tk.MustExec("drop table if exists t4") + tk.MustExec("create table t0(a int)") + tk.MustExec("insert into t0 values(1)") + tk.MustExec("create table t1(a int)") + tk.MustExec("insert into t1 values(1)") + tk.MustExec("create table t2(a int)") + tk.MustExec("insert into t2 values(1)") + tk.MustExec("create table t3(a int)") + tk.MustExec("insert into t3 values(1)") + tk.MustExec("create table t4(a int)") + tk.MustExec("insert into t4 values(1)") + tk.MustQuery("select * from t0 join t1 join t2 join t3 join t4 order by t0.a").Check(testkit.Rows("1 1 1 1 1")) + action := tk.Session().GetSessionVars().StmtCtx.MemTracker.GetFallbackForTest(true) + // All actions are finished and removed. + require.Equal(t, action.GetPriority(), int64(memory.DefLogPriority)) +} + +// Test invoke Close without invoking Open before for each operators. +func TestUnreasonablyClose(t *testing.T) { + store := testkit.CreateMockStore(t) + + is := infoschema.MockInfoSchema([]*model.TableInfo{plannercore.MockSignedTable(), plannercore.MockUnsignedTable()}) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("set tidb_cost_model_version=2") + // To enable the shuffleExec operator. + tk.MustExec("set @@tidb_merge_join_concurrency=4") + + var opsNeedsCovered = []base.PhysicalPlan{ + &plannercore.PhysicalHashJoin{}, + &plannercore.PhysicalMergeJoin{}, + &plannercore.PhysicalIndexJoin{}, + &plannercore.PhysicalIndexHashJoin{}, + &plannercore.PhysicalTableReader{}, + &plannercore.PhysicalIndexReader{}, + &plannercore.PhysicalIndexLookUpReader{}, + &plannercore.PhysicalIndexMergeReader{}, + &plannercore.PhysicalApply{}, + &plannercore.PhysicalHashAgg{}, + &plannercore.PhysicalStreamAgg{}, + &plannercore.PhysicalLimit{}, + &plannercore.PhysicalSort{}, + &plannercore.PhysicalTopN{}, + &plannercore.PhysicalCTE{}, + &plannercore.PhysicalCTETable{}, + &plannercore.PhysicalMaxOneRow{}, + &plannercore.PhysicalProjection{}, + &plannercore.PhysicalSelection{}, + &plannercore.PhysicalTableDual{}, + &plannercore.PhysicalWindow{}, + &plannercore.PhysicalShuffle{}, + &plannercore.PhysicalUnionAll{}, + } + + opsNeedsCoveredMask := uint64(1< t1.a) AS a from t as t1) t", + "select /*+ hash_agg() */ count(f) from t group by a", + "select /*+ stream_agg() */ count(f) from t", + "select * from t order by a, f", + "select * from t order by a, f limit 1", + "select * from t limit 1", + "select (select t1.a from t t1 where t1.a > t2.a) as a from t t2;", + "select a + 1 from t", + "select count(*) a from t having a > 1", + "select * from t where a = 1.1", + "with recursive cte1(c1) as (select 1 union select c1 + 1 from cte1 limit 5 offset 0) select * from cte1", + "select /*+use_index_merge(t, c_d_e, f)*/ * from t where c < 1 or f > 2", + "select sum(f) over (partition by f) from t", + "select /*+ merge_join(t1)*/ * from t t1 join t t2 on t1.d = t2.d", + "select a from t union all select a from t", + } { + comment := fmt.Sprintf("case:%v sql:%s", i, tc) + stmt, err := p.ParseOneStmt(tc, "", "") + require.NoError(t, err, comment) + err = sessiontxn.NewTxn(context.Background(), tk.Session()) + require.NoError(t, err, comment) + + err = sessiontxn.GetTxnManager(tk.Session()).OnStmtStart(context.TODO(), stmt) + require.NoError(t, err, comment) + + executorBuilder := executor.NewMockExecutorBuilderForTest(tk.Session(), is) + + nodeW := resolve.NewNodeW(stmt) + p, _, _ := planner.Optimize(context.TODO(), tk.Session(), nodeW, is) + require.NotNil(t, p) + + // This for loop level traverses the plan tree to get which operators are covered. + var hasCTE bool + for child := []base.PhysicalPlan{p.(base.PhysicalPlan)}; len(child) != 0; { + newChild := make([]base.PhysicalPlan, 0, len(child)) + for _, ch := range child { + found := false + for k, t := range opsNeedsCovered { + if reflect.TypeOf(t) == reflect.TypeOf(ch) { + opsAlreadyCoveredMask |= 1 << k + found = true + break + } + } + require.True(t, found, fmt.Sprintf("case: %v sql: %s operator %v is not registered in opsNeedsCoveredMask", i, tc, reflect.TypeOf(ch))) + switch x := ch.(type) { + case *plannercore.PhysicalCTE: + newChild = append(newChild, x.RecurPlan) + newChild = append(newChild, x.SeedPlan) + hasCTE = true + continue + case *plannercore.PhysicalShuffle: + newChild = append(newChild, x.DataSources...) + newChild = append(newChild, x.Tails...) + continue + } + newChild = append(newChild, ch.Children()...) + } + child = newChild + } + + if hasCTE { + // Normally CTEStorages will be setup in ResetContextOfStmt. + // But the following case call e.Close() directly, instead of calling session.ExecStmt(), which calls ResetContextOfStmt. + // So need to setup CTEStorages manually. + tk.Session().GetSessionVars().StmtCtx.CTEStorageMap = map[int]*executor.CTEStorages{} + } + e := executorBuilder.Build(p) + + func() { + defer func() { + r := recover() + buf := make([]byte, 4096) + stackSize := runtime.Stack(buf, false) + buf = buf[:stackSize] + require.Nil(t, r, fmt.Sprintf("case: %v\n sql: %s\n error stack: %v", i, tc, string(buf))) + }() + require.NoError(t, e.Close(), comment) + }() + } + // The following code is used to make sure all the operators registered + // in opsNeedsCoveredMask are covered. + commentBuf := strings.Builder{} + if opsAlreadyCoveredMask != opsNeedsCoveredMask { + for i := range opsNeedsCovered { + if opsAlreadyCoveredMask&(1< t1.a) AS a from t as t1) t;") + require.Contains(t, result.Rows()[1][0], "Apply") + var ( + ind int + flag bool + ) + value := (result.Rows()[1][5]).(string) + for ind = 0; ind < len(value)-5; ind++ { + if value[ind:ind+5] == "cache" { + flag = true + break + } + } + require.True(t, flag) + require.Equal(t, "cache:ON, cacheHitRatio:88.889%", value[ind:]) + + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t(a int, index idx(a));") + tk.MustExec("insert into t values (1),(2),(3),(4),(5),(6),(7),(8),(9);") + tk.MustExec("analyze table t;") + result = tk.MustQuery("explain analyze SELECT count(a) FROM (SELECT (SELECT min(a) FROM t as t2 WHERE t2.a > t1.a) AS a from t as t1) t;") + require.Contains(t, result.Rows()[1][0], "Apply") + flag = false + value = (result.Rows()[1][5]).(string) + for ind = 0; ind < len(value)-5; ind++ { + if value[ind:ind+5] == "cache" { + flag = true + break + } + } + require.True(t, flag) + require.Equal(t, "cache:OFF", value[ind:]) +} + +func TestCollectDMLRuntimeStats(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1 (a int, b int, unique index (a))") + + testSQLs := []string{ + "insert ignore into t1 values (5,5);", + "insert into t1 values (5,5) on duplicate key update a=a+1;", + "replace into t1 values (5,6),(6,7)", + "update t1 set a=a+1 where a=6;", + } + + getRootStats := func() string { + info := tk.Session().ShowProcess() + require.NotNil(t, info) + p, ok := info.Plan.(base.Plan) + require.True(t, ok) + stats := tk.Session().GetSessionVars().StmtCtx.RuntimeStatsColl.GetRootStats(p.ID()) + return stats.String() + } + for _, sql := range testSQLs { + tk.MustExec(sql) + require.Regexp(t, "time.*loops.*Get.*num_rpc.*total_time.*", getRootStats()) + } + + // Test for lock keys stats. + tk.MustExec("begin pessimistic") + tk.MustExec("update t1 set b=b+1") + require.Regexp(t, "time.*lock_keys.*time.* region.* keys.* lock_rpc:.* rpc_count.*", getRootStats()) + tk.MustExec("rollback") + + tk.MustExec("begin pessimistic") + tk.MustQuery("select * from t1 for update").Check(testkit.Rows("5 6", "7 7")) + require.Regexp(t, "time.*lock_keys.*time.* region.* keys.* lock_rpc:.* rpc_count.*", getRootStats()) + tk.MustExec("rollback") + + tk.MustExec("begin pessimistic") + tk.MustExec("insert ignore into t1 values (9,9)") + require.Regexp(t, "time:.*, loops:.*, prepare:.*, check_insert: {total_time:.*, mem_insert_time:.*, prefetch:.*, rpc:{BatchGet:{num_rpc:.*, total_time:.*}}}.*", getRootStats()) + tk.MustExec("rollback") + + tk.MustExec("begin pessimistic") + tk.MustExec("insert into t1 values (10,10) on duplicate key update a=a+1") + require.Regexp(t, "time:.*, loops:.*, prepare:.*, check_insert: {total_time:.*, mem_insert_time:.*, prefetch:.*, rpc:{BatchGet:{num_rpc:.*, total_time:.*}.*", getRootStats()) + tk.MustExec("rollback") + + tk.MustExec("begin pessimistic") + tk.MustExec("insert into t1 values (1,2)") + require.Regexp(t, "time:.*, loops:.*, prepare:.*, insert:.*", getRootStats()) + tk.MustExec("rollback") + + tk.MustExec("begin pessimistic") + tk.MustExec("insert ignore into t1 values(11,11) on duplicate key update `a`=`a`+1") + require.Regexp(t, "time:.*, loops:.*, prepare:.*, check_insert: {total_time:.*, mem_insert_time:.*, prefetch:.*, rpc:.*}", getRootStats()) + tk.MustExec("rollback") + + tk.MustExec("begin pessimistic") + tk.MustExec("replace into t1 values (1,4)") + require.Regexp(t, "time:.*, loops:.*, prefetch:.*, rpc:.*", getRootStats()) + tk.MustExec("rollback") +} + +func TestTableSampleTemporaryTable(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + // For mocktikv, safe point is not initialized, we manually insert it for snapshot to use. + safePointName := "tikv_gc_safe_point" + safePointValue := "20160102-15:04:05 -0700" + safePointComment := "All versions after safe point can be accessed. (DO NOT EDIT)" + updateSafePoint := fmt.Sprintf(`INSERT INTO mysql.tidb VALUES ('%[1]s', '%[2]s', '%[3]s') + ON DUPLICATE KEY + UPDATE variable_value = '%[2]s', comment = '%[3]s'`, safePointName, safePointValue, safePointComment) + tk.MustExec(updateSafePoint) + + tk.MustExec("use test") + tk.MustExec("drop table if exists tmp1") + tk.MustExec("create global temporary table tmp1 " + + "(id int not null primary key, code int not null, value int default null, unique key code(code))" + + "on commit delete rows") + + tk.MustExec("use test") + tk.MustExec("drop table if exists tmp2") + tk.MustExec("create temporary table tmp2 (id int not null primary key, code int not null, value int default null, unique key code(code));") + + // sleep 1us to make test stale + time.Sleep(time.Microsecond) + + // test tablesample return empty for global temporary table + tk.MustQuery("select * from tmp1 tablesample regions()").Check(testkit.Rows()) + + tk.MustExec("begin") + tk.MustExec("insert into tmp1 values (1, 1, 1)") + tk.MustQuery("select * from tmp1 tablesample regions()").Check(testkit.Rows()) + tk.MustExec("commit") + + // tablesample for global temporary table should not return error for compatibility of tools like dumpling + tk.MustExec("set @@tidb_snapshot=NOW(6)") + tk.MustQuery("select * from tmp1 tablesample regions()").Check(testkit.Rows()) + + tk.MustExec("begin") + tk.MustQuery("select * from tmp1 tablesample regions()").Check(testkit.Rows()) + tk.MustExec("commit") + tk.MustExec("set @@tidb_snapshot=''") + + // test tablesample returns error for local temporary table + tk.MustGetErrMsg("select * from tmp2 tablesample regions()", "TABLESAMPLE clause can not be applied to local temporary tables") + + tk.MustExec("begin") + tk.MustExec("insert into tmp2 values (1, 1, 1)") + tk.MustGetErrMsg("select * from tmp2 tablesample regions()", "TABLESAMPLE clause can not be applied to local temporary tables") + tk.MustExec("commit") + tk.MustGetErrMsg("select * from tmp2 tablesample regions()", "TABLESAMPLE clause can not be applied to local temporary tables") +} + +func TestGetResultRowsCount(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t (a int)") + for i := 1; i <= 10; i++ { + tk.MustExec(fmt.Sprintf("insert into t values (%v)", i)) + } + cases := []struct { + sql string + row int64 + }{ + {"select * from t", 10}, + {"select * from t where a < 0", 0}, + {"select * from t where a <= 3", 3}, + {"insert into t values (11)", 0}, + {"replace into t values (12)", 0}, + {"update t set a=13 where a=12", 0}, + } + + for _, ca := range cases { + if strings.HasPrefix(ca.sql, "select") { + tk.MustQuery(ca.sql) + } else { + tk.MustExec(ca.sql) + } + info := tk.Session().ShowProcess() + require.NotNil(t, info) + p, ok := info.Plan.(base.Plan) + require.True(t, ok) + cnt := executor.GetResultRowsCount(tk.Session().GetSessionVars().StmtCtx, p) + require.Equal(t, ca.row, cnt, fmt.Sprintf("sql: %v", ca.sql)) + } +} + +func TestAdminShowDDLJobs(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("create database if not exists test_admin_show_ddl_jobs") + tk.MustExec("use test_admin_show_ddl_jobs") + tk.MustExec("create table t (a int);") + + re := tk.MustQuery("admin show ddl jobs 1") + row := re.Rows()[0] + require.Equal(t, "test_admin_show_ddl_jobs", row[1]) + jobID, err := strconv.Atoi(row[0].(string)) + require.NoError(t, err) + + job, err := ddl.GetHistoryJobByID(tk.Session(), int64(jobID)) + require.NoError(t, err) + require.NotNil(t, job) + // Test for compatibility. Old TiDB version doesn't have SchemaName field, and the BinlogInfo maybe nil. + // See PR: 11561. + job.BinlogInfo = nil + job.SchemaName = "" + err = sessiontxn.NewTxnInStmt(context.Background(), tk.Session()) + require.NoError(t, err) + txn, err := tk.Session().Txn(true) + require.NoError(t, err) + err = meta.NewMutator(txn).AddHistoryDDLJob(job, true) + require.NoError(t, err) + tk.Session().StmtCommit(context.Background()) + + re = tk.MustQuery("admin show ddl jobs 1") + row = re.Rows()[0] + require.Equal(t, "test_admin_show_ddl_jobs", row[1]) + + re = tk.MustQuery("admin show ddl jobs 1 where job_type='create table'") + row = re.Rows()[0] + require.Equal(t, "test_admin_show_ddl_jobs", row[1]) + require.Equal(t, "", row[10]) + + // Test the START_TIME and END_TIME field. + tk.MustExec(`set @@time_zone = 'Asia/Shanghai'`) + re = tk.MustQuery("admin show ddl jobs where end_time is not NULL") + row = re.Rows()[0] + createTime, err := types.ParseDatetime(types.DefaultStmtNoWarningContext, row[8].(string)) + require.NoError(t, err) + startTime, err := types.ParseDatetime(types.DefaultStmtNoWarningContext, row[9].(string)) + require.NoError(t, err) + endTime, err := types.ParseDatetime(types.DefaultStmtNoWarningContext, row[10].(string)) + require.NoError(t, err) + tk.MustExec(`set @@time_zone = 'Europe/Amsterdam'`) + re = tk.MustQuery("admin show ddl jobs where end_time is not NULL") + row2 := re.Rows()[0] + require.NotEqual(t, row[8], row2[8]) + require.NotEqual(t, row[9], row2[9]) + require.NotEqual(t, row[10], row2[10]) + createTime2, err := types.ParseDatetime(types.DefaultStmtNoWarningContext, row2[8].(string)) + require.NoError(t, err) + startTime2, err := types.ParseDatetime(types.DefaultStmtNoWarningContext, row2[9].(string)) + require.NoError(t, err) + endTime2, err := types.ParseDatetime(types.DefaultStmtNoWarningContext, row2[10].(string)) + require.NoError(t, err) + loc, err := time.LoadLocation("Asia/Shanghai") + require.NoError(t, err) + loc2, err := time.LoadLocation("Europe/Amsterdam") + require.NoError(t, err) + tt, err := createTime.GoTime(loc) + require.NoError(t, err) + t2, err := createTime2.GoTime(loc2) + require.NoError(t, err) + require.Equal(t, t2.In(time.UTC), tt.In(time.UTC)) + tt, err = startTime.GoTime(loc) + require.NoError(t, err) + t2, err = startTime2.GoTime(loc2) + require.NoError(t, err) + require.Equal(t, t2.In(time.UTC), tt.In(time.UTC)) + tt, err = endTime.GoTime(loc) + require.NoError(t, err) + t2, err = endTime2.GoTime(loc2) + require.NoError(t, err) + require.Equal(t, t2.In(time.UTC), tt.In(time.UTC)) +} + +func TestAdminShowDDLJobsInfo(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + + // Test for issue: https://github.com/pingcap/tidb/issues/29915 + tk.MustExec("create placement policy x followers=4;") + tk.MustExec("create placement policy y " + + "PRIMARY_REGION=\"cn-east-1\" " + + "REGIONS=\"cn-east-1, cn-east-2\" " + + "FOLLOWERS=2") + tk.MustExec("create database if not exists test_admin_show_ddl_jobs") + tk.MustExec("use test_admin_show_ddl_jobs") + + tk.MustExec("create table t (a int);") + tk.MustExec("create table t1 (a int);") + + tk.MustExec("alter table t placement policy x;") + require.Equal(t, "alter table placement", tk.MustQuery("admin show ddl jobs 1").Rows()[0][3]) + + tk.MustExec("rename table t to tt, t1 to tt1") + require.Equal(t, "rename tables", tk.MustQuery("admin show ddl jobs 1").Rows()[0][3]) + + tk.MustExec("create table tt2 (c int) PARTITION BY RANGE (c) " + + "(PARTITION p0 VALUES LESS THAN (6)," + + "PARTITION p1 VALUES LESS THAN (11)," + + "PARTITION p2 VALUES LESS THAN (16)," + + "PARTITION p3 VALUES LESS THAN (21));") + tk.MustExec("alter table tt2 partition p0 placement policy y") + require.Equal(t, "alter table partition placement", tk.MustQuery("admin show ddl jobs 1").Rows()[0][3]) + + tk.MustExec("alter table tt1 cache") + require.Equal(t, "alter table cache", tk.MustQuery("admin show ddl jobs 1").Rows()[0][3]) + tk.MustExec("alter table tt1 nocache") + require.Equal(t, "alter table nocache", tk.MustQuery("admin show ddl jobs 1").Rows()[0][3]) +} + +func TestUnion2(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + + testSQL := `drop table if exists union_test; create table union_test(id int);` + tk.MustExec(testSQL) + + testSQL = `drop table if exists union_test;` + tk.MustExec(testSQL) + testSQL = `create table union_test(id int);` + tk.MustExec(testSQL) + testSQL = `insert union_test values (1),(2)` + tk.MustExec(testSQL) + + testSQL = `select * from (select id from union_test union select id from union_test) t order by id;` + r := tk.MustQuery(testSQL) + r.Check(testkit.Rows("1", "2")) + + r = tk.MustQuery("select 1 union all select 1") + r.Check(testkit.Rows("1", "1")) + + r = tk.MustQuery("select 1 union all select 1 union select 1") + r.Check(testkit.Rows("1")) + + r = tk.MustQuery("select 1 as a union (select 2) order by a limit 1") + r.Check(testkit.Rows("1")) + + r = tk.MustQuery("select 1 as a union (select 2) order by a limit 1, 1") + r.Check(testkit.Rows("2")) + + r = tk.MustQuery("select id from union_test union all (select 1) order by id desc") + r.Check(testkit.Rows("2", "1", "1")) + + r = tk.MustQuery("select id as a from union_test union (select 1) order by a desc") + r.Check(testkit.Rows("2", "1")) + + r = tk.MustQuery(`select null as a union (select "abc") order by a`) + r.Check(testkit.Rows("", "abc")) + + r = tk.MustQuery(`select "abc" as a union (select 1) order by a`) + r.Check(testkit.Rows("1", "abc")) + + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1 (c int, d int)") + tk.MustExec("insert t1 values (NULL, 1)") + tk.MustExec("insert t1 values (1, 1)") + tk.MustExec("insert t1 values (1, 2)") + tk.MustExec("drop table if exists t2") + tk.MustExec("create table t2 (c int, d int)") + tk.MustExec("insert t2 values (1, 3)") + tk.MustExec("insert t2 values (1, 1)") + tk.MustExec("drop table if exists t3") + tk.MustExec("create table t3 (c int, d int)") + tk.MustExec("insert t3 values (3, 2)") + tk.MustExec("insert t3 values (4, 3)") + r = tk.MustQuery(`select sum(c1), c2 from (select c c1, d c2 from t1 union all select d c1, c c2 from t2 union all select c c1, d c2 from t3) x group by c2 order by c2`) + r.Check(testkit.Rows("5 1", "4 2", "4 3")) + + tk.MustExec("drop table if exists t1, t2, t3") + tk.MustExec("create table t1 (a int primary key)") + tk.MustExec("create table t2 (a int primary key)") + tk.MustExec("create table t3 (a int primary key)") + tk.MustExec("insert t1 values (7), (8)") + tk.MustExec("insert t2 values (1), (9)") + tk.MustExec("insert t3 values (2), (3)") + r = tk.MustQuery("select * from t1 union all select * from t2 union all (select * from t3) order by a limit 2") + r.Check(testkit.Rows("1", "2")) + + tk.MustExec("drop table if exists t1, t2") + tk.MustExec("create table t1 (a int)") + tk.MustExec("create table t2 (a int)") + tk.MustExec("insert t1 values (2), (1)") + tk.MustExec("insert t2 values (3), (4)") + r = tk.MustQuery("select * from t1 union all (select * from t2) order by a limit 1") + r.Check(testkit.Rows("1")) + r = tk.MustQuery("select (select * from t1 where a != t.a union all (select * from t2 where a != t.a) order by a limit 1) from t1 t") + r.Check(testkit.Rows("1", "2")) + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (id int unsigned primary key auto_increment, c1 int, c2 int, index c1_c2 (c1, c2))") + tk.MustExec("insert into t (c1, c2) values (1, 1)") + tk.MustExec("insert into t (c1, c2) values (1, 2)") + tk.MustExec("insert into t (c1, c2) values (2, 3)") + r = tk.MustQuery("select * from (select * from t where t.c1 = 1 union select * from t where t.id = 1) s order by s.id") + r.Check(testkit.Rows("1 1 1", "2 1 2")) + + tk.MustExec("drop table if exists t") + tk.MustExec("CREATE TABLE t (f1 DATE)") + tk.MustExec("INSERT INTO t VALUES ('1978-11-26')") + r = tk.MustQuery("SELECT f1+0 FROM t UNION SELECT f1+0 FROM t") + r.Check(testkit.Rows("19781126")) + + tk.MustExec("drop table if exists t") + tk.MustExec("CREATE TABLE t (a int, b int)") + tk.MustExec("INSERT INTO t VALUES ('1', '1')") + r = tk.MustQuery("select b from (SELECT * FROM t UNION ALL SELECT a, b FROM t order by a) t") + r.Check(testkit.Rows("1", "1")) + + tk.MustExec("drop table if exists t") + tk.MustExec("CREATE TABLE t (a DECIMAL(4,2))") + tk.MustExec("INSERT INTO t VALUE(12.34)") + r = tk.MustQuery("SELECT 1 AS c UNION select a FROM t") + r.Sort().Check(testkit.Rows("1.00", "12.34")) + + // #issue3771 + r = tk.MustQuery("SELECT 'a' UNION SELECT CONCAT('a', -4)") + r.Sort().Check(testkit.Rows("a", "a-4")) + + // test race + tk.MustQuery("SELECT @x:=0 UNION ALL SELECT @x:=0 UNION ALL SELECT @x") + + // test field tp + tk.MustExec("drop table if exists t1, t2") + tk.MustExec("CREATE TABLE t1 (a date)") + tk.MustExec("CREATE TABLE t2 (a date)") + tk.MustExec("SELECT a from t1 UNION select a FROM t2") + tk.MustQuery("show create table t1").Check(testkit.Rows("t1 CREATE TABLE `t1` (\n" + " `a` date DEFAULT NULL\n" + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + + // Move from session test. + tk.MustExec("drop table if exists t1, t2") + tk.MustExec("create table t1 (c double);") + tk.MustExec("create table t2 (c double);") + tk.MustExec("insert into t1 value (73);") + tk.MustExec("insert into t2 value (930);") + // If set unspecified column flen to 0, it will cause bug in union. + // This test is used to prevent the bug reappear. + tk.MustQuery("select c from t1 union (select c from t2) order by c").Check(testkit.Rows("73", "930")) + + // issue 5703 + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a date)") + tk.MustExec("insert into t value ('2017-01-01'), ('2017-01-02')") + r = tk.MustQuery("(select a from t where a < 0) union (select a from t where a > 0) order by a") + r.Check(testkit.Rows("2017-01-01", "2017-01-02")) + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int)") + tk.MustExec("insert into t value(0),(0)") + tk.MustQuery("select 1 from (select a from t union all select a from t) tmp").Check(testkit.Rows("1", "1", "1", "1")) + tk.MustQuery("select 10 as a from dual union select a from t order by a desc limit 1 ").Check(testkit.Rows("10")) + tk.MustQuery("select -10 as a from dual union select a from t order by a limit 1 ").Check(testkit.Rows("-10")) + tk.MustQuery("select count(1) from (select a from t union all select a from t) tmp").Check(testkit.Rows("4")) + + err := tk.ExecToErr("select 1 from (select a from t limit 1 union all select a from t limit 1) tmp") + require.Error(t, err) + terr := errors.Cause(err).(*terror.Error) + require.Equal(t, errors.ErrCode(mysql.ErrWrongUsage), terr.Code()) + + err = tk.ExecToErr("select 1 from (select a from t order by a union all select a from t limit 1) tmp") + require.Error(t, err) + terr = errors.Cause(err).(*terror.Error) + require.Equal(t, errors.ErrCode(mysql.ErrWrongUsage), terr.Code()) + + tk.MustGetDBError("(select a from t order by a) union all select a from t limit 1 union all select a from t limit 1", plannererrors.ErrWrongUsage) + + tk.MustExec("(select a from t limit 1) union all select a from t limit 1") + tk.MustExec("(select a from t order by a) union all select a from t order by a") + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int)") + tk.MustExec("insert into t value(1),(2),(3)") + + tk.MustQuery("(select a from t order by a limit 2) union all (select a from t order by a desc limit 2) order by a desc limit 1,2").Check(testkit.Rows("2", "2")) + tk.MustQuery("select a from t union all select a from t order by a desc limit 5").Check(testkit.Rows("3", "3", "2", "2", "1")) + tk.MustQuery("(select a from t order by a desc limit 2) union all select a from t group by a order by a").Check(testkit.Rows("1", "2", "2", "3", "3")) + tk.MustQuery("(select a from t order by a desc limit 2) union all select 33 as a order by a desc limit 2").Check(testkit.Rows("33", "3")) + + tk.MustQuery("select 1 union select 1 union all select 1").Check(testkit.Rows("1", "1")) + tk.MustQuery("select 1 union all select 1 union select 1").Check(testkit.Rows("1")) + + tk.MustExec("drop table if exists t1, t2") + tk.MustExec(`create table t1(a bigint, b bigint);`) + tk.MustExec(`create table t2(a bigint, b bigint);`) + tk.MustExec(`insert into t1 values(1, 1);`) + tk.MustExec(`insert into t1 select * from t1;`) + tk.MustExec(`insert into t1 select * from t1;`) + tk.MustExec(`insert into t1 select * from t1;`) + tk.MustExec(`insert into t1 select * from t1;`) + tk.MustExec(`insert into t1 select * from t1;`) + tk.MustExec(`insert into t1 select * from t1;`) + tk.MustExec(`insert into t2 values(1, 1);`) + tk.MustExec(`set @@tidb_init_chunk_size=2;`) + tk.MustExec(`set @@sql_mode="";`) + tk.MustQuery(`select count(*) from (select t1.a, t1.b from t1 left join t2 on t1.a=t2.a union all select t1.a, t1.a from t1 left join t2 on t1.a=t2.a) tmp;`).Check(testkit.Rows("128")) + tk.MustQuery(`select tmp.a, count(*) from (select t1.a, t1.b from t1 left join t2 on t1.a=t2.a union all select t1.a, t1.a from t1 left join t2 on t1.a=t2.a) tmp;`).Check(testkit.Rows("1 128")) + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int)") + tk.MustExec("insert into t value(1 ,2)") + tk.MustQuery("select a, b from (select a, 0 as d, b from t union all select a, 0 as d, b from t) test;").Check(testkit.Rows("1 2", "1 2")) + + // #issue 8141 + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1(a int, b int)") + tk.MustExec("insert into t1 value(1,2),(1,1),(2,2),(2,2),(3,2),(3,2)") + tk.MustExec("set @@tidb_init_chunk_size=2;") + tk.MustQuery("select count(*) from (select a as c, a as d from t1 union all select a, b from t1) t;").Check(testkit.Rows("12")) + + // #issue 8189 and #issue 8199 + tk.MustExec("drop table if exists t1") + tk.MustExec("drop table if exists t2") + tk.MustExec("CREATE TABLE t1 (a int not null, b char (10) not null)") + tk.MustExec("insert into t1 values(1,'a'),(2,'b'),(3,'c'),(3,'c')") + tk.MustExec("CREATE TABLE t2 (a int not null, b char (10) not null)") + tk.MustExec("insert into t2 values(1,'a'),(2,'b'),(3,'c'),(3,'c')") + tk.MustQuery("select a from t1 union select a from t1 order by (select a+1);").Check(testkit.Rows("1", "2", "3")) + + // #issue 8201 + for i := 0; i < 4; i++ { + tk.MustQuery("SELECT(SELECT 0 AS a FROM dual UNION SELECT 1 AS a FROM dual ORDER BY a ASC LIMIT 1) AS dev").Check(testkit.Rows("0")) + } + + // #issue 8231 + tk.MustExec("drop table if exists t1") + tk.MustExec("CREATE TABLE t1 (uid int(1))") + tk.MustExec("INSERT INTO t1 SELECT 150") + tk.MustQuery("SELECT 'a' UNION SELECT uid FROM t1 order by 1 desc;").Check(testkit.Rows("a", "150")) + + // #issue 8196 + tk.MustExec("drop table if exists t1") + tk.MustExec("drop table if exists t2") + tk.MustExec("CREATE TABLE t1 (a int not null, b char (10) not null)") + tk.MustExec("insert into t1 values(1,'a'),(2,'b'),(3,'c'),(3,'c')") + tk.MustExec("CREATE TABLE t2 (a int not null, b char (10) not null)") + tk.MustExec("insert into t2 values(3,'c'),(4,'d'),(5,'f'),(6,'e')") + tk.MustExec("analyze table t1") + tk.MustExec("analyze table t2") + tk.MustGetErrMsg("(select a,b from t1 limit 2) union all (select a,b from t2 order by a limit 1) order by t1.b", + "[planner:1250]Table 't1' from one of the SELECTs cannot be used in global ORDER clause") + + // #issue 9900 + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b decimal(6, 3))") + tk.MustExec("insert into t values(1, 1.000)") + tk.MustQuery("select count(distinct a), sum(distinct a), avg(distinct a) from (select a from t union all select b from t) tmp;").Check(testkit.Rows("1 1.000 1.0000000")) + + // #issue 23832 + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a bit(20), b float, c double, d int)") + tk.MustExec("insert into t values(10, 10, 10, 10), (1, -1, 2, -2), (2, -2, 1, 1), (2, 1.1, 2.1, 10.1)") + tk.MustQuery("select a from t union select 10 order by a").Check(testkit.Rows("1", "2", "10")) +} + +func TestUnionLimit(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists union_limit") + tk.MustExec("create table union_limit (id int) partition by hash(id) partitions 30") + for i := 0; i < 60; i++ { + tk.MustExec(fmt.Sprintf("insert into union_limit values (%d)", i)) + } + // Cover the code for worker count limit in the union executor. + tk.MustQuery("select * from union_limit limit 10") +} + +func TestLowResolutionTSORead(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("set @@autocommit=1") + tk.MustExec("use test") + tk.MustExec("create table low_resolution_tso(a int key)") + tk.MustExec("insert low_resolution_tso values (1)") + + // enable low resolution tso + require.False(t, tk.Session().GetSessionVars().UseLowResolutionTSO()) + tk.MustExec("set @@tidb_low_resolution_tso = 'on'") + require.True(t, tk.Session().GetSessionVars().UseLowResolutionTSO()) + + tk.MustQuery("select * from low_resolution_tso") + err := tk.ExecToErr("update low_resolution_tso set a = 2") + require.Error(t, err) + tk.MustExec("set @@tidb_low_resolution_tso = 'off'") + tk.MustExec("update low_resolution_tso set a = 2") + tk.MustQuery("select * from low_resolution_tso").Check(testkit.Rows("2")) + + // Test select for update could not be executed when `tidb_low_resolution_tso` is enabled. + type testCase struct { + optimistic bool + pointGet bool + } + cases := []testCase{ + {true, true}, + {true, false}, + {false, true}, + {false, false}, + } + tk.MustExec("set @@tidb_low_resolution_tso = 'on'") + for _, test := range cases { + if test.optimistic { + tk.MustExec("begin optimistic") + } else { + tk.MustExec("begin") + } + var err error + if test.pointGet { + err = tk.ExecToErr("select * from low_resolution_tso where a = 1 for update") + } else { + err = tk.ExecToErr("select * from low_resolution_tso for update") + } + require.Error(t, err) + tk.MustExec("rollback") + } + tk.MustQuery("select * from low_resolution_tso for update") + tk.MustQuery("select * from low_resolution_tso where a = 1 for update") + + origPessimisticAutoCommit := config.GetGlobalConfig().PessimisticTxn.PessimisticAutoCommit.Load() + config.GetGlobalConfig().PessimisticTxn.PessimisticAutoCommit.Store(true) + defer func() { + config.GetGlobalConfig().PessimisticTxn.PessimisticAutoCommit.Store(origPessimisticAutoCommit) + }() + err = tk.ExecToErr("select * from low_resolution_tso where a = 1 for update") + require.Error(t, err) + err = tk.ExecToErr("select * from low_resolution_tso for update") + require.Error(t, err) +} + +func TestLowResolutionTSOReadScope(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk1 := testkit.NewTestKit(t, store) + require.False(t, tk1.Session().GetSessionVars().UseLowResolutionTSO()) + + tk1.MustExec("set global tidb_low_resolution_tso = 'on'") + tk2 := testkit.NewTestKit(t, store) + require.True(t, tk2.Session().GetSessionVars().UseLowResolutionTSO()) +} + +func TestAdapterStatement(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + + tk.Session().GetSessionVars().TxnCtx.InfoSchema = domain.GetDomain(tk.Session()).InfoSchema() + compiler := &executor.Compiler{Ctx: tk.Session()} + s := parser.New() + stmtNode, err := s.ParseOneStmt("select 1", "", "") + require.NoError(t, err) + stmt, err := compiler.Compile(context.TODO(), stmtNode) + require.NoError(t, err) + require.Equal(t, "select 1", stmt.OriginText()) + + gbkSQL := "select '\xb1\xed1'" + stmts, _, err := s.ParseSQL(gbkSQL, parser.CharsetClient("gbk")) + require.NoError(t, err) + stmt, err = compiler.Compile(context.TODO(), stmts[0]) + require.NoError(t, err) + require.Equal(t, "select '表1'", stmt.Text()) + require.Equal(t, gbkSQL, stmt.OriginText()) + + stmtNode, err = s.ParseOneStmt("create table test.t (a int)", "", "") + require.NoError(t, err) + stmt, err = compiler.Compile(context.TODO(), stmtNode) + require.NoError(t, err) + require.Equal(t, "create table test.t (a int)", stmt.OriginText()) +} + +func TestIsPointGet(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use mysql") + ctx := tk.Session().(sessionctx.Context) + tests := map[string]bool{ + "select * from help_topic where name='aaa'": false, + "select 1 from help_topic where name='aaa'": false, + "select * from help_topic where help_topic_id=1": true, + "select * from help_topic where help_category_id=1": false, + } + s := parser.New() + for sqlStr, result := range tests { + stmtNode, err := s.ParseOneStmt(sqlStr, "", "") + require.NoError(t, err) + preprocessorReturn := &plannercore.PreprocessorReturn{} + nodeW := resolve.NewNodeW(stmtNode) + err = plannercore.Preprocess(context.Background(), ctx, nodeW, plannercore.WithPreprocessorReturn(preprocessorReturn)) + require.NoError(t, err) + p, _, err := planner.Optimize(context.TODO(), ctx, nodeW, preprocessorReturn.InfoSchema) + require.NoError(t, err) + ret := plannercore.IsPointGetWithPKOrUniqueKeyByAutoCommit(ctx.GetSessionVars(), p) + require.Equal(t, result, ret) + } +} + +func TestPointGetOrderby(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t (i int key)") + require.Equal(t, tk.ExecToErr("select * from t where i = 1 order by j limit 10;").Error(), "[planner:1054]Unknown column 'j' in 'order clause'") +} + +func TestClusteredIndexIsPointGet(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("drop database if exists test_cluster_index_is_point_get;") + tk.MustExec("create database test_cluster_index_is_point_get;") + tk.MustExec("use test_cluster_index_is_point_get;") + + tk.Session().GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t (a varchar(255), b int, c char(10), primary key (c, a));") + ctx := tk.Session().(sessionctx.Context) + + tests := map[string]bool{ + "select 1 from t where a='x'": false, + "select * from t where c='x'": false, + "select * from t where a='x' and c='x'": true, + "select * from t where a='x' and c='x' and b=1": false, + } + s := parser.New() + for sqlStr, result := range tests { + stmtNode, err := s.ParseOneStmt(sqlStr, "", "") + require.NoError(t, err) + preprocessorReturn := &plannercore.PreprocessorReturn{} + nodeW := resolve.NewNodeW(stmtNode) + err = plannercore.Preprocess(context.Background(), ctx, nodeW, plannercore.WithPreprocessorReturn(preprocessorReturn)) + require.NoError(t, err) + p, _, err := planner.Optimize(context.TODO(), ctx, nodeW, preprocessorReturn.InfoSchema) + require.NoError(t, err) + ret := plannercore.IsPointGetWithPKOrUniqueKeyByAutoCommit(ctx.GetSessionVars(), p) + require.Equal(t, result, ret) + } +} + +func TestColumnName(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (c int, d int)") + // disable only full group by + tk.MustExec("set sql_mode='STRICT_TRANS_TABLES'") + rs, err := tk.Exec("select 1 + c, count(*) from t") + require.NoError(t, err) + fields := rs.Fields() + require.Len(t, fields, 2) + require.Equal(t, "1 + c", fields[0].Column.Name.L) + require.Equal(t, "1 + c", fields[0].ColumnAsName.L) + require.Equal(t, "count(*)", fields[1].Column.Name.L) + require.Equal(t, "count(*)", fields[1].ColumnAsName.L) + require.NoError(t, rs.Close()) + rs, err = tk.Exec("select (c) > all (select c from t) from t") + require.NoError(t, err) + fields = rs.Fields() + require.Len(t, fields, 1) + require.Equal(t, "(c) > all (select c from t)", fields[0].Column.Name.L) + require.Equal(t, "(c) > all (select c from t)", fields[0].ColumnAsName.L) + require.NoError(t, rs.Close()) + tk.MustExec("begin") + tk.MustExec("insert t values(1,1)") + rs, err = tk.Exec("select c d, d c from t") + require.NoError(t, err) + fields = rs.Fields() + require.Len(t, fields, 2) + require.Equal(t, "c", fields[0].Column.Name.L) + require.Equal(t, "d", fields[0].ColumnAsName.L) + require.Equal(t, "d", fields[1].Column.Name.L) + require.Equal(t, "c", fields[1].ColumnAsName.L) + require.NoError(t, rs.Close()) + // Test case for query a column of a table. + // In this case, all attributes have values. + rs, err = tk.Exec("select c as a from t as t2") + require.NoError(t, err) + fields = rs.Fields() + require.Equal(t, "c", fields[0].Column.Name.L) + require.Equal(t, "a", fields[0].ColumnAsName.L) + require.Equal(t, "t", fields[0].Table.Name.L) + require.Equal(t, "t2", fields[0].TableAsName.L) + require.Equal(t, "test", fields[0].DBName.L) + require.Nil(t, rs.Close()) + // Test case for query a expression which only using constant inputs. + // In this case, the table, org_table and database attributes will all be empty. + rs, err = tk.Exec("select hour(1) as a from t as t2") + require.NoError(t, err) + fields = rs.Fields() + require.Equal(t, "a", fields[0].Column.Name.L) + require.Equal(t, "a", fields[0].ColumnAsName.L) + require.Equal(t, "", fields[0].Table.Name.L) + require.Equal(t, "", fields[0].TableAsName.L) + require.Equal(t, "", fields[0].DBName.L) + require.Nil(t, rs.Close()) + // Test case for query a column wrapped with parentheses and unary plus. + // In this case, the column name should be its original name. + rs, err = tk.Exec("select (c), (+c), +(c), +(+(c)), ++c from t") + require.NoError(t, err) + fields = rs.Fields() + for i := 0; i < 5; i++ { + require.Equal(t, "c", fields[i].Column.Name.L) + require.Equal(t, "c", fields[i].ColumnAsName.L) + } + require.Nil(t, rs.Close()) + + // Test issue https://github.com/pingcap/tidb/issues/9639 . + // Both window function and expression appear in final result field. + tk.MustExec("set @@tidb_enable_window_function = 1") + rs, err = tk.Exec("select 1+1, row_number() over() num from t") + require.NoError(t, err) + fields = rs.Fields() + require.Equal(t, "1+1", fields[0].Column.Name.L) + require.Equal(t, "1+1", fields[0].ColumnAsName.L) + require.Equal(t, "num", fields[1].Column.Name.L) + require.Equal(t, "num", fields[1].ColumnAsName.L) + require.Nil(t, rs.Close()) + tk.MustExec("set @@tidb_enable_window_function = 0") + + rs, err = tk.Exec("select if(1,c,c) from t;") + require.NoError(t, err) + fields = rs.Fields() + require.Equal(t, "if(1,c,c)", fields[0].Column.Name.L) + // It's a compatibility issue. Should be empty instead. + require.Equal(t, "if(1,c,c)", fields[0].ColumnAsName.L) + require.Nil(t, rs.Close()) +} + +func TestSelectVar(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (d int)") + tk.MustExec("insert into t values(1), (2), (1)") + // This behavior is different from MySQL. + result := tk.MustQuery("select @a, @a := d+1 from t") + result.Check(testkit.Rows(" 2", "2 3", "3 2")) + // Test for PR #10658. + tk.MustExec("select SQL_BIG_RESULT d from t group by d") + tk.MustExec("select SQL_SMALL_RESULT d from t group by d") + tk.MustExec("select SQL_BUFFER_RESULT d from t group by d") +} + +func TestHistoryRead(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists history_read") + tk.MustExec("create table history_read (a int)") + tk.MustExec("insert history_read values (1)") + + // For mocktikv, safe point is not initialized, we manually insert it for snapshot to use. + safePointName := "tikv_gc_safe_point" + safePointValue := "20060102-15:04:05 -0700" + safePointComment := "All versions after safe point can be accessed. (DO NOT EDIT)" + updateSafePoint := fmt.Sprintf(`INSERT INTO mysql.tidb VALUES ('%[1]s', '%[2]s', '%[3]s') + ON DUPLICATE KEY + UPDATE variable_value = '%[2]s', comment = '%[3]s'`, safePointName, safePointValue, safePointComment) + tk.MustExec(updateSafePoint) + + // Set snapshot to a time before save point will fail. + _, err := tk.Exec("set @@tidb_snapshot = '2006-01-01 15:04:05.999999'") + require.True(t, terror.ErrorEqual(err, variable.ErrSnapshotTooOld), "err %v", err) + // SnapshotTS Is not updated if check failed. + require.Equal(t, uint64(0), tk.Session().GetSessionVars().SnapshotTS) + + // Setting snapshot to a time in the future will fail. (One day before the 2038 problem) + _, err = tk.Exec("set @@tidb_snapshot = '2038-01-18 03:14:07'") + require.Regexp(t, "cannot set read timestamp to a future time", err) + // SnapshotTS Is not updated if check failed. + require.Equal(t, uint64(0), tk.Session().GetSessionVars().SnapshotTS) + + curVer1, _ := store.CurrentVersion(kv.GlobalTxnScope) + time.Sleep(time.Millisecond) + snapshotTime := time.Now() + time.Sleep(time.Millisecond) + curVer2, _ := store.CurrentVersion(kv.GlobalTxnScope) + tk.MustExec("insert history_read values (2)") + tk.MustQuery("select * from history_read").Check(testkit.Rows("1", "2")) + tk.MustExec("set @@tidb_snapshot = '" + snapshotTime.Format("2006-01-02 15:04:05.999999") + "'") + ctx := tk.Session().(sessionctx.Context) + snapshotTS := ctx.GetSessionVars().SnapshotTS + require.Greater(t, snapshotTS, curVer1.Ver) + require.Less(t, snapshotTS, curVer2.Ver) + tk.MustQuery("select * from history_read").Check(testkit.Rows("1")) + tk.MustExecToErr("insert history_read values (2)") + tk.MustExecToErr("update history_read set a = 3 where a = 1") + tk.MustExecToErr("delete from history_read where a = 1") + tk.MustExec("set @@tidb_snapshot = ''") + tk.MustQuery("select * from history_read").Check(testkit.Rows("1", "2")) + tk.MustExec("insert history_read values (3)") + tk.MustExec("update history_read set a = 4 where a = 3") + tk.MustExec("delete from history_read where a = 1") + + time.Sleep(time.Millisecond) + snapshotTime = time.Now() + time.Sleep(time.Millisecond) + tk.MustExec("alter table history_read add column b int") + tk.MustExec("insert history_read values (8, 8), (9, 9)") + tk.MustQuery("select * from history_read order by a").Check(testkit.Rows("2 ", "4 ", "8 8", "9 9")) + tk.MustExec("set @@tidb_snapshot = '" + snapshotTime.Format("2006-01-02 15:04:05.999999") + "'") + tk.MustQuery("select * from history_read order by a").Check(testkit.Rows("2", "4")) + tsoStr := strconv.FormatUint(oracle.GoTimeToTS(snapshotTime), 10) + + tk.MustExec("set @@tidb_snapshot = '" + tsoStr + "'") + tk.MustQuery("select * from history_read order by a").Check(testkit.Rows("2", "4")) + + tk.MustExec("set @@tidb_snapshot = ''") + tk.MustQuery("select * from history_read order by a").Check(testkit.Rows("2 ", "4 ", "8 8", "9 9")) +} + +func TestHistoryReadInTxn(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + + // For mocktikv, safe point is not initialized, we manually insert it for snapshot to use. + safePointName := "tikv_gc_safe_point" + safePointValue := "20060102-15:04:05 -0700" + safePointComment := "All versions after safe point can be accessed. (DO NOT EDIT)" + updateSafePoint := fmt.Sprintf(`INSERT INTO mysql.tidb VALUES ('%[1]s', '%[2]s', '%[3]s') + ON DUPLICATE KEY + UPDATE variable_value = '%[2]s', comment = '%[3]s'`, safePointName, safePointValue, safePointComment) + tk.MustExec(updateSafePoint) + + tk.MustExec("drop table if exists his_t0, his_t1") + tk.MustExec("create table his_t0(id int primary key, v int)") + tk.MustExec("insert into his_t0 values(1, 10)") + + time.Sleep(time.Millisecond) + tk.MustExec("set @a=now(6)") + time.Sleep(time.Millisecond) + tk.MustExec("create table his_t1(id int primary key, v int)") + tk.MustExec("update his_t0 set v=v+1") + time.Sleep(time.Millisecond) + tk.MustExec("set tidb_snapshot=now(6)") + ts2 := tk.Session().GetSessionVars().SnapshotTS + tk.MustExec("set tidb_snapshot=''") + time.Sleep(time.Millisecond) + tk.MustExec("update his_t0 set v=v+1") + tk.MustExec("insert into his_t1 values(10, 100)") + + init := func(isolation string, setSnapshotBeforeTxn bool) { + if isolation == "none" { + tk.MustExec("set @@tidb_snapshot=@a") + return + } + + if setSnapshotBeforeTxn { + tk.MustExec("set @@tidb_snapshot=@a") + } + + if isolation == "optimistic" { + tk.MustExec("begin optimistic") + } else { + tk.MustExec(fmt.Sprintf("set @@tx_isolation='%s'", isolation)) + tk.MustExec("begin pessimistic") + } + + if !setSnapshotBeforeTxn { + tk.MustExec("set @@tidb_snapshot=@a") + } + } + + for _, isolation := range []string{ + "none", // not start an explicit txn + "optimistic", + "REPEATABLE-READ", + "READ-COMMITTED", + } { + for _, setSnapshotBeforeTxn := range []bool{false, true} { + t.Run(fmt.Sprintf("[%s] setSnapshotBeforeTxn[%v]", isolation, setSnapshotBeforeTxn), func(t *testing.T) { + tk.MustExec("rollback") + tk.MustExec("set @@tidb_snapshot=''") + + init(isolation, setSnapshotBeforeTxn) + // When tidb_snapshot is set, should use the snapshot info schema + tk.MustQuery("show tables like 'his_%'").Check(testkit.Rows("his_t0")) + + // When tidb_snapshot is set, select should use select ts + tk.MustQuery("select * from his_t0").Check(testkit.Rows("1 10")) + tk.MustQuery("select * from his_t0 where id=1").Check(testkit.Rows("1 10")) + + // When tidb_snapshot is set, write statements should not be allowed + if isolation != "none" && isolation != "optimistic" { + notAllowedSQLs := []string{ + "insert into his_t0 values(5, 1)", + "delete from his_t0 where id=1", + "update his_t0 set v=v+1", + "select * from his_t0 for update", + "select * from his_t0 where id=1 for update", + "create table his_t2(id int)", + } + + for _, sql := range notAllowedSQLs { + err := tk.ExecToErr(sql) + require.Errorf(t, err, "can not execute write statement when 'tidb_snapshot' is set") + } + } + + // After `ExecRestrictedSQL` with a specified snapshot and use current session, the original snapshot ts should not be reset + // See issue: https://github.com/pingcap/tidb/issues/34529 + exec := tk.Session().GetRestrictedSQLExecutor() + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnOthers) + rows, _, err := exec.ExecRestrictedSQL(ctx, []sqlexec.OptionFuncAlias{sqlexec.ExecOptionWithSnapshot(ts2), sqlexec.ExecOptionUseCurSession}, "select * from his_t0 where id=1") + require.NoError(t, err) + require.Equal(t, 1, len(rows)) + require.Equal(t, int64(1), rows[0].GetInt64(0)) + require.Equal(t, int64(11), rows[0].GetInt64(1)) + tk.MustQuery("select * from his_t0 where id=1").Check(testkit.Rows("1 10")) + tk.MustQuery("show tables like 'his_%'").Check(testkit.Rows("his_t0")) + + // CLEAR + tk.MustExec("set @@tidb_snapshot=''") + + // When tidb_snapshot is not set, should use the transaction's info schema + tk.MustQuery("show tables like 'his_%'").Check(testkit.Rows("his_t0", "his_t1")) + + // When tidb_snapshot is not set, select should use the transaction's ts + tk.MustQuery("select * from his_t0").Check(testkit.Rows("1 12")) + tk.MustQuery("select * from his_t0 where id=1").Check(testkit.Rows("1 12")) + tk.MustQuery("select * from his_t1").Check(testkit.Rows("10 100")) + tk.MustQuery("select * from his_t1 where id=10").Check(testkit.Rows("10 100")) + + // When tidb_snapshot is not set, select ... for update should not be effected + tk.MustQuery("select * from his_t0 for update").Check(testkit.Rows("1 12")) + tk.MustQuery("select * from his_t0 where id=1 for update").Check(testkit.Rows("1 12")) + tk.MustQuery("select * from his_t1 for update").Check(testkit.Rows("10 100")) + tk.MustQuery("select * from his_t1 where id=10 for update").Check(testkit.Rows("10 100")) + + tk.MustExec("rollback") + }) + } + } +} + +func TestCurrentTimestampValueSelection(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t,t1") + + tk.MustExec("create table t (id int, t0 timestamp null default current_timestamp, t1 timestamp(1) null default current_timestamp(1), t2 timestamp(2) null default current_timestamp(2) on update current_timestamp(2))") + tk.MustExec("insert into t (id) values (1)") + rs := tk.MustQuery("select t0, t1, t2 from t where id = 1") + t0 := rs.Rows()[0][0].(string) + t1 := rs.Rows()[0][1].(string) + t2 := rs.Rows()[0][2].(string) + require.Equal(t, 1, len(strings.Split(t0, "."))) + require.Equal(t, 1, len(strings.Split(t1, ".")[1])) + require.Equal(t, 2, len(strings.Split(t2, ".")[1])) + tk.MustQuery("select id from t where t0 = ?", t0).Check(testkit.Rows("1")) + tk.MustQuery("select id from t where t1 = ?", t1).Check(testkit.Rows("1")) + tk.MustQuery("select id from t where t2 = ?", t2).Check(testkit.Rows("1")) + time.Sleep(time.Second) + tk.MustExec("update t set t0 = now() where id = 1") + rs = tk.MustQuery("select t2 from t where id = 1") + newT2 := rs.Rows()[0][0].(string) + require.True(t, newT2 != t2) + + tk.MustExec("create table t1 (id int, a timestamp, b timestamp(2), c timestamp(3))") + tk.MustExec("insert into t1 (id, a, b, c) values (1, current_timestamp(2), current_timestamp, current_timestamp(3))") + rs = tk.MustQuery("select a, b, c from t1 where id = 1") + a := rs.Rows()[0][0].(string) + b := rs.Rows()[0][1].(string) + d := rs.Rows()[0][2].(string) + require.Equal(t, 1, len(strings.Split(a, "."))) + require.Equal(t, "00", strings.Split(b, ".")[1]) + require.Equal(t, 3, len(strings.Split(d, ".")[1])) +} + +func TestAdmin(t *testing.T) { + var cluster testutils.Cluster + store := testkit.CreateMockStore(t, mockstore.WithClusterInspector(func(c testutils.Cluster) { + mockstore.BootstrapWithSingleStore(c) + cluster = c + })) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk2 := testkit.NewTestKit(t, store) + tk2.MustExec("use test") + tk.MustExec("drop table if exists admin_test") + tk.MustExec("create table admin_test (c1 int, c2 int, c3 int default 1, index (c1))") + tk.MustExec("insert admin_test (c1) values (1),(2),(NULL)") + + ctx := context.Background() + // cancel DDL jobs test + r, err := tk.Exec("admin cancel ddl jobs 1") + require.NoError(t, err) + req := r.NewChunk(nil) + err = r.Next(ctx, req) + require.NoError(t, err) + row := req.GetRow(0) + require.Equal(t, 2, row.Len()) + require.Equal(t, "1", row.GetString(0)) + require.Regexp(t, ".*DDL Job:1 not found", row.GetString(1)) + + // show ddl test; + r, err = tk.Exec("admin show ddl") + require.NoError(t, err) + req = r.NewChunk(nil) + err = r.Next(ctx, req) + require.NoError(t, err) + row = req.GetRow(0) + require.Equal(t, 6, row.Len()) + tk = testkit.NewTestKit(t, store) + tk.MustExec("begin") + sess := tk.Session() + ddlInfo, err := ddl.GetDDLInfo(sess) + require.NoError(t, err) + require.Equal(t, ddlInfo.SchemaVer, row.GetInt64(0)) + // TODO: Pass this test. + // rowOwnerInfos := strings.Split(row.Data[1].GetString(), ",") + // ownerInfos := strings.Split(ddlInfo.Owner.String(), ",") + // c.Assert(rowOwnerInfos[0], Equals, ownerInfos[0]) + serverInfo, err := infosync.GetServerInfoByID(ctx, row.GetString(1)) + require.NoError(t, err) + require.Equal(t, serverInfo.IP+":"+strconv.FormatUint(uint64(serverInfo.Port), 10), row.GetString(2)) + require.Equal(t, "", row.GetString(3)) + req = r.NewChunk(nil) + err = r.Next(ctx, req) + require.NoError(t, err) + require.Zero(t, req.NumRows()) + tk.MustExec("rollback") + + // show DDL jobs test + r, err = tk.Exec("admin show ddl jobs") + require.NoError(t, err) + req = r.NewChunk(nil) + err = r.Next(ctx, req) + require.NoError(t, err) + row = req.GetRow(0) + require.Equal(t, 13, row.Len()) + txn, err := store.Begin() + require.NoError(t, err) + historyJobs, err := ddl.GetLastNHistoryDDLJobs(meta.NewMutator(txn), ddl.DefNumHistoryJobs) + require.Greater(t, len(historyJobs), 1) + require.Greater(t, len(row.GetString(1)), 0) + require.NoError(t, err) + require.Equal(t, historyJobs[0].ID, row.GetInt64(0)) + require.NoError(t, err) + + r, err = tk.Exec("admin show ddl jobs 20") + require.NoError(t, err) + req = r.NewChunk(nil) + err = r.Next(ctx, req) + require.NoError(t, err) + row = req.GetRow(0) + require.Equal(t, 13, row.Len()) + require.Equal(t, historyJobs[0].ID, row.GetInt64(0)) + require.NoError(t, err) + + // show DDL job queries test + tk.MustExec("use test") + tk.MustExec("drop table if exists admin_test2") + tk.MustExec("create table admin_test2 (c1 int, c2 int, c3 int default 1, index (c1))") + result := tk.MustQuery(`admin show ddl job queries 1, 1, 1`) + result.Check(testkit.Rows()) + result = tk.MustQuery(`admin show ddl job queries 1, 2, 3, 4`) + result.Check(testkit.Rows()) + historyJobs, err = ddl.GetLastNHistoryDDLJobs(meta.NewMutator(txn), ddl.DefNumHistoryJobs) + result = tk.MustQuery(fmt.Sprintf("admin show ddl job queries %d", historyJobs[0].ID)) + result.Check(testkit.Rows(historyJobs[0].Query)) + require.NoError(t, err) + + // show DDL job queries with range test + tk.MustExec("use test") + tk.MustExec("drop table if exists admin_test2") + tk.MustExec("create table admin_test2 (c1 int, c2 int, c3 int default 1, index (c1))") + tk.MustExec("drop table if exists admin_test3") + tk.MustExec("create table admin_test3 (c1 int, c2 int, c3 int default 1, index (c1))") + tk.MustExec("drop table if exists admin_test4") + tk.MustExec("create table admin_test4 (c1 int, c2 int, c3 int default 1, index (c1))") + tk.MustExec("drop table if exists admin_test5") + tk.MustExec("create table admin_test5 (c1 int, c2 int, c3 int default 1, index (c1))") + tk.MustExec("drop table if exists admin_test6") + tk.MustExec("create table admin_test6 (c1 int, c2 int, c3 int default 1, index (c1))") + tk.MustExec("drop table if exists admin_test7") + tk.MustExec("create table admin_test7 (c1 int, c2 int, c3 int default 1, index (c1))") + tk.MustExec("drop table if exists admin_test8") + tk.MustExec("create table admin_test8 (c1 int, c2 int, c3 int default 1, index (c1))") + historyJobs, err = ddl.GetLastNHistoryDDLJobs(meta.NewMutator(txn), ddl.DefNumHistoryJobs) + result = tk.MustQuery(`admin show ddl job queries limit 3`) + result.Check(testkit.Rows(fmt.Sprintf("%d %s", historyJobs[0].ID, historyJobs[0].Query), fmt.Sprintf("%d %s", historyJobs[1].ID, historyJobs[1].Query), fmt.Sprintf("%d %s", historyJobs[2].ID, historyJobs[2].Query))) + result = tk.MustQuery(`admin show ddl job queries limit 3, 2`) + result.Check(testkit.Rows(fmt.Sprintf("%d %s", historyJobs[3].ID, historyJobs[3].Query), fmt.Sprintf("%d %s", historyJobs[4].ID, historyJobs[4].Query))) + result = tk.MustQuery(`admin show ddl job queries limit 3 offset 2`) + result.Check(testkit.Rows(fmt.Sprintf("%d %s", historyJobs[2].ID, historyJobs[2].Query), fmt.Sprintf("%d %s", historyJobs[3].ID, historyJobs[3].Query), fmt.Sprintf("%d %s", historyJobs[4].ID, historyJobs[4].Query))) + require.NoError(t, err) + + // check situations when `admin show ddl job 20` happens at the same time with new DDLs being executed + var wg sync.WaitGroup + wg.Add(2) + flag := true + go func() { + defer wg.Done() + for i := 0; i < 10; i++ { + tk.MustExec("drop table if exists admin_test9") + tk.MustExec("create table admin_test9 (c1 int, c2 int, c3 int default 1, index (c1))") + } + }() + go func() { + // check that the result set has no duplication + defer wg.Done() + for i := 0; i < 10; i++ { + result := tk2.MustQuery(`admin show ddl job queries 20`) + rows := result.Rows() + rowIDs := make(map[string]struct{}) + for _, row := range rows { + rowID := fmt.Sprintf("%v", row[0]) + if _, ok := rowIDs[rowID]; ok { + flag = false + return + } + rowIDs[rowID] = struct{}{} + } + } + }() + wg.Wait() + require.True(t, flag) + + // check situations when `admin show ddl job queries limit 3 offset 2` happens at the same time with new DDLs being executed + var wg2 sync.WaitGroup + wg2.Add(2) + flag = true + go func() { + defer wg2.Done() + for i := 0; i < 10; i++ { + tk.MustExec("drop table if exists admin_test9") + tk.MustExec("create table admin_test9 (c1 int, c2 int, c3 int default 1, index (c1))") + } + }() + go func() { + // check that the result set has no duplication + defer wg2.Done() + for i := 0; i < 10; i++ { + result := tk2.MustQuery(`admin show ddl job queries limit 3 offset 2`) + rows := result.Rows() + rowIDs := make(map[string]struct{}) + for _, row := range rows { + rowID := fmt.Sprintf("%v", row[0]) + if _, ok := rowIDs[rowID]; ok { + flag = false + return + } + rowIDs[rowID] = struct{}{} + } + } + }() + wg2.Wait() + require.True(t, flag) + + // check table test + tk.MustExec("create table admin_test1 (c1 int, c2 int default 1, index (c1))") + tk.MustExec("insert admin_test1 (c1) values (21),(22)") + r, err = tk.Exec("admin check table admin_test, admin_test1") + require.NoError(t, err) + require.Nil(t, r) + // error table name + require.Error(t, tk.ExecToErr("admin check table admin_test_error")) + // different index values + dom := domain.GetDomain(tk.Session()) + is := dom.InfoSchema() + require.NotNil(t, is) + tb, err := is.TableByName(context.Background(), pmodel.NewCIStr("test"), pmodel.NewCIStr("admin_test")) + require.NoError(t, err) + require.Len(t, tb.Indices(), 1) + _, err = tb.Indices()[0].Create(mock.NewContext().GetTableCtx(), txn, types.MakeDatums(int64(10)), kv.IntHandle(1), nil) + require.NoError(t, err) + err = txn.Commit(context.Background()) + require.NoError(t, err) + errAdmin := tk.ExecToErr("admin check table admin_test") + require.Error(t, errAdmin) + + if config.CheckTableBeforeDrop { + tk.MustGetErrMsg("drop table admin_test", errAdmin.Error()) + + // Drop inconsistency index. + tk.MustExec("alter table admin_test drop index c1") + tk.MustExec("admin check table admin_test") + } + // checksum table test + tk.MustExec("create table checksum_with_index (id int, count int, PRIMARY KEY(id), KEY(count))") + tk.MustExec("create table checksum_without_index (id int, count int, PRIMARY KEY(id))") + r, err = tk.Exec("admin checksum table checksum_with_index, checksum_without_index") + require.NoError(t, err) + res := tk.ResultSetToResult(r, "admin checksum table") + // Mocktikv returns 1 for every table/index scan, then we will xor the checksums of a table. + // For "checksum_with_index", we have two checksums, so the result will be 1^1 = 0. + // For "checksum_without_index", we only have one checksum, so the result will be 1. + res.Sort().Check(testkit.Rows("test checksum_with_index 0 2 2", "test checksum_without_index 1 1 1")) + + tk.MustExec("drop table if exists t1;") + tk.MustExec("CREATE TABLE t1 (c2 BOOL, PRIMARY KEY (c2));") + tk.MustExec("INSERT INTO t1 SET c2 = '0';") + tk.MustExec("ALTER TABLE t1 ADD COLUMN c3 DATETIME NULL DEFAULT '2668-02-03 17:19:31';") + tk.MustExec("ALTER TABLE t1 ADD INDEX idx2 (c3);") + tk.MustExec("ALTER TABLE t1 ADD COLUMN c4 bit(10) default 127;") + tk.MustExec("ALTER TABLE t1 ADD INDEX idx3 (c4);") + tk.MustExec("admin check table t1;") + + // Test admin show ddl jobs table name after table has been droped. + tk.MustExec("drop table if exists t1;") + re := tk.MustQuery("admin show ddl jobs 1") + rows := re.Rows() + require.Len(t, rows, 1) + require.Equal(t, "t1", rows[0][2]) + + // Test for reverse scan get history ddl jobs when ddl history jobs queue has multiple regions. + txn, err = store.Begin() + require.NoError(t, err) + historyJobs, err = ddl.GetLastNHistoryDDLJobs(meta.NewMutator(txn), 20) + require.NoError(t, err) + + // Split region for history ddl job queues. + m := meta.NewMutator(txn) + startKey := meta.DDLJobHistoryKey(m, 0) + endKey := meta.DDLJobHistoryKey(m, historyJobs[0].ID) + cluster.SplitKeys(startKey, endKey, int(historyJobs[0].ID/5)) + + historyJobs2, err := ddl.GetLastNHistoryDDLJobs(meta.NewMutator(txn), 20) + require.NoError(t, err) + require.Equal(t, historyJobs2, historyJobs) +} + +func TestMaxOneRow(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec(`use test`) + tk.MustExec(`drop table if exists t1`) + tk.MustExec(`drop table if exists t2`) + tk.MustExec(`create table t1(a double, b double);`) + tk.MustExec(`create table t2(a double, b double);`) + tk.MustExec(`insert into t1 values(1, 1), (2, 2), (3, 3);`) + tk.MustExec(`insert into t2 values(0, 0);`) + tk.MustExec(`set @@tidb_init_chunk_size=1;`) + rs, err := tk.Exec(`select (select t1.a from t1 where t1.a > t2.a) as a from t2;`) + require.NoError(t, err) + + err = rs.Next(context.TODO(), rs.NewChunk(nil)) + require.Error(t, err) + require.Equal(t, "[executor:1242]Subquery returns more than 1 row", err.Error()) + require.NoError(t, rs.Close()) +} + +func TestIsFastPlan(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t(id int primary key, a int)") + + cases := []struct { + sql string + isFastPlan bool + }{ + {"select a from t where id=1", true}, + {"select a+id from t where id=1", true}, + {"select 1", true}, + {"select @@autocommit", true}, + {"set @@autocommit=1", true}, + {"set @a=1", true}, + {"select * from t where a=1", false}, + {"select * from t", false}, + } + + for _, ca := range cases { + if strings.HasPrefix(ca.sql, "select") { + tk.MustQuery(ca.sql) + } else { + tk.MustExec(ca.sql) + } + info := tk.Session().ShowProcess() + require.NotNil(t, info) + p, ok := info.Plan.(base.Plan) + require.True(t, ok) + ok = executor.IsFastPlan(p) + require.Equal(t, ca.isFastPlan, ok) + } +} + +func TestGlobalMemoryControl2(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + + tk0 := testkit.NewTestKit(t, store) + tk0.MustExec("set global tidb_mem_oom_action = 'cancel'") + tk0.MustExec("set global tidb_server_memory_limit = 1 << 30") + tk0.MustExec("set global tidb_server_memory_limit_sess_min_size = 128") + + sm := &testkit.MockSessionManager{ + PS: []*util.ProcessInfo{tk0.Session().ShowProcess()}, + } + dom.ServerMemoryLimitHandle().SetSessionManager(sm) + go dom.ServerMemoryLimitHandle().Run() + + tk0.MustExec("use test") + tk0.MustExec("create table t(a int)") + tk0.MustExec("insert into t select 1") + for i := 1; i <= 8; i++ { + tk0.MustExec("insert into t select * from t") // 256 Lines + } + + var test []int + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + time.Sleep(100 * time.Millisecond) // Make sure the sql is running. + test = make([]int, 128<<20) // Keep 1GB HeapInuse + wg.Done() + }() + sql := "select * from t t1 join t t2 join t t3 on t1.a=t2.a and t1.a=t3.a order by t1.a;" // Need 500MB + require.True(t, exeerrors.ErrMemoryExceedForInstance.Equal(tk0.QueryToErr(sql))) + require.Equal(t, tk0.Session().GetSessionVars().DiskTracker.MaxConsumed(), int64(0)) + wg.Wait() + test[0] = 0 + runtime.GC() +} + +func TestSignalCheckpointForSort(t *testing.T) { + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/executor/sortexec/SignalCheckpointForSort", `return(true)`)) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/executor/sortexec/SignalCheckpointForSort")) + }() + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/util/chunk/SignalCheckpointForSort", `return(true)`)) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/util/chunk/SignalCheckpointForSort")) + }() + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + + defer tk.MustExec("set global tidb_mem_oom_action = DEFAULT") + tk.MustExec("set global tidb_mem_oom_action='CANCEL'") + tk.MustExec("set tidb_mem_quota_query = 100000000") + tk.MustExec("use test") + tk.MustExec("create table t(a int)") + for i := 0; i < 20; i++ { + tk.MustExec(fmt.Sprintf("insert into t values(%d)", i)) + } + tk.Session().GetSessionVars().ConnectionID = 123456 + + err := tk.QueryToErr("select * from t order by a") + require.True(t, exeerrors.ErrMemoryExceedForQuery.Equal(err)) +} + +func TestSessionRootTrackerDetach(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + defer tk.MustExec("set global tidb_mem_oom_action = DEFAULT") + tk.MustExec("set global tidb_mem_oom_action='CANCEL'") + tk.MustExec("use test") + tk.MustExec("create table t(a int, b int, index idx(a))") + tk.MustExec("create table t1(a int, c int, index idx(a))") + tk.MustExec("set tidb_mem_quota_query=10") + err := tk.ExecToErr("select /*+hash_join(t1)*/ t.a, t1.a from t use index(idx), t1 use index(idx) where t.a = t1.a") + fmt.Println(err.Error()) + require.True(t, exeerrors.ErrMemoryExceedForQuery.Equal(err)) + tk.MustExec("set tidb_mem_quota_query=1000") + rs, err := tk.Exec("select /*+hash_join(t1)*/ t.a, t1.a from t use index(idx), t1 use index(idx) where t.a = t1.a") + require.NoError(t, err) + require.NotNil(t, tk.Session().GetSessionVars().MemTracker.GetFallbackForTest(false)) + err = rs.Close() + require.NoError(t, err) + require.Nil(t, tk.Session().GetSessionVars().MemTracker.GetFallbackForTest(false)) +} + +func TestProcessInfoOfSubQuery(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk2 := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t (i int, j int);") + var wg sync.WaitGroup + wg.Add(1) + go func() { + tk.MustQuery("select 1, (select sleep(count(1) + 2) from t);") + wg.Done() + }() + time.Sleep(time.Second) + tk2.MustQuery("select 1 from information_schema.processlist where TxnStart != '' and info like 'select%sleep% from t%'").Check(testkit.Rows("1")) + wg.Wait() +} + +func TestIssues49377(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table employee (employee_id int, name varchar(20), dept_id int)") + tk.MustExec("insert into employee values (1, 'Furina', 1), (2, 'Klee', 1), (3, 'Eula', 1), (4, 'Diluc', 2), (5, 'Tartaglia', 2)") + + tk.MustQuery("select 1,1,1 union all ( " + + "(select * from employee where dept_id = 1) " + + "union all " + + "(select * from employee where dept_id = 1 order by employee_id) " + + "order by 1 limit 1 " + + ");").Sort().Check(testkit.Rows("1 1 1", "1 Furina 1")) + + tk.MustQuery("select 1,1,1 union all ( " + + "(select * from employee where dept_id = 1) " + + "union all " + + "(select * from employee where dept_id = 1 order by employee_id) " + + "order by 1" + + ");").Sort().Check(testkit.Rows("1 1 1", "1 Furina 1", "1 Furina 1", "2 Klee 1", "2 Klee 1", "3 Eula 1", "3 Eula 1")) + + tk.MustQuery("select * from employee where dept_id = 1 " + + "union all " + + "(select * from employee where dept_id = 1 order by employee_id) " + + "union all" + + "(" + + "select * from employee where dept_id = 1 " + + "union all " + + "(select * from employee where dept_id = 1 order by employee_id) " + + "limit 1" + + ");").Sort().Check(testkit.Rows("1 Furina 1", "1 Furina 1", "1 Furina 1", "2 Klee 1", "2 Klee 1", "3 Eula 1", "3 Eula 1")) +} + +func TestIssues40463(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + + tk.MustExec("use test;") + tk.MustExec("CREATE TABLE `4f380f26-9af6-4df8-959d-ad6296eff914` (`f7a9a4be-3728-449b-a5ea-df9b957eec67` enum('bkdv0','9rqy','lw','neud','ym','4nbv','9a7','bpkfo','xtfl','59','6vjj') NOT NULL DEFAULT 'neud', `43ca0135-1650-429b-8887-9eabcae2a234` set('8','5x47','xc','o31','lnz','gs5s','6yam','1','20ea','i','e') NOT NULL DEFAULT 'e', PRIMARY KEY (`f7a9a4be-3728-449b-a5ea-df9b957eec67`,`43ca0135-1650-429b-8887-9eabcae2a234`) /*T![clustered_index] CLUSTERED */) ENGINE=InnoDB DEFAULT CHARSET=ascii COLLATE=ascii_bin;") + tk.MustExec("INSERT INTO `4f380f26-9af6-4df8-959d-ad6296eff914` VALUES ('bkdv0','gs5s'),('lw','20ea'),('neud','8'),('ym','o31'),('4nbv','o31'),('xtfl','e');") + + tk.MustExec("CREATE TABLE `ba35a09f-76f4-40aa-9b48-13154a24bdd2` (`9b2a7138-14a3-4e8f-b29a-720392aad22c` set('zgn','if8yo','e','k7','bav','xj6','lkag','m5','as','ia','l3') DEFAULT 'zgn,if8yo,e,k7,ia,l3',`a60d6b5c-08bd-4a6d-b951-716162d004a5` set('6li6','05jlu','w','l','m','e9r','5q','d0ol','i6ajr','csf','d32') DEFAULT '6li6,05jlu,w,l,m,d0ol,i6ajr,csf,d32',`fb753d37-6252-4bd3-9bd1-0059640e7861` year(4) DEFAULT '2065', UNIQUE KEY `51816c39-27df-4bbe-a0e7-d6b6f54be2a2` (`fb753d37-6252-4bd3-9bd1-0059640e7861`), KEY `b0dfda0a-ffed-4c5b-9a72-4113bc1cbc8e` (`9b2a7138-14a3-4e8f-b29a-720392aad22c`,`fb753d37-6252-4bd3-9bd1-0059640e7861`)) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_bin /*T! SHARD_ROW_ID_BITS=5 */;") + tk.MustExec("insert into `ba35a09f-76f4-40aa-9b48-13154a24bdd2` values ('if8yo', '6li6,05jlu,w,l,m,d0ol,i6ajr,csf,d32', 2065);") + + tk.MustExec("CREATE TABLE `07ccc74e-14c3-4685-bb41-c78a069b1a6d` (`8a93bdc5-2214-4f96-b5a7-1ba4c0d396ae` bigint(20) NOT NULL DEFAULT '-4604789462044748682',`30b19ecf-679f-4ca3-813f-d3c3b8f7da7e` date NOT NULL DEFAULT '5030-11-23',`1c52eaf2-1ebb-4486-9410-dfd00c7c835c` decimal(7,5) DEFAULT '-81.91307',`4b09dfdc-e688-41cb-9ffa-d03071a43077` float DEFAULT '1.7989023',PRIMARY KEY (`30b19ecf-679f-4ca3-813f-d3c3b8f7da7e`,`8a93bdc5-2214-4f96-b5a7-1ba4c0d396ae`) /*T![clustered_index] CLUSTERED */,KEY `ae7a7637-ca52-443b-8a3f-69694f730cc4` (`8a93bdc5-2214-4f96-b5a7-1ba4c0d396ae`),KEY `42640042-8a17-4145-9510-5bb419f83ed9` (`8a93bdc5-2214-4f96-b5a7-1ba4c0d396ae`),KEY `839f4f5a-83f3-449b-a7dd-c7d2974d351a` (`30b19ecf-679f-4ca3-813f-d3c3b8f7da7e`),KEY `c474cde1-6fe4-45df-9067-b4e479f84149` (`8a93bdc5-2214-4f96-b5a7-1ba4c0d396ae`),KEY `f834d0a9-709e-4ca8-925d-73f48322b70d` (`8a93bdc5-2214-4f96-b5a7-1ba4c0d396ae`)) ENGINE=InnoDB DEFAULT CHARSET=gbk COLLATE=gbk_chinese_ci;") + tk.MustExec("set sql_mode=``;") + tk.MustExec("INSERT INTO `07ccc74e-14c3-4685-bb41-c78a069b1a6d` VALUES (616295989348159438,'0000-00-00',1.00000,1.7989023),(2215857492573998768,'1970-02-02',0.00000,1.7989023),(2215857492573998768,'1983-05-13',0.00000,1.7989023),(-2840083604831267906,'1984-01-30',1.00000,1.7989023),(599388718360890339,'1986-09-09',1.00000,1.7989023),(3506764933630033073,'1987-11-22',1.00000,1.7989023),(3506764933630033073,'2002-02-26',1.00000,1.7989023),(3506764933630033073,'2003-05-14',1.00000,1.7989023),(3506764933630033073,'2007-05-16',1.00000,1.7989023),(3506764933630033073,'2017-02-20',1.00000,1.7989023),(3506764933630033073,'2017-08-06',1.00000,1.7989023),(2215857492573998768,'2019-02-18',1.00000,1.7989023),(3506764933630033073,'2020-08-11',1.00000,1.7989023),(3506764933630033073,'2028-06-07',1.00000,1.7989023),(3506764933630033073,'2036-08-16',1.00000,1.7989023);") + + tk.MustQuery("select /*+ use_index_merge( `4f380f26-9af6-4df8-959d-ad6296eff914` ) */ /*+ stream_agg() */ approx_percentile( `4f380f26-9af6-4df8-959d-ad6296eff914`.`f7a9a4be-3728-449b-a5ea-df9b957eec67` , 77 ) as r0 , `4f380f26-9af6-4df8-959d-ad6296eff914`.`f7a9a4be-3728-449b-a5ea-df9b957eec67` as r1 from `4f380f26-9af6-4df8-959d-ad6296eff914` where not( IsNull( `4f380f26-9af6-4df8-959d-ad6296eff914`.`f7a9a4be-3728-449b-a5ea-df9b957eec67` ) ) and not( `4f380f26-9af6-4df8-959d-ad6296eff914`.`f7a9a4be-3728-449b-a5ea-df9b957eec67` in ( select `8a93bdc5-2214-4f96-b5a7-1ba4c0d396ae` from `07ccc74e-14c3-4685-bb41-c78a069b1a6d` where `4f380f26-9af6-4df8-959d-ad6296eff914`.`f7a9a4be-3728-449b-a5ea-df9b957eec67` in ( select `a60d6b5c-08bd-4a6d-b951-716162d004a5` from `ba35a09f-76f4-40aa-9b48-13154a24bdd2` where not( `4f380f26-9af6-4df8-959d-ad6296eff914`.`f7a9a4be-3728-449b-a5ea-df9b957eec67` between 'bpkfo' and '59' ) and not( `4f380f26-9af6-4df8-959d-ad6296eff914`.`f7a9a4be-3728-449b-a5ea-df9b957eec67` in ( select `fb753d37-6252-4bd3-9bd1-0059640e7861` from `ba35a09f-76f4-40aa-9b48-13154a24bdd2` where IsNull( `4f380f26-9af6-4df8-959d-ad6296eff914`.`f7a9a4be-3728-449b-a5ea-df9b957eec67` ) or not( `4f380f26-9af6-4df8-959d-ad6296eff914`.`43ca0135-1650-429b-8887-9eabcae2a234` in ( select `8a93bdc5-2214-4f96-b5a7-1ba4c0d396ae` from `07ccc74e-14c3-4685-bb41-c78a069b1a6d` where IsNull( `4f380f26-9af6-4df8-959d-ad6296eff914`.`43ca0135-1650-429b-8887-9eabcae2a234` ) and not( `4f380f26-9af6-4df8-959d-ad6296eff914`.`f7a9a4be-3728-449b-a5ea-df9b957eec67` between 'neud' and 'bpkfo' ) ) ) ) ) ) ) ) group by `4f380f26-9af6-4df8-959d-ad6296eff914`.`f7a9a4be-3728-449b-a5ea-df9b957eec67`;") +} + +func TestIssue38756(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + + tk.MustExec("use test") + tk.MustExec("create table t (c1 int)") + tk.MustExec("insert into t values (1), (2), (3)") + tk.MustQuery("SELECT SQRT(1) FROM t").Check(testkit.Rows("1", "1", "1")) + tk.MustQuery("(SELECT DISTINCT SQRT(1) FROM t)").Check(testkit.Rows("1")) + tk.MustQuery("SELECT DISTINCT cast(1 as double) FROM t").Check(testkit.Rows("1")) +} + +func TestIssue50043(t *testing.T) { + testIssue50043WithInitSQL(t, "") +} + +func TestIssue50043WithPipelinedDML(t *testing.T) { + testIssue50043WithInitSQL(t, "set @@tidb_dml_type=bulk") +} + +func testIssue50043WithInitSQL(t *testing.T, initSQL string) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec(initSQL) + // Test simplified case by update. + tk.MustExec("use test") + tk.MustExec("create table t (c1 boolean ,c2 decimal ( 37 , 17 ), unique key idx1 (c1 ,c2),unique key idx2 ( c1 ));") + tk.MustExec("insert into t values (0,NULL);") + tk.MustExec("alter table t alter column c2 drop default;") + tk.MustExec("update t set c2 = 5 where c1 = 0;") + tk.MustQuery("select * from t order by c1,c2").Check(testkit.Rows("0 5.00000000000000000")) + + // Test simplified case by insert on duplicate key update. + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (c1 boolean ,c2 decimal ( 37 , 17 ), unique key idx1 (c1 ,c2));") + tk.MustExec("alter table t alter column c2 drop default;") + tk.MustExec("alter table t add unique key idx4 ( c1 );") + tk.MustExec("insert into t values (0, NULL), (1, 1);") + tk.MustExec("insert into t values (0, 2) ,(1, 3) on duplicate key update c2 = 5;") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("select * from t order by c1,c2").Check(testkit.Rows("0 5.00000000000000000", "1 5.00000000000000000")) + + // Test Issue 50043. + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (c1 boolean ,c2 decimal ( 37 , 17 ), unique key idx1 (c1 ,c2));") + tk.MustExec("alter table t alter column c2 drop default;") + tk.MustExec("alter table t add unique key idx4 ( c1 );") + tk.MustExec("insert into t values (0, NULL), (1, 1);") + tk.MustExec("insert ignore into t values (0, 2) ,(1, 3) on duplicate key update c2 = 5, c1 = 0") + tk.MustQuery("select * from t order by c1,c2").Check(testkit.Rows("0 5.00000000000000000", "1 1.00000000000000000")) +} + +func TestIssue51324(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t (id int key, a int, b enum('a', 'b'))") + tk.MustGetErrMsg("insert into t values ()", "[table:1364]Field 'id' doesn't have a default value") + tk.MustExec("insert into t set id = 1") + tk.MustExec("insert into t set id = 2, a = NULL, b = NULL") + tk.MustExec("insert into t set id = 3, a = DEFAULT, b = DEFAULT") + tk.MustQuery("select * from t order by id").Check(testkit.Rows("1 ", "2 ", "3 ")) + + tk.MustExec("alter table t alter column a drop default") + tk.MustExec("alter table t alter column b drop default") + tk.MustGetErrMsg("insert into t set id = 4;", "[table:1364]Field 'a' doesn't have a default value") + tk.MustExec("insert into t set id = 5, a = NULL, b = NULL;") + tk.MustGetErrMsg("insert into t set id = 6, a = DEFAULT, b = DEFAULT;", "[table:1364]Field 'a' doesn't have a default value") + tk.MustQuery("select * from t order by id").Check(testkit.Rows("1 ", "2 ", "3 ", "5 ")) + + tk.MustExec("insert ignore into t set id = 4;") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1364 Field 'a' doesn't have a default value")) + tk.MustExec("insert ignore into t set id = 6, a = DEFAULT, b = DEFAULT;") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1364 Field 'a' doesn't have a default value")) + tk.MustQuery("select * from t order by id").Check(testkit.Rows("1 ", "2 ", "3 ", "4 ", "5 ", "6 ")) + tk.MustExec("update t set id = id + 10") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("select * from t order by id").Check(testkit.Rows("11 ", "12 ", "13 ", "14 ", "15 ", "16 ")) + + // Test not null case. + tk.MustExec("drop table t") + tk.MustExec("create table t (id int key, a int not null, b enum('a', 'b') not null)") + tk.MustGetErrMsg("insert into t values ()", "[table:1364]Field 'id' doesn't have a default value") + tk.MustGetErrMsg("insert into t set id = 1", "[table:1364]Field 'a' doesn't have a default value") + tk.MustGetErrMsg("insert into t set id = 2, a = NULL, b = NULL", "[table:1048]Column 'a' cannot be null") + tk.MustGetErrMsg("insert into t set id = 2, a = 2, b = NULL", "[table:1048]Column 'b' cannot be null") + tk.MustGetErrMsg("insert into t set id = 3, a = DEFAULT, b = DEFAULT", "[table:1364]Field 'a' doesn't have a default value") + tk.MustExec("alter table t alter column a drop default") + tk.MustExec("alter table t alter column b drop default") + tk.MustExec("insert ignore into t set id = 4;") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1364 Field 'a' doesn't have a default value")) + tk.MustExec("insert ignore into t set id = 5, a = NULL, b = NULL;") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1048 Column 'a' cannot be null", "Warning 1048 Column 'b' cannot be null")) + tk.MustExec("insert ignore into t set id = 6, a = 6, b = NULL;") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1048 Column 'b' cannot be null")) + tk.MustExec("insert ignore into t set id = 7, a = DEFAULT, b = DEFAULT;") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1364 Field 'a' doesn't have a default value")) + tk.MustQuery("select * from t order by id").Check(testkit.Rows("4 0 a", "5 0 ", "6 6 ", "7 0 a")) + + // Test add column with OriginDefaultValue case. + tk.MustExec("drop table t") + tk.MustExec("create table t (id int, unique key idx (id))") + tk.MustExec("insert into t values (1)") + tk.MustExec("alter table t add column a int default 1") + tk.MustExec("alter table t add column b int default null") + tk.MustExec("alter table t add column c int not null") + tk.MustExec("alter table t add column d int not null default 1") + tk.MustExec("insert ignore into t (id) values (2)") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1364 Field 'c' doesn't have a default value")) + tk.MustExec("insert ignore into t (id) values (1),(2) on duplicate key update id = id+10") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1364 Field 'c' doesn't have a default value")) + tk.MustExec("alter table t alter column a drop default") + tk.MustExec("alter table t alter column b drop default") + tk.MustExec("alter table t alter column c drop default") + tk.MustExec("alter table t alter column d drop default") + tk.MustExec("insert ignore into t (id) values (3)") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1364 Field 'a' doesn't have a default value", "Warning 1364 Field 'b' doesn't have a default value", "Warning 1364 Field 'c' doesn't have a default value", "Warning 1364 Field 'd' doesn't have a default value")) + tk.MustExec("insert ignore into t (id) values (11),(12),(3) on duplicate key update id = id+10") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1364 Field 'a' doesn't have a default value", "Warning 1364 Field 'b' doesn't have a default value", "Warning 1364 Field 'c' doesn't have a default value", "Warning 1364 Field 'd' doesn't have a default value")) + tk.MustQuery("select * from t order by id").Check(testkit.Rows("13 0 0", "21 1 0 1", "22 1 0 1")) +} + +func TestDecimalDivPrecisionIncrement(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t (a decimal(3,0), b decimal(3,0))") + tk.MustExec("insert into t values (8, 7), (9, 7)") + tk.MustQuery("select a/b from t").Check(testkit.Rows("1.1429", "1.2857")) + + tk.MustExec("set div_precision_increment = 7") + tk.MustQuery("select a/b from t").Check(testkit.Rows("1.1428571", "1.2857143")) + + tk.MustExec("set div_precision_increment = 30") + tk.MustQuery("select a/b from t").Check(testkit.Rows("1.142857142857142857142857142857", "1.285714285714285714285714285714")) + + tk.MustExec("set div_precision_increment = 4") + tk.MustQuery("select avg(a) from t").Check(testkit.Rows("8.5000")) + + tk.MustExec("set div_precision_increment = 4") + tk.MustQuery("select avg(a/b) from t").Check(testkit.Rows("1.21428571")) + + tk.MustExec("set div_precision_increment = 10") + tk.MustQuery("select avg(a/b) from t").Check(testkit.Rows("1.21428571428571428550")) +} + +func TestIssue48756(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("CREATE TABLE t (id INT, a VARBINARY(20), b BIGINT)") + tk.MustExec(`INSERT INTO t VALUES(1, _binary '2012-05-19 09:06:07', 20120519090607), +(1, _binary '2012-05-19 09:06:07', 20120519090607), +(2, _binary '12012-05-19 09:06:07', 120120519090607), +(2, _binary '12012-05-19 09:06:07', 120120519090607)`) + tk.MustQuery("SELECT SUBTIME(BIT_OR(b), '1 1:1:1.000002') FROM t GROUP BY id").Sort().Check(testkit.Rows( + "2012-05-18 08:05:05.999998", + "", + )) + tk.MustQuery("show warnings").Check(testkit.Rows( + "Warning 1292 Incorrect time value: '120120519090607'", + )) +} + +func TestIssue50308(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t(a timestamp);") + tk.MustExec("insert ignore into t values(cast('2099-01-01' as date));") + tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning 1292 Incorrect timestamp value: '2099-01-01' for column 'a' at row 1")) + tk.MustQuery("select * from t;").Check(testkit.Rows("0000-00-00 00:00:00")) + tk.MustExec("delete from t") + tk.MustExec("insert into t values('2000-01-01');") + tk.MustGetErrMsg("update t set a=cast('2099-01-01' as date)", "[types:1292]Incorrect timestamp value: '2099-01-01'") + tk.MustExec("update ignore t set a=cast('2099-01-01' as date);") + tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning 1292 Incorrect timestamp value: '2099-01-01'")) + tk.MustQuery("select * from t;").Check(testkit.Rows("0000-00-00 00:00:00")) +} + +func TestQueryWithKill(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists tkq;") + tk.MustExec("create table tkq (a int key, b int, index idx_b(b));") + tk.MustExec("insert into tkq values (1,1);") + var wg sync.WaitGroup + ch := make(chan context.CancelFunc, 1024) + testDuration := time.Second * 10 + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test;") + start := time.Now() + for { + ctx, cancel := context.WithCancel(context.Background()) + ch <- cancel + rs, err := tk.ExecWithContext(ctx, "select a from tkq where b = 1;") + if err == nil { + require.NotNil(t, rs) + rows, err := session.ResultSetToStringSlice(ctx, tk.Session(), rs) + if err == nil { + require.Equal(t, 1, len(rows)) + require.Equal(t, 1, len(rows[0])) + require.Equal(t, "1", fmt.Sprintf("%v", rows[0][0])) + } + } + if err != nil { + require.Equal(t, context.Canceled, err) + } + if rs != nil { + rs.Close() + } + if time.Since(start) > testDuration { + return + } + } + }() + } + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case cancel := <-ch: + // mock for random kill query + if len(ch) < 5 { + time.Sleep(time.Duration(rand.Intn(1000)) * time.Nanosecond) + } + cancel() + case <-time.After(time.Second): + return + } + } + }() + wg.Wait() +} diff --git a/pkg/executor/test/writetest/BUILD.bazel b/pkg/executor/test/writetest/BUILD.bazel new file mode 100644 index 0000000000000..9274fe7730b26 --- /dev/null +++ b/pkg/executor/test/writetest/BUILD.bazel @@ -0,0 +1,34 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + +go_test( + name = "writetest_test", + timeout = "short", + srcs = [ + "main_test.go", + "write_test.go", + ], + flaky = True, + shard_count = 9, + deps = [ + "//pkg/config", + "//pkg/errctx", + "//pkg/executor", + "//pkg/kv", + "//pkg/lightning/mydump", + "//pkg/meta/autoid", + "//pkg/parser/model", + "//pkg/parser/mysql", + "//pkg/session", + "//pkg/sessionctx", + "//pkg/sessiontxn", + "//pkg/store/mockstore", + "//pkg/table/tables", + "//pkg/testkit", + "//pkg/types", + "//pkg/util", + "@com_github_stretchr_testify//require", + "@com_github_tikv_client_go_v2//tikv", + "@io_opencensus_go//stats/view", + "@org_uber_go_goleak//:goleak", + ], +) diff --git a/pkg/executor/test/writetest/write_test.go b/pkg/executor/test/writetest/write_test.go new file mode 100644 index 0000000000000..b1c420e3352bc --- /dev/null +++ b/pkg/executor/test/writetest/write_test.go @@ -0,0 +1,548 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package writetest + +import ( + "context" + "errors" + "fmt" + "io" + "testing" + + "github.com/pingcap/tidb/pkg/errctx" + "github.com/pingcap/tidb/pkg/executor" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/lightning/mydump" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/session" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessiontxn" + "github.com/pingcap/tidb/pkg/store/mockstore" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/testkit" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/stretchr/testify/require" +) + +func TestInsertIgnore(t *testing.T) { + store := testkit.CreateMockStore(t) + var cfg kv.InjectionConfig + tk := testkit.NewTestKit(t, kv.NewInjectedStore(store, &cfg)) + tk.MustExec("use test") + testSQL := `drop table if exists t; + create table t (id int PRIMARY KEY AUTO_INCREMENT, c1 int unique key);` + tk.MustExec(testSQL) + testSQL = `insert into t values (1, 2);` + tk.MustExec(testSQL) + require.Empty(t, tk.Session().LastMessage()) + + r := tk.MustQuery("select * from t;") + rowStr := fmt.Sprintf("%v %v", "1", "2") + r.Check(testkit.Rows(rowStr)) + + tk.MustExec("insert ignore into t values (1, 3), (2, 3)") + require.Equal(t, tk.Session().LastMessage(), "Records: 2 Duplicates: 1 Warnings: 1") + r = tk.MustQuery("select * from t;") + rowStr1 := fmt.Sprintf("%v %v", "2", "3") + r.Check(testkit.Rows(rowStr, rowStr1)) + + tk.MustExec("insert ignore into t values (3, 4), (3, 4)") + require.Equal(t, tk.Session().LastMessage(), "Records: 2 Duplicates: 1 Warnings: 1") + r = tk.MustQuery("select * from t;") + rowStr2 := fmt.Sprintf("%v %v", "3", "4") + r.Check(testkit.Rows(rowStr, rowStr1, rowStr2)) + + tk.MustExec("begin") + tk.MustExec("insert ignore into t values (4, 4), (4, 5), (4, 6)") + require.Equal(t, tk.Session().LastMessage(), "Records: 3 Duplicates: 2 Warnings: 2") + r = tk.MustQuery("select * from t;") + rowStr3 := fmt.Sprintf("%v %v", "4", "5") + r.Check(testkit.Rows(rowStr, rowStr1, rowStr2, rowStr3)) + tk.MustExec("commit") + + cfg.SetGetError(errors.New("foo")) + err := tk.ExecToErr("insert ignore into t values (1, 3)") + require.Error(t, err) + cfg.SetGetError(nil) + + // for issue 4268 + testSQL = `drop table if exists t; + create table t (a bigint);` + tk.MustExec(testSQL) + testSQL = "insert ignore into t select '1a';" + err = tk.ExecToErr(testSQL) + require.NoError(t, err) + require.Equal(t, tk.Session().LastMessage(), "Records: 1 Duplicates: 0 Warnings: 1") + r = tk.MustQuery("SHOW WARNINGS") + r.Check(testkit.Rows("Warning 1292 Truncated incorrect DOUBLE value: '1a'")) + testSQL = "insert ignore into t values ('1a')" + err = tk.ExecToErr(testSQL) + require.NoError(t, err) + require.Empty(t, tk.Session().LastMessage()) + r = tk.MustQuery("SHOW WARNINGS") + // TODO: MySQL8.0 reports Warning 1265 Data truncated for column 'a' at row 1 + r.Check(testkit.Rows("Warning 1366 Incorrect bigint value: '1a' for column 'a' at row 1")) + + // for duplicates with warning + testSQL = `drop table if exists t; + create table t(a int primary key, b int);` + tk.MustExec(testSQL) + testSQL = "insert ignore into t values (1,1);" + tk.MustExec(testSQL) + require.Empty(t, tk.Session().LastMessage()) + err = tk.ExecToErr(testSQL) + require.Empty(t, tk.Session().LastMessage()) + require.NoError(t, err) + r = tk.MustQuery("SHOW WARNINGS") + r.Check(testkit.Rows("Warning 1062 Duplicate entry '1' for key 't.PRIMARY'")) + + testSQL = `drop table if exists test; +create table test (i int primary key, j int unique); +begin; +insert into test values (1,1); +insert ignore into test values (2,1); +commit;` + tk.MustExec(testSQL) + testSQL = `select * from test;` + r = tk.MustQuery(testSQL) + r.Check(testkit.Rows("1 1")) + + testSQL = `delete from test; +insert into test values (1, 1); +begin; +delete from test where i = 1; +insert ignore into test values (2, 1); +commit;` + tk.MustExec(testSQL) + testSQL = `select * from test;` + r = tk.MustQuery(testSQL) + r.Check(testkit.Rows("2 1")) + + testSQL = `delete from test; +insert into test values (1, 1); +begin; +update test set i = 2, j = 2 where i = 1; +insert ignore into test values (1, 3); +insert ignore into test values (2, 4); +commit;` + tk.MustExec(testSQL) + testSQL = `select * from test order by i;` + r = tk.MustQuery(testSQL) + r.Check(testkit.Rows("1 3", "2 2")) + + testSQL = `create table badnull (i int not null)` + tk.MustExec(testSQL) + testSQL = `insert ignore into badnull values (null)` + tk.MustExec(testSQL) + require.Empty(t, tk.Session().LastMessage()) + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1048 Column 'i' cannot be null")) + testSQL = `select * from badnull` + tk.MustQuery(testSQL).Check(testkit.Rows("0")) + + tk.MustExec("create table tp (id int) partition by range (id) (partition p0 values less than (1), partition p1 values less than(2))") + tk.MustExec("insert ignore into tp values (1), (3)") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1526 Table has no partition for value 3")) +} + +type testCase struct { + data []byte + expected []string + expectedMsg string +} + +func checkCases( + tests []testCase, + loadSQL string, + t *testing.T, + tk *testkit.TestKit, + ctx sessionctx.Context, + selectSQL, deleteSQL string, +) { + for _, tt := range tests { + var reader io.ReadCloser = mydump.NewStringReader(string(tt.data)) + var readerBuilder executor.LoadDataReaderBuilder = func(_ string) ( + r io.ReadCloser, err error, + ) { + return reader, nil + } + + ctx.SetValue(executor.LoadDataReaderBuilderKey, readerBuilder) + tk.MustExec(loadSQL) + warnings := tk.Session().GetSessionVars().StmtCtx.GetWarnings() + for _, w := range warnings { + fmt.Printf("warnnig: %#v\n", w.Err.Error()) + } + require.Equal(t, tt.expectedMsg, tk.Session().LastMessage(), tt.expected) + tk.MustQuery(selectSQL).Check(testkit.RowsWithSep("|", tt.expected...)) + tk.MustExec(deleteSQL) + } +} + +func TestLoadDataMissingColumn(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + createSQL := `create table load_data_missing (id int, t timestamp not null)` + tk.MustExec(createSQL) + loadSQL := "load data local infile '/tmp/nonexistence.csv' ignore into table load_data_missing" + ctx := tk.Session().(sessionctx.Context) + + deleteSQL := "delete from load_data_missing" + selectSQL := "select id, hour(t), minute(t) from load_data_missing;" + + curTime := types.CurrentTime(mysql.TypeTimestamp) + timeHour := curTime.Hour() + timeMinute := curTime.Minute() + tests := []testCase{ + {[]byte(""), nil, "Records: 0 Deleted: 0 Skipped: 0 Warnings: 0"}, + {[]byte("12\n"), []string{fmt.Sprintf("12|%v|%v", timeHour, timeMinute)}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 1"}, + } + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) + + tk.MustExec("alter table load_data_missing add column t2 timestamp null") + curTime = types.CurrentTime(mysql.TypeTimestamp) + timeHour = curTime.Hour() + timeMinute = curTime.Minute() + selectSQL = "select id, hour(t), minute(t), t2 from load_data_missing;" + tests = []testCase{ + {[]byte("12\n"), []string{fmt.Sprintf("12|%v|%v|", timeHour, timeMinute)}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 1"}, + } + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) +} + +func TestIssue18681(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + createSQL := `drop table if exists load_data_test; + create table load_data_test (a bit(1),b bit(1),c bit(1),d bit(1),e bit(32),f bit(1));` + tk.MustExec(createSQL) + loadSQL := "load data local infile '/tmp/nonexistence.csv' ignore into table load_data_test" + ctx := tk.Session().(sessionctx.Context) + + deleteSQL := "delete from load_data_test" + selectSQL := "select bin(a), bin(b), bin(c), bin(d), bin(e), bin(f) from load_data_test;" + levels := ctx.GetSessionVars().StmtCtx.ErrLevels() + levels[errctx.ErrGroupDupKey] = errctx.LevelWarn + levels[errctx.ErrGroupBadNull] = errctx.LevelWarn + levels[errctx.ErrGroupNoDefault] = errctx.LevelWarn + + sc := ctx.GetSessionVars().StmtCtx + oldTypeFlags := sc.TypeFlags() + defer func() { + sc.SetTypeFlags(oldTypeFlags) + }() + sc.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(true)) + tests := []testCase{ + {[]byte("true\tfalse\t0\t1\tb'1'\tb'1'\n"), []string{"1|1|1|1|1100010001001110011000100100111|1"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 5"}, + } + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) + require.Equal(t, uint16(0), sc.WarningCount()) +} + +func TestIssue34358(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + ctx := tk.Session().(sessionctx.Context) + defer ctx.SetValue(executor.LoadDataVarKey, nil) + + tk.MustExec("use test") + tk.MustExec("drop table if exists load_data_test") + tk.MustExec("create table load_data_test (a varchar(10), b varchar(10))") + + loadSQL := "load data local infile '/tmp/nonexistence.csv' into table load_data_test ( @v1, " + + "@v2 ) set a = @v1, b = @v2" + checkCases([]testCase{ + {[]byte("\\N\n"), []string{"|"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 1"}, + }, loadSQL, t, tk, ctx, "select * from load_data_test", "delete from load_data_test", + ) +} + +func TestLatch(t *testing.T) { + store, err := mockstore.NewMockStore( + // Small latch slot size to make conflicts. + mockstore.WithTxnLocalLatches(64), + ) + require.NoError(t, err) + defer func() { + err := store.Close() + require.NoError(t, err) + }() + + dom, err1 := session.BootstrapSession(store) + require.Nil(t, err1) + defer dom.Close() + + setTxnTk := testkit.NewTestKit(t, store) + setTxnTk.MustExec("set global tidb_txn_mode=''") + tk1 := testkit.NewTestKit(t, store) + tk1.MustExec("use test") + tk1.MustExec("drop table if exists t") + tk1.MustExec("create table t (id int)") + tk1.MustExec("set @@tidb_disable_txn_auto_retry = true") + + tk2 := testkit.NewTestKit(t, store) + tk2.MustExec("use test") + tk1.MustExec("set @@tidb_disable_txn_auto_retry = true") + + fn := func() { + tk1.MustExec("begin") + for i := 0; i < 100; i++ { + tk1.MustExec(fmt.Sprintf("insert into t values (%d)", i)) + } + tk2.MustExec("begin") + for i := 100; i < 200; i++ { + tk1.MustExec(fmt.Sprintf("insert into t values (%d)", i)) + } + tk2.MustExec("commit") + } + + // txn1 and txn2 data range do not overlap, using latches should not + // result in txn conflict. + fn() + tk1.MustExec("commit") + + tk1.MustExec("truncate table t") + fn() + tk1.MustExec("commit") + + // Test the error type of latch and it could be retry if TiDB enable the retry. + tk1.MustExec("begin") + tk1.MustExec("update t set id = id + 1") + tk2.MustExec("update t set id = id + 1") + tk1.MustGetDBError("commit", kv.ErrWriteConflictInTiDB) +} + +func TestReplaceLog(t *testing.T) { + store, domain := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec(`create table testLog (a int not null primary key, b int unique key);`) + + // Make some dangling index. + ctx := testkit.NewSession(t, store) + is := domain.InfoSchema() + dbName := model.NewCIStr("test") + tblName := model.NewCIStr("testLog") + tbl, err := is.TableByName(context.Background(), dbName, tblName) + require.NoError(t, err) + tblInfo := tbl.Meta() + idxInfo := tblInfo.FindIndexByName("b") + indexOpr := tables.NewIndex(tblInfo.ID, tblInfo, idxInfo) + + txn, err := store.Begin() + require.NoError(t, err) + _, err = indexOpr.Create(ctx.GetTableCtx(), txn, types.MakeDatums(1), kv.IntHandle(1), nil) + require.NoError(t, err) + err = txn.Commit(context.Background()) + require.NoError(t, err) + + err = tk.ExecToErr(`replace into testLog values (0, 0), (1, 1);`) + require.Error(t, err) + require.EqualError(t, err, `can not be duplicated row, due to old row not found. handle 1 not found`) + tk.MustQuery(`admin cleanup index testLog b;`).Check(testkit.Rows("1")) +} + +// TestRebaseIfNeeded is for issue 7422. +// There is no need to do the rebase when updating a record if the auto-increment ID not changed. +// This could make the auto ID increasing speed slower. +func TestRebaseIfNeeded(t *testing.T) { + store, domain := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec(`create table t (a int not null primary key auto_increment, b int unique key);`) + tk.MustExec(`insert into t (b) values (1);`) + + ctx := testkit.NewSession(t, store) + tbl, err := domain.InfoSchema().TableByName(context.Background(), model.NewCIStr("test"), model.NewCIStr("t")) + require.NoError(t, err) + require.Nil(t, sessiontxn.NewTxn(context.Background(), ctx)) + txn, err := ctx.Txn(true) + require.NoError(t, err) + // AddRecord directly here will skip to rebase the auto ID in the insert statement, + // which could simulate another TiDB adds a large auto ID. + _, err = tbl.AddRecord(ctx.GetTableCtx(), txn, types.MakeDatums(30001, 2)) + require.NoError(t, err) + require.NoError(t, txn.Commit(context.Background())) + + tk.MustExec(`update t set b = 3 where a = 30001;`) + tk.MustExec(`insert into t (b) values (4);`) + tk.MustQuery(`select a from t where b = 4;`).Check(testkit.Rows("2")) + + tk.MustExec(`insert into t set b = 3 on duplicate key update a = a;`) + tk.MustExec(`insert into t (b) values (5);`) + tk.MustQuery(`select a from t where b = 5;`).Check(testkit.Rows("4")) + + tk.MustExec(`insert into t set b = 3 on duplicate key update a = a + 1;`) + tk.MustExec(`insert into t (b) values (6);`) + tk.MustQuery(`select a from t where b = 6;`).Check(testkit.Rows("30003")) +} + +func TestDeferConstraintCheckForInsert(t *testing.T) { + store := testkit.CreateMockStore(t) + setTxnTk := testkit.NewTestKit(t, store) + setTxnTk.MustExec("set global tidb_txn_mode=''") + tk := testkit.NewTestKit(t, store) + tk.MustExec(`use test`) + + tk.MustExec(`drop table if exists t;create table t (a int primary key, b int);`) + tk.MustExec(`insert into t values (1,2),(2,2)`) + err := tk.ExecToErr("update t set a=a+1 where b=2") + require.Error(t, err) + + tk.MustExec(`drop table if exists t;create table t (i int key);`) + tk.MustExec(`insert t values (1);`) + tk.MustExec(`set tidb_constraint_check_in_place = 1;`) + tk.MustExec(`begin;`) + err = tk.ExecToErr(`insert t values (1);`) + require.Error(t, err) + tk.MustExec(`update t set i = 2 where i = 1;`) + tk.MustExec(`commit;`) + tk.MustQuery(`select * from t;`).Check(testkit.Rows("2")) + + tk.MustExec(`set tidb_constraint_check_in_place = 0;`) + tk.MustExec("replace into t values (1),(2)") + tk.MustExec("begin") + err = tk.ExecToErr("update t set i = 2 where i = 1") + require.Error(t, err) + err = tk.ExecToErr("insert into t values (1) on duplicate key update i = i + 1") + require.Error(t, err) + tk.MustExec("rollback") + + tk.MustExec(`drop table t; create table t (id int primary key, v int unique);`) + tk.MustExec(`insert into t values (1, 1)`) + tk.MustExec(`set tidb_constraint_check_in_place = 1;`) + tk.MustExec(`set @@autocommit = 0;`) + + err = tk.ExecToErr("insert into t values (3, 1)") + require.Error(t, err) + err = tk.ExecToErr("insert into t values (1, 3)") + require.Error(t, err) + tk.MustExec("commit") + + tk.MustExec(`set tidb_constraint_check_in_place = 0;`) + tk.MustExec("insert into t values (3, 1)") + tk.MustExec("insert into t values (1, 3)") + err = tk.ExecToErr("commit") + require.Error(t, err) + + // Cover the temporary table. + for val := range []int{0, 1} { + tk.MustExec("set tidb_constraint_check_in_place = ?", val) + + tk.MustExec("drop table t") + tk.MustExec("create global temporary table t (a int primary key, b int) on commit delete rows") + tk.MustExec("begin") + tk.MustExec("insert into t values (1, 1)") + err = tk.ExecToErr(`insert into t values (1, 3)`) + require.Error(t, err) + tk.MustExec("insert into t values (2, 2)") + err = tk.ExecToErr("update t set a = a + 1 where a = 1") + require.Error(t, err) + err = tk.ExecToErr("insert into t values (1, 3) on duplicated key update a = a + 1") + require.Error(t, err) + tk.MustExec("commit") + + tk.MustExec("drop table t") + tk.MustExec("create global temporary table t (a int, b int unique) on commit delete rows") + tk.MustExec("begin") + tk.MustExec("insert into t values (1, 1)") + err = tk.ExecToErr(`insert into t values (3, 1)`) + require.Error(t, err) + tk.MustExec("insert into t values (2, 2)") + err = tk.ExecToErr("update t set b = b + 1 where a = 1") + require.Error(t, err) + err = tk.ExecToErr("insert into t values (3, 1) on duplicated key update b = b + 1") + require.Error(t, err) + tk.MustExec("commit") + + // cases for temporary table + tk.MustExec("drop table if exists tl") + tk.MustExec("create temporary table tl (a int primary key, b int)") + tk.MustExec("begin") + tk.MustExec("insert into tl values (1, 1)") + err = tk.ExecToErr(`insert into tl values (1, 3)`) + require.Error(t, err) + tk.MustExec("insert into tl values (2, 2)") + err = tk.ExecToErr("update tl set a = a + 1 where a = 1") + require.Error(t, err) + err = tk.ExecToErr("insert into tl values (1, 3) on duplicated key update a = a + 1") + require.Error(t, err) + tk.MustExec("commit") + + tk.MustExec("begin") + tk.MustQuery("select * from tl").Check(testkit.Rows("1 1", "2 2")) + err = tk.ExecToErr(`insert into tl values (1, 3)`) + require.Error(t, err) + err = tk.ExecToErr("update tl set a = a + 1 where a = 1") + require.Error(t, err) + err = tk.ExecToErr("insert into tl values (1, 3) on duplicated key update a = a + 1") + require.Error(t, err) + tk.MustExec("rollback") + + tk.MustExec("drop table tl") + tk.MustExec("create temporary table tl (a int, b int unique)") + tk.MustExec("begin") + tk.MustExec("insert into tl values (1, 1)") + err = tk.ExecToErr(`insert into tl values (3, 1)`) + require.Error(t, err) + tk.MustExec("insert into tl values (2, 2)") + err = tk.ExecToErr("update tl set b = b + 1 where a = 1") + require.Error(t, err) + err = tk.ExecToErr("insert into tl values (3, 1) on duplicated key update b = b + 1") + require.Error(t, err) + tk.MustExec("commit") + + tk.MustExec("begin") + tk.MustQuery("select * from tl").Check(testkit.Rows("1 1", "2 2")) + err = tk.ExecToErr(`insert into tl values (3, 1)`) + require.Error(t, err) + err = tk.ExecToErr("update tl set b = b + 1 where a = 1") + require.Error(t, err) + err = tk.ExecToErr("insert into tl values (3, 1) on duplicated key update b = b + 1") + require.Error(t, err) + tk.MustExec("rollback") + } +} + +func TestPessimisticDeleteYourWrites(t *testing.T) { + store := testkit.CreateMockStore(t) + + session1 := testkit.NewTestKit(t, store) + session1.MustExec("use test") + session2 := testkit.NewTestKit(t, store) + session2.MustExec("use test") + + session1.MustExec("drop table if exists x;") + session1.MustExec("create table x (id int primary key, c int);") + + session1.MustExec("set tidb_txn_mode = 'pessimistic'") + session2.MustExec("set tidb_txn_mode = 'pessimistic'") + + session1.MustExec("begin;") + session1.MustExec("insert into x select 1, 1") + session1.MustExec("delete from x where id = 1") + session2.MustExec("begin;") + var wg util.WaitGroupWrapper + wg.Run(func() { + session2.MustExec("insert into x select 1, 2") + }) + session1.MustExec("commit;") + wg.Wait() + session2.MustExec("commit;") + session2.MustQuery("select * from x").Check(testkit.Rows("1 2")) +} diff --git a/pkg/sessionctx/context.go b/pkg/sessionctx/context.go new file mode 100644 index 0000000000000..8416a4db014e2 --- /dev/null +++ b/pkg/sessionctx/context.go @@ -0,0 +1,252 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sessionctx + +import ( + "context" + "iter" + "sync" + + distsqlctx "github.com/pingcap/tidb/pkg/distsql/context" + "github.com/pingcap/tidb/pkg/expression/exprctx" + "github.com/pingcap/tidb/pkg/extension" + infoschema "github.com/pingcap/tidb/pkg/infoschema/context" + "github.com/pingcap/tidb/pkg/kv" + tablelock "github.com/pingcap/tidb/pkg/lock/context" + "github.com/pingcap/tidb/pkg/meta/model" + "github.com/pingcap/tidb/pkg/planner/planctx" + "github.com/pingcap/tidb/pkg/session/cursor" + "github.com/pingcap/tidb/pkg/sessionctx/sessionstates" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/statistics/handle/usage/indexusage" + "github.com/pingcap/tidb/pkg/table/tblctx" + "github.com/pingcap/tidb/pkg/util" + contextutil "github.com/pingcap/tidb/pkg/util/context" + rangerctx "github.com/pingcap/tidb/pkg/util/ranger/context" + "github.com/pingcap/tidb/pkg/util/sli" + "github.com/pingcap/tidb/pkg/util/sqlexec" + "github.com/pingcap/tidb/pkg/util/topsql/stmtstats" + "github.com/tikv/client-go/v2/oracle" +) + +// SessionStatesHandler is an interface for encoding and decoding session states. +type SessionStatesHandler interface { + // EncodeSessionStates encodes session states into a JSON. + EncodeSessionStates(context.Context, Context, *sessionstates.SessionStates) error + // DecodeSessionStates decodes a map into session states. + DecodeSessionStates(context.Context, Context, *sessionstates.SessionStates) error +} + +// SessionPlanCache is an interface for prepare and non-prepared plan cache +type SessionPlanCache interface { + Get(key string, paramTypes any) (value any, ok bool) + Put(key string, value, paramTypes any) + Delete(key string) + DeleteAll() + Size() int + SetCapacity(capacity uint) error + Close() +} + +// InstancePlanCache represents the instance/node level plan cache. +// Value and Opts should always be *PlanCacheValue and *PlanCacheMatchOpts, use any to avoid cycle-import. +type InstancePlanCache interface { + // Get gets the cached value from the cache according to key and opts. + Get(key string, paramTypes any) (value any, ok bool) + // Put puts the key and value into the cache. + Put(key string, value, paramTypes any) (succ bool) + // All returns all cached values. + // Returned values are read-only, don't modify them. + All() (values []any) + // Evict evicts some cached values. + Evict(evictAll bool) (detailInfo string, numEvicted int) + // Size returns the number of cached values. + Size() int64 + // MemUsage returns the total memory usage of this plan cache. + MemUsage() int64 + // GetLimits returns the soft and hard memory limits of this plan cache. + GetLimits() (softLimit, hardLimit int64) + // SetLimits sets the soft and hard memory limits of this plan cache. + SetLimits(softLimit, hardLimit int64) +} + +// Context is an interface for transaction and executive args environment. +type Context interface { + SessionStatesHandler + contextutil.ValueStoreContext + tablelock.TableLockContext + // RollbackTxn rolls back the current transaction. + RollbackTxn(ctx context.Context) + // CommitTxn commits the current transaction. + // buffered KV changes will be discarded, call StmtCommit if you want to commit them. + CommitTxn(ctx context.Context) error + // Txn returns the current transaction which is created before executing a statement. + // The returned kv.Transaction is not nil, but it maybe pending or invalid. + // If the active parameter is true, call this function will wait for the pending txn + // to become valid. + Txn(active bool) (kv.Transaction, error) + + // GetClient gets a kv.Client. + GetClient() kv.Client + + // GetMPPClient gets a kv.MPPClient. + GetMPPClient() kv.MPPClient + + // Deprecated: the semantics of session.GetInfoSchema() is ambiguous + // If you want to get the infoschema of the current transaction in SQL layer, use sessiontxn.GetTxnManager(ctx).GetTxnInfoSchema() + // If you want to get the latest infoschema use `GetDomainInfoSchema` + GetInfoSchema() infoschema.MetaOnlyInfoSchema + + // GetDomainInfoSchema returns the latest information schema in domain + // Different with `domain.InfoSchema()`, the information schema returned by this method + // includes the temporary table definitions stored in session + GetDomainInfoSchema() infoschema.MetaOnlyInfoSchema + + GetSessionVars() *variable.SessionVars + + // GetSQLExecutor returns the sqlexec.SQLExecutor. + GetSQLExecutor() sqlexec.SQLExecutor + + // GetRestrictedSQLExecutor returns the sqlexec.RestrictedSQLExecutor. + GetRestrictedSQLExecutor() sqlexec.RestrictedSQLExecutor + + // GetExprCtx returns the expression context of the session. + GetExprCtx() exprctx.ExprContext + + // GetTableCtx returns the table.MutateContext + GetTableCtx() tblctx.MutateContext + + // GetPlanCtx gets the plan context of the current session. + GetPlanCtx() planctx.PlanContext + + // GetDistSQLCtx gets the distsql ctx of the current session + GetDistSQLCtx() *distsqlctx.DistSQLContext + + // GetRangerCtx returns the context used in `ranger` related functions + GetRangerCtx() *rangerctx.RangerContext + + // GetBuildPBCtx gets the ctx used in `ToPB` of the current session + GetBuildPBCtx() *planctx.BuildPBContext + + GetSessionManager() util.SessionManager + + // RefreshTxnCtx commits old transaction without retry, + // and creates a new transaction. + // now just for load data and batch insert. + RefreshTxnCtx(context.Context) error + + // GetStore returns the store of session. + GetStore() kv.Storage + + // GetSessionPlanCache returns the session-level cache of the physical plan. + GetSessionPlanCache() SessionPlanCache + + // UpdateColStatsUsage updates the column stats usage. + UpdateColStatsUsage(predicateColumns iter.Seq[model.TableItemID]) + + // HasDirtyContent checks whether there's dirty update on the given table. + HasDirtyContent(tid int64) bool + + // StmtCommit flush all changes by the statement to the underlying transaction. + // it must be called before CommitTxn, else all changes since last StmtCommit + // will be lost. For SQL statement, StmtCommit or StmtRollback is called automatically. + // the "Stmt" not only means SQL statement, but also any KV changes, such as + // meta KV. + StmtCommit(ctx context.Context) + // StmtRollback provides statement level rollback. The parameter `forPessimisticRetry` should be true iff it's used + // for auto-retrying execution of DMLs in pessimistic transactions. + // if error happens when you are handling batch of KV changes since last StmtCommit + // or StmtRollback, and you don't want them to be committed, you must call StmtRollback + // before you start another batch, otherwise, the previous changes might be committed + // unexpectedly. + StmtRollback(ctx context.Context, isForPessimisticRetry bool) + // IsDDLOwner checks whether this session is DDL owner. + IsDDLOwner() bool + // PrepareTSFuture uses to prepare timestamp by future. + PrepareTSFuture(ctx context.Context, future oracle.Future, scope string) error + // GetPreparedTxnFuture returns the TxnFuture if it is valid or pending. + // It returns nil otherwise. + GetPreparedTxnFuture() TxnFuture + // GetTxnWriteThroughputSLI returns the TxnWriteThroughputSLI. + GetTxnWriteThroughputSLI() *sli.TxnWriteThroughputSLI + // GetStmtStats returns stmtstats.StatementStats owned by implementation. + GetStmtStats() *stmtstats.StatementStats + // ShowProcess returns ProcessInfo running in current Context + ShowProcess() *util.ProcessInfo + // GetAdvisoryLock acquires an advisory lock (aka GET_LOCK()). + GetAdvisoryLock(string, int64) error + // IsUsedAdvisoryLock checks for existing locks (aka IS_USED_LOCK()). + IsUsedAdvisoryLock(string) uint64 + // ReleaseAdvisoryLock releases an advisory lock (aka RELEASE_LOCK()). + ReleaseAdvisoryLock(string) bool + // ReleaseAllAdvisoryLocks releases all advisory locks that this session holds. + ReleaseAllAdvisoryLocks() int + // GetExtensions returns the `*extension.SessionExtensions` object + GetExtensions() *extension.SessionExtensions + // InSandBoxMode indicates that this Session is in sandbox mode + // Ref about sandbox mode: https://dev.mysql.com/doc/refman/8.0/en/expired-password-handling.html + InSandBoxMode() bool + // EnableSandBoxMode enable the sandbox mode of this Session + EnableSandBoxMode() + // DisableSandBoxMode enable the sandbox mode of this Session + DisableSandBoxMode() + // ReportUsageStats reports the usage stats to the global collector + ReportUsageStats() + // NewStmtIndexUsageCollector creates a new index usage collector for statement + NewStmtIndexUsageCollector() *indexusage.StmtIndexUsageCollector + // GetCursorTracker returns the cursor tracker of the session + GetCursorTracker() cursor.Tracker + // GetCommitWaitGroup returns the wait group for async commit and secondary lock cleanup background goroutines + GetCommitWaitGroup() *sync.WaitGroup +} + +// TxnFuture is an interface where implementations have a kv.Transaction field and after +// calling Wait of the TxnFuture, the kv.Transaction will become valid. +type TxnFuture interface { + // Wait converts pending txn to valid + Wait(ctx context.Context, sctx Context) (kv.Transaction, error) +} + +type basicCtxType int + +func (t basicCtxType) String() string { + switch t { + case QueryString: + return "query_string" + case Initing: + return "initing" + case LastExecuteDDL: + return "last_execute_ddl" + } + return "unknown" +} + +// Context keys. +const ( + // QueryString is the key for original query string. + QueryString basicCtxType = 1 + // Initing is the key for indicating if the server is running bootstrap or upgrade job. + Initing basicCtxType = 2 + // LastExecuteDDL is the key for whether the session execute a ddl command last time. + LastExecuteDDL basicCtxType = 3 +) + +// ValidateSnapshotReadTS strictly validates that readTS does not exceed the PD timestamp. +// For read requests to the storage, the check can be implicitly performed when sending the RPC request. So this +// function is only needed when it's not proper to delay the check to when RPC requests are being sent (e.g., `BEGIN` +// statements that don't make reading operation immediately). +func ValidateSnapshotReadTS(ctx context.Context, store kv.Storage, readTS uint64, isStaleRead bool) error { + return store.GetOracle().ValidateReadTS(ctx, readTS, isStaleRead, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) +} diff --git a/pkg/store/copr/batch_coprocessor.go b/pkg/store/copr/batch_coprocessor.go new file mode 100644 index 0000000000000..534feae30de08 --- /dev/null +++ b/pkg/store/copr/batch_coprocessor.go @@ -0,0 +1,1603 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package copr + +import ( + "bytes" + "context" + "fmt" + "io" + "math" + "math/rand" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/coprocessor" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/log" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/ddl/placement" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/store/driver/backoff" + derr "github.com/pingcap/tidb/pkg/store/driver/error" + "github.com/pingcap/tidb/pkg/util/intest" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/tiflash" + "github.com/pingcap/tidb/pkg/util/tiflashcompute" + "github.com/tikv/client-go/v2/metrics" + "github.com/tikv/client-go/v2/tikv" + "github.com/tikv/client-go/v2/tikvrpc" + "github.com/twmb/murmur3" + "go.uber.org/zap" +) + +const fetchTopoMaxBackoff = 20000 + +// batchCopTask comprises of multiple copTask that will send to same store. +type batchCopTask struct { + storeAddr string + cmdType tikvrpc.CmdType + ctx *tikv.RPCContext + + regionInfos []RegionInfo // region info for single physical table + // PartitionTableRegions indicates region infos for each partition table, used by scanning partitions in batch. + // Thus, one of `regionInfos` and `PartitionTableRegions` must be nil. + PartitionTableRegions []*coprocessor.TableRegions +} + +type batchCopResponse struct { + pbResp *coprocessor.BatchResponse + detail *CopRuntimeStats + + // batch Cop Response is yet to return startKey. So batchCop cannot retry partially. + startKey kv.Key + err error + respSize int64 + respTime time.Duration +} + +// GetData implements the kv.ResultSubset GetData interface. +func (rs *batchCopResponse) GetData() []byte { + return rs.pbResp.Data +} + +// GetStartKey implements the kv.ResultSubset GetStartKey interface. +func (rs *batchCopResponse) GetStartKey() kv.Key { + return rs.startKey +} + +// GetExecDetails is unavailable currently, because TiFlash has not collected exec details for batch cop. +// TODO: Will fix in near future. +func (rs *batchCopResponse) GetCopRuntimeStats() *CopRuntimeStats { + return rs.detail +} + +// MemSize returns how many bytes of memory this response use +func (rs *batchCopResponse) MemSize() int64 { + if rs.respSize != 0 { + return rs.respSize + } + + // ignore rs.err + rs.respSize += int64(cap(rs.startKey)) + if rs.detail != nil { + rs.respSize += int64(sizeofExecDetails) + } + if rs.pbResp != nil { + // Using a approximate size since it's hard to get a accurate value. + rs.respSize += int64(rs.pbResp.Size()) + } + return rs.respSize +} + +func (rs *batchCopResponse) RespTime() time.Duration { + return rs.respTime +} + +func deepCopyStoreTaskMap(storeTaskMap map[uint64]*batchCopTask, avgRegionsNum int) map[uint64]*batchCopTask { + storeTasks := make(map[uint64]*batchCopTask) + for storeID, task := range storeTaskMap { + t := batchCopTask{ + storeAddr: task.storeAddr, + cmdType: task.cmdType, + ctx: task.ctx, + } + t.regionInfos = make([]RegionInfo, len(task.regionInfos), len(task.regionInfos)+avgRegionsNum) + copy(t.regionInfos, task.regionInfos) + storeTasks[storeID] = &t + } + return storeTasks +} + +func regionTotalCount(storeTasks map[uint64]*batchCopTask, candidateRegionInfos []RegionInfo) int { + count := len(candidateRegionInfos) + for _, task := range storeTasks { + count += len(task.regionInfos) + } + return count +} + +const ( + maxBalanceScore = 100 + balanceScoreThreshold = 85 +) + +// Select at most cnt RegionInfos from candidateRegionInfos that belong to storeID. +// If selected[i] is true, candidateRegionInfos[i] has been selected and should be skip. +// storeID2RegionIndex is a map that key is storeID and value is a region index slice. +// selectRegion use storeID2RegionIndex to find RegionInfos that belong to storeID efficiently. +func selectRegion(storeID uint64, candidateRegionInfos []RegionInfo, selected []bool, storeID2RegionIndex map[uint64][]int, cnt int64, regionInfosBuffer []RegionInfo) []RegionInfo { + regionIndexes, ok := storeID2RegionIndex[storeID] + if !ok { + logutil.BgLogger().Error("selectRegion: storeID2RegionIndex not found", zap.Uint64("storeID", storeID)) + return nil + } + i := 0 + for ; i < len(regionIndexes) && len(regionInfosBuffer) < int(cnt); i++ { + idx := regionIndexes[i] + if selected[idx] { + continue + } + selected[idx] = true + regionInfosBuffer = append(regionInfosBuffer, candidateRegionInfos[idx]) + } + // Remove regions that has been selected. + storeID2RegionIndex[storeID] = regionIndexes[i:] + return regionInfosBuffer +} + +// Higher scores mean more balance: (100 - unblance percentage) +func balanceScore(maxRegionCount, minRegionCount int, balanceContinuousRegionCount int64) int { + if minRegionCount <= 0 { + return math.MinInt32 + } + unbalanceCount := maxRegionCount - minRegionCount + if unbalanceCount <= int(balanceContinuousRegionCount) { + return maxBalanceScore + } + return maxBalanceScore - unbalanceCount*100/minRegionCount +} + +func isBalance(score int) bool { + return score >= balanceScoreThreshold +} + +func checkBatchCopTaskBalance(storeTasks map[uint64]*batchCopTask, balanceContinuousRegionCount int64) (int, []string) { + if len(storeTasks) == 0 { + return 0, []string{} + } + maxRegionCount := 0 + minRegionCount := math.MaxInt32 + balanceInfos := []string{} + for storeID, task := range storeTasks { + cnt := len(task.regionInfos) + if cnt > maxRegionCount { + maxRegionCount = cnt + } + if cnt < minRegionCount { + minRegionCount = cnt + } + balanceInfos = append(balanceInfos, fmt.Sprintf("storeID %d storeAddr %s regionCount %d", storeID, task.storeAddr, cnt)) + } + return balanceScore(maxRegionCount, minRegionCount, balanceContinuousRegionCount), balanceInfos +} + +// balanceBatchCopTaskWithContinuity try to balance `continuous regions` between TiFlash Stores. +// In fact, not absolutely continuous is required, regions' range are closed to store in a TiFlash segment is enough for internal read optimization. +// +// First, the caller should guarantee the candidateRegionInfos is ordered. +// Second, build a storeID2RegionIndex data structure to fastly locate regions of a store (avoid scanning candidateRegionInfos repeatedly). +// Third, each store will take balanceContinuousRegionCount from the sorted candidateRegionInfos. These regions are stored very close to each other in TiFlash. +// Fourth, if the region count is not balance between TiFlash, it may fallback to the original balance logic. +func balanceBatchCopTaskWithContinuity(storeTaskMap map[uint64]*batchCopTask, candidateRegionInfos []RegionInfo, balanceContinuousRegionCount int64) ([]*batchCopTask, int) { + if len(candidateRegionInfos) < 500 { + return nil, 0 + } + funcStart := time.Now() + regionCount := regionTotalCount(storeTaskMap, candidateRegionInfos) + storeTasks := deepCopyStoreTaskMap(storeTaskMap, len(candidateRegionInfos)/len(storeTaskMap)) + + balanceStart := time.Now() + // Build storeID -> region index slice index and we can fastly locate regions of a store. + storeID2RegionIndex := make(map[uint64][]int) + for i, ri := range candidateRegionInfos { + for _, storeID := range ri.AllStores { + if val, ok := storeID2RegionIndex[storeID]; ok { + storeID2RegionIndex[storeID] = append(val, i) + } else { + storeID2RegionIndex[storeID] = []int{i} + } + } + } + + // If selected[i] is true, candidateRegionInfos[i] is selected by a store and should skip it in selectRegion. + selected := make([]bool, len(candidateRegionInfos)) + regionInfosBuffer := make([]RegionInfo, 0, balanceContinuousRegionCount) + for { + totalCount := 0 + selectCountThisRound := 0 + for storeID, task := range storeTasks { + // Each store select balanceContinuousRegionCount regions from candidateRegionInfos. + // Since candidateRegionInfos is sorted, it is very likely that these regions are close to each other in TiFlash. + regionInfosBuffer = selectRegion(storeID, candidateRegionInfos, selected, storeID2RegionIndex, balanceContinuousRegionCount, regionInfosBuffer[:0]) + task.regionInfos = append(task.regionInfos, regionInfosBuffer...) + totalCount += len(task.regionInfos) + selectCountThisRound += len(regionInfosBuffer) + } + if totalCount >= regionCount { + break + } + if selectCountThisRound == 0 { + logutil.BgLogger().Error("selectCandidateRegionInfos fail: some region cannot find relevant store.", zap.Int("regionCount", regionCount), zap.Int("candidateCount", len(candidateRegionInfos))) + return nil, 0 + } + } + balanceEnd := time.Now() + + score, balanceInfos := checkBatchCopTaskBalance(storeTasks, balanceContinuousRegionCount) + if !isBalance(score) { + logutil.BgLogger().Warn("balanceBatchCopTaskWithContinuity is not balance", zap.Int("score", score), zap.Strings("balanceInfos", balanceInfos)) + } + + totalCount := 0 + var res []*batchCopTask + for _, task := range storeTasks { + totalCount += len(task.regionInfos) + if len(task.regionInfos) > 0 { + res = append(res, task) + } + } + if totalCount != regionCount { + logutil.BgLogger().Error("balanceBatchCopTaskWithContinuity error", zap.Int("totalCount", totalCount), zap.Int("regionCount", regionCount)) + return nil, 0 + } + + logutil.BgLogger().Debug("balanceBatchCopTaskWithContinuity time", + zap.Int("candidateRegionCount", len(candidateRegionInfos)), + zap.Int64("balanceContinuousRegionCount", balanceContinuousRegionCount), + zap.Int("balanceScore", score), + zap.Duration("balanceTime", balanceEnd.Sub(balanceStart)), + zap.Duration("totalTime", time.Since(funcStart))) + + return res, score +} + +// balanceBatchCopTask balance the regions between available stores, the basic rule is +// 1. the first region of each original batch cop task belongs to its original store because some +// meta data(like the rpc context) in batchCopTask is related to it +// 2. for the remaining regions: +// if there is only 1 available store, then put the region to the related store +// otherwise, these region will be balance between TiFlash stores. +// +// Currently, there are two balance strategies. +// The first balance strategy: use a greedy algorithm to put it into the store with highest weight. This strategy only consider the region count between TiFlash stores. +// +// The second balance strategy: Not only consider the region count between TiFlash stores, but also try to make the regions' range continuous(stored in TiFlash closely). +// If balanceWithContinuity is true, the second balance strategy is enable. +func balanceBatchCopTask(aliveStores []*tikv.Store, originalTasks []*batchCopTask, balanceWithContinuity bool, balanceContinuousRegionCount int64, allRegionInfos []RegionInfo) []*batchCopTask { + if len(originalTasks) == 0 { + log.Info("Batch cop task balancer got an empty task set.") + return originalTasks + } + storeTaskMap := make(map[uint64]*batchCopTask) + // storeCandidateRegionMap stores all the possible store->region map. Its content is + // store id -> region signature -> region info. We can see it as store id -> region lists. + storeCandidateRegionMap := make(map[uint64]map[string]RegionInfo) + totalRegionCandidateNum := 0 + totalRemainingRegionNum := 0 + + for _, s := range aliveStores { + storeTaskMap[s.StoreID()] = &batchCopTask{ + storeAddr: s.GetAddr(), + cmdType: originalTasks[0].cmdType, + ctx: &tikv.RPCContext{Addr: s.GetAddr(), Store: s}, + } + } + + var candidateRegionInfos []RegionInfo = make([]RegionInfo, 0, len(allRegionInfos)) + for _, ri := range allRegionInfos { + // for each region, figure out the valid store num + validStoreNum := 0 + var validStoreID uint64 + for _, storeID := range ri.AllStores { + if _, ok := storeTaskMap[storeID]; ok { + validStoreNum++ + // original store id might be invalid, so we have to set it again. + validStoreID = storeID + } + } + if validStoreNum == 0 { + logutil.BgLogger().Warn("Meet regions that don't have an available store. Give up balancing") + return originalTasks + } else if validStoreNum == 1 { + // if only one store is valid, just put it to storeTaskMap + storeTaskMap[validStoreID].regionInfos = append(storeTaskMap[validStoreID].regionInfos, ri) + } else { + // if more than one store is valid, put the region + // to store candidate map + totalRegionCandidateNum += validStoreNum + totalRemainingRegionNum++ + candidateRegionInfos = append(candidateRegionInfos, ri) + taskKey := ri.Region.String() + for _, storeID := range ri.AllStores { + if _, validStore := storeTaskMap[storeID]; !validStore { + continue + } + if _, ok := storeCandidateRegionMap[storeID]; !ok { + candidateMap := make(map[string]RegionInfo) + storeCandidateRegionMap[storeID] = candidateMap + } + if _, duplicateRegion := storeCandidateRegionMap[storeID][taskKey]; duplicateRegion { + // duplicated region, should not happen, just give up balance + logutil.BgLogger().Warn("Meet duplicated region info during when trying to balance batch cop task, give up balancing") + return originalTasks + } + storeCandidateRegionMap[storeID][taskKey] = ri + } + } + } + + // If balanceBatchCopTaskWithContinuity failed (not balance or return nil), it will fallback to the original balance logic. + // So storeTaskMap should not be modify. + var contiguousTasks []*batchCopTask = nil + contiguousBalanceScore := 0 + if balanceWithContinuity { + contiguousTasks, contiguousBalanceScore = balanceBatchCopTaskWithContinuity(storeTaskMap, candidateRegionInfos, balanceContinuousRegionCount) + if isBalance(contiguousBalanceScore) && contiguousTasks != nil { + return contiguousTasks + } + } + + if totalRemainingRegionNum > 0 { + avgStorePerRegion := float64(totalRegionCandidateNum) / float64(totalRemainingRegionNum) + findNextStore := func(candidateStores []uint64) uint64 { + store := uint64(math.MaxUint64) + weightedRegionNum := math.MaxFloat64 + if candidateStores != nil { + for _, storeID := range candidateStores { + if _, validStore := storeCandidateRegionMap[storeID]; !validStore { + continue + } + num := float64(len(storeCandidateRegionMap[storeID]))/avgStorePerRegion + float64(len(storeTaskMap[storeID].regionInfos)) + if num < weightedRegionNum { + store = storeID + weightedRegionNum = num + } + } + if store != uint64(math.MaxUint64) { + return store + } + } + for storeID := range storeTaskMap { + if _, validStore := storeCandidateRegionMap[storeID]; !validStore { + continue + } + num := float64(len(storeCandidateRegionMap[storeID]))/avgStorePerRegion + float64(len(storeTaskMap[storeID].regionInfos)) + if num < weightedRegionNum { + store = storeID + weightedRegionNum = num + } + } + return store + } + + store := findNextStore(nil) + for totalRemainingRegionNum > 0 { + if store == uint64(math.MaxUint64) { + break + } + var key string + var ri RegionInfo + for key, ri = range storeCandidateRegionMap[store] { + // get the first region + break + } + storeTaskMap[store].regionInfos = append(storeTaskMap[store].regionInfos, ri) + totalRemainingRegionNum-- + for _, id := range ri.AllStores { + if _, ok := storeCandidateRegionMap[id]; ok { + delete(storeCandidateRegionMap[id], key) + totalRegionCandidateNum-- + if len(storeCandidateRegionMap[id]) == 0 { + delete(storeCandidateRegionMap, id) + } + } + } + if totalRemainingRegionNum > 0 { + avgStorePerRegion = float64(totalRegionCandidateNum) / float64(totalRemainingRegionNum) + // it is not optimal because we only check the stores that affected by this region, in fact in order + // to find out the store with the lowest weightedRegionNum, all stores should be checked, but I think + // check only the affected stores is more simple and will get a good enough result + store = findNextStore(ri.AllStores) + } + } + if totalRemainingRegionNum > 0 { + logutil.BgLogger().Warn("Some regions are not used when trying to balance batch cop task, give up balancing") + return originalTasks + } + } + + if contiguousTasks != nil { + score, balanceInfos := checkBatchCopTaskBalance(storeTaskMap, balanceContinuousRegionCount) + if !isBalance(score) { + logutil.BgLogger().Warn("Region count is not balance and use contiguousTasks", zap.Int("contiguousBalanceScore", contiguousBalanceScore), zap.Int("score", score), zap.Strings("balanceInfos", balanceInfos)) + return contiguousTasks + } + } + + var ret []*batchCopTask + for _, task := range storeTaskMap { + if len(task.regionInfos) > 0 { + ret = append(ret, task) + } + } + return ret +} + +func buildBatchCopTasksForNonPartitionedTable( + ctx context.Context, + bo *backoff.Backoffer, + store *kvStore, + ranges *KeyRanges, + storeType kv.StoreType, + isMPP bool, + ttl time.Duration, + balanceWithContinuity bool, + balanceContinuousRegionCount int64, + dispatchPolicy tiflashcompute.DispatchPolicy, + tiflashReplicaReadPolicy tiflash.ReplicaRead, + appendWarning func(error)) ([]*batchCopTask, error) { + if config.GetGlobalConfig().DisaggregatedTiFlash { + if config.GetGlobalConfig().UseAutoScaler { + return buildBatchCopTasksConsistentHash(ctx, bo, store, []*KeyRanges{ranges}, storeType, ttl, dispatchPolicy) + } + return buildBatchCopTasksConsistentHashForPD(bo, store, []*KeyRanges{ranges}, storeType, ttl, dispatchPolicy) + } + return buildBatchCopTasksCore(bo, store, []*KeyRanges{ranges}, storeType, isMPP, ttl, balanceWithContinuity, balanceContinuousRegionCount, tiflashReplicaReadPolicy, appendWarning) +} + +func buildBatchCopTasksForPartitionedTable( + ctx context.Context, + bo *backoff.Backoffer, + store *kvStore, + rangesForEachPhysicalTable []*KeyRanges, + storeType kv.StoreType, + isMPP bool, + ttl time.Duration, + balanceWithContinuity bool, + balanceContinuousRegionCount int64, + partitionIDs []int64, + dispatchPolicy tiflashcompute.DispatchPolicy, + tiflashReplicaReadPolicy tiflash.ReplicaRead, + appendWarning func(error)) (batchTasks []*batchCopTask, err error) { + if config.GetGlobalConfig().DisaggregatedTiFlash { + if config.GetGlobalConfig().UseAutoScaler { + batchTasks, err = buildBatchCopTasksConsistentHash(ctx, bo, store, rangesForEachPhysicalTable, storeType, ttl, dispatchPolicy) + } else { + // todo: remove this after AutoScaler is stable. + batchTasks, err = buildBatchCopTasksConsistentHashForPD(bo, store, rangesForEachPhysicalTable, storeType, ttl, dispatchPolicy) + } + } else { + batchTasks, err = buildBatchCopTasksCore(bo, store, rangesForEachPhysicalTable, storeType, isMPP, ttl, balanceWithContinuity, balanceContinuousRegionCount, tiflashReplicaReadPolicy, appendWarning) + } + if err != nil { + return nil, err + } + // generate tableRegions for batchCopTasks + convertRegionInfosToPartitionTableRegions(batchTasks, partitionIDs) + return batchTasks, nil +} + +func filterAliveStoresStr(ctx context.Context, storesStr []string, ttl time.Duration, kvStore *kvStore) (aliveStores []string) { + aliveIdx := filterAliveStoresHelper(ctx, storesStr, ttl, kvStore) + for _, idx := range aliveIdx { + aliveStores = append(aliveStores, storesStr[idx]) + } + return aliveStores +} + +func filterAliveStores(ctx context.Context, stores []*tikv.Store, ttl time.Duration, kvStore *kvStore) (aliveStores []*tikv.Store) { + storesStr := make([]string, 0, len(stores)) + for _, s := range stores { + storesStr = append(storesStr, s.GetAddr()) + } + + aliveIdx := filterAliveStoresHelper(ctx, storesStr, ttl, kvStore) + for _, idx := range aliveIdx { + aliveStores = append(aliveStores, stores[idx]) + } + return aliveStores +} + +func filterAliveStoresHelper(ctx context.Context, stores []string, ttl time.Duration, kvStore *kvStore) (aliveIdx []int) { + var wg sync.WaitGroup + var mu sync.Mutex + wg.Add(len(stores)) + for i := range stores { + go func(idx int) { + defer wg.Done() + s := stores[idx] + + // Check if store is failed already. + if ok := GlobalMPPFailedStoreProber.IsRecovery(ctx, s, ttl); !ok { + return + } + + tikvClient := kvStore.GetTiKVClient() + if ok := detectMPPStore(ctx, tikvClient, s, DetectTimeoutLimit); !ok { + GlobalMPPFailedStoreProber.Add(ctx, s, tikvClient) + return + } + + mu.Lock() + defer mu.Unlock() + aliveIdx = append(aliveIdx, idx) + }(i) + } + wg.Wait() + + logutil.BgLogger().Info("detecting available mpp stores", zap.Any("total", len(stores)), zap.Any("alive", len(aliveIdx))) + return aliveIdx +} + +func getTiFlashComputeRPCContextByConsistentHash(ids []tikv.RegionVerID, storesStr []string) (res []*tikv.RPCContext, err error) { + // Use RendezvousHash + for _, id := range ids { + var maxHash uint32 = 0 + var maxHashStore string = "" + for _, store := range storesStr { + h := murmur3.StringSum32(fmt.Sprintf("%s-%d", store, id.GetID())) + if h > maxHash { + maxHash = h + maxHashStore = store + } + } + rpcCtx := &tikv.RPCContext{ + Region: id, + Addr: maxHashStore, + } + res = append(res, rpcCtx) + } + return res, nil +} + +func getTiFlashComputeRPCContextByRoundRobin(ids []tikv.RegionVerID, storesStr []string) (res []*tikv.RPCContext, err error) { + startIdx := rand.Intn(len(storesStr)) + for _, id := range ids { + rpcCtx := &tikv.RPCContext{ + Region: id, + Addr: storesStr[startIdx%len(storesStr)], + } + + startIdx++ + res = append(res, rpcCtx) + } + return res, nil +} + +// 1. Split range by region location to build copTasks. +// 2. For each copTask build its rpcCtx , the target tiflash_compute node will be chosen using consistent hash. +// 3. All copTasks that will be sent to one tiflash_compute node are put in one batchCopTask. +func buildBatchCopTasksConsistentHash( + ctx context.Context, + bo *backoff.Backoffer, + kvStore *kvStore, + rangesForEachPhysicalTable []*KeyRanges, + storeType kv.StoreType, + ttl time.Duration, + dispatchPolicy tiflashcompute.DispatchPolicy) (res []*batchCopTask, err error) { + failpointCheckWhichPolicy(dispatchPolicy) + start := time.Now() + const cmdType = tikvrpc.CmdBatchCop + cache := kvStore.GetRegionCache() + fetchTopoBo := backoff.NewBackofferWithVars(ctx, fetchTopoMaxBackoff, nil) + + var ( + retryNum int + rangesLen int + storesStr []string + ) + + tasks := make([]*copTask, 0) + regionIDs := make([]tikv.RegionVerID, 0) + + for i, ranges := range rangesForEachPhysicalTable { + rangesLen += ranges.Len() + locations, err := cache.SplitKeyRangesByLocations(bo, ranges, UnspecifiedLimit, false, false) + if err != nil { + return nil, errors.Trace(err) + } + for _, lo := range locations { + tasks = append(tasks, &copTask{ + region: lo.Location.Region, + ranges: lo.Ranges, + cmdType: cmdType, + storeType: storeType, + partitionIndex: int64(i), + }) + regionIDs = append(regionIDs, lo.Location.Region) + } + } + splitKeyElapsed := time.Since(start) + + fetchTopoStart := time.Now() + for { + retryNum++ + storesStr, err = tiflashcompute.GetGlobalTopoFetcher().FetchAndGetTopo() + if err != nil { + return nil, err + } + storesBefFilter := len(storesStr) + storesStr = filterAliveStoresStr(ctx, storesStr, ttl, kvStore) + logutil.BgLogger().Info("topo filter alive", zap.Any("topo", storesStr)) + if len(storesStr) == 0 { + errMsg := "Cannot find proper topo to dispatch MPPTask: " + if storesBefFilter == 0 { + errMsg += "topo from AutoScaler is empty" + } else { + errMsg += "detect aliveness failed, no alive ComputeNode" + } + retErr := errors.New(errMsg) + logutil.BgLogger().Info("buildBatchCopTasksConsistentHash retry because FetchAndGetTopo return empty topo", zap.Int("retryNum", retryNum)) + if intest.InTest && retryNum > 3 { + return nil, retErr + } + err := fetchTopoBo.Backoff(tikv.BoTiFlashRPC(), retErr) + if err != nil { + return nil, retErr + } + continue + } + break + } + fetchTopoElapsed := time.Since(fetchTopoStart) + + var rpcCtxs []*tikv.RPCContext + if dispatchPolicy == tiflashcompute.DispatchPolicyRR { + rpcCtxs, err = getTiFlashComputeRPCContextByRoundRobin(regionIDs, storesStr) + } else if dispatchPolicy == tiflashcompute.DispatchPolicyConsistentHash { + rpcCtxs, err = getTiFlashComputeRPCContextByConsistentHash(regionIDs, storesStr) + } else { + err = errors.Errorf("unexpected dispatch policy %v", dispatchPolicy) + } + if err != nil { + return nil, err + } + if len(rpcCtxs) != len(tasks) { + return nil, errors.Errorf("length should be equal, len(rpcCtxs): %d, len(tasks): %d", len(rpcCtxs), len(tasks)) + } + taskMap := make(map[string]*batchCopTask) + for i, rpcCtx := range rpcCtxs { + regionInfo := RegionInfo{ + // tasks and rpcCtxs are correspond to each other. + Region: tasks[i].region, + Ranges: tasks[i].ranges, + PartitionIndex: tasks[i].partitionIndex, + } + if batchTask, ok := taskMap[rpcCtx.Addr]; ok { + batchTask.regionInfos = append(batchTask.regionInfos, regionInfo) + } else { + batchTask := &batchCopTask{ + storeAddr: rpcCtx.Addr, + cmdType: cmdType, + ctx: rpcCtx, + regionInfos: []RegionInfo{regionInfo}, + } + taskMap[rpcCtx.Addr] = batchTask + res = append(res, batchTask) + } + } + logutil.BgLogger().Info("buildBatchCopTasksConsistentHash done", + zap.Any("len(tasks)", len(taskMap)), + zap.Any("len(tiflash_compute)", len(storesStr)), + zap.Any("dispatchPolicy", tiflashcompute.GetDispatchPolicy(dispatchPolicy))) + + if log.GetLevel() <= zap.DebugLevel { + debugTaskMap := make(map[string]string, len(taskMap)) + for s, b := range taskMap { + debugTaskMap[s] = fmt.Sprintf("addr: %s; regionInfos: %v", b.storeAddr, b.regionInfos) + } + logutil.BgLogger().Debug("detailed info buildBatchCopTasksConsistentHash", zap.Any("taskMap", debugTaskMap), zap.Any("allStores", storesStr)) + } + + if elapsed := time.Since(start); elapsed > time.Millisecond*500 { + logutil.BgLogger().Warn("buildBatchCopTasksConsistentHash takes too much time", + zap.Duration("total elapsed", elapsed), + zap.Int("retryNum", retryNum), + zap.Duration("splitKeyElapsed", splitKeyElapsed), + zap.Duration("fetchTopoElapsed", fetchTopoElapsed), + zap.Int("range len", rangesLen), + zap.Int("copTaskNum", len(tasks)), + zap.Int("batchCopTaskNum", len(res))) + } + failpointCheckForConsistentHash(res) + return res, nil +} + +func failpointCheckForConsistentHash(tasks []*batchCopTask) { + failpoint.Inject("checkOnlyDispatchToTiFlashComputeNodes", func(val failpoint.Value) { + logutil.BgLogger().Debug("in checkOnlyDispatchToTiFlashComputeNodes") + + // This failpoint will be tested in test-infra case, because we needs setup a cluster. + // All tiflash_compute nodes addrs are stored in val, separated by semicolon. + str := val.(string) + addrs := strings.Split(str, ";") + if len(addrs) < 1 { + err := fmt.Sprintf("unexpected length of tiflash_compute node addrs: %v, %s", len(addrs), str) + panic(err) + } + addrMap := make(map[string]struct{}) + for _, addr := range addrs { + addrMap[addr] = struct{}{} + } + for _, batchTask := range tasks { + if _, ok := addrMap[batchTask.storeAddr]; !ok { + err := errors.Errorf("batchCopTask send to node which is not tiflash_compute: %v(tiflash_compute nodes: %s)", batchTask.storeAddr, str) + panic(err) + } + } + }) +} + +func failpointCheckWhichPolicy(act tiflashcompute.DispatchPolicy) { + failpoint.Inject("testWhichDispatchPolicy", func(exp failpoint.Value) { + expStr := exp.(string) + actStr := tiflashcompute.GetDispatchPolicy(act) + if actStr != expStr { + err := errors.Errorf("tiflash_compute dispatch should be %v, but got %v", expStr, actStr) + panic(err) + } + }) +} + +func filterAllStoresAccordingToTiFlashReplicaRead(allStores []uint64, aliveStores *aliveStoresBundle, policy tiflash.ReplicaRead) (storesMatchedPolicy []uint64, needsCrossZoneAccess bool) { + if policy.IsAllReplicas() { + for _, id := range allStores { + if _, ok := aliveStores.storeIDsInAllZones[id]; ok { + storesMatchedPolicy = append(storesMatchedPolicy, id) + } + } + return + } + // Check whether exists available stores in TiDB zone. If so, we only need to access TiFlash stores in TiDB zone. + for _, id := range allStores { + if _, ok := aliveStores.storeIDsInTiDBZone[id]; ok { + storesMatchedPolicy = append(storesMatchedPolicy, id) + } + } + // If no available stores in TiDB zone, we need to access TiFlash stores in other zones. + if len(storesMatchedPolicy) == 0 { + // needsCrossZoneAccess indicates whether we need to access(directly read or remote read) TiFlash stores in other zones. + needsCrossZoneAccess = true + + if policy == tiflash.ClosestAdaptive { + // If the policy is `ClosestAdaptive`, we can dispatch tasks to the TiFlash stores in other zones. + for _, id := range allStores { + if _, ok := aliveStores.storeIDsInAllZones[id]; ok { + storesMatchedPolicy = append(storesMatchedPolicy, id) + } + } + } else if policy == tiflash.ClosestReplicas { + // If the policy is `ClosestReplicas`, we dispatch tasks to the TiFlash stores in TiDB zone and remote read from other zones. + for id := range aliveStores.storeIDsInTiDBZone { + storesMatchedPolicy = append(storesMatchedPolicy, id) + } + } + } + return +} + +func getAllUsedTiFlashStores(allTiFlashStores []*tikv.Store, allUsedTiFlashStoresMap map[uint64]struct{}) []*tikv.Store { + allUsedTiFlashStores := make([]*tikv.Store, 0, len(allUsedTiFlashStoresMap)) + for _, store := range allTiFlashStores { + _, ok := allUsedTiFlashStoresMap[store.StoreID()] + if ok { + allUsedTiFlashStores = append(allUsedTiFlashStores, store) + } + } + return allUsedTiFlashStores +} + +// getAliveStoresAndStoreIDs gets alive TiFlash stores and their IDs. +// If tiflashReplicaReadPolicy is not all_replicas, it will also return the IDs of the alive TiFlash stores in TiDB zone. +func getAliveStoresAndStoreIDs(ctx context.Context, cache *RegionCache, allUsedTiFlashStoresMap map[uint64]struct{}, ttl time.Duration, store *kvStore, tiflashReplicaReadPolicy tiflash.ReplicaRead, tidbZone string) (aliveStores *aliveStoresBundle) { + aliveStores = new(aliveStoresBundle) + allTiFlashStores := cache.RegionCache.GetTiFlashStores(tikv.LabelFilterNoTiFlashWriteNode) + allUsedTiFlashStores := getAllUsedTiFlashStores(allTiFlashStores, allUsedTiFlashStoresMap) + aliveStores.storesInAllZones = filterAliveStores(ctx, allUsedTiFlashStores, ttl, store) + + if !tiflashReplicaReadPolicy.IsAllReplicas() { + aliveStores.storeIDsInTiDBZone = make(map[uint64]struct{}, len(aliveStores.storesInAllZones)) + for _, as := range aliveStores.storesInAllZones { + // If the `zone` label of the TiFlash store is not set, we treat it as a TiFlash store in other zones. + if tiflashZone, isSet := as.GetLabelValue(placement.DCLabelKey); isSet && tiflashZone == tidbZone { + aliveStores.storeIDsInTiDBZone[as.StoreID()] = struct{}{} + aliveStores.storesInTiDBZone = append(aliveStores.storesInTiDBZone, as) + } + } + } + if !tiflashReplicaReadPolicy.IsClosestReplicas() { + aliveStores.storeIDsInAllZones = make(map[uint64]struct{}, len(aliveStores.storesInAllZones)) + for _, as := range aliveStores.storesInAllZones { + aliveStores.storeIDsInAllZones[as.StoreID()] = struct{}{} + } + } + return aliveStores +} + +// filterAccessibleStoresAndBuildRegionInfo filters the stores that can be accessed according to: +// 1. tiflash_replica_read policy +// 2. whether the store is alive +// After filtering, it will build the RegionInfo. +func filterAccessibleStoresAndBuildRegionInfo( + cache *RegionCache, + allStores []uint64, + bo *Backoffer, + task *copTask, + rpcCtx *tikv.RPCContext, + aliveStores *aliveStoresBundle, + tiflashReplicaReadPolicy tiflash.ReplicaRead, + regionInfoNeedsReloadOnSendFail []RegionInfo, + regionsInOtherZones []uint64, + maxRemoteReadCountAllowed int, + tidbZone string) (regionInfo RegionInfo, _ []RegionInfo, _ []uint64, err error) { + needCrossZoneAccess := false + allStores, needCrossZoneAccess = filterAllStoresAccordingToTiFlashReplicaRead(allStores, aliveStores, tiflashReplicaReadPolicy) + + regionInfo = RegionInfo{ + Region: task.region, + Meta: rpcCtx.Meta, + Ranges: task.ranges, + AllStores: allStores, + PartitionIndex: task.partitionIndex} + + if needCrossZoneAccess { + regionsInOtherZones = append(regionsInOtherZones, task.region.GetID()) + regionInfoNeedsReloadOnSendFail = append(regionInfoNeedsReloadOnSendFail, regionInfo) + if tiflashReplicaReadPolicy.IsClosestReplicas() && len(regionsInOtherZones) > maxRemoteReadCountAllowed { + regionIDErrMsg := "" + for i := 0; i < 3 && i < len(regionsInOtherZones); i++ { + regionIDErrMsg += fmt.Sprintf("%d, ", regionsInOtherZones[i]) + } + err = errors.Errorf( + "no less than %d region(s) can not be accessed by TiFlash in the zone [%s]: %setc", + len(regionsInOtherZones), tidbZone, regionIDErrMsg) + // We need to reload the region cache here to avoid the failure throughout the region cache refresh TTL. + cache.OnSendFailForBatchRegions(bo, rpcCtx.Store, regionInfoNeedsReloadOnSendFail, true, err) + return regionInfo, nil, nil, err + } + } + return regionInfo, regionInfoNeedsReloadOnSendFail, regionsInOtherZones, nil +} + +type aliveStoresBundle struct { + storesInAllZones []*tikv.Store + storeIDsInAllZones map[uint64]struct{} + storesInTiDBZone []*tikv.Store + storeIDsInTiDBZone map[uint64]struct{} +} + +// When `partitionIDs != nil`, it means that buildBatchCopTasksCore is constructing a batch cop tasks for PartitionTableScan. +// At this time, `len(rangesForEachPhysicalTable) == len(partitionIDs)` and `rangesForEachPhysicalTable[i]` is for partition `partitionIDs[i]`. +// Otherwise, `rangesForEachPhysicalTable[0]` indicates the range for the single physical table. +func buildBatchCopTasksCore(bo *backoff.Backoffer, store *kvStore, rangesForEachPhysicalTable []*KeyRanges, storeType kv.StoreType, isMPP bool, ttl time.Duration, balanceWithContinuity bool, balanceContinuousRegionCount int64, tiflashReplicaReadPolicy tiflash.ReplicaRead, appendWarning func(error)) ([]*batchCopTask, error) { + cache := store.GetRegionCache() + start := time.Now() + const cmdType = tikvrpc.CmdBatchCop + rangesLen := 0 + + tidbZone, isTiDBLabelZoneSet := config.GetGlobalConfig().Labels[placement.DCLabelKey] + var ( + aliveStores *aliveStoresBundle + maxRemoteReadCountAllowed int + ) + if !isTiDBLabelZoneSet { + tiflashReplicaReadPolicy = tiflash.AllReplicas + } + + for { + var tasks []*copTask + var tasksForPartitions [][]*copTask = make([][]*copTask, len(rangesForEachPhysicalTable)) + rangesLen = 0 + for i, ranges := range rangesForEachPhysicalTable { + rangesLen += ranges.Len() + locations, err := cache.SplitKeyRangesByLocations(bo, ranges, UnspecifiedLimit, false, false) + if err != nil { + return nil, errors.Trace(err) + } + tasksForPartitions[i] = make([]*copTask, 0, len(locations)) + for _, lo := range locations { + tasksForPartitions[i] = append(tasksForPartitions[i], &copTask{ + region: lo.Location.Region, + ranges: lo.Ranges, + cmdType: cmdType, + storeType: storeType, + partitionIndex: int64(i), + }) + } + } + if len(tasksForPartitions) == 1 { + tasks = tasksForPartitions[0] + } else { + slices.SortFunc(tasksForPartitions, func(a, b []*copTask) int { + if len(a) == 0 { + return -1 + } + if len(b) == 0 { + return 1 + } + return a[0].ranges.RefAt(0).StartKey.Cmp(b[0].ranges.RefAt(0).StartKey) + }) + // The ranges corresponding to each partiton do not intersect, so we can merge tasks directly + for _, tasksForPartition := range tasksForPartitions { + tasks = append(tasks, tasksForPartition...) + } + } + + rpcCtxs := make([]*tikv.RPCContext, 0, len(tasks)) + usedTiFlashStores := make([][]uint64, 0, len(tasks)) + usedTiFlashStoresMap := make(map[uint64]struct{}, 0) + needRetry := false + for _, task := range tasks { + rpcCtx, err := cache.GetTiFlashRPCContext(bo.TiKVBackoffer(), task.region, isMPP, tikv.LabelFilterNoTiFlashWriteNode) + if err != nil { + return nil, errors.Trace(err) + } + + // When rpcCtx is nil, it's not only attributed to the miss region, but also + // some TiFlash stores crash and can't be recovered. + // That is not an error that can be easily recovered, so we regard this error + // same as rpc error. + if rpcCtx == nil { + needRetry = true + logutil.BgLogger().Info("retry for TiFlash peer with region missing", zap.Uint64("region id", task.region.GetID())) + // Probably all the regions are invalid. Make the loop continue and mark all the regions invalid. + // Then `splitRegion` will reloads these regions. + continue + } + + allStores, _ := cache.GetAllValidTiFlashStores(task.region, rpcCtx.Store, tikv.LabelFilterNoTiFlashWriteNode) + for _, storeID := range allStores { + usedTiFlashStoresMap[storeID] = struct{}{} + } + rpcCtxs = append(rpcCtxs, rpcCtx) + usedTiFlashStores = append(usedTiFlashStores, allStores) + } + + if needRetry { + // As mentioned above, nil rpcCtx is always attributed to failed stores. + // It's equal to long poll the store but get no response. Here we'd better use + // TiFlash error to trigger the TiKV fallback mechanism. + err := bo.Backoff(tikv.BoTiFlashRPC(), errors.New("Cannot find region with TiFlash peer")) + if err != nil { + return nil, errors.Trace(err) + } + continue + } + + aliveStores = getAliveStoresAndStoreIDs(bo.GetCtx(), cache, usedTiFlashStoresMap, ttl, store, tiflashReplicaReadPolicy, tidbZone) + if tiflashReplicaReadPolicy.IsClosestReplicas() { + if len(aliveStores.storeIDsInTiDBZone) == 0 { + return nil, errors.Errorf("There is no region in tidb zone(%s)", tidbZone) + } + maxRemoteReadCountAllowed = len(aliveStores.storeIDsInTiDBZone) * tiflash.MaxRemoteReadCountPerNodeForClosestReplicas + } + + var batchTasks []*batchCopTask + var regionIDsInOtherZones []uint64 + var regionInfosNeedReloadOnSendFail []RegionInfo + var allRegionInfos []RegionInfo + storeTaskMap := make(map[string]*batchCopTask) + storeIDsUnionSetForAllTasks := make(map[uint64]struct{}) + for idx, task := range tasks { + var err error + var regionInfo RegionInfo + regionInfo, regionInfosNeedReloadOnSendFail, regionIDsInOtherZones, err = filterAccessibleStoresAndBuildRegionInfo(cache, usedTiFlashStores[idx], bo, task, rpcCtxs[idx], aliveStores, tiflashReplicaReadPolicy, regionInfosNeedReloadOnSendFail, regionIDsInOtherZones, maxRemoteReadCountAllowed, tidbZone) + if err != nil { + return nil, err + } + if batchCop, ok := storeTaskMap[rpcCtxs[idx].Addr]; ok { + batchCop.regionInfos = append(batchCop.regionInfos, regionInfo) + } else { + batchTask := &batchCopTask{ + storeAddr: rpcCtxs[idx].Addr, + cmdType: cmdType, + ctx: rpcCtxs[idx], + regionInfos: []RegionInfo{regionInfo}, + } + storeTaskMap[rpcCtxs[idx].Addr] = batchTask + } + for _, storeID := range regionInfo.AllStores { + storeIDsUnionSetForAllTasks[storeID] = struct{}{} + } + allRegionInfos = append(allRegionInfos, regionInfo) + } + + if len(regionIDsInOtherZones) != 0 { + warningMsg := fmt.Sprintf("total %d region(s) can not be accessed by TiFlash in the zone [%s]:", len(regionIDsInOtherZones), tidbZone) + regionIDErrMsg := "" + for i := 0; i < 3 && i < len(regionIDsInOtherZones); i++ { + regionIDErrMsg += fmt.Sprintf("%d, ", regionIDsInOtherZones[i]) + } + warningMsg += regionIDErrMsg + "etc" + appendWarning(errors.NewNoStackError(warningMsg)) + } + + for _, task := range storeTaskMap { + batchTasks = append(batchTasks, task) + } + if log.GetLevel() <= zap.DebugLevel { + msg := "Before region balance:" + for _, task := range batchTasks { + msg += " store " + task.storeAddr + ": " + strconv.Itoa(len(task.regionInfos)) + " regions," + } + logutil.BgLogger().Debug(msg) + } + balanceStart := time.Now() + storesUnionSetForAllTasks := make([]*tikv.Store, 0, len(storeIDsUnionSetForAllTasks)) + for _, store := range aliveStores.storesInAllZones { + if _, ok := storeIDsUnionSetForAllTasks[store.StoreID()]; ok { + storesUnionSetForAllTasks = append(storesUnionSetForAllTasks, store) + } + } + batchTasks = balanceBatchCopTask(storesUnionSetForAllTasks, batchTasks, balanceWithContinuity, balanceContinuousRegionCount, allRegionInfos) + balanceElapsed := time.Since(balanceStart) + if log.GetLevel() <= zap.DebugLevel { + msg := "After region balance:" + for _, task := range batchTasks { + msg += " store " + task.storeAddr + ": " + strconv.Itoa(len(task.regionInfos)) + " regions," + } + logutil.BgLogger().Debug(msg) + } + + if elapsed := time.Since(start); elapsed > time.Millisecond*500 { + logutil.BgLogger().Warn("buildBatchCopTasksCore takes too much time", + zap.Duration("elapsed", elapsed), + zap.Duration("balanceElapsed", balanceElapsed), + zap.Int("range len", rangesLen), + zap.Int("task len", len(batchTasks))) + } + metrics.TxnRegionsNumHistogramWithBatchCoprocessor.Observe(float64(len(batchTasks))) + return batchTasks, nil + } +} + +func convertRegionInfosToPartitionTableRegions(batchTasks []*batchCopTask, partitionIDs []int64) { + for _, copTask := range batchTasks { + tableRegions := make([]*coprocessor.TableRegions, len(partitionIDs)) + // init coprocessor.TableRegions + for j, pid := range partitionIDs { + tableRegions[j] = &coprocessor.TableRegions{ + PhysicalTableId: pid, + } + } + // fill region infos + for _, ri := range copTask.regionInfos { + tableRegions[ri.PartitionIndex].Regions = append(tableRegions[ri.PartitionIndex].Regions, + ri.toCoprocessorRegionInfo()) + } + count := 0 + // clear empty table region + for j := 0; j < len(tableRegions); j++ { + if len(tableRegions[j].Regions) != 0 { + tableRegions[count] = tableRegions[j] + count++ + } + } + copTask.PartitionTableRegions = tableRegions[:count] + copTask.regionInfos = nil + } +} + +func (c *CopClient) sendBatch(ctx context.Context, req *kv.Request, vars *tikv.Variables, option *kv.ClientSendOption) kv.Response { + if req.KeepOrder || req.Desc { + return copErrorResponse{errors.New("batch coprocessor cannot prove keep order or desc property")} + } + ctx = context.WithValue(ctx, tikv.TxnStartKey(), req.StartTs) + bo := backoff.NewBackofferWithVars(ctx, copBuildTaskMaxBackoff, vars) + + if req.MaxExecutionTime > 0 { + // If MaxExecutionTime is set, we need to set the deadline for the whole batch coprocessor request context. + ctxWithTimeout, cancel := context.WithTimeout(bo.GetCtx(), time.Duration(req.MaxExecutionTime)*time.Millisecond) + defer cancel() + bo.TiKVBackoffer().SetCtx(ctxWithTimeout) + } + + var tasks []*batchCopTask + var err error + if req.PartitionIDAndRanges != nil { + // For Partition Table Scan + keyRanges := make([]*KeyRanges, 0, len(req.PartitionIDAndRanges)) + partitionIDs := make([]int64, 0, len(req.PartitionIDAndRanges)) + for _, pi := range req.PartitionIDAndRanges { + keyRanges = append(keyRanges, NewKeyRanges(pi.KeyRanges)) + partitionIDs = append(partitionIDs, pi.ID) + } + tasks, err = buildBatchCopTasksForPartitionedTable(ctx, bo, c.store.kvStore, keyRanges, req.StoreType, false, 0, false, 0, partitionIDs, tiflashcompute.DispatchPolicyInvalid, option.TiFlashReplicaRead, option.AppendWarning) + } else { + // TODO: merge the if branch. + ranges := NewKeyRanges(req.KeyRanges.FirstPartitionRange()) + tasks, err = buildBatchCopTasksForNonPartitionedTable(ctx, bo, c.store.kvStore, ranges, req.StoreType, false, 0, false, 0, tiflashcompute.DispatchPolicyInvalid, option.TiFlashReplicaRead, option.AppendWarning) + } + + if err != nil { + return copErrorResponse{err} + } + it := &batchCopIterator{ + store: c.store.kvStore, + req: req, + finishCh: make(chan struct{}), + vars: vars, + rpcCancel: tikv.NewRPCanceller(), + enableCollectExecutionInfo: option.EnableCollectExecutionInfo, + tiflashReplicaReadPolicy: option.TiFlashReplicaRead, + appendWarning: option.AppendWarning, + } + ctx = context.WithValue(ctx, tikv.RPCCancellerCtxKey{}, it.rpcCancel) + it.tasks = tasks + it.respChan = make(chan *batchCopResponse, 2048) + go it.run(ctx) + return it +} + +type batchCopIterator struct { + store *kvStore + req *kv.Request + finishCh chan struct{} + + tasks []*batchCopTask + + // Batch results are stored in respChan. + respChan chan *batchCopResponse + + vars *tikv.Variables + + rpcCancel *tikv.RPCCanceller + + wg sync.WaitGroup + // closed represents when the Close is called. + // There are two cases we need to close the `finishCh` channel, one is when context is done, the other one is + // when the Close is called. we use atomic.CompareAndSwap `closed` to to make sure the channel is not closed twice. + closed uint32 + + enableCollectExecutionInfo bool + tiflashReplicaReadPolicy tiflash.ReplicaRead + appendWarning func(error) +} + +func (b *batchCopIterator) run(ctx context.Context) { + // We run workers for every batch cop. + for _, task := range b.tasks { + b.wg.Add(1) + boMaxSleep := CopNextMaxBackoff + failpoint.Inject("ReduceCopNextMaxBackoff", func(value failpoint.Value) { + if value.(bool) { + boMaxSleep = 2 + } + }) + bo := backoff.NewBackofferWithVars(ctx, boMaxSleep, b.vars) + go b.handleTask(ctx, bo, task) + } + b.wg.Wait() + close(b.respChan) +} + +// Next returns next coprocessor result. +// NOTE: Use nil to indicate finish, so if the returned ResultSubset is not nil, reader should continue to call Next(). +func (b *batchCopIterator) Next(ctx context.Context) (kv.ResultSubset, error) { + var ( + resp *batchCopResponse + ok bool + closed bool + ) + + // Get next fetched resp from chan + resp, ok, closed = b.recvFromRespCh(ctx) + if !ok || closed { + return nil, nil + } + + if resp.err != nil { + return nil, errors.Trace(resp.err) + } + + err := b.store.CheckVisibility(b.req.StartTs) + if err != nil { + return nil, errors.Trace(err) + } + return resp, nil +} + +func (b *batchCopIterator) recvFromRespCh(ctx context.Context) (resp *batchCopResponse, ok bool, exit bool) { + ticker := time.NewTicker(3 * time.Second) + defer ticker.Stop() + for { + select { + case resp, ok = <-b.respChan: + return + case <-ticker.C: + killed := atomic.LoadUint32(b.vars.Killed) + if killed != 0 { + logutil.Logger(ctx).Info( + "a killed signal is received", + zap.Uint32("signal", killed), + ) + resp = &batchCopResponse{err: derr.ErrQueryInterrupted} + ok = true + return + } + case <-b.finishCh: + exit = true + return + case <-ctx.Done(): + // We select the ctx.Done() in the thread of `Next` instead of in the worker to avoid the cost of `WithCancel`. + if atomic.CompareAndSwapUint32(&b.closed, 0, 1) { + close(b.finishCh) + } + exit = true + return + } + } +} + +// Close releases the resource. +func (b *batchCopIterator) Close() error { + if atomic.CompareAndSwapUint32(&b.closed, 0, 1) { + close(b.finishCh) + } + b.rpcCancel.CancelAll() + b.wg.Wait() + return nil +} + +func (b *batchCopIterator) handleTask(ctx context.Context, bo *Backoffer, task *batchCopTask) { + tasks := []*batchCopTask{task} + for idx := 0; idx < len(tasks); idx++ { + ret, err := b.handleTaskOnce(ctx, bo, tasks[idx]) + if err != nil { + resp := &batchCopResponse{err: errors.Trace(err), detail: new(CopRuntimeStats)} + b.sendToRespCh(resp) + break + } + tasks = append(tasks, ret...) + } + b.wg.Done() +} + +// Merge all ranges and request again. +func (b *batchCopIterator) retryBatchCopTask(ctx context.Context, bo *backoff.Backoffer, batchTask *batchCopTask) ([]*batchCopTask, error) { + if batchTask.regionInfos != nil { + var ranges []kv.KeyRange + for _, ri := range batchTask.regionInfos { + ri.Ranges.Do(func(ran *kv.KeyRange) { + ranges = append(ranges, *ran) + }) + } + // need to make sure the key ranges is sorted + slices.SortFunc(ranges, func(i, j kv.KeyRange) int { + return bytes.Compare(i.StartKey, j.StartKey) + }) + ret, err := buildBatchCopTasksForNonPartitionedTable(ctx, bo, b.store, NewKeyRanges(ranges), b.req.StoreType, false, 0, false, 0, tiflashcompute.DispatchPolicyInvalid, b.tiflashReplicaReadPolicy, b.appendWarning) + return ret, err + } + // Retry Partition Table Scan + keyRanges := make([]*KeyRanges, 0, len(batchTask.PartitionTableRegions)) + pid := make([]int64, 0, len(batchTask.PartitionTableRegions)) + for _, trs := range batchTask.PartitionTableRegions { + pid = append(pid, trs.PhysicalTableId) + ranges := make([]kv.KeyRange, 0, len(trs.Regions)) + for _, ri := range trs.Regions { + for _, ran := range ri.Ranges { + ranges = append(ranges, kv.KeyRange{ + StartKey: ran.Start, + EndKey: ran.End, + }) + } + } + // need to make sure the key ranges is sorted + slices.SortFunc(ranges, func(i, j kv.KeyRange) int { + return bytes.Compare(i.StartKey, j.StartKey) + }) + keyRanges = append(keyRanges, NewKeyRanges(ranges)) + } + ret, err := buildBatchCopTasksForPartitionedTable(ctx, bo, b.store, keyRanges, b.req.StoreType, false, 0, false, 0, pid, tiflashcompute.DispatchPolicyInvalid, b.tiflashReplicaReadPolicy, b.appendWarning) + return ret, err +} + +// TiFlashReadTimeoutUltraLong represents the max time that tiflash request may take, since it may scan many regions for tiflash. +const TiFlashReadTimeoutUltraLong = 3600 * time.Second + +func (b *batchCopIterator) handleTaskOnce(ctx context.Context, bo *backoff.Backoffer, task *batchCopTask) ([]*batchCopTask, error) { + sender := NewRegionBatchRequestSender(b.store.GetRegionCache(), b.store.GetTiKVClient(), b.store.store.GetOracle(), b.enableCollectExecutionInfo) + var regionInfos = make([]*coprocessor.RegionInfo, 0, len(task.regionInfos)) + for _, ri := range task.regionInfos { + regionInfos = append(regionInfos, ri.toCoprocessorRegionInfo()) + } + + copReq := coprocessor.BatchRequest{ + Tp: b.req.Tp, + StartTs: b.req.StartTs, + Data: b.req.Data, + SchemaVer: b.req.SchemaVar, + Regions: regionInfos, + TableRegions: task.PartitionTableRegions, + ConnectionId: b.req.ConnID, + ConnectionAlias: b.req.ConnAlias, + } + + rgName := b.req.ResourceGroupName + if !variable.EnableResourceControl.Load() { + rgName = "" + } + req := tikvrpc.NewRequest(task.cmdType, &copReq, kvrpcpb.Context{ + IsolationLevel: isolationLevelToPB(b.req.IsolationLevel), + Priority: priorityToPB(b.req.Priority), + NotFillCache: b.req.NotFillCache, + RecordTimeStat: true, + RecordScanStat: true, + TaskId: b.req.TaskID, + ResourceControlContext: &kvrpcpb.ResourceControlContext{ + ResourceGroupName: rgName, + }, + }) + if b.req.ResourceGroupTagger != nil { + b.req.ResourceGroupTagger.Build(req) + } + req.StoreTp = getEndPointType(kv.TiFlash) + + logutil.BgLogger().Debug("send batch request to ", zap.String("req info", req.String()), zap.Int("cop task len", len(task.regionInfos))) + resp, retry, cancel, err := sender.SendReqToAddr(bo, task.ctx, task.regionInfos, req, TiFlashReadTimeoutUltraLong) + // If there are store errors, we should retry for all regions. + if retry { + return b.retryBatchCopTask(ctx, bo, task) + } + if err != nil { + err = derr.ToTiDBErr(err) + return nil, errors.Trace(err) + } + defer cancel() + return nil, b.handleStreamedBatchCopResponse(ctx, bo, resp.Resp.(*tikvrpc.BatchCopStreamResponse), task) +} + +func (b *batchCopIterator) handleStreamedBatchCopResponse(ctx context.Context, bo *Backoffer, response *tikvrpc.BatchCopStreamResponse, task *batchCopTask) (err error) { + defer response.Close() + resp := response.BatchResponse + if resp == nil { + // streaming request returns io.EOF, so the first Response is nil. + return + } + for { + err = b.handleBatchCopResponse(bo, resp, task) + if err != nil { + return errors.Trace(err) + } + resp, err = response.Recv() + if err != nil { + if errors.Cause(err) == io.EOF { + return nil + } + + if err1 := bo.Backoff(tikv.BoTiKVRPC(), errors.Errorf("recv stream response error: %v, task store addr: %s", err, task.storeAddr)); err1 != nil { + return errors.Trace(err) + } + + // No coprocessor.Response for network error, rebuild task based on the last success one. + if errors.Cause(err) == context.Canceled { + logutil.BgLogger().Info("stream recv timeout", zap.Error(err)) + } else { + logutil.BgLogger().Info("stream unknown error", zap.Error(err)) + } + return derr.ErrTiFlashServerTimeout + } + } +} + +func (b *batchCopIterator) handleBatchCopResponse(bo *Backoffer, response *coprocessor.BatchResponse, task *batchCopTask) (err error) { + if otherErr := response.GetOtherError(); otherErr != "" { + err = errors.Errorf("other error: %s", otherErr) + logutil.BgLogger().Warn("other error", + zap.Uint64("txnStartTS", b.req.StartTs), + zap.String("storeAddr", task.storeAddr), + zap.Error(err)) + return errors.Trace(err) + } + + if len(response.RetryRegions) > 0 { + logutil.BgLogger().Info("multiple regions are stale and need to be refreshed", zap.Int("region size", len(response.RetryRegions))) + for idx, retry := range response.RetryRegions { + id := tikv.NewRegionVerID(retry.Id, retry.RegionEpoch.ConfVer, retry.RegionEpoch.Version) + logutil.BgLogger().Info("invalid region because tiflash detected stale region", zap.String("region id", id.String())) + b.store.GetRegionCache().InvalidateCachedRegionWithReason(id, tikv.EpochNotMatch) + if idx >= 10 { + logutil.BgLogger().Info("stale regions are too many, so we omit the rest ones") + break + } + } + return + } + + resp := &batchCopResponse{ + pbResp: response, + detail: new(CopRuntimeStats), + } + + b.handleCollectExecutionInfo(bo, resp, task) + b.sendToRespCh(resp) + + return +} + +func (b *batchCopIterator) sendToRespCh(resp *batchCopResponse) (exit bool) { + select { + case b.respChan <- resp: + case <-b.finishCh: + exit = true + } + return +} + +func (b *batchCopIterator) handleCollectExecutionInfo(bo *Backoffer, resp *batchCopResponse, task *batchCopTask) { + if !b.enableCollectExecutionInfo { + return + } + backoffTimes := bo.GetBackoffTimes() + resp.detail.BackoffTime = time.Duration(bo.GetTotalSleep()) * time.Millisecond + resp.detail.BackoffSleep = make(map[string]time.Duration, len(backoffTimes)) + resp.detail.BackoffTimes = make(map[string]int, len(backoffTimes)) + for backoff := range backoffTimes { + resp.detail.BackoffTimes[backoff] = backoffTimes[backoff] + resp.detail.BackoffSleep[backoff] = time.Duration(bo.GetBackoffSleepMS()[backoff]) * time.Millisecond + } + resp.detail.CalleeAddress = task.storeAddr +} + +// Only called when UseAutoScaler is false. +func buildBatchCopTasksConsistentHashForPD(bo *backoff.Backoffer, + kvStore *kvStore, + rangesForEachPhysicalTable []*KeyRanges, + storeType kv.StoreType, + ttl time.Duration, + dispatchPolicy tiflashcompute.DispatchPolicy) (res []*batchCopTask, err error) { + failpointCheckWhichPolicy(dispatchPolicy) + const cmdType = tikvrpc.CmdBatchCop + var ( + retryNum int + rangesLen int + copTaskNum int + splitKeyElapsed time.Duration + getStoreElapsed time.Duration + ) + cache := kvStore.GetRegionCache() + start := time.Now() + + for { + retryNum++ + rangesLen = 0 + tasks := make([]*copTask, 0) + regionIDs := make([]tikv.RegionVerID, 0) + + splitKeyStart := time.Now() + for i, ranges := range rangesForEachPhysicalTable { + rangesLen += ranges.Len() + locations, err := cache.SplitKeyRangesByLocations(bo, ranges, UnspecifiedLimit, false, false) + if err != nil { + return nil, errors.Trace(err) + } + for _, lo := range locations { + tasks = append(tasks, &copTask{ + region: lo.Location.Region, + ranges: lo.Ranges, + cmdType: cmdType, + storeType: storeType, + partitionIndex: int64(i), + }) + regionIDs = append(regionIDs, lo.Location.Region) + } + } + splitKeyElapsed += time.Since(splitKeyStart) + + getStoreStart := time.Now() + stores, err := cache.GetTiFlashComputeStores(bo.TiKVBackoffer()) + if err != nil { + return nil, err + } + stores = filterAliveStores(bo.GetCtx(), stores, ttl, kvStore) + if len(stores) == 0 { + return nil, errors.New("tiflash_compute node is unavailable") + } + getStoreElapsed = time.Since(getStoreStart) + + storesStr := make([]string, 0, len(stores)) + for _, s := range stores { + storesStr = append(storesStr, s.GetAddr()) + } + var rpcCtxs []*tikv.RPCContext + if dispatchPolicy == tiflashcompute.DispatchPolicyRR { + rpcCtxs, err = getTiFlashComputeRPCContextByRoundRobin(regionIDs, storesStr) + } else if dispatchPolicy == tiflashcompute.DispatchPolicyConsistentHash { + rpcCtxs, err = getTiFlashComputeRPCContextByConsistentHash(regionIDs, storesStr) + } else { + err = errors.Errorf("unexpected dispatch policy %v", dispatchPolicy) + } + if err != nil { + return nil, err + } + if rpcCtxs == nil { + logutil.BgLogger().Info("buildBatchCopTasksConsistentHashForPD retry because rcpCtx is nil", zap.Int("retryNum", retryNum)) + err := bo.Backoff(tikv.BoTiFlashRPC(), errors.New("Cannot find region with TiFlash peer")) + if err != nil { + return nil, errors.Trace(err) + } + continue + } + if len(rpcCtxs) != len(tasks) { + return nil, errors.Errorf("length should be equal, len(rpcCtxs): %d, len(tasks): %d", len(rpcCtxs), len(tasks)) + } + copTaskNum = len(tasks) + taskMap := make(map[string]*batchCopTask) + for i, rpcCtx := range rpcCtxs { + regionInfo := RegionInfo{ + // tasks and rpcCtxs are correspond to each other. + Region: tasks[i].region, + Ranges: tasks[i].ranges, + PartitionIndex: tasks[i].partitionIndex, + } + if batchTask, ok := taskMap[rpcCtx.Addr]; ok { + batchTask.regionInfos = append(batchTask.regionInfos, regionInfo) + } else { + batchTask := &batchCopTask{ + storeAddr: rpcCtx.Addr, + cmdType: cmdType, + ctx: rpcCtx, + regionInfos: []RegionInfo{regionInfo}, + } + taskMap[rpcCtx.Addr] = batchTask + res = append(res, batchTask) + } + } + logutil.BgLogger().Info("buildBatchCopTasksConsistentHashForPD done", + zap.Any("len(tasks)", len(taskMap)), + zap.Any("len(tiflash_compute)", len(stores)), + zap.Any("dispatchPolicy", tiflashcompute.GetDispatchPolicy(dispatchPolicy))) + if log.GetLevel() <= zap.DebugLevel { + debugTaskMap := make(map[string]string, len(taskMap)) + for s, b := range taskMap { + debugTaskMap[s] = fmt.Sprintf("addr: %s; regionInfos: %v", b.storeAddr, b.regionInfos) + } + logutil.BgLogger().Debug("detailed info buildBatchCopTasksConsistentHashForPD", zap.Any("taskMap", debugTaskMap), zap.Any("allStores", storesStr)) + } + break + } + + if elapsed := time.Since(start); elapsed > time.Millisecond*500 { + logutil.BgLogger().Warn("buildBatchCopTasksConsistentHashForPD takes too much time", + zap.Duration("total elapsed", elapsed), + zap.Int("retryNum", retryNum), + zap.Duration("splitKeyElapsed", splitKeyElapsed), + zap.Duration("getStoreElapsed", getStoreElapsed), + zap.Int("range len", rangesLen), + zap.Int("copTaskNum", copTaskNum), + zap.Int("batchCopTaskNum", len(res))) + } + failpointCheckForConsistentHash(res) + return res, nil +} diff --git a/pkg/store/copr/mpp.go b/pkg/store/copr/mpp.go new file mode 100644 index 0000000000000..d0fdcaa0bd255 --- /dev/null +++ b/pkg/store/copr/mpp.go @@ -0,0 +1,346 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package copr + +import ( + "context" + "strconv" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/coprocessor" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/kvproto/pkg/mpp" + "github.com/pingcap/log" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/store/driver/backoff" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/tiflash" + "github.com/pingcap/tidb/pkg/util/tiflashcompute" + "github.com/tikv/client-go/v2/tikv" + "github.com/tikv/client-go/v2/tikvrpc" + pd "github.com/tikv/pd/client" + "go.uber.org/zap" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// MPPClient servers MPP requests. +type MPPClient struct { + store *kvStore +} + +type mppStoreCnt struct { + cnt int32 + lastUpdate int64 + initFlag int32 +} + +// GetAddress returns the network address. +func (c *batchCopTask) GetAddress() string { + return c.storeAddr +} + +// ConstructMPPTasks receives ScheduleRequest, which are actually collects of kv ranges. We allocates MPPTaskMeta for them and returns. +func (c *MPPClient) ConstructMPPTasks(ctx context.Context, req *kv.MPPBuildTasksRequest, ttl time.Duration, dispatchPolicy tiflashcompute.DispatchPolicy, tiflashReplicaReadPolicy tiflash.ReplicaRead, appendWarning func(error)) ([]kv.MPPTaskMeta, error) { + ctx = context.WithValue(ctx, tikv.TxnStartKey(), req.StartTS) + bo := backoff.NewBackofferWithVars(ctx, copBuildTaskMaxBackoff, nil) + var tasks []*batchCopTask + var err error + if req.PartitionIDAndRanges != nil { + rangesForEachPartition := make([]*KeyRanges, len(req.PartitionIDAndRanges)) + partitionIDs := make([]int64, len(req.PartitionIDAndRanges)) + for i, p := range req.PartitionIDAndRanges { + rangesForEachPartition[i] = NewKeyRanges(p.KeyRanges) + partitionIDs[i] = p.ID + } + tasks, err = buildBatchCopTasksForPartitionedTable(ctx, bo, c.store, rangesForEachPartition, kv.TiFlash, true, ttl, true, 20, partitionIDs, dispatchPolicy, tiflashReplicaReadPolicy, appendWarning) + } else { + if req.KeyRanges == nil { + return nil, errors.New("KeyRanges in MPPBuildTasksRequest is nil") + } + ranges := NewKeyRanges(req.KeyRanges) + tasks, err = buildBatchCopTasksForNonPartitionedTable(ctx, bo, c.store, ranges, kv.TiFlash, true, ttl, true, 20, dispatchPolicy, tiflashReplicaReadPolicy, appendWarning) + } + + if err != nil { + return nil, errors.Trace(err) + } + mppTasks := make([]kv.MPPTaskMeta, 0, len(tasks)) + for _, copTask := range tasks { + mppTasks = append(mppTasks, copTask) + } + return mppTasks, nil +} + +// DispatchMPPTask dispatch mpp task, and returns valid response when retry = false and err is nil +func (c *MPPClient) DispatchMPPTask(param kv.DispatchMPPTaskParam) (resp *mpp.DispatchTaskResponse, retry bool, err error) { + req := param.Req + var regionInfos []*coprocessor.RegionInfo + originalTask, ok := req.Meta.(*batchCopTask) + if ok { + for _, ri := range originalTask.regionInfos { + regionInfos = append(regionInfos, ri.toCoprocessorRegionInfo()) + } + } + + // meta for current task. + taskMeta := &mpp.TaskMeta{StartTs: req.StartTs, QueryTs: req.MppQueryID.QueryTs, LocalQueryId: req.MppQueryID.LocalQueryID, TaskId: req.ID, ServerId: req.MppQueryID.ServerID, + GatherId: req.GatherID, + Address: req.Meta.GetAddress(), + CoordinatorAddress: req.CoordinatorAddress, + ReportExecutionSummary: req.ReportExecutionSummary, + MppVersion: req.MppVersion.ToInt64(), + ResourceGroupName: req.ResourceGroupName, + ConnectionId: req.ConnectionID, + ConnectionAlias: req.ConnectionAlias, + } + + mppReq := &mpp.DispatchTaskRequest{ + Meta: taskMeta, + EncodedPlan: req.Data, + // TODO: This is only an experience value. It's better to be configurable. + Timeout: 60, + SchemaVer: req.SchemaVar, + Regions: regionInfos, + } + if originalTask != nil { + mppReq.TableRegions = originalTask.PartitionTableRegions + if mppReq.TableRegions != nil { + mppReq.Regions = nil + } + } + + wrappedReq := tikvrpc.NewRequest(tikvrpc.CmdMPPTask, mppReq, kvrpcpb.Context{}) + wrappedReq.StoreTp = getEndPointType(kv.TiFlash) + + // TODO: Handle dispatch task response correctly, including retry logic and cancel logic. + var rpcResp *tikvrpc.Response + invalidPDCache := config.GetGlobalConfig().DisaggregatedTiFlash && !config.GetGlobalConfig().UseAutoScaler + bo := backoff.NewBackofferWithTikvBo(param.Bo) + + // If copTasks is not empty, we should send request according to region distribution. + // Or else it's the task without region, which always happens in high layer task without table. + // In that case + if originalTask != nil { + sender := NewRegionBatchRequestSender(c.store.GetRegionCache(), c.store.GetTiKVClient(), c.store.store.GetOracle(), param.EnableCollectExecutionInfo) + rpcResp, retry, _, err = sender.SendReqToAddr(bo, originalTask.ctx, originalTask.regionInfos, wrappedReq, tikv.ReadTimeoutMedium) + // No matter what the rpc error is, we won't retry the mpp dispatch tasks. + // TODO: If we want to retry, we must redo the plan fragment cutting and task scheduling. + // That's a hard job but we can try it in the future. + if sender.GetRPCError() != nil { + logutil.BgLogger().Warn("mpp dispatch meet io error", zap.String("error", sender.GetRPCError().Error()), zap.Uint64("timestamp", taskMeta.StartTs), zap.Int64("task", taskMeta.TaskId), zap.Int64("mpp-version", taskMeta.MppVersion)) + if invalidPDCache { + c.store.GetRegionCache().InvalidateTiFlashComputeStores() + } + err = sender.GetRPCError() + } + } else { + rpcResp, err = c.store.GetTiKVClient().SendRequest(param.Ctx, req.Meta.GetAddress(), wrappedReq, tikv.ReadTimeoutMedium) + if errors.Cause(err) == context.Canceled || status.Code(errors.Cause(err)) == codes.Canceled { + retry = false + } else if err != nil { + if invalidPDCache { + c.store.GetRegionCache().InvalidateTiFlashComputeStores() + } + if bo.Backoff(tikv.BoTiFlashRPC(), err) == nil { + retry = true + } + } + } + + if err != nil || retry { + return nil, retry, err + } + + realResp := rpcResp.Resp.(*mpp.DispatchTaskResponse) + if realResp.Error != nil { + return realResp, false, nil + } + + if len(realResp.RetryRegions) > 0 { + logutil.BgLogger().Info("TiFlash found " + strconv.Itoa(len(realResp.RetryRegions)) + " stale regions. Only first " + strconv.Itoa(min(10, len(realResp.RetryRegions))) + " regions will be logged if the log level is higher than Debug") + for index, retry := range realResp.RetryRegions { + id := tikv.NewRegionVerID(retry.Id, retry.RegionEpoch.ConfVer, retry.RegionEpoch.Version) + if index < 10 || log.GetLevel() <= zap.DebugLevel { + logutil.BgLogger().Info("invalid region because tiflash detected stale region", zap.String("region id", id.String())) + } + c.store.GetRegionCache().InvalidateCachedRegionWithReason(id, tikv.EpochNotMatch) + } + } + return realResp, retry, err +} + +// CancelMPPTasks cancels mpp tasks +// NOTE: We do not retry here, because retry is helpless when errors result from TiFlash or Network. If errors occur, the execution on TiFlash will finally stop after some minutes. +// This function is exclusively called, and only the first call succeeds sending tasks and setting all tasks as cancelled, while others will not work. +func (c *MPPClient) CancelMPPTasks(param kv.CancelMPPTasksParam) { + usedStoreAddrs := param.StoreAddr + reqs := param.Reqs + if len(usedStoreAddrs) == 0 || len(reqs) == 0 { + return + } + + firstReq := reqs[0] + killReq := &mpp.CancelTaskRequest{ + Meta: &mpp.TaskMeta{StartTs: firstReq.StartTs, GatherId: firstReq.GatherID, QueryTs: firstReq.MppQueryID.QueryTs, LocalQueryId: firstReq.MppQueryID.LocalQueryID, ServerId: firstReq.MppQueryID.ServerID, MppVersion: firstReq.MppVersion.ToInt64(), ResourceGroupName: firstReq.ResourceGroupName}, + } + + wrappedReq := tikvrpc.NewRequest(tikvrpc.CmdMPPCancel, killReq, kvrpcpb.Context{}) + wrappedReq.StoreTp = getEndPointType(kv.TiFlash) + + // send cancel cmd to all stores where tasks run + invalidPDCache := config.GetGlobalConfig().DisaggregatedTiFlash && !config.GetGlobalConfig().UseAutoScaler + wg := util.WaitGroupWrapper{} + gotErr := atomic.Bool{} + for addr := range usedStoreAddrs { + storeAddr := addr + wg.Run(func() { + _, err := c.store.GetTiKVClient().SendRequest(context.Background(), storeAddr, wrappedReq, tikv.ReadTimeoutShort) + logutil.BgLogger().Debug("cancel task", zap.Uint64("query id ", firstReq.StartTs), zap.String("on addr", storeAddr), zap.Int64("mpp-version", firstReq.MppVersion.ToInt64())) + if err != nil { + logutil.BgLogger().Error("cancel task error", zap.Error(err), zap.Uint64("query id", firstReq.StartTs), zap.String("on addr", storeAddr), zap.Int64("mpp-version", firstReq.MppVersion.ToInt64())) + if invalidPDCache { + gotErr.CompareAndSwap(false, true) + } + } + }) + } + wg.Wait() + if invalidPDCache && gotErr.Load() { + c.store.GetRegionCache().InvalidateTiFlashComputeStores() + } +} + +// EstablishMPPConns build a mpp connection to receive data, return valid response when err is nil +func (c *MPPClient) EstablishMPPConns(param kv.EstablishMPPConnsParam) (*tikvrpc.MPPStreamResponse, error) { + req := param.Req + taskMeta := param.TaskMeta + connReq := &mpp.EstablishMPPConnectionRequest{ + SenderMeta: taskMeta, + ReceiverMeta: &mpp.TaskMeta{ + StartTs: req.StartTs, + GatherId: req.GatherID, + QueryTs: req.MppQueryID.QueryTs, + LocalQueryId: req.MppQueryID.LocalQueryID, + ServerId: req.MppQueryID.ServerID, + MppVersion: req.MppVersion.ToInt64(), + TaskId: -1, + ResourceGroupName: req.ResourceGroupName, + }, + } + + var err error + + wrappedReq := tikvrpc.NewRequest(tikvrpc.CmdMPPConn, connReq, kvrpcpb.Context{}) + wrappedReq.StoreTp = getEndPointType(kv.TiFlash) + + // Drain results from root task. + // We don't need to process any special error. When we meet errors, just let it fail. + rpcResp, err := c.store.GetTiKVClient().SendRequest(param.Ctx, req.Meta.GetAddress(), wrappedReq, TiFlashReadTimeoutUltraLong) + + var stream *tikvrpc.MPPStreamResponse + if rpcResp != nil && rpcResp.Resp != nil { + stream = rpcResp.Resp.(*tikvrpc.MPPStreamResponse) + } + + if err != nil { + if stream != nil { + stream.Close() + } + logutil.BgLogger().Warn("establish mpp connection meet error and cannot retry", zap.String("error", err.Error()), zap.Uint64("timestamp", taskMeta.StartTs), zap.Int64("task", taskMeta.TaskId), zap.Int64("mpp-version", taskMeta.MppVersion)) + if config.GetGlobalConfig().DisaggregatedTiFlash && !config.GetGlobalConfig().UseAutoScaler { + c.store.GetRegionCache().InvalidateTiFlashComputeStores() + } + return nil, err + } + + return stream, nil +} + +// CheckVisibility checks if it is safe to read using given ts. +func (c *MPPClient) CheckVisibility(startTime uint64) error { + return c.store.CheckVisibility(startTime) +} + +func (c *mppStoreCnt) getMPPStoreCount(ctx context.Context, pdClient pd.Client, TTL int64) (int, error) { + failpoint.Inject("mppStoreCountSetLastUpdateTime", func(value failpoint.Value) { + v, _ := strconv.ParseInt(value.(string), 10, 0) + c.lastUpdate = v + }) + + lastUpdate := atomic.LoadInt64(&c.lastUpdate) + now := time.Now().UnixMicro() + isInit := atomic.LoadInt32(&c.initFlag) != 0 + + if now-lastUpdate < TTL { + if isInit { + return int(atomic.LoadInt32(&c.cnt)), nil + } + } + + failpoint.Inject("mppStoreCountSetLastUpdateTimeP2", func(value failpoint.Value) { + v, _ := strconv.ParseInt(value.(string), 10, 0) + c.lastUpdate = v + }) + + if !atomic.CompareAndSwapInt64(&c.lastUpdate, lastUpdate, now) { + if isInit { + return int(atomic.LoadInt32(&c.cnt)), nil + } + // if has't initialized, always fetch latest mpp store info + } + + // update mpp store cache + cnt := 0 + stores, err := pdClient.GetAllStores(ctx, pd.WithExcludeTombstone()) + + failpoint.Inject("mppStoreCountPDError", func(value failpoint.Value) { + if value.(bool) { + err = errors.New("failed to get mpp store count") + } + }) + + if err != nil { + // always to update cache next time + atomic.StoreInt32(&c.initFlag, 0) + return 0, err + } + for _, s := range stores { + if !tikv.LabelFilterNoTiFlashWriteNode(s.GetLabels()) { + continue + } + cnt += 1 + } + failpoint.Inject("mppStoreCountSetMPPCnt", func(value failpoint.Value) { + cnt = value.(int) + }) + + if !isInit || atomic.LoadInt64(&c.lastUpdate) == now { + atomic.StoreInt32(&c.cnt, int32(cnt)) + atomic.StoreInt32(&c.initFlag, 1) + } + + return cnt, nil +} + +// GetMPPStoreCount returns number of TiFlash stores +func (c *MPPClient) GetMPPStoreCount() (int, error) { + return c.store.mppStoreCnt.getMPPStoreCount(c.store.store.Ctx(), c.store.store.GetPDClient(), 120*1e6 /* TTL 120sec */) +} diff --git a/pkg/util/mock/BUILD.bazel b/pkg/util/mock/BUILD.bazel new file mode 100644 index 0000000000000..28c64a80ad7df --- /dev/null +++ b/pkg/util/mock/BUILD.bazel @@ -0,0 +1,71 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "mock", + srcs = [ + "client.go", + "context.go", + "fortest.go", + "iter.go", + "metrics.go", + "store.go", + ], + importpath = "github.com/pingcap/tidb/pkg/util/mock", + visibility = ["//visibility:public"], + deps = [ + "//pkg/distsql/context", + "//pkg/expression/exprctx", + "//pkg/expression/sessionexpr", + "//pkg/extension", + "//pkg/infoschema/context", + "//pkg/kv", + "//pkg/meta/model", + "//pkg/parser/ast", + "//pkg/parser/model", + "//pkg/parser/terror", + "//pkg/planner/core/resolve", + "//pkg/planner/planctx", + "//pkg/session/cursor", + "//pkg/sessionctx", + "//pkg/sessionctx/sessionstates", + "//pkg/sessionctx/variable", + "//pkg/statistics/handle/usage/indexusage", + "//pkg/table/tblctx", + "//pkg/table/tblsession", + "//pkg/util", + "//pkg/util/chunk", + "//pkg/util/disk", + "//pkg/util/logutil", + "//pkg/util/memory", + "//pkg/util/ranger/context", + "//pkg/util/sli", + "//pkg/util/sqlexec", + "//pkg/util/topsql/stmtstats", + "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_kvproto//pkg/deadlock", + "@com_github_pingcap_kvproto//pkg/kvrpcpb", + "@com_github_prometheus_client_golang//prometheus", + "@com_github_stretchr_testify//assert", + "@com_github_tikv_client_go_v2//oracle", + "@com_github_tikv_client_go_v2//tikv", + "@org_uber_go_atomic//:atomic", + ], +) + +go_test( + name = "mock_test", + timeout = "short", + srcs = [ + "iter_test.go", + "main_test.go", + "mock_test.go", + ], + embed = [":mock"], + flaky = True, + deps = [ + "//pkg/kv", + "//pkg/testkit/testsetup", + "@com_github_stretchr_testify//assert", + "@org_uber_go_goleak//:goleak", + ], +) diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index 6030620cad74c..cc5062c9fb248 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -3399,7 +3399,11 @@ func (b *PlanBuilder) buildSimple(ctx context.Context, node ast.StmtNode) (Plan, if err != nil { return nil, err } +<<<<<<< HEAD:planner/core/planbuilder.go if err := sessionctx.ValidateStaleReadTS(ctx, b.ctx, startTS); err != nil { +======= + if err := sessionctx.ValidateSnapshotReadTS(ctx, b.ctx.GetStore(), startTS, true); err != nil { +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/planner/core/planbuilder.go return nil, err } p.StaleTxnStartTS = startTS @@ -3413,7 +3417,11 @@ func (b *PlanBuilder) buildSimple(ctx context.Context, node ast.StmtNode) (Plan, if err != nil { return nil, err } +<<<<<<< HEAD:planner/core/planbuilder.go if err := sessionctx.ValidateStaleReadTS(ctx, b.ctx, startTS); err != nil { +======= + if err := sessionctx.ValidateSnapshotReadTS(ctx, b.ctx.GetStore(), startTS, true); err != nil { +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/planner/core/planbuilder.go return nil, err } p.StaleTxnStartTS = startTS diff --git a/sessiontxn/staleread/processor.go b/sessiontxn/staleread/processor.go index af91ffd1b175e..63d127c912870 100644 --- a/sessiontxn/staleread/processor.go +++ b/sessiontxn/staleread/processor.go @@ -285,7 +285,11 @@ func parseAndValidateAsOf(ctx context.Context, sctx sessionctx.Context, asOf *as return 0, err } +<<<<<<< HEAD:sessiontxn/staleread/processor.go if err = sessionctx.ValidateStaleReadTS(ctx, sctx, ts); err != nil { +======= + if err = sessionctx.ValidateSnapshotReadTS(ctx, sctx.GetStore(), ts, true); err != nil { +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/sessiontxn/staleread/processor.go return 0, err } diff --git a/sessiontxn/staleread/util.go b/sessiontxn/staleread/util.go index d2cc7e4863446..d3ea5b8e86731 100644 --- a/sessiontxn/staleread/util.go +++ b/sessiontxn/staleread/util.go @@ -77,8 +77,25 @@ func CalculateTsWithReadStaleness(sctx sessionctx.Context, readStaleness time.Du return 0, err } tsVal := nowVal.Add(readStaleness) +<<<<<<< HEAD:sessiontxn/staleread/util.go minTsVal := expression.GetMinSafeTime(sctx) return oracle.GoTimeToTS(expression.CalAppropriateTime(tsVal, nowVal, minTsVal)), nil +======= + sc := sctx.GetSessionVars().StmtCtx + minSafeTSVal := expression.GetStmtMinSafeTime(sc, sctx.GetStore(), sc.TimeZone()) + calculatedTime := expression.CalAppropriateTime(tsVal, nowVal, minSafeTSVal) + readTS := oracle.GoTimeToTS(calculatedTime) + if calculatedTime.After(minSafeTSVal) { + // If the final calculated exceeds the min safe ts, we are not sure whether the ts is safe to read (note that + // reading with a ts larger than PD's max allocated ts + 1 is unsafe and may break linearizability). + // So in this case, do an extra check on it. + err = sessionctx.ValidateSnapshotReadTS(ctx, sctx.GetStore(), readTS, true) + if err != nil { + return 0, err + } + } + return readTS, nil +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/sessiontxn/staleread/util.go } // IsStmtStaleness indicates whether the current statement is staleness or not diff --git a/store/copr/BUILD.bazel b/store/copr/BUILD.bazel index 1d3238f13e93a..2ef2e7ecdf145 100644 --- a/store/copr/BUILD.bazel +++ b/store/copr/BUILD.bazel @@ -46,6 +46,7 @@ go_library( "@com_github_tikv_client_go_v2//config", "@com_github_tikv_client_go_v2//error", "@com_github_tikv_client_go_v2//metrics", + "@com_github_tikv_client_go_v2//oracle", "@com_github_tikv_client_go_v2//tikv", "@com_github_tikv_client_go_v2//tikvrpc", "@com_github_tikv_client_go_v2//txnkv/txnlock", diff --git a/store/copr/batch_request_sender.go b/store/copr/batch_request_sender.go index b976d26a59ab3..3552be8d3516b 100644 --- a/store/copr/batch_request_sender.go +++ b/store/copr/batch_request_sender.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/kvproto/pkg/coprocessor" "github.com/pingcap/kvproto/pkg/metapb" tikverr "github.com/tikv/client-go/v2/error" + "github.com/tikv/client-go/v2/oracle" "github.com/tikv/client-go/v2/tikv" "github.com/tikv/client-go/v2/tikvrpc" "google.golang.org/grpc/codes" @@ -55,9 +56,9 @@ type RegionBatchRequestSender struct { } // NewRegionBatchRequestSender creates a RegionBatchRequestSender object. -func NewRegionBatchRequestSender(cache *RegionCache, client tikv.Client, enableCollectExecutionInfo bool) *RegionBatchRequestSender { +func NewRegionBatchRequestSender(cache *RegionCache, client tikv.Client, oracle oracle.Oracle, enableCollectExecutionInfo bool) *RegionBatchRequestSender { return &RegionBatchRequestSender{ - RegionRequestSender: tikv.NewRegionRequestSender(cache.RegionCache, client), + RegionRequestSender: tikv.NewRegionRequestSender(cache.RegionCache, client, oracle), enableCollectExecutionInfo: enableCollectExecutionInfo, } } diff --git a/util/cmp/compare_test.go b/util/cmp/compare_test.go index 63f81008060d5..0cdd4e8078d1a 100644 --- a/util/cmp/compare_test.go +++ b/util/cmp/compare_test.go @@ -20,6 +20,7 @@ import ( "github.com/stretchr/testify/require" ) +<<<<<<< HEAD:util/cmp/compare_test.go func TestCompare(t *testing.T) { require.Equal(t, -1, Compare(1, 2)) require.Equal(t, 1, Compare(2, 1)) @@ -28,4 +29,12 @@ func TestCompare(t *testing.T) { require.Equal(t, -1, Compare("a", "b")) require.Equal(t, 1, Compare("b", "a")) require.Zero(t, Compare("a", "a")) +======= +// NewContext creates a new mocked sessionctx.Context. +// This function should only be used for testing. +// Avoid using this when you are in a context with a `kv.Storage` instance, especially when you are going to access +// the data in it. Consider using testkit.NewSession(t, store) instead when possible. +func NewContext() *Context { + return newContext() +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/util/mock/fortest.go } diff --git a/util/mock/context.go b/util/mock/context.go index 3445e73d4b603..a14b8665a664c 100644 --- a/util/mock/context.go +++ b/util/mock/context.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/kvrpcpb" +<<<<<<< HEAD:util/mock/context.go "github.com/pingcap/tidb/extension" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/parser/ast" @@ -37,6 +38,36 @@ import ( "github.com/pingcap/tidb/util/sqlexec" "github.com/pingcap/tidb/util/topsql/stmtstats" "github.com/pingcap/tipb/go-binlog" +======= + distsqlctx "github.com/pingcap/tidb/pkg/distsql/context" + "github.com/pingcap/tidb/pkg/expression/exprctx" + "github.com/pingcap/tidb/pkg/expression/sessionexpr" + "github.com/pingcap/tidb/pkg/extension" + infoschema "github.com/pingcap/tidb/pkg/infoschema/context" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta/model" + "github.com/pingcap/tidb/pkg/parser/ast" + pmodel "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/planner/core/resolve" + "github.com/pingcap/tidb/pkg/planner/planctx" + "github.com/pingcap/tidb/pkg/session/cursor" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/sessionstates" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/statistics/handle/usage/indexusage" + "github.com/pingcap/tidb/pkg/table/tblctx" + "github.com/pingcap/tidb/pkg/table/tblsession" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/disk" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + rangerctx "github.com/pingcap/tidb/pkg/util/ranger/context" + "github.com/pingcap/tidb/pkg/util/sli" + "github.com/pingcap/tidb/pkg/util/sqlexec" + "github.com/pingcap/tidb/pkg/util/topsql/stmtstats" +>>>>>>> 0bf3e019002 (*: Update client-go and verify all read ts (#58054)):pkg/util/mock/context.go "github.com/tikv/client-go/v2/oracle" "github.com/tikv/client-go/v2/tikv" ) @@ -67,7 +98,7 @@ type wrapTxn struct { } func (txn *wrapTxn) validOrPending() bool { - return txn.tsFuture != nil || txn.Transaction.Valid() + return txn.tsFuture != nil || (txn.Transaction != nil && txn.Transaction.Valid()) } func (txn *wrapTxn) pending() bool { @@ -173,7 +204,15 @@ func (c *Context) GetSessionVars() *variable.SessionVars { } // Txn implements sessionctx.Context Txn interface. -func (c *Context) Txn(bool) (kv.Transaction, error) { +func (c *Context) Txn(active bool) (kv.Transaction, error) { + if active { + if !c.txn.validOrPending() { + err := c.newTxn(context.Background()) + if err != nil { + return nil, err + } + } + } return &c.txn, nil } @@ -253,10 +292,12 @@ func (c *Context) GetPlanCache(_ bool) sessionctx.PlanCache { return c.pcache } -// NewTxn implements the sessionctx.Context interface. -func (c *Context) NewTxn(context.Context) error { +// newTxn Creates new transaction on the session context. +func (c *Context) newTxn(ctx context.Context) error { if c.Store == nil { - return errors.New("store is not set") + logutil.Logger(ctx).Warn("mock.Context: No store is specified when trying to create new transaction. A fake transaction will be created. Note that this is unrecommended usage.") + c.fakeTxn() + return nil } if c.txn.Valid() { err := c.txn.Commit(c.ctx) @@ -273,14 +314,41 @@ func (c *Context) NewTxn(context.Context) error { return nil } -// NewStaleTxnWithStartTS implements the sessionctx.Context interface. -func (c *Context) NewStaleTxnWithStartTS(ctx context.Context, _ uint64) error { - return c.NewTxn(ctx) +// fakeTxn is used to let some tests pass in the context without an available kv.Storage. Once usages to access +// transactions without a kv.Storage are removed, this type should also be removed. +// New code should never use this. +type fakeTxn struct { + // The inner should always be nil. + kv.Transaction + startTS uint64 +} + +func (t *fakeTxn) StartTS() uint64 { + return t.startTS +} + +func (*fakeTxn) SetDiskFullOpt(_ kvrpcpb.DiskFullOpt) {} + +func (*fakeTxn) SetOption(_ int, _ any) {} + +func (*fakeTxn) Get(ctx context.Context, _ kv.Key) ([]byte, error) { + // Check your implementation if you meet this error. It's dangerous if some calculation relies on the data but the + // read result is faked. + logutil.Logger(ctx).Warn("mock.Context: No store is specified but trying to access data from a transaction.") + return nil, nil +} + +func (*fakeTxn) Valid() bool { return true } + +func (c *Context) fakeTxn() { + c.txn.Transaction = &fakeTxn{ + startTS: 1, + } } // RefreshTxnCtx implements the sessionctx.Context interface. func (c *Context) RefreshTxnCtx(ctx context.Context) error { - return errors.Trace(c.NewTxn(ctx)) + return errors.Trace(c.newTxn(ctx)) } // RefreshVars implements the sessionctx.Context interface.