diff --git a/internal/internal_workflow_testsuite.go b/internal/internal_workflow_testsuite.go index 1b5bd1495..21dbab413 100644 --- a/internal/internal_workflow_testsuite.go +++ b/internal/internal_workflow_testsuite.go @@ -362,7 +362,6 @@ func (env *testWorkflowEnvironmentImpl) newTestWorkflowEnvironmentForChild(param childEnv.startedHandler = startedHandler childEnv.testWorkflowEnvironmentShared = env.testWorkflowEnvironmentShared childEnv.workerOptions = env.workerOptions - childEnv.workflowInterceptors = env.workflowInterceptors childEnv.workerOptions.DataConverter = params.dataConverter childEnv.workflowInterceptors = env.workflowInterceptors childEnv.registry = env.registry diff --git a/internal/workflow_testsuite.go b/internal/workflow_testsuite.go index 9796a4a9e..19e5ae3ee 100644 --- a/internal/workflow_testsuite.go +++ b/internal/workflow_testsuite.go @@ -240,8 +240,8 @@ func (t *TestWorkflowEnvironment) OnActivity(activity interface{}, args ...inter panic(err) } fnName := getActivityFunctionName(t.impl.registry, activity) - t.impl.registry.RegisterActivityWithOptions(activity, RegisterActivityOptions{DisableAlreadyRegisteredCheck: true}) call = t.Mock.On(fnName, args...) + case reflect.String: call = t.Mock.On(activity.(string), args...) default: @@ -289,12 +289,11 @@ func (t *TestWorkflowEnvironment) OnWorkflow(workflow interface{}, args ...inter if alias, ok := t.impl.registry.getWorkflowAlias(fnName); ok { fnName = alias } - t.impl.registry.RegisterWorkflowWithOptions(workflow, RegisterWorkflowOptions{DisableAlreadyRegisteredCheck: true}) call = t.Mock.On(fnName, args...) case reflect.String: call = t.Mock.On(workflow.(string), args...) default: - panic("workflow must be function or string") + panic("activity must be function or string") } return t.wrapCall(call) diff --git a/internal/workflow_testsuite_test.go b/internal/workflow_testsuite_test.go index c9a16d53c..ca1123247 100644 --- a/internal/workflow_testsuite_test.go +++ b/internal/workflow_testsuite_test.go @@ -23,8 +23,6 @@ package internal import ( "context" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/suite" "strings" "testing" "time" @@ -129,120 +127,3 @@ func TestWorkflowReturnNil(t *testing.T) { err := env.GetWorkflowResult(&r) require.NoError(t, err) } - -func HelloWorkflow(ctx Context, name string) (string, error) { - ctx = WithActivityOptions(ctx, ActivityOptions{ - ScheduleToCloseTimeout: time.Hour, - StartToCloseTimeout: time.Hour, - ScheduleToStartTimeout: time.Hour, - }) - var result string - err := ExecuteActivity(ctx, HelloActivity, name).Get(ctx, &result) - return result, err -} - -func HelloActivity(ctx context.Context, name string) (string, error) { - return "Hello " + name + "!", nil -} - -func TestWorkflowMockingWithoutRegistration(t *testing.T) { - testSuite := &WorkflowTestSuite{} - env := testSuite.NewTestWorkflowEnvironment() - env.OnWorkflow(HelloWorkflow, mock.Anything, mock.Anything).Return( - func(ctx Context, person string) (string, error) { - return "Hello " + person + "!", nil - }) - // Workflow is mocked, no activity registration required - env.ExecuteWorkflow(HelloWorkflow, "Cadence") - require.NoError(t, env.GetWorkflowError()) - var result string - err := env.GetWorkflowResult(&result) - require.NoError(t, err) - require.Equal(t, "Hello Cadence!", result) -} - -func TestActivityMockingWithoutRegistration(t *testing.T) { - testSuite := &WorkflowTestSuite{} - env := testSuite.NewTestWorkflowEnvironment() - env.OnActivity(HelloActivity, mock.Anything, mock.Anything).Return( - func(ctx context.Context, person string) (string, error) { - return "Goodbye " + person + "!", nil - }) - // Registration of activity not required - env.RegisterWorkflow(HelloWorkflow) - env.ExecuteWorkflow(HelloWorkflow, "Cadence") - require.NoError(t, env.GetWorkflowError()) - var result string - err := env.GetWorkflowResult(&result) - require.NoError(t, err) - require.Equal(t, "Goodbye Cadence!", result) - -type InterceptorTestSuite struct { - suite.Suite - WorkflowTestSuite - - env *TestWorkflowEnvironment - testFactory InterceptorFactory -} - -type InterceptorFactory struct { - workflowInterceptorInvocationCounter int - childWorkflowInterceptorInvocationCounter int -} - -type Interceptor struct { - WorkflowInterceptorBase - workflowInterceptorInvocationCounter *int - childWorkflowInterceptorInvocationCounter *int -} - -func (i *Interceptor) ExecuteWorkflow(ctx Context, workflowType string, args ...interface{}) []interface{} { - *i.workflowInterceptorInvocationCounter += 1 - return i.Next.ExecuteWorkflow(ctx, workflowType, args...) -} -func (i *Interceptor) ExecuteChildWorkflow(ctx Context, workflowType string, args ...interface{}) ChildWorkflowFuture { - *i.childWorkflowInterceptorInvocationCounter += 1 - return i.Next.ExecuteChildWorkflow(ctx, workflowType, args...) -} - -func (f *InterceptorFactory) NewInterceptor(_ *WorkflowInfo, next WorkflowInterceptor) WorkflowInterceptor { - return &Interceptor{ - WorkflowInterceptorBase: WorkflowInterceptorBase{ - Next: next, - }, - workflowInterceptorInvocationCounter: &f.workflowInterceptorInvocationCounter, - childWorkflowInterceptorInvocationCounter: &f.childWorkflowInterceptorInvocationCounter, - } -} - -func (s *InterceptorTestSuite) SetupTest() { - // Create a test workflow environment with the trace interceptor configured. - s.env = s.NewTestWorkflowEnvironment() - s.testFactory = InterceptorFactory{} - s.env.SetWorkerOptions(WorkerOptions{ - WorkflowInterceptorChainFactories: []WorkflowInterceptorFactory{ - &s.testFactory, - }, - }) -} - -func TestInterceptorTestSuite(t *testing.T) { - suite.Run(t, new(InterceptorTestSuite)) -} - -func (s *InterceptorTestSuite) Test_GeneralInterceptor_IsExecutedOnChildren() { - r := s.Require() - childWf := func(ctx Context) error { - return nil - } - s.env.RegisterWorkflowWithOptions(childWf, RegisterWorkflowOptions{Name: "child"}) - wf := func(ctx Context) error { - return ExecuteChildWorkflow(ctx, childWf).Get(ctx, nil) - } - s.env.RegisterWorkflowWithOptions(wf, RegisterWorkflowOptions{Name: "parent"}) - s.env.ExecuteWorkflow(wf) - r.True(s.env.IsWorkflowCompleted()) - r.NoError(s.env.GetWorkflowError()) - r.Equal(s.testFactory.workflowInterceptorInvocationCounter, 2) - r.Equal(s.testFactory.childWorkflowInterceptorInvocationCounter, 1) -}