Skip to content

Commit fd5b2dd

Browse files
bysomeone33cn
authored andcommitted
[[FIX]] fix merge iterator reverse list(#1211)
1 parent 89ce5be commit fd5b2dd

File tree

2 files changed

+148
-48
lines changed

2 files changed

+148
-48
lines changed

common/db/merge_iter.go

+22-48
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ func (i *mergedIterator) Rewind() bool {
7676
}
7777
}
7878
i.dir = dirSOI
79-
return i.next(false)
79+
return i.selectKey()
8080
}
8181

8282
func (i *mergedIterator) Seek(key []byte) bool {
@@ -97,35 +97,33 @@ func (i *mergedIterator) Seek(key []byte) bool {
9797
}
9898
}
9999
i.dir = dirSOI
100-
if i.next(!i.reverse) {
100+
if i.selectKey() {
101101
i.dir = dirSeek
102102
return true
103103
}
104104
i.dir = dirSOI
105105
return false
106106
}
107107

108-
func (i *mergedIterator) compare(tkey []byte, key []byte, ignoreReverse bool) int {
109-
if ignoreReverse {
110-
return i.cmp.Compare(tkey, key)
111-
}
112-
if tkey == nil && key != nil {
108+
func (i *mergedIterator) compare(key1 []byte, key2 []byte) int {
109+
110+
if key1 == nil && key2 != nil {
113111
return 1
114112
}
115-
if tkey != nil && key == nil {
113+
if key1 != nil && key2 == nil {
116114
return -1
117115
}
118-
result := i.cmp.Compare(tkey, key)
116+
result := i.cmp.Compare(key1, key2)
119117
if i.reverse {
120118
return -result
121119
}
122120
return result
123121
}
124122

125-
func (i *mergedIterator) next(ignoreReverse bool) bool {
123+
func (i *mergedIterator) selectKey() bool {
126124
var key []byte
127125
for x, tkey := range i.keys {
128-
if tkey != nil && (key == nil || i.compare(tkey, key, ignoreReverse) < 0) {
126+
if tkey != nil && (key == nil || i.compare(tkey, key) < 0) {
129127
key = tkey
130128
i.index = x
131129
}
@@ -141,64 +139,40 @@ func (i *mergedIterator) next(ignoreReverse bool) bool {
141139
return true
142140
}
143141

142+
// Next next key
144143
func (i *mergedIterator) Next() bool {
145144
for {
146-
ok, isrewind := i.nextInternal()
147-
if !ok {
148-
break
149-
}
150-
if isrewind {
151-
return true
145+
146+
if !i.next() {
147+
return false
152148
}
153-
if i.compare(i.Key(), i.prevKey, true) != 0 {
149+
150+
if i.compare(i.Key(), i.prevKey) != 0 {
154151
i.prevKey = cloneByte(i.Key())
155152
return true
156153
}
157154
}
158-
return false
159155
}
160156

161-
func (i *mergedIterator) nextInternal() (bool, bool) {
157+
func (i *mergedIterator) next() bool {
162158
if i.dir == dirEOI || i.err != nil {
163-
return false, false
159+
return false
164160
} else if i.dir == dirReleased {
165161
i.err = ErrIterReleased
166-
return false, false
167-
}
168-
switch i.dir {
169-
case dirSOI:
170-
return i.Rewind(), true
171-
case dirSeek:
172-
if !i.reverse {
173-
break
174-
}
175-
key := append([]byte{}, i.keys[i.index]...)
176-
for x, iter := range i.iters {
177-
if x == i.index {
178-
continue
179-
}
180-
seek := iter.Seek(key)
181-
switch {
182-
case seek && iter.Next(), !seek && iter.Rewind():
183-
i.keys[x] = assertKey(iter.Key())
184-
case i.iterErr(iter):
185-
return false, false
186-
default:
187-
i.keys[x] = nil
188-
}
189-
}
162+
return false
190163
}
164+
191165
x := i.index
192166
iter := i.iters[x]
193167
switch {
194168
case iter.Next():
195169
i.keys[x] = assertKey(iter.Key())
196170
case i.iterErr(iter):
197-
return false, false
171+
return false
198172
default:
199173
i.keys[x] = nil
200174
}
201-
return i.next(false), false
175+
return i.selectKey()
202176
}
203177

204178
func (i *mergedIterator) Key() []byte {
@@ -247,7 +221,7 @@ func (i *mergedIterator) Error() error {
247221
//
248222
// If strict is true the any 'corruption errors' (i.e errors.IsCorrupted(err) == true)
249223
// won't be ignored and will halt 'merged iterator', otherwise the iterator will
250-
// continue to the next 'input iterator'.
224+
// continue to the selectKey 'input iterator'.
251225
func NewMergedIterator(iters []Iterator) Iterator {
252226
reverse := true
253227
if len(iters) >= 2 {

common/db/merge_iter_test.go

+126
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package db
22

33
import (
4+
"fmt"
45
"io/ioutil"
56
"os"
67
"testing"
@@ -46,6 +47,14 @@ func newGoLevelDB(t *testing.T) (DB, string) {
4647
return db, dir
4748
}
4849

50+
func newGoBadgerDB(t *testing.T) (DB, string) {
51+
dir, err := ioutil.TempDir("", "badgerdb")
52+
assert.Nil(t, err)
53+
db, err := NewGoBadgerDB("test", dir, 16)
54+
assert.Nil(t, err)
55+
return db, dir
56+
}
57+
4958
func TestMergeIterSeek1(t *testing.T) {
5059
db1 := newGoMemDB(t)
5160
db1.Set([]byte("1"), []byte("1"))
@@ -279,3 +288,120 @@ func TestIterSearch(t *testing.T) {
279288
assert.Equal(t, "db2-key-3", string(list0[0]))
280289
assert.Equal(t, "db2-key-4", string(list0[1]))
281290
}
291+
292+
func TestMergeIterList(t *testing.T) {
293+
levelDB, dir := newGoLevelDB(t)
294+
testMergeIterList(t, newGoMemDB(t), newGoMemDB(t), levelDB)
295+
_ = os.RemoveAll(dir)
296+
badgerDB, dir := newGoBadgerDB(t)
297+
testMergeIterList(t, newGoMemDB(t), newGoMemDB(t), badgerDB)
298+
_ = os.RemoveAll(dir)
299+
levelDB, dir1 := newGoLevelDB(t)
300+
badgerDB, dir2 := newGoBadgerDB(t)
301+
testMergeIterList(t, badgerDB, levelDB, newGoMemDB(t))
302+
_ = os.RemoveAll(dir1)
303+
_ = os.RemoveAll(dir2)
304+
}
305+
306+
func testMergeIterList(t *testing.T, db1, db2, db3 DB) {
307+
308+
for i := 0; i < 10; i++ {
309+
db3.Set([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("%d", i)))
310+
}
311+
//合并以后:
312+
db := NewMergedIteratorDB([]IteratorDB{db1, db2, db3})
313+
it := NewListHelper(db)
314+
315+
//key9 ~ key1
316+
listAll := func(totalCount int, direction int32) [][]byte {
317+
var values [][]byte
318+
var primary []byte
319+
for i := 0; i < 3; i++ {
320+
data := it.List([]byte("key"), primary, 4, direction)
321+
values = append(values, data...)
322+
primary = []byte(fmt.Sprintf("key%s", data[len(data)-1]))
323+
}
324+
assert.Equal(t, totalCount, len(values))
325+
return values
326+
}
327+
328+
values := listAll(10, ListDESC)
329+
for i, val := range values {
330+
assert.Equal(t, []byte(fmt.Sprintf("%d", 9-i)), val)
331+
}
332+
values = listAll(10, ListASC)
333+
for i, val := range values {
334+
assert.Equal(t, []byte(fmt.Sprintf("%d", i)), val)
335+
}
336+
337+
// db2数据覆盖
338+
db2.Set([]byte("key3"), []byte("33"))
339+
values = listAll(10, ListDESC)
340+
for i, val := range values {
341+
value := []byte(fmt.Sprintf("%d", 9-i))
342+
if i == 6 {
343+
value = []byte("33")
344+
}
345+
assert.Equal(t, value, val)
346+
}
347+
values = listAll(10, ListASC)
348+
for i, val := range values {
349+
value := []byte(fmt.Sprintf("%d", i))
350+
if i == 3 {
351+
value = []byte("33")
352+
}
353+
assert.Equal(t, value, val)
354+
}
355+
356+
// db1数据覆盖
357+
db1.Set([]byte("key3"), []byte("333"))
358+
db1.Set([]byte("key5"), []byte("555"))
359+
values = listAll(10, ListDESC)
360+
for i, val := range values {
361+
value := []byte(fmt.Sprintf("%d", 9-i))
362+
if i == 4 {
363+
value = []byte("555")
364+
}
365+
if i == 6 {
366+
value = []byte("333")
367+
}
368+
assert.Equal(t, value, val)
369+
}
370+
values = listAll(10, ListASC)
371+
for i, val := range values {
372+
value := []byte(fmt.Sprintf("%d", i))
373+
if i == 5 {
374+
value = []byte("555")
375+
}
376+
if i == 3 {
377+
value = []byte("333")
378+
}
379+
assert.Equal(t, value, val)
380+
}
381+
382+
// 新增key
383+
db1.Set([]byte("key91"), []byte("10"))
384+
db2.Set([]byte("key92"), []byte("11"))
385+
values = listAll(12, ListDESC)
386+
for i, val := range values {
387+
value := []byte(fmt.Sprintf("%d", 11-i))
388+
if i == 6 {
389+
value = []byte("555")
390+
}
391+
if i == 8 {
392+
value = []byte("333")
393+
}
394+
assert.Equal(t, value, val)
395+
}
396+
values = listAll(12, ListASC)
397+
for i, val := range values {
398+
value := []byte(fmt.Sprintf("%d", i))
399+
if i == 5 {
400+
value = []byte("555")
401+
}
402+
if i == 3 {
403+
value = []byte("333")
404+
}
405+
assert.Equal(t, value, val)
406+
}
407+
}

0 commit comments

Comments
 (0)