diff --git a/common/domain/replication_queue_test.go b/common/domain/replication_queue_test.go index a795401f18a..c96da87ae74 100644 --- a/common/domain/replication_queue_test.go +++ b/common/domain/replication_queue_test.go @@ -21,7 +21,9 @@ package domain import ( + "bytes" "context" + "encoding/binary" "errors" "testing" @@ -32,6 +34,10 @@ import ( "github.com/uber/cadence/common/types" ) +const ( + preambleVersion0 byte = 0x59 +) + func TestReplicationQueueImpl_Publish(t *testing.T) { tests := []struct { name string @@ -69,7 +75,6 @@ func TestReplicationQueueImpl_Publish(t *testing.T) { } else { assert.NoError(t, err) } - ctrl.Finish() }) } } @@ -111,7 +116,6 @@ func TestReplicationQueueImpl_PublishToDLQ(t *testing.T) { } else { assert.NoError(t, err) } - ctrl.Finish() }) } } @@ -122,7 +126,6 @@ func TestGetReplicationMessages(t *testing.T) { name string lastID int64 maxCount int - task *types.ReplicationTask wantErr bool setupMock func(q *persistence.MockQueueManager) }{ @@ -130,7 +133,6 @@ func TestGetReplicationMessages(t *testing.T) { name: "successful message retrieval", lastID: 100, maxCount: 10, - task: &types.ReplicationTask{}, wantErr: false, setupMock: func(q *persistence.MockQueueManager) { q.EXPECT().ReadMessages(gomock.Any(), gomock.Eq(int64(100)), gomock.Eq(10)).Return(persistence.QueueMessageList{}, nil) @@ -160,7 +162,6 @@ func TestGetReplicationMessages(t *testing.T) { } else { assert.NoError(t, err) } - ctrl.Finish() }) } } @@ -206,7 +207,111 @@ func TestUpdateAckLevel(t *testing.T) { } else { assert.NoError(t, err) } - ctrl.Finish() + }) + } +} + +func TestReplicationQueueImpl_GetAckLevels(t *testing.T) { + tests := []struct { + name string + want map[string]int64 + wantErr bool + setupMock func(q *persistence.MockQueueManager) + }{ + { + name: "successful ack levels retrieval", + want: map[string]int64{"testCluster": 100}, + wantErr: false, + setupMock: func(q *persistence.MockQueueManager) { + q.EXPECT().GetAckLevels(gomock.Any()).Return(map[string]int64{"testCluster": 100}, nil) + }, + }, + { + name: "ack levels retrieval fails", + wantErr: true, + setupMock: func(q *persistence.MockQueueManager) { + q.EXPECT().GetAckLevels(gomock.Any()).Return(nil, errors.New("retrieval error")) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + mockQueue := persistence.NewMockQueueManager(ctrl) + rq := NewReplicationQueue(mockQueue, "testCluster", nil, nil) + tt.setupMock(mockQueue) + got, err := rq.GetAckLevels(context.Background()) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + +func mockEncodeReplicationTask(sourceTaskID int64) ([]byte, error) { + var buf bytes.Buffer + buf.WriteByte(preambleVersion0) + binary.Write(&buf, binary.BigEndian, sourceTaskID) + return buf.Bytes(), nil +} + +func TestGetMessagesFromDLQ(t *testing.T) { + tests := []struct { + name string + firstID int64 + lastID int64 + pageSize int + pageToken []byte + taskID int64 + wantErr bool + }{ + { + name: "successful message retrieval", + firstID: 100, + lastID: 200, + pageSize: 10, + pageToken: []byte("token"), + taskID: 12345, + wantErr: false, + }, + { + name: "read messages fails", + firstID: 100, + lastID: 200, + pageSize: 10, + pageToken: []byte("token"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + mockQueue := persistence.NewMockQueueManager(ctrl) + rq := NewReplicationQueue(mockQueue, "testCluster", nil, nil) + + if !tt.wantErr { + encodedData, _ := mockEncodeReplicationTask(tt.taskID) + messages := []*persistence.QueueMessage{ + {ID: 1, Payload: encodedData}, + } + mockQueue.EXPECT().ReadMessagesFromDLQ(gomock.Any(), tt.firstID, tt.lastID, tt.pageSize, tt.pageToken).Return(messages, []byte("nextToken"), nil) + } else { + mockQueue.EXPECT().ReadMessagesFromDLQ(gomock.Any(), tt.firstID, tt.lastID, tt.pageSize, tt.pageToken).Return(nil, nil, errors.New("read error")) + } + + replicationTasks, token, err := rq.GetMessagesFromDLQ(context.Background(), tt.firstID, tt.lastID, tt.pageSize, tt.pageToken) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Len(t, replicationTasks, 1, "Expected one replication task to be returned") + assert.Equal(t, []byte("nextToken"), token, "Expected token to match 'nextToken'") + } }) } }