diff --git a/service/matching/matchingEngine.go b/service/matching/matchingEngine.go index 8f7826df16b..510dd2e768f 100644 --- a/service/matching/matchingEngine.go +++ b/service/matching/matchingEngine.go @@ -702,16 +702,17 @@ func (e *matchingEngineImpl) getTask( return tlMgr.GetTask(ctx, maxDispatchPerSecond) } -func (e *matchingEngineImpl) unloadTaskQueue(id *taskQueueID) { +func (e *matchingEngineImpl) unloadTaskQueue(unloadTQM taskQueueManager) { + queueID := unloadTQM.QueueID() e.taskQueuesLock.Lock() - tlMgr, ok := e.taskQueues[*id] - if ok { - delete(e.taskQueues, *id) + foundTQM, ok := e.taskQueues[*queueID] + if !ok || foundTQM != unloadTQM { + e.taskQueuesLock.Unlock() + return } + delete(e.taskQueues, *queueID) e.taskQueuesLock.Unlock() - if ok { - tlMgr.Stop() - } + foundTQM.Stop() } // Populate the workflow task response based on context and scheduled/started events. diff --git a/service/matching/matchingEngine_test.go b/service/matching/matchingEngine_test.go index 8b871034ca4..54b4fb9a7c5 100644 --- a/service/matching/matchingEngine_test.go +++ b/service/matching/matchingEngine_test.go @@ -253,6 +253,42 @@ func (s *matchingEngineSuite) TestPollWorkflowTaskQueuesEmptyResultWithShortCont s.PollForTasksEmptyResultTest(callContext, enumspb.TASK_QUEUE_TYPE_WORKFLOW) } +func (s *matchingEngineSuite) TestOnlyUnloadMatchingInstance() { + queueID := newTestTaskQueueID( + uuid.New(), + "makeToast", + enumspb.TASK_QUEUE_TYPE_ACTIVITY) + tqm, err := s.matchingEngine.getTaskQueueManager( + queueID, + enumspb.TASK_QUEUE_KIND_NORMAL) + s.Require().NoError(err) + + tqm2, err := newTaskQueueManager( + s.matchingEngine, + queueID, // same queueID as above + enumspb.TASK_QUEUE_KIND_NORMAL, + s.matchingEngine.config) + s.Require().NoError(err) + + // try to unload a different tqm instance with the same taskqueue ID + s.matchingEngine.unloadTaskQueue(tqm2) + + got, err := s.matchingEngine.getTaskQueueManager( + queueID, enumspb.TASK_QUEUE_KIND_NORMAL) + s.Require().NoError(err) + s.Require().Same(tqm, got, + "Unload call with non-matching taskQueueManager should not cause unload") + + // this time unload the right tqm + s.matchingEngine.unloadTaskQueue(tqm) + + got, err = s.matchingEngine.getTaskQueueManager( + queueID, enumspb.TASK_QUEUE_KIND_NORMAL) + s.Require().NoError(err) + s.Require().NotSame(tqm, got, + "Unload call with matching incarnation should have caused unload") +} + func (s *matchingEngineSuite) TestPollWorkflowTaskQueues() { namespaceID := uuid.NewRandom().String() tl := "makeToast" diff --git a/service/matching/taskQueueManager.go b/service/matching/taskQueueManager.go index 65d52403190..8eb1855defc 100644 --- a/service/matching/taskQueueManager.go +++ b/service/matching/taskQueueManager.go @@ -108,6 +108,7 @@ type ( // DescribeTaskQueue returns information about the target task queue DescribeTaskQueue(includeTaskQueueStatus bool) *matchingservice.DescribeTaskQueueResponse String() string + QueueID() *taskQueueID } // Single task queue in memory state @@ -139,7 +140,7 @@ type ( outstandingPollsLock sync.Mutex outstandingPollsMap map[string]context.CancelFunc shutdownCh chan struct{} // Delivers stop to the pump that populates taskBuffer - signalFatalProblem func(id *taskQueueID) + signalFatalProblem func(taskQueueManager) } ) @@ -471,7 +472,7 @@ func (c *taskQueueManagerImpl) completeTask(task *persistencespb.AllocatedTaskIn tag.Error(err), tag.WorkflowTaskQueueName(c.taskQueueID.name), tag.WorkflowTaskQueueType(c.taskQueueID.taskType)) - c.signalFatalProblem(c.taskQueueID) + c.signalFatalProblem(c) return } c.taskReader.Signal() @@ -600,3 +601,7 @@ func (c *taskQueueManagerImpl) tryInitNamespaceAndScope() { c.metricScopeValue.Store(scope) c.namespaceValue.Store(namespace) } + +func (c *taskQueueManagerImpl) QueueID() *taskQueueID { + return c.taskQueueID +} diff --git a/service/matching/taskWriter.go b/service/matching/taskWriter.go index 75dbb66e66d..6d8d0005f76 100644 --- a/service/matching/taskWriter.go +++ b/service/matching/taskWriter.go @@ -201,7 +201,7 @@ func (w *taskWriter) appendTasks( return resp, nil case *persistence.ConditionFailedError: - w.tlMgr.signalFatalProblem(w.tlMgr.taskQueueID) + w.tlMgr.signalFatalProblem(w.tlMgr) return nil, err default: @@ -219,7 +219,7 @@ func (w *taskWriter) taskWriterLoop(ctx context.Context) error { err := w.initReadWriteState(ctx) if err != nil { if w.tlMgr.errShouldUnload(err) { - w.tlMgr.signalFatalProblem(w.tlMgr.taskQueueID) + w.tlMgr.signalFatalProblem(w.tlMgr) } return err } @@ -318,7 +318,7 @@ func (w *taskWriter) allocTaskIDBlock(ctx context.Context, prevBlockEnd int64) ( state, err := w.renewLeaseWithRetry(ctx, persistenceOperationRetryPolicy, common.IsPersistenceTransientError) if err != nil { if w.tlMgr.errShouldUnload(err) { - w.tlMgr.signalFatalProblem(w.taskQueueID) + w.tlMgr.signalFatalProblem(w.tlMgr) return taskIDBlock{}, errShutdown } return taskIDBlock{}, err