diff --git a/internal/artifact.go b/internal/artifact.go index f761bdf..edd85a6 100644 --- a/internal/artifact.go +++ b/internal/artifact.go @@ -8,7 +8,6 @@ import ( "go/types" "log" "os/exec" - "regexp" "strings" "sync" @@ -88,31 +87,6 @@ func Arch() *Artifact { }) return arch } - -func PkgPattern(path string) (*regexp.Regexp, error) { - p := `^(?:[a-zA-Z]+(?:\.[a-zA-Z]+)*|\.\.\.)$` - re := regexp.MustCompile(p) - for _, seg := range strings.Split(path, "/") { - if len(seg) > 0 && !re.MatchString(seg) { - return nil, fmt.Errorf("invalid package paths: %s", path) - } - } - path = strings.TrimSuffix(path, "/") - path = strings.TrimPrefix(path, "/") - path = strings.ReplaceAll(path, "...", ".*") - return regexp.MustCompile(fmt.Sprintf("%s$", path)), nil -} - -func PkgPatters(paths ...string) []*regexp.Regexp { - return lo.Map(paths, func(path string, _ int) *regexp.Regexp { - reg, err := PkgPattern(path) - if err != nil { - log.Fatal(err) - } - return reg - }) -} - func parse(pkg *packages.Package, mode ParseMode) *Package { archPkg := &Package{raw: pkg} typPkg := pkg.Types diff --git a/internal/artifact_test.go b/internal/artifact_test.go index e3c79a9..0b42541 100644 --- a/internal/artifact_test.go +++ b/internal/artifact_test.go @@ -7,61 +7,6 @@ import ( "testing" ) -func Test_pattern(t *testing.T) { - tests := []struct { - name string - path string - wantErr bool - }{ - { - name: "valid one dot", - path: "github.com/kcmvp/archunit", - wantErr: false, - }, - { - name: "invalid one dot", - path: "github.com/./kcmvp/archunit", - wantErr: true, - }, - { - name: "valid-two-dots", - path: "git.hub.com/kcmvp/archunit", - wantErr: false, - }, - { - name: "invalid with two dots", - path: "github.com/../kcmvp/archunit", - wantErr: true, - }, - { - name: "invalid-two-dots", - path: "github..com/kcmvp/archunit", - wantErr: true, - }, - { - name: "invalid-two-dots", - path: "githubcom/../kcmvp/archunit", - wantErr: true, - }, - { - name: "valid three dots", - path: "githubcom/.../kcmvp/archunit", - wantErr: false, - }, - { - name: "invalid three more dots", - path: "githubcom/..../kcmvp/archunit", - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := PkgPattern(tt.path) - assert.Equal(t, tt.wantErr, err != nil) - }) - } -} - func TestAllConstants(t *testing.T) { tests := []struct { pkg string @@ -103,8 +48,6 @@ func TestPackage_Functions(t *testing.T) { pkg: "github.com/kcmvp/archunit/internal", funcs: []string{ "Arch", - "PkgPattern", - "PkgPatters", "parse", }, imports: []string{ @@ -113,7 +56,6 @@ func TestPackage_Functions(t *testing.T) { "golang.org/x/tools/go/packages", "log", "go/types", - "regexp", "github.com/samber/lo", "strings", "sync", @@ -139,6 +81,7 @@ func TestPackage_Functions(t *testing.T) { "TypesWith", "Packages", "AllPackages", + "ScopePattern", }, imports: []string{ "fmt", @@ -199,7 +142,7 @@ func TestPackage_Functions(t *testing.T) { } func TestAllSource(t *testing.T) { - assert.Equal(t, 20, len(Arch().GoFiles())) + assert.Equal(t, 21, len(Arch().GoFiles())) } func TestMethodsOfType(t *testing.T) { diff --git a/layer.go b/layer.go index 1cbad8a..558606b 100644 --- a/layer.go +++ b/layer.go @@ -55,13 +55,16 @@ func ConstantsShouldBeDefinedInOneFileByPackage() error { return nil } -func Layer(pkgPaths ...string) ArchLayer { - patterns := internal.PkgPatters(pkgPaths...) +func Layer(pkgPaths ...string) (ArchLayer, error) { + patterns, err := ScopePattern(pkgPaths...) + if err != nil { + return nil, err + } return lo.Filter(internal.Arch().Packages(), func(pkg *internal.Package, _ int) bool { return lo.ContainsBy(patterns, func(pattern *regexp.Regexp) bool { return pattern.MatchString(pkg.ID()) }) - }) + }), nil } func (layer ArchLayer) Name() string { @@ -79,22 +82,28 @@ func (layer ArchLayer) Name() string { return fmt.Sprintf("%v", left) } -func (layer ArchLayer) Exclude(pkgPaths ...string) ArchLayer { - patterns := internal.PkgPatters(pkgPaths...) +func (layer ArchLayer) Exclude(pkgPaths ...string) (ArchLayer, error) { + patterns, err := ScopePattern(pkgPaths...) + if err != nil { + return nil, err + } return lo.Filter(layer, func(pkg *internal.Package, _ int) bool { return lo.NoneBy(patterns, func(pattern *regexp.Regexp) bool { return pattern.MatchString(pkg.ID()) }) - }) + }), nil } -func (layer ArchLayer) Sub(name string, paths ...string) ArchLayer { - patterns := internal.PkgPatters(paths...) +func (layer ArchLayer) Sub(name string, paths ...string) (ArchLayer, error) { + patterns, err := ScopePattern(paths...) + if err != nil { + return nil, err + } return lo.Filter(layer, func(pkg *internal.Package, _ int) bool { return lo.SomeBy(patterns, func(pattern *regexp.Regexp) bool { return pattern.MatchString(pkg.ID()) }) - }) + }), nil } func (layer ArchLayer) Packages() ArchPackage { @@ -137,8 +146,11 @@ func (layer ArchLayer) Files() FileSet { }) } -func (layer ArchLayer) FilesInPackages(paths ...string) FileSet { - patterns := internal.PkgPatters(paths...) +func (layer ArchLayer) FilesInPackages(paths ...string) (FileSet, error) { + patterns, err := ScopePattern(paths...) + if err != nil { + return nil, err + } return lo.FilterMap(layer, func(pkg *internal.Package, _ int) (PackageFile, bool) { if lo.SomeBy(patterns, func(reg *regexp.Regexp) bool { return reg.MatchString(pkg.ID()) @@ -146,7 +158,7 @@ func (layer ArchLayer) FilesInPackages(paths ...string) FileSet { return PackageFile{A: pkg.ID(), B: pkg.GoFiles()}, true } return PackageFile{}, false - }) + }), nil } func (layer ArchLayer) ShouldNotReferLayers(layers ...ArchLayer) error { @@ -161,7 +173,11 @@ func (layer ArchLayer) ShouldNotReferLayers(layers ...ArchLayer) error { } func (layer ArchLayer) ShouldNotReferPackages(paths ...string) error { - return layer.ShouldNotReferLayers(Layer(paths...)) + l, err := Layer(paths...) + if err != nil { + return err + } + return layer.ShouldNotReferLayers(l) } func (layer ArchLayer) ShouldOnlyReferLayers(layers ...ArchLayer) error { @@ -174,7 +190,11 @@ func (layer ArchLayer) ShouldOnlyReferLayers(layers ...ArchLayer) error { } func (layer ArchLayer) ShouldOnlyReferPackages(paths ...string) error { - return layer.ShouldOnlyReferLayers(Layer(paths...)) + l, err := Layer(paths...) + if err != nil { + return err + } + return layer.ShouldOnlyReferLayers(l) } func (layer ArchLayer) ShouldBeOnlyReferredByLayers(layers ...ArchLayer) error { @@ -184,8 +204,11 @@ func (layer ArchLayer) ShouldBeOnlyReferredByLayers(layers ...ArchLayer) error { } func (layer ArchLayer) ShouldBeOnlyReferredByPackages(paths ...string) error { - layer1 := Layer(paths...) - return layer.ShouldBeOnlyReferredByLayers(layer1) + l, err := Layer(paths...) + if err != nil { + return err + } + return layer.ShouldBeOnlyReferredByLayers(l) } func (layer ArchLayer) DepthShouldLessThan(depth int) error { diff --git a/layer_test.go b/layer_test.go index 1a4be63..c0b0ece 100644 --- a/layer_test.go +++ b/layer_test.go @@ -8,7 +8,7 @@ import ( "testing" ) -func TestPackages(t *testing.T) { +func TestLayerPackages(t *testing.T) { tests := []struct { name string paths []string @@ -41,10 +41,10 @@ func TestPackages(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - layer := Layer(test.paths...) + layer, _ := Layer(test.paths...) assert.Equal(t, test.size1, len(layer.packages())) if len(test.except) > 0 { - layer = layer.Exclude(test.except...) + layer, _ = layer.Exclude(test.except...) assert.Equal(t, test.size2, len(layer.packages())) } }) @@ -63,22 +63,22 @@ func TestLayer_Sub(t *testing.T) { name: "ext sub", paths: []string{".../service/..."}, sub: []string{".../ext/"}, - size1: 4, + size1: 5, size2: 1, }, { name: "ext sub", paths: []string{".../service/..."}, sub: []string{".../ext/..."}, - size1: 4, - size2: 2, + size1: 5, + size2: 3, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - layer := Layer(tt.paths...) + layer, _ := Layer(tt.paths...) assert.Equal(t, tt.size1, len(layer.packages())) - layer = layer.Sub(tt.name, tt.sub...) + layer, _ = layer.Sub(tt.name, tt.sub...) assert.Equal(t, tt.size2, len(layer.packages())) }) } @@ -97,7 +97,7 @@ func TestConstantsShouldBeDefinedInOneFileByPackage(t *testing.T) { } func TestLayPackages(t *testing.T) { - layer := Layer("sample/controller", "sample/controller/...") + layer, _ := Layer("sample/controller/...") assert.ElementsMatch(t, []string{"github.com/kcmvp/archunit/internal/sample/controller", "github.com/kcmvp/archunit/internal/sample/controller/module1"}, layer.packages()) assert.ElementsMatch(t, layer.Imports(), @@ -112,10 +112,10 @@ func TestLayPackages(t *testing.T) { } func TestLayer_Refer(t *testing.T) { - controller := Layer("sample/controller", "sample/controller/...") - model := Layer("sample/model") - service := Layer("sample/service", "sample/service/...") - repository := Layer("sample/repository", "sample/repository/...") + controller, _ := Layer("sample/controller", "sample/controller/...") + model, _ := Layer("sample/model") + service, _ := Layer("sample/service", "sample/service/...") + repository, _ := Layer("sample/repository", "sample/repository/...") assert.NoError(t, controller.ShouldNotReferLayers(model)) assert.NoError(t, controller.ShouldNotReferPackages("sample/model")) assert.Errorf(t, controller.ShouldNotReferLayers(repository), "controller should not refer repository") diff --git a/package.go b/package.go index b68eb0e..4662cb1 100644 --- a/package.go +++ b/package.go @@ -15,13 +15,13 @@ func AllPackages() ArchPackage { return internal.Arch().Packages() } -func Packages(paths ...string) ArchPackage { - patterns := internal.PkgPatters(paths...) +func Packages(paths ...string) (ArchPackage, error) { + patterns, err := ScopePattern(paths...) return lo.Filter(AllPackages(), func(pkg *internal.Package, _ int) bool { return lo.ContainsBy(patterns, func(pattern *regexp.Regexp) bool { return pattern.MatchString(pkg.ID()) }) - }) + }), err } func (archPkg ArchPackage) ID() []string { @@ -104,7 +104,11 @@ func (archPkg ArchPackage) ShouldNotRefer(referred ...ArchPackage) error { } func (archPkg ArchPackage) ShouldNotReferPkgPaths(paths ...string) error { - return archPkg.ShouldNotRefer(Packages(paths...)) + pkgs, err := Packages(paths...) + if err != nil { + return err + } + return archPkg.ShouldNotRefer(pkgs) } func (archPkg ArchPackage) ShouldBeOnlyReferredByPackages(referrings ...ArchPackage) error { @@ -134,11 +138,17 @@ func (archPkg ArchPackage) ShouldOnlyReferPackages(referred ...ArchPackage) erro } func (archPkg ArchPackage) ShouldOnlyReferPkgPaths(paths ...string) error { - pkg := Packages(paths...) + pkg, err := Packages(paths...) + if err != nil { + return err + } return archPkg.ShouldOnlyReferPackages(pkg) } func (archPkg ArchPackage) ShouldBeOnlyReferredByPkgPaths(paths ...string) error { - pkg := Packages(paths...) + pkg, err := Packages(paths...) + if err != nil { + return err + } return archPkg.ShouldBeOnlyReferredByPackages(pkg) } diff --git a/package_test.go b/package_test.go index beb8f0d..49676bc 100644 --- a/package_test.go +++ b/package_test.go @@ -26,12 +26,12 @@ func TestPackageNameShould(t *testing.T) { pkgs := AllPackages() err := pkgs.NameShould(BeLowerCase) assert.NoError(t, err) - err = pkgs.NameShould((BeUpperCase)) + err = pkgs.NameShould(BeUpperCase) assert.Error(t, err) } func TestPackage(t *testing.T) { - pkgs := Packages("internal/sample/...") + pkgs, _ := Packages("internal/sample/...") assert.Equal(t, 12, len(pkgs)) assert.Equal(t, 12, len(pkgs.ID())) assert.Equal(t, 12, len(pkgs.Files())) @@ -48,10 +48,10 @@ func TestPackage(t *testing.T) { } func TestPackage_Ref(t *testing.T) { - controller := Packages("sample/controller", "sample/controller/...") - model := Packages("sample/model") - service := Packages("sample/service", "sample/service/...") - repository := Packages("sample/repository", "sample/repository/...") + controller, _ := Packages("sample/controller", "sample/controller/...") + model, _ := Packages("sample/model") + service, _ := Packages("sample/service", "sample/service/...") + repository, _ := Packages("sample/repository", "sample/repository/...") assert.NoError(t, controller.ShouldNotRefer(model)) assert.NoError(t, controller.ShouldNotReferPkgPaths("sample/model")) assert.Errorf(t, controller.ShouldNotRefer(repository), "controller should not refer repository") diff --git a/scope.go b/scope.go new file mode 100644 index 0000000..186776b --- /dev/null +++ b/scope.go @@ -0,0 +1,28 @@ +package archunit + +import ( + "fmt" + "regexp" + "strings" + + "github.com/samber/lo" +) + +func ScopePattern(paths ...string) ([]*regexp.Regexp, error) { + pps := lo.FlatMap(paths, func(item string, _ int) []string { + path := strings.TrimPrefix(strings.TrimSuffix(item, "/"), "/") + return lo.Union([]string{path, strings.TrimSuffix(path, "/...")}) + }) + pattern := `^(?:[a-zA-Z]+(?:\.[a-zA-Z]+)*|\.\.\.)$` + re := regexp.MustCompile(pattern) + for _, path := range pps { + for _, seg := range strings.Split(path, "/") { + if len(seg) > 0 && !re.MatchString(seg) { + return nil, fmt.Errorf("invalid package paths: %s", path) + } + } + } + return lo.Map(pps, func(path string, _ int) *regexp.Regexp { + return regexp.MustCompile(fmt.Sprintf("%s$", strings.ReplaceAll(path, "...", ".*"))) + }), nil +} diff --git a/scope_test.go b/scope_test.go new file mode 100644 index 0000000..12c891e --- /dev/null +++ b/scope_test.go @@ -0,0 +1,72 @@ +package archunit + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func Test_scope_pattern(t *testing.T) { + tests := []struct { + name string + path string + size int + wantErr bool + }{ + { + name: "valid one dot", + path: "github.com/kcmvp/archunit", + size: 1, + wantErr: false, + }, + { + name: "invalid one dot", + path: "github.com/./kcmvp/archunit", + wantErr: true, + }, + { + name: "valid-two-dots", + path: "git.hub.com/kcmvp/archunit", + size: 1, + wantErr: false, + }, + { + name: "invalid with two dots", + path: "github.com/../kcmvp/archunit", + wantErr: true, + }, + { + name: "invalid-two-dots", + path: "github..com/kcmvp/archunit", + wantErr: true, + }, + { + name: "invalid-two-dots", + path: "githubcom/../kcmvp/archunit", + wantErr: true, + }, + { + name: "valid three dots", + path: "githubcom/.../kcmvp/archunit", + size: 1, + wantErr: false, + }, + { + name: "valid three dots multiple", + path: "githubcom/.../kcmvp/archunit/...", + size: 2, + wantErr: false, + }, + { + name: "invalid three more dots", + path: "githubcom/..../kcmvp/archunit", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + patterns, err := ScopePattern(tt.path) + assert.Equal(t, tt.wantErr, err != nil) + assert.Equal(t, tt.size, len(patterns)) + }) + } +}