Skip to content

Commit

Permalink
pass db config to db module, move url logic
Browse files Browse the repository at this point in the history
Signed-off-by: Kristoffer Dalby <[email protected]>
  • Loading branch information
kradalby committed Feb 8, 2024
1 parent 793faab commit fdbe0dc
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 68 deletions.
45 changes: 6 additions & 39 deletions hscontrol/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"os"
"os/signal"
"runtime"
"strconv"
"strings"
"sync"
"syscall"
Expand Down Expand Up @@ -118,56 +117,22 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err)
}

var dbString string
switch cfg.Database.Type {
case db.Postgres:
dbString = fmt.Sprintf(
"host=%s dbname=%s user=%s",
cfg.Database.Postgres.Host,
cfg.Database.Postgres.Name,
cfg.Database.Postgres.User,
)

if sslEnabled, err := strconv.ParseBool(cfg.Database.Postgres.Ssl); err == nil {
if !sslEnabled {
dbString += " sslmode=disable"
}
} else {
dbString += fmt.Sprintf(" sslmode=%s", cfg.Database.Postgres.Ssl)
}

if cfg.Database.Postgres.Port != 0 {
dbString += fmt.Sprintf(" port=%d", cfg.Database.Postgres.Port)
}

if cfg.Database.Postgres.Pass != "" {
dbString += fmt.Sprintf(" password=%s", cfg.Database.Postgres.Pass)
}
case db.Sqlite:
dbString = cfg.Database.Sqlite.Path
default:
return nil, errUnsupportedDatabase
}

registrationCache := cache.New(
registerCacheExpiration,
registerCacheCleanup,
)

app := Headscale{
cfg: cfg,
dbType: cfg.Database.Type,
dbString: dbString,
noisePrivateKey: noisePrivateKey,
registrationCache: registrationCache,
pollNetMapStreamWG: sync.WaitGroup{},
nodeNotifier: notifier.NewNotifier(),
}

