diff --git a/pkg/rest/server.go b/pkg/rest/server.go index a223511..2651cb2 100644 --- a/pkg/rest/server.go +++ b/pkg/rest/server.go @@ -1,6 +1,7 @@ package rest import ( + "context" "crypto/tls" "embed" "fmt" @@ -8,6 +9,7 @@ import ( "io/fs" "log" "net/http" + "time" "github.com/in4it/wireguard-server/pkg/logging" localstorage "github.com/in4it/wireguard-server/pkg/storage/local" @@ -57,7 +59,28 @@ func StartServer(httpPort, httpsPort int, serverType string) { // TLS Configuration if !c.EnableTLS || !canEnableTLS(c.Hostname) { - <-enableTLSWaiter + // enable self signed tls server + logging.DebugLog(fmt.Errorf("enabling self signed TLS (let's encrypt not enabled)")) + selfSignedServer := &http.Server{ + Addr: fmt.Sprintf(":%d", httpsPort), + TLSConfig: &tls.Config{ + GetCertificate: c.getSelfSignedCertificate, + }, + Handler: c.loggingMiddleware(c.corsMiddleware(c.getRouter(assetsFS, indexHtmlBody))), + } + go func() { + log.Printf("Start https server (self-signed) on port %d", httpsPort) + err := selfSignedServer.ListenAndServeTLS("", "") + logging.DebugLog(fmt.Errorf("shutting down self signed server: %s", err)) + }() + <-enableTLSWaiter // wait until real TLS is asked for (using let's encrypt) + logging.DebugLog(fmt.Errorf("disabling self signed TLS server")) + ctx, cancelCtx := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + defer cancelCtx() + err = selfSignedServer.Shutdown(ctx) + if err != nil { + logging.DebugLog(fmt.Errorf("self signed server stopped (error message: %s)", err)) + } } // only enable when TLS is enabled diff --git a/pkg/rest/tls.go b/pkg/rest/tls.go index 5eb94b0..570a926 100644 --- a/pkg/rest/tls.go +++ b/pkg/rest/tls.go @@ -1,8 +1,16 @@ package rest import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "fmt" "log" + "math/big" "strings" + "time" ) func canEnableTLS(hostname string) bool { @@ -15,3 +23,42 @@ func canEnableTLS(hostname string) bool { return false } + +func (c *Context) getSelfSignedCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + hostnameSplit := strings.Split(c.Hostname, ":") + if info.ServerName != "" && info.ServerName != hostnameSplit[0] { + return nil, fmt.Errorf("can't generate certificate for hostname %s while configured hostname is %s", info.ServerName, hostnameSplit[0]) + } + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, fmt.Errorf("failed to generate serial number: %v", err) + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"IN4IT VPN Server"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), + + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{c.Hostname}, + } + privateKey, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + return nil, fmt.Errorf("failed to generate private key: %v", err) + } + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + if err != nil { + return nil, fmt.Errorf("failed to create certificate: %v", err) + } + tlsCertificate := &tls.Certificate{ + PrivateKey: privateKey, + } + tlsCertificate.Certificate = append(tlsCertificate.Certificate, derBytes) + return tlsCertificate, nil +}