Skip to content

Commit

Permalink
smarter detection for the go.mod file location
Browse files Browse the repository at this point in the history
  • Loading branch information
ryancurrah committed Apr 4, 2020
1 parent 490d032 commit e8acbed
Showing 1 changed file with 30 additions and 8 deletions.
38 changes: 30 additions & 8 deletions gomodguard.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package gomodguard

import (
"bytes"
"fmt"
"go/parser"
"go/token"
"io/ioutil"
"log"
"os"
"os/exec"
"strings"

"golang.org/x/mod/modfile"
Expand All @@ -14,7 +17,7 @@ import (
var (
blockedReasonNotInAllowedList = "import of package `%s` is blocked because the module is not in the allowed modules list."
blockedReasonInBlockedList = "import of package `%s` is blocked because the module is in the blocked modules list."
goModFile = "go.mod"
goModFilename = "go.mod"
)

// Recommendations are alternative modules to use and a reason why.
Expand Down Expand Up @@ -191,18 +194,16 @@ type Processor struct {

// NewProcessor will create a Processor to lint blocked packages.
func NewProcessor(config Configuration, logger *log.Logger) (*Processor, error) {
moddata, err := ioutil.ReadFile(goModFile)
goModFileBytes, err := loadGoModFile()
if err != nil {
errMsg := fmt.Sprintf("unable to read go.mod file: %s", err)
logger.Printf(errMsg)
errMsg := fmt.Sprintf("unable to read %s file: %s", goModFilename, err)

return nil, fmt.Errorf(errMsg)
}

mfile, err := modfile.Parse(goModFile, moddata, nil)
mfile, err := modfile.Parse(goModFilename, goModFileBytes, nil)
if err != nil {
errMsg := fmt.Sprintf("unable to parse go.mod file: %s", err)
logger.Printf(errMsg)
errMsg := fmt.Sprintf("unable to parse %s file: %s", goModFilename, err)

return nil, fmt.Errorf(errMsg)
}
Expand Down Expand Up @@ -231,7 +232,8 @@ func (p *Processor) ProcessFiles(filenames []string) []Result {
pluralModuleMsg = ""
}

p.logger.Printf("info: found `%d` blocked module%s in the go.mod file, %+v", len(p.blockedModulesFromModFile), pluralModuleMsg, p.blockedModulesFromModFile)
p.logger.Printf("info: found `%d` blocked module%s in the %s file, %+v",
len(p.blockedModulesFromModFile), pluralModuleMsg, goModFilename, p.blockedModulesFromModFile)

for _, filename := range filenames {
data, err := ioutil.ReadFile(filename)
Expand Down Expand Up @@ -344,3 +346,23 @@ func (p *Processor) isBlockedPackageFromModFile(pkg string) bool {

return false
}

func loadGoModFile() ([]byte, error) {
cmd := exec.Command("go", "list", "-m", "-f", "{{.GoMod}}")
stdout, _ := cmd.StdoutPipe()
_ = cmd.Start()

goModFileLocation := ""

if stdout != nil {
buf := new(bytes.Buffer)
_, _ = buf.ReadFrom(stdout)
goModFileLocation = strings.TrimSpace(buf.String())
}

if _, err := os.Stat(goModFileLocation); os.IsNotExist(err) {
return ioutil.ReadFile(goModFilename)
}

return ioutil.ReadFile(goModFileLocation)
}

0 comments on commit e8acbed

Please sign in to comment.