Skip to content

Commit

Permalink
use key types throughout the code
Browse files Browse the repository at this point in the history
Signed-off-by: Kristoffer Dalby <[email protected]>
  • Loading branch information
kradalby committed Nov 17, 2023
1 parent 462dbee commit d9e5b4a
Show file tree
Hide file tree
Showing 11 changed files with 207 additions and 168 deletions.
26 changes: 11 additions & 15 deletions hscontrol/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (h *Headscale) handleRegister(
Msg("New node not yet in the database")

givenName, err := h.db.GenerateGivenName(
machineKey.String(),
machineKey,
registerRequest.Hostinfo.Hostname,
)
if err != nil {
Expand All @@ -97,10 +97,10 @@ func (h *Headscale) handleRegister(
// We create the node and then keep it around until a callback
// happens
newNode := types.Node{
MachineKey: machineKey.String(),
MachineKey: machineKey,
Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName,
NodeKey: registerRequest.NodeKey.String(),
NodeKey: registerRequest.NodeKey,
LastSeen: &now,
Expiry: &time.Time{},
}
Expand All @@ -116,7 +116,7 @@ func (h *Headscale) handleRegister(
}

h.registrationCache.Set(
newNode.NodeKey,
newNode.NodeKey.String(),
newNode,
registerCacheExpiration,
)
Expand All @@ -134,11 +134,7 @@ func (h *Headscale) handleRegister(
// (juan): For a while we had a bug where we were not storing the MachineKey for the nodes using the TS2021,
// due to a misunderstanding of the protocol https://github.com/juanfont/headscale/issues/1054
// So if we have a not valid MachineKey (but we were able to fetch the node with the NodeKeys), we update it.
var storedMachineKey key.MachinePublic
err = storedMachineKey.UnmarshalText(
[]byte(node.MachineKey),
)
if err != nil || storedMachineKey.IsZero() {
if err != nil || node.MachineKey.IsZero() {
if err := h.db.NodeSetMachineKey(node, machineKey); err != nil {
log.Error().
Caller().
Expand All @@ -156,7 +152,7 @@ func (h *Headscale) handleRegister(
// - Trying to log out (sending a expiry in the past)
// - A valid, registered node, looking for /map
// - Expired node wanting to reauthenticate
if node.NodeKey == registerRequest.NodeKey.String() {
if node.NodeKey.String() == registerRequest.NodeKey.String() {
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !registerRequest.Expiry.IsZero() &&
Expand All @@ -176,7 +172,7 @@ func (h *Headscale) handleRegister(
}

// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
if node.NodeKey == registerRequest.OldNodeKey.String() &&
if node.NodeKey.String() == registerRequest.OldNodeKey.String() &&
!node.IsExpired() {
h.handleNodeKeyRefresh(
writer,
Expand Down Expand Up @@ -207,7 +203,7 @@ func (h *Headscale) handleRegister(
// we need to make sure the NodeKey matches the one in the request
// TODO(juan): What happens when using fast user switching between two
// headscale-managed tailnets?
node.NodeKey = registerRequest.NodeKey.String()
node.NodeKey = registerRequest.NodeKey
h.registrationCache.Set(
registerRequest.NodeKey.String(),
*node,
Expand Down Expand Up @@ -294,7 +290,7 @@ func (h *Headscale) handleAuthKey(
Str("node", registerRequest.Hostinfo.Hostname).
Msg("Authentication key was valid, proceeding to acquire IP addresses")

nodeKey := registerRequest.NodeKey.String()
nodeKey := registerRequest.NodeKey

// retrieve node information if it exist
// The error is not important, because if it does not
Expand Down Expand Up @@ -342,7 +338,7 @@ func (h *Headscale) handleAuthKey(
} else {
now := time.Now().UTC()

givenName, err := h.db.GenerateGivenName(machineKey.String(), registerRequest.Hostinfo.Hostname)
givenName, err := h.db.GenerateGivenName(machineKey, registerRequest.Hostinfo.Hostname)
if err != nil {
log.Error().
Caller().
Expand All @@ -359,7 +355,7 @@ func (h *Headscale) handleAuthKey(
Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName,
UserID: pak.User.ID,
MachineKey: machineKey.String(),
MachineKey: machineKey,
RegisterMethod: util.RegisterMethodAuthKey,
Expiry: &registerRequest.Expiry,
NodeKey: nodeKey,
Expand Down
43 changes: 24 additions & 19 deletions hscontrol/db/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (hsdb *HSDatabase) listPeers(node *types.Node) (types.Nodes, error) {
Preload("User").
Preload("Routes").
Where("node_key <> ?",
node.NodeKey).Find(&nodes).Error; err != nil {
node.NodeKey.String()).Find(&nodes).Error; err != nil {
return types.Nodes{}, err
}

Expand Down Expand Up @@ -268,7 +268,7 @@ func (hsdb *HSDatabase) SetTags(
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
Type: types.StatePeerChanged,
Changed: types.Nodes{node},
}, node.MachineKey)
}, node.MachineKey.String())

return nil
}
Expand Down Expand Up @@ -304,7 +304,7 @@ func (hsdb *HSDatabase) RenameNode(node *types.Node, newName string) error {
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
Type: types.StatePeerChanged,
Changed: types.Nodes{node},
}, node.MachineKey)
}, node.MachineKey.String())

return nil
}
Expand All @@ -330,7 +330,7 @@ func (hsdb *HSDatabase) nodeSetExpiry(node *types.Node, expiry time.Time) error
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
Type: types.StatePeerChanged,
Changed: types.Nodes{node},
}, node.MachineKey)
}, node.MachineKey.String())

return nil
}
Expand Down Expand Up @@ -448,8 +448,8 @@ func (hsdb *HSDatabase) RegisterNode(node types.Node) (*types.Node, error) {
func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
log.Debug().
Str("node", node.Hostname).
Str("machine_key", node.MachineKey).
Str("node_key", node.NodeKey).
Str("machine_key", node.MachineKey.ShortString()).
Str("node_key", node.NodeKey.ShortString()).
Str("user", node.User.Name).
Msg("Registering node")

Expand All @@ -464,8 +464,8 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) {
log.Trace().
Caller().
Str("node", node.Hostname).
Str("machine_key", node.MachineKey).
Str("node_key", node.NodeKey).
Str("machine_key", node.MachineKey.ShortString()).
Str("node_key", node.NodeKey.ShortString()).
Str("user", node.User.Name).
Msg("Node authorized again")

Expand Down Expand Up @@ -507,7 +507,7 @@ func (hsdb *HSDatabase) NodeSetNodeKey(node *types.Node, nodeKey key.NodePublic)
defer hsdb.mu.Unlock()

if err := hsdb.db.Model(node).Updates(types.Node{
NodeKey: nodeKey.String(),
NodeKey: nodeKey,
}).Error; err != nil {
return err
}
Expand All @@ -524,7 +524,7 @@ func (hsdb *HSDatabase) NodeSetMachineKey(
defer hsdb.mu.Unlock()

if err := hsdb.db.Model(node).Updates(types.Node{
MachineKey: machineKey.String(),
MachineKey: machineKey,
}).Error; err != nil {
return err
}
Expand Down Expand Up @@ -703,7 +703,7 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
Type: types.StatePeerChanged,
Changed: types.Nodes{node},
}, node.MachineKey)
}, node.MachineKey.String())

return nil
}
Expand Down Expand Up @@ -734,7 +734,7 @@ func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
return normalizedHostname, nil
}

func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string) (string, error) {
func (hsdb *HSDatabase) GenerateGivenName(mkey key.MachinePublic, suppliedName string) (string, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()

Expand All @@ -749,15 +749,20 @@ func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string
return "", err
}

for _, node := range nodes {
if node.MachineKey != machineKey && node.GivenName == givenName {
postfixedName, err := generateGivenName(suppliedName, true)
if err != nil {
return "", err
}
var nodeFound *types.Node
for idx, node := range nodes {
if node.GivenName == givenName {
nodeFound = nodes[idx]
}
}

givenName = postfixedName
if nodeFound != nil && nodeFound.MachineKey.String() != mkey.String() {
postfixedName, err := generateGivenName(suppliedName, true)
if err != nil {
return "", err
}

givenName = postfixedName
}

return givenName, nil
Expand Down
Loading

0 comments on commit d9e5b4a

Please sign in to comment.