diff --git a/oracle/oracles/local.go b/oracle/oracles/local.go index e916286ac3..8da336f5f2 100644 --- a/oracle/oracles/local.go +++ b/oracle/oracles/local.go @@ -36,11 +36,9 @@ package oracles import ( "context" - "math" "sync" "time" - "github.com/pingcap/errors" "github.com/tikv/client-go/v2/oracle" ) @@ -138,22 +136,6 @@ func (l *localOracle) GetExternalTimestamp(ctx context.Context) (uint64, error) } func (l *localOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) error { - if readTS == math.MaxUint64 { - if isStaleRead { - return oracle.ErrLatestStaleRead{} - } - return nil - } - - currentTS, err := l.GetTimestamp(ctx, opt) - if err != nil { - return errors.Errorf("fail to validate read timestamp: %v", err) - } - if currentTS < readTS { - return oracle.ErrFutureTSRead{ - ReadTS: readTS, - CurrentTS: currentTS, - } - } + // local oracle is not supposed to be used return nil } diff --git a/oracle/oracles/mock.go b/oracle/oracles/mock.go index da8874d5c8..662766cec3 100644 --- a/oracle/oracles/mock.go +++ b/oracle/oracles/mock.go @@ -36,7 +36,6 @@ package oracles import ( "context" - "math" "sync" "time" @@ -128,23 +127,6 @@ func (o *MockOracle) SetLowResolutionTimestampUpdateInterval(time.Duration) erro } func (o *MockOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) error { - if readTS == math.MaxUint64 { - if isStaleRead { - return oracle.ErrLatestStaleRead{} - } - return nil - } - - currentTS, err := o.GetTimestamp(ctx, opt) - if err != nil { - return errors.Errorf("fail to validate read timestamp: %v", err) - } - if currentTS < readTS { - return oracle.ErrFutureTSRead{ - ReadTS: readTS, - CurrentTS: currentTS, - } - } return nil } diff --git a/oracle/oracles/pd.go b/oracle/oracles/pd.go index 805d1b5c7a..a3f906bab1 100644 --- a/oracle/oracles/pd.go +++ b/oracle/oracles/pd.go @@ -620,7 +620,22 @@ func (o *pdOracle) getCurrentTSForValidation(ctx context.Context, opt *oracle.Op } } +// ValidateReadTSForTidbSnapshot is a flag in context, indicating whether the read ts is for tidb_snapshot. +// This is a special approach for release branches to minimize code changes to reduce risks. +type ValidateReadTSForTidbSnapshot struct{} + func (o *pdOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) (errRet error) { + // For a mistake we've seen + if readTS >= math.MaxInt64 && readTS < math.MaxUint64 { + return errors.Errorf("MaxInt64 <= readTS < MaxUint64, readTS=%v", readTS) + } + + // For release branches, only check stale reads and reads using `tidb_snapshot` + forTidbSnapshot := ctx.Value(ValidateReadTSForTidbSnapshot{}) != nil + if !forTidbSnapshot && !isStaleRead { + return nil + } + if readTS == math.MaxUint64 { if isStaleRead { return oracle.ErrLatestStaleRead{} diff --git a/oracle/oracles/pd_test.go b/oracle/oracles/pd_test.go index 25345f3b85..30d7acbae0 100644 --- a/oracle/oracles/pd_test.go +++ b/oracle/oracles/pd_test.go @@ -237,56 +237,6 @@ func TestAdaptiveUpdateTSInterval(t *testing.T) { assert.Equal(t, adaptiveUpdateTSIntervalStateNormal, o.adaptiveUpdateIntervalState.state) } -func TestValidateReadTS(t *testing.T) { - testImpl := func(staleRead bool) { - pdClient := MockPdClient{} - o, err := NewPdOracle(&pdClient, &PDOracleOptions{ - UpdateInterval: time.Second * 2, - }) - assert.NoError(t, err) - defer o.Close() - - ctx := context.Background() - opt := &oracle.Option{TxnScope: oracle.GlobalTxnScope} - - // Always returns error for MaxUint64 - err = o.ValidateReadTS(ctx, math.MaxUint64, staleRead, opt) - if staleRead { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - - ts, err := o.GetTimestamp(ctx, opt) - assert.NoError(t, err) - assert.GreaterOrEqual(t, ts, uint64(1)) - - err = o.ValidateReadTS(ctx, 1, staleRead, opt) - assert.NoError(t, err) - ts, err = o.GetTimestamp(ctx, opt) - assert.NoError(t, err) - // The readTS exceeds the latest ts, so it first fails the check with the low resolution ts. Then it fallbacks to - // the fetching-from-PD path, and it can get the previous ts + 1, which can allow this validation to pass. - err = o.ValidateReadTS(ctx, ts+1, staleRead, opt) - assert.NoError(t, err) - // It can't pass if the readTS is newer than previous ts + 2. - ts, err = o.GetTimestamp(ctx, opt) - assert.NoError(t, err) - err = o.ValidateReadTS(ctx, ts+2, staleRead, opt) - assert.Error(t, err) - - // Simulate other PD clients requests a timestamp. - ts, err = o.GetTimestamp(ctx, opt) - assert.NoError(t, err) - pdClient.logicalTimestamp.Add(2) - err = o.ValidateReadTS(ctx, ts+3, staleRead, opt) - assert.NoError(t, err) - } - - testImpl(true) - testImpl(false) -} - type MockPDClientWithPause struct { MockPdClient mu sync.Mutex @@ -436,9 +386,9 @@ func TestValidateReadTSForNormalReadDoNotAffectUpdateInterval(t *testing.T) { assert.NoError(t, err) mustNoNotify() - // It loads `ts + 1` from the mock PD, and the check cannot pass. + // It loads `ts + 1` from the mock PD, and the check is skipped. err = o.ValidateReadTS(ctx, ts+2, false, opt) - assert.Error(t, err) + assert.NoError(t, err) mustNoNotify() // Do the check again. It loads `ts + 2` from the mock PD, and the check passes.