diff --git a/internal/handler/controller/utils.go b/internal/handler/controller/utils.go index 0305a5f..5b92a1d 100644 --- a/internal/handler/controller/utils.go +++ b/internal/handler/controller/utils.go @@ -5,11 +5,13 @@ import ( "encoding/json" "errors" "net/http" + "time" "github.com/go-chi/chi/v5/middleware" "github.com/go-playground/validator/v10" "github.com/oursky/pageship/internal/config" "github.com/oursky/pageship/internal/db" + "github.com/oursky/pageship/internal/httputil" "github.com/oursky/pageship/internal/models" ) @@ -26,8 +28,13 @@ func init() { } const maxJSONSize = 10 * 1024 * 1024 // 10MB +const requestIOTimeout = 10 * time.Second func bindJSON[T any](w http.ResponseWriter, r *http.Request, body *T) bool { + ctrl := http.NewResponseController(w) + dl := time.Now().Add(requestIOTimeout) + ctrl.SetReadDeadline(dl) + decoder := json.NewDecoder(http.MaxBytesReader(w, r.Body, maxJSONSize)) decoder.DisallowUnknownFields() @@ -36,6 +43,8 @@ func bindJSON[T any](w http.ResponseWriter, r *http.Request, body *T) bool { return false } + ctrl.SetReadDeadline(time.Time{}) + if err := validate.Struct(body); err != nil { writeJSON(w, http.StatusBadRequest, response{Error: err}) return false @@ -59,7 +68,9 @@ func (r response) MarshalJSON() ([]byte, error) { func writeJSON(w http.ResponseWriter, statusCode int, value any) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(statusCode) - encoder := json.NewEncoder(w) + + writer := httputil.NewTimeoutResponseWriter(w, requestIOTimeout) + encoder := json.NewEncoder(writer) if err := encoder.Encode(value); err != nil { panic(err) } diff --git a/internal/httputil/server.go b/internal/httputil/server.go index 8878dbb..bd34168 100644 --- a/internal/httputil/server.go +++ b/internal/httputil/server.go @@ -37,8 +37,6 @@ func (s *Server) makeServer(handler http.Handler) *http.Server { return &http.Server{ ErrorLog: zap.NewStdLog(s.Logger), ReadHeaderTimeout: 5 * time.Second, - ReadTimeout: 10 * time.Second, - WriteTimeout: 10 * time.Second, IdleTimeout: 120 * time.Second, MaxHeaderBytes: 10 * 1024, Handler: handler, diff --git a/internal/httputil/timeout.go b/internal/httputil/timeout.go index 6ff9801..5b72014 100644 --- a/internal/httputil/timeout.go +++ b/internal/httputil/timeout.go @@ -9,30 +9,34 @@ import ( type timeoutReader struct { r io.Reader ctrl *http.ResponseController - idleTimeout time.Duration + readTimeout time.Duration } -func NewTimeoutReader(r io.Reader, ctrl *http.ResponseController, idleTimeout time.Duration) io.Reader { - return &timeoutReader{r: r, ctrl: ctrl, idleTimeout: idleTimeout} +func NewTimeoutReader(r io.Reader, ctrl *http.ResponseController, readTimeout time.Duration) io.Reader { + return &timeoutReader{r: r, ctrl: ctrl, readTimeout: readTimeout} } func (r *timeoutReader) Read(p []byte) (int, error) { - n, err := r.r.Read(p) - - dl := time.Now().Add(r.idleTimeout) + dl := time.Now().Add(r.readTimeout) r.ctrl.SetReadDeadline(dl) r.ctrl.SetWriteDeadline(dl) + + n, err := r.r.Read(p) + + r.ctrl.SetReadDeadline(time.Time{}) + r.ctrl.SetWriteDeadline(time.Time{}) + return n, err } type timeoutResponseWriter struct { - w http.ResponseWriter - ctrl *http.ResponseController - idleTimeout time.Duration + w http.ResponseWriter + ctrl *http.ResponseController + writeTimeout time.Duration } -func NewTimeoutResponseWriter(w http.ResponseWriter, idleTimeout time.Duration) http.ResponseWriter { - return &timeoutResponseWriter{w: w, ctrl: http.NewResponseController(w), idleTimeout: idleTimeout} +func NewTimeoutResponseWriter(w http.ResponseWriter, writeTimeout time.Duration) http.ResponseWriter { + return &timeoutResponseWriter{w: w, ctrl: http.NewResponseController(w), writeTimeout: writeTimeout} } func (w *timeoutResponseWriter) Header() http.Header { @@ -44,9 +48,12 @@ func (w *timeoutResponseWriter) WriteHeader(statusCode int) { } func (w *timeoutResponseWriter) Write(p []byte) (int, error) { + dl := time.Now().Add(w.writeTimeout) + w.ctrl.SetWriteDeadline(dl) + n, err := w.w.Write(p) - dl := time.Now().Add(w.idleTimeout) - w.ctrl.SetWriteDeadline(dl) + w.ctrl.SetWriteDeadline(time.Time{}) + return n, err }