Skip to content

Commit

Permalink
Validate ts only for stale read (tikv#1592)
Browse files Browse the repository at this point in the history
ref pingcap/tidb#59402

Signed-off-by: ekexium <eke@fastmail.com>

fix test

Signed-off-by: ekexium <eke@fastmail.com>
  • Loading branch information
ekexium committed Feb 26, 2025
1 parent d561c89 commit 460b877
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 89 deletions.
20 changes: 1 addition & 19 deletions oracle/oracles/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,9 @@ package oracles

import (
"context"
"math"
"sync"
"time"

"github.com/pingcap/errors"
"github.com/tikv/client-go/v2/oracle"
)

Expand Down Expand Up @@ -138,22 +136,6 @@ func (l *localOracle) GetExternalTimestamp(ctx context.Context) (uint64, error)
}

func (l *localOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) error {
if readTS == math.MaxUint64 {
if isStaleRead {
return oracle.ErrLatestStaleRead{}
}
return nil
}

currentTS, err := l.GetTimestamp(ctx, opt)
if err != nil {
return errors.Errorf("fail to validate read timestamp: %v", err)
}
if currentTS < readTS {
return oracle.ErrFutureTSRead{
ReadTS: readTS,
CurrentTS: currentTS,
}
}
// local oracle is not supposed to be used
return nil
}
18 changes: 0 additions & 18 deletions oracle/oracles/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ package oracles

import (
"context"
"math"
"sync"
"time"

Expand Down Expand Up @@ -128,23 +127,6 @@ func (o *MockOracle) SetLowResolutionTimestampUpdateInterval(time.Duration) erro
}

func (o *MockOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) error {
if readTS == math.MaxUint64 {
if isStaleRead {
return oracle.ErrLatestStaleRead{}
}
return nil
}

currentTS, err := o.GetTimestamp(ctx, opt)
if err != nil {
return errors.Errorf("fail to validate read timestamp: %v", err)
}
if currentTS < readTS {
return oracle.ErrFutureTSRead{
ReadTS: readTS,
CurrentTS: currentTS,
}
}
return nil
}

Expand Down
15 changes: 15 additions & 0 deletions oracle/oracles/pd.go
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,22 @@ func (o *pdOracle) getCurrentTSForValidation(ctx context.Context, opt *oracle.Op
}
}

// ValidateReadTSForTidbSnapshot is a flag in context, indicating whether the read ts is for tidb_snapshot.
// This is a special approach for release branches to minimize code changes to reduce risks.
type ValidateReadTSForTidbSnapshot struct{}

func (o *pdOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) (errRet error) {
// For a mistake we've seen
if readTS >= math.MaxInt64 && readTS < math.MaxUint64 {
return errors.Errorf("MaxInt64 <= readTS < MaxUint64, readTS=%v", readTS)
}

// For release branches, only check stale reads and reads using `tidb_snapshot`
forTidbSnapshot := ctx.Value(ValidateReadTSForTidbSnapshot{}) != nil
if !forTidbSnapshot && !isStaleRead {
return nil
}

if readTS == math.MaxUint64 {
if isStaleRead {
return oracle.ErrLatestStaleRead{}
Expand Down
54 changes: 2 additions & 52 deletions oracle/oracles/pd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,56 +237,6 @@ func TestAdaptiveUpdateTSInterval(t *testing.T) {
assert.Equal(t, adaptiveUpdateTSIntervalStateNormal, o.adaptiveUpdateIntervalState.state)
}

func TestValidateReadTS(t *testing.T) {
testImpl := func(staleRead bool) {
pdClient := MockPdClient{}
o, err := NewPdOracle(&pdClient, &PDOracleOptions{
UpdateInterval: time.Second * 2,
})
assert.NoError(t, err)
defer o.Close()

ctx := context.Background()
opt := &oracle.Option{TxnScope: oracle.GlobalTxnScope}

// Always returns error for MaxUint64
err = o.ValidateReadTS(ctx, math.MaxUint64, staleRead, opt)
if staleRead {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}

ts, err := o.GetTimestamp(ctx, opt)
assert.NoError(t, err)
assert.GreaterOrEqual(t, ts, uint64(1))

err = o.ValidateReadTS(ctx, 1, staleRead, opt)
assert.NoError(t, err)
ts, err = o.GetTimestamp(ctx, opt)
assert.NoError(t, err)
// The readTS exceeds the latest ts, so it first fails the check with the low resolution ts. Then it fallbacks to
// the fetching-from-PD path, and it can get the previous ts + 1, which can allow this validation to pass.
err = o.ValidateReadTS(ctx, ts+1, staleRead, opt)
assert.NoError(t, err)
// It can't pass if the readTS is newer than previous ts + 2.
ts, err = o.GetTimestamp(ctx, opt)
assert.NoError(t, err)
err = o.ValidateReadTS(ctx, ts+2, staleRead, opt)
assert.Error(t, err)

// Simulate other PD clients requests a timestamp.
ts, err = o.GetTimestamp(ctx, opt)
assert.NoError(t, err)
pdClient.logicalTimestamp.Add(2)
err = o.ValidateReadTS(ctx, ts+3, staleRead, opt)
assert.NoError(t, err)
}

testImpl(true)
testImpl(false)
}

type MockPDClientWithPause struct {
MockPdClient
mu sync.Mutex
Expand Down Expand Up @@ -436,9 +386,9 @@ func TestValidateReadTSForNormalReadDoNotAffectUpdateInterval(t *testing.T) {
assert.NoError(t, err)
mustNoNotify()

// It loads `ts + 1` from the mock PD, and the check cannot pass.
// It loads `ts + 1` from the mock PD, and the check is skipped.
err = o.ValidateReadTS(ctx, ts+2, false, opt)
assert.Error(t, err)
assert.NoError(t, err)
mustNoNotify()

// Do the check again. It loads `ts + 2` from the mock PD, and the check passes.
Expand Down

0 comments on commit 460b877

Please sign in to comment.