diff --git a/caddy/pmtiles_proxy.go b/caddy/pmtiles_proxy.go index fc5ed01..56f45a0 100644 --- a/caddy/pmtiles_proxy.go +++ b/caddy/pmtiles_proxy.go @@ -46,7 +46,7 @@ func (m *Middleware) Provision(ctx caddy.Context) error { m.logger = ctx.Logger() logger := log.New(io.Discard, "", log.Ldate) prefix := "." // serve only the root of the bucket for now, at the root route of Caddyfile - server, err := pmtiles.NewServer(m.Bucket, prefix, logger, m.CacheSize, "", m.PublicURL) + server, err := pmtiles.NewServer(m.Bucket, prefix, logger, m.CacheSize, m.PublicURL) if err != nil { return err } diff --git a/go.mod b/go.mod index 96864f6..b3ea72e 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/dustin/go-humanize v1.0.1 github.com/paulmach/orb v0.10.0 github.com/prometheus/client_golang v1.19.1 + github.com/rs/cors v1.11.1 github.com/schollz/progressbar/v3 v3.13.1 github.com/stretchr/testify v1.9.0 go.uber.org/zap v1.27.0 diff --git a/go.sum b/go.sum index 2d270de..4c2039c 100644 --- a/go.sum +++ b/go.sum @@ -435,6 +435,8 @@ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA= +github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= diff --git a/main.go b/main.go index b4723fa..0710f52 100644 --- a/main.go +++ b/main.go @@ -93,7 +93,7 @@ var cli struct { Interface string `default:"0.0.0.0"` Port int `default:"8080"` AdminPort int `default:"-1"` - Cors string `help:"Value of HTTP CORS header."` + Cors string `help:"Comma-separated list of of allowed HTTP CORS origins."` CacheSize int `default:"64" help:"Size of cache in Megabytes."` Bucket string `help:"Remote bucket"` PublicURL string `help:"Public base URL of tile endpoint for TileJSON e.g. https://example.com/tiles/"` @@ -139,7 +139,7 @@ func main() { logger.Fatalf("Failed to show tile, %v", err) } case "serve ": - server, err := pmtiles.NewServer(cli.Serve.Bucket, cli.Serve.Path, logger, cli.Serve.CacheSize, cli.Serve.Cors, cli.Serve.PublicURL) + server, err := pmtiles.NewServer(cli.Serve.Bucket, cli.Serve.Path, logger, cli.Serve.CacheSize, cli.Serve.PublicURL) if err != nil { logger.Fatalf("Failed to create new server, %v", err) @@ -148,7 +148,9 @@ func main() { pmtiles.SetBuildInfo(version, commit, date) server.Start() - http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + mux := http.NewServeMux() + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { start := time.Now() statusCode := server.ServeHTTP(w, r) logger.Printf("served %d %s in %s", statusCode, url.PathEscape(r.URL.Path), time.Since(start)) @@ -164,7 +166,13 @@ func main() { logger.Fatal(startHTTPServer(cli.Serve.Interface+":"+adminPort, adminMux)) }() } - logger.Fatal(startHTTPServer(cli.Serve.Interface+":"+strconv.Itoa(cli.Serve.Port), nil)) + + if cli.Serve.Cors != "" { + muxWithCors := pmtiles.NewCors(cli.Serve.Cors).Handler(mux) + logger.Fatal(startHTTPServer(cli.Serve.Interface+":"+strconv.Itoa(cli.Serve.Port), muxWithCors)) + } else { + logger.Fatal(startHTTPServer(cli.Serve.Interface+":"+strconv.Itoa(cli.Serve.Port), mux)) + } case "extract ": err := pmtiles.Extract(logger, cli.Extract.Bucket, cli.Extract.Input, cli.Extract.Minzoom, cli.Extract.Maxzoom, cli.Extract.Region, cli.Extract.Bbox, cli.Extract.Output, cli.Extract.DownloadThreads, cli.Extract.Overfetch, cli.Extract.DryRun) if err != nil { diff --git a/pmtiles/server.go b/pmtiles/server.go index ad1deec..0b2c5eb 100644 --- a/pmtiles/server.go +++ b/pmtiles/server.go @@ -7,11 +7,13 @@ import ( "context" "encoding/json" "errors" + "github.com/rs/cors" "io" "log" "net/http" "regexp" "strconv" + "strings" "time" ) @@ -49,13 +51,12 @@ type Server struct { bucket Bucket logger *log.Logger cacheSize int - cors string publicURL string metrics *metrics } // NewServer creates a new pmtiles HTTP server. -func NewServer(bucketURL string, prefix string, logger *log.Logger, cacheSize int, cors string, publicURL string) (*Server, error) { +func NewServer(bucketURL string, prefix string, logger *log.Logger, cacheSize int, publicURL string) (*Server, error) { ctx := context.Background() @@ -71,11 +72,11 @@ func NewServer(bucketURL string, prefix string, logger *log.Logger, cacheSize in return nil, err } - return NewServerWithBucket(bucket, prefix, logger, cacheSize, cors, publicURL) + return NewServerWithBucket(bucket, prefix, logger, cacheSize, publicURL) } // NewServerWithBucket creates a new HTTP server for a gocloud Bucket. -func NewServerWithBucket(bucket Bucket, _ string, logger *log.Logger, cacheSize int, cors string, publicURL string) (*Server, error) { +func NewServerWithBucket(bucket Bucket, _ string, logger *log.Logger, cacheSize int, publicURL string) (*Server, error) { reqs := make(chan request, 8) @@ -84,7 +85,6 @@ func NewServerWithBucket(bucket Bucket, _ string, logger *log.Logger, cacheSize bucket: bucket, logger: logger, cacheSize: cacheSize, - cors: cors, publicURL: publicURL, metrics: createMetrics("", logger), // change scope string if there are multiple servers running in one process } @@ -474,9 +474,6 @@ func (server *Server) get(ctx context.Context, unsanitizedPath string) (archive, handler = "" archive = "" headers = make(map[string]string) - if len(server.cors) > 0 { - headers["Access-Control-Allow-Origin"] = server.cors - } if ok, key, z, x, y, ext := parseTilePath(unsanitizedPath); ok { archive, handler = key, "tile" @@ -518,18 +515,13 @@ func (lrw *loggingResponseWriter) WriteHeader(code int) { // Serve an HTTP response from the archive func (server *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) int { tracker := server.metrics.startRequest() - if r.Method == http.MethodOptions { - if len(server.cors) > 0 { - w.Header().Set("Access-Control-Allow-Origin", server.cors) - } - w.WriteHeader(204) - tracker.finish(r.Context(), "", r.Method, 204, 0, false) - return 204 - } else if r.Method != http.MethodGet && r.Method != http.MethodHead { + + if r.Method != http.MethodGet && r.Method != http.MethodHead { w.WriteHeader(405) tracker.finish(r.Context(), "", r.Method, 405, 0, false) return 405 } + archive, handler, statusCode, headers, body := server.get(r.Context(), r.URL.Path) for k, v := range headers { w.Header().Set(k, v) @@ -552,3 +544,10 @@ func (server *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) int { return statusCode } + +func NewCors(corsOrigins string) *cors.Cors { + return cors.New(cors.Options{ + AllowedMethods: []string{http.MethodGet, http.MethodHead}, + AllowedOrigins: strings.Split(corsOrigins, ","), + }) +} diff --git a/pmtiles/server_test.go b/pmtiles/server_test.go index 38263d5..3f4d769 100644 --- a/pmtiles/server_test.go +++ b/pmtiles/server_test.go @@ -6,6 +6,8 @@ import ( "context" "encoding/json" "log" + "net/http" + "net/http/httptest" "sort" "testing" @@ -13,6 +15,11 @@ import ( "github.com/stretchr/testify/assert" ) +var testResponse = []byte("bar") +var testHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write(testResponse) +}) + func TestRegex(t *testing.T) { ok, key, z, x, y, ext := parseTilePath("/foo/0/0/0") assert.False(t, ok) @@ -108,12 +115,20 @@ func fakeArchive(t *testing.T, header HeaderV3, metadata map[string]interface{}, func newServer(t *testing.T) (mockBucket, *Server) { prometheus.DefaultRegisterer = prometheus.NewRegistry() bucket := mockBucket{make(map[string][]byte)} - server, err := NewServerWithBucket(bucket, "", log.Default(), 10, "", "tiles.example.com") + server, err := NewServerWithBucket(bucket, "", log.Default(), 10, "tiles.example.com") assert.Nil(t, err) server.Start() return bucket, server } +func TestPostReturns405(t *testing.T) { + _, server := newServer(t) + res := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/", nil) + server.ServeHTTP(res, req) + assert.Equal(t, 405, res.Code) +} + func TestMissingFileReturns404(t *testing.T) { _, server := newServer(t) statusCode, _, _ := server.Get(context.Background(), "/") @@ -413,3 +428,44 @@ func TestEtagResponsesFromTile(t *testing.T) { assert.NotEqual(t, headers000v1["ETag"], headers311v1["ETag"]) assert.NotEqual(t, headers000v1["ETag"], headers412v1["ETag"]) } + +func TestSingleCorsOrigin(t *testing.T) { + res := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "http://example.com/foo", nil) + req.Header.Add("Origin", "http://example.com") + c := NewCors("http://example.com") + c.Handler(testHandler).ServeHTTP(res, req) + assert.Equal(t, 200, res.Code) + assert.Equal(t, "http://example.com", res.Header().Get("Access-Control-Allow-Origin")) +} + +func TestMultiCorsOrigin(t *testing.T) { + res := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "http://example2.com/foo", nil) + req.Header.Add("Origin", "http://example2.com") + c := NewCors("http://example.com,http://example2.com") + c.Handler(testHandler).ServeHTTP(res, req) + assert.Equal(t, 200, res.Code) + assert.Equal(t, "http://example2.com", res.Header().Get("Access-Control-Allow-Origin")) +} + +func TestWildcardCors(t *testing.T) { + res := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "http://example.com/foo", nil) + req.Header.Add("Origin", "http://example.com") + c := NewCors("*") + c.Handler(testHandler).ServeHTTP(res, req) + assert.Equal(t, 200, res.Code) + assert.Equal(t, "*", res.Header().Get("Access-Control-Allow-Origin")) +} + +func TestCorsOptions(t *testing.T) { + res := httptest.NewRecorder() + req, _ := http.NewRequest("OPTIONS", "http://example.com/foo", nil) + req.Header.Add("Origin", "http://example.com") + req.Header.Add("Access-Control-Request-Method", "GET") + c := NewCors("*") + c.Handler(testHandler).ServeHTTP(res, req) + assert.Equal(t, 204, res.Code) + assert.Equal(t, "*", res.Header().Get("Access-Control-Allow-Origin")) +}