Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try to validate read ts for all RPC requests #1513

Merged
merged 11 commits into from
Dec 9, 2024
5 changes: 3 additions & 2 deletions internal/locate/region_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ import (
"github.com/tikv/client-go/v2/internal/apicodec"
"github.com/tikv/client-go/v2/internal/mockstore/mocktikv"
"github.com/tikv/client-go/v2/kv"
"github.com/tikv/client-go/v2/oracle"
"github.com/tikv/client-go/v2/tikvrpc"
pd "github.com/tikv/pd/client"
uatomic "go.uber.org/atomic"
Expand Down Expand Up @@ -1333,7 +1334,7 @@ func (s *testRegionCacheSuite) TestRegionEpochOnTiFlash() {
s.Equal(ctxTiFlash.Peer.Id, s.peer1)
ctxTiFlash.Peer.Role = metapb.PeerRole_Learner
r := ctxTiFlash.Meta
reqSend := NewRegionRequestSender(s.cache, nil)
reqSend := NewRegionRequestSender(s.cache, nil, oracle.NoopReadTSValidator{})
regionErr := &errorpb.Error{EpochNotMatch: &errorpb.EpochNotMatch{CurrentRegions: []*metapb.Region{r}}}
reqSend.onRegionError(s.bo, ctxTiFlash, nil, regionErr)

Expand Down Expand Up @@ -1970,7 +1971,7 @@ func (s *testRegionCacheSuite) TestShouldNotRetryFlashback() {
ctx, err := s.cache.GetTiKVRPCContext(retry.NewBackofferWithVars(context.Background(), 100, nil), loc.Region, kv.ReplicaReadLeader, 0)
s.NotNil(ctx)
s.NoError(err)
reqSend := NewRegionRequestSender(s.cache, nil)
reqSend := NewRegionRequestSender(s.cache, nil, oracle.NoopReadTSValidator{})
shouldRetry, err := reqSend.onRegionError(s.bo, ctx, nil, &errorpb.Error{FlashbackInProgress: &errorpb.FlashbackInProgress{}})
s.Error(err)
s.False(shouldRetry)
Expand Down
30 changes: 26 additions & 4 deletions internal/locate/region_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@ import (
"context"
"fmt"
"maps"
"math"
"math/rand"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/tikv/client-go/v2/oracle"
"go.uber.org/zap"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand Down Expand Up @@ -107,6 +109,7 @@ type RegionRequestSender struct {
regionCache *RegionCache
apiVersion kvrpcpb.APIVersion
client client.Client
readTSValidator oracle.ReadTSValidator
storeAddr string
rpcError error
replicaSelector *replicaSelector
Expand Down Expand Up @@ -377,11 +380,12 @@ func (s *ReplicaAccessStats) String() string {
}

// NewRegionRequestSender creates a new sender.
func NewRegionRequestSender(regionCache *RegionCache, client client.Client) *RegionRequestSender {
func NewRegionRequestSender(regionCache *RegionCache, client client.Client, readTSValidator oracle.ReadTSValidator) *RegionRequestSender {
return &RegionRequestSender{
regionCache: regionCache,
apiVersion: regionCache.codec.GetAPIVersion(),
client: client,
regionCache: regionCache,
apiVersion: regionCache.codec.GetAPIVersion(),
client: client,
readTSValidator: readTSValidator,
}
}

Expand Down Expand Up @@ -766,6 +770,10 @@ func (s *RegionRequestSender) SendReqCtx(
}
}

if err = s.validateReadTS(bo.GetCtx(), req); err != nil {
return nil, nil, 0, err
}

// If the MaxExecutionDurationMs is not set yet, we set it to be the RPC timeout duration
// so TiKV can give up the requests whose response TiDB cannot receive due to timeout.
if req.Context.MaxExecutionDurationMs == 0 {
Expand Down Expand Up @@ -1756,6 +1764,20 @@ func (s *RegionRequestSender) onRegionError(
return false, nil
}

func (s *RegionRequestSender) validateReadTS(ctx context.Context, req *tikvrpc.Request) error {
var readTS uint64
switch req.Type {
case tikvrpc.CmdGet, tikvrpc.CmdScan, tikvrpc.CmdBatchGet, tikvrpc.CmdCop, tikvrpc.CmdCopStream, tikvrpc.CmdBatchCop, tikvrpc.CmdScanLock, tikvrpc.CmdBufferBatchGet:
readTS = req.GetStartTS()
default:
return nil
}
if readTS == math.MaxUint64 {
return nil
}
return s.readTSValidator.ValidateReadTS(ctx, readTS, req.StaleRead, &oracle.Option{TxnScope: req.TxnScope})
}

type staleReadMetricsCollector struct {
}

Expand Down
10 changes: 5 additions & 5 deletions internal/locate/region_request3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (s *testRegionRequestToThreeStoresSuite) SetupTest() {
s.cache = NewRegionCache(pdCli)
s.bo = retry.NewNoopBackoff(context.Background())
client := mocktikv.NewRPCClient(s.cluster, s.mvccStore, nil)
s.regionRequestSender = NewRegionRequestSender(s.cache, client)
s.regionRequestSender = NewRegionRequestSender(s.cache, client, oracle.NoopReadTSValidator{})

s.NoError(failpoint.Enable("tikvclient/doNotRecoverStoreHealthCheckPanic", "return"))
}
Expand Down Expand Up @@ -185,7 +185,7 @@ func (s *testRegionRequestToThreeStoresSuite) TestSwitchPeerWhenNoLeaderErrorWit
s.Nil(err)
s.NotNil(location)
bo := retry.NewBackoffer(context.Background(), 1000)
resp, _, _, err := NewRegionRequestSender(s.cache, cli).SendReqCtx(bo, req, location.Region, time.Second, tikvrpc.TiKV)
resp, _, _, err := NewRegionRequestSender(s.cache, cli, oracle.NoopReadTSValidator{}).SendReqCtx(bo, req, location.Region, time.Second, tikvrpc.TiKV)
s.Nil(err)
s.NotNil(resp)
regionErr, err := resp.GetRegionError()
Expand All @@ -208,7 +208,7 @@ func (s *testRegionRequestToThreeStoresSuite) loadAndGetLeaderStore() (*Store, s
}

func (s *testRegionRequestToThreeStoresSuite) TestForwarding() {
sender := NewRegionRequestSender(s.cache, s.regionRequestSender.client)
sender := NewRegionRequestSender(s.cache, s.regionRequestSender.client, oracle.NoopReadTSValidator{})
sender.regionCache.enableForwarding = true

// First get the leader's addr from region cache
Expand Down Expand Up @@ -1144,7 +1144,7 @@ func (s *testRegionRequestToThreeStoresSuite) TestSendReqFirstTimeout() {
}
resetStats := func() {
reqTargetAddrs = make(map[string]struct{})
s.regionRequestSender = NewRegionRequestSender(s.cache, mockRPCClient)
s.regionRequestSender = NewRegionRequestSender(s.cache, mockRPCClient, oracle.NoopReadTSValidator{})
s.regionRequestSender.Stats = NewRegionRequestRuntimeStats()
}

Expand Down Expand Up @@ -1375,7 +1375,7 @@ func (s *testRegionRequestToThreeStoresSuite) TestStaleReadTryFollowerAfterTimeo
}
return &tikvrpc.Response{Resp: &kvrpcpb.GetResponse{Value: []byte("value")}}, nil
}}
s.regionRequestSender = NewRegionRequestSender(s.cache, mockRPCClient)
s.regionRequestSender = NewRegionRequestSender(s.cache, mockRPCClient, oracle.NoopReadTSValidator{})
s.regionRequestSender.Stats = NewRegionRequestRuntimeStats()
getLocFn := func() *KeyLocation {
loc, err := s.regionRequestSender.regionCache.LocateKey(bo, []byte("a"))
Expand Down
3 changes: 2 additions & 1 deletion internal/locate/region_request_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/tikv/client-go/v2/internal/mockstore/mocktikv"
"github.com/tikv/client-go/v2/kv"
"github.com/tikv/client-go/v2/metrics"
"github.com/tikv/client-go/v2/oracle"
"github.com/tikv/client-go/v2/tikvrpc"
"github.com/tikv/client-go/v2/util"
)
Expand Down Expand Up @@ -77,7 +78,7 @@ func (s *testRegionCacheStaleReadSuite) SetupTest() {
s.cache = NewRegionCache(pdCli)
s.bo = retry.NewNoopBackoff(context.Background())
client := mocktikv.NewRPCClient(s.cluster, s.mvccStore, nil)
s.regionRequestSender = NewRegionRequestSender(s.cache, client)
s.regionRequestSender = NewRegionRequestSender(s.cache, client, oracle.NoopReadTSValidator{})
s.setClient()
s.injection = testRegionCacheFSMSuiteInjection{
unavailableStoreIDs: make(map[uint64]struct{}),
Expand Down
15 changes: 8 additions & 7 deletions internal/locate/region_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ import (
"github.com/tikv/client-go/v2/internal/client"
"github.com/tikv/client-go/v2/internal/client/mockserver"
"github.com/tikv/client-go/v2/internal/mockstore/mocktikv"
"github.com/tikv/client-go/v2/oracle"
"github.com/tikv/client-go/v2/tikvrpc"
pderr "github.com/tikv/pd/client/errs"
"google.golang.org/grpc"
Expand Down Expand Up @@ -93,7 +94,7 @@ func (s *testRegionRequestToSingleStoreSuite) SetupTest() {
s.cache = NewRegionCache(pdCli)
s.bo = retry.NewNoopBackoff(context.Background())
client := mocktikv.NewRPCClient(s.cluster, s.mvccStore, nil)
s.regionRequestSender = NewRegionRequestSender(s.cache, client)
s.regionRequestSender = NewRegionRequestSender(s.cache, client, oracle.NoopReadTSValidator{})

s.NoError(failpoint.Enable("tikvclient/doNotRecoverStoreHealthCheckPanic", "return"))
}
Expand Down Expand Up @@ -601,7 +602,7 @@ func (s *testRegionRequestToSingleStoreSuite) TestNoReloadRegionForGrpcWhenCtxCa
}()

cli := client.NewRPCClient()
sender := NewRegionRequestSender(s.cache, cli)
sender := NewRegionRequestSender(s.cache, cli, oracle.NoopReadTSValidator{})
req := tikvrpc.NewRequest(tikvrpc.CmdRawPut, &kvrpcpb.RawPutRequest{
Key: []byte("key"),
Value: []byte("value"),
Expand All @@ -622,7 +623,7 @@ func (s *testRegionRequestToSingleStoreSuite) TestNoReloadRegionForGrpcWhenCtxCa
Client: client.NewRPCClient(),
redirectAddr: addr,
}
sender = NewRegionRequestSender(s.cache, client1)
sender = NewRegionRequestSender(s.cache, client1, oracle.NoopReadTSValidator{})
sender.SendReq(s.bo, req, region.Region, 3*time.Second)

// cleanup
Expand Down Expand Up @@ -806,7 +807,7 @@ func (s *testRegionRequestToSingleStoreSuite) TestBatchClientSendLoopPanic() {
cancel()
}()
req := tikvrpc.NewRequest(tikvrpc.CmdCop, &coprocessor.Request{Data: []byte("a"), StartTs: 1})
regionRequestSender := NewRegionRequestSender(s.cache, fnClient)
regionRequestSender := NewRegionRequestSender(s.cache, fnClient, oracle.NoopReadTSValidator{})
reachable.injectConstantLiveness(regionRequestSender.regionCache.stores)
regionRequestSender.SendReq(bo, req, region.Region, client.ReadTimeoutShort)
}
Expand Down Expand Up @@ -850,19 +851,19 @@ type emptyClient struct {

func (s *testRegionRequestToSingleStoreSuite) TestClientExt() {
var cli client.Client = client.NewRPCClient()
sender := NewRegionRequestSender(s.cache, cli)
sender := NewRegionRequestSender(s.cache, cli, oracle.NoopReadTSValidator{})
s.NotNil(sender.client)
s.NotNil(sender.getClientExt())
cli.Close()

cli = &emptyClient{}
sender = NewRegionRequestSender(s.cache, cli)
sender = NewRegionRequestSender(s.cache, cli, oracle.NoopReadTSValidator{})
s.NotNil(sender.client)
s.Nil(sender.getClientExt())
}

func (s *testRegionRequestToSingleStoreSuite) TestRegionRequestSenderString() {
sender := NewRegionRequestSender(s.cache, &fnClient{})
sender := NewRegionRequestSender(s.cache, &fnClient{}, oracle.NoopReadTSValidator{})
loc, err := s.cache.LocateRegionByID(s.bo, s.region)
s.Nil(err)
// invalid region cache before sending request.
Expand Down
8 changes: 4 additions & 4 deletions internal/locate/replica_selector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2793,7 +2793,7 @@ func (ca *replicaSelectorAccessPathCase) run(s *testReplicaSelectorSuite) *Regio
Value: []byte("hello world"),
}}, nil
}}
sender := NewRegionRequestSender(s.cache, fnClient)
sender := NewRegionRequestSender(s.cache, fnClient, oracle.NoopReadTSValidator{})
req, opts, timeout := ca.buildRequest(s)
beforeRun(s, ca)
rc := s.getRegion()
Expand Down Expand Up @@ -3220,7 +3220,7 @@ func TestTiKVClientReadTimeout(t *testing.T) {
req.ReplicaReadType = kv.ReplicaReadLeader
req.MaxExecutionDurationMs = 1
bo := retry.NewBackofferWithVars(context.Background(), 2000, nil)
sender := NewRegionRequestSender(s.cache, fnClient)
sender := NewRegionRequestSender(s.cache, fnClient, oracle.NoopReadTSValidator{})
resp, _, err := sender.SendReq(bo, req, rc.VerID(), time.Millisecond)
s.Nil(err)
s.NotNil(resp)
Expand All @@ -3234,7 +3234,7 @@ func TestTiKVClientReadTimeout(t *testing.T) {
}, accessPath)
// clear max execution duration for retry.
req.MaxExecutionDurationMs = 0
sender = NewRegionRequestSender(s.cache, fnClient)
sender = NewRegionRequestSender(s.cache, fnClient, oracle.NoopReadTSValidator{})
resp, _, err = sender.SendReq(bo, req, rc.VerID(), time.Second) // use a longer timeout.
s.Nil(err)
s.NotNil(resp)
Expand Down Expand Up @@ -3310,7 +3310,7 @@ func BenchmarkReplicaSelector(b *testing.B) {
if err != nil {
b.Fail()
}
sender := NewRegionRequestSender(cache, fnClient)
sender := NewRegionRequestSender(cache, fnClient, oracle.NoopReadTSValidator{})
sender.SendReqCtx(bo, req, loc.Region, client.ReadTimeoutShort, tikvrpc.TiKV)
}
}
22 changes: 18 additions & 4 deletions oracle/oracle.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,17 @@ type Oracle interface {
// GetAllTSOKeyspaceGroupMinTS gets a minimum timestamp from all TSO keyspace groups.
GetAllTSOKeyspaceGroupMinTS(ctx context.Context) (uint64, error)

// ValidateSnapshotReadTS verifies whether it can be guaranteed that the given readTS doesn't exceed the maximum ts
// that has been allocated by the oracle, so that it's safe to use this ts to perform snapshot read, stale read,
// etc.
ReadTSValidator
}

// ReadTSValidator is the interface for providing the ability for verifying whether a timestamp is safe to be used
// for readings, as part of the `Oracle` interface.
type ReadTSValidator interface {
// ValidateReadTS verifies whether it can be guaranteed that the given readTS doesn't exceed the maximum ts
// that has been allocated by the oracle, so that it's safe to use this ts to perform read operations.
// Note that this method only checks the ts from the oracle's perspective. It doesn't check whether the snapshot
// has been GCed.
ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *Option) error
ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *Option) error
}

// Future is a future which promises to return a timestamp.
Expand Down Expand Up @@ -125,3 +130,12 @@ func GoTimeToTS(t time.Time) uint64 {
func GoTimeToLowerLimitStartTS(now time.Time, maxTxnTimeUse int64) uint64 {
return GoTimeToTS(now.Add(-time.Duration(maxTxnTimeUse) * time.Millisecond))
}

// NoopReadTSValidator is a dummy implementation of ReadTSValidator that always let the validation pass.
// Only use this when using RPCs that are not related to ts (e.g. rawkv), or in tests where `Oracle` is not available
// and the validation is not necessary.
type NoopReadTSValidator struct{}

func (NoopReadTSValidator) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *Option) error {
return nil
}
2 changes: 1 addition & 1 deletion oracle/oracles/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,6 @@ func SetEmptyPDOracleLastTs(oc oracle.Oracle, ts uint64) {
case *pdOracle:
lastTSInterface, _ := o.lastTSMap.LoadOrStore(oracle.GlobalTxnScope, &atomic.Pointer[lastTSO]{})
lastTSPointer := lastTSInterface.(*atomic.Pointer[lastTSO])
lastTSPointer.Store(&lastTSO{tso: ts, arrival: ts})
lastTSPointer.Store(&lastTSO{tso: ts, arrival: oracle.GetTimeFromTS(ts)})
}
}
2 changes: 1 addition & 1 deletion oracle/oracles/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ func (l *localOracle) GetExternalTimestamp(ctx context.Context) (uint64, error)
return l.getExternalTimestamp(ctx)
}

func (l *localOracle) ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *oracle.Option) error {
func (l *localOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) error {
currentTS, err := l.GetTimestamp(ctx, opt)
if err != nil {
return errors.Errorf("fail to validate read timestamp: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion oracle/oracles/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func (o *MockOracle) SetLowResolutionTimestampUpdateInterval(time.Duration) erro
return nil
}

func (o *MockOracle) ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *oracle.Option) error {
func (o *MockOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) error {
currentTS, err := o.GetTimestamp(ctx, opt)
if err != nil {
return errors.Errorf("fail to validate read timestamp: %v", err)
Expand Down
Loading
Loading