diff --git a/main.go b/main.go index 0dc1464..a766814 100644 --- a/main.go +++ b/main.go @@ -83,19 +83,16 @@ func proxy(w http.ResponseWriter, r *http.Request) { pathExt := filepath.Ext(req.URL.Path) if pathExt == ".webp" { orgRes, err = doWebp(req) - if err != nil { - http.Error(w, "Get origin failed", http.StatusBadGateway) - log.Printf("Get origin failed. %v\n", err) - return - } } else { orgRes, err = client.Do(req) - if err != nil { - http.Error(w, "Get origin failed", http.StatusBadGateway) - log.Printf("Get origin failed. %v\n", err) - return - } } + + if err != nil || orgRes.StatusCode == http.StatusNotFound { + http.Error(w, "Get origin failed", orgRes.StatusCode) + log.Printf("Get origin failed. %v\n", err) + return + } + defer orgRes.Body.Close() if orgRes.Header.Get("Last-Modified") != "" { w.Header().Set("Last-Modified", orgRes.Header.Get("Last-Modified")) diff --git a/main_test.go b/main_test.go index 7fab35c..dba3943 100644 --- a/main_test.go +++ b/main_test.go @@ -161,6 +161,27 @@ func TestOriginNotExist(t *testing.T) { t.Fatal(err) } + if res.StatusCode != http.StatusNotFound { + t.Errorf("HTTP status is %d, want %d", res.StatusCode, http.StatusNotFound) + } +} + +func TestOriginBadGateWay(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(proxy)) + + origin := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "502 Bad Gateway", http.StatusBadGateway) + })) + + orgSrvURL = origin.URL + + url := ts.URL + "/bad.jpg" + + res, err := http.Get(url) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != http.StatusBadGateway { t.Errorf("HTTP status is %d, want %d", res.StatusCode, http.StatusBadGateway) }