diff --git a/controllers/common.go b/controllers/common.go index 7bbf810bf..44c3364b2 100644 --- a/controllers/common.go +++ b/controllers/common.go @@ -7,13 +7,13 @@ import ( "net" "time" + "github.com/go-playground/validator/v10" "github.com/gravitl/netmaker/functions" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mongoconn" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo/options" "golang.org/x/crypto/bcrypt" - "gopkg.in/go-playground/validator.v9" ) func GetPeersList(networkName string) ([]models.PeersResponse, error) { @@ -66,11 +66,11 @@ func ValidateNodeCreate(networkName string, node models.Node) error { empty := node.Address == "" return (empty || isIpv4) }) - _ = v.RegisterValidation("address6_check", func(fl validator.FieldLevel) bool { - isIpv6 := functions.IsIpNet(node.Address6) - empty := node.Address6 == "" - return (empty || isIpv6) - }) + _ = v.RegisterValidation("address6_check", func(fl validator.FieldLevel) bool { + isIpv6 := functions.IsIpNet(node.Address6) + empty := node.Address6 == "" + return (empty || isIpv6) + }) _ = v.RegisterValidation("endpoint_check", func(fl validator.FieldLevel) bool { //var isFieldUnique bool = functions.IsFieldUnique(networkName, "endpoint", node.Endpoint) isIp := functions.IsIpNet(node.Endpoint) @@ -126,66 +126,66 @@ func ValidateNodeCreate(networkName string, node models.Node) error { func ValidateNodeUpdate(networkName string, node models.Node) error { - v := validator.New() - _ = v.RegisterValidation("address_check", func(fl validator.FieldLevel) bool { - isIpv4 := functions.IsIpNet(node.Address) - empty := node.Address == "" - return (empty || isIpv4) - }) - _ = v.RegisterValidation("address6_check", func(fl validator.FieldLevel) bool { - isIpv6 := functions.IsIpNet(node.Address6) - empty := node.Address6 == "" - return (empty || isIpv6) - }) - _ = v.RegisterValidation("endpoint_check", func(fl validator.FieldLevel) bool { - //var isFieldUnique bool = functions.IsFieldUnique(networkName, "endpoint", node.Endpoint) - isIp := functions.IsIpNet(node.Address) + v := validator.New() + _ = v.RegisterValidation("address_check", func(fl validator.FieldLevel) bool { + isIpv4 := functions.IsIpNet(node.Address) + empty := node.Address == "" + return (empty || isIpv4) + }) + _ = v.RegisterValidation("address6_check", func(fl validator.FieldLevel) bool { + isIpv6 := functions.IsIpNet(node.Address6) + empty := node.Address6 == "" + return (empty || isIpv6) + }) + _ = v.RegisterValidation("endpoint_check", func(fl validator.FieldLevel) bool { + //var isFieldUnique bool = functions.IsFieldUnique(networkName, "endpoint", node.Endpoint) + isIp := functions.IsIpNet(node.Address) empty := node.Endpoint == "" - return (empty || isIp) - }) - _ = v.RegisterValidation("localaddress_check", func(fl validator.FieldLevel) bool { - //var isFieldUnique bool = functions.IsFieldUnique(networkName, "endpoint", node.Endpoint) - isIp := functions.IsIpNet(node.LocalAddress) - empty := node.LocalAddress == "" - return (empty || isIp ) - }) - _ = v.RegisterValidation("macaddress_unique", func(fl validator.FieldLevel) bool { - return true - }) - - _ = v.RegisterValidation("macaddress_valid", func(fl validator.FieldLevel) bool { - _, err := net.ParseMAC(node.MacAddress) - return err == nil - }) - - _ = v.RegisterValidation("name_valid", func(fl validator.FieldLevel) bool { - isvalid := functions.NameInNodeCharSet(node.Name) - return isvalid - }) - - _ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool { - _, err := node.GetNetwork() - return err == nil - }) - _ = v.RegisterValidation("pubkey_check", func(fl validator.FieldLevel) bool { - empty := node.PublicKey == "" - isBase64 := functions.IsBase64(node.PublicKey) - return (empty || isBase64) - }) - _ = v.RegisterValidation("password_check", func(fl validator.FieldLevel) bool { - empty := node.Password == "" - goodLength := len(node.Password) > 5 - return (empty || goodLength) - }) - - err := v.Struct(node) - - if err != nil { - for _, e := range err.(validator.ValidationErrors) { - fmt.Println(e) - } - } - return err + return (empty || isIp) + }) + _ = v.RegisterValidation("localaddress_check", func(fl validator.FieldLevel) bool { + //var isFieldUnique bool = functions.IsFieldUnique(networkName, "endpoint", node.Endpoint) + isIp := functions.IsIpNet(node.LocalAddress) + empty := node.LocalAddress == "" + return (empty || isIp) + }) + _ = v.RegisterValidation("macaddress_unique", func(fl validator.FieldLevel) bool { + return true + }) + + _ = v.RegisterValidation("macaddress_valid", func(fl validator.FieldLevel) bool { + _, err := net.ParseMAC(node.MacAddress) + return err == nil + }) + + _ = v.RegisterValidation("name_valid", func(fl validator.FieldLevel) bool { + isvalid := functions.NameInNodeCharSet(node.Name) + return isvalid + }) + + _ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool { + _, err := node.GetNetwork() + return err == nil + }) + _ = v.RegisterValidation("pubkey_check", func(fl validator.FieldLevel) bool { + empty := node.PublicKey == "" + isBase64 := functions.IsBase64(node.PublicKey) + return (empty || isBase64) + }) + _ = v.RegisterValidation("password_check", func(fl validator.FieldLevel) bool { + empty := node.Password == "" + goodLength := len(node.Password) > 5 + return (empty || goodLength) + }) + + err := v.Struct(node) + + if err != nil { + for _, e := range err.(validator.ValidationErrors) { + fmt.Println(e) + } + } + return err } func UpdateNode(nodechange models.Node, node models.Node) (models.Node, error) { @@ -199,10 +199,10 @@ func UpdateNode(nodechange models.Node, node models.Node) (models.Node, error) { node.Address = nodechange.Address notifynetwork = true } - if nodechange.Address6 != "" { - node.Address6 = nodechange.Address6 - notifynetwork = true - } + if nodechange.Address6 != "" { + node.Address6 = nodechange.Address6 + notifynetwork = true + } if nodechange.Name != "" { node.Name = nodechange.Name } @@ -379,15 +379,15 @@ func CreateNode(node models.Node, networkName string) (models.Node, error) { if err != nil { return node, err } - fmt.Println("Setting node address: " + node.Address) + fmt.Println("Setting node address: " + node.Address) - node.Address6, err = functions.UniqueAddress6(networkName) - if node.Address6 != "" { + node.Address6, err = functions.UniqueAddress6(networkName) + if node.Address6 != "" { fmt.Println("Setting node ipv6 address: " + node.Address6) } if err != nil { - return node, err - } + return node, err + } //IDK why these aren't a part of "set defaults. Pretty dumb. //TODO: This is dumb. Consolidate and fix. diff --git a/controllers/dnsHttpController.go b/controllers/dnsHttpController.go index 3f8ebcb3a..f2df94564 100644 --- a/controllers/dnsHttpController.go +++ b/controllers/dnsHttpController.go @@ -7,14 +7,15 @@ import ( "fmt" "net/http" "time" + + "github.com/go-playground/validator/v10" "github.com/gorilla/mux" - "github.com/txn2/txeh" "github.com/gravitl/netmaker/functions" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mongoconn" + "github.com/txn2/txeh" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo/options" - "gopkg.in/go-playground/validator.v9" ) func dnsHandlers(r *mux.Router) { @@ -32,85 +33,84 @@ func dnsHandlers(r *mux.Router) { //Gets all nodes associated with network, including pending nodes func getNodeDNS(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Type", "application/json") - var dns []models.DNSEntry - var params = mux.Vars(r) + var dns []models.DNSEntry + var params = mux.Vars(r) dns, err := GetNodeDNS(params["network"]) - if err != nil { - returnErrorResponse(w, r, formatError(err, "internal")) - return - } - - //Returns all the nodes in JSON format - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(dns) + if err != nil { + returnErrorResponse(w, r, formatError(err, "internal")) + return + } + + //Returns all the nodes in JSON format + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(dns) } //Gets all nodes associated with network, including pending nodes func getAllDNS(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Type", "application/json") - var dns []models.DNSEntry + var dns []models.DNSEntry networks, err := functions.ListNetworks() - if err != nil { - returnErrorResponse(w, r, formatError(err, "internal")) - return - } - - for _, net := range networks { - netdns, err := GetDNS(net.NetID) - if err != nil { + if err != nil { + returnErrorResponse(w, r, formatError(err, "internal")) + return + } + + for _, net := range networks { + netdns, err := GetDNS(net.NetID) + if err != nil { returnErrorResponse(w, r, formatError(err, "internal")) - return - } - dns = append(dns, netdns...) - } - - //Returns all the nodes in JSON format - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(dns) -} + return + } + dns = append(dns, netdns...) + } + //Returns all the nodes in JSON format + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(dns) +} -func GetNodeDNS(network string) ([]models.DNSEntry, error){ +func GetNodeDNS(network string) ([]models.DNSEntry, error) { - var dns []models.DNSEntry + var dns []models.DNSEntry - collection := mongoconn.Client.Database("netmaker").Collection("nodes") + collection := mongoconn.Client.Database("netmaker").Collection("nodes") - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - filter := bson.M{"network": network} + filter := bson.M{"network": network} - cur, err := collection.Find(ctx, filter, options.Find().SetProjection(bson.M{"_id": 0})) + cur, err := collection.Find(ctx, filter, options.Find().SetProjection(bson.M{"_id": 0})) - if err != nil { - return dns, err - } + if err != nil { + return dns, err + } - defer cancel() + defer cancel() - for cur.Next(context.TODO()) { + for cur.Next(context.TODO()) { - var entry models.DNSEntry + var entry models.DNSEntry - err := cur.Decode(&entry) - if err != nil { - return dns, err - } + err := cur.Decode(&entry) + if err != nil { + return dns, err + } - // add item our array of nodes - dns = append(dns, entry) - } + // add item our array of nodes + dns = append(dns, entry) + } - //TODO: Another fatal error we should take care of. - if err := cur.Err(); err != nil { - return dns, err - } + //TODO: Another fatal error we should take care of. + if err := cur.Err(); err != nil { + return dns, err + } return dns, err } @@ -118,110 +118,110 @@ func GetNodeDNS(network string) ([]models.DNSEntry, error){ //Gets all nodes associated with network, including pending nodes func getCustomDNS(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Type", "application/json") - var dns []models.DNSEntry - var params = mux.Vars(r) + var dns []models.DNSEntry + var params = mux.Vars(r) - dns, err := GetCustomDNS(params["network"]) - if err != nil { - returnErrorResponse(w, r, formatError(err, "internal")) - return - } + dns, err := GetCustomDNS(params["network"]) + if err != nil { + returnErrorResponse(w, r, formatError(err, "internal")) + return + } - //Returns all the nodes in JSON format - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(dns) + //Returns all the nodes in JSON format + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(dns) } -func GetCustomDNS(network string) ([]models.DNSEntry, error){ +func GetCustomDNS(network string) ([]models.DNSEntry, error) { - var dns []models.DNSEntry + var dns []models.DNSEntry - collection := mongoconn.Client.Database("netmaker").Collection("dns") + collection := mongoconn.Client.Database("netmaker").Collection("dns") - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - filter := bson.M{"network": network} + filter := bson.M{"network": network} - cur, err := collection.Find(ctx, filter, options.Find().SetProjection(bson.M{"_id": 0})) + cur, err := collection.Find(ctx, filter, options.Find().SetProjection(bson.M{"_id": 0})) - if err != nil { - return dns, err - } + if err != nil { + return dns, err + } - defer cancel() + defer cancel() - for cur.Next(context.TODO()) { + for cur.Next(context.TODO()) { - var entry models.DNSEntry + var entry models.DNSEntry - err := cur.Decode(&entry) - if err != nil { - return dns, err - } + err := cur.Decode(&entry) + if err != nil { + return dns, err + } - // add item our array of nodes - dns = append(dns, entry) - } + // add item our array of nodes + dns = append(dns, entry) + } - //TODO: Another fatal error we should take care of. - if err := cur.Err(); err != nil { - return dns, err - } + //TODO: Another fatal error we should take care of. + if err := cur.Err(); err != nil { + return dns, err + } - return dns, err + return dns, err } -func GetDNSEntryNum(domain string, network string) (int, error){ +func GetDNSEntryNum(domain string, network string) (int, error) { - num := 0 + num := 0 - entries, err := GetDNS(network) - if err != nil { - return 0, err - } + entries, err := GetDNS(network) + if err != nil { + return 0, err + } - for i := 0; i < len(entries); i++ { + for i := 0; i < len(entries); i++ { - if domain == entries[i].Name { - num++ - } - } + if domain == entries[i].Name { + num++ + } + } - return num, nil + return num, nil } //Gets all nodes associated with network, including pending nodes func getDNS(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Type", "application/json") - var dns []models.DNSEntry - var params = mux.Vars(r) + var dns []models.DNSEntry + var params = mux.Vars(r) dns, err := GetDNS(params["network"]) - if err != nil { - returnErrorResponse(w, r, formatError(err, "internal")) - return - } - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(dns) + if err != nil { + returnErrorResponse(w, r, formatError(err, "internal")) + return + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(dns) } func GetDNS(network string) ([]models.DNSEntry, error) { - var dns []models.DNSEntry - dns, err := GetNodeDNS(network) - if err != nil { - return dns, err - } - customdns, err := GetCustomDNS(network) - if err != nil { - return dns, err - } - - dns = append(dns, customdns...) + var dns []models.DNSEntry + dns, err := GetNodeDNS(network) + if err != nil { + return dns, err + } + customdns, err := GetCustomDNS(network) + if err != nil { + return dns, err + } + + dns = append(dns, customdns...) return dns, err } @@ -229,7 +229,7 @@ func createDNS(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") var entry models.DNSEntry - var params = mux.Vars(r) + var params = mux.Vars(r) //get node from body of request _ = json.NewDecoder(r.Body).Decode(&entry) @@ -243,10 +243,10 @@ func createDNS(w http.ResponseWriter, r *http.Request) { entry, err = CreateDNS(entry) if err != nil { - returnErrorResponse(w, r, formatError(err, "internal")) + returnErrorResponse(w, r, formatError(err, "internal")) return } - w.WriteHeader(http.StatusOK) + w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(entry) } @@ -258,9 +258,9 @@ func updateDNS(w http.ResponseWriter, r *http.Request) { var entry models.DNSEntry //start here - entry, err := GetDNSEntry(params["domain"],params["network"]) + entry, err := GetDNSEntry(params["domain"], params["network"]) if err != nil { - returnErrorResponse(w, r, formatError(err, "badrequest")) + returnErrorResponse(w, r, formatError(err, "badrequest")) return } @@ -269,21 +269,21 @@ func updateDNS(w http.ResponseWriter, r *http.Request) { // we decode our body request params err = json.NewDecoder(r.Body).Decode(&dnschange) if err != nil { - returnErrorResponse(w, r, formatError(err, "badrequest")) + returnErrorResponse(w, r, formatError(err, "badrequest")) return } err = ValidateDNSUpdate(dnschange, entry) if err != nil { - returnErrorResponse(w, r, formatError(err, "badrequest")) + returnErrorResponse(w, r, formatError(err, "badrequest")) return } entry, err = UpdateDNS(dnschange, entry) if err != nil { - returnErrorResponse(w, r, formatError(err, "badrequest")) + returnErrorResponse(w, r, formatError(err, "badrequest")) return } @@ -291,92 +291,92 @@ func updateDNS(w http.ResponseWriter, r *http.Request) { } func deleteDNS(w http.ResponseWriter, r *http.Request) { - // Set header - w.Header().Set("Content-Type", "application/json") + // Set header + w.Header().Set("Content-Type", "application/json") - // get params - var params = mux.Vars(r) + // get params + var params = mux.Vars(r) - success, err := DeleteDNS(params["domain"], params["network"]) + success, err := DeleteDNS(params["domain"], params["network"]) - if err != nil { - returnErrorResponse(w, r, formatError(err, "internal")) - return - } else if !success { - returnErrorResponse(w, r, formatError(errors.New("Delete unsuccessful."), "badrequest")) - return - } + if err != nil { + returnErrorResponse(w, r, formatError(err, "internal")) + return + } else if !success { + returnErrorResponse(w, r, formatError(errors.New("Delete unsuccessful."), "badrequest")) + return + } - json.NewEncoder(w).Encode(params["domain"] + " deleted.") + json.NewEncoder(w).Encode(params["domain"] + " deleted.") } func CreateDNS(entry models.DNSEntry) (models.DNSEntry, error) { - // connect db - collection := mongoconn.Client.Database("netmaker").Collection("dns") - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + // connect db + collection := mongoconn.Client.Database("netmaker").Collection("dns") + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - // insert our node to the node db. - _, err := collection.InsertOne(ctx, entry) + // insert our node to the node db. + _, err := collection.InsertOne(ctx, entry) - defer cancel() + defer cancel() - return entry, err + return entry, err } func GetDNSEntry(domain string, network string) (models.DNSEntry, error) { - var entry models.DNSEntry + var entry models.DNSEntry - collection := mongoconn.Client.Database("netmaker").Collection("dns") + collection := mongoconn.Client.Database("netmaker").Collection("dns") - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - filter := bson.M{"name": domain, "network": network} - err := collection.FindOne(ctx, filter, options.FindOne().SetProjection(bson.M{"_id": 0})).Decode(&entry) + filter := bson.M{"name": domain, "network": network} + err := collection.FindOne(ctx, filter, options.FindOne().SetProjection(bson.M{"_id": 0})).Decode(&entry) - defer cancel() + defer cancel() - return entry, err + return entry, err } func UpdateDNS(dnschange models.DNSEntry, entry models.DNSEntry) (models.DNSEntry, error) { - queryDNS := entry.Name - - if dnschange.Name != "" { - entry.Name = dnschange.Name - } - if dnschange.Address != "" { - entry.Address = dnschange.Address - } - //collection := mongoconn.ConnectDB() - collection := mongoconn.Client.Database("netmaker").Collection("dns") - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - - // Create filter - filter := bson.M{"name": queryDNS} - - // prepare update model. - update := bson.D{ - {"$set", bson.D{ - {"name", entry.Name}, - {"address", entry.Address}, - }}, - } - var dnsupdate models.DNSEntry - - errN := collection.FindOneAndUpdate(ctx, filter, update).Decode(&dnsupdate) - if errN != nil { - fmt.Println("Could not update: ") - fmt.Println(errN) - } else { - fmt.Println("DNS Entry updated successfully.") - } - - defer cancel() - - return dnsupdate, errN + queryDNS := entry.Name + + if dnschange.Name != "" { + entry.Name = dnschange.Name + } + if dnschange.Address != "" { + entry.Address = dnschange.Address + } + //collection := mongoconn.ConnectDB() + collection := mongoconn.Client.Database("netmaker").Collection("dns") + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + + // Create filter + filter := bson.M{"name": queryDNS} + + // prepare update model. + update := bson.D{ + {"$set", bson.D{ + {"name", entry.Name}, + {"address", entry.Address}, + }}, + } + var dnsupdate models.DNSEntry + + errN := collection.FindOneAndUpdate(ctx, filter, update).Decode(&dnsupdate) + if errN != nil { + fmt.Println("Could not update: ") + fmt.Println(errN) + } else { + fmt.Println("DNS Entry updated successfully.") + } + + defer cancel() + + return dnsupdate, errN } func DeleteDNS(domain string, network string) (bool, error) { @@ -385,7 +385,7 @@ func DeleteDNS(domain string, network string) (bool, error) { collection := mongoconn.Client.Database("netmaker").Collection("dns") - filter := bson.M{"name": domain, "network": network} + filter := bson.M{"name": domain, "network": network} ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -403,31 +403,30 @@ func DeleteDNS(domain string, network string) (bool, error) { } func pushDNS(w http.ResponseWriter, r *http.Request) { - // Set header - w.Header().Set("Content-Type", "application/json") + // Set header + w.Header().Set("Content-Type", "application/json") - err := WriteHosts() + err := WriteHosts() - if err != nil { - returnErrorResponse(w, r, formatError(err, "internal")) - return + if err != nil { + returnErrorResponse(w, r, formatError(err, "internal")) + return } - json.NewEncoder(w).Encode("DNS Pushed to CoreDNS") + json.NewEncoder(w).Encode("DNS Pushed to CoreDNS") } - func WriteHosts() error { //hostfile, err := txeh.NewHostsDefault() hostfile := txeh.Hosts{} /* - if err != nil { - return err - } + if err != nil { + return err + } */ networks, err := functions.ListNetworks() - if err != nil { - return err - } + if err != nil { + return err + } for _, net := range networks { dns, err := GetDNS(net.NetID) @@ -438,7 +437,7 @@ func WriteHosts() error { hostfile.AddHost(entry.Address, entry.Name+"."+entry.Network) if err != nil { return err - } + } } } err = hostfile.SaveAs("./config/netmaker.hosts") @@ -448,9 +447,9 @@ func WriteHosts() error { func ValidateDNSCreate(entry models.DNSEntry) error { v := validator.New() - fmt.Println("Validating DNS: " + entry.Name) - fmt.Println(" Address: " + entry.Address) - fmt.Println(" Network: " + entry.Network) + fmt.Println("Validating DNS: " + entry.Name) + fmt.Println(" Address: " + entry.Address) + fmt.Println(" Network: " + entry.Network) _ = v.RegisterValidation("name_unique", func(fl validator.FieldLevel) bool { num, err := GetDNSEntryNum(entry.Name, entry.Network) @@ -459,19 +458,19 @@ func ValidateDNSCreate(entry models.DNSEntry) error { _ = v.RegisterValidation("name_valid", func(fl validator.FieldLevel) bool { isvalid := functions.NameInDNSCharSet(entry.Name) - notEmptyCheck := len(entry.Name) > 0 + notEmptyCheck := len(entry.Name) > 0 return isvalid && notEmptyCheck }) _ = v.RegisterValidation("address_valid", func(fl validator.FieldLevel) bool { notEmptyCheck := len(entry.Address) > 0 - isIp := functions.IsIpNet(entry.Address) + isIp := functions.IsIpNet(entry.Address) return notEmptyCheck && isIp }) - _ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool { - _, err := functions.GetParentNetwork(entry.Network) - return err == nil - }) + _ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool { + _, err := functions.GetParentNetwork(entry.Network) + return err == nil + }) err := v.Struct(entry) @@ -485,41 +484,39 @@ func ValidateDNSCreate(entry models.DNSEntry) error { func ValidateDNSUpdate(change models.DNSEntry, entry models.DNSEntry) error { - v := validator.New() + v := validator.New() - _ = v.RegisterValidation("name_unique", func(fl validator.FieldLevel) bool { + _ = v.RegisterValidation("name_unique", func(fl validator.FieldLevel) bool { goodNum := false - num, err := GetDNSEntryNum(entry.Name, entry.Network) + num, err := GetDNSEntryNum(entry.Name, entry.Network) if change.Name != entry.Name { goodNum = num == 0 } else { - goodNum = num == 1 + goodNum = num == 1 } return err == nil && goodNum - }) + }) - _ = v.RegisterValidation("name_valid", func(fl validator.FieldLevel) bool { - isvalid := functions.NameInDNSCharSet(entry.Name) - notEmptyCheck := entry.Name != "" - return isvalid && notEmptyCheck - }) + _ = v.RegisterValidation("name_valid", func(fl validator.FieldLevel) bool { + isvalid := functions.NameInDNSCharSet(entry.Name) + notEmptyCheck := entry.Name != "" + return isvalid && notEmptyCheck + }) - _ = v.RegisterValidation("address_valid", func(fl validator.FieldLevel) bool { + _ = v.RegisterValidation("address_valid", func(fl validator.FieldLevel) bool { isValid := true if entry.Address != "" { isValid = functions.IsIpNet(entry.Address) } return isValid - }) + }) - err := v.Struct(entry) + err := v.Struct(entry) - if err != nil { - for _, e := range err.(validator.ValidationErrors) { - fmt.Println(e) - } - } - return err + if err != nil { + for _, e := range err.(validator.ValidationErrors) { + fmt.Println(e) + } + } + return err } - - diff --git a/controllers/networkHttpController.go b/controllers/networkHttpController.go index cc669b564..0d8493024 100644 --- a/controllers/networkHttpController.go +++ b/controllers/networkHttpController.go @@ -7,17 +7,19 @@ import ( "errors" "fmt" "net/http" + "os" "strings" "time" - "os" + + "github.com/go-playground/validator/v10" "github.com/gorilla/mux" "github.com/gravitl/netmaker/config" "github.com/gravitl/netmaker/functions" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mongoconn" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" - "gopkg.in/go-playground/validator.v9" ) func networkHandlers(r *mux.Router) { @@ -83,7 +85,7 @@ func securityCheck(next http.Handler) http.HandlerFunc { //Consider a more secure way of setting master key func authenticateMaster(tokenString string) bool { - if tokenString == config.Config.Server.MasterKey || (tokenString == os.Getenv("MASTER_KEY") && tokenString != "") { + if tokenString == config.Config.Server.MasterKey || (tokenString == os.Getenv("MASTER_KEY") && tokenString != "") { return true } return false @@ -104,7 +106,7 @@ func getNetworks(w http.ResponseWriter, r *http.Request) { } } -func validateNetworkUpdate(network models.Network) error { +func ValidateNetworkUpdate(network models.Network) error { v := validator.New() @@ -112,10 +114,10 @@ func validateNetworkUpdate(network models.Network) error { isvalid := fl.Field().String() == "" || functions.IsIpCIDR(fl.Field().String()) return isvalid }) - _ = v.RegisterValidation("addressrange6_valid", func(fl validator.FieldLevel) bool { - isvalid := fl.Field().String() == "" || functions.IsIpCIDR(fl.Field().String()) - return isvalid - }) + _ = v.RegisterValidation("addressrange6_valid", func(fl validator.FieldLevel) bool { + isvalid := fl.Field().String() == "" || functions.IsIpCIDR(fl.Field().String()) + return isvalid + }) _ = v.RegisterValidation("localrange_valid", func(fl validator.FieldLevel) bool { isvalid := fl.Field().String() == "" || functions.IsIpCIDR(fl.Field().String()) @@ -140,34 +142,33 @@ func validateNetworkUpdate(network models.Network) error { return err } -func validateNetworkCreate(network models.Network) error { +func ValidateNetworkCreate(network models.Network) error { v := validator.New() - _ = v.RegisterValidation("addressrange_valid", func(fl validator.FieldLevel) bool { - isvalid := functions.IsIpCIDR(fl.Field().String()) - return isvalid - }) - _ = v.RegisterValidation("addressrange6_valid", func(fl validator.FieldLevel) bool { + // _ = v.RegisterValidation("addressrange_valid", func(fl validator.FieldLevel) bool { + // isvalid := functions.IsIpCIDR(fl.Field().String()) + // return isvalid + // }) + _ = v.RegisterValidation("addressrange6_valid", func(fl validator.FieldLevel) bool { isvalid := true if *network.IsDualStack { isvalid = functions.IsIpCIDR(fl.Field().String()) } return isvalid - }) - - - _ = v.RegisterValidation("localrange_valid", func(fl validator.FieldLevel) bool { - isvalid := fl.Field().String() == "" || functions.IsIpCIDR(fl.Field().String()) - return isvalid }) - + // + // _ = v.RegisterValidation("localrange_valid", func(fl validator.FieldLevel) bool { + // isvalid := fl.Field().String() == "" || functions.IsIpCIDR(fl.Field().String()) + // return isvalid + // }) + // _ = v.RegisterValidation("netid_valid", func(fl validator.FieldLevel) bool { isFieldUnique, _ := functions.IsNetworkNameUnique(fl.Field().String()) - inCharSet := functions.NameInNetworkCharSet(fl.Field().String()) - return isFieldUnique && inCharSet + // inCharSet := functions.NameInNetworkCharSet(fl.Field().String()) + return isFieldUnique }) - + // _ = v.RegisterValidation("displayname_unique", func(fl validator.FieldLevel) bool { isFieldUnique, _ := functions.IsNetworkDisplayNameUnique(fl.Field().String()) return isFieldUnique @@ -230,16 +231,16 @@ func keyUpdate(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - filter := bson.M{"netid": params["networkname"]} - // prepare update model. - update := bson.D{ - {"$set", bson.D{ - {"addressrange", network.AddressRange}, - {"addressrange6", network.AddressRange6}, - {"displayname", network.DisplayName}, - {"defaultlistenport", network.DefaultListenPort}, - {"defaultpostup", network.DefaultPostUp}, - {"defaultpostdown", network.DefaultPostDown}, + filter := bson.M{"netid": params["networkname"]} + // prepare update model. + update := bson.D{ + {"$set", bson.D{ + {"addressrange", network.AddressRange}, + {"addressrange6", network.AddressRange6}, + {"displayname", network.DisplayName}, + {"defaultlistenport", network.DefaultListenPort}, + {"defaultpostup", network.DefaultPostUp}, + {"defaultpostdown", network.DefaultPostDown}, {"defaultkeepalive", network.DefaultKeepalive}, {"keyupdatetimestamp", network.KeyUpdateTimeStamp}, {"defaultsaveconfig", network.DefaultSaveConfig}, @@ -248,8 +249,8 @@ func keyUpdate(w http.ResponseWriter, r *http.Request) { {"networklastmodified", network.NetworkLastModified}, {"allowmanualsignup", network.AllowManualSignUp}, {"checkininterval", network.DefaultCheckInInterval}, - }}, - } + }}, + } err = collection.FindOneAndUpdate(ctx, filter, update).Decode(&network) @@ -317,14 +318,14 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) { if networkChange.AddressRange == "" { networkChange.AddressRange = network.AddressRange } - if networkChange.AddressRange6 == "" { - networkChange.AddressRange6 = network.AddressRange6 - } + if networkChange.AddressRange6 == "" { + networkChange.AddressRange6 = network.AddressRange6 + } if networkChange.NetID == "" { networkChange.NetID = network.NetID } - err = validateNetworkUpdate(networkChange) + err = ValidateNetworkUpdate(networkChange) if err != nil { returnErrorResponse(w, r, formatError(err, "badrequest")) return @@ -366,9 +367,9 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) { if networkChange.IsLocal != nil { network.IsLocal = networkChange.IsLocal } - if networkChange.IsDualStack != nil { - network.IsDualStack = networkChange.IsDualStack - } + if networkChange.IsDualStack != nil { + network.IsDualStack = networkChange.IsDualStack + } if networkChange.DefaultListenPort != 0 { network.DefaultListenPort = networkChange.DefaultListenPort haschange = true @@ -409,26 +410,26 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) { if haschange { network.SetNetworkLastModified() } - // prepare update model. - update := bson.D{ - {"$set", bson.D{ - {"addressrange", network.AddressRange}, - {"addressrange6", network.AddressRange6}, - {"displayname", network.DisplayName}, - {"defaultlistenport", network.DefaultListenPort}, - {"defaultpostup", network.DefaultPostUp}, - {"defaultpostdown", network.DefaultPostDown}, - {"defaultkeepalive", network.DefaultKeepalive}, - {"defaultsaveconfig", network.DefaultSaveConfig}, - {"defaultinterface", network.DefaultInterface}, - {"nodeslastmodified", network.NodesLastModified}, - {"networklastmodified", network.NetworkLastModified}, - {"allowmanualsignup", network.AllowManualSignUp}, - {"localrange", network.LocalRange}, - {"islocal", network.IsLocal}, - {"isdualstack", network.IsDualStack}, - {"checkininterval", network.DefaultCheckInInterval}, - }}, + // prepare update model. + update := bson.D{ + {"$set", bson.D{ + {"addressrange", network.AddressRange}, + {"addressrange6", network.AddressRange6}, + {"displayname", network.DisplayName}, + {"defaultlistenport", network.DefaultListenPort}, + {"defaultpostup", network.DefaultPostUp}, + {"defaultpostdown", network.DefaultPostDown}, + {"defaultkeepalive", network.DefaultKeepalive}, + {"defaultsaveconfig", network.DefaultSaveConfig}, + {"defaultinterface", network.DefaultInterface}, + {"nodeslastmodified", network.NodesLastModified}, + {"networklastmodified", network.NetworkLastModified}, + {"allowmanualsignup", network.AllowManualSignUp}, + {"localrange", network.LocalRange}, + {"islocal", network.IsLocal}, + {"isdualstack", network.IsDualStack}, + {"checkininterval", network.DefaultCheckInInterval}, + }}, } err = collection.FindOneAndUpdate(ctx, filter, update).Decode(&network) @@ -472,36 +473,42 @@ func deleteNetwork(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") var params = mux.Vars(r) + network := params["networkname"] + count, err := DeleteNetwork(network) - nodecount, err := functions.GetNetworkNodeNumber(params["networkname"]) if err != nil { - returnErrorResponse(w, r, formatError(err, "internal")) - return - } else if nodecount > 0 { - errorResponse := models.ErrorResponse{ - Code: http.StatusForbidden, Message: "W1R3: Node check failed. All nodes must be deleted before deleting network.", - } - returnErrorResponse(w, r, errorResponse) + returnErrorResponse(w, r, formatError(err, "badrequest")) return } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(count) +} - collection := mongoconn.Client.Database("netmaker").Collection("networks") +func DeleteNetwork(network string) (*mongo.DeleteResult, error) { + none := &mongo.DeleteResult{} - filter := bson.M{"netid": params["networkname"]} + nodecount, err := functions.GetNetworkNodeNumber(network) + if err != nil { + //returnErrorResponse(w, r, formatError(err, "internal")) + return none, err + } else if nodecount > 0 { + //errorResponse := models.ErrorResponse{ + // Code: http.StatusForbidden, Message: "W1R3: Node check failed. All nodes must be deleted before deleting network.", + //} + //returnErrorResponse(w, r, errorResponse) + return none, errors.New("Node check failed. All nodes must be deleted before deleting network") + } + collection := mongoconn.Client.Database("netmaker").Collection("networks") + filter := bson.M{"netid": network} ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - deleteResult, err := collection.DeleteOne(ctx, filter) - defer cancel() - if err != nil { - returnErrorResponse(w, r, formatError(err, "internal")) - return + //returnErrorResponse(w, r, formatError(err, "internal")) + return none, err } - - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(deleteResult) + return deleteResult, nil } //Create a network @@ -519,6 +526,16 @@ func createNetwork(w http.ResponseWriter, r *http.Request) { return } + err = CreateNetwork(network) + if err != nil { + returnErrorResponse(w, r, formatError(err, "badrequest")) + return + } + w.WriteHeader(http.StatusOK) + //json.NewEncoder(w).Encode(result) +} + +func CreateNetwork(network models.Network) error { //TODO: Not really doing good validation here. Same as createNode, updateNode, and updateNetwork //Need to implement some better validation across the board @@ -526,15 +543,15 @@ func createNetwork(w http.ResponseWriter, r *http.Request) { falsevar := false network.IsLocal = &falsevar } - if network.IsDualStack == nil { - falsevar := false - network.IsDualStack = &falsevar - } + if network.IsDualStack == nil { + falsevar := false + network.IsDualStack = &falsevar + } - err = validateNetworkCreate(network) + err := ValidateNetworkCreate(network) if err != nil { - returnErrorResponse(w, r, formatError(err, "badrequest")) - return + //returnErrorResponse(w, r, formatError(err, "badrequest")) + return err } network.SetDefaults() network.SetNodesLastModified() @@ -546,15 +563,12 @@ func createNetwork(w http.ResponseWriter, r *http.Request) { // insert our network into the network table result, err := collection.InsertOne(ctx, network) - + fmt.Printf("=========%T, %v\n", result, result) defer cancel() - if err != nil { - returnErrorResponse(w, r, formatError(err, "internal")) - return + return err } - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(result) + return nil } // BEGIN KEY MANAGEMENT SECTION diff --git a/controllers/networkHttpController_test.go b/controllers/networkHttpController_test.go new file mode 100644 index 000000000..f8c40d342 --- /dev/null +++ b/controllers/networkHttpController_test.go @@ -0,0 +1,221 @@ +package controller + +import ( + "testing" + + "github.com/gravitl/netmaker/models" + "github.com/stretchr/testify/assert" +) + +type NetworkValidationTestCase struct { + testname string + network models.Network + errMessage string +} + +func TestGetNetworks(t *testing.T) { + //calls functions.ListNetworks --- nothing to be don +} +func TestCreateNetwork(t *testing.T) { +} +func TestGetNetwork(t *testing.T) { +} +func TestUpdateNetwork(t *testing.T) { +} +func TestDeleteNetwork(t *testing.T) { +} +func TestKeyUpdate(t *testing.T) { +} +func TestCreateKey(t *testing.T) { +} +func TestGetKey(t *testing.T) { +} +func TestDeleteKey(t *testing.T) { +} +func TestSecurityCheck(t *testing.T) { +} +func TestValidateNetworkUpdate(t *testing.T) { +} +func TestValidateNetworkCreate(t *testing.T) { + yes := true + no := false + //DeleteNetworks + cases := []NetworkValidationTestCase{ + NetworkValidationTestCase{ + testname: "InvalidAddress", + network: models.Network{ + AddressRange: "10.0.0.256", + NetID: "skynet", + IsDualStack: &no, + }, + errMessage: "Field validation for 'AddressRange' failed on the 'cidr' tag", + }, + NetworkValidationTestCase{ + testname: "BadDisplayName", + network: models.Network{ + AddressRange: "10.0.0.1/24", + NetID: "skynet", + DisplayName: "skynet*", + IsDualStack: &no, + }, + errMessage: "Field validation for 'DisplayName' failed on the 'alphanum' tag", + }, + NetworkValidationTestCase{ + testname: "DisplayNameTooLong", + network: models.Network{ + AddressRange: "10.0.0.1/24", + NetID: "skynet", + DisplayName: "Thisisareallylongdisplaynamethatistoolong", + IsDualStack: &no, + }, + errMessage: "Field validation for 'DisplayName' failed on the 'max' tag", + }, + NetworkValidationTestCase{ + testname: "DisplayNameTooShort", + network: models.Network{ + AddressRange: "10.0.0.1/24", + NetID: "skynet", + DisplayName: "1", + IsDualStack: &no, + }, + errMessage: "Field validation for 'DisplayName' failed on the 'min' tag", + }, + NetworkValidationTestCase{ + testname: "NetIDMissing", + network: models.Network{ + AddressRange: "10.0.0.1/24", + IsDualStack: &no, + }, + errMessage: "Field validation for 'NetID' failed on the 'required' tag", + }, + NetworkValidationTestCase{ + testname: "InvalidNetID", + network: models.Network{ + AddressRange: "10.0.0.1/24", + NetID: "contains spaces", + IsDualStack: &no, + }, + errMessage: "Field validation for 'NetID' failed on the 'alphanum' tag", + }, + NetworkValidationTestCase{ + testname: "NetIDTooShort", + network: models.Network{ + AddressRange: "10.0.0.1/24", + NetID: "", + IsDualStack: &no, + }, + errMessage: "Field validation for 'NetID' failed on the 'required' tag", + }, + NetworkValidationTestCase{ + testname: "NetIDTooLong", + network: models.Network{ + AddressRange: "10.0.0.1/24", + NetID: "LongNetIDName", + IsDualStack: &no, + }, + errMessage: "Field validation for 'NetID' failed on the 'max' tag", + }, + NetworkValidationTestCase{ + testname: "ListenPortTooLow", + network: models.Network{ + AddressRange: "10.0.0.1/24", + NetID: "skynet", + DefaultListenPort: 1023, + IsDualStack: &no, + }, + errMessage: "Field validation for 'DefaultListenPort' failed on the 'min' tag", + }, + NetworkValidationTestCase{ + testname: "ListenPortTooHigh", + network: models.Network{ + AddressRange: "10.0.0.1/24", + NetID: "skynet", + DefaultListenPort: 65536, + IsDualStack: &no, + }, + errMessage: "Field validation for 'DefaultListenPort' failed on the 'max' tag", + }, + NetworkValidationTestCase{ + testname: "KeepAliveTooBig", + network: models.Network{ + AddressRange: "10.0.0.1/24", + NetID: "skynet", + DefaultKeepalive: 1010, + IsDualStack: &no, + }, + errMessage: "Field validation for 'DefaultKeepalive' failed on the 'max' tag", + }, + NetworkValidationTestCase{ + testname: "InvalidLocalRange", + network: models.Network{ + AddressRange: "10.0.0.1/24", + NetID: "skynet", + LocalRange: "192.168.0.1", + IsDualStack: &no, + }, + errMessage: "Field validation for 'LocalRange' failed on the 'cidr' tag", + }, + NetworkValidationTestCase{ + testname: "DualStackWithoutIPv6", + network: models.Network{ + AddressRange: "10.0.0.1/24", + NetID: "skynet", + IsDualStack: &yes, + }, + errMessage: "Field validation for 'AddressRange6' failed on the 'addressrange6_valid' tag", + }, + NetworkValidationTestCase{ + testname: "CheckInIntervalTooBig", + network: models.Network{ + AddressRange: "10.0.0.1/24", + NetID: "skynet", + IsDualStack: &no, + DefaultCheckInInterval: 100001, + }, + errMessage: "Field validation for 'DefaultCheckInInterval' failed on the 'max' tag", + }, + NetworkValidationTestCase{ + testname: "CheckInIntervalTooSmall", + network: models.Network{ + AddressRange: "10.0.0.1/24", + NetID: "skynet", + IsDualStack: &no, + DefaultCheckInInterval: 1, + }, + errMessage: "Field validation for 'DefaultCheckInInterval' failed on the 'min' tag", + }, + } + for _, tc := range cases { + t.Run(tc.testname, func(t *testing.T) { + err := ValidateNetworkCreate(tc.network) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), tc.errMessage) + }) + } + t.Run("DuplicateNetID", func(t *testing.T) { + var net1, net2 models.Network + net1.NetID = "skylink" + net1.AddressRange = "10.0.0.1/24" + net1.DisplayName = "mynetwork" + net2.NetID = "skylink" + net2.AddressRange = "10.0.1.1/24" + net2.IsDualStack = &no + + err := CreateNetwork(net1) + assert.Nil(t, err) + err = ValidateNetworkCreate(net2) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "Field validation for 'NetID' failed on the 'netid_valid' tag") + }) + t.Run("DuplicateDisplayName", func(t *testing.T) { + var network models.Network + network.NetID = "wirecat" + network.AddressRange = "10.0.100.1/24" + network.IsDualStack = &no + network.DisplayName = "mynetwork" + err := ValidateNetworkCreate(network) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "Field validation for 'DisplayName' failed on the 'displayname_unique' tag") + }) + +} diff --git a/controllers/userHttpController.go b/controllers/userHttpController.go index ab7fcd001..64e36af1c 100644 --- a/controllers/userHttpController.go +++ b/controllers/userHttpController.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/go-playground/validator/v10" "github.com/gorilla/mux" "github.com/gravitl/netmaker/functions" "github.com/gravitl/netmaker/models" @@ -17,7 +18,6 @@ import ( "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "golang.org/x/crypto/bcrypt" - "gopkg.in/go-playground/validator.v9" ) func userHandlers(r *mux.Router) { diff --git a/controllers/userHttpController_test.go b/controllers/userHttpController_test.go index ddfc9609c..5345dee93 100644 --- a/controllers/userHttpController_test.go +++ b/controllers/userHttpController_test.go @@ -33,17 +33,17 @@ func TestHasAdmin(t *testing.T) { assert.Nil(t, err) user := models.User{"admin", "password", true} _, err = CreateUser(user) - assert.Nil(t, err, err) + assert.Nil(t, err) t.Run("AdminExists", func(t *testing.T) { found, err := HasAdmin() - assert.Nil(t, err, err) + assert.Nil(t, err) assert.True(t, found) }) t.Run("NoUser", func(t *testing.T) { _, err := DeleteUser("admin") - assert.Nil(t, err, err) + assert.Nil(t, err) found, err := HasAdmin() - assert.Nil(t, err, err) + assert.Nil(t, err) assert.False(t, found) }) } @@ -52,35 +52,35 @@ func TestCreateUser(t *testing.T) { user := models.User{"admin", "password", true} t.Run("NoUser", func(t *testing.T) { _, err := DeleteUser("admin") - assert.Nil(t, err, err) + assert.Nil(t, err) admin, err := CreateUser(user) - assert.Nil(t, err, err) + assert.Nil(t, err) assert.Equal(t, user.UserName, admin.UserName) }) t.Run("AdminExists", func(t *testing.T) { _, err := CreateUser(user) - assert.NotNil(t, err, err) + assert.NotNil(t, err) assert.Equal(t, "Admin already Exists", err.Error()) }) } func TestDeleteUser(t *testing.T) { hasadmin, err := HasAdmin() - assert.Nil(t, err, err) + assert.Nil(t, err) if !hasadmin { user := models.User{"admin", "pasword", true} _, err := CreateUser(user) - assert.Nil(t, err, err) + assert.Nil(t, err) } t.Run("ExistingUser", func(t *testing.T) { deleted, err := DeleteUser("admin") - assert.Nil(t, err, err) + assert.Nil(t, err) assert.True(t, deleted) t.Log(deleted, err) }) t.Run("NonExistantUser", func(t *testing.T) { deleted, err := DeleteUser("admin") - assert.Nil(t, err, err) + assert.Nil(t, err) assert.False(t, deleted) }) } @@ -91,33 +91,37 @@ func TestValidateUser(t *testing.T) { user.UserName = "admin" user.Password = "validpass" err := ValidateUser("create", user) - assert.Nil(t, err, err) + assert.Nil(t, err) }) t.Run("ValidUpdate", func(t *testing.T) { user.UserName = "admin" user.Password = "password" err := ValidateUser("update", user) - assert.Nil(t, err, err) + assert.Nil(t, err) }) t.Run("InvalidUserName", func(t *testing.T) { user.UserName = "invalid*" err := ValidateUser("update", user) - assert.NotNil(t, err, err) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "Field validation for 'UserName' failed") }) t.Run("ShortUserName", func(t *testing.T) { user.UserName = "12" err := ValidateUser("create", user) - assert.NotNil(t, err, err) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "Field validation for 'UserName' failed") }) t.Run("EmptyPassword", func(t *testing.T) { user.Password = "" err := ValidateUser("create", user) - assert.NotNil(t, err, err) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "Field validation for 'Password' failed") }) t.Run("ShortPassword", func(t *testing.T) { user.Password = "123" err := ValidateUser("create", user) - assert.NotNil(t, err, err) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "Field validation for 'Password' failed") }) } @@ -125,18 +129,18 @@ func TestGetUser(t *testing.T) { t.Run("UserExisits", func(t *testing.T) { user := models.User{"admin", "password", true} hasadmin, err := HasAdmin() - assert.Nil(t, err, err) + assert.Nil(t, err) if !hasadmin { _, err := CreateUser(user) - assert.Nil(t, err, err) + assert.Nil(t, err) } admin, err := GetUser("admin") - assert.Nil(t, err, err) + assert.Nil(t, err) assert.Equal(t, user.UserName, admin.UserName) }) t.Run("NonExistantUser", func(t *testing.T) { _, err := DeleteUser("admin") - assert.Nil(t, err, err) + assert.Nil(t, err) admin, err := GetUser("admin") assert.Equal(t, "mongo: no documents in result", err.Error()) assert.Equal(t, "", admin.UserName) @@ -149,14 +153,14 @@ func TestUpdateUser(t *testing.T) { t.Run("UserExisits", func(t *testing.T) { _, err := DeleteUser("admin") _, err = CreateUser(user) - assert.Nil(t, err, err) + assert.Nil(t, err) admin, err := UpdateUser(newuser, user) - assert.Nil(t, err, err) + assert.Nil(t, err) assert.Equal(t, newuser.UserName, admin.UserName) }) t.Run("NonExistantUser", func(t *testing.T) { _, err := DeleteUser("hello") - assert.Nil(t, err, err) + assert.Nil(t, err) _, err = UpdateUser(newuser, user) assert.Equal(t, "mongo: no documents in result", err.Error()) }) @@ -165,12 +169,12 @@ func TestUpdateUser(t *testing.T) { func TestValidateToken(t *testing.T) { t.Run("EmptyToken", func(t *testing.T) { err := ValidateToken("") - assert.NotNil(t, err, err) + assert.NotNil(t, err) assert.Equal(t, "Missing Auth Token.", err.Error()) }) t.Run("InvalidToken", func(t *testing.T) { err := ValidateToken("Bearer: badtoken") - assert.NotNil(t, err, err) + assert.NotNil(t, err) assert.Equal(t, "Error Verifying Auth Token", err.Error()) }) t.Run("InvalidUser", func(t *testing.T) { @@ -179,7 +183,7 @@ func TestValidateToken(t *testing.T) { }) t.Run("ValidToken", func(t *testing.T) { err := ValidateToken("Bearer: secretkey") - assert.Nil(t, err, err) + assert.Nil(t, err) }) } @@ -189,7 +193,7 @@ func TestVerifyAuthRequest(t *testing.T) { authRequest.UserName = "" authRequest.Password = "Password" jwt, err := VerifyAuthRequest(authRequest) - assert.NotNil(t, err, err) + assert.NotNil(t, err) assert.Equal(t, "", jwt) assert.Equal(t, "Username can't be empty", err.Error()) }) @@ -197,7 +201,7 @@ func TestVerifyAuthRequest(t *testing.T) { authRequest.UserName = "admin" authRequest.Password = "" jwt, err := VerifyAuthRequest(authRequest) - assert.NotNil(t, err, err) + assert.NotNil(t, err) assert.Equal(t, "", jwt) assert.Equal(t, "Password can't be empty", err.Error()) }) @@ -206,7 +210,7 @@ func TestVerifyAuthRequest(t *testing.T) { authRequest.UserName = "admin" authRequest.Password = "password" jwt, err := VerifyAuthRequest(authRequest) - assert.NotNil(t, err, err) + assert.NotNil(t, err) assert.Equal(t, "", jwt) assert.Equal(t, "User admin not found", err.Error()) }) @@ -218,7 +222,7 @@ func TestVerifyAuthRequest(t *testing.T) { assert.Nil(t, err) authRequest := models.UserAuthParams{"admin", "admin"} jwt, err := VerifyAuthRequest(authRequest) - assert.NotNil(t, err, err) + assert.NotNil(t, err) assert.Equal(t, "", jwt) assert.Equal(t, "User is not an admin", err.Error()) }) @@ -229,14 +233,14 @@ func TestVerifyAuthRequest(t *testing.T) { assert.Nil(t, err) authRequest := models.UserAuthParams{"admin", "badpass"} jwt, err := VerifyAuthRequest(authRequest) - assert.NotNil(t, err, err) + assert.NotNil(t, err) assert.Equal(t, "", jwt) assert.Equal(t, "Wrong Password", err.Error()) }) t.Run("Success", func(t *testing.T) { authRequest := models.UserAuthParams{"admin", "password"} jwt, err := VerifyAuthRequest(authRequest) - assert.Nil(t, err, err) + assert.Nil(t, err) assert.NotNil(t, jwt) }) } diff --git a/go.mod b/go.mod index ca20babe5..e6623ef84 100644 --- a/go.mod +++ b/go.mod @@ -3,15 +3,13 @@ module github.com/gravitl/netmaker go 1.15 require ( - github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgrijalva/jwt-go v3.2.0+incompatible - github.com/go-playground/universal-translator v0.17.0 // indirect + github.com/go-playground/validator/v10 v10.5.0 github.com/golang/protobuf v1.4.3 github.com/gorilla/handlers v1.5.1 github.com/gorilla/mux v1.8.0 - github.com/leodido/go-urn v1.2.0 // indirect github.com/stretchr/testify v1.6.1 - github.com/txn2/txeh v1.3.0 // indirect + github.com/txn2/txeh v1.3.0 go.mongodb.org/mongo-driver v1.4.3 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 golang.org/x/net v0.0.0-20210119194325-5f4716e94777 // indirect @@ -22,7 +20,5 @@ require ( golang.zx2c4.com/wireguard/wgctrl v0.0.0-20200609130330-bd2cb7843e1b google.golang.org/genproto v0.0.0-20210201151548-94839c025ad4 // indirect google.golang.org/grpc v1.35.0 - gopkg.in/go-playground/assert.v1 v1.2.1 // indirect - gopkg.in/go-playground/validator.v9 v9.31.0 gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c ) diff --git a/go.sum b/go.sum index 85b55347d..ab83eb1b5 100644 --- a/go.sum +++ b/go.sum @@ -22,10 +22,14 @@ github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7 github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8SPQ= github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= +github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q= github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= +github.com/go-playground/validator/v10 v10.5.0 h1:X9rflw/KmpACwT8zdrm1upefpvdy6ur8d1kWyq6sg3E= +github.com/go-playground/validator/v10 v10.5.0/go.mod h1:xm76BBt941f7yWdGnI2DVPFFg1UK3YY04qifoXU3lOk= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= @@ -261,10 +265,6 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= -gopkg.in/go-playground/assert.v1 v1.2.1 h1:xoYuJVE7KT85PYWrN730RguIQO0ePzVRfFMXadIrXTM= -gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE= -gopkg.in/go-playground/validator.v9 v9.31.0 h1:bmXmP2RSNtFES+bn4uYuHT7iJFJv7Vj+an+ZQdDaD1M= -gopkg.in/go-playground/validator.v9 v9.31.0/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/models/network.go b/models/network.go index 05ca32009..ffd91eea2 100644 --- a/models/network.go +++ b/models/network.go @@ -1,74 +1,77 @@ package models import ( -// "../mongoconn" - "go.mongodb.org/mongo-driver/bson/primitive" - "time" + // "../mongoconn" + "time" + + "go.mongodb.org/mongo-driver/bson/primitive" ) //Network Struct //At some point, need to replace all instances of Name with something else like Identifier type Network struct { - ID primitive.ObjectID `json:"_id,omitempty" bson:"_id,omitempty"` - AddressRange string `json:"addressrange" bson:"addressrange" validate:"required,addressrange_valid"` - AddressRange6 string `json:"addressrange6" bson:"addressrange6" validate:"addressrange6_valid"` - DisplayName string `json:"displayname,omitempty" bson:"displayname,omitempty" validate:"omitempty,displayname_unique,min=1,max=100"` - NetID string `json:"netid" bson:"netid" validate:"required,netid_valid,min=1,max=12"` - NodesLastModified int64 `json:"nodeslastmodified" bson:"nodeslastmodified"` - NetworkLastModified int64 `json:"networklastmodified" bson:"networklastmodified"` - DefaultInterface string `json:"defaultinterface" bson:"defaultinterface"` - DefaultListenPort int32 `json:"defaultlistenport,omitempty" bson:"defaultlistenport,omitempty" validate:"omitempty,numeric,min=1024,max=65535"` - DefaultPostUp string `json:"defaultpostup" bson:"defaultpostup"` - DefaultPostDown string `json:"defaultpostdown" bson:"defaultpostdown"` - KeyUpdateTimeStamp int64 `json:"keyupdatetimestamp" bson:"keyupdatetimestamp"` - DefaultKeepalive int32 `json:"defaultkeepalive" bson:"defaultkeepalive" validate: "omitempty,numeric,max=1000"` - DefaultSaveConfig *bool `json:"defaultsaveconfig" bson:"defaultsaveconfig"` - AccessKeys []AccessKey `json:"accesskeys" bson:"accesskeys"` - AllowManualSignUp *bool `json:"allowmanualsignup" bson:"allowmanualsignup"` - IsLocal *bool `json:"islocal" bson:"islocal"` - IsDualStack *bool `json:"isdualstack" bson:"isdualstack"` - LocalRange string `json:"localrange" bson:"localrange" validate:"localrange_valid"` - DefaultCheckInInterval int32 `json:"checkininterval,omitempty" bson:"checkininterval,omitempty" validate:"omitempty,numeric,min=1,max=100000"` + ID primitive.ObjectID `json:"_id,omitempty" bson:"_id,omitempty"` + AddressRange string `json:"addressrange" bson:"addressrange" validate:"required,cidr"` + // bug in validator --- required_with does not work with bools issue#683 + // AddressRange6 string `json:"addressrange6" bson:"addressrange6" validate:"required_with=isdualstack true,cidrv6"` + AddressRange6 string `json:"addressrange6" bson:"addressrange6" validate:"addressrange6_valid"` + DisplayName string `json:"displayname,omitempty" bson:"displayname,omitempty" validate:"omitempty,alphanum,min=2,max=20,displayname_unique"` + NetID string `json:"netid" bson:"netid" validate:"required,alphanum,min=1,max=12,netid_valid"` + NodesLastModified int64 `json:"nodeslastmodified" bson:"nodeslastmodified"` + NetworkLastModified int64 `json:"networklastmodified" bson:"networklastmodified"` + DefaultInterface string `json:"defaultinterface" bson:"defaultinterface"` + DefaultListenPort int32 `json:"defaultlistenport,omitempty" bson:"defaultlistenport,omitempty" validate:"omitempty,min=1024,max=65535"` + DefaultPostUp string `json:"defaultpostup" bson:"defaultpostup"` + DefaultPostDown string `json:"defaultpostdown" bson:"defaultpostdown"` + KeyUpdateTimeStamp int64 `json:"keyupdatetimestamp" bson:"keyupdatetimestamp"` + DefaultKeepalive int32 `json:"defaultkeepalive" bson:"defaultkeepalive" validate:"omitempty,max=1000"` + DefaultSaveConfig *bool `json:"defaultsaveconfig" bson:"defaultsaveconfig"` + AccessKeys []AccessKey `json:"accesskeys" bson:"accesskeys"` + AllowManualSignUp *bool `json:"allowmanualsignup" bson:"allowmanualsignup"` + IsLocal *bool `json:"islocal" bson:"islocal"` + IsDualStack *bool `json:"isdualstack" bson:"isdualstack"` + LocalRange string `json:"localrange" bson:"localrange" validate:"omitempty,cidr"` + DefaultCheckInInterval int32 `json:"checkininterval,omitempty" bson:"checkininterval,omitempty" validate:"omitempty,numeric,min=2,max=100000"` } //TODO: //Not sure if we need the below two functions. Got rid of one of the calls. May want to revisit -func(network *Network) SetNodesLastModified(){ - network.NodesLastModified = time.Now().Unix() +func (network *Network) SetNodesLastModified() { + network.NodesLastModified = time.Now().Unix() } -func(network *Network) SetNetworkLastModified(){ - network.NetworkLastModified = time.Now().Unix() +func (network *Network) SetNetworkLastModified() { + network.NetworkLastModified = time.Now().Unix() } -func(network *Network) SetDefaults(){ - if network.DisplayName == "" { - network.DisplayName = network.NetID - } - if network.DefaultInterface == "" { - network.DefaultInterface = "nm-" + network.NetID - } - if network.DefaultListenPort == 0 { - network.DefaultListenPort = 51821 - } - if network.DefaultPostDown == "" { +func (network *Network) SetDefaults() { + if network.DisplayName == "" { + network.DisplayName = network.NetID + } + if network.DefaultInterface == "" { + network.DefaultInterface = "nm-" + network.NetID + } + if network.DefaultListenPort == 0 { + network.DefaultListenPort = 51821 + } + if network.DefaultPostDown == "" { - } - if network.DefaultSaveConfig == nil { - defaultsave := true - network.DefaultSaveConfig = &defaultsave - } - if network.DefaultKeepalive == 0 { - network.DefaultKeepalive = 20 - } - if network.DefaultPostUp == "" { - } - //Check-In Interval for Nodes, In Seconds - if network.DefaultCheckInInterval == 0 { - network.DefaultCheckInInterval = 30 - } - if network.AllowManualSignUp == nil { - signup := false - network.AllowManualSignUp = &signup - } + } + if network.DefaultSaveConfig == nil { + defaultsave := true + network.DefaultSaveConfig = &defaultsave + } + if network.DefaultKeepalive == 0 { + network.DefaultKeepalive = 20 + } + if network.DefaultPostUp == "" { + } + //Check-In Interval for Nodes, In Seconds + if network.DefaultCheckInInterval == 0 { + network.DefaultCheckInInterval = 30 + } + if network.AllowManualSignUp == nil { + signup := false + network.AllowManualSignUp = &signup + } }