From 36e30290a594fac955e493cf97e82f6a22c7b139 Mon Sep 17 00:00:00 2001 From: Yehoyada Date: Fri, 18 Oct 2024 16:15:05 +0300 Subject: [PATCH 1/2] remove logging commented --- handler.go | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/handler.go b/handler.go index 621f7e3..23b3e7b 100644 --- a/handler.go +++ b/handler.go @@ -13,16 +13,6 @@ import ( var ErrUnsupportedBaseHandler = errors.New("base handler unsupported") -// func loggingMiddleware(next http.Handler) http.Handler { -// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { -// rc := middlewares.NewResponseCapture(w) -// next.ServeHTTP(rc, r) - -// w.WriteHeader(rc.Status()) -// w.Write(rc.Buffer()) -// }) -// } - func GetBaseHandler(service config.Service, path config.Path) (http.Handler, error) { if path.Destination != nil && *path.Destination != "" { return handlers.NewProxy(service, path) @@ -100,7 +90,5 @@ func NewHandler(service config.Service, path config.Path) (http.Handler, error) handlerWithMiddlewares.Add(middlewares.NewCacheMiddleware()) } - // handlerWithMiddlewares.Add(loggingMiddleware) - return handlerWithMiddlewares, nil } From beadd067550634155d005a53407ef2404a4857b9 Mon Sep 17 00:00:00 2001 From: Yehoyada Date: Tue, 22 Oct 2024 18:10:32 +0300 Subject: [PATCH 2/2] better routing and proj structure --- config/config.go | 62 +++++++++++++- config/config_test.go | 3 +- gatego.go | 103 +---------------------- handlertable.go | 63 --------------- pkg/multimux/multimux.go | 56 +++++++++++++ pkg/multimux/multimux_test.go | 148 ++++++++++++++++++++++++++++++++++ pkg/pathtree/tree.go | 70 ---------------- pkg/pathtree/tree_test.go | 93 --------------------- server.go | 107 ++++++++++++++++++++++++ 9 files changed, 374 insertions(+), 331 deletions(-) delete mode 100644 handlertable.go create mode 100644 pkg/multimux/multimux.go create mode 100644 pkg/multimux/multimux_test.go delete mode 100644 pkg/pathtree/tree.go delete mode 100644 pkg/pathtree/tree_test.go create mode 100644 server.go diff --git a/config/config.go b/config/config.go index 721fdbc..102e0a3 100644 --- a/config/config.go +++ b/config/config.go @@ -177,8 +177,44 @@ func (s Service) validate() error { } type TLS struct { - KeyFile *string `yaml:"keyfile"` - CertFile *string `yaml:"certfile"` + Auto bool `yaml:"auto"` + Domains []string `yaml:"domain"` + Email *string `yaml:"email"` + KeyFile *string `yaml:"keyfile"` + CertFile *string `yaml:"certfile"` +} + +func (tls TLS) validate() error { + if tls.Auto { + if len(tls.Domains) == 0 { + return errors.New("when using the auto tls feature you MUST include a list of domains to issue certificates for") + } + if tls.Email == nil || len(*tls.Email) == 0 || !isValidEmail(*tls.Email) { + return errors.New("when using the auto tls feature you MUST include a valid email for the lets-encrypt registration") + } + } + + if tls.CertFile != nil { + if tls.KeyFile == nil { + return errors.New("you MUST provide certfile AND keyfile") + } + } + + if tls.KeyFile != nil { + if tls.CertFile == nil { + return errors.New("you MUST provide certfile AND keyfile") + } + + if !isValidFile(*tls.CertFile) { + return errors.New("certfile path is invalid") + } + + if !isValidFile(*tls.KeyFile) { + return errors.New("keyfile path is invalid") + } + } + + return nil } type OTEL struct { @@ -216,7 +252,7 @@ type Config struct { OTEL *OTEL `yaml:"open_telemetry"` // TLS options - SSL TLS `yaml:"ssl"` + TLS TLS `yaml:"ssl"` Services []Service `yaml:"services"` } @@ -246,6 +282,18 @@ func (c Config) Validate(currentVersion string) error { } } + if c.Port == 0 { + return errors.New("port is required") + } + + if err := c.TLS.validate(); err != nil { + return err + } + + if c.TLS.Auto && c.Port != 443 { + return errors.New("the auto tls feature is only available if the server runs on port 443") + } + for _, service := range c.Services { if err := service.validate(); err != nil { return err @@ -405,3 +453,11 @@ func isValidGRPCAddress(address string) error { return nil } + +func isValidEmail(email string) bool { + // Define a regular expression for valid email addresses + var emailRegex = regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`) + + // Match the email string with the regular expression + return emailRegex.MatchString(email) +} diff --git a/config/config_test.go b/config/config_test.go index 40d0365..c66a6de 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -58,7 +58,8 @@ func TestConfigValidate(t *testing.T) { currentVersion string wantErr bool }{ - {"Valid config", Config{Version: "1.0.0", Host: "localhost", Services: []Service{{Domain: "example.com", Paths: []Path{{Path: "/api", Destination: ptr("http://api.example.com")}}}}}, "1.0.0", false}, + {"Valid config", Config{Version: "1.0.0", Host: "localhost", Port: 80, Services: []Service{{Domain: "example.com", Paths: []Path{{Path: "/api", Destination: ptr("http://api.example.com")}}}}}, "1.0.0", false}, + {"AutoTLS with port != 443", Config{Version: "1.0.0", Host: "localhost", Port: 80, TLS: TLS{Auto: true, Domains: []string{"example.com"}}, Services: []Service{{Domain: "example.com", Paths: []Path{{Path: "/api", Destination: ptr("http://api.example.com")}}}}}, "1.0.0", true}, {"Missing version", Config{Host: "localhost"}, "1.0.0", true}, {"Invalid version", Config{Version: "invalid", Host: "localhost"}, "1.0.0", true}, {"Future version", Config{Version: "2.0.0", Host: "localhost"}, "1.0.0", true}, diff --git a/gatego.go b/gatego.go index 053dcd8..925248a 100644 --- a/gatego.go +++ b/gatego.go @@ -2,16 +2,10 @@ package gatego import ( "context" - "fmt" - "log" - "net" - "net/http" - "os" "time" "github.com/hvuhsg/gatego/config" "github.com/hvuhsg/gatego/contextvalues" - "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" ) const serviceName = "gatego" @@ -48,15 +42,13 @@ func (gg GateGo) Run() error { checker := createChecker(gg.config.Services) checker.Start() - table, err := NewHandlersTable(gg.ctx, useOtel, gg.config.Services) + server, err := newServer(gg.ctx, gg.config, useOtel) if err != nil { return err } - - server := gg.createServer(table) defer server.Shutdown(gg.ctx) - serveErrChan, err := serve(server, gg.config.SSL.CertFile, gg.config.SSL.KeyFile) + serveErrChan, err := server.serve(gg.config.TLS.CertFile, gg.config.TLS.KeyFile) if err != nil { return err } @@ -92,94 +84,3 @@ func createChecker(services []config.Service) *Checker { return checker } - -func (gg GateGo) createServer(table HandlerTable) *http.Server { - mux := http.NewServeMux() - - // handleFunc is a replacement for mux.HandleFunc - // which enriches the handler's HTTP instrumentation with the pattern as the http.route. - handleFunc := func(pattern string, handlerFunc func(http.ResponseWriter, *http.Request)) { - // Configure the "http.route" for the HTTP instrumentation. - handler := otelhttp.WithRouteTag(pattern, http.HandlerFunc(handlerFunc)) - mux.Handle(pattern, handler) - } - - handleFunc("/", func(w http.ResponseWriter, r *http.Request) { - handler := table.GetHandler(r.Host, r.URL.Path) - - if handler == nil { - w.WriteHeader(http.StatusNotFound) - return - } - - handler.ServeHTTP(w, r) - }) - - // Add HTTP instrumentation for the whole server. - handler := otelhttp.NewHandler(mux, "/") - - addr := fmt.Sprintf("%s:%d", gg.config.Host, gg.config.Port) - - // Start HTTP server. - server := &http.Server{ - Addr: addr, - BaseContext: func(_ net.Listener) context.Context { return gg.ctx }, - ReadTimeout: time.Second, - WriteTimeout: 10 * time.Second, - Handler: handler, - } - - return server -} - -func serve(server *http.Server, certfile *string, keyfile *string) (chan error, error) { - supportTLS, err := checkTLSConfig(certfile, keyfile) - if err != nil { - return nil, err - } - - serveErr := make(chan error, 1) - - go func() { - if supportTLS { - log.Default().Printf("Serving proxy with TLS %s\n", server.Addr) - serveErr <- server.ListenAndServeTLS(*certfile, *keyfile) - } else { - log.Default().Printf("Serving proxy %s\n", server.Addr) - serveErr <- server.ListenAndServe() - } - }() - - return serveErr, nil -} - -func checkTLSConfig(certfile *string, keyfile *string) (bool, error) { - if keyfile == nil || certfile == nil || *keyfile == "" || *certfile == "" { - return false, nil - } - - if !fileExists(*keyfile) { - return false, fmt.Errorf("can't find keyfile at '%s'", *keyfile) - } - - if !fileExists(*certfile) { - return false, fmt.Errorf("can't find certfile at '%s'", *certfile) - } - - return true, nil -} - -func fileExists(filepath string) bool { - _, err := os.Stat(filepath) - - if os.IsNotExist(err) { - return false - } - - // If we cant check the file info we probably can't open the file - if err != nil { - return false - } - - return true -} diff --git a/handlertable.go b/handlertable.go deleted file mode 100644 index 3a9ce4a..0000000 --- a/handlertable.go +++ /dev/null @@ -1,63 +0,0 @@ -package gatego - -import ( - "context" - "net/http" - "strings" - - "github.com/hvuhsg/gatego/config" - "github.com/hvuhsg/gatego/pkg/pathtree" -) - -type HandlerTable map[string]*pathtree.Trie[http.Handler] - -func cleanDomain(domain string) string { - return removePort(strings.ToLower(domain)) -} - -func NewHandlersTable(ctx context.Context, useOtel bool, servicesConfig []config.Service) (HandlerTable, error) { - servers := make(map[string]*pathtree.Trie[http.Handler]) - - for _, service := range servicesConfig { - servicePathTree := pathtree.NewTrie[http.Handler]() - - cleanedDomain := cleanDomain(service.Domain) - - servers[cleanedDomain] = servicePathTree - - for _, path := range service.Paths { - handler, err := NewHandler(ctx, useOtel, service, path) - if err != nil { - return nil, err - } - - cleanPath := strings.ToLower(path.Path) - servicePathTree.Insert(cleanPath, handler) - } - } - - return servers, nil -} - -func (table HandlerTable) GetHandler(domain string, path string) http.Handler { - cleanedDomain := cleanDomain(domain) - - pathTree, ok := table[cleanedDomain] - if !ok { - return nil - } - - endpoint, server := pathTree.Search(path) - if len(endpoint) == 0 { - return nil - } - - return server -} - -func removePort(addr string) string { - if i := strings.LastIndex(addr, ":"); i != -1 { - return addr[:i] - } - return addr -} diff --git a/pkg/multimux/multimux.go b/pkg/multimux/multimux.go new file mode 100644 index 0000000..442771f --- /dev/null +++ b/pkg/multimux/multimux.go @@ -0,0 +1,56 @@ +// This package implement a mutil-mux an http handler +// that acts as seprate http.ServeMux for each registred host + +package multimux + +import ( + "net/http" + "strings" +) + +type MultiMux struct { + Hosts map[string]*http.ServeMux +} + +func NewMultiMux() *MultiMux { + hosts := make(map[string]*http.ServeMux) + return &MultiMux{Hosts: hosts} +} + +func (mm *MultiMux) RegisterHandler(host string, pattern string, handler http.Handler) { + cleanedHost := cleanHost(host) + mux, exists := mm.Hosts[cleanedHost] + + if !exists { + mux = http.NewServeMux() + mm.Hosts[cleanedHost] = mux + } + + cleanedPattern := strings.ToLower(pattern) + + mux.Handle(cleanedPattern, handler) +} + +func (mm *MultiMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { + host := r.Host + cleanedHost := cleanHost(host) + mux, exists := mm.Hosts[cleanedHost] + + if !exists { + w.WriteHeader(http.StatusNotFound) + return + } + + mux.ServeHTTP(w, r) +} + +func cleanHost(domain string) string { + return removePort(strings.ToLower(domain)) +} + +func removePort(addr string) string { + if i := strings.LastIndex(addr, ":"); i != -1 { + return addr[:i] + } + return addr +} diff --git a/pkg/multimux/multimux_test.go b/pkg/multimux/multimux_test.go new file mode 100644 index 0000000..251cdb6 --- /dev/null +++ b/pkg/multimux/multimux_test.go @@ -0,0 +1,148 @@ +package multimux + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" +) + +func TestRegisterHandler(t *testing.T) { + tests := []struct { + name string + host string + pattern string + }{ + {"basic registration", "example.com", "/path"}, + {"with port", "example.com:8080", "/path"}, + {"uppercase host", "EXAMPLE.COM", "/path"}, + {"uppercase pattern", "/PATH", "/path"}, + {"with subdomain", "sub.example.com", "/path"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mm := NewMultiMux() + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + mm.RegisterHandler(tt.host, tt.pattern, handler) + + cleanedHost := cleanHost(tt.host) + mux, exists := mm.Hosts[cleanedHost] + if !exists { + t.Errorf("Host %s was not registered", cleanedHost) + } + if mux == nil { + t.Errorf("ServeMux for host %s is nil", cleanedHost) + } + }) + } +} + +func TestServeHTTP(t *testing.T) { + tests := []struct { + name string + host string + path string + expectedStatus int + expectedBody string + }{ + { + name: "existing host and path", + host: "example.com", + path: "/test", + expectedStatus: http.StatusOK, + expectedBody: "handler1", + }, + { + name: "existing host with port", + host: "example.com:8080", + path: "/test", + expectedStatus: http.StatusOK, + expectedBody: "handler1", + }, + { + name: "non-existing host", + host: "unknown.com", + path: "/test", + expectedStatus: http.StatusNotFound, + expectedBody: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mm := NewMultiMux() + + // Register a test handler + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "handler1") + }) + mm.RegisterHandler("example.com", "/test", handler) + + // Create test request + req := httptest.NewRequest("GET", "http://"+tt.host+tt.path, nil) + req.Host = tt.host + w := httptest.NewRecorder() + + // Serve the request + mm.ServeHTTP(w, req) + + // Check status code + if w.Code != tt.expectedStatus { + t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code) + } + + // Check response body if expected + if tt.expectedBody != "" && w.Body.String() != tt.expectedBody { + t.Errorf("expected body %q, got %q", tt.expectedBody, w.Body.String()) + } + }) + } +} + +func TestCleanHost(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"example.com", "example.com"}, + {"EXAMPLE.COM", "example.com"}, + {"example.com:8080", "example.com"}, + {"EXAMPLE.COM:8080", "example.com"}, + {"sub.example.com:8080", "sub.example.com"}, + {"localhost", "localhost"}, + {"localhost:8080", "localhost"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := cleanHost(tt.input) + if result != tt.expected { + t.Errorf("cleanHost(%q) = %q; want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestRemovePort(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"example.com", "example.com"}, + {"example.com:8080", "example.com"}, + {"example.com:80", "example.com"}, + {"localhost:8080", "localhost"}, + {"127.0.0.1:8080", "127.0.0.1"}, + {"[::1]:8080", "[::1]"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := removePort(tt.input) + if result != tt.expected { + t.Errorf("removePort(%q) = %q; want %q", tt.input, result, tt.expected) + } + }) + } +} diff --git a/pkg/pathtree/tree.go b/pkg/pathtree/tree.go deleted file mode 100644 index f1da334..0000000 --- a/pkg/pathtree/tree.go +++ /dev/null @@ -1,70 +0,0 @@ -package pathtree - -import ( - "strings" -) - -type TrieNode[T any] struct { - children map[string]*TrieNode[T] - isEnd bool - value string - data T -} - -type Trie[T any] struct { - root *TrieNode[T] -} - -func NewTrie[T any]() *Trie[T] { - return &Trie[T]{ - root: &TrieNode[T]{ - children: make(map[string]*TrieNode[T]), - }, - } -} - -func (t *Trie[T]) Insert(path string, data T) { - if path == "/" { - t.root.value = path - t.root.data = data - t.root.isEnd = true - return - } - - node := t.root - parts := strings.Split(strings.Trim(path, "/"), "/") - - for _, part := range parts { - if _, exists := node.children[part]; !exists { - node.children[part] = &TrieNode[T]{ - children: make(map[string]*TrieNode[T]), - } - } - node = node.children[part] - } - - node.isEnd = true - node.value = path - node.data = data -} - -func (t *Trie[T]) Search(path string) (string, T) { - node := t.root - parts := strings.Split(strings.Trim(path, "/"), "/") - lastMatch := node.value - var lastData T = node.data - - for _, part := range parts { - if child, exists := node.children[part]; exists { - node = child - if node.isEnd { - lastMatch = node.value - lastData = node.data - } - } else { - break - } - } - - return lastMatch, lastData -} diff --git a/pkg/pathtree/tree_test.go b/pkg/pathtree/tree_test.go deleted file mode 100644 index 4566e13..0000000 --- a/pkg/pathtree/tree_test.go +++ /dev/null @@ -1,93 +0,0 @@ -package pathtree - -import ( - "testing" -) - -func TestTrieInsertAndSearch(t *testing.T) { - trie := NewTrie[string]() - - // Test cases - testCases := []struct { - insert string - search string - expected string - }{ - {"/", "/", "/"}, - {"/api", "/api", "/api"}, - {"/api/users", "/api/users", "/api/users"}, - {"/api/users/123", "/api/users/123", "/api/users/123"}, - {"/api/users/123", "/api/users/456", "/api/users"}, - {"/api/posts", "/api/posts/1", "/api/posts"}, - {"/blog", "/blog/2023/05/01", "/blog"}, - } - - // Insert test data - for _, tc := range testCases { - trie.Insert(tc.insert, tc.insert) - } - - // Test searches - for _, tc := range testCases { - path, data := trie.Search(tc.search) - if path != tc.expected { - t.Errorf("Search(%s) = %s, expected %s", tc.search, path, tc.expected) - } - if data != tc.expected { - t.Errorf("Search(%s) data = %s, expected %s", tc.search, data, tc.expected) - } - } -} - -func TestTrieRootInsert(t *testing.T) { - trie := NewTrie[string]() - trie.Insert("/", "root") - - path, data := trie.Search("/") - if path != "/" { - t.Errorf("Search('/') path = %s, expected '/'", path) - } - if data != "root" { - t.Errorf("Search('/') data = %s, expected 'root'", data) - } -} - -func TestTrieEmptySearch(t *testing.T) { - trie := NewTrie[string]() - trie.Insert("/api", "api") - - path, data := trie.Search("") - if path != "" { - t.Errorf("Search('') path = %s, expected ''", path) - } - if data != "" { - t.Errorf("Search('') data = %s, expected ''", data) - } -} - -func TestTrieNonExistentPath(t *testing.T) { - trie := NewTrie[string]() - trie.Insert("/api/users", "users") - - path, data := trie.Search("/api/posts") - if path != "" { - t.Errorf("Search('/api/posts') path = %s, expected ''", path) - } - if data != "" { - t.Errorf("Search('/api/posts') data = %s, expected ''", data) - } -} - -func TestTrieWithIntData(t *testing.T) { - trie := NewTrie[int]() - trie.Insert("/api", 1) - trie.Insert("/api/users", 2) - - path, data := trie.Search("/api/users/123") - if path != "/api/users" { - t.Errorf("Search('/api/users/123') path = %s, expected '/api/users'", path) - } - if data != 2 { - t.Errorf("Search('/api/users/123') data = %d, expected 2", data) - } -} diff --git a/server.go b/server.go new file mode 100644 index 0000000..3cf047a --- /dev/null +++ b/server.go @@ -0,0 +1,107 @@ +package gatego + +import ( + "context" + "fmt" + "log" + "net" + "net/http" + "os" + "time" + + "github.com/hvuhsg/gatego/config" + "github.com/hvuhsg/gatego/pkg/multimux" +) + +type gategoServer struct { + *http.Server +} + +func newServer(ctx context.Context, config config.Config, useOtel bool) (*gategoServer, error) { + multimuxer, err := createMultiMuxer(ctx, config.Services, useOtel) + if err != nil { + return nil, err + } + + addr := fmt.Sprintf("%s:%d", config.Host, config.Port) + + // Start HTTP server. + server := &http.Server{ + Addr: addr, + BaseContext: func(_ net.Listener) context.Context { return ctx }, + ReadTimeout: time.Second, + WriteTimeout: 10 * time.Second, + Handler: multimuxer, + } + + return &gategoServer{Server: server}, nil +} + +func createMultiMuxer(ctx context.Context, services []config.Service, useOtel bool) (*multimux.MultiMux, error) { + mm := multimux.NewMultiMux() + + for _, service := range services { + for _, path := range service.Paths { + handler, err := NewHandler(ctx, useOtel, service, path) + if err != nil { + return nil, err + } + + mm.RegisterHandler(service.Domain, path.Path, handler) + } + } + + return mm, nil +} + +func (gs *gategoServer) serve(certfile *string, keyfile *string) (chan error, error) { + supportTLS, err := checkTLSConfig(certfile, keyfile) + if err != nil { + return nil, err + } + + serveErr := make(chan error, 1) + + go func() { + if supportTLS { + log.Default().Printf("Serving proxy with TLS %s\n", gs.Addr) + serveErr <- gs.ListenAndServeTLS(*certfile, *keyfile) + } else { + log.Default().Printf("Serving proxy %s\n", gs.Addr) + serveErr <- gs.ListenAndServe() + } + }() + + return serveErr, nil +} + +func checkTLSConfig(certfile *string, keyfile *string) (bool, error) { + if keyfile == nil || certfile == nil || *keyfile == "" || *certfile == "" { + return false, nil + } + + if !fileExists(*keyfile) { + return false, fmt.Errorf("can't find keyfile at '%s'", *keyfile) + } + + if !fileExists(*certfile) { + return false, fmt.Errorf("can't find certfile at '%s'", *certfile) + } + + return true, nil +} + +func fileExists(filepath string) bool { + _, err := os.Stat(filepath) + + if os.IsNotExist(err) { + return false + } + + // If we cant check the file info we probably can't open the file + if err != nil { + return false + } + + return true +}