From 07660e3ae6a5b005d0d9a5b8b09b4c2da1d60525 Mon Sep 17 00:00:00 2001 From: Brandon Liu Date: Thu, 31 Oct 2024 17:02:21 -0700 Subject: [PATCH] use rs/cors library to handle multiple CORS origins [#191] --- go.mod | 1 + go.sum | 2 ++ main.go | 21 +++++++++++++++++---- pmtiles/server.go | 24 ++++-------------------- 4 files changed, 24 insertions(+), 24 deletions(-) diff --git a/go.mod b/go.mod index 96864f6..b3ea72e 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/dustin/go-humanize v1.0.1 github.com/paulmach/orb v0.10.0 github.com/prometheus/client_golang v1.19.1 + github.com/rs/cors v1.11.1 github.com/schollz/progressbar/v3 v3.13.1 github.com/stretchr/testify v1.9.0 go.uber.org/zap v1.27.0 diff --git a/go.sum b/go.sum index 2d270de..4c2039c 100644 --- a/go.sum +++ b/go.sum @@ -435,6 +435,8 @@ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA= +github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= diff --git a/main.go b/main.go index b4723fa..2a00fc2 100644 --- a/main.go +++ b/main.go @@ -8,10 +8,12 @@ import ( "os" "path/filepath" "strconv" + "strings" "time" "github.com/alecthomas/kong" "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/rs/cors" "github.com/protomaps/go-pmtiles/pmtiles" _ "gocloud.dev/blob/azureblob" @@ -93,7 +95,7 @@ var cli struct { Interface string `default:"0.0.0.0"` Port int `default:"8080"` AdminPort int `default:"-1"` - Cors string `help:"Value of HTTP CORS header."` + Cors string `help:"Comma-separated list of of allowed HTTP CORS origins."` CacheSize int `default:"64" help:"Size of cache in Megabytes."` Bucket string `help:"Remote bucket"` PublicURL string `help:"Public base URL of tile endpoint for TileJSON e.g. https://example.com/tiles/"` @@ -139,7 +141,7 @@ func main() { logger.Fatalf("Failed to show tile, %v", err) } case "serve ": - server, err := pmtiles.NewServer(cli.Serve.Bucket, cli.Serve.Path, logger, cli.Serve.CacheSize, cli.Serve.Cors, cli.Serve.PublicURL) + server, err := pmtiles.NewServer(cli.Serve.Bucket, cli.Serve.Path, logger, cli.Serve.CacheSize, cli.Serve.PublicURL) if err != nil { logger.Fatalf("Failed to create new server, %v", err) @@ -148,7 +150,9 @@ func main() { pmtiles.SetBuildInfo(version, commit, date) server.Start() - http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + mux := http.NewServeMux() + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { start := time.Now() statusCode := server.ServeHTTP(w, r) logger.Printf("served %d %s in %s", statusCode, url.PathEscape(r.URL.Path), time.Since(start)) @@ -164,7 +168,16 @@ func main() { logger.Fatal(startHTTPServer(cli.Serve.Interface+":"+adminPort, adminMux)) }() } - logger.Fatal(startHTTPServer(cli.Serve.Interface+":"+strconv.Itoa(cli.Serve.Port), nil)) + + if cli.Serve.Cors != "" { + c := cors.New(cors.Options{ + AllowedOrigins: strings.Split(cli.Serve.Cors, ","), + }) + muxWithCors := c.Handler(mux) + logger.Fatal(startHTTPServer(cli.Serve.Interface+":"+strconv.Itoa(cli.Serve.Port), muxWithCors)) + } else { + logger.Fatal(startHTTPServer(cli.Serve.Interface+":"+strconv.Itoa(cli.Serve.Port), mux)) + } case "extract ": err := pmtiles.Extract(logger, cli.Extract.Bucket, cli.Extract.Input, cli.Extract.Minzoom, cli.Extract.Maxzoom, cli.Extract.Region, cli.Extract.Bbox, cli.Extract.Output, cli.Extract.DownloadThreads, cli.Extract.Overfetch, cli.Extract.DryRun) if err != nil { diff --git a/pmtiles/server.go b/pmtiles/server.go index ad1deec..6acfd75 100644 --- a/pmtiles/server.go +++ b/pmtiles/server.go @@ -49,13 +49,12 @@ type Server struct { bucket Bucket logger *log.Logger cacheSize int - cors string publicURL string metrics *metrics } // NewServer creates a new pmtiles HTTP server. -func NewServer(bucketURL string, prefix string, logger *log.Logger, cacheSize int, cors string, publicURL string) (*Server, error) { +func NewServer(bucketURL string, prefix string, logger *log.Logger, cacheSize int, publicURL string) (*Server, error) { ctx := context.Background() @@ -71,11 +70,11 @@ func NewServer(bucketURL string, prefix string, logger *log.Logger, cacheSize in return nil, err } - return NewServerWithBucket(bucket, prefix, logger, cacheSize, cors, publicURL) + return NewServerWithBucket(bucket, prefix, logger, cacheSize, publicURL) } // NewServerWithBucket creates a new HTTP server for a gocloud Bucket. -func NewServerWithBucket(bucket Bucket, _ string, logger *log.Logger, cacheSize int, cors string, publicURL string) (*Server, error) { +func NewServerWithBucket(bucket Bucket, _ string, logger *log.Logger, cacheSize int, publicURL string) (*Server, error) { reqs := make(chan request, 8) @@ -84,7 +83,6 @@ func NewServerWithBucket(bucket Bucket, _ string, logger *log.Logger, cacheSize bucket: bucket, logger: logger, cacheSize: cacheSize, - cors: cors, publicURL: publicURL, metrics: createMetrics("", logger), // change scope string if there are multiple servers running in one process } @@ -474,9 +472,6 @@ func (server *Server) get(ctx context.Context, unsanitizedPath string) (archive, handler = "" archive = "" headers = make(map[string]string) - if len(server.cors) > 0 { - headers["Access-Control-Allow-Origin"] = server.cors - } if ok, key, z, x, y, ext := parseTilePath(unsanitizedPath); ok { archive, handler = key, "tile" @@ -518,18 +513,7 @@ func (lrw *loggingResponseWriter) WriteHeader(code int) { // Serve an HTTP response from the archive func (server *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) int { tracker := server.metrics.startRequest() - if r.Method == http.MethodOptions { - if len(server.cors) > 0 { - w.Header().Set("Access-Control-Allow-Origin", server.cors) - } - w.WriteHeader(204) - tracker.finish(r.Context(), "", r.Method, 204, 0, false) - return 204 - } else if r.Method != http.MethodGet && r.Method != http.MethodHead { - w.WriteHeader(405) - tracker.finish(r.Context(), "", r.Method, 405, 0, false) - return 405 - } + archive, handler, statusCode, headers, body := server.get(r.Context(), r.URL.Path) for k, v := range headers { w.Header().Set(k, v)