diff --git a/hscontrol/app.go b/hscontrol/app.go index 4c2c861fad..78b72bf51f 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -12,7 +12,6 @@ import ( "os" "os/signal" "runtime" - "strconv" "strings" "sync" "syscall" @@ -118,37 +117,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err) } - var dbString string - switch cfg.Database.Type { - case db.Postgres: - dbString = fmt.Sprintf( - "host=%s dbname=%s user=%s", - cfg.Database.Postgres.Host, - cfg.Database.Postgres.Name, - cfg.Database.Postgres.User, - ) - - if sslEnabled, err := strconv.ParseBool(cfg.Database.Postgres.Ssl); err == nil { - if !sslEnabled { - dbString += " sslmode=disable" - } - } else { - dbString += fmt.Sprintf(" sslmode=%s", cfg.Database.Postgres.Ssl) - } - - if cfg.Database.Postgres.Port != 0 { - dbString += fmt.Sprintf(" port=%d", cfg.Database.Postgres.Port) - } - - if cfg.Database.Postgres.Pass != "" { - dbString += fmt.Sprintf(" password=%s", cfg.Database.Postgres.Pass) - } - case db.Sqlite: - dbString = cfg.Database.Sqlite.Path - default: - return nil, errUnsupportedDatabase - } - registrationCache := cache.New( registerCacheExpiration, registerCacheCleanup, @@ -156,8 +124,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { app := Headscale{ cfg: cfg, - dbType: cfg.Database.Type, - dbString: dbString, noisePrivateKey: noisePrivateKey, registrationCache: registrationCache, pollNetMapStreamWG: sync.WaitGroup{}, @@ -165,9 +131,8 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { } database, err := db.NewHeadscaleDatabase( - cfg.Database.Type, - dbString, - app.dbDebug, + cfg.Database, + app.nodeNotifier, cfg.IPPrefixes, cfg.BaseDomain) if err != nil { @@ -755,8 +720,10 @@ func (h *Headscale) Serve() error { var tailsqlContext context.Context if tailsqlEnabled { - if h.cfg.Database.Type != db.Sqlite { - log.Fatal().Str("type", h.cfg.Database.Type).Msgf("tailsql only support %q", db.Sqlite) + if h.cfg.Database.Type != types.DatabaseSqlite { + log.Fatal(). + Str("type", h.cfg.Database.Type). + Msgf("tailsql only support %q", types.DatabaseSqlite) } if tailsqlTSKey == "" { log.Fatal().Msg("tailsql requires TS_AUTHKEY to be set") diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index df7b0a4c7f..fe77dda855 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -6,11 +6,13 @@ import ( "errors" "fmt" "net/netip" + "strconv" "strings" "time" "github.com/glebarez/sqlite" "github.com/go-gormigrate/gormigrate/v2" + "github.com/juanfont/headscale/hscontrol/notifier" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" @@ -19,11 +21,6 @@ import ( "gorm.io/gorm/logger" ) -const ( - Postgres = "postgres" - Sqlite = "sqlite3" -) - var errDatabaseNotSupported = errors.New("database type not supported") // KV is a key-value store in a psql table. For future use... @@ -43,12 +40,12 @@ type HSDatabase struct { // TODO(kradalby): assemble this struct from toptions or something typed // rather than arguments. func NewHeadscaleDatabase( - dbType, connectionAddr string, - debug bool, + cfg types.DatabaseConfig, + notifier *notifier.Notifier, ipPrefixes []netip.Prefix, baseDomain string, ) (*HSDatabase, error) { - dbConn, err := openDB(dbType, connectionAddr, debug) + dbConn, err := openDB(cfg) if err != nil { return nil, err } @@ -62,7 +59,7 @@ func NewHeadscaleDatabase( { ID: "202312101416", Migrate: func(tx *gorm.DB) error { - if dbType == Postgres { + if cfg.Type == types.DatabasePostgres { tx.Exec(`create extension if not exists "uuid-ossp";`) } @@ -321,20 +318,20 @@ func NewHeadscaleDatabase( return &db, err } -func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) { - log.Debug().Str("type", dbType).Str("connection", connectionAddr).Msg("opening database") +func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) { + // TODO(kradalby): Integrate this with zerolog var dbLogger logger.Interface - if debug { + if cfg.Debug { dbLogger = logger.Default } else { dbLogger = logger.Default.LogMode(logger.Silent) } - switch dbType { - case Sqlite: + switch cfg.Type { + case types.DatabaseSqlite: db, err := gorm.Open( - sqlite.Open(connectionAddr+"?_synchronous=1&_journal_mode=WAL"), + sqlite.Open(cfg.Sqlite.Path+"?_synchronous=1&_journal_mode=WAL"), &gorm.Config{ DisableForeignKeyConstraintWhenMigrating: true, Logger: dbLogger, @@ -353,8 +350,31 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) { return db, err - case Postgres: - return gorm.Open(postgres.Open(connectionAddr), &gorm.Config{ + case types.DatabasePostgres: + dbString := fmt.Sprintf( + "host=%s dbname=%s user=%s", + cfg.Postgres.Host, + cfg.Postgres.Name, + cfg.Postgres.User, + ) + + if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil { + if !sslEnabled { + dbString += " sslmode=disable" + } + } else { + dbString += fmt.Sprintf(" sslmode=%s", cfg.Postgres.Ssl) + } + + if cfg.Postgres.Port != 0 { + dbString += fmt.Sprintf(" port=%d", cfg.Postgres.Port) + } + + if cfg.Postgres.Pass != "" { + dbString += fmt.Sprintf(" password=%s", cfg.Postgres.Pass) + } + + return gorm.Open(postgres.Open(dbString), &gorm.Config{ DisableForeignKeyConstraintWhenMigrating: true, Logger: dbLogger, }) @@ -362,7 +382,7 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) { return nil, fmt.Errorf( "database of type %s is not supported: %w", - dbType, + cfg.Type, errDatabaseNotSupported, ) } diff --git a/hscontrol/db/routes_test.go b/hscontrol/db/routes_test.go index 3b544aa70f..5d6281e83e 100644 --- a/hscontrol/db/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/notifier" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/stretchr/testify/assert" @@ -654,9 +655,13 @@ func TestFailoverRoute(t *testing.T) { assert.NoError(t, err) db, err = NewHeadscaleDatabase( - "sqlite3", - tmpDir+"/headscale_test.db", - false, + types.DatabaseConfig{ + Type: "sqlite3", + Sqlite: types.SqliteConfig{ + Path: tmpDir + "/headscale_test.db", + }, + }, + notifier.NewNotifier(), []netip.Prefix{ netip.MustParsePrefix("10.27.0.0/23"), }, diff --git a/hscontrol/db/suite_test.go b/hscontrol/db/suite_test.go index d4b11b140e..e176e4b296 100644 --- a/hscontrol/db/suite_test.go +++ b/hscontrol/db/suite_test.go @@ -6,6 +6,8 @@ import ( "os" "testing" + "github.com/juanfont/headscale/hscontrol/notifier" + "github.com/juanfont/headscale/hscontrol/types" "gopkg.in/check.v1" ) @@ -44,9 +46,13 @@ func (s *Suite) ResetDB(c *check.C) { log.Printf("database path: %s", tmpDir+"/headscale_test.db") db, err = NewHeadscaleDatabase( - "sqlite3", - tmpDir+"/headscale_test.db", - false, + types.DatabaseConfig{ + Type: "sqlite3", + Sqlite: types.SqliteConfig{ + Path: tmpDir + "/headscale_test.db", + }, + }, + notifier.NewNotifier(), []netip.Prefix{ netip.MustParsePrefix("10.27.0.0/23"), }, diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index d45f9d4cca..ceeceea004 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -12,7 +12,11 @@ import ( "tailscale.com/tailcfg" ) -const SelfUpdateIdentifier = "self-update" +const ( + SelfUpdateIdentifier = "self-update" + DatabasePostgres = "postgres" + DatabaseSqlite = "sqlite3" +) var ErrCannotParsePrefix = errors.New("cannot parse prefix") @@ -154,7 +158,9 @@ func (su *StateUpdate) Valid() bool { } case StateSelfUpdate: if su.ChangeNodes == nil || len(su.ChangeNodes) != 1 { - panic("Mandatory field ChangeNodes is not set for StateSelfUpdate or has more than one node") + panic( + "Mandatory field ChangeNodes is not set for StateSelfUpdate or has more than one node", + ) } case StateDERPUpdated: if su.DERPMap == nil { diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index 9b0134430b..d83b21f761 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -85,7 +85,8 @@ type PostgresConfig struct { type DatabaseConfig struct { // Type sets the database type, either "sqlite3" or "postgres" - Type string + Type string + Debug bool Sqlite SqliteConfig Postgres PostgresConfig @@ -418,9 +419,12 @@ func GetLogConfig() LogConfig { } func GetDatabaseConfig() DatabaseConfig { + debug := viper.GetBool("database.debug") + type_ := viper.GetString("database.type") + switch type_ { - case "sqlite3", "postgres": + case DatabaseSqlite, DatabasePostgres: break case "sqlite": type_ = "sqlite3" @@ -429,7 +433,8 @@ func GetDatabaseConfig() DatabaseConfig { } return DatabaseConfig{ - Type: type_, + Type: type_, + Debug: debug, Sqlite: SqliteConfig{ Path: util.AbsolutePathFromConfigPath(viper.GetString("database.sqlite.path")), },