diff --git a/controllers/common.go b/controllers/common.go index ff017efc9..7ba14e311 100644 --- a/controllers/common.go +++ b/controllers/common.go @@ -2,11 +2,13 @@ package controller import ( "context" + "encoding/json" "fmt" "log" "time" "github.com/go-playground/validator/v10" + "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/functions" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mongoconn" @@ -59,64 +61,41 @@ func GetPeersList(networkName string) ([]models.PeersResponse, error) { return peers, err } - func GetExtPeersList(networkName string, macaddress string) ([]models.ExtPeersResponse, error) { - var peers []models.ExtPeersResponse - - //Connection mongoDB with mongoconn class - collection := mongoconn.Client.Database("netmaker").Collection("extclients") - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - - //Get all nodes in the relevant network which are NOT in pending state - filter := bson.M{"network": networkName, "ingressgatewayid": macaddress} - cur, err := collection.Find(ctx, filter) - - if err != nil { - return peers, err - } - - // Close the cursor once finished and cancel if it takes too long - defer cancel() - - for cur.Next(context.TODO()) { - - var peer models.ExtPeersResponse - err := cur.Decode(&peer) - if err != nil { - log.Fatal(err) - } + var peers []models.ExtPeersResponse + records, err := database.FetchRecords(database.EXT_CLIENT_TABLE_NAME) - // add the node to our node array - //maybe better to just return this? But then that's just GetNodes... - peers = append(peers, peer) - } - - //Uh oh, fatal error! This needs some better error handling - //TODO: needs appropriate error handling so the server doesnt shut down. - if err := cur.Err(); err != nil { - log.Fatal(err) - } + if err != nil { + return peers, err + } - return peers, err + for _, value := range records { + var peer models.ExtPeersResponse + err = json.Unmarshal([]byte(value), &peer) + if err != nil { + functions.PrintUserLog("netmaker", "failed to unmarshal ext client", 2) + continue + } + peers = append(peers, peer) + } + return peers, err } - func ValidateNodeCreate(networkName string, node models.Node) error { v := validator.New() _ = v.RegisterValidation("macaddress_unique", func(fl validator.FieldLevel) bool { - var isFieldUnique bool = functions.IsFieldUnique(networkName, "macaddress", node.MacAddress) + isFieldUnique, _ := functions.IsMacAddressUnique(node.MacAddress, networkName) return isFieldUnique }) _ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool { _, err := node.GetNetwork() return err == nil }) - _ = v.RegisterValidation("in_charset", func(fl validator.FieldLevel) bool { - isgood := functions.NameInNodeCharSet(node.Name) - return isgood - }) + _ = v.RegisterValidation("in_charset", func(fl validator.FieldLevel) bool { + isgood := functions.NameInNodeCharSet(node.Name) + return isgood + }) err := v.Struct(node) if err != nil { @@ -133,10 +112,10 @@ func ValidateNodeUpdate(networkName string, node models.NodeUpdate) error { _, err := node.GetNetwork() return err == nil }) - _ = v.RegisterValidation("in_charset", func(fl validator.FieldLevel) bool { - isgood := functions.NameInNodeCharSet(node.Name) - return isgood - }) + _ = v.RegisterValidation("in_charset", func(fl validator.FieldLevel) bool { + isgood := functions.NameInNodeCharSet(node.Name) + return isgood + }) err := v.Struct(node) if err != nil { for _, e := range err.(validator.ValidationErrors) { @@ -302,27 +281,27 @@ func DeleteNode(macaddress string, network string) (bool, error) { func DeleteIntClient(clientid string) (bool, error) { - deleted := false + deleted := false - collection := mongoconn.Client.Database("netmaker").Collection("intclients") + collection := mongoconn.Client.Database("netmaker").Collection("intclients") - filter := bson.M{"clientid": clientid} + filter := bson.M{"clientid": clientid} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - result, err := collection.DeleteOne(ctx, filter) + result, err := collection.DeleteOne(ctx, filter) - deletecount := result.DeletedCount + deletecount := result.DeletedCount - if deletecount > 0 { - deleted = true - } + if deletecount > 0 { + deleted = true + } - defer cancel() + defer cancel() err = serverctl.ReconfigureServerWireGuard() - return deleted, err + return deleted, err } func GetNode(macaddress string, network string) (models.Node, error) { @@ -343,18 +322,18 @@ func GetNode(macaddress string, network string) (models.Node, error) { func GetIntClient(clientid string) (models.IntClient, error) { - var client models.IntClient + var client models.IntClient - collection := mongoconn.Client.Database("netmaker").Collection("intclients") + collection := mongoconn.Client.Database("netmaker").Collection("intclients") - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - filter := bson.M{"clientid": clientid} - err := collection.FindOne(ctx, filter, options.FindOne().SetProjection(bson.M{"_id": 0})).Decode(&clientid) + filter := bson.M{"clientid": clientid} + err := collection.FindOne(ctx, filter, options.FindOne().SetProjection(bson.M{"_id": 0})).Decode(&clientid) - defer cancel() + defer cancel() - return client, err + return client, err } func CreateNode(node models.Node, networkName string) (models.Node, error) { diff --git a/controllers/networkHttpController.go b/controllers/networkHttpController.go index ba2bce891..999f7aad2 100644 --- a/controllers/networkHttpController.go +++ b/controllers/networkHttpController.go @@ -545,7 +545,7 @@ func UpdateNetwork(networkChange models.NetworkUpdate, network models.Network) ( } } if haslocalrangeupdate { - err = functions.UpdateNetworkPrivateAddresses(network.NetID) + err = functions.UpdateNetworkLocalAddresses(network.NetID) if err != nil { return models.Network{}, err } @@ -734,14 +734,14 @@ func CreateAccessKey(accesskey models.AccessKey, network models.Network) (models s := servercfg.GetServerConfig() w := servercfg.GetWGConfig() servervals := models.ServerConfig{ - CoreDNSAddr: s.CoreDNSAddr, - APIConnString: s.APIConnString, - APIHost: s.APIHost, - APIPort: s.APIPort, - GRPCConnString: s.GRPCConnString, - GRPCHost: s.GRPCHost, - GRPCPort: s.GRPCPort, - GRPCSSL: s.GRPCSSL, + CoreDNSAddr: s.CoreDNSAddr, + APIConnString: s.APIConnString, + APIHost: s.APIHost, + APIPort: s.APIPort, + GRPCConnString: s.GRPCConnString, + GRPCHost: s.GRPCHost, + GRPCPort: s.GRPCPort, + GRPCSSL: s.GRPCSSL, } wgvals := models.WG{ GRPCWireGuard: w.GRPCWireGuard, diff --git a/database/database.go b/database/database.go new file mode 100644 index 000000000..dd5adf9e7 --- /dev/null +++ b/database/database.go @@ -0,0 +1,100 @@ +package database + +import ( + "log" + + "github.com/rqlite/gorqlite" +) + +const NETWORKS_TABLE_NAME = "networks" +const NODES_TABLE_NAME = "nodes" +const USERS_TABLE_NAME = "users" +const DNS_TABLE_NAME = "dns" +const EXT_CLIENT_TABLE_NAME = "extclients" +const INT_CLIENTS_TABLE_NAME = "intclients" +const DATABASE_FILENAME = "netmaker.db" + +var Database gorqlite.Connection + +func InitializeDatabase() error { + + conn, err := gorqlite.Open("http://") + if err != nil { + return err + } + + // sqliteDatabase, _ := sql.Open("sqlite3", "./database/"+dbFilename) + Database = conn + Database.SetConsistencyLevel("strong") + createTables() + return nil +} + +func createTables() { + createTable(NETWORKS_TABLE_NAME) + createTable(NODES_TABLE_NAME) + createTable(USERS_TABLE_NAME) + createTable(DNS_TABLE_NAME) + createTable(EXT_CLIENT_TABLE_NAME) + createTable(INT_CLIENTS_TABLE_NAME) +} + +func createTable(tableName string) error { + _, err := Database.WriteOne("CREATE TABLE IF NOT EXISTS " + tableName + " (key TEXT NOT NULL UNIQUE PRIMARY KEY, value TEXT)") + if err != nil { + return err + } + return nil +} + +func Insert(key string, value string, tableName string) error { + _, err := Database.WriteOne("INSERT OR REPLACE INTO " + tableName + " (key, value) VALUES ('" + key + "', '" + value + "')") + if err != nil { + return err + } + return nil +} + +func DeleteRecord(tableName string, key string) error { + _, err := Database.WriteOne("DELETE FROM " + tableName + " WHERE key = \"" + key + "\"") + if err != nil { + return err + } + return nil +} + +func DeleteAllRecords(tableName string) error { + _, err := Database.WriteOne("DELETE TABLE " + tableName) + if err != nil { + return err + } + err = createTable(tableName) + if err != nil { + return err + } + return nil +} + +func FetchRecord(tableName string, key string) (string, error) { + results, err := FetchRecords(tableName) + if err != nil { + return "", err + } + return results[key], nil +} + +func FetchRecords(tableName string) (map[string]string, error) { + row, err := Database.QueryOne("SELECT * FROM " + tableName + " ORDER BY key") + if err != nil { + return nil, err + } + records := make(map[string]string) + for row.Next() { // Iterate and fetch the records from result cursor + var key string + var value string + row.Scan(&key, &value) + records[key] = value + } + log.Println(tableName, records) + return records, nil +} diff --git a/functions/helpers.go b/functions/helpers.go index 389daf36a..5fec45a16 100644 --- a/functions/helpers.go +++ b/functions/helpers.go @@ -5,7 +5,6 @@ package functions import ( - "context" "encoding/base64" "encoding/json" "errors" @@ -16,13 +15,9 @@ import ( "strings" "time" + "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/models" - "github.com/gravitl/netmaker/mongoconn" "github.com/gravitl/netmaker/servercfg" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/bson/primitive" - "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" ) func PrintUserLog(username string, message string, loglevel int) { @@ -76,179 +71,85 @@ func CreateServerToken(netID string) (string, error) { accesskey.AccessString = base64.StdEncoding.EncodeToString([]byte(tokenjson)) network.AccessKeys = append(network.AccessKeys, accesskey) - - collection := mongoconn.Client.Database("netmaker").Collection("networks") - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - - // Create filter - filter := bson.M{"netid": netID} - - // prepare update model. - update := bson.D{ - {"$set", bson.D{ - {"accesskeys", network.AccessKeys}, - }}, + if data, err := json.Marshal(network); err != nil { + return "", err + } else { + database.Insert(netID, string(data), database.NETWORKS_TABLE_NAME) } - errN := collection.FindOneAndUpdate(ctx, filter, update).Decode(&network) - - defer cancel() - - if errN != nil { - return "", errN - } return accesskey.AccessString, nil } func GetPeersList(networkName string) ([]models.PeersResponse, error) { var peers []models.PeersResponse - - //Connection mongoDB with mongoconn class - collection := mongoconn.Client.Database("netmaker").Collection("nodes") - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - - //Get all nodes in the relevant network which are NOT in pending state - filter := bson.M{"network": networkName, "ispending": false} - cur, err := collection.Find(ctx, filter) - + collection, err := database.FetchRecords(database.NODES_TABLE_NAME) if err != nil { return peers, err } - // Close the cursor once finished and cancel if it takes too long - defer cancel() - - for cur.Next(context.TODO()) { + for _, value := range collection { var peer models.PeersResponse - err := cur.Decode(&peer) + err := json.Unmarshal([]byte(value), &peer) if err != nil { - log.Fatal(err) + continue // try the rest } - - // add the node to our node array - //maybe better to just return this? But then that's just GetNodes... peers = append(peers, peer) } - //Uh oh, fatal error! This needs some better error handling - //TODO: needs appropriate error handling so the server doesnt shut down. - if err := cur.Err(); err != nil { - log.Fatal(err) - } - return peers, err } func GetIntPeersList() ([]models.PeersResponse, error) { var peers []models.PeersResponse - - collection := mongoconn.Client.Database("netmaker").Collection("intclients") - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - - filter := bson.M{"isserver": ""} - - cur, err := collection.Find(ctx, filter) + records, err := database.FetchRecords(database.INT_CLIENTS_TABLE_NAME) if err != nil { return peers, err } + // parse the peers - // Close the cursor once finished and cancel if it takes too long - defer cancel() - - for cur.Next(context.TODO()) { + for _, value := range records { var peer models.PeersResponse - err := cur.Decode(&peer) + err := json.Unmarshal([]byte(value), &peer) if err != nil { log.Fatal(err) } - // add the node to our node array //maybe better to just return this? But then that's just GetNodes... peers = append(peers, peer) } - //Uh oh, fatal error! This needs some better error handling - //TODO: needs appropriate error handling so the server doesnt shut down. - if err := cur.Err(); err != nil { - log.Fatal(err) - } - return peers, err } -func IsFieldUnique(network string, field string, value string) bool { - - var node models.Node - isunique := true +func GetServerIntClient() (*models.IntClient, error) { - collection := mongoconn.Client.Database("netmaker").Collection("nodes") - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - - filter := bson.M{field: value, "network": network} - - err := collection.FindOne(ctx, filter).Decode(&node) - - defer cancel() - - if err != nil { - return isunique - } - - if node.Name != "" { - isunique = false - } - - return isunique -} - -func ServerIntClientExists() (bool, error) { - - collection := mongoconn.Client.Database("netmaker").Collection("intclients") - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - - filter := bson.M{"isserver": "yes"} - - var result bson.M - err := collection.FindOne(ctx, filter).Decode(&result) - - defer cancel() - - if err != nil { - if err == mongo.ErrNoDocuments { - return false, nil + intClients, err := database.FetchRecords(database.INT_CLIENTS_TABLE_NAME) + for _, value := range intClients { + var intClient models.IntClient + err = json.Unmarshal([]byte(value), &intClient) + if err != nil { + return nil, err + } + if intClient.IsServer == "yes" && intClient.Network == "comms" { + return &intClient, nil } } - return true, err + return nil, err } func NetworkExists(name string) (bool, error) { - collection := mongoconn.Client.Database("netmaker").Collection("networks") - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - - filter := bson.M{"netid": name} - - var result bson.M - err := collection.FindOne(ctx, filter).Decode(&result) - - defer cancel() - - if err != nil { - if err == mongo.ErrNoDocuments { - return false, nil - } + var network string + var err error + if network, err = database.FetchRecord(database.NETWORKS_TABLE_NAME, name); err != nil { + return false, err } - return true, err + return len(network) > 0, nil } //TODO: This is very inefficient (N-squared). Need to find a better way. @@ -256,25 +157,15 @@ func NetworkExists(name string) (bool, error) { //for each node, it gets a unique address. That requires checking against all other nodes once more func UpdateNetworkNodeAddresses(networkName string) error { - //Connection mongoDB with mongoconn class - collection := mongoconn.Client.Database("netmaker").Collection("nodes") - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - - filter := bson.M{"network": networkName} - cur, err := collection.Find(ctx, filter) - + collections, err := database.FetchRecords(database.NODES_TABLE_NAME) if err != nil { return err } - defer cancel() - - for cur.Next(context.TODO()) { + for _, value := range collections { var node models.Node - - err := cur.Decode(&node) + err := json.Unmarshal([]byte(value), &node) if err != nil { fmt.Println("error in node address assignment!") return err @@ -285,42 +176,30 @@ func UpdateNetworkNodeAddresses(networkName string) error { return iperr } - filter := bson.M{"macaddress": node.MacAddress} - update := bson.D{{"$set", bson.D{{"address", ipaddr}}}} - - errN := collection.FindOneAndUpdate(ctx, filter, update).Decode(&node) - - defer cancel() - if errN != nil { - return errN + node.Address = ipaddr + data, err := json.Marshal(&node) + if err != nil { + return err } + database.Insert(node.MacAddress, string(data), database.NODES_TABLE_NAME) } - return err + return nil } -//TODO TODO TODO!!!!! -func UpdateNetworkPrivateAddresses(networkName string) error { - - //Connection mongoDB with mongoconn class - collection := mongoconn.Client.Database("netmaker").Collection("nodes") +func UpdateNetworkLocalAddresses(networkName string) error { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - - filter := bson.M{"network": networkName} - cur, err := collection.Find(ctx, filter) + collection, err := database.FetchRecords(database.NODES_TABLE_NAME) if err != nil { return err } - defer cancel() - - for cur.Next(context.TODO()) { + for _, value := range collection { var node models.Node - err := cur.Decode(&node) + err := json.Unmarshal([]byte(value), &node) if err != nil { fmt.Println("error in node address assignment!") return err @@ -331,18 +210,16 @@ func UpdateNetworkPrivateAddresses(networkName string) error { return iperr } - filter := bson.M{"macaddress": node.MacAddress} - update := bson.D{{"$set", bson.D{{"address", ipaddr}}}} - - errN := collection.FindOneAndUpdate(ctx, filter, update).Decode(&node) - - defer cancel() - if errN != nil { - return errN + node.Address = ipaddr + newNodeData, err := json.Marshal(&node) + if err != nil { + fmt.Println("error in node address assignment!") + return err } + database.Insert(node.MacAddress, string(newNodeData), database.NODES_TABLE_NAME) } - return err + return nil } //Checks to see if any other networks have the same name (id) @@ -385,60 +262,68 @@ func IsNetworkDisplayNameUnique(name string) (bool, error) { return isunique, nil } -func GetNetworkNodeNumber(networkName string) (int, error) { +func IsMacAddressUnique(macaddress string, networkName string) (bool, error) { - collection := mongoconn.Client.Database("netmaker").Collection("nodes") + collection, err := database.FetchRecords(database.NODES_TABLE_NAME) + if err != nil { + return false, err + } + for _, value := range collection { + var node models.Node + if err = json.Unmarshal([]byte(value), &node); err != nil { + return false, err + } else { + if node.MacAddress == macaddress && node.Network == networkName { + return false, nil + } + } + } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + return true, nil +} - filter := bson.M{"network": networkName} - count, err := collection.CountDocuments(ctx, filter) - returncount := int(count) +func GetNetworkNodeNumber(networkName string) (int, error) { - //not sure if this is the right way of handling this error... + collection, err := database.FetchRecords(database.NODES_TABLE_NAME) + count := 0 if err != nil { - return 9999, err + return count, err + } + for _, value := range collection { + var node models.Node + if err = json.Unmarshal([]byte(value), &node); err != nil { + return count, err + } else { + if node.Network == networkName { + count++ + } + } } - defer cancel() - - return returncount, err + return count, nil } -//Kind of a weird name. Should just be GetNetworks I think. Consider changing. -//Anyway, returns all the networks +// Anyway, returns all the networks func ListNetworks() ([]models.Network, error) { var networks []models.Network - collection := mongoconn.Client.Database("netmaker").Collection("networks") - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - - cur, err := collection.Find(ctx, bson.M{}, options.Find().SetProjection(bson.M{"_id": 0})) + collection, err := database.FetchRecords(database.NETWORKS_TABLE_NAME) if err != nil { return networks, err } - defer cancel() - - for cur.Next(context.TODO()) { + for _, value := range collection { var network models.Network - err := cur.Decode(&network) - if err != nil { + if err := json.Unmarshal([]byte(value), &network); err != nil { return networks, err } - // add network our array networks = append(networks, network) } - if err := cur.Err(); err != nil { - return networks, err - } - return networks, err } @@ -502,20 +387,13 @@ func IsKeyValidGlobal(keyvalue string) bool { func GetParentNetwork(networkname string) (models.Network, error) { var network models.Network - - collection := mongoconn.Client.Database("netmaker").Collection("networks") - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - - filter := bson.M{"netid": networkname} - err := collection.FindOne(ctx, filter).Decode(&network) - - defer cancel() - + networkData, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, networkname) if err != nil { return network, err } - + if err = json.Unmarshal([]byte(networkData), network); err != nil { + return network, err + } return network, nil } @@ -545,31 +423,6 @@ func IsBase64(s string) bool { return err == nil } -//This should probably just be called GetNode -//It returns a node based on the ID of the node. -//Why do we need this? -//TODO: Check references. This seems unnecessary. -func GetNodeObj(id primitive.ObjectID) models.Node { - - var node models.Node - - collection := mongoconn.Client.Database("netmaker").Collection("nodes") - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - - filter := bson.M{"_id": id} - err := collection.FindOne(ctx, filter).Decode(&node) - - defer cancel() - - if err != nil { - fmt.Println(err) - fmt.Println("Did not get the node...") - return node - } - fmt.Println("Got node " + node.Name) - return node -} - //This checks to make sure a network name is valid. //Switch to REGEX? func NameInNetworkCharSet(name string) bool { @@ -616,47 +469,41 @@ func GetNodeByMacAddress(network string, macaddress string) (models.Node, error) var node models.Node - filter := bson.M{"macaddress": macaddress, "network": network} - - collection := mongoconn.Client.Database("netmaker").Collection("nodes") - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - - err := collection.FindOne(ctx, filter).Decode(&node) - - defer cancel() + records, err := database.FetchRecords(database.NODES_TABLE_NAME) if err != nil { return node, err } - return node, nil + + for _, value := range records { + json.Unmarshal([]byte(value), &node) + if node.MacAddress == macaddress && node.Network == network { + return node, nil + } + } + + return models.Node{}, nil } func DeleteAllIntClients() error { - collection := mongoconn.Client.Database("netmaker").Collection("intclients") - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - // Filter out them ID's again - err := collection.Drop(ctx) + err := database.DeleteAllRecords(database.INT_CLIENTS_TABLE_NAME) if err != nil { return err } - defer cancel() return nil } func GetAllIntClients() ([]models.IntClient, error) { - var client models.IntClient var clients []models.IntClient - collection := mongoconn.Client.Database("netmaker").Collection("intclients") - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - // Filter out them ID's again - cur, err := collection.Find(ctx, bson.M{}, options.Find().SetProjection(bson.M{"_id": 0})) + collection, err := database.FetchRecords(database.INT_CLIENTS_TABLE_NAME) + if err != nil { - return []models.IntClient{}, err + return clients, err } - defer cancel() - for cur.Next(context.TODO()) { - err := cur.Decode(&client) + + for _, value := range collection { + var client models.IntClient + err := json.Unmarshal([]byte(value), &client) if err != nil { return []models.IntClient{}, err } @@ -664,26 +511,20 @@ func GetAllIntClients() ([]models.IntClient, error) { clients = append(clients, client) } - //TODO: Fatal error - if err := cur.Err(); err != nil { - return []models.IntClient{}, err - } return clients, nil } func GetAllExtClients() ([]models.ExtClient, error) { - var extclient models.ExtClient var extclients []models.ExtClient - collection := mongoconn.Client.Database("netmaker").Collection("extclients") - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - // Filter out them ID's again - cur, err := collection.Find(ctx, bson.M{}, options.Find().SetProjection(bson.M{"_id": 0})) + collection, err := database.FetchRecords(database.EXT_CLIENT_TABLE_NAME) + if err != nil { - return []models.ExtClient{}, err + return extclients, err } - defer cancel() - for cur.Next(context.TODO()) { - err := cur.Decode(&extclient) + + for _, value := range collection { + var extclient models.ExtClient + err := json.Unmarshal([]byte(value), &extclient) if err != nil { return []models.ExtClient{}, err } @@ -691,10 +532,6 @@ func GetAllExtClients() ([]models.ExtClient, error) { extclients = append(extclients, extclient) } - //TODO: Fatal error - if err := cur.Err(); err != nil { - return []models.ExtClient{}, err - } return extclients, nil } @@ -724,11 +561,11 @@ func UniqueAddress(networkName string) (string, error) { continue } if networkName == "comms" { - if IsIPUniqueClients(networkName, ip.String()) { + if IsIPUnique(networkName, ip.String(), database.INT_CLIENTS_TABLE_NAME, false) { return ip.String(), err } } else { - if IsIPUnique(networkName, ip.String()) && IsIPUniqueExtClients(networkName, ip.String()) { + if IsIPUnique(networkName, ip.String(), database.NODES_TABLE_NAME, false) && IsIPUnique(networkName, ip.String(), database.EXT_CLIENT_TABLE_NAME, false) { return ip.String(), err } } @@ -765,11 +602,11 @@ func UniqueAddress6(networkName string) (string, error) { continue } if networkName == "comms" { - if IsIP6UniqueClients(networkName, ip.String()) { + if IsIPUnique(networkName, ip.String(), database.INT_CLIENTS_TABLE_NAME, true) { return ip.String(), err } } else { - if IsIP6Unique(networkName, ip.String()) { + if IsIPUnique(networkName, ip.String(), database.NODES_TABLE_NAME, true) { return ip.String(), err } } @@ -814,136 +651,31 @@ func GenKeyName() string { return "key" + string(b) } -func IsIPUniqueExtClients(network string, ip string) bool { - - var extclient models.ExtClient +func IsIPUnique(network string, ip string, tableName string, isIpv6 bool) bool { isunique := true - - collection := mongoconn.Client.Database("netmaker").Collection("extclients") - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - - filter := bson.M{"address": ip, "network": network} - - err := collection.FindOne(ctx, filter).Decode(&extclient) - - defer cancel() + collection, err := database.FetchRecords(tableName) if err != nil { return isunique } - if extclient.Address == ip { - isunique = false - } - return isunique -} - -//checks if IP is unique in the address range -//used by UniqueAddress -func IsIPUnique(network string, ip string) bool { - - var node models.Node - - isunique := true - - collection := mongoconn.Client.Database("netmaker").Collection("nodes") - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - - filter := bson.M{"address": ip, "network": network} - - err := collection.FindOne(ctx, filter).Decode(&node) - - defer cancel() - - if err != nil { - return isunique - } - - if node.Address == ip { - isunique = false - } - return isunique -} - -//checks if IP is unique in the address range -//used by UniqueAddress -func IsIP6Unique(network string, ip string) bool { - - var node models.Node - - isunique := true - - collection := mongoconn.Client.Database("netmaker").Collection("nodes") - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - - filter := bson.M{"address6": ip, "network": network} - - err := collection.FindOne(ctx, filter).Decode(&node) - - defer cancel() - - if err != nil { - return isunique - } - - if node.Address6 == ip { - isunique = false - } - return isunique -} - -//checks if IP is unique in the address range -//used by UniqueAddress -func IsIP6UniqueClients(network string, ip string) bool { - - var client models.IntClient - - isunique := true - - collection := mongoconn.Client.Database("netmaker").Collection("intclients") - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - - filter := bson.M{"address6": ip, "network": network} - - err := collection.FindOne(ctx, filter).Decode(&client) - - defer cancel() - - if err != nil { - return isunique - } - - if client.Address6 == ip { - isunique = false - } - return isunique -} - -//checks if IP is unique in the address range -//used by UniqueAddress -func IsIPUniqueClients(network string, ip string) bool { - - var client models.IntClient - - isunique := true - - collection := mongoconn.Client.Database("netmaker").Collection("intclients") - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - - filter := bson.M{"address": ip, "network": network} - - err := collection.FindOne(ctx, filter).Decode(&client) - - defer cancel() - - if err != nil { - return isunique + for _, value := range collection { // filter + var node models.Node + if err = json.Unmarshal([]byte(value), &node); err != nil { + continue + } + if isIpv6 { + if node.Address6 == ip && node.Network == network { + return false + } + } else { + if node.Address == ip && node.Network == network { + return false + } + } } - if client.Address == ip { - isunique = false - } return isunique } @@ -964,31 +696,18 @@ func DecrimentKey(networkName string, keyvalue string) { if currentkey.Value == keyvalue { network.AccessKeys[i].Uses-- if network.AccessKeys[i].Uses < 1 { - //this is the part where it will call the delete - //not sure if there's edge cases I'm missing - DeleteKey(network, i) - return + network.AccessKeys = append(network.AccessKeys[:i], + network.AccessKeys[i+1:]...) + break } } } - collection := mongoconn.Client.Database("netmaker").Collection("networks") - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - - filter := bson.M{"netid": network.NetID} - - update := bson.D{ - {"$set", bson.D{ - {"accesskeys", network.AccessKeys}, - }}, - } - errN := collection.FindOneAndUpdate(ctx, filter, update).Decode(&network) - - defer cancel() - - if errN != nil { + if newNetworkData, err := json.Marshal(&network); err != nil { + PrintUserLog("netmaker", "failed to decrement key", 2) return + } else { + database.Insert(network.NetID, string(newNetworkData), database.NETWORKS_TABLE_NAME) } } @@ -998,26 +717,10 @@ func DeleteKey(network models.Network, i int) { network.AccessKeys = append(network.AccessKeys[:i], network.AccessKeys[i+1:]...) - collection := mongoconn.Client.Database("netmaker").Collection("networks") - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - - // Create filter - filter := bson.M{"netid": network.NetID} - - // prepare update model. - update := bson.D{ - {"$set", bson.D{ - {"accesskeys", network.AccessKeys}, - }}, - } - - errN := collection.FindOneAndUpdate(ctx, filter, update).Decode(&network) - - defer cancel() - - if errN != nil { + if networkData, err := json.Marshal(&network); err != nil { return + } else { + database.Insert(network.NetID, string(networkData), database.NETWORKS_TABLE_NAME) } } @@ -1032,28 +735,21 @@ func Inc(ip net.IP) { } func GetAllNodes() ([]models.Node, error) { - var node models.Node var nodes []models.Node - collection := mongoconn.Client.Database("netmaker").Collection("nodes") - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - // Filter out them ID's again - cur, err := collection.Find(ctx, bson.M{}, options.Find().SetProjection(bson.M{"_id": 0})) + + collection, err := database.FetchRecords(database.NODES_TABLE_NAME) if err != nil { return []models.Node{}, err } - defer cancel() - for cur.Next(context.TODO()) { - err := cur.Decode(&node) - if err != nil { + + for _, value := range collection { + var node models.Node + if err := json.Unmarshal([]byte(value), &node); err != nil { return []models.Node{}, err } // add node to our array nodes = append(nodes, node) } - //TODO: Fatal error - if err := cur.Err(); err != nil { - return []models.Node{}, err - } return nodes, nil } diff --git a/go.mod b/go.mod index dd78ec8ee..981645b41 100644 --- a/go.mod +++ b/go.mod @@ -6,10 +6,13 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/go-playground/validator/v10 v10.5.0 + github.com/go-sql-driver/mysql v1.6.0 // indirect github.com/golang/protobuf v1.5.2 github.com/gorilla/handlers v1.5.1 github.com/gorilla/mux v1.8.0 github.com/jinzhu/copier v0.3.2 // indirect + github.com/mattn/go-sqlite3 v1.14.8 + github.com/rqlite/gorqlite v0.0.0-20210514125552-08ff1e76b22f github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/stretchr/testify v1.6.1 github.com/txn2/txeh v1.3.0 diff --git a/go.sum b/go.sum index abe0fb5d5..e3243a924 100644 --- a/go.sum +++ b/go.sum @@ -34,6 +34,8 @@ github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+ github.com/go-playground/validator/v10 v10.5.0 h1:X9rflw/KmpACwT8zdrm1upefpvdy6ur8d1kWyq6sg3E= github.com/go-playground/validator/v10 v10.5.0/go.mod h1:xm76BBt941f7yWdGnI2DVPFFg1UK3YY04qifoXU3lOk= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gobuffalo/attrs v0.0.0-20190224210810-a9411de4debd/go.mod h1:4duuawTqi2wkkpB4ePgWMaai6/Kc6WEz83bhFwpHzj0= @@ -119,6 +121,8 @@ github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgx github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/markbates/oncer v0.0.0-20181203154359-bf2de49a0be2/go.mod h1:Ld9puTsIW75CHf65OeIOkyKbteujpZVXDpWK6YGZbxE= github.com/markbates/safe v1.0.1/go.mod h1:nAqgmRi7cY2nqMc92/bSEeQA+R4OheNU2T1kNSCBdG0= +github.com/mattn/go-sqlite3 v1.14.8 h1:gDp86IdQsN/xWjIEmr9MF6o9mpksUgh0fu+9ByFxzIU= +github.com/mattn/go-sqlite3 v1.14.8/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mdlayher/genetlink v1.0.0 h1:OoHN1OdyEIkScEmRgxLEe2M9U8ClMytqA5niynLtfj0= github.com/mdlayher/genetlink v1.0.0/go.mod h1:0rJ0h4itni50A86M2kHcgS85ttZazNt7a8H2a2cw0Gc= github.com/mdlayher/netlink v0.0.0-20190409211403-11939a169225/go.mod h1:eQB3mZE4aiYnlUsyGGCOpPETfdQq4Jhsgf1fk3cwQaA= @@ -142,6 +146,8 @@ github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1: github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rqlite/gorqlite v0.0.0-20210514125552-08ff1e76b22f h1:BSnJgAfHzEp7o8PYJ7YfwAVHhqu7BYUTggcn/LGlUWY= +github.com/rqlite/gorqlite v0.0.0-20210514125552-08ff1e76b22f/go.mod h1:UW/gxgQwSePTvL1KA8QEHsXeYHP4xkoXgbDdN781p34= github.com/russross/blackfriday v1.5.2 h1:HyvC0ARfnZBqnXwABFeSZHpKvJHJJfPz81GNueLj0oo= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/russross/blackfriday/v2 v2.0.1 h1:lPqVAte+HuHNfhJ/0LC98ESWRz8afy9tM/0RK8m9o+Q= diff --git a/main.go b/main.go index f88a86c64..b605823b0 100644 --- a/main.go +++ b/main.go @@ -4,80 +4,75 @@ package main import ( - "log" - "github.com/gravitl/netmaker/controllers" - "github.com/gravitl/netmaker/servercfg" - "github.com/gravitl/netmaker/serverctl" - "github.com/gravitl/netmaker/mongoconn" - "github.com/gravitl/netmaker/functions" - "os" - "os/exec" - "net" - "context" - "strconv" - "sync" - "os/signal" - service "github.com/gravitl/netmaker/controllers" - nodepb "github.com/gravitl/netmaker/grpc" - "google.golang.org/grpc" + "context" + "log" + "net" + "os" + "os/exec" + "os/signal" + "strconv" + "sync" + + controller "github.com/gravitl/netmaker/controllers" + "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/functions" + nodepb "github.com/gravitl/netmaker/grpc" + "github.com/gravitl/netmaker/mongoconn" + "github.com/gravitl/netmaker/servercfg" + "github.com/gravitl/netmaker/serverctl" + "google.golang.org/grpc" ) //Start MongoDB Connection and start API Request Handler func main() { + checkModes() // check which flags are set and if root or not + initialize() // initial db and grpc server + defer database.Database.Close() + startControllers() // start the grpc or rest endpoints +} +func checkModes() { // Client Mode Prereq Check + var err error + cmd := exec.Command("id", "-u") + output, err := cmd.Output() - //Client Mode Prereq Check - if servercfg.IsClientMode() { - cmd := exec.Command("id", "-u") - output, err := cmd.Output() + if err != nil { + log.Println("Error running 'id -u' for prereq check. Please investigate or disable client mode.") + log.Fatal(err) + } + uid, err := strconv.Atoi(string(output[:len(output)-1])) + if err != nil { + log.Println("Error retrieving uid from 'id -u' for prereq check. Please investigate or disable client mode.") + log.Fatal(err) + } + if uid != 0 { + log.Fatal("To run in client mode requires root privileges. Either disable client mode or run with sudo.") + } + if servercfg.IsDNSMode() { + err := functions.SetDNSDir() if err != nil { - log.Println("Error running 'id -u' for prereq check. Please investigate or disable client mode.") log.Fatal(err) } - i, err := strconv.Atoi(string(output[:len(output)-1])) - if err != nil { - log.Println("Error retrieving uid from 'id -u' for prereq check. Please investigate or disable client mode.") - log.Fatal(err) - } - if i != 0 { - log.Fatal("To run in client mode requires root privileges. Either disable client mode or run with sudo.") - } } - if servercfg.IsDNSMode() { - err := functions.SetDNSDir() - if err != nil { - log.Fatal(err) - } - } - //Start Mongodb - mongoconn.ConnectDatabase() - - installserver := false - - //Create the default network (default: 10.10.10.0/24) - created, err := serverctl.CreateDefaultNetwork() - if err != nil { - log.Printf("Error creating default network: %v", err) - } +} - if created && servercfg.IsClientMode() { - installserver = true - } +func initialize() { + database.InitializeDatabase() if servercfg.IsGRPCWireGuard() { - err = serverctl.InitServerWireGuard() - //err = serverctl.ReconfigureServerWireGuard() - if err != nil { - log.Fatal(err) + if err := serverctl.InitServerWireGuard(); err != nil { + log.Fatal(err) } } + functions.PrintUserLog("netmaker", "successfully created db tables if not present", 1) +} +func startControllers() { var waitnetwork sync.WaitGroup - //Run Agent Server if servercfg.IsAgentBackend() { - if !(servercfg.DisableRemoteIPCheck()) && servercfg.GetGRPCHost() == "127.0.0.1" { + if !(servercfg.DisableRemoteIPCheck()) && servercfg.GetGRPCHost() == "127.0.0.1" { err := servercfg.SetHost() if err != nil { log.Println("Unable to Set host. Exiting...") @@ -85,23 +80,23 @@ func main() { } } waitnetwork.Add(1) - go runGRPC(&waitnetwork, installserver) + go runGRPC(&waitnetwork) } - if servercfg.IsDNSMode() { + if servercfg.IsDNSMode() { err := controller.SetDNS() - if err != nil { - log.Fatal(err) - } - } + if err != nil { + log.Fatal(err) + } + } //Run Rest Server if servercfg.IsRestBackend() { - if !servercfg.DisableRemoteIPCheck() && servercfg.GetAPIHost() == "127.0.0.1" { - err := servercfg.SetHost() - if err != nil { - log.Println("Unable to Set host. Exiting...") - log.Fatal(err) - } - } + if !servercfg.DisableRemoteIPCheck() && servercfg.GetAPIHost() == "127.0.0.1" { + err := servercfg.SetHost() + if err != nil { + log.Println("Unable to Set host. Exiting...") + log.Fatal(err) + } + } waitnetwork.Add(1) controller.HandleRESTRequests(&waitnetwork) } @@ -112,88 +107,67 @@ func main() { log.Println("exiting") } - -func runGRPC(wg *sync.WaitGroup, installserver bool) { - +func runGRPC(wg *sync.WaitGroup) { defer wg.Done() - // Configure 'log' package to give file name and line number on eg. log.Fatal - // Pipe flags to one another (log.LstdFLags = log.Ldate | log.Ltime) - log.SetFlags(log.LstdFlags | log.Lshortfile) + // Configure 'log' package to give file name and line number on eg. log.Fatal + // Pipe flags to one another (log.LstdFLags = log.Ldate | log.Ltime) + log.SetFlags(log.LstdFlags | log.Lshortfile) grpcport := servercfg.GetGRPCPort() listener, err := net.Listen("tcp", ":"+grpcport) - // Handle errors if any - if err != nil { - log.Fatalf("Unable to listen on port " + grpcport + ", error: %v", err) - } - - s := grpc.NewServer( - authServerUnaryInterceptor(), - authServerStreamInterceptor(), - ) - // Create NodeService type - srv := &service.NodeServiceServer{} - - // Register the service with the server - nodepb.RegisterNodeServiceServer(s, srv) - - srv.NodeDB = mongoconn.NodeDB - - // Start the server in a child routine - go func() { - if err := s.Serve(listener); err != nil { - log.Fatalf("Failed to serve: %v", err) - } - }() - log.Println("Agent Server succesfully started on port " + grpcport + " (gRPC)") - - if installserver { - success := true - if !servercfg.DisableDefaultNet() { - log.Println("Adding server to default network") - success, err = serverctl.AddNetwork("default") - } - if err != nil { - log.Printf("Error adding to default network: %v", err) - log.Println("Unable to add server to network. Continuing.") - log.Println("Please investigate client installation on server.") - } else if !success { - log.Println("Unable to add server to network. Continuing.") - log.Println("Please investigate client installation on server.") - } else{ - log.Println("Server successfully added to default network.") - } + // Handle errors if any + if err != nil { + log.Fatalf("Unable to listen on port "+grpcport+", error: %v", err) } - log.Println("Setup complete. You are ready to begin using netmaker.") - - // Right way to stop the server using a SHUTDOWN HOOK - // Create a channel to receive OS signals - c := make(chan os.Signal) - - // Relay os.Interrupt to our channel (os.Interrupt = CTRL+C) - // Ignore other incoming signals - signal.Notify(c, os.Interrupt) - - // Block main routine until a signal is received - // As long as user doesn't press CTRL+C a message is not passed and our main routine keeps running - <-c - - // After receiving CTRL+C Properly stop the server - log.Println("Stopping the Agent server...") - s.Stop() - listener.Close() - log.Println("Agent server closed..") - log.Println("Closing MongoDB connection") - mongoconn.Client.Disconnect(context.TODO()) - log.Println("MongoDB connection closed.") + + s := grpc.NewServer( + authServerUnaryInterceptor(), + authServerStreamInterceptor(), + ) + // Create NodeService type + srv := &controller.NodeServiceServer{} + + // Register the service with the server + nodepb.RegisterNodeServiceServer(s, srv) + + srv.NodeDB = mongoconn.NodeDB + + // Start the server in a child routine + go func() { + if err := s.Serve(listener); err != nil { + log.Fatalf("Failed to serve: %v", err) + } + }() + log.Println("Agent Server succesfully started on port " + grpcport + " (gRPC)") + + // Right way to stop the server using a SHUTDOWN HOOK + // Create a channel to receive OS signals + c := make(chan os.Signal) + + // Relay os.Interrupt to our channel (os.Interrupt = CTRL+C) + // Ignore other incoming signals + signal.Notify(c, os.Interrupt) + + // Block main routine until a signal is received + // As long as user doesn't press CTRL+C a message is not passed and our main routine keeps running + <-c + + // After receiving CTRL+C Properly stop the server + log.Println("Stopping the Agent server...") + s.Stop() + listener.Close() + log.Println("Agent server closed..") + log.Println("Closing MongoDB connection") + mongoconn.Client.Disconnect(context.TODO()) + log.Println("MongoDB connection closed.") } func authServerUnaryInterceptor() grpc.ServerOption { return grpc.UnaryInterceptor(controller.AuthServerUnaryInterceptor) } func authServerStreamInterceptor() grpc.ServerOption { - return grpc.StreamInterceptor(controller.AuthServerStreamInterceptor) + return grpc.StreamInterceptor(controller.AuthServerStreamInterceptor) } diff --git a/serverctl/serverctl.go b/serverctl/serverctl.go index 8d00383e7..527105f1e 100644 --- a/serverctl/serverctl.go +++ b/serverctl/serverctl.go @@ -1,137 +1,92 @@ package serverctl import ( - "log" + "context" + "encoding/json" + "errors" + "io" + "log" + "os" + "os/exec" + "time" + + "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/functions" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mongoconn" "github.com/gravitl/netmaker/servercfg" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/mongo/options" - "io" - "time" - "context" - "errors" - "os" - "os/exec" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo/options" ) -func CreateDefaultNetwork() (bool, error) { - - log.Println("Creating default network...") - - iscreated := false - exists, err := functions.NetworkExists("default") - - if exists || err != nil { - log.Println("Default network already exists. Skipping...") - return iscreated, err - } else { - - var network models.Network - - network.NetID = "default" - network.AddressRange = "10.10.10.0/24" - network.DisplayName = "default" - network.SetDefaults() - network.SetNodesLastModified() - network.SetNetworkLastModified() - network.KeyUpdateTimeStamp = time.Now().Unix() - priv := false - network.IsLocal = &priv - network.KeyUpdateTimeStamp = time.Now().Unix() - allow := true - network.AllowManualSignUp = &allow - - collection := mongoconn.Client.Database("netmaker").Collection("networks") - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - - // insert our network into the network table - _, err = collection.InsertOne(ctx, network) - defer cancel() - - } - if err == nil { - iscreated = true - } - return iscreated, err - - -} - func GetServerWGConf() (models.IntClient, error) { - var server models.IntClient - collection := mongoconn.Client.Database("netmaker").Collection("intclients") - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + var server models.IntClient + collection := mongoconn.Client.Database("netmaker").Collection("intclients") + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) filter := bson.M{"network": "comms", "isserver": "yes"} - err := collection.FindOne(ctx, filter, options.FindOne().SetProjection(bson.M{"_id": 0})).Decode(&server) + err := collection.FindOne(ctx, filter, options.FindOne().SetProjection(bson.M{"_id": 0})).Decode(&server) defer cancel() return server, err } - func CreateCommsNetwork() (bool, error) { - iscreated := false - exists, err := functions.NetworkExists("comms") - - if exists || err != nil { - log.Println("comms network already exists. Skipping...") - return true, nil - } else { - - var network models.Network - - network.NetID = "comms" - network.IsIPv6 = "no" - network.IsIPv4 = "yes" - network.IsGRPCHub = "yes" - network.AddressRange = servercfg.GetGRPCWGAddressRange() - network.DisplayName = "comms" - network.SetDefaults() - network.SetNodesLastModified() - network.SetNetworkLastModified() - network.KeyUpdateTimeStamp = time.Now().Unix() - priv := false - network.IsLocal = &priv - network.KeyUpdateTimeStamp = time.Now().Unix() - - log.Println("Creating comms network...") - - collection := mongoconn.Client.Database("netmaker").Collection("networks") - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - - // insert our network into the network table - _, err = collection.InsertOne(ctx, network) - defer cancel() - - } - if err == nil { - iscreated = true - } - return iscreated, err + iscreated := false + exists, err := functions.NetworkExists("comms") + + if exists || err != nil { + log.Println("comms network already exists. Skipping...") + return true, err + } else { + var network models.Network + + network.NetID = "comms" + network.IsIPv6 = "no" + network.IsIPv4 = "yes" + network.IsGRPCHub = "yes" + network.AddressRange = servercfg.GetGRPCWGAddressRange() + network.DisplayName = "comms" + network.SetDefaults() + network.SetNodesLastModified() + network.SetNetworkLastModified() + network.KeyUpdateTimeStamp = time.Now().Unix() + priv := false + network.IsLocal = &priv + network.KeyUpdateTimeStamp = time.Now().Unix() + + log.Println("Creating comms network...") + value, err := json.Marshal(network) + if err != nil { + return false, err + } + database.Insert(network.NetID, string(value), database.NETWORKS_TABLE_NAME) + } + if err == nil { + iscreated = true + } + return iscreated, err } func DownloadNetclient() error { /* - // Get the data - resp, err := http.Get("https://github.com/gravitl/netmaker/releases/download/latest/netclient") - if err != nil { - log.Println("could not download netclient") - return err - } - defer resp.Body.Close() - - // Create the file - out, err := os.Create("/etc/netclient/netclient") - */ - if !FileExists("/etc/netclient/netclient") { + // Get the data + resp, err := http.Get("https://github.com/gravitl/netmaker/releases/download/latest/netclient") + if err != nil { + log.Println("could not download netclient") + return err + } + defer resp.Body.Close() + + // Create the file + out, err := os.Create("/etc/netclient/netclient") + */ + if !FileExists("/etc/netclient/netclient") { _, err := copy("./netclient/netclient", "/etc/netclient/netclient") - if err != nil { - log.Println("could not create /etc/netclient") - return err - } + if err != nil { + log.Println("could not create /etc/netclient") + return err + } } //defer out.Close() @@ -141,98 +96,97 @@ func DownloadNetclient() error { } func FileExists(f string) bool { - info, err := os.Stat(f) - if os.IsNotExist(err) { - return false - } - return !info.IsDir() + info, err := os.Stat(f) + if os.IsNotExist(err) { + return false + } + return !info.IsDir() } func copy(src, dst string) (int64, error) { - sourceFileStat, err := os.Stat(src) - if err != nil { - return 0, err - } - - if !sourceFileStat.Mode().IsRegular() { - return 0, errors.New(src + " is not a regular file") - } - - source, err := os.Open(src) - if err != nil { - return 0, err - } - defer source.Close() - - destination, err := os.Create(dst) - if err != nil { - return 0, err - } - defer destination.Close() - nBytes, err := io.Copy(destination, source) - err = os.Chmod(dst, 0755) - if err != nil { - log.Println(err) - } - return nBytes, err + sourceFileStat, err := os.Stat(src) + if err != nil { + return 0, err + } + + if !sourceFileStat.Mode().IsRegular() { + return 0, errors.New(src + " is not a regular file") + } + + source, err := os.Open(src) + if err != nil { + return 0, err + } + defer source.Close() + + destination, err := os.Create(dst) + if err != nil { + return 0, err + } + defer destination.Close() + nBytes, err := io.Copy(destination, source) + err = os.Chmod(dst, 0755) + if err != nil { + log.Println(err) + } + return nBytes, err } func RemoveNetwork(network string) (bool, error) { _, err := os.Stat("/etc/netclient/netclient") - if err != nil { - log.Println("could not find /etc/netclient") + if err != nil { + log.Println("could not find /etc/netclient") + return false, err + } + cmdoutput, err := exec.Command("/etc/netclient/netclient", "leave", "-n", network).Output() + if err != nil { + log.Println(string(cmdoutput)) return false, err } - cmdoutput, err := exec.Command("/etc/netclient/netclient","leave","-n",network).Output() - if err != nil { - log.Println(string(cmdoutput)) - return false, err - } - log.Println("Server removed from network " + network) - return true, err + log.Println("Server removed from network " + network) + return true, err } func AddNetwork(network string) (bool, error) { pubip, err := servercfg.GetPublicIP() - if err != nil { - log.Println("could not get public IP.") - return false, err - } + if err != nil { + log.Println("could not get public IP.") + return false, err + } _, err = os.Stat("/etc/netclient") - if os.IsNotExist(err) { - os.Mkdir("/etc/netclient", 744) - } else if err != nil { - log.Println("could not find or create /etc/netclient") - return false, err + if os.IsNotExist(err) { + os.Mkdir("/etc/netclient", 744) + } else if err != nil { + log.Println("could not find or create /etc/netclient") + return false, err } token, err := functions.CreateServerToken(network) if err != nil { - log.Println("could not create server token for " + network) + log.Println("could not create server token for " + network) return false, err - } - _, err = os.Stat("/etc/netclient/netclient") + } + _, err = os.Stat("/etc/netclient/netclient") if os.IsNotExist(err) { err = DownloadNetclient() if err != nil { return false, err } } - err = os.Chmod("/etc/netclient/netclient", 0755) - if err != nil { - log.Println("could not change netclient directory permissions") - return false, err - } - log.Println("executing network join: " + "/etc/netclient/netclient "+"join "+"-t "+token+" -name "+"netmaker"+" -endpoint "+pubip) - out, err := exec.Command("/etc/netclient/netclient","join","-t",token,"-name","netmaker","-endpoint",pubip).Output() - if string(out) != "" { - log.Println(string(out)) + err = os.Chmod("/etc/netclient/netclient", 0755) + if err != nil { + log.Println("could not change netclient directory permissions") + return false, err + } + log.Println("executing network join: " + "/etc/netclient/netclient " + "join " + "-t " + token + " -name " + "netmaker" + " -endpoint " + pubip) + out, err := exec.Command("/etc/netclient/netclient", "join", "-t", token, "-name", "netmaker", "-endpoint", pubip).Output() + if string(out) != "" { + log.Println(string(out)) } if err != nil { - return false, errors.New(string(out) + err.Error()) - } + return false, errors.New(string(out) + err.Error()) + } log.Println("Server added to network " + network) return true, err } - diff --git a/serverctl/wireguard.go b/serverctl/wireguard.go index 4a2eb838e..5128ee21c 100644 --- a/serverctl/wireguard.go +++ b/serverctl/wireguard.go @@ -1,21 +1,22 @@ package serverctl import ( - //"github.com/davecgh/go-spew/spew" - "os" + //"github.com/davecgh/go-spew/spew" + + "encoding/json" + "errors" "log" - "context" - "time" "net" + "os" "strconv" - "errors" + + "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/functions" + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/servercfg" "github.com/vishvananda/netlink" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/gravitl/netmaker/servercfg" - "github.com/gravitl/netmaker/functions" - "github.com/gravitl/netmaker/models" - "github.com/gravitl/netmaker/mongoconn" ) func InitServerWireGuard() error { @@ -50,9 +51,9 @@ func InitServerWireGuard() error { } err = netlink.AddrAdd(wglink, address) - if err != nil && !os.IsExist(err){ - return err - } + if err != nil && !os.IsExist(err) { + return err + } err = netlink.LinkSetUp(wglink) if err != nil { log.Println("could not bring up wireguard interface") @@ -69,105 +70,104 @@ func InitServerWireGuard() error { client.Address = servercfg.GetGRPCWGAddress() client.IsServer = "yes" client.Network = "comms" - exists, _ := functions.ServerIntClientExists() - if exists { - + exists, _ := functions.GetServerIntClient() + if exists != nil { + err = RegisterServer(client) } - err = RegisterServer(client) - return err + return err } func DeleteServerClient() error { return nil } - func RegisterServer(client models.IntClient) error { - if client.PrivateKey == "" { - privateKey, err := wgtypes.GeneratePrivateKey() - if err != nil { - return err - } - - client.PrivateKey = privateKey.String() - client.PublicKey = privateKey.PublicKey().String() - } - - if client.Address == "" { - newAddress, err := functions.UniqueAddress(client.Network) - if err != nil { - return err - } + if client.PrivateKey == "" { + privateKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + return err + } + + client.PrivateKey = privateKey.String() + client.PublicKey = privateKey.PublicKey().String() + } + + if client.Address == "" { + newAddress, err := functions.UniqueAddress(client.Network) + if err != nil { + return err + } if newAddress == "" { return errors.New("Could not retrieve address") } - client.Address = newAddress - } - if client.Network == "" { client.Network = "comms" } - client.ServerKey = client.PublicKey - - collection := mongoconn.Client.Database("netmaker").Collection("intclients") - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - // insert our network into the network table - _, err := collection.InsertOne(ctx, client) - defer cancel() + client.Address = newAddress + } + if client.Network == "" { + client.Network = "comms" + } + client.ServerKey = client.PublicKey + value, err := json.Marshal(client) + if err != nil { + return err + } + database.Insert(client.PublicKey, string(value), database.INT_CLIENTS_TABLE_NAME) - ReconfigureServerWireGuard() + ReconfigureServerWireGuard() - return err + return err } func ReconfigureServerWireGuard() error { - server, err := GetServerWGConf() + server, err := functions.GetServerIntClient() if err != nil { - return err - } + return err + } serverkey, err := wgtypes.ParseKey(server.PrivateKey) - if err != nil { - return err - } + if err != nil { + return err + } serverport, err := strconv.Atoi(servercfg.GetGRPCWGPort()) - if err != nil { - return err - } + if err != nil { + return err + } peers, err := functions.GetIntPeersList() - if err != nil { - return err - } + if err != nil { + return err + } wgserver, err := wgctrl.New() if err != nil { return err } - var serverpeers []wgtypes.PeerConfig + var serverpeers []wgtypes.PeerConfig for _, peer := range peers { - pubkey, err := wgtypes.ParseKey(peer.PublicKey) + pubkey, err := wgtypes.ParseKey(peer.PublicKey) if err != nil { return err } - var peercfg wgtypes.PeerConfig - var allowedips []net.IPNet - if peer.Address != "" { + var peercfg wgtypes.PeerConfig + var allowedips []net.IPNet + if peer.Address != "" { var peeraddr = net.IPNet{ - IP: net.ParseIP(peer.Address), - Mask: net.CIDRMask(32, 32), - } - allowedips = append(allowedips, peeraddr) + IP: net.ParseIP(peer.Address), + Mask: net.CIDRMask(32, 32), + } + allowedips = append(allowedips, peeraddr) } if peer.Address6 != "" { - var addr6 = net.IPNet{ - IP: net.ParseIP(peer.Address6), - Mask: net.CIDRMask(128, 128), - } - allowedips = append(allowedips, addr6) - } + var addr6 = net.IPNet{ + IP: net.ParseIP(peer.Address6), + Mask: net.CIDRMask(128, 128), + } + allowedips = append(allowedips, addr6) + } peercfg = wgtypes.PeerConfig{ - PublicKey: pubkey, - ReplaceAllowedIPs: true, - AllowedIPs: allowedips, - } - serverpeers = append(serverpeers, peercfg) + PublicKey: pubkey, + ReplaceAllowedIPs: true, + AllowedIPs: allowedips, + } + serverpeers = append(serverpeers, peercfg) } wgconf := wgtypes.Config{