database, err := db.NewHeadscaleDatabase(
cfg.Database.Type,
dbString,
app.dbDebug,
cfg.Database,
app.nodeNotifier,
cfg.IPPrefixes,
cfg.BaseDomain)
if err != nil {
Expand Down Expand Up @@ -755,8 +720,10 @@ func (h *Headscale) Serve() error {

var tailsqlContext context.Context
if tailsqlEnabled {
if h.cfg.Database.Type != db.Sqlite {
log.Fatal().Str("type", h.cfg.Database.Type).Msgf("tailsql only support %q", db.Sqlite)
if h.cfg.Database.Type != types.DatabaseSqlite {
log.Fatal().
Str("type", h.cfg.Database.Type).
Msgf("tailsql only support %q", types.DatabaseSqlite)
}
if tailsqlTSKey == "" {
log.Fatal().Msg("tailsql requires TS_AUTHKEY to be set")
Expand Down
56 changes: 38 additions & 18 deletions hscontrol/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ import (
"errors"
"fmt"
"net/netip"
"strconv"
"strings"
"time"

"github.com/glebarez/sqlite"
"github.com/go-gormigrate/gormigrate/v2"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
Expand All @@ -19,11 +21,6 @@ import (
"gorm.io/gorm/logger"
)

const (
Postgres = "postgres"
Sqlite = "sqlite3"
)

var errDatabaseNotSupported = errors.New("database type not supported")

// KV is a key-value store in a psql table. For future use...
Expand All @@ -43,12 +40,12 @@ type HSDatabase struct {
// TODO(kradalby): assemble this struct from toptions or something typed
// rather than arguments.
func NewHeadscaleDatabase(
dbType, connectionAddr string,
debug bool,
cfg types.DatabaseConfig,
notifier *notifier.Notifier,
ipPrefixes []netip.Prefix,
baseDomain string,
) (*HSDatabase, error) {
dbConn, err := openDB(dbType, connectionAddr, debug)
dbConn, err := openDB(cfg)
if err != nil {
return nil, err
}
Expand All @@ -62,7 +59,7 @@ func NewHeadscaleDatabase(
{
ID: "202312101416",
Migrate: func(tx *gorm.DB) error {
if dbType == Postgres {
if cfg.Type == types.DatabasePostgres {
tx.Exec(`create extension if not exists "uuid-ossp";`)
}

Expand Down Expand Up @@ -321,20 +318,20 @@ func NewHeadscaleDatabase(
return &db, err
}

func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
log.Debug().Str("type", dbType).Str("connection", connectionAddr).Msg("opening database")
func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {

// TODO(kradalby): Integrate this with zerolog
var dbLogger logger.Interface
if debug {
if cfg.Debug {
dbLogger = logger.Default
} else {
dbLogger = logger.Default.LogMode(logger.Silent)
}

switch dbType {
case Sqlite:
switch cfg.Type {
case types.DatabaseSqlite:
db, err := gorm.Open(
sqlite.Open(connectionAddr+"?_synchronous=1&_journal_mode=WAL"),
sqlite.Open(cfg.Sqlite.Path+"?_synchronous=1&_journal_mode=WAL"),
&gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
Logger: dbLogger,
Expand All @@ -353,16 +350,39 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {

return db, err

case Postgres:
return gorm.Open(postgres.Open(connectionAddr), &gorm.Config{
case types.DatabasePostgres:
dbString := fmt.Sprintf(
"host=%s dbname=%s user=%s",
cfg.Postgres.Host,
cfg.Postgres.Name,
cfg.Postgres.User,
)

if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil {
if !sslEnabled {
dbString += " sslmode=disable"
}
} else {
dbString += fmt.Sprintf(" sslmode=%s", cfg.Postgres.Ssl)
}

if cfg.Postgres.Port != 0 {
dbString += fmt.Sprintf(" port=%d", cfg.Postgres.Port)
}

if cfg.Postgres.Pass != "" {
dbString += fmt.Sprintf(" password=%s", cfg.Postgres.Pass)
}

return gorm.Open(postgres.Open(dbString), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
Logger: dbLogger,
})
}

return nil, fmt.Errorf(
"database of type %s is not supported: %w",
dbType,
cfg.Type,
errDatabaseNotSupported,
)
}
Expand Down
11 changes: 8 additions & 3 deletions hscontrol/db/routes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -654,9 +655,13 @@ func TestFailoverRoute(t *testing.T) {
assert.NoError(t, err)

db, err = NewHeadscaleDatabase(
"sqlite3",
tmpDir+"/headscale_test.db",
false,
types.DatabaseConfig{
Type: "sqlite3",
Sqlite: types.SqliteConfig{
Path: tmpDir + "/headscale_test.db",
},
},
notifier.NewNotifier(),
[]netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
},
Expand Down
12 changes: 9 additions & 3 deletions hscontrol/db/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"os"
"testing"

"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/types"
"gopkg.in/check.v1"
)

Expand Down Expand Up @@ -44,9 +46,13 @@ func (s *Suite) ResetDB(c *check.C) {
log.Printf("database path: %s", tmpDir+"/headscale_test.db")

db, err = NewHeadscaleDatabase(
"sqlite3",
tmpDir+"/headscale_test.db",
false,
types.DatabaseConfig{
Type: "sqlite3",
Sqlite: types.SqliteConfig{
Path: tmpDir + "/headscale_test.db",
},
},
notifier.NewNotifier(),
[]netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
},
Expand Down
10 changes: 8 additions & 2 deletions hscontrol/types/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ import (
"tailscale.com/tailcfg"
)

const SelfUpdateIdentifier = "self-update"
const (
SelfUpdateIdentifier = "self-update"
DatabasePostgres = "postgres"
DatabaseSqlite = "sqlite3"
)

var ErrCannotParsePrefix = errors.New("cannot parse prefix")

Expand Down Expand Up @@ -154,7 +158,9 @@ func (su *StateUpdate) Valid() bool {
}
case StateSelfUpdate:
if su.ChangeNodes == nil || len(su.ChangeNodes) != 1 {
panic("Mandatory field ChangeNodes is not set for StateSelfUpdate or has more than one node")
panic(
"Mandatory field ChangeNodes is not set for StateSelfUpdate or has more than one node",
)
}
case StateDERPUpdated:
if su.DERPMap == nil {
Expand Down
11 changes: 8 additions & 3 deletions hscontrol/types/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ type PostgresConfig struct {

type DatabaseConfig struct {
// Type sets the database type, either "sqlite3" or "postgres"
Type string
Type string
Debug bool

Sqlite SqliteConfig
Postgres PostgresConfig
Expand Down Expand Up @@ -418,9 +419,12 @@ func GetLogConfig() LogConfig {
}

func GetDatabaseConfig() DatabaseConfig {
debug := viper.GetBool("database.debug")

type_ := viper.GetString("database.type")

switch type_ {
case "sqlite3", "postgres":
case DatabaseSqlite, DatabasePostgres:
break
case "sqlite":
type_ = "sqlite3"
Expand All @@ -429,7 +433,8 @@ func GetDatabaseConfig() DatabaseConfig {
}

return DatabaseConfig{
Type: type_,
Type: type_,
Debug: debug,
Sqlite: SqliteConfig{
Path: util.AbsolutePathFromConfigPath(viper.GetString("database.sqlite.path")),
},
Expand Down

0 comments on commit fdbe0dc

Please sign in to comment.