diff --git a/README.md b/README.md index ebede7c..2220b8c 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,6 @@ Also, formatting for your code will be prepared(so, you don't need to use `gofmt Use additional options `-rm-unused` to remove unused imports and `-set-alias` to rewrite import aliases for versioned packages or for packages with additional prefix/suffix(example: `opentracing "github.com/opentracing/opentracing-go"`). `-local` - will create group for local imports. Values should be comma-separated. -*The last two options (`-rm-unused` & `-set-alias`) will work only for projects using Go Modules.* ## Configuration: ### Cmd diff --git a/main.go b/main.go index 927a5fa..9e5dffe 100644 --- a/main.go +++ b/main.go @@ -19,7 +19,7 @@ const ( filePathArg = "file-path" versionArg = "version" removeUnusedImportsArg = "rm-unused" - setAlias = "set-alias" + setAliasArg = "set-alias" localPkgPrefixesArg = "local" ) @@ -66,7 +66,7 @@ func init() { ) shouldSetAlias = flag.Bool( - setAlias, + setAliasArg, false, "Set alias for versioned package names, like 'github.com/go-pg/pg/v9'. "+ "In this case import will be set as 'pg \"github.com/go-pg/pg/v9\"'. Optional parameter.", diff --git a/pkg/module/error.go b/pkg/module/error.go index 0a91625..762e441 100644 --- a/pkg/module/error.go +++ b/pkg/module/error.go @@ -1,11 +1,13 @@ package module +// UndefinedModuleError will appear on absent go.mod type UndefinedModuleError struct{} func (e *UndefinedModuleError) Error() string { return "module is undefined" } +// PathIsNotSetError will appear if any directory or file is not set for searching go.mod type PathIsNotSetError struct{} func (e *PathIsNotSetError) Error() string { diff --git a/pkg/module/error_test.go b/pkg/module/error_test.go new file mode 100644 index 0000000..86228d5 --- /dev/null +++ b/pkg/module/error_test.go @@ -0,0 +1,44 @@ +package module + +import "testing" + +func TestPathIsNotSetError_Error(t *testing.T) { + tests := []struct { + name string + want string + }{ + { + name: "success", + want: "path is not set", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &PathIsNotSetError{} + if got := e.Error(); got != tt.want { + t.Errorf("Error() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUndefinedModuleError_Error(t *testing.T) { + tests := []struct { + name string + want string + }{ + { + name: "success", + want: "module is undefined", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &UndefinedModuleError{} + if got := e.Error(); got != tt.want { + t.Errorf("Error() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/module/module.go b/pkg/module/module.go index eeeb027..7986c54 100644 --- a/pkg/module/module.go +++ b/pkg/module/module.go @@ -10,6 +10,7 @@ import ( const goModFilename = "go.mod" +// Name reads module value from ./go.mod func Name(goModRootPath string) (string, error) { goModFile := filepath.Join(goModRootPath, goModFilename) @@ -30,6 +31,8 @@ func Name(goModRootPath string) (string, error) { return "", &UndefinedModuleError{} } +// GoModRootPath in case of any directory or file of the project will return root dir of the project where go.mod file +// is exist func GoModRootPath(path string) (string, error) { if path == "" { return "", &PathIsNotSetError{} diff --git a/pkg/module/module_test.go b/pkg/module/module_test.go index 9ab535b..b35cdb4 100644 --- a/pkg/module/module_test.go +++ b/pkg/module/module_test.go @@ -5,7 +5,7 @@ import ( "testing" ) -func TestName(t *testing.T) { +func TestGoModRootPathAndName(t *testing.T) { type args struct { dir string } @@ -68,3 +68,80 @@ func TestName(t *testing.T) { }) } } + +func TestName(t *testing.T) { + type args struct { + goModRootPath string + } + tests := []struct { + name string + prepareFn func() + args args + want string + wantErr bool + }{ + { + name: "read empty go.mod", + prepareFn: func() { + const f = "/tmp/go.mod" + + if _, err := os.Stat(f); os.IsExist(err) { + if err := os.Remove(f); err != nil { + panic(err) + } + } + + _, err := os.Create(f) + if err != nil { + panic(err) + } + }, + args: args{ + goModRootPath: "/tmp", + }, + want: "", + wantErr: true, + }, + { + name: "check failed parsing of go.mod", + prepareFn: func() { + const f = "/tmp/go.mod" + + if _, err := os.Stat(f); os.IsExist(err) { + if err := os.Remove(f); err != nil { + panic(err) + } + } + + file, err := os.Create(f) + if err != nil { + panic(err) + } + + if _, err := file.WriteString("mod test"); err != nil { + panic(err) + } + }, + args: args{ + goModRootPath: "/tmp", + }, + want: "", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + tt.prepareFn() + + got, err := Name(tt.args.goModRootPath) + if (err != nil) != tt.wantErr { + t.Errorf("Name() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("Name() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/reviser/reviser.go b/reviser/reviser.go index 27b48dc..674d79a 100644 --- a/reviser/reviser.go +++ b/reviser/reviser.go @@ -182,7 +182,6 @@ func generateFile(fset *token.FileSet, file *ast.File) ([]byte, error) { return buffer.Bytes(), nil } -// TODO: fix gocyclo func fixImports( f *ast.File, stdImports, generalImports, projectLocalPkgs, projectImports []string, @@ -191,80 +190,94 @@ func fixImports( var importsPositions []*importPosition for _, decl := range f.Decls { - switch decl.(type) { - case *ast.GenDecl: - dd := decl.(*ast.GenDecl) - if dd.Tok == token.IMPORT { - importsPositions = append( - importsPositions, &importPosition{ - Start: dd.Pos(), - End: dd.End(), - }, - ) - - var specs []ast.Spec - - linesCounter := len(stdImports) - for _, stdImport := range stdImports { - spec := &ast.ImportSpec{ - Path: &ast.BasicLit{Value: importWithComment(stdImport, commentsMetadata), Kind: dd.Tok}, - } - specs = append(specs, spec) + dd, ok := decl.(*ast.GenDecl) + if !ok { + continue + } - linesCounter-- + if dd.Tok != token.IMPORT { + continue + } - if linesCounter == 0 && (len(generalImports) > 0 || len(projectLocalPkgs) > 0 || len(projectImports) > 0) { - spec = &ast.ImportSpec{Path: &ast.BasicLit{Value: "", Kind: token.STRING}} + importsPositions = append( + importsPositions, &importPosition{ + Start: dd.Pos(), + End: dd.End(), + }, + ) - specs = append(specs, spec) - } - } + dd.Specs = rebuildImports(dd.Tok, commentsMetadata, stdImports, generalImports, projectLocalPkgs, projectImports) + } - linesCounter = len(generalImports) - for _, generalImport := range generalImports { - spec := &ast.ImportSpec{ - Path: &ast.BasicLit{Value: importWithComment(generalImport, commentsMetadata), Kind: dd.Tok}, - } - specs = append(specs, spec) + clearImportDocs(f, importsPositions) +} - linesCounter-- +func rebuildImports( + tok token.Token, + commentsMetadata map[string]*commentsMetadata, + stdImports []string, + generalImports []string, + projectLocalPkgs []string, + projectImports []string, +) []ast.Spec { + var specs []ast.Spec + + linesCounter := len(stdImports) + for _, stdImport := range stdImports { + spec := &ast.ImportSpec{ + Path: &ast.BasicLit{Value: importWithComment(stdImport, commentsMetadata), Kind: tok}, + } + specs = append(specs, spec) - if linesCounter == 0 && (len(projectLocalPkgs) > 0 || len(projectImports) > 0) { - spec = &ast.ImportSpec{Path: &ast.BasicLit{Value: "", Kind: token.STRING}} + linesCounter-- - specs = append(specs, spec) - } - } + if linesCounter == 0 && (len(generalImports) > 0 || len(projectLocalPkgs) > 0 || len(projectImports) > 0) { + spec = &ast.ImportSpec{Path: &ast.BasicLit{Value: "", Kind: token.STRING}} - linesCounter = len(projectLocalPkgs) - for _, projectLocalPkg := range projectLocalPkgs { - spec := &ast.ImportSpec{ - Path: &ast.BasicLit{Value: importWithComment(projectLocalPkg, commentsMetadata), Kind: dd.Tok}, - } - specs = append(specs, spec) + specs = append(specs, spec) + } + } - linesCounter-- + linesCounter = len(generalImports) + for _, generalImport := range generalImports { + spec := &ast.ImportSpec{ + Path: &ast.BasicLit{Value: importWithComment(generalImport, commentsMetadata), Kind: tok}, + } + specs = append(specs, spec) - if linesCounter == 0 && len(projectImports) > 0 { - spec = &ast.ImportSpec{Path: &ast.BasicLit{Value: "", Kind: token.STRING}} + linesCounter-- - specs = append(specs, spec) - } - } + if linesCounter == 0 && (len(projectLocalPkgs) > 0 || len(projectImports) > 0) { + spec = &ast.ImportSpec{Path: &ast.BasicLit{Value: "", Kind: token.STRING}} - for _, projectImport := range projectImports { - spec := &ast.ImportSpec{ - Path: &ast.BasicLit{Value: importWithComment(projectImport, commentsMetadata), Kind: dd.Tok}, - } - specs = append(specs, spec) - } + specs = append(specs, spec) + } + } - dd.Specs = specs - } + linesCounter = len(projectLocalPkgs) + for _, projectLocalPkg := range projectLocalPkgs { + spec := &ast.ImportSpec{ + Path: &ast.BasicLit{Value: importWithComment(projectLocalPkg, commentsMetadata), Kind: tok}, + } + specs = append(specs, spec) + + linesCounter-- + + if linesCounter == 0 && len(projectImports) > 0 { + spec = &ast.ImportSpec{Path: &ast.BasicLit{Value: "", Kind: token.STRING}} + + specs = append(specs, spec) } } - clearImportDocs(f, importsPositions) + for _, projectImport := range projectImports { + spec := &ast.ImportSpec{ + Path: &ast.BasicLit{Value: importWithComment(projectImport, commentsMetadata), Kind: tok}, + } + specs = append(specs, spec) + } + + return specs } func clearImportDocs(f *ast.File, importsPositions []*importPosition) {