diff --git a/common/log/tag/tags.go b/common/log/tag/tags.go index e8373d1dedc..40f301432a6 100644 --- a/common/log/tag/tags.go +++ b/common/log/tag/tags.go @@ -984,6 +984,10 @@ func BuildId(buildId string) ZapTag { return NewStringTag("build-id", buildId) } +func UserDataVersion(v int64) ZapTag { + return NewInt64("user-data-version", v) +} + func Cause(cause string) ZapTag { return NewStringTag("cause", cause) } diff --git a/service/matching/config.go b/service/matching/config.go index e7a595a3429..0420880ccdc 100644 --- a/service/matching/config.go +++ b/service/matching/config.go @@ -139,6 +139,7 @@ type ( GetUserDataLongPollTimeout dynamicconfig.DurationPropertyFn GetUserDataMinWaitTime time.Duration + GetUserDataReturnBudget time.Duration // taskWriter configuration OutstandingTaskAppendsThreshold func() int @@ -298,6 +299,7 @@ func newTaskQueueConfig(tq *tqid.TaskQueue, config *Config, ns namespace.Name) * }, GetUserDataLongPollTimeout: config.GetUserDataLongPollTimeout, GetUserDataMinWaitTime: 1 * time.Second, + GetUserDataReturnBudget: returnEmptyTaskTimeBudget, OutstandingTaskAppendsThreshold: func() int { return config.OutstandingTaskAppendsThreshold(ns.String(), taskQueueName, taskType) }, diff --git a/service/matching/matching_engine.go b/service/matching/matching_engine.go index 631f9ce28b7..0565395bbdc 100644 --- a/service/matching/matching_engine.go +++ b/service/matching/matching_engine.go @@ -1461,51 +1461,11 @@ func (e *matchingEngineImpl) GetTaskQueueUserData( if err != nil { return nil, err } - version := req.GetLastKnownUserDataVersion() - if version < 0 { - return nil, serviceerror.NewInvalidArgument("last_known_user_data_version must not be negative") - } - if req.WaitNewData { - var cancel context.CancelFunc - ctx, cancel = newChildContext(ctx, e.config.GetUserDataLongPollTimeout(), returnEmptyTaskTimeBudget) - defer cancel() // mark alive so that it doesn't unload while a child partition is doing a long poll pm.MarkAlive() } - - for { - resp := &matchingservice.GetTaskQueueUserDataResponse{} - userData, userDataChanged, err := pm.GetUserDataManager().GetUserData() - if errors.Is(err, errTaskQueueClosed) { - // If we're closing, return a success with no data, as if the request expired. We shouldn't - // close due to idleness (because of the MarkAlive above), so we're probably closing due to a - // change of ownership. The caller will retry and be redirected to the new owner. - return resp, nil - } else if err != nil { - return nil, err - } - if req.WaitNewData && userData.GetVersion() == version { - // long-poll: wait for data to change/appear - select { - case <-ctx.Done(): - return resp, nil - case <-userDataChanged: - continue - } - } - if userData != nil { - if userData.Version > version { - resp.UserData = userData - } else if userData.Version < version { - // This is highly unlikely but may happen due to an edge case in during ownership transfer. - // We rely on client retries in this case to let the system eventually self-heal. - return nil, serviceerror.NewInvalidArgument( - "requested task queue user data for version greater than known version") - } - } - return resp, nil - } + return pm.GetUserDataManager().HandleGetUserDataRequest(ctx, req) } func (e *matchingEngineImpl) ApplyTaskQueueUserDataReplicationEvent( diff --git a/service/matching/task_queue_partition_manager_test.go b/service/matching/task_queue_partition_manager_test.go index 12c6304cc27..8750ca7a2ac 100644 --- a/service/matching/task_queue_partition_manager_test.go +++ b/service/matching/task_queue_partition_manager_test.go @@ -37,6 +37,8 @@ import ( enumspb "go.temporal.io/api/enums/v1" "go.temporal.io/api/serviceerror" taskqueuepb "go.temporal.io/api/taskqueue/v1" + + "go.temporal.io/server/api/matchingservice/v1" "go.temporal.io/server/api/matchingservicemock/v1" "go.temporal.io/server/api/persistence/v1" "go.temporal.io/server/api/taskqueue/v1" @@ -588,6 +590,10 @@ func (m *mockUserDataManager) UpdateUserData(_ context.Context, _ UserDataUpdate return nil } +func (m *mockUserDataManager) HandleGetUserDataRequest(ctx context.Context, req *matchingservice.GetTaskQueueUserDataRequest) (*matchingservice.GetTaskQueueUserDataResponse, error) { + panic("unused") +} + func (m *mockUserDataManager) updateVersioningData(data *persistence.VersioningData) { m.Lock() defer m.Unlock() diff --git a/service/matching/user_data_manager.go b/service/matching/user_data_manager.go index bd38a199afa..6c37acdce09 100644 --- a/service/matching/user_data_manager.go +++ b/service/matching/user_data_manager.go @@ -37,6 +37,7 @@ import ( persistencespb "go.temporal.io/server/api/persistence/v1" "go.temporal.io/server/common" "go.temporal.io/server/common/backoff" + "go.temporal.io/server/common/clock/hybrid_logical_clock" "go.temporal.io/server/common/future" "go.temporal.io/server/common/headers" "go.temporal.io/server/common/log" @@ -66,6 +67,8 @@ type ( // UpdateUserData updates user data for this task queue and replicates across clusters if necessary. // Extra care should be taken to avoid mutating the existing data in the update function. UpdateUserData(ctx context.Context, options UserDataUpdateOptions, updateFn UserDataUpdateFunc) error + // Handles the maybe-long-poll GetUserData RPC. + HandleGetUserDataRequest(ctx context.Context, req *matchingservice.GetTaskQueueUserDataRequest) (*matchingservice.GetTaskQueueUserDataResponse, error) } UserDataUpdateOptions struct { @@ -276,6 +279,10 @@ func (m *userDataManagerImpl) fetchUserData(ctx context.Context) error { WaitNewData: hasFetchedUserData, }) if err != nil { + // don't log on context canceled, produces too much log spam at shutdown + if !common.IsContextCanceledErr(err) { + m.logger.Error("error fetching user data from parent", tag.Error(err)) + } var unimplErr *serviceerror.Unimplemented if errors.As(err, &unimplErr) { // This might happen during a deployment. The older version couldn't have had any user data, @@ -292,6 +299,9 @@ func (m *userDataManagerImpl) fetchUserData(ctx context.Context) error { // nil inner fields. if res.GetUserData() != nil { m.setUserDataForNonOwningPartition(res.GetUserData()) + m.logNewUserData("fetched user data from parent", res.GetUserData()) + } else { + m.logger.Debug("fetched user data from parent, no change") } hasFetchedUserData = true m.setUserDataState(userDataEnabled, nil) @@ -339,6 +349,7 @@ func (m *userDataManagerImpl) loadUserDataFromDB(ctx context.Context) error { m.lock.Lock() defer m.lock.Unlock() m.setUserDataLocked(response.UserData) + m.logNewUserData("loaded user data from db", response.UserData) return nil } @@ -349,6 +360,9 @@ func (m *userDataManagerImpl) UpdateUserData(ctx context.Context, options UserDa if m.store == nil { return errUserDataNoMutateNonRoot } + if err := m.WaitUntilInitialized(ctx); err != nil { + return err + } newData, shouldReplicate, err := m.updateUserData(ctx, updateFn, options.KnownVersion, options.TaskQueueLimitPerBuildId) if err != nil { return err @@ -413,6 +427,7 @@ func (m *userDataManagerImpl) updateUserData( } updatedUserData, shouldReplicate, err := updateFn(preUpdateData) if err != nil { + m.logger.Error("user data update function failed", tag.Error(err)) return nil, false, err } @@ -441,14 +456,86 @@ func (m *userDataManagerImpl) updateUserData( BuildIdsAdded: added, BuildIdsRemoved: removed, }) - var updatedVersionedData *persistencespb.VersionedTaskQueueUserData - if err == nil { - updatedVersionedData = &persistencespb.VersionedTaskQueueUserData{Version: preUpdateVersion + 1, Data: updatedUserData} - m.setUserDataLocked(updatedVersionedData) + if err != nil { + m.logger.Error("failed to push new user data to owning matching node for namespace", tag.Error(err)) + return nil, false, err } + + updatedVersionedData := &persistencespb.VersionedTaskQueueUserData{Version: preUpdateVersion + 1, Data: updatedUserData} + m.logNewUserData("modified user data", updatedVersionedData) + m.setUserDataLocked(updatedVersionedData) + return updatedVersionedData, shouldReplicate, err } +func (m *userDataManagerImpl) HandleGetUserDataRequest( + ctx context.Context, + req *matchingservice.GetTaskQueueUserDataRequest, +) (*matchingservice.GetTaskQueueUserDataResponse, error) { + version := req.GetLastKnownUserDataVersion() + if version < 0 { + return nil, serviceerror.NewInvalidArgument("last_known_user_data_version must not be negative") + } + + if req.WaitNewData { + var cancel context.CancelFunc + ctx, cancel = newChildContext(ctx, m.config.GetUserDataLongPollTimeout(), m.config.GetUserDataReturnBudget) + defer cancel() + } + + for { + resp := &matchingservice.GetTaskQueueUserDataResponse{} + userData, userDataChanged, err := m.GetUserData() + if errors.Is(err, errTaskQueueClosed) { + // If we're closing, return a success with no data, as if the request expired. We shouldn't + // close due to idleness (because of the MarkAlive above), so we're probably closing due to a + // change of ownership. The caller will retry and be redirected to the new owner. + m.logger.Debug("returning empty user data (closing)", tag.NewBoolTag("long-poll", req.WaitNewData)) + return resp, nil + } else if err != nil { + return nil, err + } + if req.WaitNewData && userData.GetVersion() == version { + // long-poll: wait for data to change/appear + select { + case <-ctx.Done(): + m.logger.Debug("returning empty user data (expired)", + tag.NewBoolTag("long-poll", req.WaitNewData), + tag.NewInt64("request-known-version", version), + tag.UserDataVersion(userData.GetVersion()), + ) + return resp, nil + case <-userDataChanged: + m.logger.Debug("user data changed while blocked in long poll") + continue + } + } + if userData != nil { + if userData.Version > version { + resp.UserData = userData + m.logger.Info("returning user data", + tag.NewBoolTag("long-poll", req.WaitNewData), + tag.NewInt64("request-known-version", version), + tag.UserDataVersion(userData.Version), + ) + } else if userData.Version < version { + // This is highly unlikely but may happen due to an edge case in during ownership transfer. + // We rely on client retries in this case to let the system eventually self-heal. + m.logger.Error("requested task queue user data for version greater than known version", + tag.NewInt64("request-known-version", version), + tag.UserDataVersion(userData.Version), + ) + return nil, serviceerror.NewInvalidArgument( + "requested task queue user data for version greater than known version") + } + } else { + m.logger.Debug("returning empty user data (no data)", tag.NewBoolTag("long-poll", req.WaitNewData)) + } + return resp, nil + } + +} + func (m *userDataManagerImpl) setUserDataForNonOwningPartition(userData *persistencespb.VersionedTaskQueueUserData) { m.lock.Lock() defer m.lock.Unlock() @@ -459,3 +546,10 @@ func (m *userDataManagerImpl) callerInfoContext(ctx context.Context) context.Con ns, _ := m.namespaceRegistry.GetNamespaceName(namespace.ID(m.partition.NamespaceId())) return headers.SetCallerInfo(ctx, headers.NewBackgroundCallerInfo(ns.String())) } + +func (m *userDataManagerImpl) logNewUserData(message string, data *persistencespb.VersionedTaskQueueUserData) { + m.logger.Info(message, + tag.UserDataVersion(data.GetVersion()), + tag.Timestamp(hybrid_logical_clock.UTC(data.GetData().GetClock())), + ) +} diff --git a/service/matching/user_data_manager_test.go b/service/matching/user_data_manager_test.go index 363e3838c3e..27ab3173e6b 100644 --- a/service/matching/user_data_manager_test.go +++ b/service/matching/user_data_manager_test.go @@ -26,10 +26,13 @@ package matching import ( "context" + "fmt" + "math/rand" "testing" "time" "github.com/pborman/uuid" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" enumspb "go.temporal.io/api/enums/v1" "go.temporal.io/api/serviceerror" @@ -38,8 +41,10 @@ import ( "go.temporal.io/server/common/backoff" "go.temporal.io/server/common/dynamicconfig" "go.temporal.io/server/common/log" + "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/persistence" + "go.temporal.io/server/common/tqid" "go.uber.org/mock/gomock" "google.golang.org/grpc" ) @@ -64,7 +69,6 @@ func TestUserData_LoadOnInit(t *testing.T) { t.Parallel() controller := gomock.NewController(t) - defer controller.Finish() ctx := context.Background() dbq := newTestUnversionedPhysicalQueueKey(defaultNamespaceId, defaultRootTqID, enumspb.TASK_QUEUE_TYPE_WORKFLOW, 0) tqCfg := defaultTqmTestOpts(controller) @@ -97,7 +101,6 @@ func TestUserData_LoadOnInit_OnlyOnceWhenNoData(t *testing.T) { t.Parallel() controller := gomock.NewController(t) - defer controller.Finish() ctx := context.Background() dbq := newTestUnversionedPhysicalQueueKey(defaultNamespaceId, defaultRootTqID, enumspb.TASK_QUEUE_TYPE_WORKFLOW, 0) tqCfg := defaultTqmTestOpts(controller) @@ -133,7 +136,6 @@ func TestUserData_FetchesOnInit(t *testing.T) { t.Parallel() controller := gomock.NewController(t) - defer controller.Finish() ctx := context.Background() dbq := newTestUnversionedPhysicalQueueKey(defaultNamespaceId, defaultRootTqID, enumspb.TASK_QUEUE_TYPE_WORKFLOW, 1) tqCfg := defaultTqmTestOpts(controller) @@ -172,7 +174,6 @@ func TestUserData_FetchesAndFetchesAgain(t *testing.T) { t.Parallel() controller := gomock.NewController(t) - defer controller.Finish() ctx := context.Background() // note: using activity here dbq := newTestUnversionedPhysicalQueueKey(defaultNamespaceId, defaultRootTqID, enumspb.TASK_QUEUE_TYPE_ACTIVITY, 1) @@ -240,7 +241,6 @@ func TestUserData_RetriesFetchOnUnavailable(t *testing.T) { t.Parallel() controller := gomock.NewController(t) - defer controller.Finish() ctx := context.Background() dbq := newTestUnversionedPhysicalQueueKey(defaultNamespaceId, defaultRootTqID, enumspb.TASK_QUEUE_TYPE_WORKFLOW, 1) tqCfg := defaultTqmTestOpts(controller) @@ -312,7 +312,6 @@ func TestUserData_RetriesFetchOnUnImplemented(t *testing.T) { t.Parallel() controller := gomock.NewController(t) - defer controller.Finish() ctx := context.Background() dbq := newTestUnversionedPhysicalQueueKey(defaultNamespaceId, defaultRootTqID, enumspb.TASK_QUEUE_TYPE_WORKFLOW, 1) tqCfg := defaultTqmTestOpts(controller) @@ -386,7 +385,6 @@ func TestUserData_FetchesUpTree(t *testing.T) { t.Parallel() controller := gomock.NewController(t) - defer controller.Finish() ctx := context.Background() taskQueue := newTestTaskQueue(defaultNamespaceId, defaultRootTqID, enumspb.TASK_QUEUE_TYPE_WORKFLOW) dbq := UnversionedQueueKey(taskQueue.NormalPartition(31)) @@ -426,7 +424,6 @@ func TestUserData_FetchesActivityToWorkflow(t *testing.T) { t.Parallel() controller := gomock.NewController(t) - defer controller.Finish() ctx := context.Background() // note: activity root dbq := newTestUnversionedPhysicalQueueKey(defaultNamespaceId, defaultRootTqID, enumspb.TASK_QUEUE_TYPE_ACTIVITY, 0) @@ -465,7 +462,6 @@ func TestUserData_FetchesStickyToNormal(t *testing.T) { t.Parallel() controller := gomock.NewController(t) - defer controller.Finish() ctx := context.Background() tqCfg := defaultTqmTestOpts(controller) @@ -508,7 +504,6 @@ func TestUserData_UpdateOnNonRootFails(t *testing.T) { t.Parallel() controller := gomock.NewController(t) - defer controller.Finish() ctx := context.Background() subTqId := newTestUnversionedPhysicalQueueKey(defaultNamespaceId, defaultRootTqID, enumspb.TASK_QUEUE_TYPE_WORKFLOW, 1) @@ -535,3 +530,88 @@ func TestUserData_UpdateOnNonRootFails(t *testing.T) { func newTestUnversionedPhysicalQueueKey(namespaceId string, name string, taskType enumspb.TaskQueueType, partition int) *PhysicalTaskQueueKey { return UnversionedQueueKey(newTestTaskQueue(namespaceId, name, taskType).NormalPartition(partition)) } + +func TestUserData_Propagation(t *testing.T) { + t.Parallel() + + const N = 7 + + ctx := context.Background() + controller := gomock.NewController(t) + opts := defaultTqmTestOpts(controller) + + keys := make([]*PhysicalTaskQueueKey, N) + for i := range keys { + keys[i] = newTestUnversionedPhysicalQueueKey(defaultNamespaceId, defaultRootTqID, enumspb.TASK_QUEUE_TYPE_WORKFLOW, i) + } + + managers := make([]*userDataManagerImpl, N) + var tm *testTaskManager + for i := range managers { + optsi := *opts // share config and mock client + optsi.dbq = keys[i] + managers[i] = createUserDataManager(t, controller, &optsi) + if i == 0 { + // only the root uses persistence + tm = managers[0].store.(*testTaskManager) + } + // use two levels + managers[i].config.ForwarderMaxChildrenPerNode = dynamicconfig.GetIntPropertyFn(3) + // override timeouts to run much faster + managers[i].config.GetUserDataLongPollTimeout = dynamicconfig.GetDurationPropertyFn(100 * time.Millisecond) + managers[i].config.GetUserDataMinWaitTime = 10 * time.Millisecond + managers[i].config.GetUserDataReturnBudget = 10 * time.Millisecond + managers[i].config.GetUserDataRetryPolicy = backoff.NewExponentialRetryPolicy(100 * time.Millisecond).WithMaximumInterval(1 * time.Second) + managers[i].logger = log.With(managers[i].logger, tag.HostID(fmt.Sprintf("%d", i))) + } + + // hook up "rpcs" + opts.matchingClientMock.EXPECT().GetTaskQueueUserData(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, req *matchingservice.GetTaskQueueUserDataRequest, opts ...grpc.CallOption) (*matchingservice.GetTaskQueueUserDataResponse, error) { + // inject failures + if rand.Float64() < 0.1 { + return nil, serviceerror.NewUnavailable("timeout") + } + p, err := tqid.NormalPartitionFromRpcName(req.TaskQueue, req.NamespaceId, req.TaskQueueType) + require.NoError(t, err) + require.Equal(t, enumspb.TASK_QUEUE_TYPE_WORKFLOW, p.TaskType()) + res, err := managers[p.PartitionId()].HandleGetUserDataRequest(ctx, req) + return res, err + }, + ).AnyTimes() + opts.matchingClientMock.EXPECT().UpdateTaskQueueUserData(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, req *matchingservice.UpdateTaskQueueUserDataRequest, opts ...grpc.CallOption) (*matchingservice.UpdateTaskQueueUserDataResponse, error) { + err := tm.UpdateTaskQueueUserData(ctx, &persistence.UpdateTaskQueueUserDataRequest{ + NamespaceID: req.NamespaceId, + TaskQueue: req.TaskQueue, + UserData: req.UserData, + BuildIdsAdded: req.BuildIdsAdded, + BuildIdsRemoved: req.BuildIdsRemoved, + }) + return &matchingservice.UpdateTaskQueueUserDataResponse{}, err + }, + ).AnyTimes() + + defer time.Sleep(50 * time.Millisecond) // extra buffer to let goroutines exit after manager.Stop() + for i := range managers { + managers[i].Start() + defer managers[i].Stop() + } + + const iters = 5 + for iter := 0; iter < iters; iter++ { + err := managers[0].UpdateUserData(ctx, UserDataUpdateOptions{}, func(data *persistencespb.TaskQueueUserData) (*persistencespb.TaskQueueUserData, bool, error) { + return data, false, nil + }) + require.NoError(t, err) + start := time.Now() + require.EventuallyWithT(t, func(c *assert.CollectT) { + for i := 1; i < N; i++ { + d, _, err := managers[i].GetUserData() + assert.NoError(c, err, "number", i) + assert.Equal(c, iter+1, int(d.GetVersion()), "number", i) + } + }, 5*time.Second, 10*time.Millisecond, "failed to propagate") + t.Log("Propagation time:", time.Since(start)) + } +}