Skip to content

Commit

Permalink
initial replacing mkey with random registration ID
Browse files Browse the repository at this point in the history
Signed-off-by: Kristoffer Dalby <[email protected]>
  • Loading branch information
kradalby committed Jan 10, 2025
1 parent 8ef323a commit 19945fd
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 103 deletions.
6 changes: 3 additions & 3 deletions hscontrol/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ type Headscale struct {
mapper *mapper.Mapper
nodeNotifier *notifier.Notifier

registrationCache *zcache.Cache[string, types.Node]
registrationCache *zcache.Cache[types.RegistrationID, types.Node]

authProvider AuthProvider

Expand All @@ -123,7 +123,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err)
}

registrationCache := zcache.New[string, types.Node](
registrationCache := zcache.New[types.RegistrationID, types.Node](
registerCacheExpiration,
registerCacheCleanup,
)
Expand Down Expand Up @@ -462,7 +462,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {

router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet)
router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
router.HandleFunc("/register/{mkey}", h.authProvider.RegisterHandler).Methods(http.MethodGet)
router.HandleFunc("/register/{registration_id}", h.authProvider.RegisterHandler).Methods(http.MethodGet)

if provider, ok := h.authProvider.(*AuthProviderOIDC); ok {
router.HandleFunc("/oidc/callback", provider.OIDCCallbackHandler).Methods(http.MethodGet)
Expand Down
44 changes: 30 additions & 14 deletions hscontrol/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@ import (

type AuthProvider interface {
RegisterHandler(http.ResponseWriter, *http.Request)
AuthURL(key.MachinePublic) string
AuthURL(types.RegistrationID) string
}

func logAuthFunc(
registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic,
registrationId types.RegistrationID,
) (func(string), func(string), func(error, string)) {
return func(msg string) {
log.Info().
Caller().
Str("registration_id", registrationId.String()).
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Expand All @@ -41,6 +43,7 @@ func logAuthFunc(
func(msg string) {
log.Trace().
Caller().
Str("registration_id", registrationId.String()).
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Expand All @@ -52,6 +55,7 @@ func logAuthFunc(
func(err error, msg string) {
log.Error().
Caller().
Str("registration_id", registrationId.String()).
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Expand All @@ -70,7 +74,18 @@ func (h *Headscale) handleRegister(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) {
logInfo, logTrace, _ := logAuthFunc(regReq, machineKey)
registrationId, err := types.NewRegistrationID()
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to generate registration ID")
http.Error(writer, "Internal server error", http.StatusInternalServerError)

return
}

logInfo, logTrace, _ := logAuthFunc(regReq, machineKey, registrationId)
now := time.Now().UTC()
logTrace("handleRegister called, looking up machine in DB")
node, err := h.db.GetNodeByAnyKey(machineKey, regReq.NodeKey, regReq.OldNodeKey)
Expand All @@ -93,14 +108,14 @@ func (h *Headscale) handleRegister(
// successful RegisterResponse.
if regReq.Followup != "" {
logTrace("register request is a followup")
if _, ok := h.registrationCache.Get(machineKey.String()); ok {
if _, ok := h.registrationCache.Get(registrationId); ok {
logTrace("Node is waiting for interactive login")

select {
case <-req.Context().Done():
return
case <-time.After(registrationHoldoff):
h.handleNewNode(writer, regReq, machineKey)
h.handleNewNode(writer, regReq, registrationId)

return
}
Expand All @@ -127,11 +142,11 @@ func (h *Headscale) handleRegister(
}

h.registrationCache.Set(
machineKey.String(),
registrationId,
newNode,
)

h.handleNewNode(writer, regReq, machineKey)
h.handleNewNode(writer, regReq, registrationId)

return
}
Expand Down Expand Up @@ -214,7 +229,7 @@ func (h *Headscale) handleRegister(
}

// The node has expired or it is logged out
h.handleNodeExpiredOrLoggedOut(writer, regReq, *node, machineKey)
h.handleNodeExpiredOrLoggedOut(writer, regReq, *node, machineKey, registrationId)

// TODO(juan): RegisterRequest includes an Expiry time, that we could optionally use
node.Expiry = &time.Time{}
Expand All @@ -225,7 +240,7 @@ func (h *Headscale) handleRegister(
// headscale-managed tailnets?
node.NodeKey = regReq.NodeKey
h.registrationCache.Set(
machineKey.String(),
registrationId,
*node,
)

Expand Down Expand Up @@ -444,16 +459,16 @@ func (h *Headscale) handleAuthKey(
func (h *Headscale) handleNewNode(
writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic,
registrationId types.RegistrationID,
) {
logInfo, logTrace, logErr := logAuthFunc(registerRequest, machineKey)
logInfo, logTrace, logErr := logAuthFunc(registerRequest, key.MachinePublic{}, registrationId)

resp := tailcfg.RegisterResponse{}

// The node registration is new, redirect the client to the registration URL
logTrace("The node seems to be new, sending auth url")

resp.AuthURL = h.authProvider.AuthURL(machineKey)
resp.AuthURL = h.authProvider.AuthURL(registrationId)

respBody, err := json.Marshal(resp)
if err != nil {
Expand Down Expand Up @@ -660,6 +675,7 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
regReq tailcfg.RegisterRequest,
node types.Node,
machineKey key.MachinePublic,
registrationId types.RegistrationID,
) {
resp := tailcfg.RegisterResponse{}

Expand All @@ -673,12 +689,12 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
log.Trace().
Caller().
Str("node", node.Hostname).
Str("machine_key", machineKey.ShortString()).
Str("registration_id", registrationId.String()).
Str("node_key", regReq.NodeKey.ShortString()).
Str("node_key_old", regReq.OldNodeKey.ShortString()).
Msg("Node registration has expired or logged out. Sending a auth url to register")

resp.AuthURL = h.authProvider.AuthURL(machineKey)
resp.AuthURL = h.authProvider.AuthURL(registrationId)

respBody, err := json.Marshal(resp)
if err != nil {
Expand All @@ -703,7 +719,7 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(

log.Trace().
Caller().
Str("machine_key", machineKey.ShortString()).
Str("registration_id", registrationId.String()).
Str("node_key", regReq.NodeKey.ShortString()).
Str("node_key_old", regReq.OldNodeKey.ShortString()).
Str("node", node.Hostname).
Expand Down
4 changes: 2 additions & 2 deletions hscontrol/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type KV struct {
type HSDatabase struct {
DB *gorm.DB
cfg *types.DatabaseConfig
regCache *zcache.Cache[string, types.Node]
regCache *zcache.Cache[types.RegistrationID, types.Node]

baseDomain string
}
Expand All @@ -51,7 +51,7 @@ type HSDatabase struct {
func NewHeadscaleDatabase(
cfg types.DatabaseConfig,
baseDomain string,
regCache *zcache.Cache[string, types.Node],
regCache *zcache.Cache[types.RegistrationID, types.Node],
) (*HSDatabase, error) {
dbConn, err := openDB(cfg)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions hscontrol/db/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,8 @@ func testCopyOfDatabase(src string) (string, error) {
return dst, err
}

func emptyCache() *zcache.Cache[string, types.Node] {
return zcache.New[string, types.Node](time.Minute, time.Hour)
func emptyCache() *zcache.Cache[types.RegistrationID, types.Node] {
return zcache.New[types.RegistrationID, types.Node](time.Minute, time.Hour)
}

// requireConstraintFailed checks if the error is a constraint failure with
Expand Down
8 changes: 4 additions & 4 deletions hscontrol/db/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,15 +320,15 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
}

func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
mkey key.MachinePublic,
registrationID types.RegistrationID,
userID types.UserID,
nodeExpiry *time.Time,
registrationMethod string,
ipv4 *netip.Addr,
ipv6 *netip.Addr,
) (*types.Node, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
if node, ok := hsdb.regCache.Get(mkey.String()); ok {
if node, ok := hsdb.regCache.Get(registrationID); ok {
user, err := GetUserByID(tx, userID)
if err != nil {
return nil, fmt.Errorf(
Expand All @@ -338,7 +338,7 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
}

log.Debug().
Str("machine_key", mkey.ShortString()).
Str("registration_id", registrationID.String()).
Str("username", user.Username()).
Str("registrationMethod", registrationMethod).
Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)).
Expand All @@ -365,7 +365,7 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
)

if err == nil {
hsdb.regCache.Delete(mkey.String())
hsdb.regCache.Delete(registrationID)
}

return node, err
Expand Down
21 changes: 9 additions & 12 deletions hscontrol/grpcv1.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,10 @@ func (api headscaleV1APIServer) RegisterNode(
) (*v1.RegisterNodeResponse, error) {
log.Trace().
Str("user", request.GetUser()).
Str("machine_key", request.GetKey()).
Str("registration_id", request.GetKey()).
Msg("Registering node")

var mkey key.MachinePublic
err := mkey.UnmarshalText([]byte(request.GetKey()))
registrationId, err := types.RegistrationIDFromString(request.GetKey())
if err != nil {
return nil, err
}
Expand All @@ -247,7 +246,7 @@ func (api headscaleV1APIServer) RegisterNode(
}

node, err := api.h.db.RegisterNodeFromAuthCallback(
mkey,
registrationId,
types.UserID(user.ID),
nil,
util.RegisterMethodCLI,
Expand Down Expand Up @@ -839,19 +838,17 @@ func (api headscaleV1APIServer) DebugCreateNode(
Hostname: "DebugTestNode",
}

var mkey key.MachinePublic
err = mkey.UnmarshalText([]byte(request.GetKey()))
registrationId, err := types.RegistrationIDFromString(request.GetKey())
if err != nil {
return nil, err
}

nodeKey := key.NewNode()

newNode := types.Node{
MachineKey: mkey,
NodeKey: nodeKey.Public(),
Hostname: request.GetName(),
User: *user,
NodeKey: nodeKey.Public(),
Hostname: request.GetName(),
User: *user,

Expiry: &time.Time{},
LastSeen: &time.Time{},
Expand All @@ -860,11 +857,11 @@ func (api headscaleV1APIServer) DebugCreateNode(
}

log.Debug().
Str("machine_key", mkey.ShortString()).
Str("registration_id", registrationId.String()).
Msg("adding debug machine via CLI, appending to registration cache")

api.h.registrationCache.Set(
mkey.String(),
registrationId,
newNode,
)

Expand Down
5 changes: 3 additions & 2 deletions hscontrol/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/chasefleming/elem-go/styles"
"github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/templates"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
Expand Down Expand Up @@ -239,11 +240,11 @@ func NewAuthProviderWeb(serverURL string) *AuthProviderWeb {
}
}

func (a *AuthProviderWeb) AuthURL(mKey key.MachinePublic) string {
func (a *AuthProviderWeb) AuthURL(registrationId types.RegistrationID) string {
return fmt.Sprintf(
"%s/register/%s",
strings.TrimSuffix(a.serverURL, "/"),
mKey.String())
registrationId.String())
}

// RegisterWebAPI shows a simple message in the browser to point to the CLI
Expand Down
Loading

0 comments on commit 19945fd

Please sign in to comment.