From a02e126411b7dcf4d701066d749f98d556345138 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Mon, 29 Jul 2024 16:25:23 +0800 Subject: [PATCH 01/55] WIP(x/http/get): Implementing get request using native socket --- go.mod | 2 +- go.sum | 2 + x/http/_demo/get/get.go | 17 ++ x/http/_demo/test.go | 141 +++++++++++++++ x/http/client.go | 388 ++++++++++++++++++++++++++++++++++++++++ x/http/header.go | 34 ++++ x/http/hyper-go.go | 12 ++ x/http/response.go | 87 +++++++++ 8 files changed, 682 insertions(+), 1 deletion(-) create mode 100644 x/http/_demo/get/get.go create mode 100644 x/http/_demo/test.go create mode 100644 x/http/client.go create mode 100644 x/http/header.go create mode 100644 x/http/hyper-go.go create mode 100644 x/http/response.go diff --git a/go.mod b/go.mod index 7e12cef..ff4ef8f 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module github.com/goplus/llgoexamples go 1.20 -require github.com/goplus/llgo v0.9.3-0.20240726020431-98d075728f2b +require github.com/goplus/llgo v0.9.4-0.20240729010130-b3b4f55c68f0 diff --git a/go.sum b/go.sum index fdc017f..8799390 100644 --- a/go.sum +++ b/go.sum @@ -6,3 +6,5 @@ github.com/goplus/llgo v0.9.1-0.20240712060421-858d38d314a3 h1:2fZ2zQ8S58KvOsJTx github.com/goplus/llgo v0.9.1-0.20240712060421-858d38d314a3/go.mod h1:zsrtWZapL4aklZc99xBSZRynGzLTIT1mLRjP0VSn9iw= github.com/goplus/llgo v0.9.3-0.20240726020431-98d075728f2b h1:z9FUoeAALL5ytBhhGhE1dXm4+L1Q2eMUTcfiqLAZgf8= github.com/goplus/llgo v0.9.3-0.20240726020431-98d075728f2b/go.mod h1:zsrtWZapL4aklZc99xBSZRynGzLTIT1mLRjP0VSn9iw= +github.com/goplus/llgo v0.9.4-0.20240729010130-b3b4f55c68f0 h1:02gSx3Oj3cLlBMed+9IWBUGHThEZMnCNiR67yaQbpqo= +github.com/goplus/llgo v0.9.4-0.20240729010130-b3b4f55c68f0/go.mod h1:zsrtWZapL4aklZc99xBSZRynGzLTIT1mLRjP0VSn9iw= diff --git a/x/http/_demo/get/get.go b/x/http/_demo/get/get.go new file mode 100644 index 0000000..e7591a6 --- /dev/null +++ b/x/http/_demo/get/get.go @@ -0,0 +1,17 @@ +package main + +import ( + "fmt" + + "github.com/goplus/llgoexamples/x/http" +) + +func main() { + // 使用 http.Get 发送 GET 请求 + resp := http.Get("https://www.baidu.com/") + fmt.Println(resp.Status) + fmt.Println(resp.StatusCode) + resp.PrintHeaders() + fmt.Println() + resp.PrintBody() +} diff --git a/x/http/_demo/test.go b/x/http/_demo/test.go new file mode 100644 index 0000000..2ddd3ea --- /dev/null +++ b/x/http/_demo/test.go @@ -0,0 +1,141 @@ +package main + +import ( + "fmt" + "io" +) + +func main() { + + // 假设你有一个 []byte 数组 + data := []byte("This is some data that needs to be stored in Body.") + + // 创建一个 io.Pipe + pr, pw := io.Pipe() + + // 启动一个 goroutine 将数据写入 io.Pipe 的写入端 + go func() { + defer pw.Close() // 确保写入完成后关闭写入端 + + if _, err := pw.Write(data); err != nil { + fmt.Println("Error writing to pipe:", err) + return + } + }() + + // 读取 Body 中的数据进行验证 + readData, err := io.ReadAll(pr) + if err != nil { + fmt.Println("Error reading from Body:", err) + return + } + + // 输出 Body 中的数据 + fmt.Println("Body content:", string(readData)) + // + //http.Get() + + //r, w := io.Pipe() + // + //go func() { + // fmt.Fprint(w, "some io.Reader stream to be read\n") + // w.Close() + //}() + // + //if _, err := io.Copy(os.Stdout, r); err != nil { + // log.Fatal(err) + //} + + // 使用 http.Get 发送 GET 请求 + //resp, err := http.Get("https://www.baidu.com/") + //if err != nil { + // fmt.Println("Error:", err) + // return + //} + //defer resp.Body.Close() + // + //body, err := io.ReadAll(resp.Body) + //if err != nil { + // fmt.Println("Error reading response:", err) + // return + //} + //fmt.Println("GET Response:\n", string(body)) + + //rawURL := "http://example.com:8080/path/to/resource?query=123#fragment" + //parsedURL, err := url.Parse(rawURL) + //if err != nil { + // fmt.Println("Error parsing URL:", err) + // return + //} + // + //hostname := parsedURL.Hostname() + //port := parsedURL.Port() + // + //uri := parsedURL.RequestURI() + // + //fmt.Println("Hostname:", hostname) + //fmt.Println("Port:", port) + //fmt.Println("URI:", uri) + + //// 使用 http.Post 发送 POST 请求上传文件 + //file, err := os.Open("path/to/your/file.jpg") + //if err != nil { + // fmt.Println("Error opening file:", err) + // return + //} + //defer file.Close() + // + //var buf bytes.Buffer + //writer := multipart.NewWriter(&buf) + //_, err = writer.CreateFormFile("file", "file.jpg") + //if err != nil { + // fmt.Println("Error creating form file:", err) + // return + //} + // + //_, err = io.ReadAll(file) + //if err != nil { + // fmt.Println("Error reading file:", err) + // return + //} + // + //err = writer.Close() + //if err != nil { + // fmt.Println("Error closing writer:", err) + // return + //} + // + //resp, err = http.Post("https://www.baidu.com/upload", writer.FormDataContentType(), &buf) + //if err != nil { + // fmt.Println("Error:", err) + // return + //} + //defer resp.Body.Close() + // + //body, err = io.ReadAll(resp.Body) + //if err != nil { + // fmt.Println("Error reading response:", err) + // return + //} + //fmt.Println("POST Response:\n", string(body)) + // + //// 使用 http.PostForm 发送表单数据 + //formData := url.Values{ + // "key": {"Value"}, + // "id": {"123"}, + //} + // + //resp, err = http.PostForm("https://www.baidu.com/form", formData) + //if err != nil { + // fmt.Println("Error:", err) + // return + //} + //defer resp.Body.Close() + // + //body, err = io.ReadAll(resp.Body) + //if err != nil { + // fmt.Println("Error reading response:", err) + // return + //} + //fmt.Println("POST Form Response:\n", string(body)) +} diff --git a/x/http/client.go b/x/http/client.go new file mode 100644 index 0000000..32cf658 --- /dev/null +++ b/x/http/client.go @@ -0,0 +1,388 @@ +package http + +import ( + "fmt" + "strconv" + "strings" + + "github.com/goplus/llgo/c" + "github.com/goplus/llgo/c/net" + "github.com/goplus/llgo/c/os" + "github.com/goplus/llgo/c/sys" + "github.com/goplus/llgo/c/syscall" + "github.com/goplus/llgoexamples/rust/hyper" +) + +type ConnData struct { + Fd c.Int + ReadWaker *hyper.Waker + WriteWaker *hyper.Waker +} + +type RequestConfig struct { + ReqMethod string + ReqHost string + ReqPort string + ReqUri string + ReqHeaders map[string]string + ReqHTTPVersion hyper.HTTPVersion + TimeoutSec int64 + TimeoutUsec int32 + //ReqBody + //ReqURIParts +} + +func Get(url string) *Response { + host, port, uri := parseURL(url) + req := hyper.NewRequest() + + // Prepare the request + // Set the request method and uri + if req.SetMethod((*uint8)(&[]byte("GET")[0]), c.Strlen(c.Str("GET"))) != hyper.OK { + panic(fmt.Sprintf("error setting method %s\n", "GET")) + } + if req.SetURI((*uint8)(&[]byte(uri)[0]), c.Strlen(c.AllocaCStr(uri))) != hyper.OK { + panic(fmt.Sprintf("error setting uri %s\n", uri)) + } + + // Set the request headers + reqHeaders := req.Headers() + if reqHeaders.Set((*uint8)(&[]byte("Host")[0]), c.Strlen(c.Str("Host")), (*uint8)(&[]byte(host)[0]), c.Strlen(c.AllocaCStr(host))) != hyper.OK { + panic("error setting headers\n") + } + + //var response RequestResponse + + fd := ConnectTo(host, port) + + connData := NewConnData(fd) + + // Hookup the IO + io := NewIoWithConnReadWrite(connData) + + // We need an executor generally to poll futures + exec := hyper.NewExecutor() + + // Prepare client options + opts := hyper.NewClientConnOptions() + opts.Exec(exec) + + handshakeTask := hyper.Handshake(io, opts) + SetUserData(handshakeTask, hyper.ExampleHandshake) + + // Let's wait for the handshake to finish... + exec.Push(handshakeTask) + + var fdsRead, fdsWrite, fdsExcep syscall.FdSet + var err *hyper.Error + var response Response + + // The polling state machine! + for { + // Poll all ready tasks and act on them... + for { + task := exec.Poll() + + if task == nil { + break + } + + switch (hyper.ExampleId)(uintptr(task.Userdata())) { + case hyper.ExampleHandshake: + if task.Type() == hyper.TaskError { + c.Printf(c.Str("handshake error!\n")) + err = (*hyper.Error)(task.Value()) + Fail(err) + } + if task.Type() != hyper.TaskClientConn { + c.Printf(c.Str("unexpected task type\n")) + Fail(err) + } + + client := (*hyper.ClientConn)(task.Value()) + task.Free() + + // Send it! + sendTask := client.Send(req) + SetUserData(sendTask, hyper.ExampleSend) + sendRes := exec.Push(sendTask) + if sendRes != hyper.OK { + panic("error send\n") + } + + // For this example, no longer need the client + client.Free() + + break + case hyper.ExampleSend: + if task.Type() == hyper.TaskError { + c.Printf(c.Str("send error!\n")) + err = (*hyper.Error)(task.Value()) + Fail(err) + } + if task.Type() != hyper.TaskResponse { + c.Printf(c.Str("unexpected task type\n")) + Fail(err) + } + + // Take the results + resp := (*hyper.Response)(task.Value()) + task.Free() + + rp := resp.ReasonPhrase() + rpLen := resp.ReasonPhraseLen() + + response.Status = strconv.Itoa(int(resp.Status())) + " " + string((*[1 << 30]byte)(c.Pointer(rp))[:rpLen:rpLen]) + response.StatusCode = int(resp.Status()) + + headers := resp.Headers() + headers.Foreach(AppendToResponseHeader, c.Pointer(&response)) + respBody := resp.Body() + + foreachTask := respBody.Foreach(AppendToResponseBody, c.Pointer(&response)) + + SetUserData(foreachTask, hyper.ExampleRespBody) + exec.Push(foreachTask) + + // No longer need the response + resp.Free() + + break + case hyper.ExampleRespBody: + if task.Type() == hyper.TaskError { + c.Printf(c.Str("body error!\n")) + err = (*hyper.Error)(task.Value()) + Fail(err) + } + if task.Type() != hyper.TaskEmpty { + c.Printf(c.Str("unexpected task type\n")) + Fail(err) + } + + // Cleaning up before exiting + task.Free() + exec.Free() + FreeConnData(connData) + + if response.respBodyWriter != nil { + defer response.respBodyWriter.Close() + } + + return &response + case hyper.ExampleNotSet: + // A background task for hyper_client completed... + task.Free() + break + } + } + + // All futures are pending on IO work, so select on the fds. + + sys.FD_ZERO(&fdsRead) + sys.FD_ZERO(&fdsWrite) + sys.FD_ZERO(&fdsExcep) + + if connData.ReadWaker != nil { + sys.FD_SET(connData.Fd, &fdsRead) + } + if connData.WriteWaker != nil { + sys.FD_SET(connData.Fd, &fdsWrite) + } + + // Set the default request timeout + var tv syscall.Timeval + tv.Sec = 10 + + selRet := sys.Select(connData.Fd+1, &fdsRead, &fdsWrite, &fdsExcep, &tv) + if selRet < 0 { + panic("select() error\n") + } else if selRet == 0 { + panic("select() timeout\n") + } + + if sys.FD_ISSET(connData.Fd, &fdsRead) != 0 { + connData.ReadWaker.Wake() + connData.ReadWaker = nil + } + + if sys.FD_ISSET(connData.Fd, &fdsWrite) != 0 { + connData.WriteWaker.Wake() + connData.WriteWaker = nil + } + } +} + +// ConnectTo connects to a host and port +func ConnectTo(host string, port string) c.Int { + var hints net.AddrInfo + hints.Family = net.AF_UNSPEC + hints.SockType = net.SOCK_STREAM + + var result, rp *net.AddrInfo + + if net.Getaddrinfo(c.AllocaCStr(host), c.AllocaCStr(port), &hints, &result) != 0 { + panic(fmt.Sprintf("dns failed for %s\n", host)) + } + + var sfd c.Int + for rp = result; rp != nil; rp = rp.Next { + sfd = net.Socket(rp.Family, rp.SockType, rp.Protocol) + if sfd == -1 { + continue + } + if net.Connect(sfd, rp.Addr, rp.AddrLen) != -1 { + break + } + os.Close(sfd) + } + + net.Freeaddrinfo(result) + + // no address succeeded + if rp == nil || sfd < 0 { + panic(fmt.Sprintf("connect failed for %s\n", host)) + } + + if os.Fcntl(sfd, os.F_SETFL, os.O_NONBLOCK) != 0 { + panic("failed to set net to non-blocking\n") + } + return sfd +} + +// ReadCallBack is the callback for reading from a socket +func ReadCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { + conn := (*ConnData)(userdata) + + ret := os.Read(conn.Fd, c.Pointer(buf), bufLen) + + if ret >= 0 { + return uintptr(ret) + } + + if os.Errno != os.EAGAIN { + c.Perror(c.Str("[read callback fail]")) + // kaboom + return hyper.IoError + } + + // would block, register interest + if conn.ReadWaker != nil { + conn.ReadWaker.Free() + } + conn.ReadWaker = ctx.Waker() + return hyper.IoPending +} + +// WriteCallBack is the callback for writing to a socket +func WriteCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { + conn := (*ConnData)(userdata) + ret := os.Write(conn.Fd, c.Pointer(buf), bufLen) + + if int(ret) >= 0 { + return uintptr(ret) + } + + if os.Errno != os.EAGAIN { + c.Perror(c.Str("[write callback fail]")) + // kaboom + return hyper.IoError + } + + // would block, register interest + if conn.WriteWaker != nil { + conn.WriteWaker.Free() + } + conn.WriteWaker = ctx.Waker() + return hyper.IoPending +} + +// FreeConnData frees the connection data +func FreeConnData(conn *ConnData) { + if conn.ReadWaker != nil { + conn.ReadWaker.Free() + conn.ReadWaker = nil + } + if conn.WriteWaker != nil { + conn.WriteWaker.Free() + conn.WriteWaker = nil + } +} + +// Fail prints the error details and panics +func Fail(err *hyper.Error) { + if err != nil { + c.Printf(c.Str("error code: %d\n"), err.Code()) + // grab the error details + var errBuf [256]c.Char + errLen := err.Print((*uint8)(c.Pointer(&errBuf[:][0])), uintptr(len(errBuf))) + + c.Printf(c.Str("details: %.*s\n"), c.Int(errLen), c.Pointer(&errBuf[:][0])) + c.Printf(c.Str("details: ")) + for i := 0; i < int(errLen); i++ { + c.Printf(c.Str("%c"), errBuf[i]) + } + c.Printf(c.Str("\n")) + + // clean up the error + err.Free() + panic("request failed\n") + } + return +} + +// NewConnData creates a new connection data +func NewConnData(fd c.Int) *ConnData { + return &ConnData{Fd: fd, ReadWaker: nil, WriteWaker: nil} +} + +// NewIoWithConnReadWrite creates a new IO with read and write callbacks +func NewIoWithConnReadWrite(connData *ConnData) *hyper.Io { + io := hyper.NewIo() + io.SetUserdata(c.Pointer(connData)) + io.SetRead(ReadCallBack) + io.SetWrite(WriteCallBack) + return io +} + +// parseURL Parse the URL and extract the host name, port number, and URI +func parseURL(rawURL string) (hostname, port, uri string) { + // 找到 "://" 的位置,以分隔协议和主机名 + schemeEnd := strings.Index(rawURL, "://") + if schemeEnd != -1 { + //scheme = rawURL[:schemeEnd] + rawURL = rawURL[schemeEnd+3:] + } else { + //scheme = "http" // 默认协议为 http + } + + // 找到第一个 "/" 的位置,以分隔主机名和路径 + pathStart := strings.Index(rawURL, "/") + if pathStart != -1 { + uri = rawURL[pathStart:] + rawURL = rawURL[:pathStart] + } else { + uri = "/" + } + + // 找到 ":" 的位置,以分隔主机名和端口号 + portStart := strings.LastIndex(rawURL, ":") + if portStart != -1 { + hostname = rawURL[:portStart] + port = rawURL[portStart+1:] + } else { + hostname = rawURL + port = "" // 未指定端口号 + } + + // 如果未指定端口号,根据协议设置默认端口号 + if port == "" { + //if scheme == "https" { + // port = "443" + //} else { + // port = "80" + //} + port = "80" + } + + return +} diff --git a/x/http/header.go b/x/http/header.go new file mode 100644 index 0000000..ea313a2 --- /dev/null +++ b/x/http/header.go @@ -0,0 +1,34 @@ +package http + +import ( + "fmt" + + "github.com/goplus/llgo/c" + "github.com/goplus/llgoexamples/rust/hyper" +) + +type Header map[string][]string + +// AppendToResponseHeader (HeadersForEachCallback) prints each header to the console +func AppendToResponseHeader(userdata c.Pointer, name *uint8, nameLen uintptr, value *uint8, valueLen uintptr) c.Int { + resp := (*Response)(userdata) + nameStr := string((*[1 << 30]byte)(c.Pointer(name))[:nameLen:nameLen]) + valueStr := string((*[1 << 30]byte)(c.Pointer(value))[:valueLen:valueLen]) + + if resp.Header == nil { + resp.Header = make(map[string][]string) + } + resp.Header[nameStr] = append(resp.Header[nameStr], valueStr) + //c.Printf(c.Str("%.*s: %.*s\n"), int(nameLen), name, int(valueLen), value) + return hyper.IterContinue +} + +func (resp *Response) PrintHeaders() { + for key, values := range resp.Header { + fmt.Printf("%s: ", key) + for _, value := range values { + fmt.Printf(value + "; ") + } + fmt.Printf("\n") + } +} diff --git a/x/http/hyper-go.go b/x/http/hyper-go.go new file mode 100644 index 0000000..a1db081 --- /dev/null +++ b/x/http/hyper-go.go @@ -0,0 +1,12 @@ +package http + +import ( + "github.com/goplus/llgo/c" + "github.com/goplus/llgoexamples/rust/hyper" +) + +// SetUserData Set the user data for the task +func SetUserData(task *hyper.Task, userData hyper.ExampleId) { + var data = userData + task.SetUserdata(c.Pointer(uintptr(data))) +} diff --git a/x/http/response.go b/x/http/response.go new file mode 100644 index 0000000..9263461 --- /dev/null +++ b/x/http/response.go @@ -0,0 +1,87 @@ +package http + +import ( + "fmt" + "io" + "unsafe" + + "github.com/goplus/llgo/c" + "github.com/goplus/llgoexamples/rust/hyper" +) + +type Response struct { + Status string + StatusCode int + Header Header + ResponseBody io.ReadCloser + respBodyWriter *io.PipeWriter + ResponseBodyLen int64 +} + +// AppendToResponseBody (BodyForEachCallback) appends the body to the response +func AppendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { + resp := (*Response)(userdata) + len := chunk.Len() + buf := unsafe.Slice((*byte)(chunk.Bytes()), len) + + if resp.ResponseBody == nil { + var reader *io.PipeReader + reader, resp.respBodyWriter = io.Pipe() + resp.ResponseBody = io.ReadCloser(reader) + } + resp.ResponseBodyLen += int64(len) + var err error + go func() { + _, err = resp.respBodyWriter.Write(buf) + }() + if err != nil { + fmt.Printf("Failed to write response body: %v\n", err) + return hyper.IterBreak + } + return hyper.IterContinue +} + +func (resp *Response) PrintBody() { + var buffer = make([]byte, resp.ResponseBodyLen) + for { + n, err := resp.ResponseBody.Read(buffer) + if err == io.EOF { + fmt.Printf("\n") + break + } + if err != nil { + fmt.Println("Error reading from pipe:", err) + break + } + fmt.Printf("%s", string(buffer[:n])) + } +} + +//// AppendToResponseBody (BodyForEachCallback) appends the body to the response +//func AppendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { +// resp := (*Response)(userdata) +// buf := chunk.Bytes() +// len := chunk.Len() +// responseBody := (*uint8)(c.Malloc(resp.ResponseBodyLen + len)) +// if responseBody == nil { +// c.Fprintf(c.Stderr, c.Str("Failed to allocate memory for response body\n")) +// return hyper.IterBreak +// } +// +// // Copy the existing response body to the new buffer +// if resp.ResponseBody != nil { +// c.Memcpy(c.Pointer(responseBody), c.Pointer(resp.ResponseBody), resp.ResponseBodyLen) +// c.Free(c.Pointer(resp.ResponseBody)) +// } +// +// // Append the new data +// c.Memcpy(c.Pointer(uintptr(c.Pointer(responseBody))+resp.ResponseBodyLen), c.Pointer(buf), len) +// resp.ResponseBody = responseBody +// resp.ResponseBodyLen += len +// return hyper.IterContinue +//} + +//func (resp *Response) PrintBody() { +// //c.Printf(c.Str("%.*s\n"), c.Int(resp.ResponseBodyLen), resp.ResponseBody) +// fmt.Println(string((*[1 << 30]byte)(c.Pointer(resp.ResponseBody))[:resp.ResponseBodyLen:resp.ResponseBodyLen])) +//} From dc55abcfc2205d17eaf81867a8f2fb6f7158c9f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Tue, 30 Jul 2024 16:16:20 +0800 Subject: [PATCH 02/55] feat(x/http/get): Using libuv to speed up http.Get() --- go.mod | 2 +- go.sum | 12 +- x/http/_demo/get/get.go | 11 +- x/http/_demo/test.go | 141 -------------- x/http/client.go | 393 +++++++++++++++++++++++++--------------- x/http/header.go | 4 +- x/http/hyper-go.go | 12 -- x/http/response.go | 117 ++++++------ 8 files changed, 325 insertions(+), 367 deletions(-) delete mode 100644 x/http/_demo/test.go delete mode 100644 x/http/hyper-go.go diff --git a/go.mod b/go.mod index ff4ef8f..39080a4 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module github.com/goplus/llgoexamples go 1.20 -require github.com/goplus/llgo v0.9.4-0.20240729010130-b3b4f55c68f0 +require github.com/goplus/llgo v0.9.5-0.20240731053840-36072584d0be diff --git a/go.sum b/go.sum index 8799390..e2c5d17 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,2 @@ -github.com/goplus/llgo v0.9.0 h1:yaJzQperGUafEaHc9VlVQVskIngacoTNweEXY0GRi0Q= -github.com/goplus/llgo v0.9.0/go.mod h1:M3UwiYdPZFyx7m2J0+6Ti1dYVA3uOO1WvSBocuE8N7M= -github.com/goplus/llgo v0.9.1-0.20240709104849-d6a38a567fda h1:UIPwlgzCb8dV/7WFMyprhZuq8CSLAQIqwFpH5AhrNOM= -github.com/goplus/llgo v0.9.1-0.20240709104849-d6a38a567fda/go.mod h1:zsrtWZapL4aklZc99xBSZRynGzLTIT1mLRjP0VSn9iw= -github.com/goplus/llgo v0.9.1-0.20240712060421-858d38d314a3 h1:2fZ2zQ8S58KvOsJTx6s6MHoi6n1K4sqQwIbTauMrgEE= -github.com/goplus/llgo v0.9.1-0.20240712060421-858d38d314a3/go.mod h1:zsrtWZapL4aklZc99xBSZRynGzLTIT1mLRjP0VSn9iw= -github.com/goplus/llgo v0.9.3-0.20240726020431-98d075728f2b h1:z9FUoeAALL5ytBhhGhE1dXm4+L1Q2eMUTcfiqLAZgf8= -github.com/goplus/llgo v0.9.3-0.20240726020431-98d075728f2b/go.mod h1:zsrtWZapL4aklZc99xBSZRynGzLTIT1mLRjP0VSn9iw= -github.com/goplus/llgo v0.9.4-0.20240729010130-b3b4f55c68f0 h1:02gSx3Oj3cLlBMed+9IWBUGHThEZMnCNiR67yaQbpqo= -github.com/goplus/llgo v0.9.4-0.20240729010130-b3b4f55c68f0/go.mod h1:zsrtWZapL4aklZc99xBSZRynGzLTIT1mLRjP0VSn9iw= +github.com/goplus/llgo v0.9.5-0.20240731053840-36072584d0be h1:FTALxA3ivIeVRAO93e1hCSCLaPbjKn+RZx40p5lx8KE= +github.com/goplus/llgo v0.9.5-0.20240731053840-36072584d0be/go.mod h1:zsrtWZapL4aklZc99xBSZRynGzLTIT1mLRjP0VSn9iw= diff --git a/x/http/_demo/get/get.go b/x/http/_demo/get/get.go index e7591a6..09f32a0 100644 --- a/x/http/_demo/get/get.go +++ b/x/http/_demo/get/get.go @@ -8,10 +8,17 @@ import ( func main() { // 使用 http.Get 发送 GET 请求 - resp := http.Get("https://www.baidu.com/") + resp, err := http.Get("https://www.baidu.com/") + if err != nil { + fmt.Println(err) + return + } fmt.Println(resp.Status) fmt.Println(resp.StatusCode) resp.PrintHeaders() fmt.Println() - resp.PrintBody() + resp.PrintBody2() + + resp.PrintBody1() + defer resp.Content.Close() } diff --git a/x/http/_demo/test.go b/x/http/_demo/test.go deleted file mode 100644 index 2ddd3ea..0000000 --- a/x/http/_demo/test.go +++ /dev/null @@ -1,141 +0,0 @@ -package main - -import ( - "fmt" - "io" -) - -func main() { - - // 假设你有一个 []byte 数组 - data := []byte("This is some data that needs to be stored in Body.") - - // 创建一个 io.Pipe - pr, pw := io.Pipe() - - // 启动一个 goroutine 将数据写入 io.Pipe 的写入端 - go func() { - defer pw.Close() // 确保写入完成后关闭写入端 - - if _, err := pw.Write(data); err != nil { - fmt.Println("Error writing to pipe:", err) - return - } - }() - - // 读取 Body 中的数据进行验证 - readData, err := io.ReadAll(pr) - if err != nil { - fmt.Println("Error reading from Body:", err) - return - } - - // 输出 Body 中的数据 - fmt.Println("Body content:", string(readData)) - // - //http.Get() - - //r, w := io.Pipe() - // - //go func() { - // fmt.Fprint(w, "some io.Reader stream to be read\n") - // w.Close() - //}() - // - //if _, err := io.Copy(os.Stdout, r); err != nil { - // log.Fatal(err) - //} - - // 使用 http.Get 发送 GET 请求 - //resp, err := http.Get("https://www.baidu.com/") - //if err != nil { - // fmt.Println("Error:", err) - // return - //} - //defer resp.Body.Close() - // - //body, err := io.ReadAll(resp.Body) - //if err != nil { - // fmt.Println("Error reading response:", err) - // return - //} - //fmt.Println("GET Response:\n", string(body)) - - //rawURL := "http://example.com:8080/path/to/resource?query=123#fragment" - //parsedURL, err := url.Parse(rawURL) - //if err != nil { - // fmt.Println("Error parsing URL:", err) - // return - //} - // - //hostname := parsedURL.Hostname() - //port := parsedURL.Port() - // - //uri := parsedURL.RequestURI() - // - //fmt.Println("Hostname:", hostname) - //fmt.Println("Port:", port) - //fmt.Println("URI:", uri) - - //// 使用 http.Post 发送 POST 请求上传文件 - //file, err := os.Open("path/to/your/file.jpg") - //if err != nil { - // fmt.Println("Error opening file:", err) - // return - //} - //defer file.Close() - // - //var buf bytes.Buffer - //writer := multipart.NewWriter(&buf) - //_, err = writer.CreateFormFile("file", "file.jpg") - //if err != nil { - // fmt.Println("Error creating form file:", err) - // return - //} - // - //_, err = io.ReadAll(file) - //if err != nil { - // fmt.Println("Error reading file:", err) - // return - //} - // - //err = writer.Close() - //if err != nil { - // fmt.Println("Error closing writer:", err) - // return - //} - // - //resp, err = http.Post("https://www.baidu.com/upload", writer.FormDataContentType(), &buf) - //if err != nil { - // fmt.Println("Error:", err) - // return - //} - //defer resp.Body.Close() - // - //body, err = io.ReadAll(resp.Body) - //if err != nil { - // fmt.Println("Error reading response:", err) - // return - //} - //fmt.Println("POST Response:\n", string(body)) - // - //// 使用 http.PostForm 发送表单数据 - //formData := url.Values{ - // "key": {"Value"}, - // "id": {"123"}, - //} - // - //resp, err = http.PostForm("https://www.baidu.com/form", formData) - //if err != nil { - // fmt.Println("Error:", err) - // return - //} - //defer resp.Body.Close() - // - //body, err = io.ReadAll(resp.Body) - //if err != nil { - // fmt.Println("Error reading response:", err) - // return - //} - //fmt.Println("POST Form Response:\n", string(body)) -} diff --git a/x/http/client.go b/x/http/client.go index 32cf658..d173574 100644 --- a/x/http/client.go +++ b/x/http/client.go @@ -2,63 +2,127 @@ package http import ( "fmt" + io2 "io" "strconv" "strings" + "unsafe" "github.com/goplus/llgo/c" + "github.com/goplus/llgo/c/libuv" "github.com/goplus/llgo/c/net" - "github.com/goplus/llgo/c/os" - "github.com/goplus/llgo/c/sys" "github.com/goplus/llgo/c/syscall" "github.com/goplus/llgoexamples/rust/hyper" ) type ConnData struct { - Fd c.Int - ReadWaker *hyper.Waker - WriteWaker *hyper.Waker + TcpHandle libuv.Tcp + ConnectReq libuv.Connect + ReadBuf libuv.Buf + ReadBufFilled uintptr + ReadWaker *hyper.Waker + WriteWaker *hyper.Waker } -type RequestConfig struct { - ReqMethod string - ReqHost string - ReqPort string - ReqUri string - ReqHeaders map[string]string - ReqHTTPVersion hyper.HTTPVersion - TimeoutSec int64 - TimeoutUsec int32 - //ReqBody - //ReqURIParts +type Client struct { + Transport RoundTripper } -func Get(url string) *Response { - host, port, uri := parseURL(url) - req := hyper.NewRequest() +var DefaultClient = &Client{} + +type RoundTripper interface { + RoundTrip(*hyper.Request) (*Response, error) +} + +func (c *Client) transport() RoundTripper { + if c.Transport != nil { + return c.Transport + } + return DefaultTransport +} + +func Get2(url string) (*Response, error) { + return DefaultClient.Get(url) +} +func (c *Client) Get(url string) (*Response, error) { + req, err := NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + return c.Do(req) +} + +func (c *Client) Do(req *hyper.Request) (*Response, error) { + return c.do(req) +} + +func (c *Client) do(req *hyper.Request) (*Response, error) { + return c.send(req, nil) +} + +func (c *Client) send(req *hyper.Request, deadline any) (*Response, error) { + return send(req, c.transport(), deadline) +} + +func send(req *hyper.Request, rt RoundTripper, deadline any) (resp *Response, err error) { + return rt.RoundTrip(req) +} + +func NewRequest(method, url string, body io2.Reader) (*hyper.Request, error) { + host, _, uri := parseURL(url) // Prepare the request + req := hyper.NewRequest() // Set the request method and uri - if req.SetMethod((*uint8)(&[]byte("GET")[0]), c.Strlen(c.Str("GET"))) != hyper.OK { - panic(fmt.Sprintf("error setting method %s\n", "GET")) + if req.SetMethod((*uint8)(&[]byte(method)[0]), c.Strlen(c.AllocaCStr(method))) != hyper.OK { + return nil, fmt.Errorf("error setting method %s\n", method) } if req.SetURI((*uint8)(&[]byte(uri)[0]), c.Strlen(c.AllocaCStr(uri))) != hyper.OK { - panic(fmt.Sprintf("error setting uri %s\n", uri)) + return nil, fmt.Errorf("error setting uri %s\n", uri) } // Set the request headers reqHeaders := req.Headers() if reqHeaders.Set((*uint8)(&[]byte("Host")[0]), c.Strlen(c.Str("Host")), (*uint8)(&[]byte(host)[0]), c.Strlen(c.AllocaCStr(host))) != hyper.OK { - panic("error setting headers\n") + return nil, fmt.Errorf("error setting headers\n") + } + return req, nil +} + +func Get(url string) (_ *Response, err error) { + host, port, uri := parseURL(url) + + loop := libuv.DefaultLoop() + conn := (*ConnData)(c.Malloc(unsafe.Sizeof(ConnData{}))) + if conn == nil { + return nil, fmt.Errorf("Failed to allocate memory for conn_data\n") } - //var response RequestResponse + libuv.InitTcp(loop, &conn.TcpHandle) + //conn.TcpHandle.Data = c.Pointer(conn) + (*libuv.Handle)(c.Pointer(&conn.TcpHandle)).SetData(c.Pointer(conn)) + + var hints net.AddrInfo + c.Memset(c.Pointer(&hints), 0, unsafe.Sizeof(hints)) + hints.Family = syscall.AF_UNSPEC + hints.SockType = syscall.SOCK_STREAM + + var res *net.AddrInfo + status := net.Getaddrinfo(c.AllocaCStr(host), c.AllocaCStr(port), &hints, &res) + if status != 0 { + return nil, fmt.Errorf("getaddrinfo error\n") + } - fd := ConnectTo(host, port) + //conn.ConnectReq.Data = c.Pointer(conn) + (*libuv.Req)(c.Pointer(&conn.ConnectReq)).SetData(c.Pointer(conn)) + status = libuv.TcpConnect(&conn.ConnectReq, &conn.TcpHandle, res.Addr, OnConnect) + if status != 0 { + return nil, fmt.Errorf("connect error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(status)))) + } - connData := NewConnData(fd) + net.Freeaddrinfo(res) // Hookup the IO - io := NewIoWithConnReadWrite(connData) + io := NewIoWithConnReadWrite(conn) // We need an executor generally to poll futures exec := hyper.NewExecutor() @@ -73,8 +137,7 @@ func Get(url string) *Response { // Let's wait for the handshake to finish... exec.Push(handshakeTask) - var fdsRead, fdsWrite, fdsExcep syscall.FdSet - var err *hyper.Error + var hyperErr *hyper.Error var response Response // The polling state machine! @@ -82,7 +145,6 @@ func Get(url string) *Response { // Poll all ready tasks and act on them... for { task := exec.Poll() - if task == nil { break } @@ -91,23 +153,41 @@ func Get(url string) *Response { case hyper.ExampleHandshake: if task.Type() == hyper.TaskError { c.Printf(c.Str("handshake error!\n")) - err = (*hyper.Error)(task.Value()) - Fail(err) + hyperErr = (*hyper.Error)(task.Value()) + err = Fail(hyperErr) + return nil, err } if task.Type() != hyper.TaskClientConn { c.Printf(c.Str("unexpected task type\n")) - Fail(err) + err = Fail(hyperErr) + return nil, err } client := (*hyper.ClientConn)(task.Value()) task.Free() + // Prepare the request + req := hyper.NewRequest() + // Set the request method and uri + if req.SetMethod((*uint8)(&[]byte("GET")[0]), c.Strlen(c.Str("GET"))) != hyper.OK { + return nil, fmt.Errorf("error setting method %s\n", "GET") + } + if req.SetURI((*uint8)(&[]byte(uri)[0]), c.Strlen(c.AllocaCStr(uri))) != hyper.OK { + return nil, fmt.Errorf("error setting uri %s\n", uri) + } + + // Set the request headers + reqHeaders := req.Headers() + if reqHeaders.Set((*uint8)(&[]byte("Host")[0]), c.Strlen(c.Str("Host")), (*uint8)(&[]byte(host)[0]), c.Strlen(c.AllocaCStr(host))) != hyper.OK { + return nil, fmt.Errorf("error setting headers\n") + } + // Send it! sendTask := client.Send(req) SetUserData(sendTask, hyper.ExampleSend) sendRes := exec.Push(sendTask) if sendRes != hyper.OK { - panic("error send\n") + return nil, fmt.Errorf("error send\n") } // For this example, no longer need the client @@ -117,12 +197,14 @@ func Get(url string) *Response { case hyper.ExampleSend: if task.Type() == hyper.TaskError { c.Printf(c.Str("send error!\n")) - err = (*hyper.Error)(task.Value()) - Fail(err) + hyperErr = (*hyper.Error)(task.Value()) + err = Fail(hyperErr) + return nil, err } if task.Type() != hyper.TaskResponse { c.Printf(c.Str("unexpected task type\n")) - Fail(err) + err = Fail(hyperErr) + return nil, err } // Take the results @@ -139,133 +221,127 @@ func Get(url string) *Response { headers.Foreach(AppendToResponseHeader, c.Pointer(&response)) respBody := resp.Body() + response.Body, response.respBodyWriter = io2.Pipe() + + /*go func() { + fmt.Println("writing...") + for { + fmt.Println("writing for...") + dataTask := respBody.Data() + exec.Push(dataTask) + dataTask = exec.Poll() + if dataTask.Type() == hyper.TaskBuf { + buf := (*hyper.Buf)(dataTask.Value()) + len := buf.Len() + bytes := unsafe.Slice((*byte)(buf.Bytes()), len) + _, err := response.respBodyWriter.Write(bytes) + if err != nil { + fmt.Printf("Failed to write response body: %v\n", err) + break + } + dataTask.Free() + } else if dataTask.Type() == hyper.TaskEmpty { + fmt.Println("writing empty") + dataTask.Free() + break + } + } + fmt.Println("end writing") + defer response.respBodyWriter.Close() + }()*/ + foreachTask := respBody.Foreach(AppendToResponseBody, c.Pointer(&response)) SetUserData(foreachTask, hyper.ExampleRespBody) exec.Push(foreachTask) + return &response, nil + // No longer need the response - resp.Free() + //resp.Free() break case hyper.ExampleRespBody: + println("ExampleRespBody") if task.Type() == hyper.TaskError { c.Printf(c.Str("body error!\n")) - err = (*hyper.Error)(task.Value()) - Fail(err) + hyperErr = (*hyper.Error)(task.Value()) + err = Fail(hyperErr) + return nil, err } if task.Type() != hyper.TaskEmpty { c.Printf(c.Str("unexpected task type\n")) - Fail(err) + err = Fail(hyperErr) + return nil, err } // Cleaning up before exiting task.Free() exec.Free() - FreeConnData(connData) + (*libuv.Handle)(c.Pointer(&conn.TcpHandle)).Close(nil) - if response.respBodyWriter != nil { - defer response.respBodyWriter.Close() - } + FreeConnData(conn) + + //if response.respBodyWriter != nil { + // defer response.respBodyWriter.Close() + //} - return &response + return &response, nil case hyper.ExampleNotSet: + println("ExampleNotSet") // A background task for hyper_client completed... task.Free() break } } - // All futures are pending on IO work, so select on the fds. - - sys.FD_ZERO(&fdsRead) - sys.FD_ZERO(&fdsWrite) - sys.FD_ZERO(&fdsExcep) - - if connData.ReadWaker != nil { - sys.FD_SET(connData.Fd, &fdsRead) - } - if connData.WriteWaker != nil { - sys.FD_SET(connData.Fd, &fdsWrite) - } - - // Set the default request timeout - var tv syscall.Timeval - tv.Sec = 10 - - selRet := sys.Select(connData.Fd+1, &fdsRead, &fdsWrite, &fdsExcep, &tv) - if selRet < 0 { - panic("select() error\n") - } else if selRet == 0 { - panic("select() timeout\n") - } - - if sys.FD_ISSET(connData.Fd, &fdsRead) != 0 { - connData.ReadWaker.Wake() - connData.ReadWaker = nil - } - - if sys.FD_ISSET(connData.Fd, &fdsWrite) != 0 { - connData.WriteWaker.Wake() - connData.WriteWaker = nil - } + libuv.Run(loop, libuv.RUN_ONCE) } } -// ConnectTo connects to a host and port -func ConnectTo(host string, port string) c.Int { - var hints net.AddrInfo - hints.Family = net.AF_UNSPEC - hints.SockType = net.SOCK_STREAM - - var result, rp *net.AddrInfo - - if net.Getaddrinfo(c.AllocaCStr(host), c.AllocaCStr(port), &hints, &result) != 0 { - panic(fmt.Sprintf("dns failed for %s\n", host)) +// AllocBuffer allocates a buffer for reading from a socket +func AllocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { + //conn := (*ConnData)(handle.Data) + //conn := (*struct{ data *ConnData })(c.Pointer(handle)).data + conn := (*ConnData)(handle.GetData()) + if conn.ReadBuf.Base == nil { + conn.ReadBuf = libuv.InitBuf((*c.Char)(c.Malloc(suggestedSize)), c.Uint(suggestedSize)) + conn.ReadBufFilled = 0 } + *buf = libuv.InitBuf((*c.Char)(c.Pointer(uintptr(c.Pointer(conn.ReadBuf.Base))+conn.ReadBufFilled)), c.Uint(suggestedSize-conn.ReadBufFilled)) +} - var sfd c.Int - for rp = result; rp != nil; rp = rp.Next { - sfd = net.Socket(rp.Family, rp.SockType, rp.Protocol) - if sfd == -1 { - continue - } - if net.Connect(sfd, rp.Addr, rp.AddrLen) != -1 { - break - } - os.Close(sfd) - } - - net.Freeaddrinfo(result) - - // no address succeeded - if rp == nil || sfd < 0 { - panic(fmt.Sprintf("connect failed for %s\n", host)) +// OnRead is the libuv callback for reading from a socket +func OnRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { + //conn := (*ConnData)(stream.Data) + //conn := (*struct{ data *ConnData })(c.Pointer(stream)).data + conn := (*ConnData)((*libuv.Handle)(c.Pointer(stream)).GetData()) + if nread > 0 { + conn.ReadBufFilled += uintptr(nread) } - - if os.Fcntl(sfd, os.F_SETFL, os.O_NONBLOCK) != 0 { - panic("failed to set net to non-blocking\n") + if conn.ReadWaker != nil { + conn.ReadWaker.Wake() + conn.ReadWaker = nil } - return sfd } -// ReadCallBack is the callback for reading from a socket +// ReadCallBack is the hyper callback for reading from a socket func ReadCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { conn := (*ConnData)(userdata) - ret := os.Read(conn.Fd, c.Pointer(buf), bufLen) - - if ret >= 0 { - return uintptr(ret) - } - - if os.Errno != os.EAGAIN { - c.Perror(c.Str("[read callback fail]")) - // kaboom - return hyper.IoError + if conn.ReadBufFilled > 0 { + var toCopy uintptr + if bufLen < conn.ReadBufFilled { + toCopy = bufLen + } else { + toCopy = conn.ReadBufFilled + } + c.Memcpy(c.Pointer(buf), c.Pointer(conn.ReadBuf.Base), toCopy) + c.Memmove(c.Pointer(conn.ReadBuf.Base), c.Pointer(uintptr(c.Pointer(conn.ReadBuf.Base))+toCopy), conn.ReadBufFilled-toCopy) + conn.ReadBufFilled -= toCopy + return toCopy } - // would block, register interest if conn.ReadWaker != nil { conn.ReadWaker.Free() } @@ -273,22 +349,32 @@ func ReadCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uin return hyper.IoPending } -// WriteCallBack is the callback for writing to a socket -func WriteCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { - conn := (*ConnData)(userdata) - ret := os.Write(conn.Fd, c.Pointer(buf), bufLen) +// OnWrite is the libuv callback for writing to a socket +func OnWrite(req *libuv.Write, status c.Int) { + //conn := (*ConnData)(req.Data) + //conn := (*struct{ data *ConnData })(c.Pointer(req)).data + conn := (*ConnData)((*libuv.Req)(c.Pointer(req)).GetData()) - if int(ret) >= 0 { - return uintptr(ret) + if conn.WriteWaker != nil { + conn.WriteWaker.Wake() + conn.WriteWaker = nil } + c.Free(c.Pointer(req)) +} + +// WriteCallBack is the hyper callback for writing to a socket +func WriteCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { + conn := (*ConnData)(userdata) + initBuf := libuv.InitBuf((*c.Char)(c.Pointer(buf)), c.Uint(bufLen)) + req := (*libuv.Write)(c.Malloc(unsafe.Sizeof(libuv.Write{}))) + //req.Data = c.Pointer(conn) + (*libuv.Req)(c.Pointer(req)).SetData(c.Pointer(conn)) + ret := req.Write((*libuv.Stream)(c.Pointer(&conn.TcpHandle)), &initBuf, 1, OnWrite) - if os.Errno != os.EAGAIN { - c.Perror(c.Str("[write callback fail]")) - // kaboom - return hyper.IoError + if ret >= 0 { + return bufLen } - // would block, register interest if conn.WriteWaker != nil { conn.WriteWaker.Free() } @@ -296,6 +382,19 @@ func WriteCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen ui return hyper.IoPending } +// OnConnect is the libuv callback for a successful connection +func OnConnect(req *libuv.Connect, status c.Int) { + //conn := (*ConnData)(req.Data) + //conn := (*struct{ data *ConnData })(c.Pointer(req)).data + conn := (*ConnData)((*libuv.Req)(c.Pointer(req)).GetData()) + + if status < 0 { + c.Fprintf(c.Stderr, c.Str("connect error: %d\n"), libuv.Strerror(libuv.Errno(status))) + return + } + (*libuv.Stream)(c.Pointer(&conn.TcpHandle)).StartRead(AllocBuffer, OnRead) +} + // FreeConnData frees the connection data func FreeConnData(conn *ConnData) { if conn.ReadWaker != nil { @@ -306,10 +405,15 @@ func FreeConnData(conn *ConnData) { conn.WriteWaker.Free() conn.WriteWaker = nil } + if conn.ReadBuf.Base != nil { + c.Free(c.Pointer(conn.ReadBuf.Base)) + conn.ReadBuf.Base = nil + } + c.Free(c.Pointer(conn)) } // Fail prints the error details and panics -func Fail(err *hyper.Error) { +func Fail(err *hyper.Error) error { if err != nil { c.Printf(c.Str("error code: %d\n"), err.Code()) // grab the error details @@ -317,22 +421,12 @@ func Fail(err *hyper.Error) { errLen := err.Print((*uint8)(c.Pointer(&errBuf[:][0])), uintptr(len(errBuf))) c.Printf(c.Str("details: %.*s\n"), c.Int(errLen), c.Pointer(&errBuf[:][0])) - c.Printf(c.Str("details: ")) - for i := 0; i < int(errLen); i++ { - c.Printf(c.Str("%c"), errBuf[i]) - } - c.Printf(c.Str("\n")) // clean up the error err.Free() - panic("request failed\n") + return fmt.Errorf("hyper error\n") } - return -} - -// NewConnData creates a new connection data -func NewConnData(fd c.Int) *ConnData { - return &ConnData{Fd: fd, ReadWaker: nil, WriteWaker: nil} + return nil } // NewIoWithConnReadWrite creates a new IO with read and write callbacks @@ -344,6 +438,12 @@ func NewIoWithConnReadWrite(connData *ConnData) *hyper.Io { return io } +// SetUserData Set the user data for the task +func SetUserData(task *hyper.Task, userData hyper.ExampleId) { + var data = userData + task.SetUserdata(c.Pointer(uintptr(data))) +} + // parseURL Parse the URL and extract the host name, port number, and URI func parseURL(rawURL string) (hostname, port, uri string) { // 找到 "://" 的位置,以分隔协议和主机名 @@ -383,6 +483,5 @@ func parseURL(rawURL string) (hostname, port, uri string) { //} port = "80" } - return } diff --git a/x/http/header.go b/x/http/header.go index ea313a2..4710854 100644 --- a/x/http/header.go +++ b/x/http/header.go @@ -25,10 +25,8 @@ func AppendToResponseHeader(userdata c.Pointer, name *uint8, nameLen uintptr, va func (resp *Response) PrintHeaders() { for key, values := range resp.Header { - fmt.Printf("%s: ", key) for _, value := range values { - fmt.Printf(value + "; ") + fmt.Printf("%s: %s\n", key, value) } - fmt.Printf("\n") } } diff --git a/x/http/hyper-go.go b/x/http/hyper-go.go deleted file mode 100644 index a1db081..0000000 --- a/x/http/hyper-go.go +++ /dev/null @@ -1,12 +0,0 @@ -package http - -import ( - "github.com/goplus/llgo/c" - "github.com/goplus/llgoexamples/rust/hyper" -) - -// SetUserData Set the user data for the task -func SetUserData(task *hyper.Task, userData hyper.ExampleId) { - var data = userData - task.SetUserdata(c.Pointer(uintptr(data))) -} diff --git a/x/http/response.go b/x/http/response.go index 9263461..020f2d9 100644 --- a/x/http/response.go +++ b/x/http/response.go @@ -3,7 +3,6 @@ package http import ( "fmt" "io" - "unsafe" "github.com/goplus/llgo/c" "github.com/goplus/llgoexamples/rust/hyper" @@ -13,38 +12,51 @@ type Response struct { Status string StatusCode int Header Header - ResponseBody io.ReadCloser + Content io.ReadCloser + ContentLen int64 respBodyWriter *io.PipeWriter - ResponseBodyLen int64 + ResponseBody *uint8 + ResponseBodyLen uintptr } // AppendToResponseBody (BodyForEachCallback) appends the body to the response -func AppendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { - resp := (*Response)(userdata) - len := chunk.Len() - buf := unsafe.Slice((*byte)(chunk.Bytes()), len) +//func AppendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { +// resp := (*Response)(userdata) +// len := chunk.Len() +// buf := unsafe.Slice((*byte)(chunk.Bytes()), len) +// +// if resp.Content == nil { +// var reader *io.PipeReader +// reader, resp.respBodyWriter = io.Pipe() +// resp.Content = io.ReadCloser(reader) +// } +// resp.ContentLen += int64(len) +// var err error +// go func() { +// _, err = resp.respBodyWriter.Write(buf) +// }() +// if err != nil { +// fmt.Printf("Failed to write response body: %v\n", err) +// return hyper.IterBreak +// } +// return hyper.IterContinue +//} - if resp.ResponseBody == nil { - var reader *io.PipeReader - reader, resp.respBodyWriter = io.Pipe() - resp.ResponseBody = io.ReadCloser(reader) - } - resp.ResponseBodyLen += int64(len) - var err error +func (resp *Response) PrintBody1() { go func() { - _, err = resp.respBodyWriter.Write(buf) + var reader *io.PipeReader + reader, writer := io.Pipe() + resp.Content = reader + writer.Write((*[1 << 30]byte)(c.Pointer(resp.ResponseBody))[:resp.ResponseBodyLen:resp.ResponseBodyLen]) + defer writer.Close() }() - if err != nil { - fmt.Printf("Failed to write response body: %v\n", err) - return hyper.IterBreak + for i := 0; i < 10; i++ { + c.Usleep(1 * 1000 * 1000) + fmt.Println("Sleeping...") } - return hyper.IterContinue -} - -func (resp *Response) PrintBody() { - var buffer = make([]byte, resp.ResponseBodyLen) + var buffer = make([]byte, 4096) for { - n, err := resp.ResponseBody.Read(buffer) + n, err := resp.Content.Read(buffer) if err == io.EOF { fmt.Printf("\n") break @@ -55,33 +67,36 @@ func (resp *Response) PrintBody() { } fmt.Printf("%s", string(buffer[:n])) } + buffer = nil + //body, _ := io.ReadAll(resp.Content) + //fmt.Println(string(body)) } -//// AppendToResponseBody (BodyForEachCallback) appends the body to the response -//func AppendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { -// resp := (*Response)(userdata) -// buf := chunk.Bytes() -// len := chunk.Len() -// responseBody := (*uint8)(c.Malloc(resp.ResponseBodyLen + len)) -// if responseBody == nil { -// c.Fprintf(c.Stderr, c.Str("Failed to allocate memory for response body\n")) -// return hyper.IterBreak -// } -// -// // Copy the existing response body to the new buffer -// if resp.ResponseBody != nil { -// c.Memcpy(c.Pointer(responseBody), c.Pointer(resp.ResponseBody), resp.ResponseBodyLen) -// c.Free(c.Pointer(resp.ResponseBody)) -// } -// -// // Append the new data -// c.Memcpy(c.Pointer(uintptr(c.Pointer(responseBody))+resp.ResponseBodyLen), c.Pointer(buf), len) -// resp.ResponseBody = responseBody -// resp.ResponseBodyLen += len -// return hyper.IterContinue -//} +// AppendToResponseBody (BodyForEachCallback) appends the body to the response +func AppendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { + resp := (*Response)(userdata) + buf := chunk.Bytes() + len := chunk.Len() + responseBody := (*uint8)(c.Malloc(resp.ResponseBodyLen + len)) + if responseBody == nil { + c.Fprintf(c.Stderr, c.Str("Failed to allocate memory for response body\n")) + return hyper.IterBreak + } -//func (resp *Response) PrintBody() { -// //c.Printf(c.Str("%.*s\n"), c.Int(resp.ResponseBodyLen), resp.ResponseBody) -// fmt.Println(string((*[1 << 30]byte)(c.Pointer(resp.ResponseBody))[:resp.ResponseBodyLen:resp.ResponseBodyLen])) -//} + // Copy the existing response body to the new buffer + if resp.ResponseBody != nil { + c.Memcpy(c.Pointer(responseBody), c.Pointer(resp.ResponseBody), resp.ResponseBodyLen) + c.Free(c.Pointer(resp.ResponseBody)) + } + + // Append the new data + c.Memcpy(c.Pointer(uintptr(c.Pointer(responseBody))+resp.ResponseBodyLen), c.Pointer(buf), len) + resp.ResponseBody = responseBody + resp.ResponseBodyLen += len + return hyper.IterContinue +} + +func (resp *Response) PrintBody2() { + //c.Printf(c.Str("%.*s\n"), c.Int(resp.ResponseBodyLen), resp.ResponseBody) + fmt.Println(string((*[1 << 30]byte)(c.Pointer(resp.ResponseBody))[:resp.ResponseBodyLen:resp.ResponseBodyLen])) +} From 685154ff53b5f2eb74fe0857325d80bc0e8b0797 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Fri, 2 Aug 2024 18:08:38 +0800 Subject: [PATCH 03/55] WIP(x/http-get): Use channels to pass responses --- x/httpget/_demo/get/get.go | 24 ++ x/httpget/client.go | 46 ++++ x/httpget/header.go | 32 +++ x/httpget/request.go | 42 ++++ x/httpget/response.go | 54 ++++ x/httpget/transport.go | 502 +++++++++++++++++++++++++++++++++++++ 6 files changed, 700 insertions(+) create mode 100644 x/httpget/_demo/get/get.go create mode 100644 x/httpget/client.go create mode 100644 x/httpget/header.go create mode 100644 x/httpget/request.go create mode 100644 x/httpget/response.go create mode 100644 x/httpget/transport.go diff --git a/x/httpget/_demo/get/get.go b/x/httpget/_demo/get/get.go new file mode 100644 index 0000000..da674ba --- /dev/null +++ b/x/httpget/_demo/get/get.go @@ -0,0 +1,24 @@ +package main + +import ( + "fmt" + "io" + + "github.com/goplus/llgo/x/httpget" +) + +func main() { + resp, err := httpget.Get("www.baidu.com") + //req, _ := httpget.NewRequest("GET", "http://www.baidu.com", nil) + //resp, err := httpget.DefaultClient.Send(req, nil) + if err != nil { + fmt.Println(err) + return + } + body, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(string(body)) +} diff --git a/x/httpget/client.go b/x/httpget/client.go new file mode 100644 index 0000000..8a1f610 --- /dev/null +++ b/x/httpget/client.go @@ -0,0 +1,46 @@ +package httpget + +type Client struct { + Transport RoundTripper +} + +var DefaultClient = &Client{} + +type RoundTripper interface { + RoundTrip(*Request) (*Response, error) +} + +func (c *Client) transport() RoundTripper { + if c.Transport != nil { + return c.Transport + } + return DefaultTransport +} + +func Get(url string) (*Response, error) { + return DefaultClient.Get(url) +} + +func (c *Client) Get(url string) (*Response, error) { + req, err := NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + return c.Do(req) +} + +func (c *Client) Do(req *Request) (*Response, error) { + return c.do(req) +} + +func (c *Client) do(req *Request) (*Response, error) { + return c.send(req, nil) +} + +func (c *Client) send(req *Request, deadline any) (*Response, error) { + return send(req, c.transport(), deadline) +} + +func send(req *Request, rt RoundTripper, deadline any) (resp *Response, err error) { + return rt.RoundTrip(req) +} diff --git a/x/httpget/header.go b/x/httpget/header.go new file mode 100644 index 0000000..1768557 --- /dev/null +++ b/x/httpget/header.go @@ -0,0 +1,32 @@ +package httpget + +import ( + "fmt" + + "github.com/goplus/llgo/c" + "github.com/goplus/llgoexamples/rust/hyper" +) + +type Header map[string][]string + +// AppendToResponseHeader (HeadersForEachCallback) prints each header to the console +func AppendToResponseHeader(userdata c.Pointer, name *uint8, nameLen uintptr, value *uint8, valueLen uintptr) c.Int { + resp := (*Response)(userdata) + nameStr := string((*[1 << 30]byte)(c.Pointer(name))[:nameLen:nameLen]) + valueStr := string((*[1 << 30]byte)(c.Pointer(value))[:valueLen:valueLen]) + + if resp.Header == nil { + resp.Header = make(map[string][]string) + } + resp.Header[nameStr] = append(resp.Header[nameStr], valueStr) + //c.Printf(c.Str("%.*s: %.*s\n"), int(nameLen), name, int(valueLen), value) + return hyper.IterContinue +} + +func (resp *Response) PrintHeaders() { + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } +} diff --git a/x/httpget/request.go b/x/httpget/request.go new file mode 100644 index 0000000..391311c --- /dev/null +++ b/x/httpget/request.go @@ -0,0 +1,42 @@ +package httpget + +import ( + "fmt" + "io" + + "github.com/goplus/llgo/c" + "github.com/goplus/llgoexamples/rust/hyper" +) + +type Request struct { + Method string + Url string +} + +func NewRequest(method, url string, body io.Reader) (*Request, error) { + return &Request{ + Method: method, + Url: url, + }, nil +} + +func NewHyperRequest(request *Request) (*hyper.Request, error) { + host, _, uri := parseURL(request.Url) + method := request.Method + // Prepare the request + req := hyper.NewRequest() + // Set the request method and uri + if req.SetMethod((*uint8)(&[]byte(method)[0]), c.Strlen(c.AllocaCStr(method))) != hyper.OK { + return nil, fmt.Errorf("error setting method %s\n", method) + } + if req.SetURI((*uint8)(&[]byte(uri)[0]), c.Strlen(c.AllocaCStr(uri))) != hyper.OK { + return nil, fmt.Errorf("error setting uri %s\n", uri) + } + + // Set the request headers + reqHeaders := req.Headers() + if reqHeaders.Set((*uint8)(&[]byte("Host")[0]), c.Strlen(c.Str("Host")), (*uint8)(&[]byte(host)[0]), c.Strlen(c.AllocaCStr(host))) != hyper.OK { + return nil, fmt.Errorf("error setting headers\n") + } + return req, nil +} diff --git a/x/httpget/response.go b/x/httpget/response.go new file mode 100644 index 0000000..a9e4468 --- /dev/null +++ b/x/httpget/response.go @@ -0,0 +1,54 @@ +package httpget + +import ( + "fmt" + "io" + "unsafe" + + "github.com/goplus/llgo/c" + "github.com/goplus/llgoexamples/rust/hyper" +) + +type Response struct { + Status string + StatusCode int + Header Header + Body io.ReadCloser + ContentLength int64 + respBodyWriter *io.PipeWriter +} + +// AppendToResponseBody (BodyForEachCallback) appends the body to the response +func AppendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { + fmt.Println("reading1...") + resp := (*Response)(userdata) + len := chunk.Len() + buf := unsafe.Slice((*byte)(chunk.Bytes()), len) + _, err := resp.respBodyWriter.Write(buf) + resp.ContentLength += int64(len) + if err != nil { + fmt.Printf("Failed to write response body: %v\n", err) + return hyper.IterBreak + } + fmt.Println("reading2...") + return hyper.IterContinue +} + +func (resp *Response) PrintBody() { + var buffer = make([]byte, 4096) + for { + n, err := resp.Body.Read(buffer) + if err == io.EOF { + fmt.Printf("\n") + break + } + if err != nil { + fmt.Println("Error reading from pipe:", err) + break + } + fmt.Printf("%s", string(buffer[:n])) + } + buffer = nil + //body, _ := io.ReadAll(resp.Content) + //fmt.Println(string(body)) +} diff --git a/x/httpget/transport.go b/x/httpget/transport.go new file mode 100644 index 0000000..0450bbd --- /dev/null +++ b/x/httpget/transport.go @@ -0,0 +1,502 @@ +package httpget + +import ( + "bufio" + "fmt" + io2 "io" + "strconv" + "strings" + "unsafe" + + "github.com/goplus/llgo/c" + "github.com/goplus/llgo/c/libuv" + "github.com/goplus/llgo/c/net" + "github.com/goplus/llgo/c/syscall" + "github.com/goplus/llgoexamples/rust/hyper" +) + +type ConnData struct { + TcpHandle libuv.Tcp + ConnectReq libuv.Connect + ReadBuf libuv.Buf + ReadBufFilled uintptr + ReadWaker *hyper.Waker + WriteWaker *hyper.Waker +} + +type Transport struct { +} + +var DefaultTransport RoundTripper = &Transport{} + +// persistConn wraps a connection, usually a persistent one +// (but may be used for non-keep-alive requests as well) +type persistConn struct { + // alt optionally specifies the TLS NextProto RoundTripper. + // This is used for HTTP/2 today and future protocols later. + // If it's non-nil, the rest of the fields are unused. + alt RoundTripper + + conn *ConnData + t *Transport + br *bufio.Reader // from conn + bw *bufio.Writer // to conn + nwrite int64 // bytes written + reqch chan requestAndChan // written by roundTrip; read by readLoop + writech chan writeRequest // written by roundTrip; read by writeLoop + closech chan struct{} // closed when conn closed +} + +// incomparable is a zero-width, non-comparable type. Adding it to a struct +// makes that struct also non-comparable, and generally doesn't add +// any size (as long as it's first). +type incomparable [0]func() + +type requestAndChan struct { + _ incomparable + req *hyper.Request + ch chan responseAndError // unbuffered; always send in select on callerGone +} + +// A writeRequest is sent by the caller's goroutine to the +// writeLoop's goroutine to write a request while the read loop +// concurrently waits on both the write response and the server's +// reply. +type writeRequest struct { + // req *transportRequest + ch chan<- error + + // Optional blocking chan for Expect: 100-continue (for receive). + // If not nil, writeLoop blocks sending request body until + // it receives from this chan. + continueCh <-chan struct{} +} + +// responseAndError is how the goroutine reading from an HTTP/1 server +// communicates with the goroutine doing the RoundTrip. +type responseAndError struct { + _ incomparable + res *Response // else use this response (see res method) + err error +} + +func (t *Transport) RoundTrip(request *Request) (*Response, error) { + req, err := NewHyperRequest(request) + if err != nil { + return nil, err + } + pconn, err := t.getConn(req) + var resp *Response + resp, err = pconn.roundTrip(req) + if err == nil { + return resp, nil + } + return nil, err +} + +func (t *Transport) getConn(req *hyper.Request) (pconn *persistConn, err error) { + host := "www.baidu.com" + port := "80" + loop := libuv.DefaultLoop() + conn := (*ConnData)(c.Malloc(unsafe.Sizeof(ConnData{}))) + if conn == nil { + return nil, fmt.Errorf("Failed to allocate memory for conn_data\n") + } + + libuv.InitTcp(loop, &conn.TcpHandle) + //conn.TcpHandle.Data = c.Pointer(conn) + (*libuv.Handle)(c.Pointer(&conn.TcpHandle)).SetData(c.Pointer(conn)) + + var hints net.AddrInfo + c.Memset(c.Pointer(&hints), 0, unsafe.Sizeof(hints)) + hints.Family = syscall.AF_UNSPEC + hints.SockType = syscall.SOCK_STREAM + + var res *net.AddrInfo + status := net.Getaddrinfo(c.AllocaCStr(host), c.AllocaCStr(port), &hints, &res) + if status != 0 { + return nil, fmt.Errorf("getaddrinfo error\n") + } + + //conn.ConnectReq.Data = c.Pointer(conn) + (*libuv.Req)(c.Pointer(&conn.ConnectReq)).SetData(c.Pointer(conn)) + status = libuv.TcpConnect(&conn.ConnectReq, &conn.TcpHandle, res.Addr, OnConnect) + if status != 0 { + return nil, fmt.Errorf("connect error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(status)))) + } + pconn = &persistConn{ + conn: conn, + t: t, + reqch: make(chan requestAndChan, 1), + writech: make(chan writeRequest, 1), + closech: make(chan struct{}), + } + + net.Freeaddrinfo(res) + + go pconn.startLoop(loop) + return pconn, nil +} + +func (pc *persistConn) roundTrip(req *hyper.Request) (resp *Response, err error) { + resc := make(chan responseAndError) + pc.reqch <- requestAndChan{ + req: req, + ch: resc, + } + + select { + case re := <-resc: + if (re.res == nil) == (re.err == nil) { + panic(fmt.Sprintf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil)) + } + if re.err != nil { + return nil, err + } + return re.res, nil + } +} + +func (pc *persistConn) startLoop(loop *libuv.Loop) { + // Hookup the IO + io := NewIoWithConnReadWrite(pc.conn) + + // We need an executor generally to poll futures + exec := hyper.NewExecutor() + + // Prepare client options + opts := hyper.NewClientConnOptions() + opts.Exec(exec) + + handshakeTask := hyper.Handshake(io, opts) + SetUserData(handshakeTask, hyper.ExampleHandshake) + + // Let's wait for the handshake to finish... + exec.Push(handshakeTask) + + var hyperErr *hyper.Error + var response Response + + var rc requestAndChan + + select { + case rc = <-pc.reqch: + } + // The polling state machine! + for { + // Poll all ready tasks and act on them... + for { + task := exec.Poll() + if task == nil { + break + } + + switch (hyper.ExampleId)(uintptr(task.Userdata())) { + case hyper.ExampleHandshake: + if task.Type() == hyper.TaskError { + c.Printf(c.Str("handshake error!\n")) + hyperErr = (*hyper.Error)(task.Value()) + Fail(hyperErr) + } + if task.Type() != hyper.TaskClientConn { + c.Printf(c.Str("unexpected task type\n")) + Fail(hyperErr) + } + + client := (*hyper.ClientConn)(task.Value()) + task.Free() + + // Send it! + sendTask := client.Send(rc.req) + SetUserData(sendTask, hyper.ExampleSend) + sendRes := exec.Push(sendTask) + if sendRes != hyper.OK { + panic("error send\n") + } + + // For this example, no longer need the client + client.Free() + + break + case hyper.ExampleSend: + if task.Type() == hyper.TaskError { + c.Printf(c.Str("send error!\n")) + hyperErr = (*hyper.Error)(task.Value()) + Fail(hyperErr) + } + if task.Type() != hyper.TaskResponse { + c.Printf(c.Str("unexpected task type\n")) + Fail(hyperErr) + } + + // Take the results + resp := (*hyper.Response)(task.Value()) + task.Free() + + rp := resp.ReasonPhrase() + rpLen := resp.ReasonPhraseLen() + + response.Status = strconv.Itoa(int(resp.Status())) + " " + string((*[1 << 30]byte)(c.Pointer(rp))[:rpLen:rpLen]) + response.StatusCode = int(resp.Status()) + + headers := resp.Headers() + headers.Foreach(AppendToResponseHeader, c.Pointer(&response)) + respBody := resp.Body() + + response.Body, response.respBodyWriter = io2.Pipe() + + /*go func() { + fmt.Println("writing...") + for { + fmt.Println("writing for...") + dataTask := respBody.Data() + exec.Push(dataTask) + dataTask = exec.Poll() + if dataTask.Type() == hyper.TaskBuf { + buf := (*hyper.Buf)(dataTask.Value()) + len := buf.Len() + bytes := unsafe.Slice((*byte)(buf.Bytes()), len) + _, err := response.respBodyWriter.Write(bytes) + if err != nil { + fmt.Printf("Failed to write response body: %v\n", err) + break + } + dataTask.Free() + } else if dataTask.Type() == hyper.TaskEmpty { + fmt.Println("writing empty") + dataTask.Free() + break + } + } + fmt.Println("end writing") + defer response.respBodyWriter.Close() + }()*/ + + foreachTask := respBody.Foreach(AppendToResponseBody, c.Pointer(&response)) + + SetUserData(foreachTask, hyper.ExampleRespBody) + exec.Push(foreachTask) + + rc.ch <- responseAndError{res: &response} + // No longer need the response + //resp.Free() + + break + case hyper.ExampleRespBody: + println("ExampleRespBody") + if task.Type() == hyper.TaskError { + c.Printf(c.Str("body error!\n")) + hyperErr = (*hyper.Error)(task.Value()) + Fail(hyperErr) + } + if task.Type() != hyper.TaskEmpty { + c.Printf(c.Str("unexpected task type\n")) + Fail(hyperErr) + } + + // Cleaning up before exiting + task.Free() + //exec.Free() + (*libuv.Handle)(c.Pointer(&pc.conn.TcpHandle)).Close(nil) + + FreeConnData(pc.conn) + + //return &response, nil + break + case hyper.ExampleNotSet: + println("ExampleNotSet") + // A background task for hyper_client completed... + task.Free() + break + } + } + + libuv.Run(loop, libuv.RUN_ONCE) + } +} + +// AllocBuffer allocates a buffer for reading from a socket +func AllocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { + //conn := (*ConnData)(handle.Data) + //conn := (*struct{ data *ConnData })(c.Pointer(handle)).data + conn := (*ConnData)(handle.GetData()) + if conn.ReadBuf.Base == nil { + conn.ReadBuf = libuv.InitBuf((*c.Char)(c.Malloc(suggestedSize)), c.Uint(suggestedSize)) + conn.ReadBufFilled = 0 + } + *buf = libuv.InitBuf((*c.Char)(c.Pointer(uintptr(c.Pointer(conn.ReadBuf.Base))+conn.ReadBufFilled)), c.Uint(suggestedSize-conn.ReadBufFilled)) +} + +// OnRead is the libuv callback for reading from a socket +func OnRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { + //conn := (*ConnData)(stream.Data) + //conn := (*struct{ data *ConnData })(c.Pointer(stream)).data + conn := (*ConnData)((*libuv.Handle)(c.Pointer(stream)).GetData()) + if nread > 0 { + conn.ReadBufFilled += uintptr(nread) + } + if conn.ReadWaker != nil { + conn.ReadWaker.Wake() + conn.ReadWaker = nil + } +} + +// ReadCallBack is the hyper callback for reading from a socket +func ReadCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { + conn := (*ConnData)(userdata) + + if conn.ReadBufFilled > 0 { + var toCopy uintptr + if bufLen < conn.ReadBufFilled { + toCopy = bufLen + } else { + toCopy = conn.ReadBufFilled + } + c.Memcpy(c.Pointer(buf), c.Pointer(conn.ReadBuf.Base), toCopy) + c.Memmove(c.Pointer(conn.ReadBuf.Base), c.Pointer(uintptr(c.Pointer(conn.ReadBuf.Base))+toCopy), conn.ReadBufFilled-toCopy) + conn.ReadBufFilled -= toCopy + return toCopy + } + + if conn.ReadWaker != nil { + conn.ReadWaker.Free() + } + conn.ReadWaker = ctx.Waker() + return hyper.IoPending +} + +// OnWrite is the libuv callback for writing to a socket +func OnWrite(req *libuv.Write, status c.Int) { + //conn := (*ConnData)(req.Data) + //conn := (*struct{ data *ConnData })(c.Pointer(req)).data + conn := (*ConnData)((*libuv.Req)(c.Pointer(req)).GetData()) + + if conn.WriteWaker != nil { + conn.WriteWaker.Wake() + conn.WriteWaker = nil + } + c.Free(c.Pointer(req)) +} + +// WriteCallBack is the hyper callback for writing to a socket +func WriteCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { + conn := (*ConnData)(userdata) + initBuf := libuv.InitBuf((*c.Char)(c.Pointer(buf)), c.Uint(bufLen)) + req := (*libuv.Write)(c.Malloc(unsafe.Sizeof(libuv.Write{}))) + //req.Data = c.Pointer(conn) + (*libuv.Req)(c.Pointer(req)).SetData(c.Pointer(conn)) + ret := req.Write((*libuv.Stream)(c.Pointer(&conn.TcpHandle)), &initBuf, 1, OnWrite) + + if ret >= 0 { + return bufLen + } + + if conn.WriteWaker != nil { + conn.WriteWaker.Free() + } + conn.WriteWaker = ctx.Waker() + return hyper.IoPending +} + +// OnConnect is the libuv callback for a successful connection +func OnConnect(req *libuv.Connect, status c.Int) { + //conn := (*ConnData)(req.Data) + //conn := (*struct{ data *ConnData })(c.Pointer(req)).data + conn := (*ConnData)((*libuv.Req)(c.Pointer(req)).GetData()) + + if status < 0 { + c.Fprintf(c.Stderr, c.Str("connect error: %d\n"), libuv.Strerror(libuv.Errno(status))) + return + } + (*libuv.Stream)(c.Pointer(&conn.TcpHandle)).StartRead(AllocBuffer, OnRead) +} + +// FreeConnData frees the connection data +func FreeConnData(conn *ConnData) { + if conn.ReadWaker != nil { + conn.ReadWaker.Free() + conn.ReadWaker = nil + } + if conn.WriteWaker != nil { + conn.WriteWaker.Free() + conn.WriteWaker = nil + } + if conn.ReadBuf.Base != nil { + c.Free(c.Pointer(conn.ReadBuf.Base)) + conn.ReadBuf.Base = nil + } + c.Free(c.Pointer(conn)) +} + +// Fail prints the error details and panics +func Fail(err *hyper.Error) { + if err != nil { + c.Printf(c.Str("error code: %d\n"), err.Code()) + // grab the error details + var errBuf [256]c.Char + errLen := err.Print((*uint8)(c.Pointer(&errBuf[:][0])), uintptr(len(errBuf))) + + c.Printf(c.Str("details: %.*s\n"), c.Int(errLen), c.Pointer(&errBuf[:][0])) + + // clean up the error + err.Free() + panic("hyper error \n") + } +} + +// NewIoWithConnReadWrite creates a new IO with read and write callbacks +func NewIoWithConnReadWrite(connData *ConnData) *hyper.Io { + io := hyper.NewIo() + io.SetUserdata(c.Pointer(connData)) + io.SetRead(ReadCallBack) + io.SetWrite(WriteCallBack) + return io +} + +// SetUserData Set the user data for the task +func SetUserData(task *hyper.Task, userData hyper.ExampleId) { + var data = userData + task.SetUserdata(c.Pointer(uintptr(data))) +} + +// parseURL Parse the URL and extract the host name, port number, and URI +func parseURL(rawURL string) (hostname, port, uri string) { + // 找到 "://" 的位置,以分隔协议和主机名 + schemeEnd := strings.Index(rawURL, "://") + if schemeEnd != -1 { + //scheme = rawURL[:schemeEnd] + rawURL = rawURL[schemeEnd+3:] + } else { + //scheme = "http" // 默认协议为 http + } + + // 找到第一个 "/" 的位置,以分隔主机名和路径 + pathStart := strings.Index(rawURL, "/") + if pathStart != -1 { + uri = rawURL[pathStart:] + rawURL = rawURL[:pathStart] + } else { + uri = "/" + } + + // 找到 ":" 的位置,以分隔主机名和端口号 + portStart := strings.LastIndex(rawURL, ":") + if portStart != -1 { + hostname = rawURL[:portStart] + port = rawURL[portStart+1:] + } else { + hostname = rawURL + port = "" // 未指定端口号 + } + + // 如果未指定端口号,根据协议设置默认端口号 + if port == "" { + //if scheme == "https" { + // port = "443" + //} else { + // port = "80" + //} + port = "80" + } + return +} From c52bdd5f06fd7f22a7c0cb3c4709de030c457d90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Mon, 5 Aug 2024 17:03:35 +0800 Subject: [PATCH 04/55] WIP(x/http/client/get): Use channels to pass response(Passed the test) --- go.mod | 2 +- go.sum | 4 +- x/http/_demo/get/get.go | 19 +- x/http/client.go | 453 +------------------------------ x/{httpget => http}/request.go | 25 +- x/http/response.go | 93 +------ x/{httpget => http}/transport.go | 360 +++++++++++------------- x/httpget/_demo/get/get.go | 24 -- x/httpget/client.go | 46 ---- x/httpget/header.go | 32 --- x/httpget/response.go | 54 ---- 11 files changed, 201 insertions(+), 911 deletions(-) rename x/{httpget => http}/request.go (65%) rename x/{httpget => http}/transport.go (56%) delete mode 100644 x/httpget/_demo/get/get.go delete mode 100644 x/httpget/client.go delete mode 100644 x/httpget/header.go delete mode 100644 x/httpget/response.go diff --git a/go.mod b/go.mod index 39080a4..fa05f1f 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module github.com/goplus/llgoexamples go 1.20 -require github.com/goplus/llgo v0.9.5-0.20240731053840-36072584d0be +require github.com/goplus/llgo v0.9.5-0.20240805045323-4bff9cc3df0c diff --git a/go.sum b/go.sum index e2c5d17..ba1d000 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,2 @@ -github.com/goplus/llgo v0.9.5-0.20240731053840-36072584d0be h1:FTALxA3ivIeVRAO93e1hCSCLaPbjKn+RZx40p5lx8KE= -github.com/goplus/llgo v0.9.5-0.20240731053840-36072584d0be/go.mod h1:zsrtWZapL4aklZc99xBSZRynGzLTIT1mLRjP0VSn9iw= +github.com/goplus/llgo v0.9.5-0.20240805045323-4bff9cc3df0c h1:PhaSnZL8LLyRIHWc5Wim9No0Q475H8Ljikxfj1gHHjc= +github.com/goplus/llgo v0.9.5-0.20240805045323-4bff9cc3df0c/go.mod h1:zsrtWZapL4aklZc99xBSZRynGzLTIT1mLRjP0VSn9iw= diff --git a/x/http/_demo/get/get.go b/x/http/_demo/get/get.go index 09f32a0..73e7113 100644 --- a/x/http/_demo/get/get.go +++ b/x/http/_demo/get/get.go @@ -2,23 +2,24 @@ package main import ( "fmt" + "io" "github.com/goplus/llgoexamples/x/http" ) func main() { - // 使用 http.Get 发送 GET 请求 - resp, err := http.Get("https://www.baidu.com/") + resp, err := http.Get("https://www.baidu.com") if err != nil { fmt.Println(err) return } - fmt.Println(resp.Status) - fmt.Println(resp.StatusCode) + println(resp.Status) resp.PrintHeaders() - fmt.Println() - resp.PrintBody2() - - resp.PrintBody1() - defer resp.Content.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(string(body)) + defer resp.Body.Close() } diff --git a/x/http/client.go b/x/http/client.go index d173574..ac0bc6e 100644 --- a/x/http/client.go +++ b/x/http/client.go @@ -1,28 +1,5 @@ package http -import ( - "fmt" - io2 "io" - "strconv" - "strings" - "unsafe" - - "github.com/goplus/llgo/c" - "github.com/goplus/llgo/c/libuv" - "github.com/goplus/llgo/c/net" - "github.com/goplus/llgo/c/syscall" - "github.com/goplus/llgoexamples/rust/hyper" -) - -type ConnData struct { - TcpHandle libuv.Tcp - ConnectReq libuv.Connect - ReadBuf libuv.Buf - ReadBufFilled uintptr - ReadWaker *hyper.Waker - WriteWaker *hyper.Waker -} - type Client struct { Transport RoundTripper } @@ -30,7 +7,7 @@ type Client struct { var DefaultClient = &Client{} type RoundTripper interface { - RoundTrip(*hyper.Request) (*Response, error) + RoundTrip(*Request) (*Response, error) } func (c *Client) transport() RoundTripper { @@ -40,7 +17,7 @@ func (c *Client) transport() RoundTripper { return DefaultTransport } -func Get2(url string) (*Response, error) { +func Get(url string) (*Response, error) { return DefaultClient.Get(url) } @@ -52,436 +29,18 @@ func (c *Client) Get(url string) (*Response, error) { return c.Do(req) } -func (c *Client) Do(req *hyper.Request) (*Response, error) { +func (c *Client) Do(req *Request) (*Response, error) { return c.do(req) } -func (c *Client) do(req *hyper.Request) (*Response, error) { +func (c *Client) do(req *Request) (*Response, error) { return c.send(req, nil) } -func (c *Client) send(req *hyper.Request, deadline any) (*Response, error) { +func (c *Client) send(req *Request, deadline any) (*Response, error) { return send(req, c.transport(), deadline) } -func send(req *hyper.Request, rt RoundTripper, deadline any) (resp *Response, err error) { +func send(req *Request, rt RoundTripper, deadline any) (resp *Response, err error) { return rt.RoundTrip(req) } - -func NewRequest(method, url string, body io2.Reader) (*hyper.Request, error) { - host, _, uri := parseURL(url) - // Prepare the request - req := hyper.NewRequest() - // Set the request method and uri - if req.SetMethod((*uint8)(&[]byte(method)[0]), c.Strlen(c.AllocaCStr(method))) != hyper.OK { - return nil, fmt.Errorf("error setting method %s\n", method) - } - if req.SetURI((*uint8)(&[]byte(uri)[0]), c.Strlen(c.AllocaCStr(uri))) != hyper.OK { - return nil, fmt.Errorf("error setting uri %s\n", uri) - } - - // Set the request headers - reqHeaders := req.Headers() - if reqHeaders.Set((*uint8)(&[]byte("Host")[0]), c.Strlen(c.Str("Host")), (*uint8)(&[]byte(host)[0]), c.Strlen(c.AllocaCStr(host))) != hyper.OK { - return nil, fmt.Errorf("error setting headers\n") - } - return req, nil -} - -func Get(url string) (_ *Response, err error) { - host, port, uri := parseURL(url) - - loop := libuv.DefaultLoop() - conn := (*ConnData)(c.Malloc(unsafe.Sizeof(ConnData{}))) - if conn == nil { - return nil, fmt.Errorf("Failed to allocate memory for conn_data\n") - } - - libuv.InitTcp(loop, &conn.TcpHandle) - //conn.TcpHandle.Data = c.Pointer(conn) - (*libuv.Handle)(c.Pointer(&conn.TcpHandle)).SetData(c.Pointer(conn)) - - var hints net.AddrInfo - c.Memset(c.Pointer(&hints), 0, unsafe.Sizeof(hints)) - hints.Family = syscall.AF_UNSPEC - hints.SockType = syscall.SOCK_STREAM - - var res *net.AddrInfo - status := net.Getaddrinfo(c.AllocaCStr(host), c.AllocaCStr(port), &hints, &res) - if status != 0 { - return nil, fmt.Errorf("getaddrinfo error\n") - } - - //conn.ConnectReq.Data = c.Pointer(conn) - (*libuv.Req)(c.Pointer(&conn.ConnectReq)).SetData(c.Pointer(conn)) - status = libuv.TcpConnect(&conn.ConnectReq, &conn.TcpHandle, res.Addr, OnConnect) - if status != 0 { - return nil, fmt.Errorf("connect error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(status)))) - } - - net.Freeaddrinfo(res) - - // Hookup the IO - io := NewIoWithConnReadWrite(conn) - - // We need an executor generally to poll futures - exec := hyper.NewExecutor() - - // Prepare client options - opts := hyper.NewClientConnOptions() - opts.Exec(exec) - - handshakeTask := hyper.Handshake(io, opts) - SetUserData(handshakeTask, hyper.ExampleHandshake) - - // Let's wait for the handshake to finish... - exec.Push(handshakeTask) - - var hyperErr *hyper.Error - var response Response - - // The polling state machine! - for { - // Poll all ready tasks and act on them... - for { - task := exec.Poll() - if task == nil { - break - } - - switch (hyper.ExampleId)(uintptr(task.Userdata())) { - case hyper.ExampleHandshake: - if task.Type() == hyper.TaskError { - c.Printf(c.Str("handshake error!\n")) - hyperErr = (*hyper.Error)(task.Value()) - err = Fail(hyperErr) - return nil, err - } - if task.Type() != hyper.TaskClientConn { - c.Printf(c.Str("unexpected task type\n")) - err = Fail(hyperErr) - return nil, err - } - - client := (*hyper.ClientConn)(task.Value()) - task.Free() - - // Prepare the request - req := hyper.NewRequest() - // Set the request method and uri - if req.SetMethod((*uint8)(&[]byte("GET")[0]), c.Strlen(c.Str("GET"))) != hyper.OK { - return nil, fmt.Errorf("error setting method %s\n", "GET") - } - if req.SetURI((*uint8)(&[]byte(uri)[0]), c.Strlen(c.AllocaCStr(uri))) != hyper.OK { - return nil, fmt.Errorf("error setting uri %s\n", uri) - } - - // Set the request headers - reqHeaders := req.Headers() - if reqHeaders.Set((*uint8)(&[]byte("Host")[0]), c.Strlen(c.Str("Host")), (*uint8)(&[]byte(host)[0]), c.Strlen(c.AllocaCStr(host))) != hyper.OK { - return nil, fmt.Errorf("error setting headers\n") - } - - // Send it! - sendTask := client.Send(req) - SetUserData(sendTask, hyper.ExampleSend) - sendRes := exec.Push(sendTask) - if sendRes != hyper.OK { - return nil, fmt.Errorf("error send\n") - } - - // For this example, no longer need the client - client.Free() - - break - case hyper.ExampleSend: - if task.Type() == hyper.TaskError { - c.Printf(c.Str("send error!\n")) - hyperErr = (*hyper.Error)(task.Value()) - err = Fail(hyperErr) - return nil, err - } - if task.Type() != hyper.TaskResponse { - c.Printf(c.Str("unexpected task type\n")) - err = Fail(hyperErr) - return nil, err - } - - // Take the results - resp := (*hyper.Response)(task.Value()) - task.Free() - - rp := resp.ReasonPhrase() - rpLen := resp.ReasonPhraseLen() - - response.Status = strconv.Itoa(int(resp.Status())) + " " + string((*[1 << 30]byte)(c.Pointer(rp))[:rpLen:rpLen]) - response.StatusCode = int(resp.Status()) - - headers := resp.Headers() - headers.Foreach(AppendToResponseHeader, c.Pointer(&response)) - respBody := resp.Body() - - response.Body, response.respBodyWriter = io2.Pipe() - - /*go func() { - fmt.Println("writing...") - for { - fmt.Println("writing for...") - dataTask := respBody.Data() - exec.Push(dataTask) - dataTask = exec.Poll() - if dataTask.Type() == hyper.TaskBuf { - buf := (*hyper.Buf)(dataTask.Value()) - len := buf.Len() - bytes := unsafe.Slice((*byte)(buf.Bytes()), len) - _, err := response.respBodyWriter.Write(bytes) - if err != nil { - fmt.Printf("Failed to write response body: %v\n", err) - break - } - dataTask.Free() - } else if dataTask.Type() == hyper.TaskEmpty { - fmt.Println("writing empty") - dataTask.Free() - break - } - } - fmt.Println("end writing") - defer response.respBodyWriter.Close() - }()*/ - - foreachTask := respBody.Foreach(AppendToResponseBody, c.Pointer(&response)) - - SetUserData(foreachTask, hyper.ExampleRespBody) - exec.Push(foreachTask) - - return &response, nil - - // No longer need the response - //resp.Free() - - break - case hyper.ExampleRespBody: - println("ExampleRespBody") - if task.Type() == hyper.TaskError { - c.Printf(c.Str("body error!\n")) - hyperErr = (*hyper.Error)(task.Value()) - err = Fail(hyperErr) - return nil, err - } - if task.Type() != hyper.TaskEmpty { - c.Printf(c.Str("unexpected task type\n")) - err = Fail(hyperErr) - return nil, err - } - - // Cleaning up before exiting - task.Free() - exec.Free() - (*libuv.Handle)(c.Pointer(&conn.TcpHandle)).Close(nil) - - FreeConnData(conn) - - //if response.respBodyWriter != nil { - // defer response.respBodyWriter.Close() - //} - - return &response, nil - case hyper.ExampleNotSet: - println("ExampleNotSet") - // A background task for hyper_client completed... - task.Free() - break - } - } - - libuv.Run(loop, libuv.RUN_ONCE) - } -} - -// AllocBuffer allocates a buffer for reading from a socket -func AllocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { - //conn := (*ConnData)(handle.Data) - //conn := (*struct{ data *ConnData })(c.Pointer(handle)).data - conn := (*ConnData)(handle.GetData()) - if conn.ReadBuf.Base == nil { - conn.ReadBuf = libuv.InitBuf((*c.Char)(c.Malloc(suggestedSize)), c.Uint(suggestedSize)) - conn.ReadBufFilled = 0 - } - *buf = libuv.InitBuf((*c.Char)(c.Pointer(uintptr(c.Pointer(conn.ReadBuf.Base))+conn.ReadBufFilled)), c.Uint(suggestedSize-conn.ReadBufFilled)) -} - -// OnRead is the libuv callback for reading from a socket -func OnRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { - //conn := (*ConnData)(stream.Data) - //conn := (*struct{ data *ConnData })(c.Pointer(stream)).data - conn := (*ConnData)((*libuv.Handle)(c.Pointer(stream)).GetData()) - if nread > 0 { - conn.ReadBufFilled += uintptr(nread) - } - if conn.ReadWaker != nil { - conn.ReadWaker.Wake() - conn.ReadWaker = nil - } -} - -// ReadCallBack is the hyper callback for reading from a socket -func ReadCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { - conn := (*ConnData)(userdata) - - if conn.ReadBufFilled > 0 { - var toCopy uintptr - if bufLen < conn.ReadBufFilled { - toCopy = bufLen - } else { - toCopy = conn.ReadBufFilled - } - c.Memcpy(c.Pointer(buf), c.Pointer(conn.ReadBuf.Base), toCopy) - c.Memmove(c.Pointer(conn.ReadBuf.Base), c.Pointer(uintptr(c.Pointer(conn.ReadBuf.Base))+toCopy), conn.ReadBufFilled-toCopy) - conn.ReadBufFilled -= toCopy - return toCopy - } - - if conn.ReadWaker != nil { - conn.ReadWaker.Free() - } - conn.ReadWaker = ctx.Waker() - return hyper.IoPending -} - -// OnWrite is the libuv callback for writing to a socket -func OnWrite(req *libuv.Write, status c.Int) { - //conn := (*ConnData)(req.Data) - //conn := (*struct{ data *ConnData })(c.Pointer(req)).data - conn := (*ConnData)((*libuv.Req)(c.Pointer(req)).GetData()) - - if conn.WriteWaker != nil { - conn.WriteWaker.Wake() - conn.WriteWaker = nil - } - c.Free(c.Pointer(req)) -} - -// WriteCallBack is the hyper callback for writing to a socket -func WriteCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { - conn := (*ConnData)(userdata) - initBuf := libuv.InitBuf((*c.Char)(c.Pointer(buf)), c.Uint(bufLen)) - req := (*libuv.Write)(c.Malloc(unsafe.Sizeof(libuv.Write{}))) - //req.Data = c.Pointer(conn) - (*libuv.Req)(c.Pointer(req)).SetData(c.Pointer(conn)) - ret := req.Write((*libuv.Stream)(c.Pointer(&conn.TcpHandle)), &initBuf, 1, OnWrite) - - if ret >= 0 { - return bufLen - } - - if conn.WriteWaker != nil { - conn.WriteWaker.Free() - } - conn.WriteWaker = ctx.Waker() - return hyper.IoPending -} - -// OnConnect is the libuv callback for a successful connection -func OnConnect(req *libuv.Connect, status c.Int) { - //conn := (*ConnData)(req.Data) - //conn := (*struct{ data *ConnData })(c.Pointer(req)).data - conn := (*ConnData)((*libuv.Req)(c.Pointer(req)).GetData()) - - if status < 0 { - c.Fprintf(c.Stderr, c.Str("connect error: %d\n"), libuv.Strerror(libuv.Errno(status))) - return - } - (*libuv.Stream)(c.Pointer(&conn.TcpHandle)).StartRead(AllocBuffer, OnRead) -} - -// FreeConnData frees the connection data -func FreeConnData(conn *ConnData) { - if conn.ReadWaker != nil { - conn.ReadWaker.Free() - conn.ReadWaker = nil - } - if conn.WriteWaker != nil { - conn.WriteWaker.Free() - conn.WriteWaker = nil - } - if conn.ReadBuf.Base != nil { - c.Free(c.Pointer(conn.ReadBuf.Base)) - conn.ReadBuf.Base = nil - } - c.Free(c.Pointer(conn)) -} - -// Fail prints the error details and panics -func Fail(err *hyper.Error) error { - if err != nil { - c.Printf(c.Str("error code: %d\n"), err.Code()) - // grab the error details - var errBuf [256]c.Char - errLen := err.Print((*uint8)(c.Pointer(&errBuf[:][0])), uintptr(len(errBuf))) - - c.Printf(c.Str("details: %.*s\n"), c.Int(errLen), c.Pointer(&errBuf[:][0])) - - // clean up the error - err.Free() - return fmt.Errorf("hyper error\n") - } - return nil -} - -// NewIoWithConnReadWrite creates a new IO with read and write callbacks -func NewIoWithConnReadWrite(connData *ConnData) *hyper.Io { - io := hyper.NewIo() - io.SetUserdata(c.Pointer(connData)) - io.SetRead(ReadCallBack) - io.SetWrite(WriteCallBack) - return io -} - -// SetUserData Set the user data for the task -func SetUserData(task *hyper.Task, userData hyper.ExampleId) { - var data = userData - task.SetUserdata(c.Pointer(uintptr(data))) -} - -// parseURL Parse the URL and extract the host name, port number, and URI -func parseURL(rawURL string) (hostname, port, uri string) { - // 找到 "://" 的位置,以分隔协议和主机名 - schemeEnd := strings.Index(rawURL, "://") - if schemeEnd != -1 { - //scheme = rawURL[:schemeEnd] - rawURL = rawURL[schemeEnd+3:] - } else { - //scheme = "http" // 默认协议为 http - } - - // 找到第一个 "/" 的位置,以分隔主机名和路径 - pathStart := strings.Index(rawURL, "/") - if pathStart != -1 { - uri = rawURL[pathStart:] - rawURL = rawURL[:pathStart] - } else { - uri = "/" - } - - // 找到 ":" 的位置,以分隔主机名和端口号 - portStart := strings.LastIndex(rawURL, ":") - if portStart != -1 { - hostname = rawURL[:portStart] - port = rawURL[portStart+1:] - } else { - hostname = rawURL - port = "" // 未指定端口号 - } - - // 如果未指定端口号,根据协议设置默认端口号 - if port == "" { - //if scheme == "https" { - // port = "443" - //} else { - // port = "80" - //} - port = "80" - } - return -} diff --git a/x/httpget/request.go b/x/http/request.go similarity index 65% rename from x/httpget/request.go rename to x/http/request.go index 391311c..98d48b4 100644 --- a/x/httpget/request.go +++ b/x/http/request.go @@ -1,8 +1,9 @@ -package httpget +package http import ( "fmt" "io" + "net/url" "github.com/goplus/llgo/c" "github.com/goplus/llgoexamples/rust/hyper" @@ -10,19 +11,29 @@ import ( type Request struct { Method string - Url string + URL *url.URL + Req *hyper.Request } -func NewRequest(method, url string, body io.Reader) (*Request, error) { +func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { + parseURL, err := url.Parse(urlStr) + if err != nil { + return nil, err + } + req, err := NewHyperRequest(method, parseURL) + if err != nil { + return nil, err + } return &Request{ Method: method, - Url: url, + URL: parseURL, + Req: req, }, nil } -func NewHyperRequest(request *Request) (*hyper.Request, error) { - host, _, uri := parseURL(request.Url) - method := request.Method +func NewHyperRequest(method string, URL *url.URL) (*hyper.Request, error) { + host := URL.Hostname() + uri := URL.RequestURI() // Prepare the request req := hyper.NewRequest() // Set the request method and uri diff --git a/x/http/response.go b/x/http/response.go index 020f2d9..8d01b80 100644 --- a/x/http/response.go +++ b/x/http/response.go @@ -3,100 +3,31 @@ package http import ( "fmt" "io" + "unsafe" "github.com/goplus/llgo/c" "github.com/goplus/llgoexamples/rust/hyper" ) type Response struct { - Status string - StatusCode int - Header Header - Content io.ReadCloser - ContentLen int64 - respBodyWriter *io.PipeWriter - ResponseBody *uint8 - ResponseBodyLen uintptr -} - -// AppendToResponseBody (BodyForEachCallback) appends the body to the response -//func AppendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { -// resp := (*Response)(userdata) -// len := chunk.Len() -// buf := unsafe.Slice((*byte)(chunk.Bytes()), len) -// -// if resp.Content == nil { -// var reader *io.PipeReader -// reader, resp.respBodyWriter = io.Pipe() -// resp.Content = io.ReadCloser(reader) -// } -// resp.ContentLen += int64(len) -// var err error -// go func() { -// _, err = resp.respBodyWriter.Write(buf) -// }() -// if err != nil { -// fmt.Printf("Failed to write response body: %v\n", err) -// return hyper.IterBreak -// } -// return hyper.IterContinue -//} - -func (resp *Response) PrintBody1() { - go func() { - var reader *io.PipeReader - reader, writer := io.Pipe() - resp.Content = reader - writer.Write((*[1 << 30]byte)(c.Pointer(resp.ResponseBody))[:resp.ResponseBodyLen:resp.ResponseBodyLen]) - defer writer.Close() - }() - for i := 0; i < 10; i++ { - c.Usleep(1 * 1000 * 1000) - fmt.Println("Sleeping...") - } - var buffer = make([]byte, 4096) - for { - n, err := resp.Content.Read(buffer) - if err == io.EOF { - fmt.Printf("\n") - break - } - if err != nil { - fmt.Println("Error reading from pipe:", err) - break - } - fmt.Printf("%s", string(buffer[:n])) - } - buffer = nil - //body, _ := io.ReadAll(resp.Content) - //fmt.Println(string(body)) + Status string + StatusCode int + Header Header + Body io.ReadCloser + ContentLength int64 + respBodyWriter *io.PipeWriter } // AppendToResponseBody (BodyForEachCallback) appends the body to the response func AppendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { resp := (*Response)(userdata) - buf := chunk.Bytes() len := chunk.Len() - responseBody := (*uint8)(c.Malloc(resp.ResponseBodyLen + len)) - if responseBody == nil { - c.Fprintf(c.Stderr, c.Str("Failed to allocate memory for response body\n")) + buf := unsafe.Slice((*byte)(chunk.Bytes()), len) + _, err := resp.respBodyWriter.Write(buf) + resp.ContentLength += int64(len) + if err != nil { + fmt.Printf("Failed to write response body: %v\n", err) return hyper.IterBreak } - - // Copy the existing response body to the new buffer - if resp.ResponseBody != nil { - c.Memcpy(c.Pointer(responseBody), c.Pointer(resp.ResponseBody), resp.ResponseBodyLen) - c.Free(c.Pointer(resp.ResponseBody)) - } - - // Append the new data - c.Memcpy(c.Pointer(uintptr(c.Pointer(responseBody))+resp.ResponseBodyLen), c.Pointer(buf), len) - resp.ResponseBody = responseBody - resp.ResponseBodyLen += len return hyper.IterContinue } - -func (resp *Response) PrintBody2() { - //c.Printf(c.Str("%.*s\n"), c.Int(resp.ResponseBodyLen), resp.ResponseBody) - fmt.Println(string((*[1 << 30]byte)(c.Pointer(resp.ResponseBody))[:resp.ResponseBodyLen:resp.ResponseBodyLen])) -} diff --git a/x/httpget/transport.go b/x/http/transport.go similarity index 56% rename from x/httpget/transport.go rename to x/http/transport.go index 0450bbd..c2559a8 100644 --- a/x/httpget/transport.go +++ b/x/http/transport.go @@ -1,11 +1,9 @@ -package httpget +package http import ( - "bufio" "fmt" io2 "io" "strconv" - "strings" "unsafe" "github.com/goplus/llgo/c" @@ -27,6 +25,16 @@ type ConnData struct { type Transport struct { } +// TaskId The unique identifier of the next task polled from the executor +type TaskId c.Int + +const ( + NotSet TaskId = iota + Send + ReceiveResp + ReceiveRespBody +) + var DefaultTransport RoundTripper = &Transport{} // persistConn wraps a connection, usually a persistent one @@ -35,16 +43,15 @@ type persistConn struct { // alt optionally specifies the TLS NextProto RoundTripper. // This is used for HTTP/2 today and future protocols later. // If it's non-nil, the rest of the fields are unused. - alt RoundTripper - - conn *ConnData - t *Transport - br *bufio.Reader // from conn - bw *bufio.Writer // to conn - nwrite int64 // bytes written - reqch chan requestAndChan // written by roundTrip; read by readLoop - writech chan writeRequest // written by roundTrip; read by writeLoop - closech chan struct{} // closed when conn closed + //alt RoundTripper + //br *bufio.Reader // from conn + //bw *bufio.Writer // to conn + //nwrite int64 // bytes written + //writech chan writeRequest // written by roundTrip; read by writeLoop + //closech chan struct{} // closed when conn closed + conn *ConnData + t *Transport + reqch chan requestAndChan // written by roundTrip; read by readLoop } // incomparable is a zero-width, non-comparable type. Adding it to a struct @@ -58,20 +65,6 @@ type requestAndChan struct { ch chan responseAndError // unbuffered; always send in select on callerGone } -// A writeRequest is sent by the caller's goroutine to the -// writeLoop's goroutine to write a request while the read loop -// concurrently waits on both the write response and the server's -// reply. -type writeRequest struct { - // req *transportRequest - ch chan<- error - - // Optional blocking chan for Expect: 100-continue (for receive). - // If not nil, writeLoop blocks sending request body until - // it receives from this chan. - continueCh <-chan struct{} -} - // responseAndError is how the goroutine reading from an HTTP/1 server // communicates with the goroutine doing the RoundTrip. type responseAndError struct { @@ -80,23 +73,26 @@ type responseAndError struct { err error } -func (t *Transport) RoundTrip(request *Request) (*Response, error) { - req, err := NewHyperRequest(request) +func (t *Transport) RoundTrip(req *Request) (*Response, error) { + pconn, err := t.getConn(req) if err != nil { return nil, err } - pconn, err := t.getConn(req) var resp *Response resp, err = pconn.roundTrip(req) - if err == nil { - return resp, nil + if err != nil { + return nil, err } - return nil, err + return resp, nil } -func (t *Transport) getConn(req *hyper.Request) (pconn *persistConn, err error) { - host := "www.baidu.com" - port := "80" +func (t *Transport) getConn(req *Request) (pconn *persistConn, err error) { + host := req.URL.Hostname() + port := req.URL.Port() + if port == "" { + // Hyper only supports http + port = "80" + } loop := libuv.DefaultLoop() conn := (*ConnData)(c.Malloc(unsafe.Sizeof(ConnData{}))) if conn == nil { @@ -125,23 +121,23 @@ func (t *Transport) getConn(req *hyper.Request) (pconn *persistConn, err error) return nil, fmt.Errorf("connect error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(status)))) } pconn = &persistConn{ - conn: conn, - t: t, - reqch: make(chan requestAndChan, 1), - writech: make(chan writeRequest, 1), - closech: make(chan struct{}), + conn: conn, + t: t, + reqch: make(chan requestAndChan, 1), + //writech: make(chan writeRequest, 1), + //closech: make(chan struct{}), } net.Freeaddrinfo(res) - go pconn.startLoop(loop) + go pconn.readWriteLoop(loop) return pconn, nil } -func (pc *persistConn) roundTrip(req *hyper.Request) (resp *Response, err error) { +func (pc *persistConn) roundTrip(req *Request) (resp *Response, err error) { resc := make(chan responseAndError) pc.reqch <- requestAndChan{ - req: req, + req: req.Req, ch: resc, } @@ -157,7 +153,7 @@ func (pc *persistConn) roundTrip(req *hyper.Request) (resp *Response, err error) } } -func (pc *persistConn) startLoop(loop *libuv.Loop) { +func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // Hookup the IO io := NewIoWithConnReadWrite(pc.conn) @@ -169,150 +165,140 @@ func (pc *persistConn) startLoop(loop *libuv.Loop) { opts.Exec(exec) handshakeTask := hyper.Handshake(io, opts) - SetUserData(handshakeTask, hyper.ExampleHandshake) + SetTaskId(handshakeTask, Send) // Let's wait for the handshake to finish... exec.Push(handshakeTask) + // The polling state machine! + //for { + // Poll all ready tasks and act on them... + rc := <-pc.reqch // blocking + alive := true var hyperErr *hyper.Error var response Response + var respBody *hyper.Body = nil + for alive { + task := exec.Poll() + if task == nil { + //break + libuv.Run(loop, libuv.RUN_ONCE) + continue + } - var rc requestAndChan + switch (TaskId)(uintptr(task.Userdata())) { + case Send: + if task.Type() == hyper.TaskError { + c.Printf(c.Str("handshake error!\n")) + hyperErr = (*hyper.Error)(task.Value()) + Fail(hyperErr) + } + if task.Type() != hyper.TaskClientConn { + c.Printf(c.Str("unexpected task type\n")) + Fail(hyperErr) + } - select { - case rc = <-pc.reqch: - } - // The polling state machine! - for { - // Poll all ready tasks and act on them... - for { - task := exec.Poll() - if task == nil { - break + client := (*hyper.ClientConn)(task.Value()) + task.Free() + + // Send it! + sendTask := client.Send(rc.req) + SetTaskId(sendTask, ReceiveResp) + sendRes := exec.Push(sendTask) + if sendRes != hyper.OK { + panic("error send\n") } - switch (hyper.ExampleId)(uintptr(task.Userdata())) { - case hyper.ExampleHandshake: - if task.Type() == hyper.TaskError { - c.Printf(c.Str("handshake error!\n")) - hyperErr = (*hyper.Error)(task.Value()) - Fail(hyperErr) - } - if task.Type() != hyper.TaskClientConn { - c.Printf(c.Str("unexpected task type\n")) - Fail(hyperErr) - } + // For this example, no longer need the client + client.Free() - client := (*hyper.ClientConn)(task.Value()) - task.Free() + case ReceiveResp: + if task.Type() == hyper.TaskError { + c.Printf(c.Str("send error!\n")) + hyperErr = (*hyper.Error)(task.Value()) + Fail(hyperErr) + } + if task.Type() != hyper.TaskResponse { + c.Printf(c.Str("unexpected task type\n")) + Fail(hyperErr) + } - // Send it! - sendTask := client.Send(rc.req) - SetUserData(sendTask, hyper.ExampleSend) - sendRes := exec.Push(sendTask) - if sendRes != hyper.OK { - panic("error send\n") - } + // Take the results + resp := (*hyper.Response)(task.Value()) + task.Free() - // For this example, no longer need the client - client.Free() + rp := resp.ReasonPhrase() + rpLen := resp.ReasonPhraseLen() - break - case hyper.ExampleSend: - if task.Type() == hyper.TaskError { - c.Printf(c.Str("send error!\n")) - hyperErr = (*hyper.Error)(task.Value()) - Fail(hyperErr) - } - if task.Type() != hyper.TaskResponse { - c.Printf(c.Str("unexpected task type\n")) - Fail(hyperErr) - } + response.Status = strconv.Itoa(int(resp.Status())) + " " + string((*[1 << 30]byte)(c.Pointer(rp))[:rpLen:rpLen]) + response.StatusCode = int(resp.Status()) - // Take the results - resp := (*hyper.Response)(task.Value()) - task.Free() + headers := resp.Headers() + headers.Foreach(AppendToResponseHeader, c.Pointer(&response)) + //respBody := resp.Body() + respBody = resp.Body() - rp := resp.ReasonPhrase() - rpLen := resp.ReasonPhraseLen() - - response.Status = strconv.Itoa(int(resp.Status())) + " " + string((*[1 << 30]byte)(c.Pointer(rp))[:rpLen:rpLen]) - response.StatusCode = int(resp.Status()) - - headers := resp.Headers() - headers.Foreach(AppendToResponseHeader, c.Pointer(&response)) - respBody := resp.Body() - - response.Body, response.respBodyWriter = io2.Pipe() - - /*go func() { - fmt.Println("writing...") - for { - fmt.Println("writing for...") - dataTask := respBody.Data() - exec.Push(dataTask) - dataTask = exec.Poll() - if dataTask.Type() == hyper.TaskBuf { - buf := (*hyper.Buf)(dataTask.Value()) - len := buf.Len() - bytes := unsafe.Slice((*byte)(buf.Bytes()), len) - _, err := response.respBodyWriter.Write(bytes) - if err != nil { - fmt.Printf("Failed to write response body: %v\n", err) - break - } - dataTask.Free() - } else if dataTask.Type() == hyper.TaskEmpty { - fmt.Println("writing empty") - dataTask.Free() - break - } - } - fmt.Println("end writing") - defer response.respBodyWriter.Close() - }()*/ - - foreachTask := respBody.Foreach(AppendToResponseBody, c.Pointer(&response)) - - SetUserData(foreachTask, hyper.ExampleRespBody) - exec.Push(foreachTask) - - rc.ch <- responseAndError{res: &response} - // No longer need the response - //resp.Free() + response.Body, response.respBodyWriter = io2.Pipe() - break - case hyper.ExampleRespBody: - println("ExampleRespBody") - if task.Type() == hyper.TaskError { - c.Printf(c.Str("body error!\n")) - hyperErr = (*hyper.Error)(task.Value()) - Fail(hyperErr) - } - if task.Type() != hyper.TaskEmpty { - c.Printf(c.Str("unexpected task type\n")) - Fail(hyperErr) - } + //foreachTask := respBody.Foreach(AppendToResponseBody, c.Pointer(&response)) + //SetTaskId(foreachTask, ReceiveRespBody) + //exec.Push(foreachTask) - // Cleaning up before exiting - task.Free() - //exec.Free() - (*libuv.Handle)(c.Pointer(&pc.conn.TcpHandle)).Close(nil) + rc.ch <- responseAndError{res: &response} - FreeConnData(pc.conn) + dataTask := respBody.Data() + SetTaskId(dataTask, ReceiveRespBody) + exec.Push(dataTask) - //return &response, nil - break - case hyper.ExampleNotSet: - println("ExampleNotSet") - // A background task for hyper_client completed... + // No longer need the response + resp.Free() + + case ReceiveRespBody: + if task.Type() == hyper.TaskError { + c.Printf(c.Str("body error!\n")) + hyperErr = (*hyper.Error)(task.Value()) + Fail(hyperErr) + } + if task.Type() == hyper.TaskBuf { + buf := (*hyper.Buf)(task.Value()) + bufLen := buf.Len() + bytes := unsafe.Slice((*byte)(buf.Bytes()), bufLen) + _, err := response.respBodyWriter.Write(bytes) // blocking + if err != nil { + panic("[readWriteLoop(): case ReceiveRespBody] error write\n") + } + buf.Free() task.Free() + + dataTask := respBody.Data() + SetTaskId(dataTask, ReceiveRespBody) + exec.Push(dataTask) + break } + // task.Type() == hyper.TaskEmpty + if task.Type() != hyper.TaskEmpty { + c.Printf(c.Str("unexpected task type\n")) + Fail(hyperErr) + } + // Cleaning up before exiting + task.Free() + respBody.Free() + response.respBodyWriter.Close() + exec.Free() + (*libuv.Handle)(c.Pointer(&pc.conn.TcpHandle)).Close(nil) + FreeConnData(pc.conn) + + close(rc.ch) + close(pc.reqch) + + alive = false + case NotSet: + // A background task for hyper_client completed... + task.Free() } - - libuv.Run(loop, libuv.RUN_ONCE) } + //} } // AllocBuffer allocates a buffer for reading from a socket @@ -453,50 +439,8 @@ func NewIoWithConnReadWrite(connData *ConnData) *hyper.Io { return io } -// SetUserData Set the user data for the task -func SetUserData(task *hyper.Task, userData hyper.ExampleId) { +// SetTaskId Set TaskId to the task's userdata as a unique identifier +func SetTaskId(task *hyper.Task, userData TaskId) { var data = userData task.SetUserdata(c.Pointer(uintptr(data))) } - -// parseURL Parse the URL and extract the host name, port number, and URI -func parseURL(rawURL string) (hostname, port, uri string) { - // 找到 "://" 的位置,以分隔协议和主机名 - schemeEnd := strings.Index(rawURL, "://") - if schemeEnd != -1 { - //scheme = rawURL[:schemeEnd] - rawURL = rawURL[schemeEnd+3:] - } else { - //scheme = "http" // 默认协议为 http - } - - // 找到第一个 "/" 的位置,以分隔主机名和路径 - pathStart := strings.Index(rawURL, "/") - if pathStart != -1 { - uri = rawURL[pathStart:] - rawURL = rawURL[:pathStart] - } else { - uri = "/" - } - - // 找到 ":" 的位置,以分隔主机名和端口号 - portStart := strings.LastIndex(rawURL, ":") - if portStart != -1 { - hostname = rawURL[:portStart] - port = rawURL[portStart+1:] - } else { - hostname = rawURL - port = "" // 未指定端口号 - } - - // 如果未指定端口号,根据协议设置默认端口号 - if port == "" { - //if scheme == "https" { - // port = "443" - //} else { - // port = "80" - //} - port = "80" - } - return -} diff --git a/x/httpget/_demo/get/get.go b/x/httpget/_demo/get/get.go deleted file mode 100644 index da674ba..0000000 --- a/x/httpget/_demo/get/get.go +++ /dev/null @@ -1,24 +0,0 @@ -package main - -import ( - "fmt" - "io" - - "github.com/goplus/llgo/x/httpget" -) - -func main() { - resp, err := httpget.Get("www.baidu.com") - //req, _ := httpget.NewRequest("GET", "http://www.baidu.com", nil) - //resp, err := httpget.DefaultClient.Send(req, nil) - if err != nil { - fmt.Println(err) - return - } - body, err := io.ReadAll(resp.Body) - if err != nil { - fmt.Println(err) - return - } - fmt.Println(string(body)) -} diff --git a/x/httpget/client.go b/x/httpget/client.go deleted file mode 100644 index 8a1f610..0000000 --- a/x/httpget/client.go +++ /dev/null @@ -1,46 +0,0 @@ -package httpget - -type Client struct { - Transport RoundTripper -} - -var DefaultClient = &Client{} - -type RoundTripper interface { - RoundTrip(*Request) (*Response, error) -} - -func (c *Client) transport() RoundTripper { - if c.Transport != nil { - return c.Transport - } - return DefaultTransport -} - -func Get(url string) (*Response, error) { - return DefaultClient.Get(url) -} - -func (c *Client) Get(url string) (*Response, error) { - req, err := NewRequest("GET", url, nil) - if err != nil { - return nil, err - } - return c.Do(req) -} - -func (c *Client) Do(req *Request) (*Response, error) { - return c.do(req) -} - -func (c *Client) do(req *Request) (*Response, error) { - return c.send(req, nil) -} - -func (c *Client) send(req *Request, deadline any) (*Response, error) { - return send(req, c.transport(), deadline) -} - -func send(req *Request, rt RoundTripper, deadline any) (resp *Response, err error) { - return rt.RoundTrip(req) -} diff --git a/x/httpget/header.go b/x/httpget/header.go deleted file mode 100644 index 1768557..0000000 --- a/x/httpget/header.go +++ /dev/null @@ -1,32 +0,0 @@ -package httpget - -import ( - "fmt" - - "github.com/goplus/llgo/c" - "github.com/goplus/llgoexamples/rust/hyper" -) - -type Header map[string][]string - -// AppendToResponseHeader (HeadersForEachCallback) prints each header to the console -func AppendToResponseHeader(userdata c.Pointer, name *uint8, nameLen uintptr, value *uint8, valueLen uintptr) c.Int { - resp := (*Response)(userdata) - nameStr := string((*[1 << 30]byte)(c.Pointer(name))[:nameLen:nameLen]) - valueStr := string((*[1 << 30]byte)(c.Pointer(value))[:valueLen:valueLen]) - - if resp.Header == nil { - resp.Header = make(map[string][]string) - } - resp.Header[nameStr] = append(resp.Header[nameStr], valueStr) - //c.Printf(c.Str("%.*s: %.*s\n"), int(nameLen), name, int(valueLen), value) - return hyper.IterContinue -} - -func (resp *Response) PrintHeaders() { - for key, values := range resp.Header { - for _, value := range values { - fmt.Printf("%s: %s\n", key, value) - } - } -} diff --git a/x/httpget/response.go b/x/httpget/response.go deleted file mode 100644 index a9e4468..0000000 --- a/x/httpget/response.go +++ /dev/null @@ -1,54 +0,0 @@ -package httpget - -import ( - "fmt" - "io" - "unsafe" - - "github.com/goplus/llgo/c" - "github.com/goplus/llgoexamples/rust/hyper" -) - -type Response struct { - Status string - StatusCode int - Header Header - Body io.ReadCloser - ContentLength int64 - respBodyWriter *io.PipeWriter -} - -// AppendToResponseBody (BodyForEachCallback) appends the body to the response -func AppendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { - fmt.Println("reading1...") - resp := (*Response)(userdata) - len := chunk.Len() - buf := unsafe.Slice((*byte)(chunk.Bytes()), len) - _, err := resp.respBodyWriter.Write(buf) - resp.ContentLength += int64(len) - if err != nil { - fmt.Printf("Failed to write response body: %v\n", err) - return hyper.IterBreak - } - fmt.Println("reading2...") - return hyper.IterContinue -} - -func (resp *Response) PrintBody() { - var buffer = make([]byte, 4096) - for { - n, err := resp.Body.Read(buffer) - if err == io.EOF { - fmt.Printf("\n") - break - } - if err != nil { - fmt.Println("Error reading from pipe:", err) - break - } - fmt.Printf("%s", string(buffer[:n])) - } - buffer = nil - //body, _ := io.ReadAll(resp.Content) - //fmt.Println(string(body)) -} From 763dd7227d621dbc4dc0a9eb1de98496c31983e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Thu, 8 Aug 2024 18:04:50 +0800 Subject: [PATCH 05/55] WIP(x/http/client/get): Some code optimization and comment addition --- go.mod | 2 +- go.sum | 4 +- x/http/_demo/get/get.go | 2 +- x/http/request.go | 15 +- x/http/response.go | 40 +++--- x/http/transport.go | 305 +++++++++++++++++++++++++++------------- 6 files changed, 240 insertions(+), 128 deletions(-) diff --git a/go.mod b/go.mod index fa05f1f..978043d 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module github.com/goplus/llgoexamples go 1.20 -require github.com/goplus/llgo v0.9.5-0.20240805045323-4bff9cc3df0c +require github.com/goplus/llgo v0.9.6-0.20240808082624-c5b96f4e9cf3 diff --git a/go.sum b/go.sum index ba1d000..17cad08 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,2 @@ -github.com/goplus/llgo v0.9.5-0.20240805045323-4bff9cc3df0c h1:PhaSnZL8LLyRIHWc5Wim9No0Q475H8Ljikxfj1gHHjc= -github.com/goplus/llgo v0.9.5-0.20240805045323-4bff9cc3df0c/go.mod h1:zsrtWZapL4aklZc99xBSZRynGzLTIT1mLRjP0VSn9iw= +github.com/goplus/llgo v0.9.6-0.20240808082624-c5b96f4e9cf3 h1:il9j5kdSnaoO57XJ8ebSHppPWIJ8iwqgcegOJNkipt4= +github.com/goplus/llgo v0.9.6-0.20240808082624-c5b96f4e9cf3/go.mod h1:zsrtWZapL4aklZc99xBSZRynGzLTIT1mLRjP0VSn9iw= diff --git a/x/http/_demo/get/get.go b/x/http/_demo/get/get.go index 73e7113..c8460e4 100644 --- a/x/http/_demo/get/get.go +++ b/x/http/_demo/get/get.go @@ -13,7 +13,7 @@ func main() { fmt.Println(err) return } - println(resp.Status) + fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) resp.PrintHeaders() body, err := io.ReadAll(resp.Body) if err != nil { diff --git a/x/http/request.go b/x/http/request.go index 98d48b4..23e98b4 100644 --- a/x/http/request.go +++ b/x/http/request.go @@ -4,15 +4,19 @@ import ( "fmt" "io" "net/url" + "time" "github.com/goplus/llgo/c" "github.com/goplus/llgoexamples/rust/hyper" ) type Request struct { - Method string - URL *url.URL - Req *hyper.Request + Method string + URL *url.URL + Req *hyper.Request + Host string + Header Header + Timeout time.Duration } func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { @@ -20,7 +24,7 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { if err != nil { return nil, err } - req, err := NewHyperRequest(method, parseURL) + req, err := newHyperRequest(method, parseURL) if err != nil { return nil, err } @@ -28,10 +32,11 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { Method: method, URL: parseURL, Req: req, + Host: parseURL.Hostname(), }, nil } -func NewHyperRequest(method string, URL *url.URL) (*hyper.Request, error) { +func newHyperRequest(method string, URL *url.URL) (*hyper.Request, error) { host := URL.Hostname() uri := URL.RequestURI() // Prepare the request diff --git a/x/http/response.go b/x/http/response.go index 8d01b80..2f3a641 100644 --- a/x/http/response.go +++ b/x/http/response.go @@ -1,33 +1,27 @@ package http import ( - "fmt" "io" - "unsafe" - - "github.com/goplus/llgo/c" - "github.com/goplus/llgoexamples/rust/hyper" ) type Response struct { - Status string - StatusCode int - Header Header - Body io.ReadCloser - ContentLength int64 - respBodyWriter *io.PipeWriter + Status string + StatusCode int + Header Header + Body io.ReadCloser + ContentLength int64 } // AppendToResponseBody (BodyForEachCallback) appends the body to the response -func AppendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { - resp := (*Response)(userdata) - len := chunk.Len() - buf := unsafe.Slice((*byte)(chunk.Bytes()), len) - _, err := resp.respBodyWriter.Write(buf) - resp.ContentLength += int64(len) - if err != nil { - fmt.Printf("Failed to write response body: %v\n", err) - return hyper.IterBreak - } - return hyper.IterContinue -} +//func AppendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { +// resp := (*Response)(userdata) +// len := chunk.Len() +// buf := unsafe.Slice((*byte)(chunk.Bytes()), len) +// _, err := resp.respBodyWriter.Write(buf) +// resp.ContentLength += int64(len) +// if err != nil { +// fmt.Printf("Failed to write response body: %v\n", err) +// return hyper.IterBreak +// } +// return hyper.IterContinue +//} diff --git a/x/http/transport.go b/x/http/transport.go index c2559a8..ed08e45 100644 --- a/x/http/transport.go +++ b/x/http/transport.go @@ -2,7 +2,7 @@ package http import ( "fmt" - io2 "io" + "io" "strconv" "unsafe" @@ -35,6 +35,10 @@ const ( ReceiveRespBody ) +const ( + DefaultHTTPPort = "80" +) + var DefaultTransport RoundTripper = &Transport{} // persistConn wraps a connection, usually a persistent one @@ -87,14 +91,15 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { } func (t *Transport) getConn(req *Request) (pconn *persistConn, err error) { - host := req.URL.Hostname() + host := req.Host port := req.URL.Port() if port == "" { // Hyper only supports http - port = "80" + port = DefaultHTTPPort } loop := libuv.DefaultLoop() - conn := (*ConnData)(c.Malloc(unsafe.Sizeof(ConnData{}))) + //conn := (*ConnData)(c.Calloc(1, unsafe.Sizeof(ConnData{}))) + conn := new(ConnData) if conn == nil { return nil, fmt.Errorf("Failed to allocate memory for conn_data\n") } @@ -144,7 +149,7 @@ func (pc *persistConn) roundTrip(req *Request) (resp *Response, err error) { select { case re := <-resc: if (re.res == nil) == (re.err == nil) { - panic(fmt.Sprintf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil)) + return nil, fmt.Errorf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil) } if re.err != nil { return nil, err @@ -153,18 +158,19 @@ func (pc *persistConn) roundTrip(req *Request) (resp *Response, err error) { } } +// readWriteLoop handles the main I/O loop for a persistent connection. +// It processes incoming requests, sends them to the server, and handles responses. func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // Hookup the IO - io := NewIoWithConnReadWrite(pc.conn) + hyperIo := NewIoWithConnReadWrite(pc.conn) // We need an executor generally to poll futures exec := hyper.NewExecutor() - // Prepare client options opts := hyper.NewClientConnOptions() opts.Exec(exec) - handshakeTask := hyper.Handshake(io, opts) + handshakeTask := hyper.Handshake(hyperIo, opts) SetTaskId(handshakeTask, Send) // Let's wait for the handshake to finish... @@ -175,27 +181,25 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // Poll all ready tasks and act on them... rc := <-pc.reqch // blocking alive := true - var hyperErr *hyper.Error var response Response + var bodyWriter *io.PipeWriter var respBody *hyper.Body = nil for alive { task := exec.Poll() if task == nil { //break - libuv.Run(loop, libuv.RUN_ONCE) + loop.Run(libuv.RUN_ONCE) continue } switch (TaskId)(uintptr(task.Userdata())) { case Send: - if task.Type() == hyper.TaskError { - c.Printf(c.Str("handshake error!\n")) - hyperErr = (*hyper.Error)(task.Value()) - Fail(hyperErr) - } - if task.Type() != hyper.TaskClientConn { - c.Printf(c.Str("unexpected task type\n")) - Fail(hyperErr) + err := CheckTaskType(task, Send) + if err != nil { + rc.ch <- responseAndError{err: err} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return } client := (*hyper.ClientConn)(task.Value()) @@ -206,21 +210,21 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { SetTaskId(sendTask, ReceiveResp) sendRes := exec.Push(sendTask) if sendRes != hyper.OK { - panic("error send\n") + rc.ch <- responseAndError{err: fmt.Errorf("failed to send request")} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return } // For this example, no longer need the client client.Free() - case ReceiveResp: - if task.Type() == hyper.TaskError { - c.Printf(c.Str("send error!\n")) - hyperErr = (*hyper.Error)(task.Value()) - Fail(hyperErr) - } - if task.Type() != hyper.TaskResponse { - c.Printf(c.Str("unexpected task type\n")) - Fail(hyperErr) + err := CheckTaskType(task, ReceiveResp) + if err != nil { + rc.ch <- responseAndError{err: err} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return } // Take the results @@ -235,14 +239,25 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { headers := resp.Headers() headers.Foreach(AppendToResponseHeader, c.Pointer(&response)) - //respBody := resp.Body() respBody = resp.Body() - response.Body, response.respBodyWriter = io2.Pipe() + response.Body, bodyWriter = io.Pipe() - //foreachTask := respBody.Foreach(AppendToResponseBody, c.Pointer(&response)) - //SetTaskId(foreachTask, ReceiveRespBody) - //exec.Push(foreachTask) + // TODO(spongehah) Replace header operations with using the textproto package + lengthSlice := response.Header["content-length"] + if lengthSlice == nil { + response.ContentLength = 0 + } else { + contentLength := response.Header["content-length"][0] + length, err := strconv.Atoi(contentLength) + if err != nil { + rc.ch <- responseAndError{err: fmt.Errorf("failed to parse content-length")} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return + } + response.ContentLength = int64(length) + } rc.ch <- responseAndError{res: &response} @@ -252,20 +267,31 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // No longer need the response resp.Free() - case ReceiveRespBody: - if task.Type() == hyper.TaskError { - c.Printf(c.Str("body error!\n")) - hyperErr = (*hyper.Error)(task.Value()) - Fail(hyperErr) + err := CheckTaskType(task, ReceiveRespBody) + if err != nil { + rc.ch <- responseAndError{err: err} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return } + if task.Type() == hyper.TaskBuf { buf := (*hyper.Buf)(task.Value()) bufLen := buf.Len() bytes := unsafe.Slice((*byte)(buf.Bytes()), bufLen) - _, err := response.respBodyWriter.Write(bytes) // blocking + if bodyWriter == nil { + rc.ch <- responseAndError{err: fmt.Errorf("ResponseBodyWriter is nil")} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return + } + _, err := bodyWriter.Write(bytes) // blocking if err != nil { - panic("[readWriteLoop(): case ReceiveRespBody] error write\n") + rc.ch <- responseAndError{err: err} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return } buf.Free() task.Free() @@ -276,21 +302,18 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { break } - // task.Type() == hyper.TaskEmpty + + // We are done with the response body if task.Type() != hyper.TaskEmpty { c.Printf(c.Str("unexpected task type\n")) - Fail(hyperErr) + rc.ch <- responseAndError{err: fmt.Errorf("unexpected task type\n")} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return } - // Cleaning up before exiting - task.Free() - respBody.Free() - response.respBodyWriter.Close() - exec.Free() - (*libuv.Handle)(c.Pointer(&pc.conn.TcpHandle)).Close(nil) - FreeConnData(pc.conn) - close(rc.ch) - close(pc.reqch) + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) alive = false case NotSet: @@ -301,6 +324,19 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { //} } +// OnConnect is the libuv callback for a successful connection +func OnConnect(req *libuv.Connect, status c.Int) { + //conn := (*ConnData)(req.Data) + //conn := (*struct{ data *ConnData })(c.Pointer(req)).data + conn := (*ConnData)((*libuv.Req)(c.Pointer(req)).GetData()) + + if status < 0 { + c.Fprintf(c.Stderr, c.Str("connect error: %d\n"), libuv.Strerror(libuv.Errno(status))) + return + } + (*libuv.Stream)(c.Pointer(&conn.TcpHandle)).StartRead(AllocBuffer, OnRead) +} + // AllocBuffer allocates a buffer for reading from a socket func AllocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { //conn := (*ConnData)(handle.Data) @@ -314,108 +350,160 @@ func AllocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { } // OnRead is the libuv callback for reading from a socket +// This callback function is called when data is available to be read func OnRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { + // Get the connection data associated with the stream + conn := (*ConnData)((*libuv.Handle)(c.Pointer(stream)).GetData()) //conn := (*ConnData)(stream.Data) //conn := (*struct{ data *ConnData })(c.Pointer(stream)).data - conn := (*ConnData)((*libuv.Handle)(c.Pointer(stream)).GetData()) + + // If data was read (nread > 0) if nread > 0 { + // Update the amount of filled buffer conn.ReadBufFilled += uintptr(nread) } + // If there's a pending read waker if conn.ReadWaker != nil { + // Wake up the pending read operation of Hyper conn.ReadWaker.Wake() + // Clear the waker reference conn.ReadWaker = nil } } -// ReadCallBack is the hyper callback for reading from a socket +// ReadCallBack read callback function for Hyper library func ReadCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { + // Get the user data (connection data) conn := (*ConnData)(userdata) + // If there's data in the buffer if conn.ReadBufFilled > 0 { + // Calculate how much data to copy (minimum of filled amount and requested amount) var toCopy uintptr if bufLen < conn.ReadBufFilled { toCopy = bufLen } else { toCopy = conn.ReadBufFilled } + // Copy data from read buffer to Hyper's buffer c.Memcpy(c.Pointer(buf), c.Pointer(conn.ReadBuf.Base), toCopy) + // Move remaining data to the beginning of the buffer c.Memmove(c.Pointer(conn.ReadBuf.Base), c.Pointer(uintptr(c.Pointer(conn.ReadBuf.Base))+toCopy), conn.ReadBufFilled-toCopy) + // Update the amount of filled buffer conn.ReadBufFilled -= toCopy + // Return the number of bytes copied return toCopy } + // If no data in buffer, set up a waker to wait for more data + // Free the old waker if it exists if conn.ReadWaker != nil { conn.ReadWaker.Free() } + // Create a new waker conn.ReadWaker = ctx.Waker() + // Return HYPER_IO_PENDING to indicate operation is pending, waiting for more data return hyper.IoPending } // OnWrite is the libuv callback for writing to a socket +// Callback function called after a write operation completes func OnWrite(req *libuv.Write, status c.Int) { + // Get the connection data associated with the write request + conn := (*ConnData)((*libuv.Req)(c.Pointer(req)).GetData()) //conn := (*ConnData)(req.Data) //conn := (*struct{ data *ConnData })(c.Pointer(req)).data - conn := (*ConnData)((*libuv.Req)(c.Pointer(req)).GetData()) + // If there's a pending write waker if conn.WriteWaker != nil { + // Wake up the pending write operation conn.WriteWaker.Wake() + // Clear the waker reference conn.WriteWaker = nil } - c.Free(c.Pointer(req)) } -// WriteCallBack is the hyper callback for writing to a socket +// WriteCallBack write callback function for Hyper library func WriteCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { + // Get the user data (connection data) conn := (*ConnData)(userdata) + // Create a libuv buffer initBuf := libuv.InitBuf((*c.Char)(c.Pointer(buf)), c.Uint(bufLen)) - req := (*libuv.Write)(c.Malloc(unsafe.Sizeof(libuv.Write{}))) - //req.Data = c.Pointer(conn) + //req := (*libuv.Write)(c.Malloc(unsafe.Sizeof(libuv.Write{}))) + req := &libuv.Write{} + // Associate the connection data with the write request (*libuv.Req)(c.Pointer(req)).SetData(c.Pointer(conn)) - ret := req.Write((*libuv.Stream)(c.Pointer(&conn.TcpHandle)), &initBuf, 1, OnWrite) + //req.Data = c.Pointer(conn) + // Perform the asynchronous write operation + ret := req.Write((*libuv.Stream)(c.Pointer(&conn.TcpHandle)), &initBuf, 1, OnWrite) + // If the write operation was successfully initiated if ret >= 0 { + // Return the number of bytes to be written return bufLen } + // If the write operation can't complete immediately, set up a waker to wait for completion if conn.WriteWaker != nil { + // Free the old waker if it exists conn.WriteWaker.Free() } + // Create a new waker conn.WriteWaker = ctx.Waker() + // Return HYPER_IO_PENDING to indicate operation is pending, waiting for write to complete return hyper.IoPending } -// OnConnect is the libuv callback for a successful connection -func OnConnect(req *libuv.Connect, status c.Int) { - //conn := (*ConnData)(req.Data) - //conn := (*struct{ data *ConnData })(c.Pointer(req)).data - conn := (*ConnData)((*libuv.Req)(c.Pointer(req)).GetData()) +// NewIoWithConnReadWrite creates a new IO with read and write callbacks +func NewIoWithConnReadWrite(connData *ConnData) *hyper.Io { + hyperIo := hyper.NewIo() + hyperIo.SetUserdata(c.Pointer(connData)) + hyperIo.SetRead(ReadCallBack) + hyperIo.SetWrite(WriteCallBack) + return hyperIo +} - if status < 0 { - c.Fprintf(c.Stderr, c.Str("connect error: %d\n"), libuv.Strerror(libuv.Errno(status))) - return - } - (*libuv.Stream)(c.Pointer(&conn.TcpHandle)).StartRead(AllocBuffer, OnRead) +// SetTaskId Set TaskId to the task's userdata as a unique identifier +func SetTaskId(task *hyper.Task, userData TaskId) { + var data = userData + task.SetUserdata(unsafe.Pointer(uintptr(data))) } -// FreeConnData frees the connection data -func FreeConnData(conn *ConnData) { - if conn.ReadWaker != nil { - conn.ReadWaker.Free() - conn.ReadWaker = nil - } - if conn.WriteWaker != nil { - conn.WriteWaker.Free() - conn.WriteWaker = nil - } - if conn.ReadBuf.Base != nil { - c.Free(c.Pointer(conn.ReadBuf.Base)) - conn.ReadBuf.Base = nil +// CheckTaskType checks the task type +func CheckTaskType(task *hyper.Task, curTaskId TaskId) error { + switch curTaskId { + case Send: + if task.Type() == hyper.TaskError { + c.Printf(c.Str("handshake task error!\n")) + return Fail((*hyper.Error)(task.Value())) + } + if task.Type() != hyper.TaskClientConn { + return fmt.Errorf("unexpected task type\n") + } + return nil + case ReceiveResp: + if task.Type() == hyper.TaskError { + c.Printf(c.Str("send task error!\n")) + return Fail((*hyper.Error)(task.Value())) + } + if task.Type() != hyper.TaskResponse { + c.Printf(c.Str("unexpected task type\n")) + return fmt.Errorf("unexpected task type\n") + } + return nil + case ReceiveRespBody: + if task.Type() == hyper.TaskError { + c.Printf(c.Str("body error!\n")) + return Fail((*hyper.Error)(task.Value())) + } + return nil + case NotSet: } - c.Free(c.Pointer(conn)) + return fmt.Errorf("unexpected TaskId\n") } // Fail prints the error details and panics -func Fail(err *hyper.Error) { +func Fail(err *hyper.Error) error { if err != nil { c.Printf(c.Str("error code: %d\n"), err.Code()) // grab the error details @@ -426,21 +514,46 @@ func Fail(err *hyper.Error) { // clean up the error err.Free() - panic("hyper error \n") + return fmt.Errorf("hyper request error, error code: %d\n", int(err.Code())) } + return nil } -// NewIoWithConnReadWrite creates a new IO with read and write callbacks -func NewIoWithConnReadWrite(connData *ConnData) *hyper.Io { - io := hyper.NewIo() - io.SetUserdata(c.Pointer(connData)) - io.SetRead(ReadCallBack) - io.SetWrite(WriteCallBack) - return io +// FreeResources frees the resources +func FreeResources(task *hyper.Task, respBody *hyper.Body, bodyWriter *io.PipeWriter, exec *hyper.Executor, pc *persistConn, rc requestAndChan) { + // Cleaning up before exiting + if task != nil { + task.Free() + } + if respBody != nil { + respBody.Free() + } + if bodyWriter != nil { + bodyWriter.Close() + } + if exec != nil { + exec.Free() + } + (*libuv.Handle)(c.Pointer(&pc.conn.TcpHandle)).Close(nil) + FreeConnData(pc.conn) + + // Closing the channel + close(rc.ch) + close(pc.reqch) } -// SetTaskId Set TaskId to the task's userdata as a unique identifier -func SetTaskId(task *hyper.Task, userData TaskId) { - var data = userData - task.SetUserdata(c.Pointer(uintptr(data))) +// FreeConnData frees the connection data +func FreeConnData(conn *ConnData) { + if conn.ReadWaker != nil { + conn.ReadWaker.Free() + conn.ReadWaker = nil + } + if conn.WriteWaker != nil { + conn.WriteWaker.Free() + conn.WriteWaker = nil + } + if conn.ReadBuf.Base != nil { + c.Free(c.Pointer(conn.ReadBuf.Base)) + conn.ReadBuf.Base = nil + } } From b9f4944b87bbe886d6005a38e2404860e354adeb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Fri, 9 Aug 2024 18:06:57 +0800 Subject: [PATCH 06/55] WIP(c/http/client): Add request timeout logic --- go.mod | 2 +- go.sum | 4 +- x/http/_demo/timeout/timeout.go | 33 ++++ x/http/client.go | 12 +- x/http/request.go | 11 +- x/http/transport.go | 319 ++++++++++++++++++++------------ 6 files changed, 250 insertions(+), 131 deletions(-) create mode 100644 x/http/_demo/timeout/timeout.go diff --git a/go.mod b/go.mod index 978043d..4082df2 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module github.com/goplus/llgoexamples go 1.20 -require github.com/goplus/llgo v0.9.6-0.20240808082624-c5b96f4e9cf3 +require github.com/goplus/llgo v0.9.7-0.20240812013847-321766fd4641 diff --git a/go.sum b/go.sum index 17cad08..e3abd53 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,2 @@ -github.com/goplus/llgo v0.9.6-0.20240808082624-c5b96f4e9cf3 h1:il9j5kdSnaoO57XJ8ebSHppPWIJ8iwqgcegOJNkipt4= -github.com/goplus/llgo v0.9.6-0.20240808082624-c5b96f4e9cf3/go.mod h1:zsrtWZapL4aklZc99xBSZRynGzLTIT1mLRjP0VSn9iw= +github.com/goplus/llgo v0.9.7-0.20240812013847-321766fd4641 h1:VIJ38bCFRIIr62YXyRKkxy6GXYVA6R3xqAb0HkcoUgw= +github.com/goplus/llgo v0.9.7-0.20240812013847-321766fd4641/go.mod h1:5Fs+08NslqofJ7xtOiIXugkurYOoQvY02ZkFNWA1uEI= diff --git a/x/http/_demo/timeout/timeout.go b/x/http/_demo/timeout/timeout.go new file mode 100644 index 0000000..42f8bf8 --- /dev/null +++ b/x/http/_demo/timeout/timeout.go @@ -0,0 +1,33 @@ +package main + +import ( + "fmt" + "io" + "time" + + "github.com/goplus/llgoexamples/x/http" +) + +func main() { + client := &http.Client{ + Timeout: time.Millisecond, // Set a small timeout to ensure it will time out + //Timeout: time.Second * 5, + } + req, err := http.NewRequest("GET", "https://www.baidu.com", nil) + if err != nil { + fmt.Println(err.Error()) + return + } + resp, err := client.Do(req) + if err != nil { + fmt.Println(err.Error()) + return + } + body, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Println(err.Error()) + return + } + println(string(body)) + defer resp.Body.Close() +} diff --git a/x/http/client.go b/x/http/client.go index ac0bc6e..9ce8506 100644 --- a/x/http/client.go +++ b/x/http/client.go @@ -1,7 +1,10 @@ package http +import "time" + type Client struct { Transport RoundTripper + Timeout time.Duration } var DefaultClient = &Client{} @@ -34,13 +37,14 @@ func (c *Client) Do(req *Request) (*Response, error) { } func (c *Client) do(req *Request) (*Response, error) { - return c.send(req, nil) + return c.send(req, c.Timeout) } -func (c *Client) send(req *Request, deadline any) (*Response, error) { - return send(req, c.transport(), deadline) +func (c *Client) send(req *Request, timeout time.Duration) (*Response, error) { + return send(req, c.transport(), timeout) } -func send(req *Request, rt RoundTripper, deadline any) (resp *Response, err error) { +func send(req *Request, rt RoundTripper, timeout time.Duration) (resp *Response, err error) { + req.timeout = timeout return rt.RoundTrip(req) } diff --git a/x/http/request.go b/x/http/request.go index 23e98b4..2e04939 100644 --- a/x/http/request.go +++ b/x/http/request.go @@ -16,7 +16,7 @@ type Request struct { Req *hyper.Request Host string Header Header - Timeout time.Duration + timeout time.Duration } func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { @@ -29,10 +29,11 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { return nil, err } return &Request{ - Method: method, - URL: parseURL, - Req: req, - Host: parseURL.Hostname(), + Method: method, + URL: parseURL, + Req: req, + Host: parseURL.Hostname(), + timeout: 0, }, nil } diff --git a/x/http/transport.go b/x/http/transport.go index ed08e45..1eed648 100644 --- a/x/http/transport.go +++ b/x/http/transport.go @@ -17,6 +17,8 @@ type ConnData struct { TcpHandle libuv.Tcp ConnectReq libuv.Connect ReadBuf libuv.Buf + TimeoutTimer libuv.Timer + IsCompleted int ReadBufFilled uintptr ReadWaker *hyper.Waker WriteWaker *hyper.Waker @@ -53,9 +55,11 @@ type persistConn struct { //nwrite int64 // bytes written //writech chan writeRequest // written by roundTrip; read by writeLoop //closech chan struct{} // closed when conn closed - conn *ConnData - t *Transport - reqch chan requestAndChan // written by roundTrip; read by readLoop + conn *ConnData + t *Transport + reqch chan requestAndChan // written by roundTrip; read by readLoop + cancelch chan freeChan + timeoutch chan struct{} } // incomparable is a zero-width, non-comparable type. Adding it to a struct @@ -65,7 +69,7 @@ type incomparable [0]func() type requestAndChan struct { _ incomparable - req *hyper.Request + req *Request ch chan responseAndError // unbuffered; always send in select on callerGone } @@ -77,6 +81,17 @@ type responseAndError struct { err error } +type connAndTimeoutChan struct { + _ incomparable + conn *ConnData + timeoutch chan struct{} +} + +type freeChan struct { + _ incomparable + freech chan struct{} +} + func (t *Transport) RoundTrip(req *Request) (*Response, error) { pconn, err := t.getConn(req) if err != nil { @@ -104,6 +119,18 @@ func (t *Transport) getConn(req *Request) (pconn *persistConn, err error) { return nil, fmt.Errorf("Failed to allocate memory for conn_data\n") } + // If timeout is set, start the timer + timeoutch := make(chan struct{}, 1) + if req.timeout != 0 { + libuv.InitTimer(loop, &conn.TimeoutTimer) + ct := &connAndTimeoutChan{ + conn: conn, + timeoutch: timeoutch, + } + (*libuv.Handle)(c.Pointer(&conn.TimeoutTimer)).SetData(c.Pointer(ct)) + conn.TimeoutTimer.Start(OnTimeout, uint64(req.timeout.Milliseconds()), 0) + } + libuv.InitTcp(loop, &conn.TcpHandle) //conn.TcpHandle.Data = c.Pointer(conn) (*libuv.Handle)(c.Pointer(&conn.TcpHandle)).SetData(c.Pointer(conn)) @@ -116,6 +143,7 @@ func (t *Transport) getConn(req *Request) (pconn *persistConn, err error) { var res *net.AddrInfo status := net.Getaddrinfo(c.AllocaCStr(host), c.AllocaCStr(port), &hints, &res) if status != 0 { + close(timeoutch) return nil, fmt.Errorf("getaddrinfo error\n") } @@ -123,38 +151,57 @@ func (t *Transport) getConn(req *Request) (pconn *persistConn, err error) { (*libuv.Req)(c.Pointer(&conn.ConnectReq)).SetData(c.Pointer(conn)) status = libuv.TcpConnect(&conn.ConnectReq, &conn.TcpHandle, res.Addr, OnConnect) if status != 0 { + close(timeoutch) return nil, fmt.Errorf("connect error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(status)))) } pconn = &persistConn{ - conn: conn, - t: t, - reqch: make(chan requestAndChan, 1), + conn: conn, + t: t, + reqch: make(chan requestAndChan, 1), + cancelch: make(chan freeChan, 1), + timeoutch: timeoutch, //writech: make(chan writeRequest, 1), //closech: make(chan struct{}), } net.Freeaddrinfo(res) - go pconn.readWriteLoop(loop) + if pconn.conn.IsCompleted != 1 { + go pconn.readWriteLoop(loop) + } return pconn, nil } -func (pc *persistConn) roundTrip(req *Request) (resp *Response, err error) { - resc := make(chan responseAndError) +func (pc *persistConn) roundTrip(req *Request) (*Response, error) { + resc := make(chan responseAndError, 1) + pc.reqch <- requestAndChan{ - req: req.Req, + req: req, ch: resc, } - + // Determine whether timeout has occurred + if pc.conn.IsCompleted == 1 { + rc := <-pc.reqch // blocking + // Free the resources + FreeResources(nil, nil, nil, nil, pc, rc) + } select { case re := <-resc: if (re.res == nil) == (re.err == nil) { return nil, fmt.Errorf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil) } if re.err != nil { - return nil, err + return nil, re.err } return re.res, nil + case <-pc.timeoutch: + freech := make(chan struct{}, 1) + pc.cancelch <- freeChan{ + freech: freech, + } + <-freech + close(freech) + return nil, fmt.Errorf("request timeout\n") } } @@ -185,140 +232,156 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { var bodyWriter *io.PipeWriter var respBody *hyper.Body = nil for alive { - task := exec.Poll() - if task == nil { - //break - loop.Run(libuv.RUN_ONCE) - continue - } - - switch (TaskId)(uintptr(task.Userdata())) { - case Send: - err := CheckTaskType(task, Send) - if err != nil { - rc.ch <- responseAndError{err: err} - // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) - return - } - - client := (*hyper.ClientConn)(task.Value()) - task.Free() - - // Send it! - sendTask := client.Send(rc.req) - SetTaskId(sendTask, ReceiveResp) - sendRes := exec.Push(sendTask) - if sendRes != hyper.OK { - rc.ch <- responseAndError{err: fmt.Errorf("failed to send request")} - // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) - return - } - - // For this example, no longer need the client - client.Free() - case ReceiveResp: - err := CheckTaskType(task, ReceiveResp) - if err != nil { - rc.ch <- responseAndError{err: err} - // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) - return + select { + case fc := <-pc.cancelch: + // Free the resources + FreeResources(nil, respBody, bodyWriter, exec, pc, rc) + alive = false + fc.freech <- struct{}{} + return + default: + task := exec.Poll() + if task == nil { + //break + loop.Run(libuv.RUN_ONCE) + continue } - - // Take the results - resp := (*hyper.Response)(task.Value()) - task.Free() - - rp := resp.ReasonPhrase() - rpLen := resp.ReasonPhraseLen() - - response.Status = strconv.Itoa(int(resp.Status())) + " " + string((*[1 << 30]byte)(c.Pointer(rp))[:rpLen:rpLen]) - response.StatusCode = int(resp.Status()) - - headers := resp.Headers() - headers.Foreach(AppendToResponseHeader, c.Pointer(&response)) - respBody = resp.Body() - - response.Body, bodyWriter = io.Pipe() - - // TODO(spongehah) Replace header operations with using the textproto package - lengthSlice := response.Header["content-length"] - if lengthSlice == nil { - response.ContentLength = 0 - } else { - contentLength := response.Header["content-length"][0] - length, err := strconv.Atoi(contentLength) + switch (TaskId)(uintptr(task.Userdata())) { + case Send: + err := CheckTaskType(task, Send) if err != nil { - rc.ch <- responseAndError{err: fmt.Errorf("failed to parse content-length")} + rc.ch <- responseAndError{err: err} // Free the resources FreeResources(task, respBody, bodyWriter, exec, pc, rc) return } - response.ContentLength = int64(length) - } - - rc.ch <- responseAndError{res: &response} - - dataTask := respBody.Data() - SetTaskId(dataTask, ReceiveRespBody) - exec.Push(dataTask) - // No longer need the response - resp.Free() - case ReceiveRespBody: - err := CheckTaskType(task, ReceiveRespBody) - if err != nil { - rc.ch <- responseAndError{err: err} - // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) - return - } + client := (*hyper.ClientConn)(task.Value()) + task.Free() - if task.Type() == hyper.TaskBuf { - buf := (*hyper.Buf)(task.Value()) - bufLen := buf.Len() - bytes := unsafe.Slice((*byte)(buf.Bytes()), bufLen) - if bodyWriter == nil { - rc.ch <- responseAndError{err: fmt.Errorf("ResponseBodyWriter is nil")} + // Send it! + sendTask := client.Send(rc.req.Req) + SetTaskId(sendTask, ReceiveResp) + sendRes := exec.Push(sendTask) + if sendRes != hyper.OK { + rc.ch <- responseAndError{err: fmt.Errorf("failed to send request")} // Free the resources FreeResources(task, respBody, bodyWriter, exec, pc, rc) return } - _, err := bodyWriter.Write(bytes) // blocking + + // For this example, no longer need the client + client.Free() + case ReceiveResp: + err := CheckTaskType(task, ReceiveResp) if err != nil { rc.ch <- responseAndError{err: err} // Free the resources FreeResources(task, respBody, bodyWriter, exec, pc, rc) return } - buf.Free() + + // Take the results + resp := (*hyper.Response)(task.Value()) task.Free() + rp := resp.ReasonPhrase() + rpLen := resp.ReasonPhraseLen() + + response.Status = strconv.Itoa(int(resp.Status())) + " " + string((*[1 << 30]byte)(c.Pointer(rp))[:rpLen:rpLen]) + response.StatusCode = int(resp.Status()) + + headers := resp.Headers() + headers.Foreach(AppendToResponseHeader, c.Pointer(&response)) + respBody = resp.Body() + + response.Body, bodyWriter = io.Pipe() + + // TODO(spongehah) Replace header operations with using the textproto package + lengthSlice := response.Header["content-length"] + if lengthSlice == nil { + response.ContentLength = 0 + } else { + contentLength := response.Header["content-length"][0] + length, err := strconv.Atoi(contentLength) + if err != nil { + rc.ch <- responseAndError{err: fmt.Errorf("failed to parse content-length")} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return + } + response.ContentLength = int64(length) + } + + rc.ch <- responseAndError{res: &response} + + // Response has been returned, stop the timer + pc.conn.IsCompleted = 1 + // Stop the timer + if rc.req.timeout != 0 { + pc.conn.TimeoutTimer.Stop() + (*libuv.Handle)(c.Pointer(&pc.conn.TimeoutTimer)).Close(nil) + } + dataTask := respBody.Data() SetTaskId(dataTask, ReceiveRespBody) exec.Push(dataTask) - break - } + // No longer need the response + resp.Free() + case ReceiveRespBody: + err := CheckTaskType(task, ReceiveRespBody) + if err != nil { + rc.ch <- responseAndError{err: err} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return + } + + if task.Type() == hyper.TaskBuf { + buf := (*hyper.Buf)(task.Value()) + bufLen := buf.Len() + bytes := unsafe.Slice((*byte)(buf.Bytes()), bufLen) + if bodyWriter == nil { + rc.ch <- responseAndError{err: fmt.Errorf("ResponseBodyWriter is nil")} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return + } + _, err := bodyWriter.Write(bytes) // blocking + if err != nil { + rc.ch <- responseAndError{err: err} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return + } + buf.Free() + task.Free() + + dataTask := respBody.Data() + SetTaskId(dataTask, ReceiveRespBody) + exec.Push(dataTask) + + break + } + + // We are done with the response body + if task.Type() != hyper.TaskEmpty { + c.Printf(c.Str("unexpected task type\n")) + rc.ch <- responseAndError{err: fmt.Errorf("unexpected task type\n")} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return + } - // We are done with the response body - if task.Type() != hyper.TaskEmpty { - c.Printf(c.Str("unexpected task type\n")) - rc.ch <- responseAndError{err: fmt.Errorf("unexpected task type\n")} // Free the resources FreeResources(task, respBody, bodyWriter, exec, pc, rc) - return - } - - // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) - alive = false - case NotSet: - // A background task for hyper_client completed... - task.Free() + alive = false + case NotSet: + // A background task for hyper_client completed... + task.Free() + } } } //} @@ -454,6 +517,17 @@ func WriteCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen ui return hyper.IoPending } +// OnTimeout is the libuv callback for a timeout +func OnTimeout(handle *libuv.Timer) { + ct := (*connAndTimeoutChan)((*libuv.Handle)(c.Pointer(handle)).GetData()) + if ct.conn.IsCompleted != 1 { + ct.conn.IsCompleted = 1 + ct.timeoutch <- struct{}{} + } + // Close the timer + (*libuv.Handle)(c.Pointer(&ct.conn.TimeoutTimer)).Close(nil) +} + // NewIoWithConnReadWrite creates a new IO with read and write callbacks func NewIoWithConnReadWrite(connData *ConnData) *hyper.Io { hyperIo := hyper.NewIo() @@ -537,9 +611,16 @@ func FreeResources(task *hyper.Task, respBody *hyper.Body, bodyWriter *io.PipeWr (*libuv.Handle)(c.Pointer(&pc.conn.TcpHandle)).Close(nil) FreeConnData(pc.conn) + CloseChannels(rc, pc) +} + +// CloseChannels closes the channels +func CloseChannels(rc requestAndChan, pc *persistConn) { // Closing the channel close(rc.ch) close(pc.reqch) + close(pc.timeoutch) + close(pc.cancelch) } // FreeConnData frees the connection data From 2c9394c1f6d590b686cbb450725573801d0d4aaa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Mon, 12 Aug 2024 16:47:07 +0800 Subject: [PATCH 07/55] WIP(x/http/client/get): Introducing textproto for header & implementing custom header --- x/http/_demo/headers/headers.go | 45 ++++++++++++++++++++++ x/http/client.go | 5 +++ x/http/header.go | 67 ++++++++++++++++++++++++++++++++- x/http/request.go | 33 ++++++++++++++-- x/http/response.go | 14 ------- x/http/transport.go | 25 +++++++++--- 6 files changed, 164 insertions(+), 25 deletions(-) create mode 100644 x/http/_demo/headers/headers.go diff --git a/x/http/_demo/headers/headers.go b/x/http/_demo/headers/headers.go new file mode 100644 index 0000000..98cda79 --- /dev/null +++ b/x/http/_demo/headers/headers.go @@ -0,0 +1,45 @@ +package main + +import ( + "fmt" + "io" + + "github.com/goplus/llgo/x/http" +) + +func main() { + client := &http.Client{} + req, err := http.NewRequest("GET", "https://jsonplaceholder.typicode.com/comments?postId=1", nil) + if err != nil { + println(err.Error()) + return + } + + //req.Header.Set("accept", "*/*") + //req.Header.Set("accept-encoding", "identity") + //req.Header.Set("cache-control", "no-cache") + //req.Header.Set("pragma", "no-cache") + //req.Header.Set("priority", "u=0, i") + //req.Header.Set("referer", "https://jsonplaceholder.typicode.com/") + //req.Header.Set("sec-ch-ua", "\"Not)A;Brand\";v=\"99\", \"Google Chrome\";v=\"127\", \"Chromium\";v=\"127\"") + //req.Header.Set("sec-ch-ua-mobile", "?0") + //req.Header.Set("sec-ch-ua-platform", "\"macOS\"") + //req.Header.Set("sec-fetch-dest", "document") + //req.Header.Set("sec-fetch-mode", "navigate") + //req.Header.Set("sec-fetch-site", "same-origin") + //req.Header.Set("sec-fetch-user", "?1") + //req.Header.Set("upgrade-insecure-requests", "1") + //req.Header.Set("user-agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36") + + resp, err := client.Do(req) + if err != nil { + println(err.Error()) + return + } + body, err := io.ReadAll(resp.Body) + if err != nil { + println(err.Error()) + return + } + fmt.Println(string(body)) +} diff --git a/x/http/client.go b/x/http/client.go index 9ce8506..177b089 100644 --- a/x/http/client.go +++ b/x/http/client.go @@ -37,6 +37,11 @@ func (c *Client) Do(req *Request) (*Response, error) { } func (c *Client) do(req *Request) (*Response, error) { + // Add user-defined request headers to hyper.Request + err := req.setHeaders() + if err != nil { + return nil, err + } return c.send(req, c.Timeout) } diff --git a/x/http/header.go b/x/http/header.go index 4710854..aac05ce 100644 --- a/x/http/header.go +++ b/x/http/header.go @@ -4,11 +4,74 @@ import ( "fmt" "github.com/goplus/llgo/c" + "github.com/goplus/llgo/x/textproto" "github.com/goplus/llgoexamples/rust/hyper" ) +// A Header represents the key-value pairs in an HTTP header. +// +// The keys should be in canonical form, as returned by +// CanonicalHeaderKey. type Header map[string][]string +// Add adds the key, value pair to the header. +// It appends to any existing values associated with key. +// The key is case insensitive; it is canonicalized by +// CanonicalHeaderKey. +func (h Header) Add(key, value string) { + textproto.MIMEHeader(h).Add(key, value) +} + +// Set sets the header entries associated with key to the +// single element value. It replaces any existing values +// associated with key. The key is case insensitive; it is +// canonicalized by textproto.CanonicalMIMEHeaderKey. +// To use non-canonical keys, assign to the map directly. +func (h Header) Set(key, value string) { + textproto.MIMEHeader(h).Set(key, value) +} + +// Get gets the first value associated with the given key. If +// there are no values associated with the key, Get returns "". +// It is case insensitive; textproto.CanonicalMIMEHeaderKey is +// used to canonicalize the provided key. Get assumes that all +// keys are stored in canonical form. To use non-canonical keys, +// access the map directly. +func (h Header) Get(key string) string { + return textproto.MIMEHeader(h).Get(key) +} + +// Values returns all values associated with the given key. +// It is case insensitive; textproto.CanonicalMIMEHeaderKey is +// used to canonicalize the provided key. To use non-canonical +// keys, access the map directly. +// The returned slice is not a copy. +func (h Header) Values(key string) []string { + return textproto.MIMEHeader(h).Values(key) +} + +// get is like Get, but key must already be in CanonicalHeaderKey form. +func (h Header) get(key string) string { + if v := h[key]; len(v) > 0 { + return v[0] + } + return "" +} + +// has reports whether h has the provided key defined, even if it's +// set to 0-length slice. +func (h Header) has(key string) bool { + _, ok := h[key] + return ok +} + +// Del deletes the values associated with key. +// The key is case insensitive; it is canonicalized by +// CanonicalHeaderKey. +func (h Header) Del(key string) { + textproto.MIMEHeader(h).Del(key) +} + // AppendToResponseHeader (HeadersForEachCallback) prints each header to the console func AppendToResponseHeader(userdata c.Pointer, name *uint8, nameLen uintptr, value *uint8, valueLen uintptr) c.Int { resp := (*Response)(userdata) @@ -18,8 +81,8 @@ func AppendToResponseHeader(userdata c.Pointer, name *uint8, nameLen uintptr, va if resp.Header == nil { resp.Header = make(map[string][]string) } - resp.Header[nameStr] = append(resp.Header[nameStr], valueStr) - //c.Printf(c.Str("%.*s: %.*s\n"), int(nameLen), name, int(valueLen), value) + resp.Header.Add(nameStr, valueStr) + //resp.Header[nameStr] = append(resp.Header[nameStr], valueStr) return hyper.IterContinue } diff --git a/x/http/request.go b/x/http/request.go index 2e04939..933ae3b 100644 --- a/x/http/request.go +++ b/x/http/request.go @@ -28,13 +28,16 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { if err != nil { return nil, err } - return &Request{ + request := &Request{ Method: method, URL: parseURL, Req: req, Host: parseURL.Hostname(), + Header: make(Header), timeout: 0, - }, nil + } + request.Header.Set("Host", request.Host) + return request, nil } func newHyperRequest(method string, URL *url.URL) (*hyper.Request, error) { @@ -49,11 +52,33 @@ func newHyperRequest(method string, URL *url.URL) (*hyper.Request, error) { if req.SetURI((*uint8)(&[]byte(uri)[0]), c.Strlen(c.AllocaCStr(uri))) != hyper.OK { return nil, fmt.Errorf("error setting uri %s\n", uri) } - // Set the request headers reqHeaders := req.Headers() if reqHeaders.Set((*uint8)(&[]byte("Host")[0]), c.Strlen(c.Str("Host")), (*uint8)(&[]byte(host)[0]), c.Strlen(c.AllocaCStr(host))) != hyper.OK { - return nil, fmt.Errorf("error setting headers\n") + return nil, fmt.Errorf("error setting header: Host: %s\n", host) } + return req, nil } + +// setHeaders sets the headers of the request +func (req *Request) setHeaders() error { + headers := req.Req.Headers() + for key, values := range req.Header { + valueLen := len(values) + if valueLen > 1 { + for _, value := range values { + if headers.Add((*uint8)(&[]byte(key)[0]), c.Strlen(c.AllocaCStr(key)), (*uint8)(&[]byte(value)[0]), c.Strlen(c.AllocaCStr(value))) != hyper.OK { + return fmt.Errorf("error adding header %s: %s\n", key, value) + } + } + } else if valueLen == 1 { + if headers.Set((*uint8)(&[]byte(key)[0]), c.Strlen(c.AllocaCStr(key)), (*uint8)(&[]byte(values[0])[0]), c.Strlen(c.AllocaCStr(values[0]))) != hyper.OK { + return fmt.Errorf("error setting header %s: %s\n", key, values[0]) + } + } else { + return fmt.Errorf("error setting header %s: empty value\n", key) + } + } + return nil +} diff --git a/x/http/response.go b/x/http/response.go index 2f3a641..e69c273 100644 --- a/x/http/response.go +++ b/x/http/response.go @@ -11,17 +11,3 @@ type Response struct { Body io.ReadCloser ContentLength int64 } - -// AppendToResponseBody (BodyForEachCallback) appends the body to the response -//func AppendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { -// resp := (*Response)(userdata) -// len := chunk.Len() -// buf := unsafe.Slice((*byte)(chunk.Bytes()), len) -// _, err := resp.respBodyWriter.Write(buf) -// resp.ContentLength += int64(len) -// if err != nil { -// fmt.Printf("Failed to write response body: %v\n", err) -// return hyper.IterBreak -// } -// return hyper.IterContinue -//} diff --git a/x/http/transport.go b/x/http/transport.go index 1eed648..ec5ae02 100644 --- a/x/http/transport.go +++ b/x/http/transport.go @@ -184,6 +184,7 @@ func (pc *persistConn) roundTrip(req *Request) (*Response, error) { rc := <-pc.reqch // blocking // Free the resources FreeResources(nil, nil, nil, nil, pc, rc) + return nil, fmt.Errorf("request timeout\n") } select { case re := <-resc: @@ -297,12 +298,26 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { response.Body, bodyWriter = io.Pipe() - // TODO(spongehah) Replace header operations with using the textproto package - lengthSlice := response.Header["content-length"] - if lengthSlice == nil { - response.ContentLength = 0 + //// TODO(spongehah) Replace header operations with using the textproto package + //lengthSlice := response.Header["content-length"] + //if lengthSlice == nil { + // response.ContentLength = -1 + //} else { + // contentLength := response.Header["content-length"][0] + // length, err := strconv.Atoi(contentLength) + // if err != nil { + // rc.ch <- responseAndError{err: fmt.Errorf("failed to parse content-length")} + // // Free the resources + // FreeResources(task, respBody, bodyWriter, exec, pc, rc) + // return + // } + // response.ContentLength = int64(length) + //} + + contentLength := response.Header.Get("content-length") + if contentLength == "" { + response.ContentLength = -1 } else { - contentLength := response.Header["content-length"][0] length, err := strconv.Atoi(contentLength) if err != nil { rc.ch <- responseAndError{err: fmt.Errorf("failed to parse content-length")} From 5744fd69371df5ebc36d29579e63fd84e346f007 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Tue, 13 Aug 2024 18:16:44 +0800 Subject: [PATCH 08/55] WIP(x/http/client/get): Extract the readTransfer function and complete its content --- x/http/header.go | 122 ++++++++-------- x/http/request.go | 3 +- x/http/response.go | 30 +++- x/http/transfer.go | 332 ++++++++++++++++++++++++++++++++++++++++++++ x/http/transport.go | 61 +++----- 5 files changed, 444 insertions(+), 104 deletions(-) create mode 100644 x/http/transfer.go diff --git a/x/http/header.go b/x/http/header.go index aac05ce..0533ed7 100644 --- a/x/http/header.go +++ b/x/http/header.go @@ -4,7 +4,6 @@ import ( "fmt" "github.com/goplus/llgo/c" - "github.com/goplus/llgo/x/textproto" "github.com/goplus/llgoexamples/rust/hyper" ) @@ -18,59 +17,68 @@ type Header map[string][]string // It appends to any existing values associated with key. // The key is case insensitive; it is canonicalized by // CanonicalHeaderKey. -func (h Header) Add(key, value string) { - textproto.MIMEHeader(h).Add(key, value) -} - -// Set sets the header entries associated with key to the -// single element value. It replaces any existing values -// associated with key. The key is case insensitive; it is -// canonicalized by textproto.CanonicalMIMEHeaderKey. -// To use non-canonical keys, assign to the map directly. -func (h Header) Set(key, value string) { - textproto.MIMEHeader(h).Set(key, value) -} - -// Get gets the first value associated with the given key. If -// there are no values associated with the key, Get returns "". -// It is case insensitive; textproto.CanonicalMIMEHeaderKey is -// used to canonicalize the provided key. Get assumes that all -// keys are stored in canonical form. To use non-canonical keys, -// access the map directly. -func (h Header) Get(key string) string { - return textproto.MIMEHeader(h).Get(key) -} - -// Values returns all values associated with the given key. -// It is case insensitive; textproto.CanonicalMIMEHeaderKey is -// used to canonicalize the provided key. To use non-canonical -// keys, access the map directly. -// The returned slice is not a copy. -func (h Header) Values(key string) []string { - return textproto.MIMEHeader(h).Values(key) -} - -// get is like Get, but key must already be in CanonicalHeaderKey form. -func (h Header) get(key string) string { - if v := h[key]; len(v) > 0 { - return v[0] - } - return "" -} - -// has reports whether h has the provided key defined, even if it's -// set to 0-length slice. -func (h Header) has(key string) bool { - _, ok := h[key] - return ok -} - -// Del deletes the values associated with key. -// The key is case insensitive; it is canonicalized by -// CanonicalHeaderKey. -func (h Header) Del(key string) { - textproto.MIMEHeader(h).Del(key) -} +//func (h Header) Add(key, value string) { +// textproto.MIMEHeader(h).Add(key, value) +//} +// +//// Set sets the header entries associated with key to the +//// single element value. It replaces any existing values +//// associated with key. The key is case insensitive; it is +//// canonicalized by textproto.CanonicalMIMEHeaderKey. +//// To use non-canonical keys, assign to the map directly. +//func (h Header) Set(key, value string) { +// textproto.MIMEHeader(h).Set(key, value) +//} +// +//// Get gets the first value associated with the given key. If +//// there are no values associated with the key, Get returns "". +//// It is case insensitive; textproto.CanonicalMIMEHeaderKey is +//// used to canonicalize the provided key. Get assumes that all +//// keys are stored in canonical form. To use non-canonical keys, +//// access the map directly. +//func (h Header) Get(key string) string { +// return textproto.MIMEHeader(h).Get(key) +//} +// +//// Values returns all values associated with the given key. +//// It is case insensitive; textproto.CanonicalMIMEHeaderKey is +//// used to canonicalize the provided key. To use non-canonical +//// keys, access the map directly. +//// The returned slice is not a copy. +//func (h Header) Values(key string) []string { +// return textproto.MIMEHeader(h).Values(key) +//} +// +//// get is like Get, but key must already be in CanonicalHeaderKey form. +//func (h Header) get(key string) string { +// if v := h[key]; len(v) > 0 { +// return v[0] +// } +// return "" +//} +// +//// has reports whether h has the provided key defined, even if it's +//// set to 0-length slice. +//func (h Header) has(key string) bool { +// _, ok := h[key] +// return ok +//} +// +//// Del deletes the values associated with key. +//// The key is case insensitive; it is canonicalized by +//// CanonicalHeaderKey. +//func (h Header) Del(key string) { +// textproto.MIMEHeader(h).Del(key) +//} +// +//// CanonicalHeaderKey returns the canonical format of the +//// header key s. The canonicalization converts the first +//// letter and any letter following a hyphen to upper case; +//// the rest are converted to lowercase. For example, the +//// canonical key for "accept-encoding" is "Accept-Encoding". +//// If s contains a space or invalid header field bytes, it is +//// returned without modifications. +//func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) } // AppendToResponseHeader (HeadersForEachCallback) prints each header to the console func AppendToResponseHeader(userdata c.Pointer, name *uint8, nameLen uintptr, value *uint8, valueLen uintptr) c.Int { @@ -79,10 +87,10 @@ func AppendToResponseHeader(userdata c.Pointer, name *uint8, nameLen uintptr, va valueStr := string((*[1 << 30]byte)(c.Pointer(value))[:valueLen:valueLen]) if resp.Header == nil { - resp.Header = make(map[string][]string) + resp.Header = make(Header) } - resp.Header.Add(nameStr, valueStr) - //resp.Header[nameStr] = append(resp.Header[nameStr], valueStr) + //resp.Header.Add(nameStr, valueStr) + resp.Header[nameStr] = append(resp.Header[nameStr], valueStr) return hyper.IterContinue } diff --git a/x/http/request.go b/x/http/request.go index 933ae3b..1d219a1 100644 --- a/x/http/request.go +++ b/x/http/request.go @@ -36,7 +36,8 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { Header: make(Header), timeout: 0, } - request.Header.Set("Host", request.Host) + //request.Header.Set("Host", request.Host) + request.Header["Host"] = []string{request.Host} return request, nil } diff --git a/x/http/response.go b/x/http/response.go index e69c273..a08d77e 100644 --- a/x/http/response.go +++ b/x/http/response.go @@ -1,13 +1,39 @@ package http import ( + "fmt" "io" + "strconv" + + "github.com/goplus/llgo/c" + "github.com/goplus/llgoexamples/rust/hyper" ) type Response struct { - Status string - StatusCode int + Status string // e.g. "200 OK" + StatusCode int // e.g. 200 + Proto string // e.g. "HTTP/1.0" + ProtoMajor int // e.g. 1 + ProtoMinor int // e.g. 0 Header Header Body io.ReadCloser ContentLength int64 + Trailer Header + Chunked bool + Request *Request +} + +func readResponseLineAndHeader(resp *Response, hyperResp *hyper.Response) { + rp := hyperResp.ReasonPhrase() + rpLen := hyperResp.ReasonPhraseLen() + + resp.Status = strconv.Itoa(int(hyperResp.Status())) + " " + string((*[1 << 30]byte)(c.Pointer(rp))[:rpLen:rpLen]) + resp.StatusCode = int(hyperResp.Status()) + + version := int(hyperResp.Version()) + resp.ProtoMajor, resp.ProtoMinor = splitTwoDigitNumber(version) + resp.Proto = fmt.Sprintf("HTTP/%d.%d", resp.ProtoMajor, resp.ProtoMinor) + + headers := hyperResp.Headers() + headers.Foreach(AppendToResponseHeader, c.Pointer(resp)) } diff --git a/x/http/transfer.go b/x/http/transfer.go new file mode 100644 index 0000000..5157324 --- /dev/null +++ b/x/http/transfer.go @@ -0,0 +1,332 @@ +package http +// +//import ( +// "fmt" +// "io" +// "net/textproto" +// "strconv" +// "strings" +// +// "github.com/goplus/llgoexamples/rust/hyper" +//) +// +//type transferReader struct { +// // Input +// Header Header +// StatusCode int +// RequestMethod string +// ProtoMajor int +// ProtoMinor int +// // Output +// Body io.ReadCloser +// ContentLength int64 +// Chunked bool +// Close bool +// Trailer Header +//} +// +//// unsupportedTEError reports unsupported transfer-encodings. +//type unsupportedTEError struct { +// err string +//} +// +//func (uste *unsupportedTEError) Error() string { +// return uste.err +//} +// +//func readTransfer(resp *Response, hyperResp *hyper.Response) (err error) { +// //// TODO(spongehah) Replace header operations with using the textproto package +// //lengthSlice := resp.Header["content-length"] +// //if lengthSlice == nil { +// // resp.ContentLength = -1 +// //} else { +// // contentLength := resp.Header["content-length"][0] +// // length, err := strconv.Atoi(contentLength) +// // if err != nil { +// // return err +// // } +// // resp.ContentLength = int64(length) +// //} +// +// t := &transferReader{ +// Header: resp.Header, +// StatusCode: resp.StatusCode, +// RequestMethod: resp.Request.Method, +// ProtoMajor: resp.ProtoMajor, +// ProtoMinor: resp.ProtoMinor, +// } +// +// // Transfer-Encoding: chunked, and overriding Content-Length. +// if err = t.parseTransferEncoding(); err != nil { +// return err +// } +// +// realLength, err := fixLength(true, t.StatusCode, t.RequestMethod, t.Header, t.Chunked) +// if err != nil { +// return err +// } +// if t.RequestMethod == "HEAD" { +// if n, err := parseContentLength(t.Header.get("Content-Length")); err != nil { +// return err +// } else { +// t.ContentLength = n +// } +// } else { +// t.ContentLength = realLength +// } +// +// // Trailer +// t.Trailer, err = fixTrailer(t.Header, t.Chunked) +// +// // If there is no Content-Length or chunked Transfer-Encoding on a *Response +// // and the status is not 1xx, 204 or 304, then the body is unbounded. +// // See RFC 7230, section 3.3. +// if realLength == -1 && !t.Chunked && bodyAllowedForStatus(t.StatusCode) { +// // Unbounded body. +// t.Close = true +// } +// +// return nil +//} +// +//// parseTransferEncoding sets t.Chunked based on the Transfer-Encoding header. +//func (t *transferReader) parseTransferEncoding() error { +// raw, present := t.Header["Transfer-Encoding"] +// if !present { +// return nil +// } +// delete(t.Header, "Transfer-Encoding") +// +// // Issue 12785; ignore Transfer-Encoding on HTTP/1.0 requests. +// if !t.protoAtLeast(1, 1) { +// return nil +// } +// +// // Like nginx, we only support a single Transfer-Encoding header field, and +// // only if set to "chunked". This is one of the most security sensitive +// // surfaces in HTTP/1.1 due to the risk of request smuggling, so we keep it +// // strict and simple. +// if len(raw) != 1 { +// return &unsupportedTEError{fmt.Sprintf("too many transfer encodings: %q", raw)} +// } +// if !equalFold(raw[0], "chunked") { +// return &unsupportedTEError{fmt.Sprintf("unsupported transfer encoding: %q", raw[0])} +// } +// +// // RFC 7230 3.3.2 says "A sender MUST NOT send a Content-Length header field +// // in any message that contains a Transfer-Encoding header field." +// // +// // but also: "If a message is received with both a Transfer-Encoding and a +// // Content-Length header field, the Transfer-Encoding overrides the +// // Content-Length. Such a message might indicate an attempt to perform +// // request smuggling (Section 9.5) or response splitting (Section 9.4) and +// // ought to be handled as an error. A sender MUST remove the received +// // Content-Length field prior to forwarding such a message downstream." +// // +// // Reportedly, these appear in the wild. +// delete(t.Header, "Content-Length") +// +// t.Chunked = true +// return nil +//} +// +//func (t *transferReader) protoAtLeast(m, n int) bool { +// return t.ProtoMajor > m || (t.ProtoMajor == m && t.ProtoMinor >= n) +//} +// +//// equalFold is strings.EqualFold, ASCII only. It reports whether s and t +//// are equal, ASCII-case-insensitively. +//func equalFold(s, t string) bool { +// if len(s) != len(t) { +// return false +// } +// for i := 0; i < len(s); i++ { +// if lower(s[i]) != lower(t[i]) { +// return false +// } +// } +// return true +//} +// +//// Determine the expected body length, using RFC 7230 Section 3.3. This +//// function is not a method, because ultimately it should be shared by +//// ReadResponse and ReadRequest. +//func fixLength(isResponse bool, status int, requestMethod string, header Header, chunked bool) (int64, error) { +// isRequest := !isResponse +// contentLens := header["Content-Length"] +// +// // Hardening against HTTP request smuggling +// if len(contentLens) > 1 { +// // Per RFC 7230 Section 3.3.2, prevent multiple +// // Content-Length headers if they differ in value. +// // If there are dups of the value, remove the dups. +// // See Issue 16490. +// first := textproto.TrimString(contentLens[0]) +// for _, ct := range contentLens[1:] { +// if first != textproto.TrimString(ct) { +// return 0, fmt.Errorf("http: message cannot contain multiple Content-Length headers; got %q", contentLens) +// } +// } +// +// // deduplicate Content-Length +// header.Del("Content-Length") +// header.Add("Content-Length", first) +// +// contentLens = header["Content-Length"] +// } +// +// // Logic based on response type or status +// if isResponse && noResponseBodyExpected(requestMethod) { +// return 0, nil +// } +// if status/100 == 1 { +// return 0, nil +// } +// switch status { +// case 204, 304: +// return 0, nil +// } +// +// // Logic based on Transfer-Encoding +// if chunked { +// return -1, nil +// } +// +// // Logic based on Content-Length +// var cl string +// if len(contentLens) == 1 { +// cl = textproto.TrimString(contentLens[0]) +// } +// if cl != "" { +// n, err := parseContentLength(cl) +// if err != nil { +// return -1, err +// } +// return n, nil +// } +// header.Del("Content-Length") +// +// if isRequest { +// // RFC 7230 neither explicitly permits nor forbids an +// // entity-body on a GET request so we permit one if +// // declared, but we default to 0 here (not -1 below) +// // if there's no mention of a body. +// // Likewise, all other request methods are assumed to have +// // no body if neither Transfer-Encoding chunked nor a +// // Content-Length are set. +// return 0, nil +// } +// +// // Body-EOF logic based on other methods (like closing, or chunked coding) +// return -1, nil +//} +// +//// parseContentLength trims whitespace from s and returns -1 if no value +//// is set, or the value if it's >= 0. +//func parseContentLength(cl string) (int64, error) { +// cl = textproto.TrimString(cl) +// if cl == "" { +// return -1, nil +// } +// n, err := strconv.ParseUint(cl, 10, 63) +// if err != nil { +// return 0, badStringError("bad Content-Length", cl) +// } +// return int64(n), nil +// +//} +// +//// Parse the trailer header. +//func fixTrailer(header Header, chunked bool) (Header, error) { +// vv, ok := header["Trailer"] +// if !ok { +// return nil, nil +// } +// if !chunked { +// // Trailer and no chunking: +// // this is an invalid use case for trailer header. +// // Nevertheless, no error will be returned and we +// // let users decide if this is a valid HTTP message. +// // The Trailer header will be kept in Response.Header +// // but not populate Response.Trailer. +// // See issue #27197. +// return nil, nil +// } +// header.Del("Trailer") +// +// trailer := make(Header) +// var err error +// for _, v := range vv { +// foreachHeaderElement(v, func(key string) { +// key = CanonicalHeaderKey(key) +// switch key { +// case "Transfer-Encoding", "Trailer", "Content-Length": +// if err == nil { +// err = badStringError("bad trailer key", key) +// return +// } +// } +// trailer[key] = nil +// }) +// } +// if err != nil { +// return nil, err +// } +// if len(trailer) == 0 { +// return nil, nil +// } +// return trailer, nil +//} +// +//// splitTwoDigitNumber splits a two-digit number into two digits. +func splitTwoDigitNumber(num int) (int, int) { + tens := num / 10 + ones := num % 10 + return tens, ones +} +// +//// lower returns the ASCII lowercase version of b. +//func lower(b byte) byte { +// if 'A' <= b && b <= 'Z' { +// return b + ('a' - 'A') +// } +// return b +//} +// +//// foreachHeaderElement splits v according to the "#rule" construction +//// in RFC 7230 section 7 and calls fn for each non-empty element. +//func foreachHeaderElement(v string, fn func(string)) { +// v = textproto.TrimString(v) +// if v == "" { +// return +// } +// if !strings.Contains(v, ",") { +// fn(v) +// return +// } +// for _, f := range strings.Split(v, ",") { +// if f = textproto.TrimString(f); f != "" { +// fn(f) +// } +// } +//} +// +//func noResponseBodyExpected(requestMethod string) bool { +// return requestMethod == "HEAD" +//} +// +//func badStringError(what, val string) error { return fmt.Errorf("%s %q", what, val) } +// +//// bodyAllowedForStatus reports whether a given response status code +//// permits a body. See RFC 7230, section 3.3. +//func bodyAllowedForStatus(status int) bool { +// switch { +// case status >= 100 && status <= 199: +// return false +// case status == 204: +// return false +// case status == 304: +// return false +// } +// return true +//} diff --git a/x/http/transport.go b/x/http/transport.go index ec5ae02..b9f845c 100644 --- a/x/http/transport.go +++ b/x/http/transport.go @@ -3,7 +3,6 @@ package http import ( "fmt" "io" - "strconv" "unsafe" "github.com/goplus/llgo/c" @@ -229,7 +228,11 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // Poll all ready tasks and act on them... rc := <-pc.reqch // blocking alive := true - var response Response + resp := &Response{ + Request: rc.req, + Header: make(Header), + Trailer: make(Header), + } var bodyWriter *io.PipeWriter var respBody *hyper.Body = nil for alive { @@ -283,52 +286,22 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { } // Take the results - resp := (*hyper.Response)(task.Value()) + hyperResp := (*hyper.Response)(task.Value()) task.Free() - rp := resp.ReasonPhrase() - rpLen := resp.ReasonPhraseLen() - - response.Status = strconv.Itoa(int(resp.Status())) + " " + string((*[1 << 30]byte)(c.Pointer(rp))[:rpLen:rpLen]) - response.StatusCode = int(resp.Status()) - - headers := resp.Headers() - headers.Foreach(AppendToResponseHeader, c.Pointer(&response)) - respBody = resp.Body() - - response.Body, bodyWriter = io.Pipe() - - //// TODO(spongehah) Replace header operations with using the textproto package - //lengthSlice := response.Header["content-length"] - //if lengthSlice == nil { - // response.ContentLength = -1 - //} else { - // contentLength := response.Header["content-length"][0] - // length, err := strconv.Atoi(contentLength) - // if err != nil { - // rc.ch <- responseAndError{err: fmt.Errorf("failed to parse content-length")} - // // Free the resources - // FreeResources(task, respBody, bodyWriter, exec, pc, rc) - // return - // } - // response.ContentLength = int64(length) + readResponseLineAndHeader(resp, hyperResp) + //err = readTransfer(resp, hyperResp) + //if err != nil { + // rc.ch <- responseAndError{err: err} + // // Free the resources + // FreeResources(task, respBody, bodyWriter, exec, pc, rc) + // return //} - contentLength := response.Header.Get("content-length") - if contentLength == "" { - response.ContentLength = -1 - } else { - length, err := strconv.Atoi(contentLength) - if err != nil { - rc.ch <- responseAndError{err: fmt.Errorf("failed to parse content-length")} - // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) - return - } - response.ContentLength = int64(length) - } + respBody = hyperResp.Body() + resp.Body, bodyWriter = io.Pipe() - rc.ch <- responseAndError{res: &response} + rc.ch <- responseAndError{res: resp} // Response has been returned, stop the timer pc.conn.IsCompleted = 1 @@ -343,7 +316,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { exec.Push(dataTask) // No longer need the response - resp.Free() + hyperResp.Free() case ReceiveRespBody: err := CheckTaskType(task, ReceiveRespBody) if err != nil { From d25cbfedaf77d80b31f5ce3705d20b127dbd3ddb Mon Sep 17 00:00:00 2001 From: hackerchai Date: Wed, 14 Aug 2024 17:38:08 +0800 Subject: [PATCH 09/55] feat(x/net/http): Init server skeleton Signed-off-by: hackerchai --- x/net/http/header.go | 18 + x/net/http/request.go | 48 + x/net/http/response.go | 73 ++ x/net/http/server.go | 155 +++ x/net/http/servermux.go | 54 + x/net/http/status.go | 210 ++++ x/net/url/url.go | 2626 +++++++++++++++++++++++++++++++++++++++ 7 files changed, 3184 insertions(+) create mode 100644 x/net/http/header.go create mode 100644 x/net/http/request.go create mode 100644 x/net/http/response.go create mode 100644 x/net/http/server.go create mode 100644 x/net/http/servermux.go create mode 100644 x/net/http/status.go create mode 100644 x/net/url/url.go diff --git a/x/net/http/header.go b/x/net/http/header.go new file mode 100644 index 0000000..39130e4 --- /dev/null +++ b/x/net/http/header.go @@ -0,0 +1,18 @@ +package http + +type Header map[string][]string + +func (h Header) Add(key, value string) { + h[key] = append(h[key], value) +} + +func (h Header) Set(key, value string) { + h[key] = []string{value} +} + +func (h Header) Get(key string) string { + if v := h[key]; len(v) > 0 { + return v[0] + } + return "" +} \ No newline at end of file diff --git a/x/net/http/request.go b/x/net/http/request.go new file mode 100644 index 0000000..96b834c --- /dev/null +++ b/x/net/http/request.go @@ -0,0 +1,48 @@ +package http + +import ( + "fmt" + "io" + "unsafe" + + "github.com/goplus/llgo/rust/hyper" +) + +type Request struct { + Method string + URL string + Header Header + Body io.ReadCloser +} + +func newRequest(hyperReq *hyper.Request) (*Request, error) { + method := make([]byte, 32) + methodLen := uintptr(len(method)) + if err := hyperReq.Method(&method[0], &methodLen); err != hyper.OK { + return nil, fmt.Errorf("failed to get method: %v", err) + } + + var scheme, authority, pathAndQuery [1024]byte + schemeLen, authorityLen, pathAndQueryLen := uintptr(len(scheme)), uintptr(len(authority)), uintptr(len(pathAndQuery)) + if err := hyperReq.URIParts(&scheme[0], &schemeLen, &authority[0], &authorityLen, &pathAndQuery[0], &pathAndQueryLen); err != hyper.OK { + return nil, fmt.Errorf("failed to get URI parts: %v", err) + } + + req := &Request{ + Method: string(method[:methodLen]), + URL: fmt.Sprintf("%s://%s%s", string(scheme[:schemeLen]), string(authority[:authorityLen]), string(pathAndQuery[:pathAndQueryLen])), + Header: make(Header), + } + + headers := hyperReq.Headers() + if headers != nil { + headers.Foreach(func(name *byte, nameLen uintptr, value *byte, valueLen uintptr) int { + key := string(unsafe.Slice(name, nameLen)) + val := string(unsafe.Slice(value, valueLen)) + req.Header.Add(key, val) + return hyper.IterContinue + }, nil) + } + + return req, nil +} \ No newline at end of file diff --git a/x/net/http/response.go b/x/net/http/response.go new file mode 100644 index 0000000..8892339 --- /dev/null +++ b/x/net/http/response.go @@ -0,0 +1,73 @@ +package http + +import ( + "unsafe" + + "github.com/goplus/llgo/rust/hyper" +) + +type Response struct { + header Header + statusCode int + written bool + body []byte + channel *hyper.ResponseChannel +} + +func newResponse(channel *hyper.ResponseChannel) *Response { + return &Response{ + header: make(Header), + channel: channel, + } +} + +func (r *Response) Header() Header { + return r.header +} + +func (r *Response) Write(data []byte) (int, error) { + if !r.written { + r.WriteHeader(200) + } + r.body = append(r.body, data...) + return len(data), nil +} + +func (r *Response) WriteHeader(statusCode int) { + if r.written { + return + } + r.written = true + r.statusCode = statusCode + + resp := hyper.NewResponse() + resp.SetStatus(uint(statusCode)) + + headers := resp.Headers() + for k, v := range r.header { + for _, val := range v { + headers.Set([]byte(k), uintptr(len(k)), []byte(val), uintptr(len(val))) + } + } + + r.channel.Send(resp) +} + +func (r *Response) finalize() error { + if !r.written { + r.WriteHeader(200) + } + + body := hyper.NewBody() + body.SetDataFunc(func(userdata unsafe.Pointer, ctx *hyper.Context, chunk **hyper.Buf) int { + *chunk = hyper.CopyBuf(r.body, uintptr(len(r.body))) + r.body = nil // Clear the body after sending + return hyper.PollReady + }) + + resp := hyper.NewResponse() + resp.SetBody(body) + + r.channel.Send(resp) + return nil +} \ No newline at end of file diff --git a/x/net/http/server.go b/x/net/http/server.go new file mode 100644 index 0000000..e1ed90f --- /dev/null +++ b/x/net/http/server.go @@ -0,0 +1,155 @@ +package http + +import ( + "fmt" + "unsafe" + + "github.com/goplus/llgo/c/libuv" + "github.com/goplus/llgo/c/net" + "github.com/goplus/llgo/rust/hyper" +) + +type Handler interface { + ServeHTTP(ResponseWriter, *Request) +} + +type ResponseWriter interface { + Header() Header + Write([]byte) (int, error) + WriteHeader(statusCode int) +} + +type Server struct { + Addr string + Handler Handler + + uvLoop *libuv.Loop + uvServer libuv.Tcp + hyperExecutor *hyper.Executor +} + +func NewServer(addr string) *Server { + return &Server{ + Addr: addr, + Handler: DefaultServeMux, + } +} + +func (srv *Server) ListenAndServe() error { + srv.uvLoop = libuv.DefaultLoop() + srv.hyperExecutor = hyper.NewExecutor() + + if err := libuv.InitTcp(srv.uvLoop, &srv.uvServer); err != 0 { + return fmt.Errorf("failed to init TCP: %v", err) + } + + var sockaddr net.SockaddrIn + if err := libuv.Ip4Addr(srv.Addr, 0, &sockaddr); err != 0 { + return fmt.Errorf("failed to create IP address: %v", err) + } + + if err := srv.uvServer.Bind((*net.SockAddr)(unsafe.Pointer(&sockaddr)), 0); err != 0 { + return fmt.Errorf("failed to bind: %v", err) + } + + if err := srv.uvServer.Listen(128, srv.onNewConnection); err != 0 { + return fmt.Errorf("failed to listen: %v", err) + } + + fmt.Printf("Listening on %s\n", srv.Addr) + + for { + srv.uvLoop.Run(libuv.RUN_NOWAIT) + + task := srv.hyperExecutor.Poll() + for task != nil { + srv.handleTask(task) + task.Free() + task = srv.hyperExecutor.Poll() + } + } +} + +func (srv *Server) onNewConnection(serverStream *libuv.Stream, status int) { + if status < 0 { + fmt.Printf("New connection error: %s\n", libuv.Strerror(libuv.Errno(status))) + return + } + + client := new(libuv.Tcp) + libuv.InitTcp(srv.uvLoop, client) + + if serverStream.Accept((*libuv.Stream)(unsafe.Pointer(client))) == 0 { + io := createIo(client) + service := hyper.ServiceNew(srv.serverCallback) + + http1Opts := hyper.Http1ServerconnOptionsNew(srv.hyperExecutor) + http2Opts := hyper.Http2ServerconnOptionsNew(srv.hyperExecutor) + + serverconn := hyper.ServeHttpXConnection(http1Opts, http2Opts, io, service) + srv.hyperExecutor.Push(serverconn) + + http1Opts.Free() + http2Opts.Free() + } else { + (*libuv.Handle)(unsafe.Pointer(client)).Close(nil) + } +} + +func (srv *Server) serverCallback(userdata unsafe.Pointer, hyperReq *hyper.Request, channel *hyper.ResponseChannel) { + req, err := newRequest(hyperReq) + if err != nil { + fmt.Printf("Error creating request: %v\n", err) + return + } + + res := newResponse(channel) + + srv.Handler.ServeHTTP(res, req) + + res.finalize() +} + +func (srv *Server) handleTask(task *hyper.Task) { + switch task.Type() { + case hyper.TaskServerconn: + fmt.Println("New server connection") + case hyper.TaskResponse: + fmt.Println("Response sent") + case hyper.TaskError: + err := (*hyper.Error)(task.Value()) + fmt.Printf("Task error: %s\n", err.Message()) + } +} + +func createIo(client *libuv.Tcp) *hyper.Io { + io := hyper.NewIo() + io.SetRead(func(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uintptr) uintptr { + ret := client.Read(unsafe.Pointer(buf), bufLen) + if ret < 0 { + return hyper.IoError + } + return uintptr(ret) + }) + io.SetWrite(func(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uintptr) uintptr { + ret := client.Write(unsafe.Pointer(buf), bufLen) + if ret < 0 { + return hyper.IoError + } + return uintptr(ret) + }) + return io +} + +type HandlerFunc func(ResponseWriter, *Request) + +func (f HandlerFunc) ServeHTTP(w ResponseWriter, r *Request) { + f(w, r) +} + +func NotFoundHandler() Handler { return HandlerFunc(NotFound) } + +func NotFound(w ResponseWriter, r *Request) { + w.WriteHeader(404) + w.Write([]byte("404 page not found")) +} \ No newline at end of file diff --git a/x/net/http/servermux.go b/x/net/http/servermux.go new file mode 100644 index 0000000..d6fc8bf --- /dev/null +++ b/x/net/http/servermux.go @@ -0,0 +1,54 @@ +package http + +import ( + "sync" +) + +type ServeMux struct { + mu sync.RWMutex + m map[string]muxEntry +} + +type muxEntry struct { + h Handler + pattern string +} + +var DefaultServeMux = &ServeMux{m: make(map[string]muxEntry)} + +func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) { + h, _ := mux.Handler(r) + h.ServeHTTP(w, r) +} + +func (mux *ServeMux) Handler(r *Request) (h Handler, pattern string) { + mux.mu.RLock() + defer mux.mu.RUnlock() + + h, pattern = mux.m[r.URL].h, r.URL + if h == nil { + h, pattern = NotFoundHandler(), "" + } + return +} + +func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Request)) { + mux.Handle(pattern, HandlerFunc(handler)) +} + +func (mux *ServeMux) Handle(pattern string, handler Handler) { + mux.mu.Lock() + defer mux.mu.Unlock() + + if pattern == "" { + panic("http: invalid pattern") + } + if handler == nil { + panic("http: nil handler") + } + if _, exist := mux.m[pattern]; exist { + panic("http: multiple registrations for " + pattern) + } + + mux.m[pattern] = muxEntry{h: handler, pattern: pattern} +} \ No newline at end of file diff --git a/x/net/http/status.go b/x/net/http/status.go new file mode 100644 index 0000000..cd90877 --- /dev/null +++ b/x/net/http/status.go @@ -0,0 +1,210 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +// HTTP status codes as registered with IANA. +// See: https://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml +const ( + StatusContinue = 100 // RFC 9110, 15.2.1 + StatusSwitchingProtocols = 101 // RFC 9110, 15.2.2 + StatusProcessing = 102 // RFC 2518, 10.1 + StatusEarlyHints = 103 // RFC 8297 + + StatusOK = 200 // RFC 9110, 15.3.1 + StatusCreated = 201 // RFC 9110, 15.3.2 + StatusAccepted = 202 // RFC 9110, 15.3.3 + StatusNonAuthoritativeInfo = 203 // RFC 9110, 15.3.4 + StatusNoContent = 204 // RFC 9110, 15.3.5 + StatusResetContent = 205 // RFC 9110, 15.3.6 + StatusPartialContent = 206 // RFC 9110, 15.3.7 + StatusMultiStatus = 207 // RFC 4918, 11.1 + StatusAlreadyReported = 208 // RFC 5842, 7.1 + StatusIMUsed = 226 // RFC 3229, 10.4.1 + + StatusMultipleChoices = 300 // RFC 9110, 15.4.1 + StatusMovedPermanently = 301 // RFC 9110, 15.4.2 + StatusFound = 302 // RFC 9110, 15.4.3 + StatusSeeOther = 303 // RFC 9110, 15.4.4 + StatusNotModified = 304 // RFC 9110, 15.4.5 + StatusUseProxy = 305 // RFC 9110, 15.4.6 + _ = 306 // RFC 9110, 15.4.7 (Unused) + StatusTemporaryRedirect = 307 // RFC 9110, 15.4.8 + StatusPermanentRedirect = 308 // RFC 9110, 15.4.9 + + StatusBadRequest = 400 // RFC 9110, 15.5.1 + StatusUnauthorized = 401 // RFC 9110, 15.5.2 + StatusPaymentRequired = 402 // RFC 9110, 15.5.3 + StatusForbidden = 403 // RFC 9110, 15.5.4 + StatusNotFound = 404 // RFC 9110, 15.5.5 + StatusMethodNotAllowed = 405 // RFC 9110, 15.5.6 + StatusNotAcceptable = 406 // RFC 9110, 15.5.7 + StatusProxyAuthRequired = 407 // RFC 9110, 15.5.8 + StatusRequestTimeout = 408 // RFC 9110, 15.5.9 + StatusConflict = 409 // RFC 9110, 15.5.10 + StatusGone = 410 // RFC 9110, 15.5.11 + StatusLengthRequired = 411 // RFC 9110, 15.5.12 + StatusPreconditionFailed = 412 // RFC 9110, 15.5.13 + StatusRequestEntityTooLarge = 413 // RFC 9110, 15.5.14 + StatusRequestURITooLong = 414 // RFC 9110, 15.5.15 + StatusUnsupportedMediaType = 415 // RFC 9110, 15.5.16 + StatusRequestedRangeNotSatisfiable = 416 // RFC 9110, 15.5.17 + StatusExpectationFailed = 417 // RFC 9110, 15.5.18 + StatusTeapot = 418 // RFC 9110, 15.5.19 (Unused) + StatusMisdirectedRequest = 421 // RFC 9110, 15.5.20 + StatusUnprocessableEntity = 422 // RFC 9110, 15.5.21 + StatusLocked = 423 // RFC 4918, 11.3 + StatusFailedDependency = 424 // RFC 4918, 11.4 + StatusTooEarly = 425 // RFC 8470, 5.2. + StatusUpgradeRequired = 426 // RFC 9110, 15.5.22 + StatusPreconditionRequired = 428 // RFC 6585, 3 + StatusTooManyRequests = 429 // RFC 6585, 4 + StatusRequestHeaderFieldsTooLarge = 431 // RFC 6585, 5 + StatusUnavailableForLegalReasons = 451 // RFC 7725, 3 + + StatusInternalServerError = 500 // RFC 9110, 15.6.1 + StatusNotImplemented = 501 // RFC 9110, 15.6.2 + StatusBadGateway = 502 // RFC 9110, 15.6.3 + StatusServiceUnavailable = 503 // RFC 9110, 15.6.4 + StatusGatewayTimeout = 504 // RFC 9110, 15.6.5 + StatusHTTPVersionNotSupported = 505 // RFC 9110, 15.6.6 + StatusVariantAlsoNegotiates = 506 // RFC 2295, 8.1 + StatusInsufficientStorage = 507 // RFC 4918, 11.5 + StatusLoopDetected = 508 // RFC 5842, 7.2 + StatusNotExtended = 510 // RFC 2774, 7 + StatusNetworkAuthenticationRequired = 511 // RFC 6585, 6 +) + +// StatusText returns a text for the HTTP status code. It returns the empty +// string if the code is unknown. +func StatusText(code int) string { + switch code { + case StatusContinue: + return "Continue" + case StatusSwitchingProtocols: + return "Switching Protocols" + case StatusProcessing: + return "Processing" + case StatusEarlyHints: + return "Early Hints" + case StatusOK: + return "OK" + case StatusCreated: + return "Created" + case StatusAccepted: + return "Accepted" + case StatusNonAuthoritativeInfo: + return "Non-Authoritative Information" + case StatusNoContent: + return "No Content" + case StatusResetContent: + return "Reset Content" + case StatusPartialContent: + return "Partial Content" + case StatusMultiStatus: + return "Multi-Status" + case StatusAlreadyReported: + return "Already Reported" + case StatusIMUsed: + return "IM Used" + case StatusMultipleChoices: + return "Multiple Choices" + case StatusMovedPermanently: + return "Moved Permanently" + case StatusFound: + return "Found" + case StatusSeeOther: + return "See Other" + case StatusNotModified: + return "Not Modified" + case StatusUseProxy: + return "Use Proxy" + case StatusTemporaryRedirect: + return "Temporary Redirect" + case StatusPermanentRedirect: + return "Permanent Redirect" + case StatusBadRequest: + return "Bad Request" + case StatusUnauthorized: + return "Unauthorized" + case StatusPaymentRequired: + return "Payment Required" + case StatusForbidden: + return "Forbidden" + case StatusNotFound: + return "Not Found" + case StatusMethodNotAllowed: + return "Method Not Allowed" + case StatusNotAcceptable: + return "Not Acceptable" + case StatusProxyAuthRequired: + return "Proxy Authentication Required" + case StatusRequestTimeout: + return "Request Timeout" + case StatusConflict: + return "Conflict" + case StatusGone: + return "Gone" + case StatusLengthRequired: + return "Length Required" + case StatusPreconditionFailed: + return "Precondition Failed" + case StatusRequestEntityTooLarge: + return "Request Entity Too Large" + case StatusRequestURITooLong: + return "Request URI Too Long" + case StatusUnsupportedMediaType: + return "Unsupported Media Type" + case StatusRequestedRangeNotSatisfiable: + return "Requested Range Not Satisfiable" + case StatusExpectationFailed: + return "Expectation Failed" + case StatusTeapot: + return "I'm a teapot" + case StatusMisdirectedRequest: + return "Misdirected Request" + case StatusUnprocessableEntity: + return "Unprocessable Entity" + case StatusLocked: + return "Locked" + case StatusFailedDependency: + return "Failed Dependency" + case StatusTooEarly: + return "Too Early" + case StatusUpgradeRequired: + return "Upgrade Required" + case StatusPreconditionRequired: + return "Precondition Required" + case StatusTooManyRequests: + return "Too Many Requests" + case StatusRequestHeaderFieldsTooLarge: + return "Request Header Fields Too Large" + case StatusUnavailableForLegalReasons: + return "Unavailable For Legal Reasons" + case StatusInternalServerError: + return "Internal Server Error" + case StatusNotImplemented: + return "Not Implemented" + case StatusBadGateway: + return "Bad Gateway" + case StatusServiceUnavailable: + return "Service Unavailable" + case StatusGatewayTimeout: + return "Gateway Timeout" + case StatusHTTPVersionNotSupported: + return "HTTP Version Not Supported" + case StatusVariantAlsoNegotiates: + return "Variant Also Negotiates" + case StatusInsufficientStorage: + return "Insufficient Storage" + case StatusLoopDetected: + return "Loop Detected" + case StatusNotExtended: + return "Not Extended" + case StatusNetworkAuthenticationRequired: + return "Network Authentication Required" + default: + return "" + } +} diff --git a/x/net/url/url.go b/x/net/url/url.go new file mode 100644 index 0000000..340b74f --- /dev/null +++ b/x/net/url/url.go @@ -0,0 +1,2626 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package url parses URLs and implements query escaping. +package url + +// See RFC 3986. This package generally follows RFC 3986, except where +// it deviates for compatibility reasons. When sending changes, first +// search old issues for history on decisions. Unit tests should also +// contain references to issue numbers with details. + +import ( + "errors" + "fmt" + "path" + "slices" + "strconv" + "strings" + _ "unsafe" // for linkname +) + +// Error reports an error and the operation and URL that caused it. +type Error struct { + Op string + URL string + Err error +} + +func (e *Error) Unwrap() error { return e.Err } +func (e *Error) Error() string { return fmt.Sprintf("%s %q: %s", e.Op, e.URL, e.Err) } + +func (e *Error) Timeout() bool { + t, ok := e.Err.(interface { + Timeout() bool + }) + return ok && t.Timeout() +} + +func (e *Error) Temporary() bool { + t, ok := e.Err.(interface { + Temporary() bool + }) + return ok && t.Temporary() +} + +const upperhex = "0123456789ABCDEF" + +func ishex(c byte) bool { + switch { + case '0' <= c && c <= '9': + return true + case 'a' <= c && c <= 'f': + return true + case 'A' <= c && c <= 'F': + return true + } + return false +} + +func unhex(c byte) byte { + switch { + case '0' <= c && c <= '9': + return c - '0' + case 'a' <= c && c <= 'f': + return c - 'a' + 10 + case 'A' <= c && c <= 'F': + return c - 'A' + 10 + } + return 0 +} + +type encoding int + +const ( + encodePath encoding = 1 + iota + encodePathSegment + encodeHost + encodeZone + encodeUserPassword + encodeQueryComponent + encodeFragment +) + +type EscapeError string + +func (e EscapeError) Error() string { + return "invalid URL escape " + strconv.Quote(string(e)) +} + +type InvalidHostError string + +func (e InvalidHostError) Error() string { + return "invalid character " + strconv.Quote(string(e)) + " in host name" +} + +// Return true if the specified character should be escaped when +// appearing in a URL string, according to RFC 3986. +// +// Please be informed that for now shouldEscape does not check all +// reserved characters correctly. See golang.org/issue/5684. +func shouldEscape(c byte, mode encoding) bool { + // §2.3 Unreserved characters (alphanum) + if 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9' { + return false + } + + if mode == encodeHost || mode == encodeZone { + // §3.2.2 Host allows + // sub-delims = "!" / "$" / "&" / "'" / "(" / ")" / "*" / "+" / "," / ";" / "=" + // as part of reg-name. + // We add : because we include :port as part of host. + // We add [ ] because we include [ipv6]:port as part of host. + // We add < > because they're the only characters left that + // we could possibly allow, and Parse will reject them if we + // escape them (because hosts can't use %-encoding for + // ASCII bytes). + switch c { + case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=', ':', '[', ']', '<', '>', '"': + return false + } + } + + switch c { + case '-', '_', '.', '~': // §2.3 Unreserved characters (mark) + return false + + case '$', '&', '+', ',', '/', ':', ';', '=', '?', '@': // §2.2 Reserved characters (reserved) + // Different sections of the URL allow a few of + // the reserved characters to appear unescaped. + switch mode { + case encodePath: // §3.3 + // The RFC allows : @ & = + $ but saves / ; , for assigning + // meaning to individual path segments. This package + // only manipulates the path as a whole, so we allow those + // last three as well. That leaves only ? to escape. + return c == '?' + + case encodePathSegment: // §3.3 + // The RFC allows : @ & = + $ but saves / ; , for assigning + // meaning to individual path segments. + return c == '/' || c == ';' || c == ',' || c == '?' + + case encodeUserPassword: // §3.2.1 + // The RFC allows ';', ':', '&', '=', '+', '$', and ',' in + // userinfo, so we must escape only '@', '/', and '?'. + // The parsing of userinfo treats ':' as special so we must escape + // that too. + return c == '@' || c == '/' || c == '?' || c == ':' + + case encodeQueryComponent: // §3.4 + // The RFC reserves (so we must escape) everything. + return true + + case encodeFragment: // §4.1 + // The RFC text is silent but the grammar allows + // everything, so escape nothing. + return false + } + } + + if mode == encodeFragment { + // RFC 3986 §2.2 allows not escaping sub-delims. A subset of sub-delims are + // included in reserved from RFC 2396 §2.2. The remaining sub-delims do not + // need to be escaped. To minimize potential breakage, we apply two restrictions: + // (1) we always escape sub-delims outside of the fragment, and (2) we always + // escape single quote to avoid breaking callers that had previously assumed that + // single quotes would be escaped. See issue #19917. + switch c { + case '!', '(', ')', '*': + return false + } + } + + // Everything else must be escaped. + return true +} + +// QueryUnescape does the inverse transformation of [QueryEscape], +// converting each 3-byte encoded substring of the form "%AB" into the +// hex-decoded byte 0xAB. +// It returns an error if any % is not followed by two hexadecimal +// digits. +func QueryUnescape(s string) (string, error) { + return unescape(s, encodeQueryComponent) +} + +// PathUnescape does the inverse transformation of [PathEscape], +// converting each 3-byte encoded substring of the form "%AB" into the +// hex-decoded byte 0xAB. It returns an error if any % is not followed +// by two hexadecimal digits. +// +// PathUnescape is identical to [QueryUnescape] except that it does not +// unescape '+' to ' ' (space). +func PathUnescape(s string) (string, error) { + return unescape(s, encodePathSegment) +} + +// unescape unescapes a string; the mode specifies +// which section of the URL string is being unescaped. +func unescape(s string, mode encoding) (string, error) { + // Count %, check that they're well-formed. + n := 0 + hasPlus := false + for i := 0; i < len(s); { + switch s[i] { + case '%': + n++ + if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) { + s = s[i:] + if len(s) > 3 { + s = s[:3] + } + return "", EscapeError(s) + } + // Per https://tools.ietf.org/html/rfc3986#page-21 + // in the host component %-encoding can only be used + // for non-ASCII bytes. + // But https://tools.ietf.org/html/rfc6874#section-2 + // introduces %25 being allowed to escape a percent sign + // in IPv6 scoped-address literals. Yay. + if mode == encodeHost && unhex(s[i+1]) < 8 && s[i:i+3] != "%25" { + return "", EscapeError(s[i : i+3]) + } + if mode == encodeZone { + // RFC 6874 says basically "anything goes" for zone identifiers + // and that even non-ASCII can be redundantly escaped, + // but it seems prudent to restrict %-escaped bytes here to those + // that are valid host name bytes in their unescaped form. + // That is, you can use escaping in the zone identifier but not + // to introduce bytes you couldn't just write directly. + // But Windows puts spaces here! Yay. + v := unhex(s[i+1])<<4 | unhex(s[i+2]) + if s[i:i+3] != "%25" && v != ' ' && shouldEscape(v, encodeHost) { + return "", EscapeError(s[i : i+3]) + } + } + i += 3 + case '+': + hasPlus = mode == encodeQueryComponent + i++ + default: + if (mode == encodeHost || mode == encodeZone) && s[i] < 0x80 && shouldEscape(s[i], mode) { + return "", InvalidHostError(s[i : i+1]) + } + i++ + } + } + + if n == 0 && !hasPlus { + return s, nil + } + + var t strings.Builder + t.Grow(len(s) - 2*n) + for i := 0; i < len(s); i++ { + switch s[i] { + case '%': + t.WriteByte(unhex(s[i+1])<<4 | unhex(s[i+2])) + i += 2 + case '+': + if mode == encodeQueryComponent { + t.WriteByte(' ') + } else { + t.WriteByte('+') + } + default: + t.WriteByte(s[i]) + } + } + return t.String(), nil +} + +// QueryEscape escapes the string so it can be safely placed +// inside a [URL] query. +func QueryEscape(s string) string { + return escape(s, encodeQueryComponent) +} + +// PathEscape escapes the string so it can be safely placed inside a [URL] path segment, +// replacing special characters (including /) with %XX sequences as needed. +func PathEscape(s string) string { + return escape(s, encodePathSegment) +} + +func escape(s string, mode encoding) string { + spaceCount, hexCount := 0, 0 + for i := 0; i < len(s); i++ { + c := s[i] + if shouldEscape(c, mode) { + if c == ' ' && mode == encodeQueryComponent { + spaceCount++ + } else { + hexCount++ + } + } + } + + if spaceCount == 0 && hexCount == 0 { + return s + } + + var buf [64]byte + var t []byte + // Copyright 2009 The Go Authors. All rights reserved. + // Use of this source code is governed by a BSD-style + // license that can be found in the LICENSE file. + + // Package url parses URLs and implements query escaping. + package url + + // See RFC 3986. This package generally follows RFC 3986, except where + // it deviates for compatibility reasons. When sending changes, first + // search old issues for history on decisions. Unit tests should also + // contain references to issue numbers with details. + + import ( + "errors" + "fmt" + "path" + "slices" + "strconv" + "strings" + _ "unsafe" // for linkname + ) + + // Error reports an error and the operation and URL that caused it. + type Error struct { + Op string + URL string + Err error + } + + func (e *Error) Unwrap() error { return e.Err } + func (e *Error) Error() string { return fmt.Sprintf("%s %q: %s", e.Op, e.URL, e.Err) } + + func (e *Error) Timeout() bool { + t, ok := e.Err.(interface { + Timeout() bool + }) + return ok && t.Timeout() + } + + func (e *Error) Temporary() bool { + t, ok := e.Err.(interface { + Temporary() bool + }) + return ok && t.Temporary() + } + + const upperhex = "0123456789ABCDEF" + + func ishex(c byte) bool { + switch { + case '0' <= c && c <= '9': + return true + case 'a' <= c && c <= 'f': + return true + case 'A' <= c && c <= 'F': + return true + } + return false + } + + func unhex(c byte) byte { + switch { + case '0' <= c && c <= '9': + return c - '0' + case 'a' <= c && c <= 'f': + return c - 'a' + 10 + case 'A' <= c && c <= 'F': + return c - 'A' + 10 + } + return 0 + } + + type encoding int + + const ( + encodePath encoding = 1 + iota + encodePathSegment + encodeHost + encodeZone + encodeUserPassword + encodeQueryComponent + encodeFragment + ) + + type EscapeError string + + func (e EscapeError) Error() string { + return "invalid URL escape " + strconv.Quote(string(e)) + } + + type InvalidHostError string + + func (e InvalidHostError) Error() string { + return "invalid character " + strconv.Quote(string(e)) + " in host name" + } + + // Return true if the specified character should be escaped when + // appearing in a URL string, according to RFC 3986. + // + // Please be informed that for now shouldEscape does not check all + // reserved characters correctly. See golang.org/issue/5684. + func shouldEscape(c byte, mode encoding) bool { + // §2.3 Unreserved characters (alphanum) + if 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9' { + return false + } + + if mode == encodeHost || mode == encodeZone { + // §3.2.2 Host allows + // sub-delims = "!" / "$" / "&" / "'" / "(" / ")" / "*" / "+" / "," / ";" / "=" + // as part of reg-name. + // We add : because we include :port as part of host. + // We add [ ] because we include [ipv6]:port as part of host. + // We add < > because they're the only characters left that + // we could possibly allow, and Parse will reject them if we + // escape them (because hosts can't use %-encoding for + // ASCII bytes). + switch c { + case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=', ':', '[', ']', '<', '>', '"': + return false + } + } + + switch c { + case '-', '_', '.', '~': // §2.3 Unreserved characters (mark) + return false + + case '$', '&', '+', ',', '/', ':', ';', '=', '?', '@': // §2.2 Reserved characters (reserved) + // Different sections of the URL allow a few of + // the reserved characters to appear unescaped. + switch mode { + case encodePath: // §3.3 + // The RFC allows : @ & = + $ but saves / ; , for assigning + // meaning to individual path segments. This package + // only manipulates the path as a whole, so we allow those + // last three as well. That leaves only ? to escape. + return c == '?' + + case encodePathSegment: // §3.3 + // The RFC allows : @ & = + $ but saves / ; , for assigning + // meaning to individual path segments. + return c == '/' || c == ';' || c == ',' || c == '?' + + case encodeUserPassword: // §3.2.1 + // The RFC allows ';', ':', '&', '=', '+', '$', and ',' in + // userinfo, so we must escape only '@', '/', and '?'. + // The parsing of userinfo treats ':' as special so we must escape + // that too. + return c == '@' || c == '/' || c == '?' || c == ':' + + case encodeQueryComponent: // §3.4 + // The RFC reserves (so we must escape) everything. + return true + + case encodeFragment: // §4.1 + // The RFC text is silent but the grammar allows + // everything, so escape nothing. + return false + } + } + + if mode == encodeFragment { + // RFC 3986 §2.2 allows not escaping sub-delims. A subset of sub-delims are + // included in reserved from RFC 2396 §2.2. The remaining sub-delims do not + // need to be escaped. To minimize potential breakage, we apply two restrictions: + // (1) we always escape sub-delims outside of the fragment, and (2) we always + // escape single quote to avoid breaking callers that had previously assumed that + // single quotes would be escaped. See issue #19917. + switch c { + case '!', '(', ')', '*': + return false + } + } + + // Everything else must be escaped. + return true + } + + // QueryUnescape does the inverse transformation of [QueryEscape], + // converting each 3-byte encoded substring of the form "%AB" into the + // hex-decoded byte 0xAB. + // It returns an error if any % is not followed by two hexadecimal + // digits. + func QueryUnescape(s string) (string, error) { + return unescape(s, encodeQueryComponent) + } + + // PathUnescape does the inverse transformation of [PathEscape], + // converting each 3-byte encoded substring of the form "%AB" into the + // hex-decoded byte 0xAB. It returns an error if any % is not followed + // by two hexadecimal digits. + // + // PathUnescape is identical to [QueryUnescape] except that it does not + // unescape '+' to ' ' (space). + func PathUnescape(s string) (string, error) { + return unescape(s, encodePathSegment) + } + + // unescape unescapes a string; the mode specifies + // which section of the URL string is being unescaped. + func unescape(s string, mode encoding) (string, error) { + // Count %, check that they're well-formed. + n := 0 + hasPlus := false + for i := 0; i < len(s); { + switch s[i] { + case '%': + n++ + if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) { + s = s[i:] + if len(s) > 3 { + s = s[:3] + } + return "", EscapeError(s) + } + // Per https://tools.ietf.org/html/rfc3986#page-21 + // in the host component %-encoding can only be used + // for non-ASCII bytes. + // But https://tools.ietf.org/html/rfc6874#section-2 + // introduces %25 being allowed to escape a percent sign + // in IPv6 scoped-address literals. Yay. + if mode == encodeHost && unhex(s[i+1]) < 8 && s[i:i+3] != "%25" { + return "", EscapeError(s[i : i+3]) + } + if mode == encodeZone { + // RFC 6874 says basically "anything goes" for zone identifiers + // and that even non-ASCII can be redundantly escaped, + // but it seems prudent to restrict %-escaped bytes here to those + // that are valid host name bytes in their unescaped form. + // That is, you can use escaping in the zone identifier but not + // to introduce bytes you couldn't just write directly. + // But Windows puts spaces here! Yay. + v := unhex(s[i+1])<<4 | unhex(s[i+2]) + if s[i:i+3] != "%25" && v != ' ' && shouldEscape(v, encodeHost) { + return "", EscapeError(s[i : i+3]) + } + } + i += 3 + case '+': + hasPlus = mode == encodeQueryComponent + i++ + default: + if (mode == encodeHost || mode == encodeZone) && s[i] < 0x80 && shouldEscape(s[i], mode) { + return "", InvalidHostError(s[i : i+1]) + } + i++ + } + } + + if n == 0 && !hasPlus { + return s, nil + } + + var t strings.Builder + t.Grow(len(s) - 2*n) + for i := 0; i < len(s); i++ { + switch s[i] { + case '%': + t.WriteByte(unhex(s[i+1])<<4 | unhex(s[i+2])) + i += 2 + case '+': + if mode == encodeQueryComponent { + t.WriteByte(' ') + } else { + t.WriteByte('+') + } + default: + t.WriteByte(s[i]) + } + } + return t.String(), nil + } + + // QueryEscape escapes the string so it can be safely placed + // inside a [URL] query. + func QueryEscape(s string) string { + return escape(s, encodeQueryComponent) + } + + // PathEscape escapes the string so it can be safely placed inside a [URL] path segment, + // replacing special characters (including /) with %XX sequences as needed. + func PathEscape(s string) string { + return escape(s, encodePathSegment) + } + + func escape(s string, mode encoding) string { + spaceCount, hexCount := 0, 0 + for i := 0; i < len(s); i++ { + c := s[i] + if shouldEscape(c, mode) { + if c == ' ' && mode == encodeQueryComponent { + spaceCount++ + } else { + hexCount++ + } + } + } + + if spaceCount == 0 && hexCount == 0 { + return s + } + + var buf [64]byte + var t []byte + + required := len(s) + 2*hexCount + if required <= len(buf) { + t = buf[:required] + } else { + t = make([]byte, required) + } + + if hexCount == 0 { + copy(t, s) + for i := 0; i < len(s); i++ { + if s[i] == ' ' { + t[i] = '+' + } + } + return string(t) + } + + j := 0 + for i := 0; i < len(s); i++ { + switch c := s[i]; { + case c == ' ' && mode == encodeQueryComponent: + t[j] = '+' + j++ + case shouldEscape(c, mode): + t[j] = '%' + t[j+1] = upperhex[c>>4] + t[j+2] = upperhex[c&15] + j += 3 + default: + t[j] = s[i] + j++ + } + } + return string(t) + } + + // A URL represents a parsed URL (technically, a URI reference). + // + // The general form represented is: + // + // [scheme:][//[userinfo@]host][/]path[?query][#fragment] + // + // URLs that do not start with a slash after the scheme are interpreted as: + // + // scheme:opaque[?query][#fragment] + // + // The Host field contains the host and port subcomponents of the URL. + // When the port is present, it is separated from the host with a colon. + // When the host is an IPv6 address, it must be enclosed in square brackets: + // "[fe80::1]:80". The [net.JoinHostPort] function combines a host and port + // into a string suitable for the Host field, adding square brackets to + // the host when necessary. + // + // Note that the Path field is stored in decoded form: /%47%6f%2f becomes /Go/. + // A consequence is that it is impossible to tell which slashes in the Path were + // slashes in the raw URL and which were %2f. This distinction is rarely important, + // but when it is, the code should use the [URL.EscapedPath] method, which preserves + // the original encoding of Path. + // + // The RawPath field is an optional field which is only set when the default + // encoding of Path is different from the escaped path. See the EscapedPath method + // for more details. + // + // URL's String method uses the EscapedPath method to obtain the path. + type URL struct { + Scheme string + Opaque string // encoded opaque data + User *Userinfo // username and password information + Host string // host or host:port (see Hostname and Port methods) + Path string // path (relative paths may omit leading slash) + RawPath string // encoded path hint (see EscapedPath method) + OmitHost bool // do not emit empty host (authority) + ForceQuery bool // append a query ('?') even if RawQuery is empty + RawQuery string // encoded query values, without '?' + Fragment string // fragment for references, without '#' + RawFragment string // encoded fragment hint (see EscapedFragment method) + } + + // User returns a [Userinfo] containing the provided username + // and no password set. + func User(username string) *Userinfo { + return &Userinfo{username, "", false} + } + + // UserPassword returns a [Userinfo] containing the provided username + // and password. + // + // This functionality should only be used with legacy web sites. + // RFC 2396 warns that interpreting Userinfo this way + // “is NOT RECOMMENDED, because the passing of authentication + // information in clear text (such as URI) has proven to be a + // security risk in almost every case where it has been used.” + func UserPassword(username, password string) *Userinfo { + return &Userinfo{username, password, true} + } + + // The Userinfo type is an immutable encapsulation of username and + // password details for a [URL]. An existing Userinfo value is guaranteed + // to have a username set (potentially empty, as allowed by RFC 2396), + // and optionally a password. + type Userinfo struct { + username string + password string + passwordSet bool + } + + // Username returns the username. + func (u *Userinfo) Username() string { + if u == nil { + return "" + } + return u.username + } + + // Password returns the password in case it is set, and whether it is set. + func (u *Userinfo) Password() (string, bool) { + if u == nil { + return "", false + } + return u.password, u.passwordSet + } + + // String returns the encoded userinfo information in the standard form + // of "username[:password]". + func (u *Userinfo) String() string { + if u == nil { + return "" + } + s := escape(u.username, encodeUserPassword) + if u.passwordSet { + s += ":" + escape(u.password, encodeUserPassword) + } + return s + } + + // Maybe rawURL is of the form scheme:path. + // (Scheme must be [a-zA-Z][a-zA-Z0-9+.-]*) + // If so, return scheme, path; else return "", rawURL. + func getScheme(rawURL string) (scheme, path string, err error) { + for i := 0; i < len(rawURL); i++ { + c := rawURL[i] + switch { + case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z': + // do nothing + case '0' <= c && c <= '9' || c == '+' || c == '-' || c == '.': + if i == 0 { + return "", rawURL, nil + } + case c == ':': + if i == 0 { + return "", "", errors.New("missing protocol scheme") + } + return rawURL[:i], rawURL[i+1:], nil + default: + // we have encountered an invalid character, + // so there is no valid scheme + return "", rawURL, nil + } + } + return "", rawURL, nil + } + + // Parse parses a raw url into a [URL] structure. + // + // The url may be relative (a path, without a host) or absolute + // (starting with a scheme). Trying to parse a hostname and path + // without a scheme is invalid but may not necessarily return an + // error, due to parsing ambiguities. + func Parse(rawURL string) (*URL, error) { + // Cut off #frag + u, frag, _ := strings.Cut(rawURL, "#") + url, err := parse(u, false) + if err != nil { + return nil, &Error{"parse", u, err} + } + if frag == "" { + return url, nil + } + if err = url.setFragment(frag); err != nil { + return nil, &Error{"parse", rawURL, err} + } + return url, nil + } + + // ParseRequestURI parses a raw url into a [URL] structure. It assumes that + // url was received in an HTTP request, so the url is interpreted + // only as an absolute URI or an absolute path. + // The string url is assumed not to have a #fragment suffix. + // (Web browsers strip #fragment before sending the URL to a web server.) + func ParseRequestURI(rawURL string) (*URL, error) { + url, err := parse(rawURL, true) + if err != nil { + return nil, &Error{"parse", rawURL, err} + } + return url, nil + } + + // parse parses a URL from a string in one of two contexts. If + // viaRequest is true, the URL is assumed to have arrived via an HTTP request, + // in which case only absolute URLs or path-absolute relative URLs are allowed. + // If viaRequest is false, all forms of relative URLs are allowed. + func parse(rawURL string, viaRequest bool) (*URL, error) { + var rest string + var err error + + if stringContainsCTLByte(rawURL) { + return nil, errors.New("net/url: invalid control character in URL") + } + + if rawURL == "" && viaRequest { + return nil, errors.New("empty url") + } + url := new(URL) + + if rawURL == "*" { + url.Path = "*" + return url, nil + } + + // Split off possible leading "http:", "mailto:", etc. + // Cannot contain escaped characters. + if url.Scheme, rest, err = getScheme(rawURL); err != nil { + return nil, err + } + url.Scheme = strings.ToLower(url.Scheme) + + if strings.HasSuffix(rest, "?") && strings.Count(rest, "?") == 1 { + url.ForceQuery = true + rest = rest[:len(rest)-1] + } else { + rest, url.RawQuery, _ = strings.Cut(rest, "?") + } + + if !strings.HasPrefix(rest, "/") { + if url.Scheme != "" { + // We consider rootless paths per RFC 3986 as opaque. + url.Opaque = rest + return url, nil + } + if viaRequest { + return nil, errors.New("invalid URI for request") + } + + // Avoid confusion with malformed schemes, like cache_object:foo/bar. + // See golang.org/issue/16822. + // + // RFC 3986, §3.3: + // In addition, a URI reference (Section 4.1) may be a relative-path reference, + // in which case the first path segment cannot contain a colon (":") character. + if segment, _, _ := strings.Cut(rest, "/"); strings.Contains(segment, ":") { + // First path segment has colon. Not allowed in relative URL. + return nil, errors.New("first path segment in URL cannot contain colon") + } + } + + if (url.Scheme != "" || !viaRequest && !strings.HasPrefix(rest, "///")) && strings.HasPrefix(rest, "//") { + var authority string + authority, rest = rest[2:], "" + if i := strings.Index(authority, "/"); i >= 0 { + authority, rest = authority[:i], authority[i:] + } + url.User, url.Host, err = parseAuthority(authority) + if err != nil { + return nil, err + } + } else if url.Scheme != "" && strings.HasPrefix(rest, "/") { + // OmitHost is set to true when rawURL has an empty host (authority). + // See golang.org/issue/46059. + url.OmitHost = true + } + + // Set Path and, optionally, RawPath. + // RawPath is a hint of the encoding of Path. We don't want to set it if + // the default escaping of Path is equivalent, to help make sure that people + // don't rely on it in general. + if err := url.setPath(rest); err != nil { + return nil, err + } + return url, nil + } + + func parseAuthority(authority string) (user *Userinfo, host string, err error) { + i := strings.LastIndex(authority, "@") + if i < 0 { + host, err = parseHost(authority) + } else { + host, err = parseHost(authority[i+1:]) + } + if err != nil { + return nil, "", err + } + if i < 0 { + return nil, host, nil + } + userinfo := authority[:i] + if !validUserinfo(userinfo) { + return nil, "", errors.New("net/url: invalid userinfo") + } + if !strings.Contains(userinfo, ":") { + if userinfo, err = unescape(userinfo, encodeUserPassword); err != nil { + return nil, "", err + } + user = User(userinfo) + } else { + username, password, _ := strings.Cut(userinfo, ":") + if username, err = unescape(username, encodeUserPassword); err != nil { + return nil, "", err + } + if password, err = unescape(password, encodeUserPassword); err != nil { + return nil, "", err + } + user = UserPassword(username, password) + } + return user, host, nil + } + + // parseHost parses host as an authority without user + // information. That is, as host[:port]. + func parseHost(host string) (string, error) { + if strings.HasPrefix(host, "[") { + // Parse an IP-Literal in RFC 3986 and RFC 6874. + // E.g., "[fe80::1]", "[fe80::1%25en0]", "[fe80::1]:80". + i := strings.LastIndex(host, "]") + if i < 0 { + return "", errors.New("missing ']' in host") + } + colonPort := host[i+1:] + if !validOptionalPort(colonPort) { + return "", fmt.Errorf("invalid port %q after host", colonPort) + } + + // RFC 6874 defines that %25 (%-encoded percent) introduces + // the zone identifier, and the zone identifier can use basically + // any %-encoding it likes. That's different from the host, which + // can only %-encode non-ASCII bytes. + // We do impose some restrictions on the zone, to avoid stupidity + // like newlines. + zone := strings.Index(host[:i], "%25") + if zone >= 0 { + host1, err := unescape(host[:zone], encodeHost) + if err != nil { + return "", err + } + host2, err := unescape(host[zone:i], encodeZone) + if err != nil { + return "", err + } + host3, err := unescape(host[i:], encodeHost) + if err != nil { + return "", err + } + return host1 + host2 + host3, nil + } + } else if i := strings.LastIndex(host, ":"); i != -1 { + colonPort := host[i:] + if !validOptionalPort(colonPort) { + return "", fmt.Errorf("invalid port %q after host", colonPort) + } + } + + var err error + if host, err = unescape(host, encodeHost); err != nil { + return "", err + } + return host, nil + } + + // setPath sets the Path and RawPath fields of the URL based on the provided + // escaped path p. It maintains the invariant that RawPath is only specified + // when it differs from the default encoding of the path. + // For example: + // - setPath("/foo/bar") will set Path="/foo/bar" and RawPath="" + // - setPath("/foo%2fbar") will set Path="/foo/bar" and RawPath="/foo%2fbar" + // setPath will return an error only if the provided path contains an invalid + // escaping. + // + // setPath should be an internal detail, + // but widely used packages access it using linkname. + // Notable members of the hall of shame include: + // - github.com/sagernet/sing + // + // Do not remove or change the type signature. + // See go.dev/issue/67401. + // + //go:linkname badSetPath net/url.(*URL).setPath + func (u *URL) setPath(p string) error { + path, err := unescape(p, encodePath) + if err != nil { + return err + } + u.Path = path + if escp := escape(path, encodePath); p == escp { + // Default encoding is fine. + u.RawPath = "" + } else { + u.RawPath = p + } + return nil + } + + // for linkname because we cannot linkname methods directly + func badSetPath(*URL, string) error + + // EscapedPath returns the escaped form of u.Path. + // In general there are multiple possible escaped forms of any path. + // EscapedPath returns u.RawPath when it is a valid escaping of u.Path. + // Otherwise EscapedPath ignores u.RawPath and computes an escaped + // form on its own. + // The [URL.String] and [URL.RequestURI] methods use EscapedPath to construct + // their results. + // In general, code should call EscapedPath instead of + // reading u.RawPath directly. + func (u *URL) EscapedPath() string { + if u.RawPath != "" && validEncoded(u.RawPath, encodePath) { + p, err := unescape(u.RawPath, encodePath) + if err == nil && p == u.Path { + return u.RawPath + } + } + if u.Path == "*" { + return "*" // don't escape (Issue 11202) + } + return escape(u.Path, encodePath) + } + + // validEncoded reports whether s is a valid encoded path or fragment, + // according to mode. + // It must not contain any bytes that require escaping during encoding. + func validEncoded(s string, mode encoding) bool { + for i := 0; i < len(s); i++ { + // RFC 3986, Appendix A. + // pchar = unreserved / pct-encoded / sub-delims / ":" / "@". + // shouldEscape is not quite compliant with the RFC, + // so we check the sub-delims ourselves and let + // shouldEscape handle the others. + switch s[i] { + case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=', ':', '@': + // ok + case '[', ']': + // ok - not specified in RFC 3986 but left alone by modern browsers + case '%': + // ok - percent encoded, will decode + default: + if shouldEscape(s[i], mode) { + return false + } + } + } + return true + } + + // setFragment is like setPath but for Fragment/RawFragment. + func (u *URL) setFragment(f string) error { + frag, err := unescape(f, encodeFragment) + if err != nil { + return err + } + u.Fragment = frag + if escf := escape(frag, encodeFragment); f == escf { + // Default encoding is fine. + u.RawFragment = "" + } else { + u.RawFragment = f + } + return nil + } + + // EscapedFragment returns the escaped form of u.Fragment. + // In general there are multiple possible escaped forms of any fragment. + // EscapedFragment returns u.RawFragment when it is a valid escaping of u.Fragment. + // Otherwise EscapedFragment ignores u.RawFragment and computes an escaped + // form on its own. + // The [URL.String] method uses EscapedFragment to construct its result. + // In general, code should call EscapedFragment instead of + // reading u.RawFragment directly. + func (u *URL) EscapedFragment() string { + if u.RawFragment != "" && validEncoded(u.RawFragment, encodeFragment) { + f, err := unescape(u.RawFragment, encodeFragment) + if err == nil && f == u.Fragment { + return u.RawFragment + } + } + return escape(u.Fragment, encodeFragment) + } + + // validOptionalPort reports whether port is either an empty string + // or matches /^:\d*$/ + func validOptionalPort(port string) bool { + if port == "" { + return true + } + if port[0] != ':' { + return false + } + for _, b := range port[1:] { + if b < '0' || b > '9' { + return false + } + } + return true + } + + // String reassembles the [URL] into a valid URL string. + // The general form of the result is one of: + // + // scheme:opaque?query#fragment + // scheme://userinfo@host/path?query#fragment + // + // If u.Opaque is non-empty, String uses the first form; + // otherwise it uses the second form. + // Any non-ASCII characters in host are escaped. + // To obtain the path, String uses u.EscapedPath(). + // + // In the second form, the following rules apply: + // - if u.Scheme is empty, scheme: is omitted. + // - if u.User is nil, userinfo@ is omitted. + // - if u.Host is empty, host/ is omitted. + // - if u.Scheme and u.Host are empty and u.User is nil, + // the entire scheme://userinfo@host/ is omitted. + // - if u.Host is non-empty and u.Path begins with a /, + // the form host/path does not add its own /. + // - if u.RawQuery is empty, ?query is omitted. + // - if u.Fragment is empty, #fragment is omitted. + func (u *URL) String() string { + var buf strings.Builder + + n := len(u.Scheme) + if u.Opaque != "" { + n += len(u.Opaque) + } else { + if !u.OmitHost && (u.Scheme != "" || u.Host != "" || u.User != nil) { + username := u.User.Username() + password, _ := u.User.Password() + n += len(username) + len(password) + len(u.Host) + } + n += len(u.Path) + } + n += len(u.RawQuery) + len(u.RawFragment) + n += len(":" + "//" + "//" + ":" + "@" + "/" + "./" + "?" + "#") + buf.Grow(n) + + if u.Scheme != "" { + buf.WriteString(u.Scheme) + buf.WriteByte(':') + } + if u.Opaque != "" { + buf.WriteString(u.Opaque) + } else { + if u.Scheme != "" || u.Host != "" || u.User != nil { + if u.OmitHost && u.Host == "" && u.User == nil { + // omit empty host + } else { + if u.Host != "" || u.Path != "" || u.User != nil { + buf.WriteString("//") + } + if ui := u.User; ui != nil { + buf.WriteString(ui.String()) + buf.WriteByte('@') + } + if h := u.Host; h != "" { + buf.WriteString(escape(h, encodeHost)) + } + } + } + path := u.EscapedPath() + if path != "" && path[0] != '/' && u.Host != "" { + buf.WriteByte('/') + } + if buf.Len() == 0 { + // RFC 3986 §4.2 + // A path segment that contains a colon character (e.g., "this:that") + // cannot be used as the first segment of a relative-path reference, as + // it would be mistaken for a scheme name. Such a segment must be + // preceded by a dot-segment (e.g., "./this:that") to make a relative- + // path reference. + if segment, _, _ := strings.Cut(path, "/"); strings.Contains(segment, ":") { + buf.WriteString("./") + } + } + buf.WriteString(path) + } + if u.ForceQuery || u.RawQuery != "" { + buf.WriteByte('?') + buf.WriteString(u.RawQuery) + } + if u.Fragment != "" { + buf.WriteByte('#') + buf.WriteString(u.EscapedFragment()) + } + return buf.String() + } + + // Redacted is like [URL.String] but replaces any password with "xxxxx". + // Only the password in u.User is redacted. + func (u *URL) Redacted() string { + if u == nil { + return "" + } + + ru := *u + if _, has := ru.User.Password(); has { + ru.User = UserPassword(ru.User.Username(), "xxxxx") + } + return ru.String() + } + + // Values maps a string key to a list of values. + // It is typically used for query parameters and form values. + // Unlike in the http.Header map, the keys in a Values map + // are case-sensitive. + type Values map[string][]string + + // Get gets the first value associated with the given key. + // If there are no values associated with the key, Get returns + // the empty string. To access multiple values, use the map + // directly. + func (v Values) Get(key string) string { + vs := v[key] + if len(vs) == 0 { + return "" + } + return vs[0] + } + + // Set sets the key to value. It replaces any existing + // values. + func (v Values) Set(key, value string) { + v[key] = []string{value} + } + + // Add adds the value to key. It appends to any existing + // values associated with key. + func (v Values) Add(key, value string) { + v[key] = append(v[key], value) + } + + // Del deletes the values associated with key. + func (v Values) Del(key string) { + delete(v, key) + } + + // Has checks whether a given key is set. + func (v Values) Has(key string) bool { + _, ok := v[key] + return ok + } + + // ParseQuery parses the URL-encoded query string and returns + // a map listing the values specified for each key. + // ParseQuery always returns a non-nil map containing all the + // valid query parameters found; err describes the first decoding error + // encountered, if any. + // + // Query is expected to be a list of key=value settings separated by ampersands. + // A setting without an equals sign is interpreted as a key set to an empty + // value. + // Settings containing a non-URL-encoded semicolon are considered invalid. + func ParseQuery(query string) (Values, error) { + m := make(Values) + err := parseQuery(m, query) + return m, err + } + + func parseQuery(m Values, query string) (err error) { + for query != "" { + var key string + key, query, _ = strings.Cut(query, "&") + if strings.Contains(key, ";") { + err = fmt.Errorf("invalid semicolon separator in query") + continue + } + if key == "" { + continue + } + key, value, _ := strings.Cut(key, "=") + key, err1 := QueryUnescape(key) + if err1 != nil { + if err == nil { + err = err1 + } + continue + } + value, err1 = QueryUnescape(value) + if err1 != nil { + if err == nil { + err = err1 + } + continue + } + m[key] = append(m[key], value) + } + return err + } + + // Encode encodes the values into “URL encoded” form + // ("bar=baz&foo=quux") sorted by key. + func (v Values) Encode() string { + if len(v) == 0 { + return "" + } + var buf strings.Builder + keys := make([]string, 0, len(v)) + for k := range v { + keys = append(keys, k) + } + slices.Sort(keys) + for _, k := range keys { + vs := v[k] + keyEscaped := QueryEscape(k) + for _, v := range vs { + if buf.Len() > 0 { + buf.WriteByte('&') + } + buf.WriteString(keyEscaped) + buf.WriteByte('=') + buf.WriteString(QueryEscape(v)) + } + } + return buf.String() + } + + // resolvePath applies special path segments from refs and applies + // them to base, per RFC 3986. + func resolvePath(base, ref string) string { + var full string + if ref == "" { + full = base + } else if ref[0] != '/' { + i := strings.LastIndex(base, "/") + full = base[:i+1] + ref + } else { + full = ref + } + if full == "" { + return "" + } + + var ( + elem string + dst strings.Builder + ) + first := true + remaining := full + // We want to return a leading '/', so write it now. + dst.WriteByte('/') + found := true + for found { + elem, remaining, found = strings.Cut(remaining, "/") + if elem == "." { + first = false + // drop + continue + } + + if elem == ".." { + // Ignore the leading '/' we already wrote. + str := dst.String()[1:] + index := strings.LastIndexByte(str, '/') + + dst.Reset() + dst.WriteByte('/') + if index == -1 { + first = true + } else { + dst.WriteString(str[:index]) + } + } else { + if !first { + dst.WriteByte('/') + } + dst.WriteString(elem) + first = false + } + } + + if elem == "." || elem == ".." { + dst.WriteByte('/') + } + + // We wrote an initial '/', but we don't want two. + r := dst.String() + if len(r) > 1 && r[1] == '/' { + r = r[1:] + } + return r + } + + // IsAbs reports whether the [URL] is absolute. + // Absolute means that it has a non-empty scheme. + func (u *URL) IsAbs() bool { + return u.Scheme != "" + } + + // Parse parses a [URL] in the context of the receiver. The provided URL + // may be relative or absolute. Parse returns nil, err on parse + // failure, otherwise its return value is the same as [URL.ResolveReference]. + func (u *URL) Parse(ref string) (*URL, error) { + refURL, err := Parse(ref) + if err != nil { + return nil, err + } + return u.ResolveReference(refURL), nil + } + + // ResolveReference resolves a URI reference to an absolute URI from + // an absolute base URI u, per RFC 3986 Section 5.2. The URI reference + // may be relative or absolute. ResolveReference always returns a new + // [URL] instance, even if the returned URL is identical to either the + // base or reference. If ref is an absolute URL, then ResolveReference + // ignores base and returns a copy of ref. + func (u *URL) ResolveReference(ref *URL) *URL { + url := *ref + if ref.Scheme == "" { + url.Scheme = u.Scheme + } + if ref.Scheme != "" || ref.Host != "" || ref.User != nil { + // The "absoluteURI" or "net_path" cases. + // We can ignore the error from setPath since we know we provided a + // validly-escaped path. + url.setPath(resolvePath(ref.EscapedPath(), "")) + return &url + } + if ref.Opaque != "" { + url.User = nil + url.Host = "" + url.Path = "" + return &url + } + if ref.Path == "" && !ref.ForceQuery && ref.RawQuery == "" { + url.RawQuery = u.RawQuery + if ref.Fragment == "" { + url.Fragment = u.Fragment + url.RawFragment = u.RawFragment + } + } + if ref.Path == "" && u.Opaque != "" { + url.Opaque = u.Opaque + url.User = nil + url.Host = "" + url.Path = "" + return &url + } + // The "abs_path" or "rel_path" cases. + url.Host = u.Host + url.User = u.User + url.setPath(resolvePath(u.EscapedPath(), ref.EscapedPath())) + return &url + } + + // Query parses RawQuery and returns the corresponding values. + // It silently discards malformed value pairs. + // To check errors use [ParseQuery]. + func (u *URL) Query() Values { + v, _ := ParseQuery(u.RawQuery) + return v + } + + // RequestURI returns the encoded path?query or opaque?query + // string that would be used in an HTTP request for u. + func (u *URL) RequestURI() string { + result := u.Opaque + if result == "" { + result = u.EscapedPath() + if result == "" { + result = "/" + } + } else { + if strings.HasPrefix(result, "//") { + result = u.Scheme + ":" + result + } + } + if u.ForceQuery || u.RawQuery != "" { + result += "?" + u.RawQuery + } + return result + } + + // Hostname returns u.Host, stripping any valid port number if present. + // + // If the result is enclosed in square brackets, as literal IPv6 addresses are, + // the square brackets are removed from the result. + func (u *URL) Hostname() string { + host, _ := splitHostPort(u.Host) + return host + } + + // Port returns the port part of u.Host, without the leading colon. + // + // If u.Host doesn't contain a valid numeric port, Port returns an empty string. + func (u *URL) Port() string { + _, port := splitHostPort(u.Host) + return port + } + + // splitHostPort separates host and port. If the port is not valid, it returns + // the entire input as host, and it doesn't check the validity of the host. + // Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric. + func splitHostPort(hostPort string) (host, port string) { + host = hostPort + + colon := strings.LastIndexByte(host, ':') + if colon != -1 && validOptionalPort(host[colon:]) { + host, port = host[:colon], host[colon+1:] + } + + if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { + host = host[1 : len(host)-1] + } + + return + } + + // Marshaling interface implementations. + // Would like to implement MarshalText/UnmarshalText but that will change the JSON representation of URLs. + + func (u *URL) MarshalBinary() (text []byte, err error) { + return u.AppendBinary(nil) + } + + func (u *URL) AppendBinary(b []byte) ([]byte, error) { + return append(b, u.String()...), nil + } + + func (u *URL) UnmarshalBinary(text []byte) error { + u1, err := Parse(string(text)) + if err != nil { + return err + } + *u = *u1 + return nil + } + + // JoinPath returns a new [URL] with the provided path elements joined to + // any existing path and the resulting path cleaned of any ./ or ../ elements. + // Any sequences of multiple / characters will be reduced to a single /. + func (u *URL) JoinPath(elem ...string) *URL { + elem = append([]string{u.EscapedPath()}, elem...) + var p string + if !strings.HasPrefix(elem[0], "/") { + // Return a relative path if u is relative, + // but ensure that it contains no ../ elements. + elem[0] = "/" + elem[0] + p = path.Join(elem...)[1:] + } else { + p = path.Join(elem...) + } + // path.Join will remove any trailing slashes. + // Preserve at least one. + if strings.HasSuffix(elem[len(elem)-1], "/") && !strings.HasSuffix(p, "/") { + p += "/" + } + url := *u + url.setPath(p) + return &url + } + + // validUserinfo reports whether s is a valid userinfo string per RFC 3986 + // Section 3.2.1: + // + // userinfo = *( unreserved / pct-encoded / sub-delims / ":" ) + // unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" + // sub-delims = "!" / "$" / "&" / "'" / "(" / ")" + // / "*" / "+" / "," / ";" / "=" + // + // It doesn't validate pct-encoded. The caller does that via func unescape. + func validUserinfo(s string) bool { + for _, r := range s { + if 'A' <= r && r <= 'Z' { + continue + } + if 'a' <= r && r <= 'z' { + continue + } + if '0' <= r && r <= '9' { + continue + } + switch r { + case '-', '.', '_', ':', '~', '!', '$', '&', '\'', + '(', ')', '*', '+', ',', ';', '=', '%', '@': + continue + default: + return false + } + } + return true + } + + // stringContainsCTLByte reports whether s contains any ASCII control character. + func stringContainsCTLByte(s string) bool { + for i := 0; i < len(s); i++ { + b := s[i] + if b < ' ' || b == 0x7f { + return true + } + } + return false + } + + // JoinPath returns a [URL] string with the provided path elements joined to + // the existing path of base and the resulting path cleaned of any ./ or ../ elements. + func JoinPath(base string, elem ...string) (result string, err error) { + url, err := Parse(base) + if err != nil { + return + } + result = url.JoinPath(elem...).String() + return + } + + required := len(s) + 2*hexCount + if required <= len(buf) { + t = buf[:required] + } else { + t = make([]byte, required) + } + + if hexCount == 0 { + copy(t, s) + for i := 0; i < len(s); i++ { + if s[i] == ' ' { + t[i] = '+' + } + } + return string(t) + } + + j := 0 + for i := 0; i < len(s); i++ { + switch c := s[i]; { + case c == ' ' && mode == encodeQueryComponent: + t[j] = '+' + j++ + case shouldEscape(c, mode): + t[j] = '%' + t[j+1] = upperhex[c>>4] + t[j+2] = upperhex[c&15] + j += 3 + default: + t[j] = s[i] + j++ + } + } + return string(t) +} + +// A URL represents a parsed URL (technically, a URI reference). +// +// The general form represented is: +// +// [scheme:][//[userinfo@]host][/]path[?query][#fragment] +// +// URLs that do not start with a slash after the scheme are interpreted as: +// +// scheme:opaque[?query][#fragment] +// +// The Host field contains the host and port subcomponents of the URL. +// When the port is present, it is separated from the host with a colon. +// When the host is an IPv6 address, it must be enclosed in square brackets: +// "[fe80::1]:80". The [net.JoinHostPort] function combines a host and port +// into a string suitable for the Host field, adding square brackets to +// the host when necessary. +// +// Note that the Path field is stored in decoded form: /%47%6f%2f becomes /Go/. +// A consequence is that it is impossible to tell which slashes in the Path were +// slashes in the raw URL and which were %2f. This distinction is rarely important, +// but when it is, the code should use the [URL.EscapedPath] method, which preserves +// the original encoding of Path. +// +// The RawPath field is an optional field which is only set when the default +// encoding of Path is different from the escaped path. See the EscapedPath method +// for more details. +// +// URL's String method uses the EscapedPath method to obtain the path. +type URL struct { + Scheme string + Opaque string // encoded opaque data + User *Userinfo // username and password information + Host string // host or host:port (see Hostname and Port methods) + Path string // path (relative paths may omit leading slash) + RawPath string // encoded path hint (see EscapedPath method) + OmitHost bool // do not emit empty host (authority) + ForceQuery bool // append a query ('?') even if RawQuery is empty + RawQuery string // encoded query values, without '?' + Fragment string // fragment for references, without '#' + RawFragment string // encoded fragment hint (see EscapedFragment method) +} + +// User returns a [Userinfo] containing the provided username +// and no password set. +func User(username string) *Userinfo { + return &Userinfo{username, "", false} +} + +// UserPassword returns a [Userinfo] containing the provided username +// and password. +// +// This functionality should only be used with legacy web sites. +// RFC 2396 warns that interpreting Userinfo this way +// “is NOT RECOMMENDED, because the passing of authentication +// information in clear text (such as URI) has proven to be a +// security risk in almost every case where it has been used.” +func UserPassword(username, password string) *Userinfo { + return &Userinfo{username, password, true} +} + +// The Userinfo type is an immutable encapsulation of username and +// password details for a [URL]. An existing Userinfo value is guaranteed +// to have a username set (potentially empty, as allowed by RFC 2396), +// and optionally a password. +type Userinfo struct { + username string + password string + passwordSet bool +} + +// Username returns the username. +func (u *Userinfo) Username() string { + if u == nil { + return "" + } + return u.username +} + +// Password returns the password in case it is set, and whether it is set. +func (u *Userinfo) Password() (string, bool) { + if u == nil { + return "", false + } + return u.password, u.passwordSet +} + +// String returns the encoded userinfo information in the standard form +// of "username[:password]". +func (u *Userinfo) String() string { + if u == nil { + return "" + } + s := escape(u.username, encodeUserPassword) + if u.passwordSet { + s += ":" + escape(u.password, encodeUserPassword) + } + return s +} + +// Maybe rawURL is of the form scheme:path. +// (Scheme must be [a-zA-Z][a-zA-Z0-9+.-]*) +// If so, return scheme, path; else return "", rawURL. +func getScheme(rawURL string) (scheme, path string, err error) { + for i := 0; i < len(rawURL); i++ { + c := rawURL[i] + switch { + case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z': + // do nothing + case '0' <= c && c <= '9' || c == '+' || c == '-' || c == '.': + if i == 0 { + return "", rawURL, nil + } + case c == ':': + if i == 0 { + return "", "", errors.New("missing protocol scheme") + } + return rawURL[:i], rawURL[i+1:], nil + default: + // we have encountered an invalid character, + // so there is no valid scheme + return "", rawURL, nil + } + } + return "", rawURL, nil +} + +// Parse parses a raw url into a [URL] structure. +// +// The url may be relative (a path, without a host) or absolute +// (starting with a scheme). Trying to parse a hostname and path +// without a scheme is invalid but may not necessarily return an +// error, due to parsing ambiguities. +func Parse(rawURL string) (*URL, error) { + // Cut off #frag + u, frag, _ := strings.Cut(rawURL, "#") + url, err := parse(u, false) + if err != nil { + return nil, &Error{"parse", u, err} + } + if frag == "" { + return url, nil + } + if err = url.setFragment(frag); err != nil { + return nil, &Error{"parse", rawURL, err} + } + return url, nil +} + +// ParseRequestURI parses a raw url into a [URL] structure. It assumes that +// url was received in an HTTP request, so the url is interpreted +// only as an absolute URI or an absolute path. +// The string url is assumed not to have a #fragment suffix. +// (Web browsers strip #fragment before sending the URL to a web server.) +func ParseRequestURI(rawURL string) (*URL, error) { + url, err := parse(rawURL, true) + if err != nil { + return nil, &Error{"parse", rawURL, err} + } + return url, nil +} + +// parse parses a URL from a string in one of two contexts. If +// viaRequest is true, the URL is assumed to have arrived via an HTTP request, +// in which case only absolute URLs or path-absolute relative URLs are allowed. +// If viaRequest is false, all forms of relative URLs are allowed. +func parse(rawURL string, viaRequest bool) (*URL, error) { + var rest string + var err error + + if stringContainsCTLByte(rawURL) { + return nil, errors.New("net/url: invalid control character in URL") + } + + if rawURL == "" && viaRequest { + return nil, errors.New("empty url") + } + url := new(URL) + + if rawURL == "*" { + url.Path = "*" + return url, nil + } + + // Split off possible leading "http:", "mailto:", etc. + // Cannot contain escaped characters. + if url.Scheme, rest, err = getScheme(rawURL); err != nil { + return nil, err + } + url.Scheme = strings.ToLower(url.Scheme) + + if strings.HasSuffix(rest, "?") && strings.Count(rest, "?") == 1 { + url.ForceQuery = true + rest = rest[:len(rest)-1] + } else { + rest, url.RawQuery, _ = strings.Cut(rest, "?") + } + + if !strings.HasPrefix(rest, "/") { + if url.Scheme != "" { + // We consider rootless paths per RFC 3986 as opaque. + url.Opaque = rest + return url, nil + } + if viaRequest { + return nil, errors.New("invalid URI for request") + } + + // Avoid confusion with malformed schemes, like cache_object:foo/bar. + // See golang.org/issue/16822. + // + // RFC 3986, §3.3: + // In addition, a URI reference (Section 4.1) may be a relative-path reference, + // in which case the first path segment cannot contain a colon (":") character. + if segment, _, _ := strings.Cut(rest, "/"); strings.Contains(segment, ":") { + // First path segment has colon. Not allowed in relative URL. + return nil, errors.New("first path segment in URL cannot contain colon") + } + } + + if (url.Scheme != "" || !viaRequest && !strings.HasPrefix(rest, "///")) && strings.HasPrefix(rest, "//") { + var authority string + authority, rest = rest[2:], "" + if i := strings.Index(authority, "/"); i >= 0 { + authority, rest = authority[:i], authority[i:] + } + url.User, url.Host, err = parseAuthority(authority) + if err != nil { + return nil, err + } + } else if url.Scheme != "" && strings.HasPrefix(rest, "/") { + // OmitHost is set to true when rawURL has an empty host (authority). + // See golang.org/issue/46059. + url.OmitHost = true + } + + // Set Path and, optionally, RawPath. + // RawPath is a hint of the encoding of Path. We don't want to set it if + // the default escaping of Path is equivalent, to help make sure that people + // don't rely on it in general. + if err := url.setPath(rest); err != nil { + return nil, err + } + return url, nil +} + +func parseAuthority(authority string) (user *Userinfo, host string, err error) { + i := strings.LastIndex(authority, "@") + if i < 0 { + host, err = parseHost(authority) + } else { + host, err = parseHost(authority[i+1:]) + } + if err != nil { + return nil, "", err + } + if i < 0 { + return nil, host, nil + } + userinfo := authority[:i] + if !validUserinfo(userinfo) { + return nil, "", errors.New("net/url: invalid userinfo") + } + if !strings.Contains(userinfo, ":") { + if userinfo, err = unescape(userinfo, encodeUserPassword); err != nil { + return nil, "", err + } + user = User(userinfo) + } else { + username, password, _ := strings.Cut(userinfo, ":") + if username, err = unescape(username, encodeUserPassword); err != nil { + return nil, "", err + } + if password, err = unescape(password, encodeUserPassword); err != nil { + return nil, "", err + } + user = UserPassword(username, password) + } + return user, host, nil +} + +// parseHost parses host as an authority without user +// information. That is, as host[:port]. +func parseHost(host string) (string, error) { + if strings.HasPrefix(host, "[") { + // Parse an IP-Literal in RFC 3986 and RFC 6874. + // E.g., "[fe80::1]", "[fe80::1%25en0]", "[fe80::1]:80". + i := strings.LastIndex(host, "]") + if i < 0 { + return "", errors.New("missing ']' in host") + } + colonPort := host[i+1:] + if !validOptionalPort(colonPort) { + return "", fmt.Errorf("invalid port %q after host", colonPort) + } + + // RFC 6874 defines that %25 (%-encoded percent) introduces + // the zone identifier, and the zone identifier can use basically + // any %-encoding it likes. That's different from the host, which + // can only %-encode non-ASCII bytes. + // We do impose some restrictions on the zone, to avoid stupidity + // like newlines. + zone := strings.Index(host[:i], "%25") + if zone >= 0 { + host1, err := unescape(host[:zone], encodeHost) + if err != nil { + return "", err + } + host2, err := unescape(host[zone:i], encodeZone) + if err != nil { + return "", err + } + host3, err := unescape(host[i:], encodeHost) + if err != nil { + return "", err + } + return host1 + host2 + host3, nil + } + } else if i := strings.LastIndex(host, ":"); i != -1 { + colonPort := host[i:] + if !validOptionalPort(colonPort) { + return "", fmt.Errorf("invalid port %q after host", colonPort) + } + } + + var err error + if host, err = unescape(host, encodeHost); err != nil { + return "", err + } + return host, nil +} + +// setPath sets the Path and RawPath fields of the URL based on the provided +// escaped path p. It maintains the invariant that RawPath is only specified +// when it differs from the default encoding of the path. +// For example: +// - setPath("/foo/bar") will set Path="/foo/bar" and RawPath="" +// - setPath("/foo%2fbar") will set Path="/foo/bar" and RawPath="/foo%2fbar" +// setPath will return an error only if the provided path contains an invalid +// escaping. +// +// setPath should be an internal detail, +// but widely used packages access it using linkname. +// Notable members of the hall of shame include: +// - github.com/sagernet/sing +// +// Do not remove or change the type signature. +// See go.dev/issue/67401. +// +//go:linkname badSetPath net/url.(*URL).setPath +func (u *URL) setPath(p string) error { + path, err := unescape(p, encodePath) + if err != nil { + return err + } + u.Path = path + if escp := escape(path, encodePath); p == escp { + // Default encoding is fine. + u.RawPath = "" + } else { + u.RawPath = p + } + return nil +} + +// for linkname because we cannot linkname methods directly +func badSetPath(*URL, string) error + +// EscapedPath returns the escaped form of u.Path. +// In general there are multiple possible escaped forms of any path. +// EscapedPath returns u.RawPath when it is a valid escaping of u.Path. +// Otherwise EscapedPath ignores u.RawPath and computes an escaped +// form on its own. +// The [URL.String] and [URL.RequestURI] methods use EscapedPath to construct +// their results. +// In general, code should call EscapedPath instead of +// reading u.RawPath directly. +func (u *URL) EscapedPath() string { + if u.RawPath != "" && validEncoded(u.RawPath, encodePath) { + p, err := unescape(u.RawPath, encodePath) + if err == nil && p == u.Path { + return u.RawPath + } + } + if u.Path == "*" { + return "*" // don't escape (Issue 11202) + } + return escape(u.Path, encodePath) +} + +// validEncoded reports whether s is a valid encoded path or fragment, +// according to mode. +// It must not contain any bytes that require escaping during encoding. +func validEncoded(s string, mode encoding) bool { + for i := 0; i < len(s); i++ { + // RFC 3986, Appendix A. + // pchar = unreserved / pct-encoded / sub-delims / ":" / "@". + // shouldEscape is not quite compliant with the RFC, + // so we check the sub-delims ourselves and let + // shouldEscape handle the others. + switch s[i] { + case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=', ':', '@': + // ok + case '[', ']': + // ok - not specified in RFC 3986 but left alone by modern browsers + case '%': + // ok - percent encoded, will decode + default: + if shouldEscape(s[i], mode) { + return false + } + } + } + return true +} + +// setFragment is like setPath but for Fragment/RawFragment. +func (u *URL) setFragment(f string) error { + frag, err := unescape(f, encodeFragment) + if err != nil { + return err + } + u.Fragment = frag + if escf := escape(frag, encodeFragment); f == escf { + // Default encoding is fine. + u.RawFragment = "" + } else { + u.RawFragment = f + } + return nil +} + +// EscapedFragment returns the escaped form of u.Fragment. +// In general there are multiple possible escaped forms of any fragment. +// EscapedFragment returns u.RawFragment when it is a valid escaping of u.Fragment. +// Otherwise EscapedFragment ignores u.RawFragment and computes an escaped +// form on its own. +// The [URL.String] method uses EscapedFragment to construct its result. +// In general, code should call EscapedFragment instead of +// reading u.RawFragment directly. +func (u *URL) EscapedFragment() string { + if u.RawFragment != "" && validEncoded(u.RawFragment, encodeFragment) { + f, err := unescape(u.RawFragment, encodeFragment) + if err == nil && f == u.Fragment { + return u.RawFragment + } + } + return escape(u.Fragment, encodeFragment) +} + +// validOptionalPort reports whether port is either an empty string +// or matches /^:\d*$/ +func validOptionalPort(port string) bool { + if port == "" { + return true + } + if port[0] != ':' { + return false + } + for _, b := range port[1:] { + if b < '0' || b > '9' { + return false + } + } + return true +} + +// String reassembles the [URL] into a valid URL string. +// The general form of the result is one of: +// +// scheme:opaque?query#fragment +// scheme://userinfo@host/path?query#fragment +// +// If u.Opaque is non-empty, String uses the first form; +// otherwise it uses the second form. +// Any non-ASCII characters in host are escaped. +// To obtain the path, String uses u.EscapedPath(). +// +// In the second form, the following rules apply: +// - if u.Scheme is empty, scheme: is omitted. +// - if u.User is nil, userinfo@ is omitted. +// - if u.Host is empty, host/ is omitted. +// - if u.Scheme and u.Host are empty and u.User is nil, +// the entire scheme://userinfo@host/ is omitted. +// - if u.Host is non-empty and u.Path begins with a /, +// the form host/path does not add its own /. +// - if u.RawQuery is empty, ?query is omitted. +// - if u.Fragment is empty, #fragment is omitted. +func (u *URL) String() string { + var buf strings.Builder + + n := len(u.Scheme) + if u.Opaque != "" { + n += len(u.Opaque) + } else { + if !u.OmitHost && (u.Scheme != "" || u.Host != "" || u.User != nil) { + username := u.User.Username() + password, _ := u.User.Password() + n += len(username) + len(password) + len(u.Host) + } + n += len(u.Path) + } + n += len(u.RawQuery) + len(u.RawFragment) + n += len(":" + "//" + "//" + ":" + "@" + "/" + "./" + "?" + "#") + buf.Grow(n) + + if u.Scheme != "" { + buf.WriteString(u.Scheme) + buf.WriteByte(':') + } + if u.Opaque != "" { + buf.WriteString(u.Opaque) + } else { + if u.Scheme != "" || u.Host != "" || u.User != nil { + if u.OmitHost && u.Host == "" && u.User == nil { + // omit empty host + } else { + if u.Host != "" || u.Path != "" || u.User != nil { + buf.WriteString("//") + } + if ui := u.User; ui != nil { + buf.WriteString(ui.String()) + buf.WriteByte('@') + } + if h := u.Host; h != "" { + buf.WriteString(escape(h, encodeHost)) + } + } + } + path := u.EscapedPath() + if path != "" && path[0] != '/' && u.Host != "" { + buf.WriteByte('/') + } + if buf.Len() == 0 { + // RFC 3986 §4.2 + // A path segment that contains a colon character (e.g., "this:that") + // cannot be used as the first segment of a relative-path reference, as + // it would be mistaken for a scheme name. Such a segment must be + // preceded by a dot-segment (e.g., "./this:that") to make a relative- + // path reference. + if segment, _, _ := strings.Cut(path, "/"); strings.Contains(segment, ":") { + buf.WriteString("./") + } + } + buf.WriteString(path) + } + if u.ForceQuery || u.RawQuery != "" { + buf.WriteByte('?') + buf.WriteString(u.RawQuery) + } + if u.Fragment != "" { + buf.WriteByte('#') + buf.WriteString(u.EscapedFragment()) + } + return buf.String() +} + +// Redacted is like [URL.String] but replaces any password with "xxxxx". +// Only the password in u.User is redacted. +func (u *URL) Redacted() string { + if u == nil { + return "" + } + + ru := *u + if _, has := ru.User.Password(); has { + ru.User = UserPassword(ru.User.Username(), "xxxxx") + } + return ru.String() +} + +// Values maps a string key to a list of values. +// It is typically used for query parameters and form values. +// Unlike in the http.Header map, the keys in a Values map +// are case-sensitive. +type Values map[string][]string + +// Get gets the first value associated with the given key. +// If there are no values associated with the key, Get returns +// the empty string. To access multiple values, use the map +// directly. +func (v Values) Get(key string) string { + vs := v[key] + if len(vs) == 0 { + return "" + } + return vs[0] +} + +// Set sets the key to value. It replaces any existing +// values. +func (v Values) Set(key, value string) { + v[key] = []string{value} +} + +// Add adds the value to key. It appends to any existing +// values associated with key. +func (v Values) Add(key, value string) { + v[key] = append(v[key], value) +} + +// Del deletes the values associated with key. +func (v Values) Del(key string) { + delete(v, key) +} + +// Has checks whether a given key is set. +func (v Values) Has(key string) bool { + _, ok := v[key] + return ok +} + +// ParseQuery parses the URL-encoded query string and returns +// a map listing the values specified for each key. +// ParseQuery always returns a non-nil map containing all the +// valid query parameters found; err describes the first decoding error +// encountered, if any. +// +// Query is expected to be a list of key=value settings separated by ampersands. +// A setting without an equals sign is interpreted as a key set to an empty +// value. +// Settings containing a non-URL-encoded semicolon are considered invalid. +func ParseQuery(query string) (Values, error) { + m := make(Values) + err := parseQuery(m, query) + return m, err +} + +func parseQuery(m Values, query string) (err error) { + for query != "" { + var key string + key, query, _ = strings.Cut(query, "&") + if strings.Contains(key, ";") { + err = fmt.Errorf("invalid semicolon separator in query") + continue + } + if key == "" { + continue + } + key, value, _ := strings.Cut(key, "=") + key, err1 := QueryUnescape(key) + if err1 != nil { + if err == nil { + err = err1 + } + continue + } + value, err1 = QueryUnescape(value) + if err1 != nil { + if err == nil { + err = err1 + } + continue + } + m[key] = append(m[key], value) + } + return err +} + +// Encode encodes the values into “URL encoded” form +// ("bar=baz&foo=quux") sorted by key. +func (v Values) Encode() string { + if len(v) == 0 { + return "" + } + var buf strings.Builder + keys := make([]string, 0, len(v)) + for k := range v { + keys = append(keys, k) + } + slices.Sort(keys) + for _, k := range keys { + vs := v[k] + keyEscaped := QueryEscape(k) + for _, v := range vs { + if buf.Len() > 0 { + buf.WriteByte('&') + } + buf.WriteString(keyEscaped) + buf.WriteByte('=') + buf.WriteString(QueryEscape(v)) + } + } + return buf.String() +} + +// resolvePath applies special path segments from refs and applies +// them to base, per RFC 3986. +func resolvePath(base, ref string) string { + var full string + if ref == "" { + full = base + } else if ref[0] != '/' { + i := strings.LastIndex(base, "/") + full = base[:i+1] + ref + } else { + full = ref + } + if full == "" { + return "" + } + + var ( + elem string + dst strings.Builder + ) + first := true + remaining := full + // We want to return a leading '/', so write it now. + dst.WriteByte('/') + found := true + for found { + elem, remaining, found = strings.Cut(remaining, "/") + if elem == "." { + first = false + // drop + continue + } + + if elem == ".." { + // Ignore the leading '/' we already wrote. + str := dst.String()[1:] + index := strings.LastIndexByte(str, '/') + + dst.Reset() + dst.WriteByte('/') + if index == -1 { + first = true + } else { + dst.WriteString(str[:index]) + } + } else { + if !first { + dst.WriteByte('/') + } + dst.WriteString(elem) + first = false + } + } + + if elem == "." || elem == ".." { + dst.WriteByte('/') + } + + // We wrote an initial '/', but we don't want two. + r := dst.String() + if len(r) > 1 && r[1] == '/' { + r = r[1:] + } + return r +} + +// IsAbs reports whether the [URL] is absolute. +// Absolute means that it has a non-empty scheme. +func (u *URL) IsAbs() bool { + return u.Scheme != "" +} + +// Parse parses a [URL] in the context of the receiver. The provided URL +// may be relative or absolute. Parse returns nil, err on parse +// failure, otherwise its return value is the same as [URL.ResolveReference]. +func (u *URL) Parse(ref string) (*URL, error) { + refURL, err := Parse(ref) + if err != nil { + return nil, err + } + return u.ResolveReference(refURL), nil +} + +// ResolveReference resolves a URI reference to an absolute URI from +// an absolute base URI u, per RFC 3986 Section 5.2. The URI reference +// may be relative or absolute. ResolveReference always returns a new +// [URL] instance, even if the returned URL is identical to either the +// base or reference. If ref is an absolute URL, then ResolveReference +// ignores base and returns a copy of ref. +func (u *URL) ResolveReference(ref *URL) *URL { + url := *ref + if ref.Scheme == "" { + url.Scheme = u.Scheme + } + if ref.Scheme != "" || ref.Host != "" || ref.User != nil { + // The "absoluteURI" or "net_path" cases. + // We can ignore the error from setPath since we know we provided a + // validly-escaped path. + url.setPath(resolvePath(ref.EscapedPath(), "")) + return &url + } + if ref.Opaque != "" { + url.User = nil + url.Host = "" + url.Path = "" + return &url + } + if ref.Path == "" && !ref.ForceQuery && ref.RawQuery == "" { + url.RawQuery = u.RawQuery + if ref.Fragment == "" { + url.Fragment = u.Fragment + url.RawFragment = u.RawFragment + } + } + if ref.Path == "" && u.Opaque != "" { + url.Opaque = u.Opaque + url.User = nil + url.Host = "" + url.Path = "" + return &url + } + // The "abs_path" or "rel_path" cases. + url.Host = u.Host + url.User = u.User + url.setPath(resolvePath(u.EscapedPath(), ref.EscapedPath())) + return &url +} + +// Query parses RawQuery and returns the corresponding values. +// It silently discards malformed value pairs. +// To check errors use [ParseQuery]. +func (u *URL) Query() Values { + v, _ := ParseQuery(u.RawQuery) + return v +} + +// RequestURI returns the encoded path?query or opaque?query +// string that would be used in an HTTP request for u. +func (u *URL) RequestURI() string { + result := u.Opaque + if result == "" { + result = u.EscapedPath() + if result == "" { + result = "/" + } + } else { + if strings.HasPrefix(result, "//") { + result = u.Scheme + ":" + result + } + } + if u.ForceQuery || u.RawQuery != "" { + result += "?" + u.RawQuery + } + return result +} + +// Hostname returns u.Host, stripping any valid port number if present. +// +// If the result is enclosed in square brackets, as literal IPv6 addresses are, +// the square brackets are removed from the result. +func (u *URL) Hostname() string { + host, _ := splitHostPort(u.Host) + return host +} + +// Port returns the port part of u.Host, without the leading colon. +// +// If u.Host doesn't contain a valid numeric port, Port returns an empty string. +func (u *URL) Port() string { + _, port := splitHostPort(u.Host) + return port +} + +// splitHostPort separates host and port. If the port is not valid, it returns +// the entire input as host, and it doesn't check the validity of the host. +// Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric. +func splitHostPort(hostPort string) (host, port string) { + host = hostPort + + colon := strings.LastIndexByte(host, ':') + if colon != -1 && validOptionalPort(host[colon:]) { + host, port = host[:colon], host[colon+1:] + } + + if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { + host = host[1 : len(host)-1] + } + + return +} + +// Marshaling interface implementations. +// Would like to implement MarshalText/UnmarshalText but that will change the JSON representation of URLs. + +func (u *URL) MarshalBinary() (text []byte, err error) { + return u.AppendBinary(nil) +} + +func (u *URL) AppendBinary(b []byte) ([]byte, error) { + return append(b, u.String()...), nil +} + +func (u *URL) UnmarshalBinary(text []byte) error { + u1, err := Parse(string(text)) + if err != nil { + return err + } + *u = *u1 + return nil +} + +// JoinPath returns a new [URL] with the provided path elements joined to +// any existing path and the resulting path cleaned of any ./ or ../ elements. +// Any sequences of multiple / characters will be reduced to a single /. +func (u *URL) JoinPath(elem ...string) *URL { + elem = append([]string{u.EscapedPath()}, elem...) + var p string + if !strings.HasPrefix(elem[0], "/") { + // Return a relative path if u is relative, + // but ensure that it contains no ../ elements. + elem[0] = "/" + elem[0] + p = path.Join(elem...)[1:] + } else { + p = path.Join(elem...) + } + // path.Join will remove any trailing slashes. + // Preserve at least one. + if strings.HasSuffix(elem[len(elem)-1], "/") && !strings.HasSuffix(p, "/") { + p += "/" + } + url := *u + url.setPath(p) + return &url +} + +// validUserinfo reports whether s is a valid userinfo string per RFC 3986 +// Section 3.2.1: +// +// userinfo = *( unreserved / pct-encoded / sub-delims / ":" ) +// unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" +// sub-delims = "!" / "$" / "&" / "'" / "(" / ")" +// / "*" / "+" / "," / ";" / "=" +// +// It doesn't validate pct-encoded. The caller does that via func unescape. +func validUserinfo(s string) bool { + for _, r := range s { + if 'A' <= r && r <= 'Z' { + continue + } + if 'a' <= r && r <= 'z' { + continue + } + if '0' <= r && r <= '9' { + continue + } + switch r { + case '-', '.', '_', ':', '~', '!', '$', '&', '\'', + '(', ')', '*', '+', ',', ';', '=', '%', '@': + continue + default: + return false + } + } + return true +} + +// stringContainsCTLByte reports whether s contains any ASCII control character. +func stringContainsCTLByte(s string) bool { + for i := 0; i < len(s); i++ { + b := s[i] + if b < ' ' || b == 0x7f { + return true + } + } + return false +} + +// JoinPath returns a [URL] string with the provided path elements joined to +// the existing path of base and the resulting path cleaned of any ./ or ../ elements. +func JoinPath(base string, elem ...string) (result string, err error) { + url, err := Parse(base) + if err != nil { + return + } + result = url.JoinPath(elem...).String() + return +} From 9fb6f37a7cf3ad8cb4f5de94dc54841095de8d88 Mon Sep 17 00:00:00 2001 From: hackerchai Date: Thu, 15 Aug 2024 17:56:45 +0800 Subject: [PATCH 10/55] feat(x/net/http): Implement server conn logic and modify response & request logic Signed-off-by: hackerchai --- x/net/http/_demo/http.go | 13 ++ x/net/http/request.go | 57 +++++++-- x/net/http/response.go | 18 +-- x/net/http/server.go | 251 ++++++++++++++++++++++++++++++++++----- 4 files changed, 292 insertions(+), 47 deletions(-) create mode 100644 x/net/http/_demo/http.go diff --git a/x/net/http/_demo/http.go b/x/net/http/_demo/http.go new file mode 100644 index 0000000..7fd6b37 --- /dev/null +++ b/x/net/http/_demo/http.go @@ -0,0 +1,13 @@ +func main() { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "Hello, %s!", r.URL) + }) + + server := http.NewServer(":8080") + server.Handler = mux + err := server.ListenAndServe() + if err != nil { + fmt.Printf("Server error: %v\n", err) + } +} \ No newline at end of file diff --git a/x/net/http/request.go b/x/net/http/request.go index 96b834c..037ff6d 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -5,23 +5,28 @@ import ( "io" "unsafe" + "github.com/goplus/llgo/c" "github.com/goplus/llgo/rust/hyper" + cos "github.com/goplus/llgo/c/os" ) type Request struct { + Conn *Conn Method string URL string Header Header Body io.ReadCloser } -func newRequest(hyperReq *hyper.Request) (*Request, error) { +func newRequest(conn *Conn, hyperReq *hyper.Request) (*Request, error) { method := make([]byte, 32) methodLen := uintptr(len(method)) if err := hyperReq.Method(&method[0], &methodLen); err != hyper.OK { return nil, fmt.Errorf("failed to get method: %v", err) } + methodStr := string(method[:methodLen]) + var scheme, authority, pathAndQuery [1024]byte schemeLen, authorityLen, pathAndQueryLen := uintptr(len(scheme)), uintptr(len(authority)), uintptr(len(pathAndQuery)) if err := hyperReq.URIParts(&scheme[0], &schemeLen, &authority[0], &authorityLen, &pathAndQuery[0], &pathAndQueryLen); err != hyper.OK { @@ -29,20 +34,56 @@ func newRequest(hyperReq *hyper.Request) (*Request, error) { } req := &Request{ - Method: string(method[:methodLen]), + Conn: conn, + Method: methodStr, URL: fmt.Sprintf("%s://%s%s", string(scheme[:schemeLen]), string(authority[:authorityLen]), string(pathAndQuery[:pathAndQueryLen])), Header: make(Header), } headers := hyperReq.Headers() if headers != nil { - headers.Foreach(func(name *byte, nameLen uintptr, value *byte, valueLen uintptr) int { - key := string(unsafe.Slice(name, nameLen)) - val := string(unsafe.Slice(value, valueLen)) - req.Header.Add(key, val) - return hyper.IterContinue - }, nil) + headers.Foreach(addHeader, c.Pointer(req)) + } else { + return nil, fmt.Errorf("failed to get request headers") } + if methodStr == "POST" || methodStr == "PUT" { + body := hyperReq.Body() + if body != nil { + // task := body.Foreach(getBodyChunk, c.Pointer(req), nil) + // if task != nil { + // r := conn.Executor.Push(task) + // if r != hyper.OK { + // task.Free() + // return nil, fmt.Errorf("failed to push body foreach task: %v", r) + // } + // } else { + // return nil, fmt.Errorf("failed to create body foreach task") + // } + + } else { + return nil, fmt.Errorf("failed to get request body") + } + } + + return req, nil +} + +func addHeader(data unsafe.Pointer, name *byte, nameLen uintptr, value *byte, valueLen uintptr) c.Int { + req := (*Request)(data) + key := string(unsafe.Slice(name, nameLen)) + val := string(unsafe.Slice(value, valueLen)) + req.Header.Add(key, val) + return hyper.IterContinue +} + +//TODO(hackerchai): implement body chunk reader +func getBodyChunk(userdata c.Pointer, chunk *hyper.Buf) c.Int { + req := (*Request)(userdata) + buf := chunk.Bytes() + len := chunk.Len() + cos.Write(1, unsafe.Pointer(buf), len) + + return hyper.IterContinue } \ No newline at end of file diff --git a/x/net/http/response.go b/x/net/http/response.go index 8892339..d712118 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -1,8 +1,7 @@ package http import ( - "unsafe" - + "github.com/goplus/llgo/c" "github.com/goplus/llgo/rust/hyper" ) @@ -41,12 +40,12 @@ func (r *Response) WriteHeader(statusCode int) { r.statusCode = statusCode resp := hyper.NewResponse() - resp.SetStatus(uint(statusCode)) + resp.SetStatus(uint16(statusCode)) headers := resp.Headers() for k, v := range r.header { for _, val := range v { - headers.Set([]byte(k), uintptr(len(k)), []byte(val), uintptr(len(val))) + headers.Set(&[]byte(k)[0], uintptr(len(k)), &[]byte(val)[0], uintptr(len(val))) } } @@ -59,15 +58,16 @@ func (r *Response) finalize() error { } body := hyper.NewBody() - body.SetDataFunc(func(userdata unsafe.Pointer, ctx *hyper.Context, chunk **hyper.Buf) int { - *chunk = hyper.CopyBuf(r.body, uintptr(len(r.body))) - r.body = nil // Clear the body after sending - return hyper.PollReady - }) + //TODO(hackerchai): implement body data func + body.SetDataFunc() resp := hyper.NewResponse() resp.SetBody(body) r.channel.Send(resp) return nil +} + +//TODO(hackerchai): implement body chunk reader +func setBodyDataFunc(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { } \ No newline at end of file diff --git a/x/net/http/server.go b/x/net/http/server.go index e1ed90f..75e1006 100644 --- a/x/net/http/server.go +++ b/x/net/http/server.go @@ -2,10 +2,16 @@ package http import ( "fmt" + "os" + "sync" + "sync/atomic" "unsafe" + "github.com/goplus/llgo/c" "github.com/goplus/llgo/c/libuv" "github.com/goplus/llgo/c/net" + cos "github.com/goplus/llgo/c/os" + "github.com/goplus/llgo/c/syscall" "github.com/goplus/llgo/rust/hyper" ) @@ -23,9 +29,23 @@ type Server struct { Addr string Handler Handler - uvLoop *libuv.Loop - uvServer libuv.Tcp - hyperExecutor *hyper.Executor + uvLoop *libuv.Loop + uvServer libuv.Tcp + inShutdown atomic.Bool + + mu sync.Mutex + activeConnections map[*Conn]struct{} +} + +type Conn struct { + Stream *libuv.Tcp + PollHandle *libuv.Poll + EventMask c.Uint + ReadWaker *hyper.Waker + WriteWaker *hyper.Waker + ConnTask *hyper.Task + IsClosing c.Int + Executor *hyper.Executor } func NewServer(addr string) *Server { @@ -37,14 +57,13 @@ func NewServer(addr string) *Server { func (srv *Server) ListenAndServe() error { srv.uvLoop = libuv.DefaultLoop() - srv.hyperExecutor = hyper.NewExecutor() if err := libuv.InitTcp(srv.uvLoop, &srv.uvServer); err != 0 { return fmt.Errorf("failed to init TCP: %v", err) } var sockaddr net.SockaddrIn - if err := libuv.Ip4Addr(srv.Addr, 0, &sockaddr); err != 0 { + if err := libuv.Ip4Addr(c.AllocaCStr(srv.Addr), 0, &sockaddr); err != 0 { return fmt.Errorf("failed to create IP address: %v", err) } @@ -52,7 +71,14 @@ func (srv *Server) ListenAndServe() error { return fmt.Errorf("failed to bind: %v", err) } - if err := srv.uvServer.Listen(128, srv.onNewConnection); err != 0 { + // Set SO_REUSEADDR + yes := c.Int(1) + result := net.SetSockOpt(srv.uvServer.GetIoWatcherFd(), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, unsafe.Pointer(&yes), c.Uint(unsafe.Sizeof(yes))) + if result != 0 { + return fmt.Errorf("failed to set SO_REUSEADDR: %v", result) + } + + if err := (*libuv.Stream)(&srv.uvServer).Listen(128, srv.onNewConnection); err != 0 { return fmt.Errorf("failed to listen: %v", err) } @@ -61,16 +87,18 @@ func (srv *Server) ListenAndServe() error { for { srv.uvLoop.Run(libuv.RUN_NOWAIT) - task := srv.hyperExecutor.Poll() - for task != nil { - srv.handleTask(task) - task.Free() - task = srv.hyperExecutor.Poll() + for conn := range srv.activeConnections { + task := conn.Executor.Poll() + for task != nil { + srv.handleTask(task) + task.Free() + task = conn.Executor.Poll() + } } } } -func (srv *Server) onNewConnection(serverStream *libuv.Stream, status int) { +func (srv *Server) onNewConnection(serverStream *libuv.Stream, status c.Int) { if status < 0 { fmt.Printf("New connection error: %s\n", libuv.Strerror(libuv.Errno(status))) return @@ -80,14 +108,23 @@ func (srv *Server) onNewConnection(serverStream *libuv.Stream, status int) { libuv.InitTcp(srv.uvLoop, client) if serverStream.Accept((*libuv.Stream)(unsafe.Pointer(client))) == 0 { - io := createIo(client) + conn := createConnData(srv.uvLoop, client) + if conn == nil { + fmt.Fprintf(os.Stderr, "Failed to create Conn\n") + (*libuv.Handle)(unsafe.Pointer(client)).Close(onClose) + return + } + srv.trackConn(conn, true) + + io := createIo(conn) service := hyper.ServiceNew(srv.serverCallback) + service.SetUserdata(unsafe.Pointer(conn), freeConnData) - http1Opts := hyper.Http1ServerconnOptionsNew(srv.hyperExecutor) - http2Opts := hyper.Http2ServerconnOptionsNew(srv.hyperExecutor) + http1Opts := hyper.Http1ServerconnOptionsNew(conn.Executor) + http2Opts := hyper.Http2ServerconnOptionsNew(conn.Executor) serverconn := hyper.ServeHttpXConnection(http1Opts, http2Opts, io, service) - srv.hyperExecutor.Push(serverconn) + conn.Executor.Push(serverconn) http1Opts.Free() http2Opts.Free() @@ -97,7 +134,14 @@ func (srv *Server) onNewConnection(serverStream *libuv.Stream, status int) { } func (srv *Server) serverCallback(userdata unsafe.Pointer, hyperReq *hyper.Request, channel *hyper.ResponseChannel) { - req, err := newRequest(hyperReq) + conn := (*Conn)(userdata) + + if hyperReq == nil { + fmt.Fprintf(os.Stderr, "Error: Received null request\n") + return + } + + req, err := newRequest(conn, hyperReq) if err != nil { fmt.Printf("Error creating request: %v\n", err) return @@ -118,27 +162,174 @@ func (srv *Server) handleTask(task *hyper.Task) { fmt.Println("Response sent") case hyper.TaskError: err := (*hyper.Error)(task.Value()) - fmt.Printf("Task error: %s\n", err.Message()) + var errbuf [256]byte + errlen := err.Print(&errbuf[0], unsafe.Sizeof(errbuf)) + fmt.Printf("Task error: %.*s\n", errlen, (*c.Char)(unsafe.Pointer(&errbuf[0]))) + err.Free() + } +} + +func (s *Server) trackConn(c *Conn, add bool) { + s.mu.Lock() + defer s.mu.Unlock() + if s.activeConnections == nil { + s.activeConnections = make(map[*Conn]struct{}) + } + if add { + s.activeConnections[c] = struct{}{} + } else { + delete(s.activeConnections, c) + } +} + +func (srv *Server) Close() error { + srv.inShutdown.Store(true) + srv.mu.Lock() + defer srv.mu.Unlock() + + for c := range srv.activeConnections { + delete(srv.activeConnections, c) } + return nil } -func createIo(client *libuv.Tcp) *hyper.Io { +func createIo(conn *Conn) *hyper.Io { io := hyper.NewIo() - io.SetRead(func(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uintptr) uintptr { - ret := client.Read(unsafe.Pointer(buf), bufLen) - if ret < 0 { + io.SetUserdata(unsafe.Pointer(conn), freeConnData) + io.SetRead(readCb) + io.SetWrite(writeCb) + return io +} + +func readCb(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uintptr) uintptr { + conn := (*Conn)(userdata) + ret := net.Recv(conn.Stream.GetIoWatcherFd(), unsafe.Pointer(buf), bufLen, 0) + + if ret >= 0 { + return uintptr(ret) + } + + if uintptr(cos.Errno) != syscall.EAGAIN && uintptr(cos.Errno) != syscall.EWOULDBLOCK { + return hyper.IoError + } + + if conn.ReadWaker != nil { + conn.ReadWaker.Free() + } + + if conn.EventMask&c.Uint(libuv.READABLE) == 0 { + conn.EventMask |= c.Uint(libuv.READABLE) + if !updateConnRegistrations(conn, false) { return hyper.IoError } + } + + conn.ReadWaker = ctx.Waker() + return hyper.IoPending +} + +func writeCb(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uintptr) uintptr { + conn := (*Conn)(userdata) + ret := net.Send(conn.Stream.GetIoWatcherFd(), unsafe.Pointer(buf), bufLen, 0) + + if ret >= 0 { return uintptr(ret) - }) - io.SetWrite(func(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uintptr) uintptr { - ret := client.Write(unsafe.Pointer(buf), bufLen) - if ret < 0 { + } + + if uintptr(cos.Errno) != syscall.EAGAIN && uintptr(cos.Errno) != syscall.EWOULDBLOCK { + return hyper.IoError + } + + if conn.WriteWaker != nil { + conn.WriteWaker.Free() + } + + if conn.EventMask&c.Uint(libuv.WRITABLE) == 0 { + conn.EventMask |= c.Uint(libuv.WRITABLE) + if !updateConnRegistrations(conn, false) { return hyper.IoError } - return uintptr(ret) - }) - return io + } + + conn.WriteWaker = ctx.Waker() + return hyper.IoPending +} + +func onClose(handle *libuv.Handle) { + c.Free(unsafe.Pointer(handle)) +} + +func onPoll(handle *libuv.Poll, status c.Int, events c.Int) { + conn := (*Conn)((*libuv.Handle)(unsafe.Pointer(handle)).GetData()) + + if status < 0 { + fmt.Fprintf(os.Stderr, "Poll error: %s\n", libuv.Strerror(libuv.Errno(status))) + return + } + + if events&c.Int(libuv.READABLE) != 0 && conn.ReadWaker != nil { + conn.ReadWaker.Wake() + conn.ReadWaker = nil + } + + if events&c.Int(libuv.WRITABLE) != 0 && conn.WriteWaker != nil { + conn.WriteWaker.Wake() + conn.WriteWaker = nil + } +} + +func updateConnRegistrations(conn *Conn, create bool) bool { + events := c.Int(0) + if conn.EventMask&c.Uint(libuv.READABLE) != 0 { + events |= c.Int(libuv.READABLE) + } + if conn.EventMask&c.Uint(libuv.WRITABLE) != 0 { + events |= c.Int(libuv.WRITABLE) + } + + r := conn.PollHandle.Start(events, onPoll) + if r < 0 { + fmt.Fprintf(os.Stderr, "uv_poll_start error: %s\n", libuv.Strerror(libuv.Errno(r))) + return false + } + return true +} + +func createConnData(loop *libuv.Loop, client *libuv.Tcp) *Conn { + conn := (*Conn)(c.Calloc(1, unsafe.Sizeof(Conn{}))) + if conn == nil { + fmt.Fprintf(os.Stderr, "Failed to allocate conn_data\n") + return nil + } + c.Memcpy(unsafe.Pointer(&conn.Stream), unsafe.Pointer(client), unsafe.Sizeof(libuv.Tcp{})) + conn.IsClosing = 0 + + r := libuv.PollInit(loop, conn.PollHandle, libuv.OsFd(client.GetIoWatcherFd())) + if r < 0 { + fmt.Fprintf(os.Stderr, "uv_poll_init error: %s\n", libuv.Strerror(libuv.Errno(r))) + c.Free(unsafe.Pointer(conn)) + return nil + } + + (*libuv.Handle)(unsafe.Pointer(&conn.PollHandle)).Data = unsafe.Pointer(conn) + conn.Stream.Data = unsafe.Pointer(conn) + + if !updateConnRegistrations(conn, true) { + (*libuv.Handle)(unsafe.Pointer(&conn.PollHandle)).Close(nil) + c.Free(unsafe.Pointer(conn)) + return nil + } + + return conn +} + +func freeConnData(userdata c.Pointer) { + conn := (*Conn)(userdata) + if conn != nil && conn.IsClosing == 0 { + conn.IsClosing = 1 + // We don't immediately close the connection here. + // Instead, we'll let the main loop handle the closure when appropriate. + } } type HandlerFunc func(ResponseWriter, *Request) @@ -152,4 +343,4 @@ func NotFoundHandler() Handler { return HandlerFunc(NotFound) } func NotFound(w ResponseWriter, r *Request) { w.WriteHeader(404) w.Write([]byte("404 page not found")) -} \ No newline at end of file +} From 2e9e3384d9913d506f5108059c509dc978609c89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Thu, 15 Aug 2024 18:16:45 +0800 Subject: [PATCH 11/55] WIP(x/http/client): Code tweaks made; Post request runs successfully. Initial redirection check implemented. --- x/http/_demo/get/get.go | 1 + x/http/_demo/headers/headers.go | 10 +- x/http/_demo/post/post.go | 24 + x/http/_demo/upload/example.txt | 1 + x/http/_demo/upload/upload.go | 24 + x/http/client.go | 179 ++++++- x/http/header.go | 132 +++--- x/http/request.go | 164 +++++-- x/http/response.go | 58 ++- x/http/transfer.go | 808 +++++++++++++++++++------------- x/http/transport.go | 33 +- 11 files changed, 974 insertions(+), 460 deletions(-) create mode 100644 x/http/_demo/post/post.go create mode 100755 x/http/_demo/upload/example.txt create mode 100644 x/http/_demo/upload/upload.go diff --git a/x/http/_demo/get/get.go b/x/http/_demo/get/get.go index c8460e4..bff1bd1 100644 --- a/x/http/_demo/get/get.go +++ b/x/http/_demo/get/get.go @@ -14,6 +14,7 @@ func main() { return } fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) + fmt.Println(resp.Proto) resp.PrintHeaders() body, err := io.ReadAll(resp.Body) if err != nil { diff --git a/x/http/_demo/headers/headers.go b/x/http/_demo/headers/headers.go index 98cda79..2672a66 100644 --- a/x/http/_demo/headers/headers.go +++ b/x/http/_demo/headers/headers.go @@ -4,19 +4,19 @@ import ( "fmt" "io" - "github.com/goplus/llgo/x/http" + "github.com/goplus/llgoexamples/x/http" ) func main() { client := &http.Client{} - req, err := http.NewRequest("GET", "https://jsonplaceholder.typicode.com/comments?postId=1", nil) + req, err := http.NewRequest("GET", "https://www.baidu.com", nil) if err != nil { println(err.Error()) return } //req.Header.Set("accept", "*/*") - //req.Header.Set("accept-encoding", "identity") + req.Header.Set("accept-encoding", "identity") //req.Header.Set("cache-control", "no-cache") //req.Header.Set("pragma", "no-cache") //req.Header.Set("priority", "u=0, i") @@ -28,7 +28,7 @@ func main() { //req.Header.Set("sec-fetch-mode", "navigate") //req.Header.Set("sec-fetch-site", "same-origin") //req.Header.Set("sec-fetch-user", "?1") - //req.Header.Set("upgrade-insecure-requests", "1") + ////req.Header.Set("upgrade-insecure-requests", "1") //req.Header.Set("user-agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36") resp, err := client.Do(req) @@ -36,10 +36,12 @@ func main() { println(err.Error()) return } + resp.PrintHeaders() body, err := io.ReadAll(resp.Body) if err != nil { println(err.Error()) return } fmt.Println(string(body)) + defer resp.Body.Close() } diff --git a/x/http/_demo/post/post.go b/x/http/_demo/post/post.go new file mode 100644 index 0000000..a86e805 --- /dev/null +++ b/x/http/_demo/post/post.go @@ -0,0 +1,24 @@ +package main + +import ( + "fmt" + "io" + + "github.com/goplus/llgoexamples/x/http" +) + +func main() { + resp, err := http.Post("https://jsonplaceholder.typicode.com/posts", "application/json; charset=UTF-8", nil) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(resp.Status) + body, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(string(body)) + defer resp.Body.Close() +} diff --git a/x/http/_demo/upload/example.txt b/x/http/_demo/upload/example.txt new file mode 100755 index 0000000..1253cd4 --- /dev/null +++ b/x/http/_demo/upload/example.txt @@ -0,0 +1 @@ +hello upload \ No newline at end of file diff --git a/x/http/_demo/upload/upload.go b/x/http/_demo/upload/upload.go new file mode 100644 index 0000000..c6bb391 --- /dev/null +++ b/x/http/_demo/upload/upload.go @@ -0,0 +1,24 @@ +package main + +import ( + "fmt" + "io" + + "github.com/goplus/llgoexamples/x/http" +) + +func main() { + resp, err := http.Post("http://httpbin.org/post", "", nil) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(resp.Status) + body, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(string(body)) + defer resp.Body.Close() +} diff --git a/x/http/client.go b/x/http/client.go index 177b089..72e9688 100644 --- a/x/http/client.go +++ b/x/http/client.go @@ -1,6 +1,12 @@ package http -import "time" +import ( + "errors" + "io" + "net/url" + "strings" + "time" +) type Client struct { Transport RoundTripper @@ -32,17 +38,93 @@ func (c *Client) Get(url string) (*Response, error) { return c.Do(req) } -func (c *Client) Do(req *Request) (*Response, error) { - return c.do(req) +func Post(url, contentType string, body io.Reader) (resp *Response, err error) { + return DefaultClient.Post(url, contentType, body) } -func (c *Client) do(req *Request) (*Response, error) { - // Add user-defined request headers to hyper.Request - err := req.setHeaders() +func (c *Client) Post(url, contentType string, body io.Reader) (resp *Response, err error) { + req, err := NewRequest("POST", url, body) if err != nil { return nil, err } - return c.send(req, c.Timeout) + req.Header.Set("Content-Type", contentType) + return c.Do(req) +} + +func (c *Client) Do(req *Request) (*Response, error) { + return c.do(req) +} + +var testHookClientDoResult func(retres *Response, reterr error) + +func (c *Client) do(req *Request) (retres *Response, reterr error) { + if testHookClientDoResult != nil { + defer func() { testHookClientDoResult(retres, reterr) }() + } + + if req.URL == nil { + req.closeBody() + return nil, &url.Error{ + Op: urlErrorOp(req.Method), + Err: errors.New("http: nil Request.URL"), + } + } + var ( + //deadline = c.deadline() + reqs []*Request + resp *Response + //copyHeaders = c.makeHeadersCopier(req) + reqBodyClosed = false // have we closed the current req.Body? + + // Redirect behavior: + //redirectMethod string + //includeBody bool + ) + uerr := func(err error) error { + // the body may have been closed already by c.send() + if !reqBodyClosed { + req.closeBody() + } + var urlStr string + if resp != nil && resp.Request != nil { + urlStr = stripPassword(resp.Request.URL) + } else { + urlStr = stripPassword(req.URL) + } + return &url.Error{ + Op: urlErrorOp(reqs[0].Method), + URL: urlStr, + Err: err, + } + } + + // For all but the first request, create the next + // request hop and replace req. + for { + if len(reqs) > 0 { + + } + + reqs = append(reqs, req) + var err error + if resp, err = c.send(req, c.Timeout); err != nil { + // c.send() always closes req.Body + reqBodyClosed = true + return nil, uerr(err) + } + + var shouldRedirect bool + //redirectMethod, shouldRedirect, includeBody = redirectBehavior(req.Method, resp, reqs[0]) + _, shouldRedirect, _ = redirectBehavior(req.Method, resp, reqs[0]) + if !shouldRedirect { + return resp, nil + } else { + // TODO(spongehah) + return nil, errors.New("TODO: redirect not implemented") + } + + req.closeBody() + } } func (c *Client) send(req *Request, timeout time.Duration) (*Response, error) { @@ -53,3 +135,86 @@ func send(req *Request, rt RoundTripper, timeout time.Duration) (resp *Response, req.timeout = timeout return rt.RoundTrip(req) } + +// redirectBehavior describes what should happen when the +// client encounters a 3xx status code from the server. +func redirectBehavior(reqMethod string, resp *Response, ireq *Request) (redirectMethod string, shouldRedirect, includeBody bool) { + switch resp.StatusCode { + case 301, 302, 303: + redirectMethod = reqMethod + shouldRedirect = true + includeBody = false + + // RFC 2616 allowed automatic redirection only with GET and + // HEAD requests. RFC 7231 lifts this restriction, but we still + // restrict other methods to GET to maintain compatibility. + // See Issue 18570. + if reqMethod != "GET" && reqMethod != "HEAD" { + redirectMethod = "GET" + } + case 307, 308: + redirectMethod = reqMethod + shouldRedirect = true + includeBody = true + + if ireq.GetBody == nil && ireq.outgoingLength() != 0 { + // We had a request body, and 307/308 require + // re-sending it, but GetBody is not defined. So just + // return this response to the user instead of an + // error, like we did in Go 1.7 and earlier. + shouldRedirect = false + } + } + return redirectMethod, shouldRedirect, includeBody +} + +// outgoingLength reports the Content-Length of this outgoing (Client) request. +// It maps 0 into -1 (unknown) when the Body is non-nil. +func (r *Request) outgoingLength() int64 { + if r.Body == nil || r.Body == NoBody { + return 0 + } + if r.ContentLength != 0 { + return r.ContentLength + } + return -1 +} + +// urlErrorOp returns the (*url.Error).Op value to use for the +// provided (*Request).Method value. +func urlErrorOp(method string) string { + if method == "" { + return "Get" + } + if lowerMethod, ok := ToLower(method); ok { + return method[:1] + lowerMethod[1:] + } + return method +} + +// ToLower returns the lowercase version of s if s is ASCII and printable. +func ToLower(s string) (lower string, ok bool) { + if !IsPrint(s) { + return "", false + } + return strings.ToLower(s), true +} + +// IsPrint returns whether s is ASCII and printable according to +// https://tools.ietf.org/html/rfc20#section-4.2. +func IsPrint(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] < ' ' || s[i] > '~' { + return false + } + } + return true +} + +func stripPassword(u *url.URL) string { + _, passSet := u.User.Password() + if passSet { + return strings.Replace(u.String(), u.User.String()+"@", u.User.Username()+":***@", 1) + } + return u.String() +} diff --git a/x/http/header.go b/x/http/header.go index 0533ed7..076db0f 100644 --- a/x/http/header.go +++ b/x/http/header.go @@ -2,6 +2,7 @@ package http import ( "fmt" + "net/textproto" "github.com/goplus/llgo/c" "github.com/goplus/llgoexamples/rust/hyper" @@ -17,80 +18,79 @@ type Header map[string][]string // It appends to any existing values associated with key. // The key is case insensitive; it is canonicalized by // CanonicalHeaderKey. -//func (h Header) Add(key, value string) { -// textproto.MIMEHeader(h).Add(key, value) -//} -// -//// Set sets the header entries associated with key to the -//// single element value. It replaces any existing values -//// associated with key. The key is case insensitive; it is -//// canonicalized by textproto.CanonicalMIMEHeaderKey. -//// To use non-canonical keys, assign to the map directly. -//func (h Header) Set(key, value string) { -// textproto.MIMEHeader(h).Set(key, value) -//} -// -//// Get gets the first value associated with the given key. If -//// there are no values associated with the key, Get returns "". -//// It is case insensitive; textproto.CanonicalMIMEHeaderKey is -//// used to canonicalize the provided key. Get assumes that all -//// keys are stored in canonical form. To use non-canonical keys, -//// access the map directly. -//func (h Header) Get(key string) string { -// return textproto.MIMEHeader(h).Get(key) -//} -// -//// Values returns all values associated with the given key. -//// It is case insensitive; textproto.CanonicalMIMEHeaderKey is -//// used to canonicalize the provided key. To use non-canonical -//// keys, access the map directly. -//// The returned slice is not a copy. -//func (h Header) Values(key string) []string { -// return textproto.MIMEHeader(h).Values(key) -//} -// -//// get is like Get, but key must already be in CanonicalHeaderKey form. -//func (h Header) get(key string) string { -// if v := h[key]; len(v) > 0 { -// return v[0] -// } -// return "" -//} -// -//// has reports whether h has the provided key defined, even if it's -//// set to 0-length slice. -//func (h Header) has(key string) bool { -// _, ok := h[key] -// return ok -//} -// -//// Del deletes the values associated with key. -//// The key is case insensitive; it is canonicalized by -//// CanonicalHeaderKey. -//func (h Header) Del(key string) { -// textproto.MIMEHeader(h).Del(key) -//} -// -//// CanonicalHeaderKey returns the canonical format of the -//// header key s. The canonicalization converts the first -//// letter and any letter following a hyphen to upper case; -//// the rest are converted to lowercase. For example, the -//// canonical key for "accept-encoding" is "Accept-Encoding". -//// If s contains a space or invalid header field bytes, it is -//// returned without modifications. -//func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) } +func (h Header) Add(key, value string) { + textproto.MIMEHeader(h).Add(key, value) +} + +// Set sets the header entries associated with key to the +// single element value. It replaces any existing values +// associated with key. The key is case insensitive; it is +// canonicalized by textproto.CanonicalMIMEHeaderKey. +// To use non-canonical keys, assign to the map directly. +func (h Header) Set(key, value string) { + textproto.MIMEHeader(h).Set(key, value) +} + +// Get gets the first value associated with the given key. If +// there are no values associated with the key, Get returns "". +// It is case insensitive; textproto.CanonicalMIMEHeaderKey is +// used to canonicalize the provided key. Get assumes that all +// keys are stored in canonical form. To use non-canonical keys, +// access the map directly. +func (h Header) Get(key string) string { + return textproto.MIMEHeader(h).Get(key) +} + +// Values returns all values associated with the given key. +// It is case insensitive; textproto.CanonicalMIMEHeaderKey is +// used to canonicalize the provided key. To use non-canonical +// keys, access the map directly. +// The returned slice is not a copy. +func (h Header) Values(key string) []string { + return textproto.MIMEHeader(h).Values(key) +} + +// get is like Get, but key must already be in CanonicalHeaderKey form. +func (h Header) get(key string) string { + if v := h[key]; len(v) > 0 { + return v[0] + } + return "" +} + +// has reports whether h has the provided key defined, even if it's +// set to 0-length slice. +func (h Header) has(key string) bool { + _, ok := h[key] + return ok +} + +// Del deletes the values associated with key. +// The key is case insensitive; it is canonicalized by +// CanonicalHeaderKey. +func (h Header) Del(key string) { + textproto.MIMEHeader(h).Del(key) +} + +// CanonicalHeaderKey returns the canonical format of the +// header key s. The canonicalization converts the first +// letter and any letter following a hyphen to upper case; +// the rest are converted to lowercase. For example, the +// canonical key for "accept-encoding" is "Accept-Encoding". +// If s contains a space or invalid header field bytes, it is +// returned without modifications. +func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) } // AppendToResponseHeader (HeadersForEachCallback) prints each header to the console func AppendToResponseHeader(userdata c.Pointer, name *uint8, nameLen uintptr, value *uint8, valueLen uintptr) c.Int { resp := (*Response)(userdata) - nameStr := string((*[1 << 30]byte)(c.Pointer(name))[:nameLen:nameLen]) - valueStr := string((*[1 << 30]byte)(c.Pointer(value))[:valueLen:valueLen]) + nameStr := c.GoString((*int8)(c.Pointer(name)), nameLen) + valueStr := c.GoString((*int8)(c.Pointer(value)), valueLen) if resp.Header == nil { resp.Header = make(Header) } - //resp.Header.Add(nameStr, valueStr) - resp.Header[nameStr] = append(resp.Header[nameStr], valueStr) + resp.Header.Add(nameStr, valueStr) return hyper.IterContinue } diff --git a/x/http/request.go b/x/http/request.go index 1d219a1..c84cd9b 100644 --- a/x/http/request.go +++ b/x/http/request.go @@ -7,74 +7,169 @@ import ( "time" "github.com/goplus/llgo/c" + "github.com/goplus/llgo/c/os" "github.com/goplus/llgoexamples/rust/hyper" ) type Request struct { - Method string - URL *url.URL - Req *hyper.Request - Host string - Header Header - timeout time.Duration + Method string + URL *url.URL + Proto string // "HTTP/1.0" + ProtoMajor int // 1 + ProtoMinor int // 0 + Header Header + Body io.ReadCloser + GetBody func() (io.ReadCloser, error) + ContentLength int64 + TransferEncoding []string + Close bool + Host string + timeout time.Duration } +type postBody struct { + data []byte + len uintptr + readLen uintptr +} + +type uploadBody struct { + fd c.Int + buf []byte + len uintptr +} + +var DefaultChunkSize uintptr = 8192 + func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { - parseURL, err := url.Parse(urlStr) - if err != nil { - return nil, err - } - req, err := newHyperRequest(method, parseURL) + u, err := url.Parse(urlStr) if err != nil { return nil, err } + //rc, ok := body.(io.ReadCloser) + //if !ok && body != nil { + // rc = io.NopCloser(body) + //} request := &Request{ - Method: method, - URL: parseURL, - Req: req, - Host: parseURL.Hostname(), - Header: make(Header), + Method: method, + URL: u, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(Header), + Host: u.Host, + //Body: rc, timeout: 0, } - //request.Header.Set("Host", request.Host) - request.Header["Host"] = []string{request.Host} + request.Header.Set("Host", request.Host) + return request, nil } -func newHyperRequest(method string, URL *url.URL) (*hyper.Request, error) { - host := URL.Hostname() - uri := URL.RequestURI() +func PrintInformational(userdata c.Pointer, resp *hyper.Response) { + status := resp.Status() + fmt.Println("Informational (1xx): ", status) +} + +func SetPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { + //upload := (*uploadBody)(userdata) + //res := os.Read(upload.fd, c.Pointer(&upload.buf[0]), upload.len) + //if res > 0 { + // *chunk = hyper.CopyBuf(&upload.buf[0], uintptr(res)) + // return hyper.PollReady + //} + //if res == 0 { + // *chunk = nil + // os.Close(upload.fd) + // return hyper.PollReady + //} + body := (*postBody)(userdata) + if body.len > 0 { + if body.len > DefaultChunkSize { + *chunk = hyper.CopyBuf(&body.data[body.readLen], DefaultChunkSize) + body.readLen += DefaultChunkSize + body.len -= DefaultChunkSize + } else { + *chunk = hyper.CopyBuf(&body.data[body.readLen], body.len) + body.readLen += body.len + body.len = 0 + } + return hyper.PollReady + } + if body.len == 0 { + *chunk = nil + return hyper.PollReady + } + + fmt.Printf("error reading upload file: %s\n", c.GoString(c.Strerror(os.Errno))) + return hyper.PollError +} + +func newHyperRequest(req *Request) (*hyper.Request, error) { + host := req.Host + uri := req.URL.Path + method := req.Method // Prepare the request - req := hyper.NewRequest() + hyperReq := hyper.NewRequest() // Set the request method and uri - if req.SetMethod((*uint8)(&[]byte(method)[0]), c.Strlen(c.AllocaCStr(method))) != hyper.OK { + if hyperReq.SetMethod(&[]byte(method)[0], c.Strlen(c.AllocaCStr(method))) != hyper.OK { return nil, fmt.Errorf("error setting method %s\n", method) } - if req.SetURI((*uint8)(&[]byte(uri)[0]), c.Strlen(c.AllocaCStr(uri))) != hyper.OK { + if hyperReq.SetURI(&[]byte(uri)[0], c.Strlen(c.AllocaCStr(uri))) != hyper.OK { return nil, fmt.Errorf("error setting uri %s\n", uri) } // Set the request headers - reqHeaders := req.Headers() - if reqHeaders.Set((*uint8)(&[]byte("Host")[0]), c.Strlen(c.Str("Host")), (*uint8)(&[]byte(host)[0]), c.Strlen(c.AllocaCStr(host))) != hyper.OK { + reqHeaders := hyperReq.Headers() + if reqHeaders.Set(&[]byte("Host")[0], c.Strlen(c.Str("Host")), &[]byte(host)[0], c.Strlen(c.AllocaCStr(host))) != hyper.OK { return nil, fmt.Errorf("error setting header: Host: %s\n", host) } - return req, nil + if method == "POST" { + //var upload uploadBody + //upload.fd = os.Open(c.Str("/Users/spongehah/go/src/llgo/x/http/_demo/post/example.txt"), os.O_RDONLY) + //if upload.fd < 0 { + // return nil, fmt.Errorf("error opening file to upload: %s\n", c.GoString(c.Strerror(os.Errno))) + //} + //upload.len = 8192 + //upload.buf = make([]byte, upload.len) + req.Header.Set("expect", "100-continue") + hyperReq.OnInformational(PrintInformational, nil) + postData := []byte(`{"id":1,"title":"foo","body":"bar","userId":"1"}`) + + reqBody := &postBody{ + data: postData, + len: uintptr(len(postData)), + } + + hyperReqBody := hyper.NewBody() + hyperReqBody.SetUserdata(c.Pointer(reqBody)) + //hyperReqBody.SetUserdata(c.Pointer(&upload)) + hyperReqBody.SetDataFunc(SetPostData) + hyperReq.SetBody(hyperReqBody) + } + + // Add user-defined request headers to hyper.Request + err := req.setHeaders(hyperReq) + if err != nil { + return nil, err + } + + return hyperReq, nil } // setHeaders sets the headers of the request -func (req *Request) setHeaders() error { - headers := req.Req.Headers() +func (req *Request) setHeaders(hyperReq *hyper.Request) error { + headers := hyperReq.Headers() for key, values := range req.Header { valueLen := len(values) if valueLen > 1 { for _, value := range values { - if headers.Add((*uint8)(&[]byte(key)[0]), c.Strlen(c.AllocaCStr(key)), (*uint8)(&[]byte(value)[0]), c.Strlen(c.AllocaCStr(value))) != hyper.OK { + if headers.Add(&[]byte(key)[0], c.Strlen(c.AllocaCStr(key)), &[]byte(value)[0], c.Strlen(c.AllocaCStr(value))) != hyper.OK { return fmt.Errorf("error adding header %s: %s\n", key, value) } } } else if valueLen == 1 { - if headers.Set((*uint8)(&[]byte(key)[0]), c.Strlen(c.AllocaCStr(key)), (*uint8)(&[]byte(values[0])[0]), c.Strlen(c.AllocaCStr(values[0]))) != hyper.OK { + if headers.Set(&[]byte(key)[0], c.Strlen(c.AllocaCStr(key)), &[]byte(values[0])[0], c.Strlen(c.AllocaCStr(values[0]))) != hyper.OK { return fmt.Errorf("error setting header %s: %s\n", key, values[0]) } } else { @@ -83,3 +178,10 @@ func (req *Request) setHeaders() error { } return nil } + +func (r *Request) closeBody() error { + if r.Body == nil { + return nil + } + return r.Body.Close() +} diff --git a/x/http/response.go b/x/http/response.go index a08d77e..c99bade 100644 --- a/x/http/response.go +++ b/x/http/response.go @@ -10,24 +10,43 @@ import ( ) type Response struct { - Status string // e.g. "200 OK" - StatusCode int // e.g. 200 - Proto string // e.g. "HTTP/1.0" - ProtoMajor int // e.g. 1 - ProtoMinor int // e.g. 0 - Header Header - Body io.ReadCloser - ContentLength int64 - Trailer Header - Chunked bool - Request *Request + Status string // e.g. "200 OK" + StatusCode int // e.g. 200 + Proto string // e.g. "HTTP/1.0" + ProtoMajor int // e.g. 1 + ProtoMinor int // e.g. 0 + Header Header + Body io.ReadCloser + ContentLength int64 + TransferEncoding []string + Close bool + Trailer Header + Request *Request } +func ReadResponse(hyperResp *hyper.Response, req *Request) (*Response, error) { + resp := &Response{ + Request: req, + Header: make(Header), + Trailer: make(Header), + } + readResponseLineAndHeader(resp, hyperResp) + + fixPragmaCacheControl(req.Header) + + err := readTransfer(resp) + if err != nil { + return nil, err + } + return resp, nil +} + +// readResponseLineAndHeader reads the response line and header from hyper response. func readResponseLineAndHeader(resp *Response, hyperResp *hyper.Response) { rp := hyperResp.ReasonPhrase() rpLen := hyperResp.ReasonPhraseLen() - resp.Status = strconv.Itoa(int(hyperResp.Status())) + " " + string((*[1 << 30]byte)(c.Pointer(rp))[:rpLen:rpLen]) + resp.Status = strconv.Itoa(int(hyperResp.Status())) + " " + c.GoString((*int8)(c.Pointer(rp)), rpLen) resp.StatusCode = int(hyperResp.Status()) version := int(hyperResp.Version()) @@ -37,3 +56,18 @@ func readResponseLineAndHeader(resp *Response, hyperResp *hyper.Response) { headers := hyperResp.Headers() headers.Foreach(AppendToResponseHeader, c.Pointer(resp)) } + +// RFC 7234, section 5.4: Should treat +// +// Pragma: no-cache +// +// like +// +// Cache-Control: no-cache +func fixPragmaCacheControl(header Header) { + if hp, ok := header["Pragma"]; ok && len(hp) > 0 && hp[0] == "no-cache" { + if _, presentcc := header["Cache-Control"]; !presentcc { + header["Cache-Control"] = []string{"no-cache"} + } + } +} diff --git a/x/http/transfer.go b/x/http/transfer.go index 5157324..70f082e 100644 --- a/x/http/transfer.go +++ b/x/http/transfer.go @@ -1,332 +1,488 @@ package http -// -//import ( -// "fmt" -// "io" -// "net/textproto" -// "strconv" -// "strings" -// -// "github.com/goplus/llgoexamples/rust/hyper" -//) -// -//type transferReader struct { -// // Input -// Header Header -// StatusCode int -// RequestMethod string -// ProtoMajor int -// ProtoMinor int -// // Output -// Body io.ReadCloser -// ContentLength int64 -// Chunked bool -// Close bool -// Trailer Header -//} -// -//// unsupportedTEError reports unsupported transfer-encodings. -//type unsupportedTEError struct { -// err string -//} -// -//func (uste *unsupportedTEError) Error() string { -// return uste.err -//} -// -//func readTransfer(resp *Response, hyperResp *hyper.Response) (err error) { -// //// TODO(spongehah) Replace header operations with using the textproto package -// //lengthSlice := resp.Header["content-length"] -// //if lengthSlice == nil { -// // resp.ContentLength = -1 -// //} else { -// // contentLength := resp.Header["content-length"][0] -// // length, err := strconv.Atoi(contentLength) -// // if err != nil { -// // return err -// // } -// // resp.ContentLength = int64(length) -// //} -// -// t := &transferReader{ -// Header: resp.Header, -// StatusCode: resp.StatusCode, -// RequestMethod: resp.Request.Method, -// ProtoMajor: resp.ProtoMajor, -// ProtoMinor: resp.ProtoMinor, -// } -// -// // Transfer-Encoding: chunked, and overriding Content-Length. -// if err = t.parseTransferEncoding(); err != nil { -// return err -// } -// -// realLength, err := fixLength(true, t.StatusCode, t.RequestMethod, t.Header, t.Chunked) -// if err != nil { -// return err -// } -// if t.RequestMethod == "HEAD" { -// if n, err := parseContentLength(t.Header.get("Content-Length")); err != nil { -// return err -// } else { -// t.ContentLength = n -// } -// } else { -// t.ContentLength = realLength -// } -// -// // Trailer -// t.Trailer, err = fixTrailer(t.Header, t.Chunked) -// -// // If there is no Content-Length or chunked Transfer-Encoding on a *Response -// // and the status is not 1xx, 204 or 304, then the body is unbounded. -// // See RFC 7230, section 3.3. -// if realLength == -1 && !t.Chunked && bodyAllowedForStatus(t.StatusCode) { -// // Unbounded body. -// t.Close = true -// } -// -// return nil -//} -// -//// parseTransferEncoding sets t.Chunked based on the Transfer-Encoding header. -//func (t *transferReader) parseTransferEncoding() error { -// raw, present := t.Header["Transfer-Encoding"] -// if !present { -// return nil -// } -// delete(t.Header, "Transfer-Encoding") -// -// // Issue 12785; ignore Transfer-Encoding on HTTP/1.0 requests. -// if !t.protoAtLeast(1, 1) { -// return nil -// } -// -// // Like nginx, we only support a single Transfer-Encoding header field, and -// // only if set to "chunked". This is one of the most security sensitive -// // surfaces in HTTP/1.1 due to the risk of request smuggling, so we keep it -// // strict and simple. -// if len(raw) != 1 { -// return &unsupportedTEError{fmt.Sprintf("too many transfer encodings: %q", raw)} -// } -// if !equalFold(raw[0], "chunked") { -// return &unsupportedTEError{fmt.Sprintf("unsupported transfer encoding: %q", raw[0])} -// } -// -// // RFC 7230 3.3.2 says "A sender MUST NOT send a Content-Length header field -// // in any message that contains a Transfer-Encoding header field." -// // -// // but also: "If a message is received with both a Transfer-Encoding and a -// // Content-Length header field, the Transfer-Encoding overrides the -// // Content-Length. Such a message might indicate an attempt to perform -// // request smuggling (Section 9.5) or response splitting (Section 9.4) and -// // ought to be handled as an error. A sender MUST remove the received -// // Content-Length field prior to forwarding such a message downstream." -// // -// // Reportedly, these appear in the wild. -// delete(t.Header, "Content-Length") -// -// t.Chunked = true -// return nil -//} -// -//func (t *transferReader) protoAtLeast(m, n int) bool { -// return t.ProtoMajor > m || (t.ProtoMajor == m && t.ProtoMinor >= n) -//} -// -//// equalFold is strings.EqualFold, ASCII only. It reports whether s and t -//// are equal, ASCII-case-insensitively. -//func equalFold(s, t string) bool { -// if len(s) != len(t) { -// return false -// } -// for i := 0; i < len(s); i++ { -// if lower(s[i]) != lower(t[i]) { -// return false -// } -// } -// return true -//} -// -//// Determine the expected body length, using RFC 7230 Section 3.3. This -//// function is not a method, because ultimately it should be shared by -//// ReadResponse and ReadRequest. -//func fixLength(isResponse bool, status int, requestMethod string, header Header, chunked bool) (int64, error) { -// isRequest := !isResponse -// contentLens := header["Content-Length"] -// -// // Hardening against HTTP request smuggling -// if len(contentLens) > 1 { -// // Per RFC 7230 Section 3.3.2, prevent multiple -// // Content-Length headers if they differ in value. -// // If there are dups of the value, remove the dups. -// // See Issue 16490. -// first := textproto.TrimString(contentLens[0]) -// for _, ct := range contentLens[1:] { -// if first != textproto.TrimString(ct) { -// return 0, fmt.Errorf("http: message cannot contain multiple Content-Length headers; got %q", contentLens) -// } -// } -// -// // deduplicate Content-Length -// header.Del("Content-Length") -// header.Add("Content-Length", first) -// -// contentLens = header["Content-Length"] -// } -// -// // Logic based on response type or status -// if isResponse && noResponseBodyExpected(requestMethod) { -// return 0, nil -// } -// if status/100 == 1 { -// return 0, nil -// } -// switch status { -// case 204, 304: -// return 0, nil -// } -// -// // Logic based on Transfer-Encoding -// if chunked { -// return -1, nil -// } -// -// // Logic based on Content-Length -// var cl string -// if len(contentLens) == 1 { -// cl = textproto.TrimString(contentLens[0]) -// } -// if cl != "" { -// n, err := parseContentLength(cl) -// if err != nil { -// return -1, err -// } -// return n, nil -// } -// header.Del("Content-Length") -// -// if isRequest { -// // RFC 7230 neither explicitly permits nor forbids an -// // entity-body on a GET request so we permit one if -// // declared, but we default to 0 here (not -1 below) -// // if there's no mention of a body. -// // Likewise, all other request methods are assumed to have -// // no body if neither Transfer-Encoding chunked nor a -// // Content-Length are set. -// return 0, nil -// } -// -// // Body-EOF logic based on other methods (like closing, or chunked coding) -// return -1, nil -//} -// -//// parseContentLength trims whitespace from s and returns -1 if no value -//// is set, or the value if it's >= 0. -//func parseContentLength(cl string) (int64, error) { -// cl = textproto.TrimString(cl) -// if cl == "" { -// return -1, nil -// } -// n, err := strconv.ParseUint(cl, 10, 63) -// if err != nil { -// return 0, badStringError("bad Content-Length", cl) -// } -// return int64(n), nil -// -//} -// -//// Parse the trailer header. -//func fixTrailer(header Header, chunked bool) (Header, error) { -// vv, ok := header["Trailer"] -// if !ok { -// return nil, nil -// } -// if !chunked { -// // Trailer and no chunking: -// // this is an invalid use case for trailer header. -// // Nevertheless, no error will be returned and we -// // let users decide if this is a valid HTTP message. -// // The Trailer header will be kept in Response.Header -// // but not populate Response.Trailer. -// // See issue #27197. -// return nil, nil -// } -// header.Del("Trailer") -// -// trailer := make(Header) -// var err error -// for _, v := range vv { -// foreachHeaderElement(v, func(key string) { -// key = CanonicalHeaderKey(key) -// switch key { -// case "Transfer-Encoding", "Trailer", "Content-Length": -// if err == nil { -// err = badStringError("bad trailer key", key) -// return -// } -// } -// trailer[key] = nil -// }) -// } -// if err != nil { -// return nil, err -// } -// if len(trailer) == 0 { -// return nil, nil -// } -// return trailer, nil -//} -// -//// splitTwoDigitNumber splits a two-digit number into two digits. + +import ( + "fmt" + "io" + "net/textproto" + "strconv" + "strings" + "unicode/utf8" +) + +type transferReader struct { + // Input + Header Header + StatusCode int + RequestMethod string + ProtoMajor int + ProtoMinor int + // Output + Body io.ReadCloser + ContentLength int64 + Chunked bool + Close bool + Trailer Header +} + +// unsupportedTEError reports unsupported transfer-encodings. +type unsupportedTEError struct { + err string +} + +func (uste *unsupportedTEError) Error() string { + return uste.err +} + +// NoBody is an io.ReadCloser with no bytes. Read always returns EOF +// and Close always returns nil. It can be used in an outgoing client +// request to explicitly signal that a request has zero bytes. +// An alternative, however, is to simply set Request.Body to nil. +var NoBody = noBody{} + +type noBody struct{} + +func (noBody) Read([]byte) (int, error) { return 0, io.EOF } +func (noBody) Close() error { return nil } +func (noBody) WriteTo(io.Writer) (int64, error) { return 0, nil } + +func readTransfer(msg any) (err error) { + t := &transferReader{RequestMethod: "GET"} + + // Unify input + isResponse := false + switch rr := msg.(type) { + case *Response: + t.Header = rr.Header + t.StatusCode = rr.StatusCode + t.ProtoMajor = rr.ProtoMajor + t.ProtoMinor = rr.ProtoMinor + t.Close = shouldClose(t.ProtoMajor, t.ProtoMinor, t.Header, true) + isResponse = true + if rr.Request != nil { + t.RequestMethod = rr.Request.Method + } + case *Request: + t.Header = rr.Header + t.RequestMethod = rr.Method + t.ProtoMajor = rr.ProtoMajor + t.ProtoMinor = rr.ProtoMinor + // Transfer semantics for Requests are exactly like those for + // Responses with status code 200, responding to a GET method + t.StatusCode = 200 + t.Close = rr.Close + default: + panic("unexpected type") + } + + // Transfer-Encoding: chunked, and overriding Content-Length. + if err = t.parseTransferEncoding(); err != nil { + return err + } + + realLength, err := fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.Chunked) + if err != nil { + return err + } + if isResponse && t.RequestMethod == "HEAD" { + if n, err := parseContentLength(t.Header.get("Content-Length")); err != nil { + return err + } else { + t.ContentLength = n + } + } else { + t.ContentLength = realLength + } + + // Trailer + t.Trailer, err = fixTrailer(t.Header, t.Chunked) + + // If there is no Content-Length or chunked Transfer-Encoding on a *Response + // and the status is not 1xx, 204 or 304, then the body is unbounded. + // See RFC 7230, section 3.3. + switch msg.(type) { + case *Response: + if realLength == -1 && !t.Chunked && bodyAllowedForStatus(t.StatusCode) { + // Unbounded body. + t.Close = true + } + } + + // Prepare body reader. ContentLength < 0 means chunked encoding + // or close connection when finished, since multipart is not supported yet + //switch { + //case t.Chunked: + // if isResponse && noResponseBodyExpected(t.RequestMethod) || !bodyAllowedForStatus(t.StatusCode) { + // t.Body = NoBody + // } else { + // t.Body = &body{src: internal.NewChunkedReader(r), hdr: msg, r: r, closing: t.Close} + // } + //case realLength == 0: + // t.Body = NoBody + //case realLength > 0: + // t.Body = &body{src: io.LimitReader(r, realLength), closing: t.Close} + //default: + // // realLength < 0, i.e. "Content-Length" not mentioned in header + // if t.Close { + // // Close semantics (i.e. HTTP/1.0) + // t.Body = &body{src: r, closing: t.Close} + // } else { + // // Persistent connection (i.e. HTTP/1.1) + // t.Body = NoBody + // } + //} + + // Unify output + switch rr := msg.(type) { + case *Request: + //rr.Body = t.Body + //rr.ContentLength = t.ContentLength + //if t.Chunked { + // rr.TransferEncoding = []string{"chunked"} + //} + rr.Close = t.Close + //rr.Trailer = t.Trailer + case *Response: + //rr.Body = t.Body + rr.ContentLength = t.ContentLength + if t.Chunked { + rr.TransferEncoding = []string{"chunked"} + } + rr.Close = t.Close + rr.Trailer = t.Trailer + } + + return nil +} + +// parseTransferEncoding sets t.Chunked based on the Transfer-Encoding header. +func (t *transferReader) parseTransferEncoding() error { + raw, present := t.Header["Transfer-Encoding"] + if !present { + return nil + } + delete(t.Header, "Transfer-Encoding") + + // Issue 12785; ignore Transfer-Encoding on HTTP/1.0 requests. + if !t.protoAtLeast(1, 1) { + return nil + } + + // Like nginx, we only support a single Transfer-Encoding header field, and + // only if set to "chunked". This is one of the most security sensitive + // surfaces in HTTP/1.1 due to the risk of request smuggling, so we keep it + // strict and simple. + if len(raw) != 1 { + return &unsupportedTEError{fmt.Sprintf("too many transfer encodings: %q", raw)} + } + if !equalFold(raw[0], "chunked") { + return &unsupportedTEError{fmt.Sprintf("unsupported transfer encoding: %q", raw[0])} + } + + // RFC 7230 3.3.2 says "A sender MUST NOT send a Content-Length header field + // in any message that contains a Transfer-Encoding header field." + // + // but also: "If a message is received with both a Transfer-Encoding and a + // Content-Length header field, the Transfer-Encoding overrides the + // Content-Length. Such a message might indicate an attempt to perform + // request smuggling (Section 9.5) or response splitting (Section 9.4) and + // ought to be handled as an error. A sender MUST remove the received + // Content-Length field prior to forwarding such a message downstream." + // + // Reportedly, these appear in the wild. + delete(t.Header, "Content-Length") + + t.Chunked = true + return nil +} + +func (t *transferReader) protoAtLeast(m, n int) bool { + return t.ProtoMajor > m || (t.ProtoMajor == m && t.ProtoMinor >= n) +} + +// equalFold is strings.EqualFold, ASCII only. It reports whether s and t +// are equal, ASCII-case-insensitively. +func equalFold(s, t string) bool { + if len(s) != len(t) { + return false + } + for i := 0; i < len(s); i++ { + if lower(s[i]) != lower(t[i]) { + return false + } + } + return true +} + +// Determine the expected body length, using RFC 7230 Section 3.3. This +// function is not a method, because ultimately it should be shared by +// ReadResponse and ReadRequest. +func fixLength(isResponse bool, status int, requestMethod string, header Header, chunked bool) (int64, error) { + isRequest := !isResponse + contentLens := header["Content-Length"] + + // Hardening against HTTP request smuggling + if len(contentLens) > 1 { + // Per RFC 7230 Section 3.3.2, prevent multiple + // Content-Length headers if they differ in value. + // If there are dups of the value, remove the dups. + // See Issue 16490. + first := textproto.TrimString(contentLens[0]) + for _, ct := range contentLens[1:] { + if first != textproto.TrimString(ct) { + return 0, fmt.Errorf("http: message cannot contain multiple Content-Length headers; got %q", contentLens) + } + } + + // deduplicate Content-Length + header.Del("Content-Length") + header.Add("Content-Length", first) + + contentLens = header["Content-Length"] + } + + // Logic based on response type or status + if isResponse && noResponseBodyExpected(requestMethod) { + return 0, nil + } + if status/100 == 1 { + return 0, nil + } + switch status { + case 204, 304: + return 0, nil + } + + // Logic based on Transfer-Encoding + if chunked { + return -1, nil + } + + // Logic based on Content-Length + var cl string + if len(contentLens) == 1 { + cl = textproto.TrimString(contentLens[0]) + } + if cl != "" { + n, err := parseContentLength(cl) + if err != nil { + return -1, err + } + return n, nil + } + header.Del("Content-Length") + + if isRequest { + // RFC 7230 neither explicitly permits nor forbids an + // entity-body on a GET request so we permit one if + // declared, but we default to 0 here (not -1 below) + // if there's no mention of a body. + // Likewise, all other request methods are assumed to have + // no body if neither Transfer-Encoding chunked nor a + // Content-Length are set. + return 0, nil + } + + // Body-EOF logic based on other methods (like closing, or chunked coding) + return -1, nil +} + +// parseContentLength trims whitespace from s and returns -1 if no value +// is set, or the value if it's >= 0. +func parseContentLength(cl string) (int64, error) { + cl = textproto.TrimString(cl) + if cl == "" { + return -1, nil + } + n, err := strconv.ParseUint(cl, 10, 63) + if err != nil { + return 0, badStringError("bad Content-Length", cl) + } + return int64(n), nil + +} + +// Parse the trailer header. +func fixTrailer(header Header, chunked bool) (Header, error) { + vv, ok := header["Trailer"] + if !ok { + return nil, nil + } + if !chunked { + // Trailer and no chunking: + // this is an invalid use case for trailer header. + // Nevertheless, no error will be returned and we + // let users decide if this is a valid HTTP message. + // The Trailer header will be kept in Response.Header + // but not populate Response.Trailer. + // See issue #27197. + return nil, nil + } + header.Del("Trailer") + + trailer := make(Header) + var err error + for _, v := range vv { + foreachHeaderElement(v, func(key string) { + key = CanonicalHeaderKey(key) + switch key { + case "Transfer-Encoding", "Trailer", "Content-Length": + if err == nil { + err = badStringError("bad trailer key", key) + return + } + } + trailer[key] = nil + }) + } + if err != nil { + return nil, err + } + if len(trailer) == 0 { + return nil, nil + } + return trailer, nil +} + +// splitTwoDigitNumber splits a two-digit number into two digits. func splitTwoDigitNumber(num int) (int, int) { tens := num / 10 ones := num % 10 return tens, ones } -// -//// lower returns the ASCII lowercase version of b. -//func lower(b byte) byte { -// if 'A' <= b && b <= 'Z' { -// return b + ('a' - 'A') -// } -// return b -//} -// -//// foreachHeaderElement splits v according to the "#rule" construction -//// in RFC 7230 section 7 and calls fn for each non-empty element. -//func foreachHeaderElement(v string, fn func(string)) { -// v = textproto.TrimString(v) -// if v == "" { -// return -// } -// if !strings.Contains(v, ",") { -// fn(v) -// return -// } -// for _, f := range strings.Split(v, ",") { -// if f = textproto.TrimString(f); f != "" { -// fn(f) -// } -// } -//} -// -//func noResponseBodyExpected(requestMethod string) bool { -// return requestMethod == "HEAD" -//} -// -//func badStringError(what, val string) error { return fmt.Errorf("%s %q", what, val) } -// -//// bodyAllowedForStatus reports whether a given response status code -//// permits a body. See RFC 7230, section 3.3. -//func bodyAllowedForStatus(status int) bool { -// switch { -// case status >= 100 && status <= 199: -// return false -// case status == 204: -// return false -// case status == 304: -// return false -// } -// return true -//} + +// lower returns the ASCII lowercase version of b. +func lower(b byte) byte { + if 'A' <= b && b <= 'Z' { + return b + ('a' - 'A') + } + return b +} + +// foreachHeaderElement splits v according to the "#rule" construction +// in RFC 7230 section 7 and calls fn for each non-empty element. +func foreachHeaderElement(v string, fn func(string)) { + v = textproto.TrimString(v) + if v == "" { + return + } + if !strings.Contains(v, ",") { + fn(v) + return + } + for _, f := range strings.Split(v, ",") { + if f = textproto.TrimString(f); f != "" { + fn(f) + } + } +} + +func noResponseBodyExpected(requestMethod string) bool { + return requestMethod == "HEAD" +} + +func badStringError(what, val string) error { return fmt.Errorf("%s %q", what, val) } + +// bodyAllowedForStatus reports whether a given response status code +// permits a body. See RFC 7230, section 3.3. +func bodyAllowedForStatus(status int) bool { + switch { + case status >= 100 && status <= 199: + return false + case status == 204: + return false + case status == 304: + return false + } + return true +} + +// Determine whether to hang up after sending a request and body, or +// receiving a response and body +// 'header' is the request headers. +func shouldClose(major, minor int, header Header, removeCloseHeader bool) bool { + if major < 1 { + return true + } + + conv := header["Connection"] + hasClose := HeaderValuesContainsToken(conv, "close") + if major == 1 && minor == 0 { + return hasClose || !HeaderValuesContainsToken(conv, "keep-alive") + } + + if hasClose && removeCloseHeader { + header.Del("Connection") + } + + return hasClose +} + +// HeaderValuesContainsToken reports whether any string in values +// contains the provided token, ASCII case-insensitively. +func HeaderValuesContainsToken(values []string, token string) bool { + for _, v := range values { + if headerValueContainsToken(v, token) { + return true + } + } + return false +} + +// headerValueContainsToken reports whether v (assumed to be a +// 0#element, in the ABNF extension described in RFC 7230 section 7) +// contains token amongst its comma-separated tokens, ASCII +// case-insensitively. +func headerValueContainsToken(v string, token string) bool { + for comma := strings.IndexByte(v, ','); comma != -1; comma = strings.IndexByte(v, ',') { + if tokenEqual(trimOWS(v[:comma]), token) { + return true + } + v = v[comma+1:] + } + return tokenEqual(trimOWS(v), token) +} + +// tokenEqual reports whether t1 and t2 are equal, ASCII case-insensitively. +func tokenEqual(t1, t2 string) bool { + if len(t1) != len(t2) { + return false + } + for i, b := range t1 { + if b >= utf8.RuneSelf { + // No UTF-8 or non-ASCII allowed in tokens. + return false + } + if lowerASCII(byte(b)) != lowerASCII(t2[i]) { + return false + } + } + return true +} + +// trimOWS returns x with all optional whitespace removes from the +// beginning and end. +func trimOWS(x string) string { + // TODO: consider using strings.Trim(x, " \t") instead, + // if and when it's fast enough. See issue 10292. + // But this ASCII-only code will probably always beat UTF-8 + // aware code. + for len(x) > 0 && isOWS(x[0]) { + x = x[1:] + } + for len(x) > 0 && isOWS(x[len(x)-1]) { + x = x[:len(x)-1] + } + return x +} + +// lowerASCII returns the ASCII lowercase version of b. +func lowerASCII(b byte) byte { + if 'A' <= b && b <= 'Z' { + return b + ('a' - 'A') + } + return b +} + +// isOWS reports whether b is an optional whitespace byte, as defined +// by RFC 7230 section 3.2.3. +func isOWS(b byte) bool { return b == ' ' || b == '\t' } diff --git a/x/http/transport.go b/x/http/transport.go index b9f845c..2dad490 100644 --- a/x/http/transport.go +++ b/x/http/transport.go @@ -228,11 +228,6 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // Poll all ready tasks and act on them... rc := <-pc.reqch // blocking alive := true - resp := &Response{ - Request: rc.req, - Header: make(Header), - Trailer: make(Header), - } var bodyWriter *io.PipeWriter var respBody *hyper.Body = nil for alive { @@ -263,8 +258,17 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { client := (*hyper.ClientConn)(task.Value()) task.Free() + // Prepare the hyper.Request + hyperReq, err := newHyperRequest(rc.req) + if err != nil { + rc.ch <- responseAndError{err: err} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return + } + // Send it! - sendTask := client.Send(rc.req.Req) + sendTask := client.Send(hyperReq) SetTaskId(sendTask, ReceiveResp) sendRes := exec.Push(sendTask) if sendRes != hyper.OK { @@ -289,14 +293,13 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { hyperResp := (*hyper.Response)(task.Value()) task.Free() - readResponseLineAndHeader(resp, hyperResp) - //err = readTransfer(resp, hyperResp) - //if err != nil { - // rc.ch <- responseAndError{err: err} - // // Free the resources - // FreeResources(task, respBody, bodyWriter, exec, pc, rc) - // return - //} + resp, err := ReadResponse(hyperResp, rc.req) + if err != nil { + rc.ch <- responseAndError{err: err} + // Free the resources + FreeResources(task, respBody, bodyWriter, exec, pc, rc) + return + } respBody = hyperResp.Body() resp.Body, bodyWriter = io.Pipe() @@ -395,6 +398,8 @@ func AllocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { conn := (*ConnData)(handle.GetData()) if conn.ReadBuf.Base == nil { conn.ReadBuf = libuv.InitBuf((*c.Char)(c.Malloc(suggestedSize)), c.Uint(suggestedSize)) + //base := make([]byte, suggestedSize) + //conn.ReadBuf = libuv.InitBuf((*c.Char)(c.Pointer(&base[0])), c.Uint(suggestedSize)) conn.ReadBufFilled = 0 } *buf = libuv.InitBuf((*c.Char)(c.Pointer(uintptr(c.Pointer(conn.ReadBuf.Base))+conn.ReadBufFilled)), c.Uint(suggestedSize-conn.ReadBufFilled)) From ba2a9d098f61d4f2a5fe85c712497cf57fbeb8ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Fri, 16 Aug 2024 18:09:37 +0800 Subject: [PATCH 12/55] WIP(x/http/client): Implement http.Post() and redirection logic --- go.mod | 7 +- go.sum | 8 +- x/http/_demo/post/post.go | 4 +- x/http/_demo/redirect/redirect.go | 26 ++ x/http/_demo/server/redirectServer.go | 26 ++ x/http/_demo/timeout/timeout.go | 4 +- x/http/client.go | 586 ++++++++++++++++++++++++-- x/http/clone.go | 11 + x/http/cookie.go | 232 ++++++++++ x/http/header.go | 71 +++- x/http/http.go | 27 ++ x/http/jar.go | 27 ++ x/http/request.go | 286 ++++++++++--- x/http/response.go | 19 +- x/http/transfer.go | 35 +- x/http/transport.go | 180 +++++--- x/http/util.go | 146 +++++++ 17 files changed, 1482 insertions(+), 213 deletions(-) create mode 100644 x/http/_demo/redirect/redirect.go create mode 100644 x/http/_demo/server/redirectServer.go create mode 100644 x/http/clone.go create mode 100644 x/http/cookie.go create mode 100644 x/http/http.go create mode 100644 x/http/jar.go create mode 100644 x/http/util.go diff --git a/go.mod b/go.mod index 4082df2..f961f75 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,9 @@ module github.com/goplus/llgoexamples go 1.20 -require github.com/goplus/llgo v0.9.7-0.20240812013847-321766fd4641 +require ( + github.com/goplus/llgo v0.9.7-0.20240816085229-53d2d080f4c4 + golang.org/x/net v0.28.0 +) + +require golang.org/x/text v0.17.0 // indirect diff --git a/go.sum b/go.sum index e3abd53..4c64063 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,6 @@ -github.com/goplus/llgo v0.9.7-0.20240812013847-321766fd4641 h1:VIJ38bCFRIIr62YXyRKkxy6GXYVA6R3xqAb0HkcoUgw= -github.com/goplus/llgo v0.9.7-0.20240812013847-321766fd4641/go.mod h1:5Fs+08NslqofJ7xtOiIXugkurYOoQvY02ZkFNWA1uEI= +github.com/goplus/llgo v0.9.7-0.20240816085229-53d2d080f4c4 h1:fqqbWhWaoseSplLJF8OTkNGl4Kruqm1wQWT/Yooq6E4= +github.com/goplus/llgo v0.9.7-0.20240816085229-53d2d080f4c4/go.mod h1:5Fs+08NslqofJ7xtOiIXugkurYOoQvY02ZkFNWA1uEI= +golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= +golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= +golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= +golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= diff --git a/x/http/_demo/post/post.go b/x/http/_demo/post/post.go index a86e805..4958a8e 100644 --- a/x/http/_demo/post/post.go +++ b/x/http/_demo/post/post.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "fmt" "io" @@ -8,7 +9,8 @@ import ( ) func main() { - resp, err := http.Post("https://jsonplaceholder.typicode.com/posts", "application/json; charset=UTF-8", nil) + data := []byte(`{"id":1,"title":"foo","body":"bar","userId":"1"}`) + resp, err := http.Post("https://jsonplaceholder.typicode.com/posts", "application/json; charset=UTF-8", bytes.NewBuffer(data)) if err != nil { fmt.Println(err) return diff --git a/x/http/_demo/redirect/redirect.go b/x/http/_demo/redirect/redirect.go new file mode 100644 index 0000000..48465b7 --- /dev/null +++ b/x/http/_demo/redirect/redirect.go @@ -0,0 +1,26 @@ +package main + +import ( + "fmt" + "io" + + "github.com/goplus/llgoexamples/x/http" +) + +func main() { + resp, err := http.Get("http://localhost:8080") // Start "../server/redirectServer.go" before running + if err != nil { + fmt.Println(err) + return + } + fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) + fmt.Println(resp.Proto) + resp.PrintHeaders() + body, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(string(body)) + defer resp.Body.Close() +} diff --git a/x/http/_demo/server/redirectServer.go b/x/http/_demo/server/redirectServer.go new file mode 100644 index 0000000..a6830af --- /dev/null +++ b/x/http/_demo/server/redirectServer.go @@ -0,0 +1,26 @@ +package main + +import ( + "fmt" + "log" + "net/http" +) + +func main() { + http.HandleFunc("/", handleInitialRequest) + http.HandleFunc("/redirect", handleRedirectRequest) + + fmt.Println("Server is running on http://localhost:8080") + log.Fatal(http.ListenAndServe(":8080", nil)) +} + +func handleInitialRequest(w http.ResponseWriter, r *http.Request) { + log.Println("Received initial request, redirecting...") + http.Redirect(w, r, "/redirect", http.StatusSeeOther) +} + +func handleRedirectRequest(w http.ResponseWriter, r *http.Request) { + log.Println("Received redirect request, sending response...") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "Hello redirect") +} diff --git a/x/http/_demo/timeout/timeout.go b/x/http/_demo/timeout/timeout.go index 42f8bf8..6eece04 100644 --- a/x/http/_demo/timeout/timeout.go +++ b/x/http/_demo/timeout/timeout.go @@ -10,8 +10,8 @@ import ( func main() { client := &http.Client{ - Timeout: time.Millisecond, // Set a small timeout to ensure it will time out - //Timeout: time.Second * 5, + //Timeout: time.Millisecond, // Set a small timeout to ensure it will time out + Timeout: time.Second * 5, } req, err := http.NewRequest("GET", "https://www.baidu.com", nil) if err != nil { diff --git a/x/http/client.go b/x/http/client.go index 72e9688..31362a9 100644 --- a/x/http/client.go +++ b/x/http/client.go @@ -1,16 +1,25 @@ package http import ( + "context" + "encoding/base64" "errors" + "fmt" "io" + "log" "net/url" + "sort" "strings" + "sync" + "sync/atomic" "time" ) type Client struct { - Transport RoundTripper - Timeout time.Duration + Transport RoundTripper + CheckRedirect func(req *Request, via []*Request) error + Jar CookieJar + Timeout time.Duration } var DefaultClient = &Client{} @@ -38,6 +47,14 @@ func (c *Client) Get(url string) (*Response, error) { return c.Do(req) } +func alwaysFalse() bool { return false } + +// ErrUseLastResponse can be returned by Client.CheckRedirect hooks to +// control how redirects are processed. If returned, the next request +// is not sent and the most recent response is returned with its body +// unclosed. +var ErrUseLastResponse = errors.New("net/http: use last response") + func Post(url, contentType string, body io.Reader) (resp *Response, err error) { return DefaultClient.Post(url, contentType, body) } @@ -70,15 +87,15 @@ func (c *Client) do(req *Request) (retres *Response, reterr error) { } } var ( - //deadline = c.deadline() - reqs []*Request - resp *Response - //copyHeaders = c.makeHeadersCopier(req) + deadline = c.deadline() + reqs []*Request + resp *Response + copyHeaders = c.makeHeadersCopier(req) reqBodyClosed = false // have we closed the current req.Body? // Redirect behavior: - //redirectMethod string - //includeBody bool + redirectMethod string + includeBody bool ) uerr := func(err error) error { // the body may have been closed already by c.send() @@ -98,42 +115,236 @@ func (c *Client) do(req *Request) (retres *Response, reterr error) { } } - // For all but the first request, create the next - // request hop and replace req. for { + // For all but the first request, create the next + // request hop and replace req. if len(reqs) > 0 { + loc := resp.Header.Get("Location") + if loc == "" { + // While most 3xx responses include a Location, it is not + // required and 3xx responses without a Location have been + // observed in the wild. See issues #17773 and #49281. + return resp, nil + } + u, err := req.URL.Parse(loc) + if err != nil { + resp.closeBody() + return nil, uerr(fmt.Errorf("failed to parse Location header %q: %v", loc, err)) + } + // TODO(spongehah) redirect: Why use host := "" + //host := "" + host := u.Host + + if req.Host != "" && req.Host != req.URL.Host { + // If the caller specified a custom Host header and the + // redirect location is relative, preserve the Host header + // through the redirect. See issue #22233. + if u, _ := url.Parse(loc); u != nil && !u.IsAbs() { + host = req.Host + } + } + ireq := reqs[0] + req = &Request{ + Method: redirectMethod, + Response: resp, + URL: u, + Header: make(Header), + Host: host, + Cancel: ireq.Cancel, + ctx: ireq.ctx, + } + if includeBody && ireq.GetBody != nil { + req.Body, err = ireq.GetBody() + if err != nil { + resp.closeBody() + return nil, uerr(err) + } + req.ContentLength = ireq.ContentLength + } + + // Copy original headers before setting the Referer, + // in case the user set Referer on their first request. + // If they really want to override, they can do it in + // their CheckRedirect func. + copyHeaders(req) + // Add the Referer header from the most recent + // request URL to the new one, if it's not https->http: + if ref := refererForURL(reqs[len(reqs)-1].URL, req.URL, req.Header.Get("Referer")); ref != "" { + req.Header.Set("Referer", ref) + } + err = c.checkRedirect(req, reqs) + + // Sentinel error to let users select the + // previous response, without closing its + // body. See Issue 10069. + if err == ErrUseLastResponse { + return resp, nil + } + + // Close the previous response's body. But + // read at least some of the body so if it's + // small the underlying TCP connection will be + // re-used. No need to check for errors: if it + // fails, the Transport won't reuse it anyway. + const maxBodySlurpSize = 2 << 10 + if resp.ContentLength == -1 || resp.ContentLength <= maxBodySlurpSize { + io.CopyN(io.Discard, resp.Body, maxBodySlurpSize) + } + resp.Body.Close() + + if err != nil { + // Special case for Go 1 compatibility: return both the response + // and an error if the CheckRedirect function failed. + // See https://golang.org/issue/3795 + // The resp.Body has already been closed. + ue := uerr(err) + ue.(*url.Error).URL = loc + return resp, ue + } } reqs = append(reqs, req) var err error - if resp, err = c.send(req, c.Timeout); err != nil { + var didTimeout func() bool + if resp, didTimeout, err = c.send(req, deadline); err != nil { // c.send() always closes req.Body reqBodyClosed = true + if !deadline.IsZero() && didTimeout() { + err = &httpError{ + err: err.Error() + " (Client.Timeout exceeded while awaiting headers)", + timeout: true, + } + } return nil, uerr(err) } var shouldRedirect bool - //redirectMethod, shouldRedirect, includeBody = redirectBehavior(req.Method, resp, reqs[0]) - _, shouldRedirect, _ = redirectBehavior(req.Method, resp, reqs[0]) + redirectMethod, shouldRedirect, includeBody = redirectBehavior(req.Method, resp, reqs[0]) if !shouldRedirect { return resp, nil - } else { - // TODO(spongehah) - return nil, errors.New("TODO: redirect not implemented") } req.closeBody() } } -func (c *Client) send(req *Request, timeout time.Duration) (*Response, error) { - return send(req, c.transport(), timeout) +// didTimeout is non-nil only if err != nil. +func (c *Client) send(req *Request, deadline time.Time) (resp *Response, didTimeout func() bool, err error) { + // TODO(spongehah) cookie + if c.Jar != nil { + for _, cookie := range c.Jar.Cookies(req.URL) { + req.AddCookie(cookie) + } + } + resp, didTimeout, err = send(req, c.transport(), deadline) + if err != nil { + return nil, didTimeout, err + } + if c.Jar != nil { + if rc := resp.Cookies(); len(rc) > 0 { + c.Jar.SetCookies(req.URL, rc) + } + } + return resp, nil, nil } -func send(req *Request, rt RoundTripper, timeout time.Duration) (resp *Response, err error) { - req.timeout = timeout - return rt.RoundTrip(req) +// send issues an HTTP request. +// Caller should close resp.Body when done reading from it. +func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, didTimeout func() bool, err error) { + req := ireq // req is either the original request, or a modified fork + + if rt == nil { + req.closeBody() + return nil, alwaysFalse, errors.New("http: no Client.Transport or DefaultTransport") + } + + if req.URL == nil { + req.closeBody() + return nil, alwaysFalse, errors.New("http: nil Request.URL") + } + + if req.RequestURI != "" { + req.closeBody() + return nil, alwaysFalse, errors.New("http: Request.RequestURI can't be set in client requests") + } + + // forkReq forks req into a shallow clone of ireq the first + // time it's called. + forkReq := func() { + if ireq == req { + req = new(Request) + *req = *ireq // shallow clone + } + } + + // Most the callers of send (Get, Post, et al) don't need + // Headers, leaving it uninitialized. We guarantee to the + // Transport that this has been initialized, though. + if req.Header == nil { + forkReq() + req.Header = make(Header) + } + + if u := req.URL.User; u != nil && req.Header.Get("Authorization") == "" { + username := u.Username() + password, _ := u.Password() + forkReq() + req.Header = cloneOrMakeHeader(ireq.Header) + req.Header.Set("Authorization", "Basic "+basicAuth(username, password)) + } + + if !deadline.IsZero() { + forkReq() + } + + // TODO(spongehah) timeout + //stopTimer, didTimeout := setRequestCancel(req, rt, deadline) + sub := deadline.Sub(time.Now()) + req.timeout = sub + resp, err = rt.RoundTrip(req) + if err != nil { + //stopTimer() + if resp != nil { + log.Printf("RoundTripper returned a response & error; ignoring response") + } + //if tlsErr, ok := err.(tls.RecordHeaderError); ok { + // // If we get a bad TLS record header, check to see if the + // // response looks like HTTP and give a more helpful error. + // // See golang.org/issue/11111. + // if string(tlsErr.RecordHeader[:]) == "HTTP/" { + // err = ErrSchemeMismatch + // } + //} + return nil, didTimeout, err + } + if resp == nil { + return nil, didTimeout, fmt.Errorf("http: RoundTripper implementation (%T) returned a nil *Response with a nil error", rt) + } + if resp.Body == nil { + // The documentation on the Body field says “The http Client and Transport + // guarantee that Body is always non-nil, even on responses without a body + // or responses with a zero-length body.” Unfortunately, we didn't document + // that same constraint for arbitrary RoundTripper implementations, and + // RoundTripper implementations in the wild (mostly in tests) assume that + // they can use a nil Body to mean an empty one (similar to Request.Body). + // (See https://golang.org/issue/38095.) + // + // If the ContentLength allows the Body to be empty, fill in an empty one + // here to ensure that it is non-nil. + if resp.ContentLength > 0 && req.Method != "HEAD" { + return nil, didTimeout, fmt.Errorf("http: RoundTripper implementation (%T) returned a *Response with content length %d but a nil Body", rt, resp.ContentLength) + } + resp.Body = io.NopCloser(strings.NewReader("")) + } + //if !deadline.IsZero() { + // resp.Body = &cancelTimerBody{ + // stop: stopTimer, + // rc: resp.Body, + // reqDidTimeout: didTimeout, + // } + //} + return resp, nil, nil } // redirectBehavior describes what should happen when the @@ -192,29 +403,330 @@ func urlErrorOp(method string) string { return method } -// ToLower returns the lowercase version of s if s is ASCII and printable. -func ToLower(s string) (lower string, ok bool) { - if !IsPrint(s) { - return "", false +func stripPassword(u *url.URL) string { + _, passSet := u.User.Password() + if passSet { + return strings.Replace(u.String(), u.User.String()+"@", u.User.Username()+":***@", 1) } - return strings.ToLower(s), true + return u.String() } -// IsPrint returns whether s is ASCII and printable according to -// https://tools.ietf.org/html/rfc20#section-4.2. -func IsPrint(s string) bool { - for i := 0; i < len(s); i++ { - if s[i] < ' ' || s[i] > '~' { - return false +// See 2 (end of page 4) https://www.ietf.org/rfc/rfc2617.txt +// "To receive authorization, the client sends the userid and password, +// separated by a single colon (":") character, within a base64 +// encoded string in the credentials." +// It is not meant to be urlencoded. +func basicAuth(username, password string) string { + auth := username + ":" + password + return base64.StdEncoding.EncodeToString([]byte(auth)) +} + +func (c *Client) deadline() time.Time { + if c.Timeout > 0 { + return time.Now().Add(c.Timeout) + } + return time.Time{} +} + +// cancelTimerBody is an io.ReadCloser that wraps rc with two features: +// 1. On Read error or close, the stop func is called. +// 2. On Read failure, if reqDidTimeout is true, the error is wrapped and +// marked as net.Error that hit its timeout. +type cancelTimerBody struct { + stop func() // stops the time.Timer waiting to cancel the request + rc io.ReadCloser + reqDidTimeout func() bool +} + +func (b *cancelTimerBody) Read(p []byte) (n int, err error) { + n, err = b.rc.Read(p) + if err == nil { + return n, nil + } + if err == io.EOF { + return n, err + } + if b.reqDidTimeout() { + err = &httpError{ + err: err.Error() + " (Client.Timeout or context cancellation while reading body)", + timeout: true, } } + return n, err +} + +func (b *cancelTimerBody) Close() error { + err := b.rc.Close() + b.stop() + return err +} + +// setRequestCancel sets req.Cancel and adds a deadline context to req +// if deadline is non-zero. The RoundTripper's type is used to +// determine whether the legacy CancelRequest behavior should be used. +// +// As background, there are three ways to cancel a request: +// First was Transport.CancelRequest. (deprecated) +// Second was Request.Cancel. +// Third was Request.Context. +// This function populates the second and third, and uses the first if it really needs to. +func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTimer func(), didTimeout func() bool) { + if deadline.IsZero() { + return nop, alwaysFalse + } + // TODO(spongehah) todo: map[string]github.com/goplus/llgo/x/http.RoundTripper + //knownTransport := knownRoundTripperImpl(rt, req) + oldCtx := req.Context() + + //if req.Cancel == nil && knownTransport { + if req.Cancel == nil { + // If they already had a Request.Context that's + // expiring sooner, do nothing: + if !timeBeforeContextDeadline(deadline, oldCtx) { + return nop, alwaysFalse + } + + var cancelCtx func() + req.ctx, cancelCtx = context.WithDeadline(oldCtx, deadline) + return cancelCtx, func() bool { return time.Now().After(deadline) } + } + initialReqCancel := req.Cancel // the user's original Request.Cancel, if any + + var cancelCtx func() + if timeBeforeContextDeadline(deadline, oldCtx) { + req.ctx, cancelCtx = context.WithDeadline(oldCtx, deadline) + } + + cancel := make(chan struct{}) + req.Cancel = cancel + + doCancel := func() { + // The second way in the func comment above: + close(cancel) + // The first way, used only for RoundTripper + // implementations written before Go 1.5 or Go 1.6. + type canceler interface{ CancelRequest(*Request) } + if v, ok := rt.(canceler); ok { + v.CancelRequest(req) + } + } + + stopTimerCh := make(chan struct{}) + var once sync.Once + stopTimer = func() { + once.Do(func() { + close(stopTimerCh) + if cancelCtx != nil { + cancelCtx() + } + }) + } + + timer := time.NewTimer(time.Until(deadline)) + var timedOut atomic.Bool + + go func() { + select { + case <-initialReqCancel: + doCancel() + timer.Stop() + case <-timer.C: + timedOut.Store(true) + doCancel() + case <-stopTimerCh: + timer.Stop() + } + }() + + return stopTimer, timedOut.Load +} + +// timeBeforeContextDeadline reports whether the non-zero Time t is +// before ctx's deadline, if any. If ctx does not have a deadline, it +// always reports true (the deadline is considered infinite). +func timeBeforeContextDeadline(t time.Time, ctx context.Context) bool { + d, ok := ctx.Deadline() + if !ok { + return true + } + return t.Before(d) +} + +/* +// knownRoundTripperImpl reports whether rt is a RoundTripper that's +// maintained by the Go team and known to implement the latest +// optional semantics (notably contexts). The Request is used +// to check whether this particular request is using an alternate protocol, +// in which case we need to check the RoundTripper for that protocol. +func knownRoundTripperImpl(rt RoundTripper, req *Request) bool { + switch t := rt.(type) { + case *Transport: + if altRT := t.alternateRoundTripper(req); altRT != nil { + return knownRoundTripperImpl(altRT, req) + } + return true + //case *http2Transport, http2noDialH2RoundTripper: + // return true + } + // There's a very minor chance of a false positive with this. + // Instead of detecting our golang.org/x/net/http2.Transport, + // it might detect a Transport type in a different http2 + // package. But I know of none, and the only problem would be + // some temporarily leaked goroutines if the transport didn't + // support contexts. So this is a good enough heuristic: + if reflect.TypeOf(rt).String() == "*http2.Transport" { + return true + } + return false +}*/ + +// makeHeadersCopier makes a function that copies headers from the +// initial Request, ireq. For every redirect, this function must be called +// so that it can copy headers into the upcoming Request. +func (c *Client) makeHeadersCopier(ireq *Request) func(*Request) { + // The headers to copy are from the very initial request. + // We use a closured callback to keep a reference to these original headers. + var ( + ireqhdr = cloneOrMakeHeader(ireq.Header) + icookies map[string][]*Cookie + ) + if c.Jar != nil && ireq.Header.Get("Cookie") != "" { + icookies = make(map[string][]*Cookie) + for _, c := range ireq.Cookies() { + icookies[c.Name] = append(icookies[c.Name], c) + } + } + + preq := ireq // The previous request + return func(req *Request) { + // If Jar is present and there was some initial cookies provided + // via the request header, then we may need to alter the initial + // cookies as we follow redirects since each redirect may end up + // modifying a pre-existing cookie. + // + // Since cookies already set in the request header do not contain + // information about the original domain and path, the logic below + // assumes any new set cookies override the original cookie + // regardless of domain or path. + // + // See https://golang.org/issue/17494 + if c.Jar != nil && icookies != nil { + var changed bool + resp := req.Response // The response that caused the upcoming redirect + for _, c := range resp.Cookies() { + if _, ok := icookies[c.Name]; ok { + delete(icookies, c.Name) + changed = true + } + } + if changed { + ireqhdr.Del("Cookie") + var ss []string + for _, cs := range icookies { + for _, c := range cs { + ss = append(ss, c.Name+"="+c.Value) + } + } + sort.Strings(ss) // Ensure deterministic headers + ireqhdr.Set("Cookie", strings.Join(ss, "; ")) + } + } + + // Copy the initial request's Header values + // (at least the safe ones). + for k, vv := range ireqhdr { + if shouldCopyHeaderOnRedirect(k, preq.URL, req.URL) { + req.Header[k] = vv + } + } + + preq = req // Update previous Request with the current request + } +} + +func shouldCopyHeaderOnRedirect(headerKey string, initial, dest *url.URL) bool { + switch CanonicalHeaderKey(headerKey) { + case "Authorization", "Www-Authenticate", "Cookie", "Cookie2": + // Permit sending auth/cookie headers from "foo.com" + // to "sub.foo.com". + + // Note that we don't send all cookies to subdomains + // automatically. This function is only used for + // Cookies set explicitly on the initial outgoing + // client request. Cookies automatically added via the + // CookieJar mechanism continue to follow each + // cookie's scope as set by Set-Cookie. But for + // outgoing requests with the Cookie header set + // directly, we don't know their scope, so we assume + // it's for *.domain.com. + + ihost := idnaASCIIFromURL(initial) + dhost := idnaASCIIFromURL(dest) + return isDomainOrSubdomain(dhost, ihost) + } + // All other headers are copied: return true } -func stripPassword(u *url.URL) string { - _, passSet := u.User.Password() - if passSet { - return strings.Replace(u.String(), u.User.String()+"@", u.User.Username()+":***@", 1) +// isDomainOrSubdomain reports whether sub is a subdomain (or exact +// match) of the parent domain. +// +// Both domains must already be in canonical form. +func isDomainOrSubdomain(sub, parent string) bool { + if sub == parent { + return true } - return u.String() + // If sub is "foo.example.com" and parent is "example.com", + // that means sub must end in "."+parent. + // Do it without allocating. + if !strings.HasSuffix(sub, parent) { + return false + } + return sub[len(sub)-len(parent)-1] == '.' +} + +// refererForURL returns a referer without any authentication info or +// an empty string if lastReq scheme is https and newReq scheme is http. +// If the referer was explicitly set, then it will continue to be used. +func refererForURL(lastReq, newReq *url.URL, explicitRef string) string { + // https://tools.ietf.org/html/rfc7231#section-5.5.2 + // "Clients SHOULD NOT include a Referer header field in a + // (non-secure) HTTP request if the referring page was + // transferred with a secure protocol." + if lastReq.Scheme == "https" && newReq.Scheme == "http" { + return "" + } + if explicitRef != "" { + return explicitRef + } + + referer := lastReq.String() + if lastReq.User != nil { + // This is not very efficient, but is the best we can + // do without: + // - introducing a new method on URL + // - creating a race condition + // - copying the URL struct manually, which would cause + // maintenance problems down the line + auth := lastReq.User.String() + "@" + referer = strings.Replace(referer, auth, "", 1) + } + return referer +} + +// checkRedirect calls either the user's configured CheckRedirect +// function, or the default. +func (c *Client) checkRedirect(req *Request, via []*Request) error { + fn := c.CheckRedirect + if fn == nil { + fn = defaultCheckRedirect + } + return fn(req, via) +} + +func defaultCheckRedirect(req *Request, via []*Request) error { + if len(via) >= 10 { + return errors.New("stopped after 10 redirects") + } + return nil } diff --git a/x/http/clone.go b/x/http/clone.go new file mode 100644 index 0000000..ff67949 --- /dev/null +++ b/x/http/clone.go @@ -0,0 +1,11 @@ +package http + +// cloneOrMakeHeader invokes Header.Clone but if the +// result is nil, it'll instead make and return a non-nil Header. +func cloneOrMakeHeader(hdr Header) Header { + clone := hdr.Clone() + if clone == nil { + clone = make(Header) + } + return clone +} diff --git a/x/http/cookie.go b/x/http/cookie.go new file mode 100644 index 0000000..4b7175c --- /dev/null +++ b/x/http/cookie.go @@ -0,0 +1,232 @@ +package http + +import ( + "log" + "net/textproto" + "strconv" + "strings" + "time" +) + +// A Cookie represents an HTTP cookie as sent in the Set-Cookie header of an +// HTTP response or the Cookie header of an HTTP request. +// +// See https://tools.ietf.org/html/rfc6265 for details. +type Cookie struct { + Name string + Value string + + Path string // optional + Domain string // optional + Expires time.Time // optional + RawExpires string // for reading cookies only + + // MaxAge=0 means no 'Max-Age' attribute specified. + // MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0' + // MaxAge>0 means Max-Age attribute present and given in seconds + MaxAge int + Secure bool + HttpOnly bool + SameSite SameSite + Raw string + Unparsed []string // Raw text of unparsed attribute-value pairs +} + +// SameSite allows a server to define a cookie attribute making it impossible for +// the browser to send this cookie along with cross-site requests. The main +// goal is to mitigate the risk of cross-origin information leakage, and provide +// some protection against cross-site request forgery attacks. +// +// See https://tools.ietf.org/html/draft-ietf-httpbis-cookie-same-site-00 for details. +type SameSite int + +const ( + SameSiteDefaultMode SameSite = iota + 1 + SameSiteLaxMode + SameSiteStrictMode + SameSiteNoneMode +) + +var cookieNameSanitizer = strings.NewReplacer("\n", "-", "\r", "-") + +func sanitizeCookieName(n string) string { + return cookieNameSanitizer.Replace(n) +} + +// sanitizeCookieValue produces a suitable cookie-value from v. +// https://tools.ietf.org/html/rfc6265#section-4.1.1 +// +// cookie-value = *cookie-octet / ( DQUOTE *cookie-octet DQUOTE ) +// cookie-octet = %x21 / %x23-2B / %x2D-3A / %x3C-5B / %x5D-7E +// ; US-ASCII characters excluding CTLs, +// ; whitespace DQUOTE, comma, semicolon, +// ; and backslash +// +// We loosen this as spaces and commas are common in cookie values +// but we produce a quoted cookie-value if and only if v contains +// commas or spaces. +// See https://golang.org/issue/7243 for the discussion. +func sanitizeCookieValue(v string) string { + v = sanitizeOrWarn("Cookie.Value", validCookieValueByte, v) + if len(v) == 0 { + return v + } + if strings.ContainsAny(v, " ,") { + return `"` + v + `"` + } + return v +} + +func validCookieValueByte(b byte) bool { + return 0x20 <= b && b < 0x7f && b != '"' && b != ';' && b != '\\' +} + +func sanitizeOrWarn(fieldName string, valid func(byte) bool, v string) string { + ok := true + for i := 0; i < len(v); i++ { + if valid(v[i]) { + continue + } + log.Printf("net/http: invalid byte %q in %s; dropping invalid bytes", v[i], fieldName) + ok = false + break + } + if ok { + return v + } + buf := make([]byte, 0, len(v)) + for i := 0; i < len(v); i++ { + if b := v[i]; valid(b) { + buf = append(buf, b) + } + } + return string(buf) +} + +// readSetCookies parses all "Set-Cookie" values from +// the header h and returns the successfully parsed Cookies. +func readSetCookies(h Header) []*Cookie { + cookieCount := len(h["Set-Cookie"]) + if cookieCount == 0 { + return []*Cookie{} + } + cookies := make([]*Cookie, 0, cookieCount) + for _, line := range h["Set-Cookie"] { + parts := strings.Split(textproto.TrimString(line), ";") + if len(parts) == 1 && parts[0] == "" { + continue + } + parts[0] = textproto.TrimString(parts[0]) + name, value, ok := strings.Cut(parts[0], "=") + if !ok { + continue + } + name = textproto.TrimString(name) + if !isCookieNameValid(name) { + continue + } + value, ok = parseCookieValue(value, true) + if !ok { + continue + } + c := &Cookie{ + Name: name, + Value: value, + Raw: line, + } + for i := 1; i < len(parts); i++ { + parts[i] = textproto.TrimString(parts[i]) + if len(parts[i]) == 0 { + continue + } + + attr, val, _ := strings.Cut(parts[i], "=") + lowerAttr, isASCII := ToLower(attr) + if !isASCII { + continue + } + val, ok = parseCookieValue(val, false) + if !ok { + c.Unparsed = append(c.Unparsed, parts[i]) + continue + } + + switch lowerAttr { + case "samesite": + lowerVal, ascii := ToLower(val) + if !ascii { + c.SameSite = SameSiteDefaultMode + continue + } + switch lowerVal { + case "lax": + c.SameSite = SameSiteLaxMode + case "strict": + c.SameSite = SameSiteStrictMode + case "none": + c.SameSite = SameSiteNoneMode + default: + c.SameSite = SameSiteDefaultMode + } + continue + case "secure": + c.Secure = true + continue + case "httponly": + c.HttpOnly = true + continue + case "domain": + c.Domain = val + continue + case "max-age": + secs, err := strconv.Atoi(val) + if err != nil || secs != 0 && val[0] == '0' { + break + } + if secs <= 0 { + secs = -1 + } + c.MaxAge = secs + continue + case "expires": + c.RawExpires = val + exptime, err := time.Parse(time.RFC1123, val) + if err != nil { + exptime, err = time.Parse("Mon, 02-Jan-2006 15:04:05 MST", val) + if err != nil { + c.Expires = time.Time{} + break + } + } + c.Expires = exptime.UTC() + continue + case "path": + c.Path = val + continue + } + c.Unparsed = append(c.Unparsed, parts[i]) + } + cookies = append(cookies, c) + } + return cookies +} + +func isCookieNameValid(raw string) bool { + if raw == "" { + return false + } + return strings.IndexFunc(raw, isNotToken) < 0 +} + +func parseCookieValue(raw string, allowDoubleQuote bool) (string, bool) { + // Strip the quotes, if present. + if allowDoubleQuote && len(raw) > 1 && raw[0] == '"' && raw[len(raw)-1] == '"' { + raw = raw[1 : len(raw)-1] + } + for i := 0; i < len(raw); i++ { + if !validCookieValueByte(raw[i]) { + return "", false + } + } + return raw, true +} diff --git a/x/http/header.go b/x/http/header.go index 076db0f..6515c48 100644 --- a/x/http/header.go +++ b/x/http/header.go @@ -81,8 +81,75 @@ func (h Header) Del(key string) { // returned without modifications. func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) } -// AppendToResponseHeader (HeadersForEachCallback) prints each header to the console -func AppendToResponseHeader(userdata c.Pointer, name *uint8, nameLen uintptr, value *uint8, valueLen uintptr) c.Int { +// Clone returns a copy of h or nil if h is nil. +func (h Header) Clone() Header { + if h == nil { + return nil + } + + // Find total number of values. + nv := 0 + for _, vv := range h { + nv += len(vv) + } + sv := make([]string, nv) // shared backing array for headers' values + h2 := make(Header, len(h)) + for k, vv := range h { + if vv == nil { + // Preserve nil values. ReverseProxy distinguishes + // between nil and zero-length header values. + h2[k] = nil + continue + } + n := copy(sv, vv) + h2[k] = sv[:n:n] + sv = sv[n:] + } + return h2 +} + +// hasToken reports whether token appears with v, ASCII +// case-insensitive, with space or comma boundaries. +// token must be all lowercase. +// v may contain mixed cased. +func hasToken(v, token string) bool { + if len(token) > len(v) || token == "" { + return false + } + if v == token { + return true + } + for sp := 0; sp <= len(v)-len(token); sp++ { + // Check that first character is good. + // The token is ASCII, so checking only a single byte + // is sufficient. We skip this potential starting + // position if both the first byte and its potential + // ASCII uppercase equivalent (b|0x20) don't match. + // False positives ('^' => '~') are caught by EqualFold. + if b := v[sp]; b != token[0] && b|0x20 != token[0] { + continue + } + // Check that start pos is on a valid token boundary. + if sp > 0 && !isTokenBoundary(v[sp-1]) { + continue + } + // Check that end pos is on a valid token boundary. + if endPos := sp + len(token); endPos != len(v) && !isTokenBoundary(v[endPos]) { + continue + } + if EqualFold(v[sp:sp+len(token)], token) { + return true + } + } + return false +} + +func isTokenBoundary(b byte) bool { + return b == ' ' || b == ',' || b == '\t' +} + +// appendToResponseHeader (HeadersForEachCallback) prints each header to the console +func appendToResponseHeader(userdata c.Pointer, name *uint8, nameLen uintptr, value *uint8, valueLen uintptr) c.Int { resp := (*Response)(userdata) nameStr := c.GoString((*int8)(c.Pointer(name)), nameLen) valueStr := c.GoString((*int8)(c.Pointer(value)), valueLen) diff --git a/x/http/http.go b/x/http/http.go new file mode 100644 index 0000000..f668906 --- /dev/null +++ b/x/http/http.go @@ -0,0 +1,27 @@ +package http + +import "strings" + +// splitTwoDigitNumber splits a two-digit number into two digits. +func splitTwoDigitNumber(num int) (int, int) { + tens := num / 10 + ones := num % 10 + return tens, ones +} + +func isNotToken(r rune) bool { + return !IsTokenRune(r) +} + +// removeEmptyPort strips the empty port in ":port" to "" +// as mandated by RFC 3986 Section 6.2.3. +func removeEmptyPort(host string) string { + if hasPort(host) { + return strings.TrimSuffix(host, ":") + } + return host +} + +// Given a string of the form "host", "host:port", or "[ipv6::address]:port", +// return true if the string includes a port. +func hasPort(s string) bool { return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") } diff --git a/x/http/jar.go b/x/http/jar.go new file mode 100644 index 0000000..5c3de0d --- /dev/null +++ b/x/http/jar.go @@ -0,0 +1,27 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "net/url" +) + +// A CookieJar manages storage and use of cookies in HTTP requests. +// +// Implementations of CookieJar must be safe for concurrent use by multiple +// goroutines. +// +// The net/http/cookiejar package provides a CookieJar implementation. +type CookieJar interface { + // SetCookies handles the receipt of the cookies in a reply for the + // given URL. It may or may not choose to save the cookies, depending + // on the jar's policy and implementation. + SetCookies(u *url.URL, cookies []*Cookie) + + // Cookies returns the cookies to send in a request for the given URL. + // It is up to the implementation to honor the standard cookie use + // restrictions such as in RFC 6265. + Cookies(u *url.URL) []*Cookie +} diff --git a/x/http/request.go b/x/http/request.go index c84cd9b..f6e6f16 100644 --- a/x/http/request.go +++ b/x/http/request.go @@ -1,11 +1,17 @@ package http import ( + "bytes" + "context" "fmt" "io" + "net/textproto" "net/url" + "strings" "time" + "golang.org/x/net/idna" + "github.com/goplus/llgo/c" "github.com/goplus/llgo/c/os" "github.com/goplus/llgoexamples/rust/hyper" @@ -24,79 +30,124 @@ type Request struct { TransferEncoding []string Close bool Host string - timeout time.Duration -} - -type postBody struct { - data []byte - len uintptr - readLen uintptr + //Form url.Values + //PostForm url.Values + //MultipartForm *multipart.Form + Trailer Header + RemoteAddr string + RequestURI string + //TLS *tls.ConnectionState + Cancel <-chan struct{} + Response *Response + timeout time.Duration + ctx context.Context } -type uploadBody struct { - fd c.Int - buf []byte - len uintptr -} - -var DefaultChunkSize uintptr = 8192 +var defaultChunkSize uintptr = 8192 func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { + if method == "" { + // We document that "" means "GET" for Request.Method, and people have + // relied on that from NewRequest, so keep that working. + // We still enforce validMethod for non-empty methods. + method = "GET" + } + if !validMethod(method) { + return nil, fmt.Errorf("net/http: invalid method %q", method) + } + //if ctx == nil { + // return nil, errors.New("net/http: nil Context") + //} u, err := url.Parse(urlStr) if err != nil { return nil, err } - //rc, ok := body.(io.ReadCloser) - //if !ok && body != nil { - // rc = io.NopCloser(body) - //} - request := &Request{ + rc, ok := body.(io.ReadCloser) + if !ok && body != nil { + rc = io.NopCloser(body) + } + // The host's colon:port should be normalized. See Issue 14836. + u.Host = removeEmptyPort(u.Host) + req := &Request{ + //ctx: ctx, Method: method, URL: u, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, Header: make(Header), + Body: rc, Host: u.Host, - //Body: rc, - timeout: 0, } - request.Header.Set("Host", request.Host) + if body != nil { + switch v := body.(type) { + case *bytes.Buffer: + req.ContentLength = int64(v.Len()) + buf := v.Bytes() + req.GetBody = func() (io.ReadCloser, error) { + r := bytes.NewReader(buf) + return io.NopCloser(r), nil + } + case *bytes.Reader: + req.ContentLength = int64(v.Len()) + snapshot := *v + req.GetBody = func() (io.ReadCloser, error) { + r := snapshot + return io.NopCloser(&r), nil + } + case *strings.Reader: + req.ContentLength = int64(v.Len()) + snapshot := *v + req.GetBody = func() (io.ReadCloser, error) { + r := snapshot + return io.NopCloser(&r), nil + } + default: + // This is where we'd set it to -1 (at least + // if body != NoBody) to mean unknown, but + // that broke people during the Go 1.8 testing + // period. People depend on it being 0 I + // guess. Maybe retry later. See Issue 18117. + } + // For client requests, Request.ContentLength of 0 + // means either actually 0, or unknown. The only way + // to explicitly say that the ContentLength is zero is + // to set the Body to nil. But turns out too much code + // depends on NewRequest returning a non-nil Body, + // so we use a well-known ReadCloser variable instead + // and have the http package also treat that sentinel + // variable to mean explicitly zero. + if req.GetBody != nil && req.ContentLength == 0 { + req.Body = NoBody + req.GetBody = func() (io.ReadCloser, error) { return NoBody, nil } + } + } - return request, nil + return req, nil } -func PrintInformational(userdata c.Pointer, resp *hyper.Response) { +func printInformational(userdata c.Pointer, resp *hyper.Response) { status := resp.Status() fmt.Println("Informational (1xx): ", status) } -func SetPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { - //upload := (*uploadBody)(userdata) - //res := os.Read(upload.fd, c.Pointer(&upload.buf[0]), upload.len) - //if res > 0 { - // *chunk = hyper.CopyBuf(&upload.buf[0], uintptr(res)) - // return hyper.PollReady - //} - //if res == 0 { - // *chunk = nil - // os.Close(upload.fd) - // return hyper.PollReady - //} - body := (*postBody)(userdata) - if body.len > 0 { - if body.len > DefaultChunkSize { - *chunk = hyper.CopyBuf(&body.data[body.readLen], DefaultChunkSize) - body.readLen += DefaultChunkSize - body.len -= DefaultChunkSize - } else { - *chunk = hyper.CopyBuf(&body.data[body.readLen], body.len) - body.readLen += body.len - body.len = 0 +func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { + req := (*Request)(userdata) + buffer := make([]byte, defaultChunkSize) + n, err := req.Body.Read(buffer) + if err != nil { + if err == io.EOF { + *chunk = nil + return hyper.PollReady } + fmt.Println("error reading upload file: ", err) + return hyper.PollError + } + if n > 0 { + *chunk = hyper.CopyBuf(&buffer[0], uintptr(n)) return hyper.PollReady } - if body.len == 0 { + if n == 0 { *chunk = nil return hyper.PollReady } @@ -107,7 +158,7 @@ func SetPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.In func newHyperRequest(req *Request) (*hyper.Request, error) { host := req.Host - uri := req.URL.Path + uri := req.URL.RequestURI() method := req.Method // Prepare the request hyperReq := hyper.NewRequest() @@ -124,27 +175,13 @@ func newHyperRequest(req *Request) (*hyper.Request, error) { return nil, fmt.Errorf("error setting header: Host: %s\n", host) } - if method == "POST" { - //var upload uploadBody - //upload.fd = os.Open(c.Str("/Users/spongehah/go/src/llgo/x/http/_demo/post/example.txt"), os.O_RDONLY) - //if upload.fd < 0 { - // return nil, fmt.Errorf("error opening file to upload: %s\n", c.GoString(c.Strerror(os.Errno))) - //} - //upload.len = 8192 - //upload.buf = make([]byte, upload.len) + if method == "POST" && req.Body != nil { req.Header.Set("expect", "100-continue") - hyperReq.OnInformational(PrintInformational, nil) - postData := []byte(`{"id":1,"title":"foo","body":"bar","userId":"1"}`) - - reqBody := &postBody{ - data: postData, - len: uintptr(len(postData)), - } + hyperReq.OnInformational(printInformational, nil) hyperReqBody := hyper.NewBody() - hyperReqBody.SetUserdata(c.Pointer(reqBody)) - //hyperReqBody.SetUserdata(c.Pointer(&upload)) - hyperReqBody.SetDataFunc(SetPostData) + hyperReqBody.SetUserdata(c.Pointer(req)) + hyperReqBody.SetDataFunc(setPostData) hyperReq.SetBody(hyperReqBody) } @@ -185,3 +222,120 @@ func (r *Request) closeBody() error { } return r.Body.Close() } + +func validMethod(method string) bool { + /* + Method = "OPTIONS" ; Section 9.2 + | "GET" ; Section 9.3 + | "HEAD" ; Section 9.4 + | "POST" ; Section 9.5 + | "PUT" ; Section 9.6 + | "DELETE" ; Section 9.7 + | "TRACE" ; Section 9.8 + | "CONNECT" ; Section 9.9 + | extension-method + extension-method = token + token = 1* + */ + return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1 +} + +// Context returns the request's context. To change the context, use +// Clone or WithContext. +// +// The returned context is always non-nil; it defaults to the +// background context. +// +// For outgoing client requests, the context controls cancellation. +// +// For incoming server requests, the context is canceled when the +// client's connection closes, the request is canceled (with HTTP/2), +// or when the ServeHTTP method returns. +func (r *Request) Context() context.Context { + if r.ctx != nil { + return r.ctx + } + return context.Background() +} + +// AddCookie adds a cookie to the request. Per RFC 6265 section 5.4, +// AddCookie does not attach more than one Cookie header field. That +// means all cookies, if any, are written into the same line, +// separated by semicolon. +// AddCookie only sanitizes c's name and value, and does not sanitize +// a Cookie header already present in the request. +func (r *Request) AddCookie(c *Cookie) { + s := fmt.Sprintf("%s=%s", sanitizeCookieName(c.Name), sanitizeCookieValue(c.Value)) + if c := r.Header.Get("Cookie"); c != "" { + r.Header.Set("Cookie", c+"; "+s) + } else { + r.Header.Set("Cookie", s) + } +} + +// requiresHTTP1 reports whether this request requires being sent on +// an HTTP/1 connection. +func (r *Request) requiresHTTP1() bool { + return hasToken(r.Header.Get("Connection"), "upgrade") && + EqualFold(r.Header.Get("Upgrade"), "websocket") +} + +// Cookies parses and returns the HTTP cookies sent with the request. +func (r *Request) Cookies() []*Cookie { + return readCookies(r.Header, "") +} + +// readCookies parses all "Cookie" values from the header h and +// returns the successfully parsed Cookies. +// +// if filter isn't empty, only cookies of that name are returned. +func readCookies(h Header, filter string) []*Cookie { + lines := h["Cookie"] + if len(lines) == 0 { + return []*Cookie{} + } + + cookies := make([]*Cookie, 0, len(lines)+strings.Count(lines[0], ";")) + for _, line := range lines { + line = textproto.TrimString(line) + + var part string + for len(line) > 0 { // continue since we have rest + part, line, _ = strings.Cut(line, ";") + part = textproto.TrimString(part) + if part == "" { + continue + } + name, val, _ := strings.Cut(part, "=") + name = textproto.TrimString(name) + if !isCookieNameValid(name) { + continue + } + if filter != "" && filter != name { + continue + } + val, ok := parseCookieValue(val, true) + if !ok { + continue + } + cookies = append(cookies, &Cookie{Name: name, Value: val}) + } + } + return cookies +} + +func idnaASCII(v string) (string, error) { + // TODO: Consider removing this check after verifying performance is okay. + // Right now punycode verification, length checks, context checks, and the + // permissible character tests are all omitted. It also prevents the ToASCII + // call from salvaging an invalid IDN, when possible. As a result it may be + // possible to have two IDNs that appear identical to the user where the + // ASCII-only version causes an error downstream whereas the non-ASCII + // version does not. + // Note that for correct ASCII IDNs ToASCII will only do considerably more + // work, but it will not cause an allocation. + if Is(v) { + return v, nil + } + return idna.Lookup.ToASCII(v) +} diff --git a/x/http/response.go b/x/http/response.go index c99bade..174d2fc 100644 --- a/x/http/response.go +++ b/x/http/response.go @@ -20,15 +20,21 @@ type Response struct { ContentLength int64 TransferEncoding []string Close bool - Trailer Header - Request *Request + //Trailer Header + Request *Request +} + +func (r *Response) closeBody() { + if r.Body != nil { + r.Body.Close() + } } func ReadResponse(hyperResp *hyper.Response, req *Request) (*Response, error) { resp := &Response{ Request: req, Header: make(Header), - Trailer: make(Header), + //Trailer: make(Header), } readResponseLineAndHeader(resp, hyperResp) @@ -54,7 +60,7 @@ func readResponseLineAndHeader(resp *Response, hyperResp *hyper.Response) { resp.Proto = fmt.Sprintf("HTTP/%d.%d", resp.ProtoMajor, resp.ProtoMinor) headers := hyperResp.Headers() - headers.Foreach(AppendToResponseHeader, c.Pointer(resp)) + headers.Foreach(appendToResponseHeader, c.Pointer(resp)) } // RFC 7234, section 5.4: Should treat @@ -71,3 +77,8 @@ func fixPragmaCacheControl(header Header) { } } } + +// Cookies parses and returns the cookies set in the Set-Cookie headers. +func (r *Response) Cookies() []*Cookie { + return readSetCookies(r.Header) +} diff --git a/x/http/transfer.go b/x/http/transfer.go index 70f082e..ac50296 100644 --- a/x/http/transfer.go +++ b/x/http/transfer.go @@ -94,7 +94,7 @@ func readTransfer(msg any) (err error) { } // Trailer - t.Trailer, err = fixTrailer(t.Header, t.Chunked) + //t.Trailer, err = fixTrailer(t.Header, t.Chunked) // If there is no Content-Length or chunked Transfer-Encoding on a *Response // and the status is not 1xx, 204 or 304, then the body is unbounded. @@ -148,7 +148,7 @@ func readTransfer(msg any) (err error) { rr.TransferEncoding = []string{"chunked"} } rr.Close = t.Close - rr.Trailer = t.Trailer + //rr.Trailer = t.Trailer } return nil @@ -174,7 +174,7 @@ func (t *transferReader) parseTransferEncoding() error { if len(raw) != 1 { return &unsupportedTEError{fmt.Sprintf("too many transfer encodings: %q", raw)} } - if !equalFold(raw[0], "chunked") { + if !EqualFold(raw[0], "chunked") { return &unsupportedTEError{fmt.Sprintf("unsupported transfer encoding: %q", raw[0])} } @@ -199,20 +199,6 @@ func (t *transferReader) protoAtLeast(m, n int) bool { return t.ProtoMajor > m || (t.ProtoMajor == m && t.ProtoMinor >= n) } -// equalFold is strings.EqualFold, ASCII only. It reports whether s and t -// are equal, ASCII-case-insensitively. -func equalFold(s, t string) bool { - if len(s) != len(t) { - return false - } - for i := 0; i < len(s); i++ { - if lower(s[i]) != lower(t[i]) { - return false - } - } - return true -} - // Determine the expected body length, using RFC 7230 Section 3.3. This // function is not a method, because ultimately it should be shared by // ReadResponse and ReadRequest. @@ -343,21 +329,6 @@ func fixTrailer(header Header, chunked bool) (Header, error) { return trailer, nil } -// splitTwoDigitNumber splits a two-digit number into two digits. -func splitTwoDigitNumber(num int) (int, int) { - tens := num / 10 - ones := num % 10 - return tens, ones -} - -// lower returns the ASCII lowercase version of b. -func lower(b byte) byte { - if 'A' <= b && b <= 'Z' { - return b + ('a' - 'A') - } - return b -} - // foreachHeaderElement splits v according to the "#rule" construction // in RFC 7230 section 7 and calls fn for each non-empty element. func foreachHeaderElement(v string, fn func(string)) { diff --git a/x/http/transport.go b/x/http/transport.go index 2dad490..f9bfa46 100644 --- a/x/http/transport.go +++ b/x/http/transport.go @@ -3,6 +3,8 @@ package http import ( "fmt" "io" + "net/url" + "sync/atomic" "unsafe" "github.com/goplus/llgo/c" @@ -12,7 +14,7 @@ import ( "github.com/goplus/llgoexamples/rust/hyper" ) -type ConnData struct { +type connData struct { TcpHandle libuv.Tcp ConnectReq libuv.Connect ReadBuf libuv.Buf @@ -24,20 +26,21 @@ type ConnData struct { } type Transport struct { + altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme } -// TaskId The unique identifier of the next task polled from the executor -type TaskId c.Int +// taskId The unique identifier of the next task polled from the executor +type taskId c.Int const ( - NotSet TaskId = iota - Send - ReceiveResp - ReceiveRespBody + notSet taskId = iota + sending + receiveResp + receiveRespBody ) const ( - DefaultHTTPPort = "80" + defaultHTTPPort = "80" ) var DefaultTransport RoundTripper = &Transport{} @@ -54,7 +57,7 @@ type persistConn struct { //nwrite int64 // bytes written //writech chan writeRequest // written by roundTrip; read by writeLoop //closech chan struct{} // closed when conn closed - conn *ConnData + conn *connData t *Transport reqch chan requestAndChan // written by roundTrip; read by readLoop cancelch chan freeChan @@ -82,7 +85,7 @@ type responseAndError struct { type connAndTimeoutChan struct { _ incomparable - conn *ConnData + conn *connData timeoutch chan struct{} } @@ -105,29 +108,29 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { } func (t *Transport) getConn(req *Request) (pconn *persistConn, err error) { - host := req.Host + host := req.URL.Hostname() port := req.URL.Port() if port == "" { // Hyper only supports http - port = DefaultHTTPPort + port = defaultHTTPPort } loop := libuv.DefaultLoop() //conn := (*ConnData)(c.Calloc(1, unsafe.Sizeof(ConnData{}))) - conn := new(ConnData) + conn := new(connData) if conn == nil { return nil, fmt.Errorf("Failed to allocate memory for conn_data\n") } // If timeout is set, start the timer timeoutch := make(chan struct{}, 1) - if req.timeout != 0 { + if req.timeout > 0 { libuv.InitTimer(loop, &conn.TimeoutTimer) ct := &connAndTimeoutChan{ conn: conn, timeoutch: timeoutch, } (*libuv.Handle)(c.Pointer(&conn.TimeoutTimer)).SetData(c.Pointer(ct)) - conn.TimeoutTimer.Start(OnTimeout, uint64(req.timeout.Milliseconds()), 0) + conn.TimeoutTimer.Start(onTimeout, uint64(req.timeout.Milliseconds()), 0) } libuv.InitTcp(loop, &conn.TcpHandle) @@ -148,7 +151,7 @@ func (t *Transport) getConn(req *Request) (pconn *persistConn, err error) { //conn.ConnectReq.Data = c.Pointer(conn) (*libuv.Req)(c.Pointer(&conn.ConnectReq)).SetData(c.Pointer(conn)) - status = libuv.TcpConnect(&conn.ConnectReq, &conn.TcpHandle, res.Addr, OnConnect) + status = libuv.TcpConnect(&conn.ConnectReq, &conn.TcpHandle, res.Addr, onConnect) if status != 0 { close(timeoutch) return nil, fmt.Errorf("connect error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(status)))) @@ -209,7 +212,7 @@ func (pc *persistConn) roundTrip(req *Request) (*Response, error) { // It processes incoming requests, sends them to the server, and handles responses. func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // Hookup the IO - hyperIo := NewIoWithConnReadWrite(pc.conn) + hyperIo := newIoWithConnReadWrite(pc.conn) // We need an executor generally to poll futures exec := hyper.NewExecutor() @@ -218,7 +221,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { opts.Exec(exec) handshakeTask := hyper.Handshake(hyperIo, opts) - SetTaskId(handshakeTask, Send) + setTaskId(handshakeTask, sending) // Let's wait for the handshake to finish... exec.Push(handshakeTask) @@ -241,13 +244,12 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { default: task := exec.Poll() if task == nil { - //break loop.Run(libuv.RUN_ONCE) continue } - switch (TaskId)(uintptr(task.Userdata())) { - case Send: - err := CheckTaskType(task, Send) + switch (taskId)(uintptr(task.Userdata())) { + case sending: + err := checkTaskType(task, sending) if err != nil { rc.ch <- responseAndError{err: err} // Free the resources @@ -269,7 +271,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // Send it! sendTask := client.Send(hyperReq) - SetTaskId(sendTask, ReceiveResp) + setTaskId(sendTask, receiveResp) sendRes := exec.Push(sendTask) if sendRes != hyper.OK { rc.ch <- responseAndError{err: fmt.Errorf("failed to send request")} @@ -280,8 +282,8 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // For this example, no longer need the client client.Free() - case ReceiveResp: - err := CheckTaskType(task, ReceiveResp) + case receiveResp: + err := checkTaskType(task, receiveResp) if err != nil { rc.ch <- responseAndError{err: err} // Free the resources @@ -309,19 +311,19 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // Response has been returned, stop the timer pc.conn.IsCompleted = 1 // Stop the timer - if rc.req.timeout != 0 { + if rc.req.timeout > 0 { pc.conn.TimeoutTimer.Stop() (*libuv.Handle)(c.Pointer(&pc.conn.TimeoutTimer)).Close(nil) } dataTask := respBody.Data() - SetTaskId(dataTask, ReceiveRespBody) + setTaskId(dataTask, receiveRespBody) exec.Push(dataTask) // No longer need the response hyperResp.Free() - case ReceiveRespBody: - err := CheckTaskType(task, ReceiveRespBody) + case receiveRespBody: + err := checkTaskType(task, receiveRespBody) if err != nil { rc.ch <- responseAndError{err: err} // Free the resources @@ -350,7 +352,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { task.Free() dataTask := respBody.Data() - SetTaskId(dataTask, ReceiveRespBody) + setTaskId(dataTask, receiveRespBody) exec.Push(dataTask) break @@ -369,7 +371,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { FreeResources(task, respBody, bodyWriter, exec, pc, rc) alive = false - case NotSet: + case notSet: // A background task for hyper_client completed... task.Free() } @@ -378,24 +380,24 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { //} } -// OnConnect is the libuv callback for a successful connection -func OnConnect(req *libuv.Connect, status c.Int) { +// onConnect is the libuv callback for a successful connection +func onConnect(req *libuv.Connect, status c.Int) { //conn := (*ConnData)(req.Data) //conn := (*struct{ data *ConnData })(c.Pointer(req)).data - conn := (*ConnData)((*libuv.Req)(c.Pointer(req)).GetData()) + conn := (*connData)((*libuv.Req)(c.Pointer(req)).GetData()) if status < 0 { c.Fprintf(c.Stderr, c.Str("connect error: %d\n"), libuv.Strerror(libuv.Errno(status))) return } - (*libuv.Stream)(c.Pointer(&conn.TcpHandle)).StartRead(AllocBuffer, OnRead) + (*libuv.Stream)(c.Pointer(&conn.TcpHandle)).StartRead(allocBuffer, onRead) } -// AllocBuffer allocates a buffer for reading from a socket -func AllocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { +// allocBuffer allocates a buffer for reading from a socket +func allocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { //conn := (*ConnData)(handle.Data) //conn := (*struct{ data *ConnData })(c.Pointer(handle)).data - conn := (*ConnData)(handle.GetData()) + conn := (*connData)(handle.GetData()) if conn.ReadBuf.Base == nil { conn.ReadBuf = libuv.InitBuf((*c.Char)(c.Malloc(suggestedSize)), c.Uint(suggestedSize)) //base := make([]byte, suggestedSize) @@ -405,11 +407,11 @@ func AllocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { *buf = libuv.InitBuf((*c.Char)(c.Pointer(uintptr(c.Pointer(conn.ReadBuf.Base))+conn.ReadBufFilled)), c.Uint(suggestedSize-conn.ReadBufFilled)) } -// OnRead is the libuv callback for reading from a socket +// onRead is the libuv callback for reading from a socket // This callback function is called when data is available to be read -func OnRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { +func onRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { // Get the connection data associated with the stream - conn := (*ConnData)((*libuv.Handle)(c.Pointer(stream)).GetData()) + conn := (*connData)((*libuv.Handle)(c.Pointer(stream)).GetData()) //conn := (*ConnData)(stream.Data) //conn := (*struct{ data *ConnData })(c.Pointer(stream)).data @@ -427,10 +429,10 @@ func OnRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { } } -// ReadCallBack read callback function for Hyper library -func ReadCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { +// readCallBack read callback function for Hyper library +func readCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { // Get the user data (connection data) - conn := (*ConnData)(userdata) + conn := (*connData)(userdata) // If there's data in the buffer if conn.ReadBufFilled > 0 { @@ -462,11 +464,11 @@ func ReadCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uin return hyper.IoPending } -// OnWrite is the libuv callback for writing to a socket +// onWrite is the libuv callback for writing to a socket // Callback function called after a write operation completes -func OnWrite(req *libuv.Write, status c.Int) { +func onWrite(req *libuv.Write, status c.Int) { // Get the connection data associated with the write request - conn := (*ConnData)((*libuv.Req)(c.Pointer(req)).GetData()) + conn := (*connData)((*libuv.Req)(c.Pointer(req)).GetData()) //conn := (*ConnData)(req.Data) //conn := (*struct{ data *ConnData })(c.Pointer(req)).data @@ -479,10 +481,10 @@ func OnWrite(req *libuv.Write, status c.Int) { } } -// WriteCallBack write callback function for Hyper library -func WriteCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { +// writeCallBack write callback function for Hyper library +func writeCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { // Get the user data (connection data) - conn := (*ConnData)(userdata) + conn := (*connData)(userdata) // Create a libuv buffer initBuf := libuv.InitBuf((*c.Char)(c.Pointer(buf)), c.Uint(bufLen)) //req := (*libuv.Write)(c.Malloc(unsafe.Sizeof(libuv.Write{}))) @@ -492,7 +494,7 @@ func WriteCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen ui //req.Data = c.Pointer(conn) // Perform the asynchronous write operation - ret := req.Write((*libuv.Stream)(c.Pointer(&conn.TcpHandle)), &initBuf, 1, OnWrite) + ret := req.Write((*libuv.Stream)(c.Pointer(&conn.TcpHandle)), &initBuf, 1, onWrite) // If the write operation was successfully initiated if ret >= 0 { // Return the number of bytes to be written @@ -510,8 +512,8 @@ func WriteCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen ui return hyper.IoPending } -// OnTimeout is the libuv callback for a timeout -func OnTimeout(handle *libuv.Timer) { +// onTimeout is the libuv callback for a timeout +func onTimeout(handle *libuv.Timer) { ct := (*connAndTimeoutChan)((*libuv.Handle)(c.Pointer(handle)).GetData()) if ct.conn.IsCompleted != 1 { ct.conn.IsCompleted = 1 @@ -521,25 +523,25 @@ func OnTimeout(handle *libuv.Timer) { (*libuv.Handle)(c.Pointer(&ct.conn.TimeoutTimer)).Close(nil) } -// NewIoWithConnReadWrite creates a new IO with read and write callbacks -func NewIoWithConnReadWrite(connData *ConnData) *hyper.Io { +// newIoWithConnReadWrite creates a new IO with read and write callbacks +func newIoWithConnReadWrite(connData *connData) *hyper.Io { hyperIo := hyper.NewIo() hyperIo.SetUserdata(c.Pointer(connData)) - hyperIo.SetRead(ReadCallBack) - hyperIo.SetWrite(WriteCallBack) + hyperIo.SetRead(readCallBack) + hyperIo.SetWrite(writeCallBack) return hyperIo } -// SetTaskId Set TaskId to the task's userdata as a unique identifier -func SetTaskId(task *hyper.Task, userData TaskId) { +// setTaskId Set taskId to the task's userdata as a unique identifier +func setTaskId(task *hyper.Task, userData taskId) { var data = userData task.SetUserdata(unsafe.Pointer(uintptr(data))) } -// CheckTaskType checks the task type -func CheckTaskType(task *hyper.Task, curTaskId TaskId) error { +// checkTaskType checks the task type +func checkTaskType(task *hyper.Task, curTaskId taskId) error { switch curTaskId { - case Send: + case sending: if task.Type() == hyper.TaskError { c.Printf(c.Str("handshake task error!\n")) return Fail((*hyper.Error)(task.Value())) @@ -548,7 +550,7 @@ func CheckTaskType(task *hyper.Task, curTaskId TaskId) error { return fmt.Errorf("unexpected task type\n") } return nil - case ReceiveResp: + case receiveResp: if task.Type() == hyper.TaskError { c.Printf(c.Str("send task error!\n")) return Fail((*hyper.Error)(task.Value())) @@ -558,13 +560,13 @@ func CheckTaskType(task *hyper.Task, curTaskId TaskId) error { return fmt.Errorf("unexpected task type\n") } return nil - case ReceiveRespBody: + case receiveRespBody: if task.Type() == hyper.TaskError { c.Printf(c.Str("body error!\n")) return Fail((*hyper.Error)(task.Value())) } return nil - case NotSet: + case notSet: } return fmt.Errorf("unexpected TaskId\n") } @@ -617,7 +619,7 @@ func CloseChannels(rc requestAndChan, pc *persistConn) { } // FreeConnData frees the connection data -func FreeConnData(conn *ConnData) { +func FreeConnData(conn *connData) { if conn.ReadWaker != nil { conn.ReadWaker.Free() conn.ReadWaker = nil @@ -631,3 +633,49 @@ func FreeConnData(conn *ConnData) { conn.ReadBuf.Base = nil } } + +type httpError struct { + err string + timeout bool +} + +func (e *httpError) Error() string { return e.err } +func (e *httpError) Timeout() bool { return e.timeout } +func (e *httpError) Temporary() bool { return true } + +var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true} + +func nop() {} + +/*// alternateRoundTripper returns the alternate RoundTripper to use +// for this request if the Request's URL scheme requires one, +// or nil for the normal case of using the Transport. +func (t *Transport) alternateRoundTripper(req *Request) RoundTripper { + if !t.useRegisteredProtocol(req) { + return nil + } + altProto, _ := t.altProto.Load().(map[string]RoundTripper) + return altProto[req.URL.Scheme] +} + +// useRegisteredProtocol reports whether an alternate protocol (as registered +// with Transport.RegisterProtocol) should be respected for this request. +func (t *Transport) useRegisteredProtocol(req *Request) bool { + if req.URL.Scheme == "https" && req.requiresHTTP1() { + // If this request requires HTTP/1, don't use the + // "https" alternate protocol, which is used by the + // HTTP/2 code to take over requests if there's an + // existing cached HTTP/2 connection. + return false + } + return true +} +*/ + +func idnaASCIIFromURL(url *url.URL) string { + addr := url.Hostname() + if v, err := idnaASCII(addr); err == nil { + addr = v + } + return addr +} \ No newline at end of file diff --git a/x/http/util.go b/x/http/util.go new file mode 100644 index 0000000..674f481 --- /dev/null +++ b/x/http/util.go @@ -0,0 +1,146 @@ +package http + +import ( + "strings" + "unicode" +) + +/** + * Copied from the libraries that llgo cannot be used + */ + +var isTokenTable = [127]bool{ // httpguts.isTokenTable + '!': true, + '#': true, + '$': true, + '%': true, + '&': true, + '\'': true, + '*': true, + '+': true, + '-': true, + '.': true, + '0': true, + '1': true, + '2': true, + '3': true, + '4': true, + '5': true, + '6': true, + '7': true, + '8': true, + '9': true, + 'A': true, + 'B': true, + 'C': true, + 'D': true, + 'E': true, + 'F': true, + 'G': true, + 'H': true, + 'I': true, + 'J': true, + 'K': true, + 'L': true, + 'M': true, + 'N': true, + 'O': true, + 'P': true, + 'Q': true, + 'R': true, + 'S': true, + 'T': true, + 'U': true, + 'W': true, + 'V': true, + 'X': true, + 'Y': true, + 'Z': true, + '^': true, + '_': true, + '`': true, + 'a': true, + 'b': true, + 'c': true, + 'd': true, + 'e': true, + 'f': true, + 'g': true, + 'h': true, + 'i': true, + 'j': true, + 'k': true, + 'l': true, + 'm': true, + 'n': true, + 'o': true, + 'p': true, + 'q': true, + 'r': true, + 's': true, + 't': true, + 'u': true, + 'v': true, + 'w': true, + 'x': true, + 'y': true, + 'z': true, + '|': true, + '~': true, +} + +func IsTokenRune(r rune) bool { // httpguts.IsTokenRune + i := int(r) + return i < len(isTokenTable) && isTokenTable[i] +} + +// IsPrint returns whether s is ASCII and printable according to +// https://tools.ietf.org/html/rfc20#section-4.2. +func IsPrint(s string) bool { // ascii.IsPrint + for i := 0; i < len(s); i++ { + if s[i] < ' ' || s[i] > '~' { + return false + } + } + return true +} + +// ToLower returns the lowercase version of s if s is ASCII and printable. +func ToLower(s string) (lower string, ok bool) { // ascii.ToLower + if !IsPrint(s) { + return "", false + } + return strings.ToLower(s), true +} + +// EqualFold is strings.EqualFold, ASCII only. It reports whether s and t +// are equal, ASCII-case-insensitively. +func EqualFold(s, t string) bool { // ascii.EqualFold + if len(s) != len(t) { + return false + } + for i := 0; i < len(s); i++ { + if lower(s[i]) != lower(t[i]) { + return false + } + } + return true +} + +// lower returns the ASCII lowercase version of b. +func lower(b byte) byte { // ascii.lower + if 'A' <= b && b <= 'Z' { + return b + ('a' - 'A') + } + return b +} + +// Is returns whether s is ASCII. +func Is(s string) bool { // ascii.Is + for i := 0; i < len(s); i++ { + if s[i] > unicode.MaxASCII { + return false + } + } + return true +} From 19f642598519fbf1a2466e6ff89598dbc0b08cde Mon Sep 17 00:00:00 2001 From: hackerchai Date: Fri, 16 Aug 2024 18:29:00 +0800 Subject: [PATCH 13/55] feat(x/net/http): Implement resonse and request logic Signed-off-by: hackerchai --- x/net/http/request.go | 113 +- x/net/http/response.go | 76 +- x/net/http/server.go | 35 +- x/net/http/servermux.go | 2 +- x/net/url/url.go | 2626 --------------------------------------- 5 files changed, 169 insertions(+), 2683 deletions(-) delete mode 100644 x/net/url/url.go diff --git a/x/net/http/request.go b/x/net/http/request.go index 037ff6d..0404815 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -3,22 +3,32 @@ package http import ( "fmt" "io" + "net/url" + "strings" + "time" "unsafe" "github.com/goplus/llgo/c" "github.com/goplus/llgo/rust/hyper" - cos "github.com/goplus/llgo/c/os" ) type Request struct { - Conn *Conn - Method string - URL string - Header Header - Body io.ReadCloser + Method string + URL *url.URL + Proto string // "HTTP/1.0" + ProtoMajor int // 1 + ProtoMinor int // 0 + Header Header + Body io.ReadCloser + GetBody func() (io.ReadCloser, error) + ContentLength int64 + TransferEncoding []string + Close bool + Host string + timeout time.Duration } -func newRequest(conn *Conn, hyperReq *hyper.Request) (*Request, error) { +func newRequest(conn *conn, hyperReq *hyper.Request) (*Request, error) { method := make([]byte, 32) methodLen := uintptr(len(method)) if err := hyperReq.Method(&method[0], &methodLen); err != hyper.OK { @@ -33,11 +43,45 @@ func newRequest(conn *Conn, hyperReq *hyper.Request) (*Request, error) { return nil, fmt.Errorf("failed to get URI parts: %v", err) } + var proto string + var protoMajor, protoMinor int + version := hyperReq.Version() + switch version { + case hyper.HTTPVersion10: + proto = "HTTP/1.0" + protoMajor = 1 + protoMinor = 0 + case hyper.HTTPVersion11: + proto = "HTTP/1.1" + protoMajor = 1 + protoMinor = 1 + case hyper.HTTPVersion2: + proto = "HTTP/2.0" + protoMajor = 2 + protoMinor = 0 + case hyper.HTTPVersionNone: + proto = "HTTP/0.0" + protoMajor = 0 + protoMinor = 0 + default: + return nil, fmt.Errorf("unknown HTTP version: %d", version) + } + + urlStr := fmt.Sprintf("%s://%s%s", string(scheme[:schemeLen]), string(authority[:authorityLen]), string(pathAndQuery[:pathAndQueryLen])) + url, err := url.Parse(urlStr) + if err != nil { + return nil, err + } + req := &Request{ - Conn: conn, - Method: methodStr, - URL: fmt.Sprintf("%s://%s%s", string(scheme[:schemeLen]), string(authority[:authorityLen]), string(pathAndQuery[:pathAndQueryLen])), - Header: make(Header), + Method: methodStr, + URL: url, + Proto: proto, + ProtoMajor: protoMajor, + ProtoMinor: protoMinor, + Header: make(Header), + Host: string(authority[:authorityLen]), + timeout: 0, } headers := hyperReq.Headers() @@ -47,26 +91,28 @@ func newRequest(conn *Conn, hyperReq *hyper.Request) (*Request, error) { return nil, fmt.Errorf("failed to get request headers") } - if methodStr == "POST" || methodStr == "PUT" { + if methodStr == "POST" || methodStr == "PUT" || methodStr == "PATCH" { body := hyperReq.Body() if body != nil { - // task := body.Foreach(getBodyChunk, c.Pointer(req), nil) - // if task != nil { - // r := conn.Executor.Push(task) - // if r != hyper.OK { - // task.Free() - // return nil, fmt.Errorf("failed to push body foreach task: %v", r) - // } - // } else { - // return nil, fmt.Errorf("failed to create body foreach task") - // } + var bodyWriter *io.PipeWriter + req.Body, bodyWriter = io.Pipe() + + task := body.Foreach(getBodyChunk, c.Pointer(bodyWriter), freeBodyWriter) + if task != nil { + r := conn.Executor.Push(task) + if r != hyper.OK { + task.Free() + return nil, fmt.Errorf("failed to push body foreach task: %v", r) + } + } else { + return nil, fmt.Errorf("failed to create body foreach task") + } } else { return nil, fmt.Errorf("failed to get request body") } } - return req, nil } @@ -74,16 +120,27 @@ func addHeader(data unsafe.Pointer, name *byte, nameLen uintptr, value *byte, va req := (*Request)(data) key := string(unsafe.Slice(name, nameLen)) val := string(unsafe.Slice(value, valueLen)) - req.Header.Add(key, val) + values := strings.Split(val, ",") + if len(values) > 1 { + for _, v := range values { + req.Header.Add(key, strings.TrimSpace(v)) + } + } else { + req.Header.Add(key, val) + } return hyper.IterContinue } -//TODO(hackerchai): implement body chunk reader func getBodyChunk(userdata c.Pointer, chunk *hyper.Buf) c.Int { - req := (*Request)(userdata) + writer := (*io.PipeWriter)(userdata) buf := chunk.Bytes() len := chunk.Len() - cos.Write(1, unsafe.Pointer(buf), len) + writer.Write(unsafe.Slice(buf, len)) return hyper.IterContinue -} \ No newline at end of file +} + +func freeBodyWriter(userdata c.Pointer) { + writer := (*io.PipeWriter)(userdata) + writer.Close() +} diff --git a/x/net/http/response.go b/x/net/http/response.go index d712118..b3c110c 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -1,11 +1,15 @@ package http import ( + "fmt" + "unsafe" + "github.com/goplus/llgo/c" + "github.com/goplus/llgo/c/os" "github.com/goplus/llgo/rust/hyper" ) -type Response struct { +type response struct { header Header statusCode int written bool @@ -13,18 +17,27 @@ type Response struct { channel *hyper.ResponseChannel } -func newResponse(channel *hyper.ResponseChannel) *Response { - return &Response{ +type body struct { + data []byte + len uintptr + readLen uintptr +} + +var DefaultChunkSize uintptr = 8192 + + +func newResponse(channel *hyper.ResponseChannel) *response { + return &response{ header: make(Header), channel: channel, } } -func (r *Response) Header() Header { +func (r *response) Header() Header { return r.header } -func (r *Response) Write(data []byte) (int, error) { +func (r *response) Write(data []byte) (int, error) { if !r.written { r.WriteHeader(200) } @@ -32,7 +45,7 @@ func (r *Response) Write(data []byte) (int, error) { return len(data), nil } -func (r *Response) WriteHeader(statusCode int) { +func (r *response) WriteHeader(statusCode int) { if r.written { return } @@ -43,23 +56,41 @@ func (r *Response) WriteHeader(statusCode int) { resp.SetStatus(uint16(statusCode)) headers := resp.Headers() - for k, v := range r.header { - for _, val := range v { - headers.Set(&[]byte(k)[0], uintptr(len(k)), &[]byte(val)[0], uintptr(len(val))) + for key, values := range r.header { + valueLen := len(values) + if valueLen > 1 { + for _, value := range values { + if headers.Add(&[]byte(key)[0], c.Strlen(c.AllocaCStr(key)), &[]byte(value)[0], c.Strlen(c.AllocaCStr(value))) != hyper.OK { + return + } + } + } else if valueLen == 1 { + if headers.Set(&[]byte(key)[0], c.Strlen(c.AllocaCStr(key)), &[]byte(values[0])[0], c.Strlen(c.AllocaCStr(values[0]))) != hyper.OK { + return + } + } else { + return } } r.channel.Send(resp) } -func (r *Response) finalize() error { +func (r *response) finalize() error { if !r.written { r.WriteHeader(200) } + bodyData := &body{ + data: r.body, + len: uintptr(len(r.body)), + readLen: 0, + } + body := hyper.NewBody() - //TODO(hackerchai): implement body data func - body.SetDataFunc() + body.SetUserdata(unsafe.Pointer(bodyData), nil) + + body.SetDataFunc(setBodyDataFunc) resp := hyper.NewResponse() resp.SetBody(body) @@ -68,6 +99,25 @@ func (r *Response) finalize() error { return nil } -//TODO(hackerchai): implement body chunk reader func setBodyDataFunc(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { + body := (*body)(userdata) + if body.len > 0 { + if body.len > DefaultChunkSize { + *chunk = hyper.CopyBuf(&body.data[body.readLen], DefaultChunkSize) + body.readLen += DefaultChunkSize + body.len -= DefaultChunkSize + } else { + *chunk = hyper.CopyBuf(&body.data[body.readLen], body.len) + body.readLen += body.len + body.len = 0 + } + return hyper.PollReady + } + if body.len == 0 { + *chunk = nil + return hyper.PollReady + } + + fmt.Printf("error setting body data: %s\n", c.GoString(c.Strerror(os.Errno))) + return hyper.PollError } \ No newline at end of file diff --git a/x/net/http/server.go b/x/net/http/server.go index 75e1006..f3196c3 100644 --- a/x/net/http/server.go +++ b/x/net/http/server.go @@ -29,15 +29,15 @@ type Server struct { Addr string Handler Handler - uvLoop *libuv.Loop - uvServer libuv.Tcp + uvLoop *libuv.Loop + uvServer libuv.Tcp inShutdown atomic.Bool mu sync.Mutex - activeConnections map[*Conn]struct{} + activeConnections map[*conn]struct{} } -type Conn struct { +type conn struct { Stream *libuv.Tcp PollHandle *libuv.Poll EventMask c.Uint @@ -55,6 +55,11 @@ func NewServer(addr string) *Server { } } +func ListenAndServe(addr string, handler Handler) error { + server := &Server{Addr: addr, Handler: handler} + return server.ListenAndServe() +} + func (srv *Server) ListenAndServe() error { srv.uvLoop = libuv.DefaultLoop() @@ -134,7 +139,7 @@ func (srv *Server) onNewConnection(serverStream *libuv.Stream, status c.Int) { } func (srv *Server) serverCallback(userdata unsafe.Pointer, hyperReq *hyper.Request, channel *hyper.ResponseChannel) { - conn := (*Conn)(userdata) + conn := (*conn)(userdata) if hyperReq == nil { fmt.Fprintf(os.Stderr, "Error: Received null request\n") @@ -169,11 +174,11 @@ func (srv *Server) handleTask(task *hyper.Task) { } } -func (s *Server) trackConn(c *Conn, add bool) { +func (s *Server) trackConn(c *conn, add bool) { s.mu.Lock() defer s.mu.Unlock() if s.activeConnections == nil { - s.activeConnections = make(map[*Conn]struct{}) + s.activeConnections = make(map[*conn]struct{}) } if add { s.activeConnections[c] = struct{}{} @@ -193,7 +198,7 @@ func (srv *Server) Close() error { return nil } -func createIo(conn *Conn) *hyper.Io { +func createIo(conn *conn) *hyper.Io { io := hyper.NewIo() io.SetUserdata(unsafe.Pointer(conn), freeConnData) io.SetRead(readCb) @@ -202,7 +207,7 @@ func createIo(conn *Conn) *hyper.Io { } func readCb(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uintptr) uintptr { - conn := (*Conn)(userdata) + conn := (*conn)(userdata) ret := net.Recv(conn.Stream.GetIoWatcherFd(), unsafe.Pointer(buf), bufLen, 0) if ret >= 0 { @@ -229,7 +234,7 @@ func readCb(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uintp } func writeCb(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uintptr) uintptr { - conn := (*Conn)(userdata) + conn := (*conn)(userdata) ret := net.Send(conn.Stream.GetIoWatcherFd(), unsafe.Pointer(buf), bufLen, 0) if ret >= 0 { @@ -260,7 +265,7 @@ func onClose(handle *libuv.Handle) { } func onPoll(handle *libuv.Poll, status c.Int, events c.Int) { - conn := (*Conn)((*libuv.Handle)(unsafe.Pointer(handle)).GetData()) + conn := (*conn)((*libuv.Handle)(unsafe.Pointer(handle)).GetData()) if status < 0 { fmt.Fprintf(os.Stderr, "Poll error: %s\n", libuv.Strerror(libuv.Errno(status))) @@ -278,7 +283,7 @@ func onPoll(handle *libuv.Poll, status c.Int, events c.Int) { } } -func updateConnRegistrations(conn *Conn, create bool) bool { +func updateConnRegistrations(conn *conn, create bool) bool { events := c.Int(0) if conn.EventMask&c.Uint(libuv.READABLE) != 0 { events |= c.Int(libuv.READABLE) @@ -295,8 +300,8 @@ func updateConnRegistrations(conn *Conn, create bool) bool { return true } -func createConnData(loop *libuv.Loop, client *libuv.Tcp) *Conn { - conn := (*Conn)(c.Calloc(1, unsafe.Sizeof(Conn{}))) +func createConnData(loop *libuv.Loop, client *libuv.Tcp) *conn { + conn := (*conn)(c.Calloc(1, unsafe.Sizeof(conn{}))) if conn == nil { fmt.Fprintf(os.Stderr, "Failed to allocate conn_data\n") return nil @@ -324,7 +329,7 @@ func createConnData(loop *libuv.Loop, client *libuv.Tcp) *Conn { } func freeConnData(userdata c.Pointer) { - conn := (*Conn)(userdata) + conn := (*conn)(userdata) if conn != nil && conn.IsClosing == 0 { conn.IsClosing = 1 // We don't immediately close the connection here. diff --git a/x/net/http/servermux.go b/x/net/http/servermux.go index d6fc8bf..a37b8a0 100644 --- a/x/net/http/servermux.go +++ b/x/net/http/servermux.go @@ -25,7 +25,7 @@ func (mux *ServeMux) Handler(r *Request) (h Handler, pattern string) { mux.mu.RLock() defer mux.mu.RUnlock() - h, pattern = mux.m[r.URL].h, r.URL + h, pattern = mux.m[r.URL.Path].h, r.URL.Path if h == nil { h, pattern = NotFoundHandler(), "" } diff --git a/x/net/url/url.go b/x/net/url/url.go deleted file mode 100644 index 340b74f..0000000 --- a/x/net/url/url.go +++ /dev/null @@ -1,2626 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package url parses URLs and implements query escaping. -package url - -// See RFC 3986. This package generally follows RFC 3986, except where -// it deviates for compatibility reasons. When sending changes, first -// search old issues for history on decisions. Unit tests should also -// contain references to issue numbers with details. - -import ( - "errors" - "fmt" - "path" - "slices" - "strconv" - "strings" - _ "unsafe" // for linkname -) - -// Error reports an error and the operation and URL that caused it. -type Error struct { - Op string - URL string - Err error -} - -func (e *Error) Unwrap() error { return e.Err } -func (e *Error) Error() string { return fmt.Sprintf("%s %q: %s", e.Op, e.URL, e.Err) } - -func (e *Error) Timeout() bool { - t, ok := e.Err.(interface { - Timeout() bool - }) - return ok && t.Timeout() -} - -func (e *Error) Temporary() bool { - t, ok := e.Err.(interface { - Temporary() bool - }) - return ok && t.Temporary() -} - -const upperhex = "0123456789ABCDEF" - -func ishex(c byte) bool { - switch { - case '0' <= c && c <= '9': - return true - case 'a' <= c && c <= 'f': - return true - case 'A' <= c && c <= 'F': - return true - } - return false -} - -func unhex(c byte) byte { - switch { - case '0' <= c && c <= '9': - return c - '0' - case 'a' <= c && c <= 'f': - return c - 'a' + 10 - case 'A' <= c && c <= 'F': - return c - 'A' + 10 - } - return 0 -} - -type encoding int - -const ( - encodePath encoding = 1 + iota - encodePathSegment - encodeHost - encodeZone - encodeUserPassword - encodeQueryComponent - encodeFragment -) - -type EscapeError string - -func (e EscapeError) Error() string { - return "invalid URL escape " + strconv.Quote(string(e)) -} - -type InvalidHostError string - -func (e InvalidHostError) Error() string { - return "invalid character " + strconv.Quote(string(e)) + " in host name" -} - -// Return true if the specified character should be escaped when -// appearing in a URL string, according to RFC 3986. -// -// Please be informed that for now shouldEscape does not check all -// reserved characters correctly. See golang.org/issue/5684. -func shouldEscape(c byte, mode encoding) bool { - // §2.3 Unreserved characters (alphanum) - if 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9' { - return false - } - - if mode == encodeHost || mode == encodeZone { - // §3.2.2 Host allows - // sub-delims = "!" / "$" / "&" / "'" / "(" / ")" / "*" / "+" / "," / ";" / "=" - // as part of reg-name. - // We add : because we include :port as part of host. - // We add [ ] because we include [ipv6]:port as part of host. - // We add < > because they're the only characters left that - // we could possibly allow, and Parse will reject them if we - // escape them (because hosts can't use %-encoding for - // ASCII bytes). - switch c { - case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=', ':', '[', ']', '<', '>', '"': - return false - } - } - - switch c { - case '-', '_', '.', '~': // §2.3 Unreserved characters (mark) - return false - - case '$', '&', '+', ',', '/', ':', ';', '=', '?', '@': // §2.2 Reserved characters (reserved) - // Different sections of the URL allow a few of - // the reserved characters to appear unescaped. - switch mode { - case encodePath: // §3.3 - // The RFC allows : @ & = + $ but saves / ; , for assigning - // meaning to individual path segments. This package - // only manipulates the path as a whole, so we allow those - // last three as well. That leaves only ? to escape. - return c == '?' - - case encodePathSegment: // §3.3 - // The RFC allows : @ & = + $ but saves / ; , for assigning - // meaning to individual path segments. - return c == '/' || c == ';' || c == ',' || c == '?' - - case encodeUserPassword: // §3.2.1 - // The RFC allows ';', ':', '&', '=', '+', '$', and ',' in - // userinfo, so we must escape only '@', '/', and '?'. - // The parsing of userinfo treats ':' as special so we must escape - // that too. - return c == '@' || c == '/' || c == '?' || c == ':' - - case encodeQueryComponent: // §3.4 - // The RFC reserves (so we must escape) everything. - return true - - case encodeFragment: // §4.1 - // The RFC text is silent but the grammar allows - // everything, so escape nothing. - return false - } - } - - if mode == encodeFragment { - // RFC 3986 §2.2 allows not escaping sub-delims. A subset of sub-delims are - // included in reserved from RFC 2396 §2.2. The remaining sub-delims do not - // need to be escaped. To minimize potential breakage, we apply two restrictions: - // (1) we always escape sub-delims outside of the fragment, and (2) we always - // escape single quote to avoid breaking callers that had previously assumed that - // single quotes would be escaped. See issue #19917. - switch c { - case '!', '(', ')', '*': - return false - } - } - - // Everything else must be escaped. - return true -} - -// QueryUnescape does the inverse transformation of [QueryEscape], -// converting each 3-byte encoded substring of the form "%AB" into the -// hex-decoded byte 0xAB. -// It returns an error if any % is not followed by two hexadecimal -// digits. -func QueryUnescape(s string) (string, error) { - return unescape(s, encodeQueryComponent) -} - -// PathUnescape does the inverse transformation of [PathEscape], -// converting each 3-byte encoded substring of the form "%AB" into the -// hex-decoded byte 0xAB. It returns an error if any % is not followed -// by two hexadecimal digits. -// -// PathUnescape is identical to [QueryUnescape] except that it does not -// unescape '+' to ' ' (space). -func PathUnescape(s string) (string, error) { - return unescape(s, encodePathSegment) -} - -// unescape unescapes a string; the mode specifies -// which section of the URL string is being unescaped. -func unescape(s string, mode encoding) (string, error) { - // Count %, check that they're well-formed. - n := 0 - hasPlus := false - for i := 0; i < len(s); { - switch s[i] { - case '%': - n++ - if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) { - s = s[i:] - if len(s) > 3 { - s = s[:3] - } - return "", EscapeError(s) - } - // Per https://tools.ietf.org/html/rfc3986#page-21 - // in the host component %-encoding can only be used - // for non-ASCII bytes. - // But https://tools.ietf.org/html/rfc6874#section-2 - // introduces %25 being allowed to escape a percent sign - // in IPv6 scoped-address literals. Yay. - if mode == encodeHost && unhex(s[i+1]) < 8 && s[i:i+3] != "%25" { - return "", EscapeError(s[i : i+3]) - } - if mode == encodeZone { - // RFC 6874 says basically "anything goes" for zone identifiers - // and that even non-ASCII can be redundantly escaped, - // but it seems prudent to restrict %-escaped bytes here to those - // that are valid host name bytes in their unescaped form. - // That is, you can use escaping in the zone identifier but not - // to introduce bytes you couldn't just write directly. - // But Windows puts spaces here! Yay. - v := unhex(s[i+1])<<4 | unhex(s[i+2]) - if s[i:i+3] != "%25" && v != ' ' && shouldEscape(v, encodeHost) { - return "", EscapeError(s[i : i+3]) - } - } - i += 3 - case '+': - hasPlus = mode == encodeQueryComponent - i++ - default: - if (mode == encodeHost || mode == encodeZone) && s[i] < 0x80 && shouldEscape(s[i], mode) { - return "", InvalidHostError(s[i : i+1]) - } - i++ - } - } - - if n == 0 && !hasPlus { - return s, nil - } - - var t strings.Builder - t.Grow(len(s) - 2*n) - for i := 0; i < len(s); i++ { - switch s[i] { - case '%': - t.WriteByte(unhex(s[i+1])<<4 | unhex(s[i+2])) - i += 2 - case '+': - if mode == encodeQueryComponent { - t.WriteByte(' ') - } else { - t.WriteByte('+') - } - default: - t.WriteByte(s[i]) - } - } - return t.String(), nil -} - -// QueryEscape escapes the string so it can be safely placed -// inside a [URL] query. -func QueryEscape(s string) string { - return escape(s, encodeQueryComponent) -} - -// PathEscape escapes the string so it can be safely placed inside a [URL] path segment, -// replacing special characters (including /) with %XX sequences as needed. -func PathEscape(s string) string { - return escape(s, encodePathSegment) -} - -func escape(s string, mode encoding) string { - spaceCount, hexCount := 0, 0 - for i := 0; i < len(s); i++ { - c := s[i] - if shouldEscape(c, mode) { - if c == ' ' && mode == encodeQueryComponent { - spaceCount++ - } else { - hexCount++ - } - } - } - - if spaceCount == 0 && hexCount == 0 { - return s - } - - var buf [64]byte - var t []byte - // Copyright 2009 The Go Authors. All rights reserved. - // Use of this source code is governed by a BSD-style - // license that can be found in the LICENSE file. - - // Package url parses URLs and implements query escaping. - package url - - // See RFC 3986. This package generally follows RFC 3986, except where - // it deviates for compatibility reasons. When sending changes, first - // search old issues for history on decisions. Unit tests should also - // contain references to issue numbers with details. - - import ( - "errors" - "fmt" - "path" - "slices" - "strconv" - "strings" - _ "unsafe" // for linkname - ) - - // Error reports an error and the operation and URL that caused it. - type Error struct { - Op string - URL string - Err error - } - - func (e *Error) Unwrap() error { return e.Err } - func (e *Error) Error() string { return fmt.Sprintf("%s %q: %s", e.Op, e.URL, e.Err) } - - func (e *Error) Timeout() bool { - t, ok := e.Err.(interface { - Timeout() bool - }) - return ok && t.Timeout() - } - - func (e *Error) Temporary() bool { - t, ok := e.Err.(interface { - Temporary() bool - }) - return ok && t.Temporary() - } - - const upperhex = "0123456789ABCDEF" - - func ishex(c byte) bool { - switch { - case '0' <= c && c <= '9': - return true - case 'a' <= c && c <= 'f': - return true - case 'A' <= c && c <= 'F': - return true - } - return false - } - - func unhex(c byte) byte { - switch { - case '0' <= c && c <= '9': - return c - '0' - case 'a' <= c && c <= 'f': - return c - 'a' + 10 - case 'A' <= c && c <= 'F': - return c - 'A' + 10 - } - return 0 - } - - type encoding int - - const ( - encodePath encoding = 1 + iota - encodePathSegment - encodeHost - encodeZone - encodeUserPassword - encodeQueryComponent - encodeFragment - ) - - type EscapeError string - - func (e EscapeError) Error() string { - return "invalid URL escape " + strconv.Quote(string(e)) - } - - type InvalidHostError string - - func (e InvalidHostError) Error() string { - return "invalid character " + strconv.Quote(string(e)) + " in host name" - } - - // Return true if the specified character should be escaped when - // appearing in a URL string, according to RFC 3986. - // - // Please be informed that for now shouldEscape does not check all - // reserved characters correctly. See golang.org/issue/5684. - func shouldEscape(c byte, mode encoding) bool { - // §2.3 Unreserved characters (alphanum) - if 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9' { - return false - } - - if mode == encodeHost || mode == encodeZone { - // §3.2.2 Host allows - // sub-delims = "!" / "$" / "&" / "'" / "(" / ")" / "*" / "+" / "," / ";" / "=" - // as part of reg-name. - // We add : because we include :port as part of host. - // We add [ ] because we include [ipv6]:port as part of host. - // We add < > because they're the only characters left that - // we could possibly allow, and Parse will reject them if we - // escape them (because hosts can't use %-encoding for - // ASCII bytes). - switch c { - case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=', ':', '[', ']', '<', '>', '"': - return false - } - } - - switch c { - case '-', '_', '.', '~': // §2.3 Unreserved characters (mark) - return false - - case '$', '&', '+', ',', '/', ':', ';', '=', '?', '@': // §2.2 Reserved characters (reserved) - // Different sections of the URL allow a few of - // the reserved characters to appear unescaped. - switch mode { - case encodePath: // §3.3 - // The RFC allows : @ & = + $ but saves / ; , for assigning - // meaning to individual path segments. This package - // only manipulates the path as a whole, so we allow those - // last three as well. That leaves only ? to escape. - return c == '?' - - case encodePathSegment: // §3.3 - // The RFC allows : @ & = + $ but saves / ; , for assigning - // meaning to individual path segments. - return c == '/' || c == ';' || c == ',' || c == '?' - - case encodeUserPassword: // §3.2.1 - // The RFC allows ';', ':', '&', '=', '+', '$', and ',' in - // userinfo, so we must escape only '@', '/', and '?'. - // The parsing of userinfo treats ':' as special so we must escape - // that too. - return c == '@' || c == '/' || c == '?' || c == ':' - - case encodeQueryComponent: // §3.4 - // The RFC reserves (so we must escape) everything. - return true - - case encodeFragment: // §4.1 - // The RFC text is silent but the grammar allows - // everything, so escape nothing. - return false - } - } - - if mode == encodeFragment { - // RFC 3986 §2.2 allows not escaping sub-delims. A subset of sub-delims are - // included in reserved from RFC 2396 §2.2. The remaining sub-delims do not - // need to be escaped. To minimize potential breakage, we apply two restrictions: - // (1) we always escape sub-delims outside of the fragment, and (2) we always - // escape single quote to avoid breaking callers that had previously assumed that - // single quotes would be escaped. See issue #19917. - switch c { - case '!', '(', ')', '*': - return false - } - } - - // Everything else must be escaped. - return true - } - - // QueryUnescape does the inverse transformation of [QueryEscape], - // converting each 3-byte encoded substring of the form "%AB" into the - // hex-decoded byte 0xAB. - // It returns an error if any % is not followed by two hexadecimal - // digits. - func QueryUnescape(s string) (string, error) { - return unescape(s, encodeQueryComponent) - } - - // PathUnescape does the inverse transformation of [PathEscape], - // converting each 3-byte encoded substring of the form "%AB" into the - // hex-decoded byte 0xAB. It returns an error if any % is not followed - // by two hexadecimal digits. - // - // PathUnescape is identical to [QueryUnescape] except that it does not - // unescape '+' to ' ' (space). - func PathUnescape(s string) (string, error) { - return unescape(s, encodePathSegment) - } - - // unescape unescapes a string; the mode specifies - // which section of the URL string is being unescaped. - func unescape(s string, mode encoding) (string, error) { - // Count %, check that they're well-formed. - n := 0 - hasPlus := false - for i := 0; i < len(s); { - switch s[i] { - case '%': - n++ - if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) { - s = s[i:] - if len(s) > 3 { - s = s[:3] - } - return "", EscapeError(s) - } - // Per https://tools.ietf.org/html/rfc3986#page-21 - // in the host component %-encoding can only be used - // for non-ASCII bytes. - // But https://tools.ietf.org/html/rfc6874#section-2 - // introduces %25 being allowed to escape a percent sign - // in IPv6 scoped-address literals. Yay. - if mode == encodeHost && unhex(s[i+1]) < 8 && s[i:i+3] != "%25" { - return "", EscapeError(s[i : i+3]) - } - if mode == encodeZone { - // RFC 6874 says basically "anything goes" for zone identifiers - // and that even non-ASCII can be redundantly escaped, - // but it seems prudent to restrict %-escaped bytes here to those - // that are valid host name bytes in their unescaped form. - // That is, you can use escaping in the zone identifier but not - // to introduce bytes you couldn't just write directly. - // But Windows puts spaces here! Yay. - v := unhex(s[i+1])<<4 | unhex(s[i+2]) - if s[i:i+3] != "%25" && v != ' ' && shouldEscape(v, encodeHost) { - return "", EscapeError(s[i : i+3]) - } - } - i += 3 - case '+': - hasPlus = mode == encodeQueryComponent - i++ - default: - if (mode == encodeHost || mode == encodeZone) && s[i] < 0x80 && shouldEscape(s[i], mode) { - return "", InvalidHostError(s[i : i+1]) - } - i++ - } - } - - if n == 0 && !hasPlus { - return s, nil - } - - var t strings.Builder - t.Grow(len(s) - 2*n) - for i := 0; i < len(s); i++ { - switch s[i] { - case '%': - t.WriteByte(unhex(s[i+1])<<4 | unhex(s[i+2])) - i += 2 - case '+': - if mode == encodeQueryComponent { - t.WriteByte(' ') - } else { - t.WriteByte('+') - } - default: - t.WriteByte(s[i]) - } - } - return t.String(), nil - } - - // QueryEscape escapes the string so it can be safely placed - // inside a [URL] query. - func QueryEscape(s string) string { - return escape(s, encodeQueryComponent) - } - - // PathEscape escapes the string so it can be safely placed inside a [URL] path segment, - // replacing special characters (including /) with %XX sequences as needed. - func PathEscape(s string) string { - return escape(s, encodePathSegment) - } - - func escape(s string, mode encoding) string { - spaceCount, hexCount := 0, 0 - for i := 0; i < len(s); i++ { - c := s[i] - if shouldEscape(c, mode) { - if c == ' ' && mode == encodeQueryComponent { - spaceCount++ - } else { - hexCount++ - } - } - } - - if spaceCount == 0 && hexCount == 0 { - return s - } - - var buf [64]byte - var t []byte - - required := len(s) + 2*hexCount - if required <= len(buf) { - t = buf[:required] - } else { - t = make([]byte, required) - } - - if hexCount == 0 { - copy(t, s) - for i := 0; i < len(s); i++ { - if s[i] == ' ' { - t[i] = '+' - } - } - return string(t) - } - - j := 0 - for i := 0; i < len(s); i++ { - switch c := s[i]; { - case c == ' ' && mode == encodeQueryComponent: - t[j] = '+' - j++ - case shouldEscape(c, mode): - t[j] = '%' - t[j+1] = upperhex[c>>4] - t[j+2] = upperhex[c&15] - j += 3 - default: - t[j] = s[i] - j++ - } - } - return string(t) - } - - // A URL represents a parsed URL (technically, a URI reference). - // - // The general form represented is: - // - // [scheme:][//[userinfo@]host][/]path[?query][#fragment] - // - // URLs that do not start with a slash after the scheme are interpreted as: - // - // scheme:opaque[?query][#fragment] - // - // The Host field contains the host and port subcomponents of the URL. - // When the port is present, it is separated from the host with a colon. - // When the host is an IPv6 address, it must be enclosed in square brackets: - // "[fe80::1]:80". The [net.JoinHostPort] function combines a host and port - // into a string suitable for the Host field, adding square brackets to - // the host when necessary. - // - // Note that the Path field is stored in decoded form: /%47%6f%2f becomes /Go/. - // A consequence is that it is impossible to tell which slashes in the Path were - // slashes in the raw URL and which were %2f. This distinction is rarely important, - // but when it is, the code should use the [URL.EscapedPath] method, which preserves - // the original encoding of Path. - // - // The RawPath field is an optional field which is only set when the default - // encoding of Path is different from the escaped path. See the EscapedPath method - // for more details. - // - // URL's String method uses the EscapedPath method to obtain the path. - type URL struct { - Scheme string - Opaque string // encoded opaque data - User *Userinfo // username and password information - Host string // host or host:port (see Hostname and Port methods) - Path string // path (relative paths may omit leading slash) - RawPath string // encoded path hint (see EscapedPath method) - OmitHost bool // do not emit empty host (authority) - ForceQuery bool // append a query ('?') even if RawQuery is empty - RawQuery string // encoded query values, without '?' - Fragment string // fragment for references, without '#' - RawFragment string // encoded fragment hint (see EscapedFragment method) - } - - // User returns a [Userinfo] containing the provided username - // and no password set. - func User(username string) *Userinfo { - return &Userinfo{username, "", false} - } - - // UserPassword returns a [Userinfo] containing the provided username - // and password. - // - // This functionality should only be used with legacy web sites. - // RFC 2396 warns that interpreting Userinfo this way - // “is NOT RECOMMENDED, because the passing of authentication - // information in clear text (such as URI) has proven to be a - // security risk in almost every case where it has been used.” - func UserPassword(username, password string) *Userinfo { - return &Userinfo{username, password, true} - } - - // The Userinfo type is an immutable encapsulation of username and - // password details for a [URL]. An existing Userinfo value is guaranteed - // to have a username set (potentially empty, as allowed by RFC 2396), - // and optionally a password. - type Userinfo struct { - username string - password string - passwordSet bool - } - - // Username returns the username. - func (u *Userinfo) Username() string { - if u == nil { - return "" - } - return u.username - } - - // Password returns the password in case it is set, and whether it is set. - func (u *Userinfo) Password() (string, bool) { - if u == nil { - return "", false - } - return u.password, u.passwordSet - } - - // String returns the encoded userinfo information in the standard form - // of "username[:password]". - func (u *Userinfo) String() string { - if u == nil { - return "" - } - s := escape(u.username, encodeUserPassword) - if u.passwordSet { - s += ":" + escape(u.password, encodeUserPassword) - } - return s - } - - // Maybe rawURL is of the form scheme:path. - // (Scheme must be [a-zA-Z][a-zA-Z0-9+.-]*) - // If so, return scheme, path; else return "", rawURL. - func getScheme(rawURL string) (scheme, path string, err error) { - for i := 0; i < len(rawURL); i++ { - c := rawURL[i] - switch { - case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z': - // do nothing - case '0' <= c && c <= '9' || c == '+' || c == '-' || c == '.': - if i == 0 { - return "", rawURL, nil - } - case c == ':': - if i == 0 { - return "", "", errors.New("missing protocol scheme") - } - return rawURL[:i], rawURL[i+1:], nil - default: - // we have encountered an invalid character, - // so there is no valid scheme - return "", rawURL, nil - } - } - return "", rawURL, nil - } - - // Parse parses a raw url into a [URL] structure. - // - // The url may be relative (a path, without a host) or absolute - // (starting with a scheme). Trying to parse a hostname and path - // without a scheme is invalid but may not necessarily return an - // error, due to parsing ambiguities. - func Parse(rawURL string) (*URL, error) { - // Cut off #frag - u, frag, _ := strings.Cut(rawURL, "#") - url, err := parse(u, false) - if err != nil { - return nil, &Error{"parse", u, err} - } - if frag == "" { - return url, nil - } - if err = url.setFragment(frag); err != nil { - return nil, &Error{"parse", rawURL, err} - } - return url, nil - } - - // ParseRequestURI parses a raw url into a [URL] structure. It assumes that - // url was received in an HTTP request, so the url is interpreted - // only as an absolute URI or an absolute path. - // The string url is assumed not to have a #fragment suffix. - // (Web browsers strip #fragment before sending the URL to a web server.) - func ParseRequestURI(rawURL string) (*URL, error) { - url, err := parse(rawURL, true) - if err != nil { - return nil, &Error{"parse", rawURL, err} - } - return url, nil - } - - // parse parses a URL from a string in one of two contexts. If - // viaRequest is true, the URL is assumed to have arrived via an HTTP request, - // in which case only absolute URLs or path-absolute relative URLs are allowed. - // If viaRequest is false, all forms of relative URLs are allowed. - func parse(rawURL string, viaRequest bool) (*URL, error) { - var rest string - var err error - - if stringContainsCTLByte(rawURL) { - return nil, errors.New("net/url: invalid control character in URL") - } - - if rawURL == "" && viaRequest { - return nil, errors.New("empty url") - } - url := new(URL) - - if rawURL == "*" { - url.Path = "*" - return url, nil - } - - // Split off possible leading "http:", "mailto:", etc. - // Cannot contain escaped characters. - if url.Scheme, rest, err = getScheme(rawURL); err != nil { - return nil, err - } - url.Scheme = strings.ToLower(url.Scheme) - - if strings.HasSuffix(rest, "?") && strings.Count(rest, "?") == 1 { - url.ForceQuery = true - rest = rest[:len(rest)-1] - } else { - rest, url.RawQuery, _ = strings.Cut(rest, "?") - } - - if !strings.HasPrefix(rest, "/") { - if url.Scheme != "" { - // We consider rootless paths per RFC 3986 as opaque. - url.Opaque = rest - return url, nil - } - if viaRequest { - return nil, errors.New("invalid URI for request") - } - - // Avoid confusion with malformed schemes, like cache_object:foo/bar. - // See golang.org/issue/16822. - // - // RFC 3986, §3.3: - // In addition, a URI reference (Section 4.1) may be a relative-path reference, - // in which case the first path segment cannot contain a colon (":") character. - if segment, _, _ := strings.Cut(rest, "/"); strings.Contains(segment, ":") { - // First path segment has colon. Not allowed in relative URL. - return nil, errors.New("first path segment in URL cannot contain colon") - } - } - - if (url.Scheme != "" || !viaRequest && !strings.HasPrefix(rest, "///")) && strings.HasPrefix(rest, "//") { - var authority string - authority, rest = rest[2:], "" - if i := strings.Index(authority, "/"); i >= 0 { - authority, rest = authority[:i], authority[i:] - } - url.User, url.Host, err = parseAuthority(authority) - if err != nil { - return nil, err - } - } else if url.Scheme != "" && strings.HasPrefix(rest, "/") { - // OmitHost is set to true when rawURL has an empty host (authority). - // See golang.org/issue/46059. - url.OmitHost = true - } - - // Set Path and, optionally, RawPath. - // RawPath is a hint of the encoding of Path. We don't want to set it if - // the default escaping of Path is equivalent, to help make sure that people - // don't rely on it in general. - if err := url.setPath(rest); err != nil { - return nil, err - } - return url, nil - } - - func parseAuthority(authority string) (user *Userinfo, host string, err error) { - i := strings.LastIndex(authority, "@") - if i < 0 { - host, err = parseHost(authority) - } else { - host, err = parseHost(authority[i+1:]) - } - if err != nil { - return nil, "", err - } - if i < 0 { - return nil, host, nil - } - userinfo := authority[:i] - if !validUserinfo(userinfo) { - return nil, "", errors.New("net/url: invalid userinfo") - } - if !strings.Contains(userinfo, ":") { - if userinfo, err = unescape(userinfo, encodeUserPassword); err != nil { - return nil, "", err - } - user = User(userinfo) - } else { - username, password, _ := strings.Cut(userinfo, ":") - if username, err = unescape(username, encodeUserPassword); err != nil { - return nil, "", err - } - if password, err = unescape(password, encodeUserPassword); err != nil { - return nil, "", err - } - user = UserPassword(username, password) - } - return user, host, nil - } - - // parseHost parses host as an authority without user - // information. That is, as host[:port]. - func parseHost(host string) (string, error) { - if strings.HasPrefix(host, "[") { - // Parse an IP-Literal in RFC 3986 and RFC 6874. - // E.g., "[fe80::1]", "[fe80::1%25en0]", "[fe80::1]:80". - i := strings.LastIndex(host, "]") - if i < 0 { - return "", errors.New("missing ']' in host") - } - colonPort := host[i+1:] - if !validOptionalPort(colonPort) { - return "", fmt.Errorf("invalid port %q after host", colonPort) - } - - // RFC 6874 defines that %25 (%-encoded percent) introduces - // the zone identifier, and the zone identifier can use basically - // any %-encoding it likes. That's different from the host, which - // can only %-encode non-ASCII bytes. - // We do impose some restrictions on the zone, to avoid stupidity - // like newlines. - zone := strings.Index(host[:i], "%25") - if zone >= 0 { - host1, err := unescape(host[:zone], encodeHost) - if err != nil { - return "", err - } - host2, err := unescape(host[zone:i], encodeZone) - if err != nil { - return "", err - } - host3, err := unescape(host[i:], encodeHost) - if err != nil { - return "", err - } - return host1 + host2 + host3, nil - } - } else if i := strings.LastIndex(host, ":"); i != -1 { - colonPort := host[i:] - if !validOptionalPort(colonPort) { - return "", fmt.Errorf("invalid port %q after host", colonPort) - } - } - - var err error - if host, err = unescape(host, encodeHost); err != nil { - return "", err - } - return host, nil - } - - // setPath sets the Path and RawPath fields of the URL based on the provided - // escaped path p. It maintains the invariant that RawPath is only specified - // when it differs from the default encoding of the path. - // For example: - // - setPath("/foo/bar") will set Path="/foo/bar" and RawPath="" - // - setPath("/foo%2fbar") will set Path="/foo/bar" and RawPath="/foo%2fbar" - // setPath will return an error only if the provided path contains an invalid - // escaping. - // - // setPath should be an internal detail, - // but widely used packages access it using linkname. - // Notable members of the hall of shame include: - // - github.com/sagernet/sing - // - // Do not remove or change the type signature. - // See go.dev/issue/67401. - // - //go:linkname badSetPath net/url.(*URL).setPath - func (u *URL) setPath(p string) error { - path, err := unescape(p, encodePath) - if err != nil { - return err - } - u.Path = path - if escp := escape(path, encodePath); p == escp { - // Default encoding is fine. - u.RawPath = "" - } else { - u.RawPath = p - } - return nil - } - - // for linkname because we cannot linkname methods directly - func badSetPath(*URL, string) error - - // EscapedPath returns the escaped form of u.Path. - // In general there are multiple possible escaped forms of any path. - // EscapedPath returns u.RawPath when it is a valid escaping of u.Path. - // Otherwise EscapedPath ignores u.RawPath and computes an escaped - // form on its own. - // The [URL.String] and [URL.RequestURI] methods use EscapedPath to construct - // their results. - // In general, code should call EscapedPath instead of - // reading u.RawPath directly. - func (u *URL) EscapedPath() string { - if u.RawPath != "" && validEncoded(u.RawPath, encodePath) { - p, err := unescape(u.RawPath, encodePath) - if err == nil && p == u.Path { - return u.RawPath - } - } - if u.Path == "*" { - return "*" // don't escape (Issue 11202) - } - return escape(u.Path, encodePath) - } - - // validEncoded reports whether s is a valid encoded path or fragment, - // according to mode. - // It must not contain any bytes that require escaping during encoding. - func validEncoded(s string, mode encoding) bool { - for i := 0; i < len(s); i++ { - // RFC 3986, Appendix A. - // pchar = unreserved / pct-encoded / sub-delims / ":" / "@". - // shouldEscape is not quite compliant with the RFC, - // so we check the sub-delims ourselves and let - // shouldEscape handle the others. - switch s[i] { - case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=', ':', '@': - // ok - case '[', ']': - // ok - not specified in RFC 3986 but left alone by modern browsers - case '%': - // ok - percent encoded, will decode - default: - if shouldEscape(s[i], mode) { - return false - } - } - } - return true - } - - // setFragment is like setPath but for Fragment/RawFragment. - func (u *URL) setFragment(f string) error { - frag, err := unescape(f, encodeFragment) - if err != nil { - return err - } - u.Fragment = frag - if escf := escape(frag, encodeFragment); f == escf { - // Default encoding is fine. - u.RawFragment = "" - } else { - u.RawFragment = f - } - return nil - } - - // EscapedFragment returns the escaped form of u.Fragment. - // In general there are multiple possible escaped forms of any fragment. - // EscapedFragment returns u.RawFragment when it is a valid escaping of u.Fragment. - // Otherwise EscapedFragment ignores u.RawFragment and computes an escaped - // form on its own. - // The [URL.String] method uses EscapedFragment to construct its result. - // In general, code should call EscapedFragment instead of - // reading u.RawFragment directly. - func (u *URL) EscapedFragment() string { - if u.RawFragment != "" && validEncoded(u.RawFragment, encodeFragment) { - f, err := unescape(u.RawFragment, encodeFragment) - if err == nil && f == u.Fragment { - return u.RawFragment - } - } - return escape(u.Fragment, encodeFragment) - } - - // validOptionalPort reports whether port is either an empty string - // or matches /^:\d*$/ - func validOptionalPort(port string) bool { - if port == "" { - return true - } - if port[0] != ':' { - return false - } - for _, b := range port[1:] { - if b < '0' || b > '9' { - return false - } - } - return true - } - - // String reassembles the [URL] into a valid URL string. - // The general form of the result is one of: - // - // scheme:opaque?query#fragment - // scheme://userinfo@host/path?query#fragment - // - // If u.Opaque is non-empty, String uses the first form; - // otherwise it uses the second form. - // Any non-ASCII characters in host are escaped. - // To obtain the path, String uses u.EscapedPath(). - // - // In the second form, the following rules apply: - // - if u.Scheme is empty, scheme: is omitted. - // - if u.User is nil, userinfo@ is omitted. - // - if u.Host is empty, host/ is omitted. - // - if u.Scheme and u.Host are empty and u.User is nil, - // the entire scheme://userinfo@host/ is omitted. - // - if u.Host is non-empty and u.Path begins with a /, - // the form host/path does not add its own /. - // - if u.RawQuery is empty, ?query is omitted. - // - if u.Fragment is empty, #fragment is omitted. - func (u *URL) String() string { - var buf strings.Builder - - n := len(u.Scheme) - if u.Opaque != "" { - n += len(u.Opaque) - } else { - if !u.OmitHost && (u.Scheme != "" || u.Host != "" || u.User != nil) { - username := u.User.Username() - password, _ := u.User.Password() - n += len(username) + len(password) + len(u.Host) - } - n += len(u.Path) - } - n += len(u.RawQuery) + len(u.RawFragment) - n += len(":" + "//" + "//" + ":" + "@" + "/" + "./" + "?" + "#") - buf.Grow(n) - - if u.Scheme != "" { - buf.WriteString(u.Scheme) - buf.WriteByte(':') - } - if u.Opaque != "" { - buf.WriteString(u.Opaque) - } else { - if u.Scheme != "" || u.Host != "" || u.User != nil { - if u.OmitHost && u.Host == "" && u.User == nil { - // omit empty host - } else { - if u.Host != "" || u.Path != "" || u.User != nil { - buf.WriteString("//") - } - if ui := u.User; ui != nil { - buf.WriteString(ui.String()) - buf.WriteByte('@') - } - if h := u.Host; h != "" { - buf.WriteString(escape(h, encodeHost)) - } - } - } - path := u.EscapedPath() - if path != "" && path[0] != '/' && u.Host != "" { - buf.WriteByte('/') - } - if buf.Len() == 0 { - // RFC 3986 §4.2 - // A path segment that contains a colon character (e.g., "this:that") - // cannot be used as the first segment of a relative-path reference, as - // it would be mistaken for a scheme name. Such a segment must be - // preceded by a dot-segment (e.g., "./this:that") to make a relative- - // path reference. - if segment, _, _ := strings.Cut(path, "/"); strings.Contains(segment, ":") { - buf.WriteString("./") - } - } - buf.WriteString(path) - } - if u.ForceQuery || u.RawQuery != "" { - buf.WriteByte('?') - buf.WriteString(u.RawQuery) - } - if u.Fragment != "" { - buf.WriteByte('#') - buf.WriteString(u.EscapedFragment()) - } - return buf.String() - } - - // Redacted is like [URL.String] but replaces any password with "xxxxx". - // Only the password in u.User is redacted. - func (u *URL) Redacted() string { - if u == nil { - return "" - } - - ru := *u - if _, has := ru.User.Password(); has { - ru.User = UserPassword(ru.User.Username(), "xxxxx") - } - return ru.String() - } - - // Values maps a string key to a list of values. - // It is typically used for query parameters and form values. - // Unlike in the http.Header map, the keys in a Values map - // are case-sensitive. - type Values map[string][]string - - // Get gets the first value associated with the given key. - // If there are no values associated with the key, Get returns - // the empty string. To access multiple values, use the map - // directly. - func (v Values) Get(key string) string { - vs := v[key] - if len(vs) == 0 { - return "" - } - return vs[0] - } - - // Set sets the key to value. It replaces any existing - // values. - func (v Values) Set(key, value string) { - v[key] = []string{value} - } - - // Add adds the value to key. It appends to any existing - // values associated with key. - func (v Values) Add(key, value string) { - v[key] = append(v[key], value) - } - - // Del deletes the values associated with key. - func (v Values) Del(key string) { - delete(v, key) - } - - // Has checks whether a given key is set. - func (v Values) Has(key string) bool { - _, ok := v[key] - return ok - } - - // ParseQuery parses the URL-encoded query string and returns - // a map listing the values specified for each key. - // ParseQuery always returns a non-nil map containing all the - // valid query parameters found; err describes the first decoding error - // encountered, if any. - // - // Query is expected to be a list of key=value settings separated by ampersands. - // A setting without an equals sign is interpreted as a key set to an empty - // value. - // Settings containing a non-URL-encoded semicolon are considered invalid. - func ParseQuery(query string) (Values, error) { - m := make(Values) - err := parseQuery(m, query) - return m, err - } - - func parseQuery(m Values, query string) (err error) { - for query != "" { - var key string - key, query, _ = strings.Cut(query, "&") - if strings.Contains(key, ";") { - err = fmt.Errorf("invalid semicolon separator in query") - continue - } - if key == "" { - continue - } - key, value, _ := strings.Cut(key, "=") - key, err1 := QueryUnescape(key) - if err1 != nil { - if err == nil { - err = err1 - } - continue - } - value, err1 = QueryUnescape(value) - if err1 != nil { - if err == nil { - err = err1 - } - continue - } - m[key] = append(m[key], value) - } - return err - } - - // Encode encodes the values into “URL encoded” form - // ("bar=baz&foo=quux") sorted by key. - func (v Values) Encode() string { - if len(v) == 0 { - return "" - } - var buf strings.Builder - keys := make([]string, 0, len(v)) - for k := range v { - keys = append(keys, k) - } - slices.Sort(keys) - for _, k := range keys { - vs := v[k] - keyEscaped := QueryEscape(k) - for _, v := range vs { - if buf.Len() > 0 { - buf.WriteByte('&') - } - buf.WriteString(keyEscaped) - buf.WriteByte('=') - buf.WriteString(QueryEscape(v)) - } - } - return buf.String() - } - - // resolvePath applies special path segments from refs and applies - // them to base, per RFC 3986. - func resolvePath(base, ref string) string { - var full string - if ref == "" { - full = base - } else if ref[0] != '/' { - i := strings.LastIndex(base, "/") - full = base[:i+1] + ref - } else { - full = ref - } - if full == "" { - return "" - } - - var ( - elem string - dst strings.Builder - ) - first := true - remaining := full - // We want to return a leading '/', so write it now. - dst.WriteByte('/') - found := true - for found { - elem, remaining, found = strings.Cut(remaining, "/") - if elem == "." { - first = false - // drop - continue - } - - if elem == ".." { - // Ignore the leading '/' we already wrote. - str := dst.String()[1:] - index := strings.LastIndexByte(str, '/') - - dst.Reset() - dst.WriteByte('/') - if index == -1 { - first = true - } else { - dst.WriteString(str[:index]) - } - } else { - if !first { - dst.WriteByte('/') - } - dst.WriteString(elem) - first = false - } - } - - if elem == "." || elem == ".." { - dst.WriteByte('/') - } - - // We wrote an initial '/', but we don't want two. - r := dst.String() - if len(r) > 1 && r[1] == '/' { - r = r[1:] - } - return r - } - - // IsAbs reports whether the [URL] is absolute. - // Absolute means that it has a non-empty scheme. - func (u *URL) IsAbs() bool { - return u.Scheme != "" - } - - // Parse parses a [URL] in the context of the receiver. The provided URL - // may be relative or absolute. Parse returns nil, err on parse - // failure, otherwise its return value is the same as [URL.ResolveReference]. - func (u *URL) Parse(ref string) (*URL, error) { - refURL, err := Parse(ref) - if err != nil { - return nil, err - } - return u.ResolveReference(refURL), nil - } - - // ResolveReference resolves a URI reference to an absolute URI from - // an absolute base URI u, per RFC 3986 Section 5.2. The URI reference - // may be relative or absolute. ResolveReference always returns a new - // [URL] instance, even if the returned URL is identical to either the - // base or reference. If ref is an absolute URL, then ResolveReference - // ignores base and returns a copy of ref. - func (u *URL) ResolveReference(ref *URL) *URL { - url := *ref - if ref.Scheme == "" { - url.Scheme = u.Scheme - } - if ref.Scheme != "" || ref.Host != "" || ref.User != nil { - // The "absoluteURI" or "net_path" cases. - // We can ignore the error from setPath since we know we provided a - // validly-escaped path. - url.setPath(resolvePath(ref.EscapedPath(), "")) - return &url - } - if ref.Opaque != "" { - url.User = nil - url.Host = "" - url.Path = "" - return &url - } - if ref.Path == "" && !ref.ForceQuery && ref.RawQuery == "" { - url.RawQuery = u.RawQuery - if ref.Fragment == "" { - url.Fragment = u.Fragment - url.RawFragment = u.RawFragment - } - } - if ref.Path == "" && u.Opaque != "" { - url.Opaque = u.Opaque - url.User = nil - url.Host = "" - url.Path = "" - return &url - } - // The "abs_path" or "rel_path" cases. - url.Host = u.Host - url.User = u.User - url.setPath(resolvePath(u.EscapedPath(), ref.EscapedPath())) - return &url - } - - // Query parses RawQuery and returns the corresponding values. - // It silently discards malformed value pairs. - // To check errors use [ParseQuery]. - func (u *URL) Query() Values { - v, _ := ParseQuery(u.RawQuery) - return v - } - - // RequestURI returns the encoded path?query or opaque?query - // string that would be used in an HTTP request for u. - func (u *URL) RequestURI() string { - result := u.Opaque - if result == "" { - result = u.EscapedPath() - if result == "" { - result = "/" - } - } else { - if strings.HasPrefix(result, "//") { - result = u.Scheme + ":" + result - } - } - if u.ForceQuery || u.RawQuery != "" { - result += "?" + u.RawQuery - } - return result - } - - // Hostname returns u.Host, stripping any valid port number if present. - // - // If the result is enclosed in square brackets, as literal IPv6 addresses are, - // the square brackets are removed from the result. - func (u *URL) Hostname() string { - host, _ := splitHostPort(u.Host) - return host - } - - // Port returns the port part of u.Host, without the leading colon. - // - // If u.Host doesn't contain a valid numeric port, Port returns an empty string. - func (u *URL) Port() string { - _, port := splitHostPort(u.Host) - return port - } - - // splitHostPort separates host and port. If the port is not valid, it returns - // the entire input as host, and it doesn't check the validity of the host. - // Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric. - func splitHostPort(hostPort string) (host, port string) { - host = hostPort - - colon := strings.LastIndexByte(host, ':') - if colon != -1 && validOptionalPort(host[colon:]) { - host, port = host[:colon], host[colon+1:] - } - - if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { - host = host[1 : len(host)-1] - } - - return - } - - // Marshaling interface implementations. - // Would like to implement MarshalText/UnmarshalText but that will change the JSON representation of URLs. - - func (u *URL) MarshalBinary() (text []byte, err error) { - return u.AppendBinary(nil) - } - - func (u *URL) AppendBinary(b []byte) ([]byte, error) { - return append(b, u.String()...), nil - } - - func (u *URL) UnmarshalBinary(text []byte) error { - u1, err := Parse(string(text)) - if err != nil { - return err - } - *u = *u1 - return nil - } - - // JoinPath returns a new [URL] with the provided path elements joined to - // any existing path and the resulting path cleaned of any ./ or ../ elements. - // Any sequences of multiple / characters will be reduced to a single /. - func (u *URL) JoinPath(elem ...string) *URL { - elem = append([]string{u.EscapedPath()}, elem...) - var p string - if !strings.HasPrefix(elem[0], "/") { - // Return a relative path if u is relative, - // but ensure that it contains no ../ elements. - elem[0] = "/" + elem[0] - p = path.Join(elem...)[1:] - } else { - p = path.Join(elem...) - } - // path.Join will remove any trailing slashes. - // Preserve at least one. - if strings.HasSuffix(elem[len(elem)-1], "/") && !strings.HasSuffix(p, "/") { - p += "/" - } - url := *u - url.setPath(p) - return &url - } - - // validUserinfo reports whether s is a valid userinfo string per RFC 3986 - // Section 3.2.1: - // - // userinfo = *( unreserved / pct-encoded / sub-delims / ":" ) - // unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" - // sub-delims = "!" / "$" / "&" / "'" / "(" / ")" - // / "*" / "+" / "," / ";" / "=" - // - // It doesn't validate pct-encoded. The caller does that via func unescape. - func validUserinfo(s string) bool { - for _, r := range s { - if 'A' <= r && r <= 'Z' { - continue - } - if 'a' <= r && r <= 'z' { - continue - } - if '0' <= r && r <= '9' { - continue - } - switch r { - case '-', '.', '_', ':', '~', '!', '$', '&', '\'', - '(', ')', '*', '+', ',', ';', '=', '%', '@': - continue - default: - return false - } - } - return true - } - - // stringContainsCTLByte reports whether s contains any ASCII control character. - func stringContainsCTLByte(s string) bool { - for i := 0; i < len(s); i++ { - b := s[i] - if b < ' ' || b == 0x7f { - return true - } - } - return false - } - - // JoinPath returns a [URL] string with the provided path elements joined to - // the existing path of base and the resulting path cleaned of any ./ or ../ elements. - func JoinPath(base string, elem ...string) (result string, err error) { - url, err := Parse(base) - if err != nil { - return - } - result = url.JoinPath(elem...).String() - return - } - - required := len(s) + 2*hexCount - if required <= len(buf) { - t = buf[:required] - } else { - t = make([]byte, required) - } - - if hexCount == 0 { - copy(t, s) - for i := 0; i < len(s); i++ { - if s[i] == ' ' { - t[i] = '+' - } - } - return string(t) - } - - j := 0 - for i := 0; i < len(s); i++ { - switch c := s[i]; { - case c == ' ' && mode == encodeQueryComponent: - t[j] = '+' - j++ - case shouldEscape(c, mode): - t[j] = '%' - t[j+1] = upperhex[c>>4] - t[j+2] = upperhex[c&15] - j += 3 - default: - t[j] = s[i] - j++ - } - } - return string(t) -} - -// A URL represents a parsed URL (technically, a URI reference). -// -// The general form represented is: -// -// [scheme:][//[userinfo@]host][/]path[?query][#fragment] -// -// URLs that do not start with a slash after the scheme are interpreted as: -// -// scheme:opaque[?query][#fragment] -// -// The Host field contains the host and port subcomponents of the URL. -// When the port is present, it is separated from the host with a colon. -// When the host is an IPv6 address, it must be enclosed in square brackets: -// "[fe80::1]:80". The [net.JoinHostPort] function combines a host and port -// into a string suitable for the Host field, adding square brackets to -// the host when necessary. -// -// Note that the Path field is stored in decoded form: /%47%6f%2f becomes /Go/. -// A consequence is that it is impossible to tell which slashes in the Path were -// slashes in the raw URL and which were %2f. This distinction is rarely important, -// but when it is, the code should use the [URL.EscapedPath] method, which preserves -// the original encoding of Path. -// -// The RawPath field is an optional field which is only set when the default -// encoding of Path is different from the escaped path. See the EscapedPath method -// for more details. -// -// URL's String method uses the EscapedPath method to obtain the path. -type URL struct { - Scheme string - Opaque string // encoded opaque data - User *Userinfo // username and password information - Host string // host or host:port (see Hostname and Port methods) - Path string // path (relative paths may omit leading slash) - RawPath string // encoded path hint (see EscapedPath method) - OmitHost bool // do not emit empty host (authority) - ForceQuery bool // append a query ('?') even if RawQuery is empty - RawQuery string // encoded query values, without '?' - Fragment string // fragment for references, without '#' - RawFragment string // encoded fragment hint (see EscapedFragment method) -} - -// User returns a [Userinfo] containing the provided username -// and no password set. -func User(username string) *Userinfo { - return &Userinfo{username, "", false} -} - -// UserPassword returns a [Userinfo] containing the provided username -// and password. -// -// This functionality should only be used with legacy web sites. -// RFC 2396 warns that interpreting Userinfo this way -// “is NOT RECOMMENDED, because the passing of authentication -// information in clear text (such as URI) has proven to be a -// security risk in almost every case where it has been used.” -func UserPassword(username, password string) *Userinfo { - return &Userinfo{username, password, true} -} - -// The Userinfo type is an immutable encapsulation of username and -// password details for a [URL]. An existing Userinfo value is guaranteed -// to have a username set (potentially empty, as allowed by RFC 2396), -// and optionally a password. -type Userinfo struct { - username string - password string - passwordSet bool -} - -// Username returns the username. -func (u *Userinfo) Username() string { - if u == nil { - return "" - } - return u.username -} - -// Password returns the password in case it is set, and whether it is set. -func (u *Userinfo) Password() (string, bool) { - if u == nil { - return "", false - } - return u.password, u.passwordSet -} - -// String returns the encoded userinfo information in the standard form -// of "username[:password]". -func (u *Userinfo) String() string { - if u == nil { - return "" - } - s := escape(u.username, encodeUserPassword) - if u.passwordSet { - s += ":" + escape(u.password, encodeUserPassword) - } - return s -} - -// Maybe rawURL is of the form scheme:path. -// (Scheme must be [a-zA-Z][a-zA-Z0-9+.-]*) -// If so, return scheme, path; else return "", rawURL. -func getScheme(rawURL string) (scheme, path string, err error) { - for i := 0; i < len(rawURL); i++ { - c := rawURL[i] - switch { - case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z': - // do nothing - case '0' <= c && c <= '9' || c == '+' || c == '-' || c == '.': - if i == 0 { - return "", rawURL, nil - } - case c == ':': - if i == 0 { - return "", "", errors.New("missing protocol scheme") - } - return rawURL[:i], rawURL[i+1:], nil - default: - // we have encountered an invalid character, - // so there is no valid scheme - return "", rawURL, nil - } - } - return "", rawURL, nil -} - -// Parse parses a raw url into a [URL] structure. -// -// The url may be relative (a path, without a host) or absolute -// (starting with a scheme). Trying to parse a hostname and path -// without a scheme is invalid but may not necessarily return an -// error, due to parsing ambiguities. -func Parse(rawURL string) (*URL, error) { - // Cut off #frag - u, frag, _ := strings.Cut(rawURL, "#") - url, err := parse(u, false) - if err != nil { - return nil, &Error{"parse", u, err} - } - if frag == "" { - return url, nil - } - if err = url.setFragment(frag); err != nil { - return nil, &Error{"parse", rawURL, err} - } - return url, nil -} - -// ParseRequestURI parses a raw url into a [URL] structure. It assumes that -// url was received in an HTTP request, so the url is interpreted -// only as an absolute URI or an absolute path. -// The string url is assumed not to have a #fragment suffix. -// (Web browsers strip #fragment before sending the URL to a web server.) -func ParseRequestURI(rawURL string) (*URL, error) { - url, err := parse(rawURL, true) - if err != nil { - return nil, &Error{"parse", rawURL, err} - } - return url, nil -} - -// parse parses a URL from a string in one of two contexts. If -// viaRequest is true, the URL is assumed to have arrived via an HTTP request, -// in which case only absolute URLs or path-absolute relative URLs are allowed. -// If viaRequest is false, all forms of relative URLs are allowed. -func parse(rawURL string, viaRequest bool) (*URL, error) { - var rest string - var err error - - if stringContainsCTLByte(rawURL) { - return nil, errors.New("net/url: invalid control character in URL") - } - - if rawURL == "" && viaRequest { - return nil, errors.New("empty url") - } - url := new(URL) - - if rawURL == "*" { - url.Path = "*" - return url, nil - } - - // Split off possible leading "http:", "mailto:", etc. - // Cannot contain escaped characters. - if url.Scheme, rest, err = getScheme(rawURL); err != nil { - return nil, err - } - url.Scheme = strings.ToLower(url.Scheme) - - if strings.HasSuffix(rest, "?") && strings.Count(rest, "?") == 1 { - url.ForceQuery = true - rest = rest[:len(rest)-1] - } else { - rest, url.RawQuery, _ = strings.Cut(rest, "?") - } - - if !strings.HasPrefix(rest, "/") { - if url.Scheme != "" { - // We consider rootless paths per RFC 3986 as opaque. - url.Opaque = rest - return url, nil - } - if viaRequest { - return nil, errors.New("invalid URI for request") - } - - // Avoid confusion with malformed schemes, like cache_object:foo/bar. - // See golang.org/issue/16822. - // - // RFC 3986, §3.3: - // In addition, a URI reference (Section 4.1) may be a relative-path reference, - // in which case the first path segment cannot contain a colon (":") character. - if segment, _, _ := strings.Cut(rest, "/"); strings.Contains(segment, ":") { - // First path segment has colon. Not allowed in relative URL. - return nil, errors.New("first path segment in URL cannot contain colon") - } - } - - if (url.Scheme != "" || !viaRequest && !strings.HasPrefix(rest, "///")) && strings.HasPrefix(rest, "//") { - var authority string - authority, rest = rest[2:], "" - if i := strings.Index(authority, "/"); i >= 0 { - authority, rest = authority[:i], authority[i:] - } - url.User, url.Host, err = parseAuthority(authority) - if err != nil { - return nil, err - } - } else if url.Scheme != "" && strings.HasPrefix(rest, "/") { - // OmitHost is set to true when rawURL has an empty host (authority). - // See golang.org/issue/46059. - url.OmitHost = true - } - - // Set Path and, optionally, RawPath. - // RawPath is a hint of the encoding of Path. We don't want to set it if - // the default escaping of Path is equivalent, to help make sure that people - // don't rely on it in general. - if err := url.setPath(rest); err != nil { - return nil, err - } - return url, nil -} - -func parseAuthority(authority string) (user *Userinfo, host string, err error) { - i := strings.LastIndex(authority, "@") - if i < 0 { - host, err = parseHost(authority) - } else { - host, err = parseHost(authority[i+1:]) - } - if err != nil { - return nil, "", err - } - if i < 0 { - return nil, host, nil - } - userinfo := authority[:i] - if !validUserinfo(userinfo) { - return nil, "", errors.New("net/url: invalid userinfo") - } - if !strings.Contains(userinfo, ":") { - if userinfo, err = unescape(userinfo, encodeUserPassword); err != nil { - return nil, "", err - } - user = User(userinfo) - } else { - username, password, _ := strings.Cut(userinfo, ":") - if username, err = unescape(username, encodeUserPassword); err != nil { - return nil, "", err - } - if password, err = unescape(password, encodeUserPassword); err != nil { - return nil, "", err - } - user = UserPassword(username, password) - } - return user, host, nil -} - -// parseHost parses host as an authority without user -// information. That is, as host[:port]. -func parseHost(host string) (string, error) { - if strings.HasPrefix(host, "[") { - // Parse an IP-Literal in RFC 3986 and RFC 6874. - // E.g., "[fe80::1]", "[fe80::1%25en0]", "[fe80::1]:80". - i := strings.LastIndex(host, "]") - if i < 0 { - return "", errors.New("missing ']' in host") - } - colonPort := host[i+1:] - if !validOptionalPort(colonPort) { - return "", fmt.Errorf("invalid port %q after host", colonPort) - } - - // RFC 6874 defines that %25 (%-encoded percent) introduces - // the zone identifier, and the zone identifier can use basically - // any %-encoding it likes. That's different from the host, which - // can only %-encode non-ASCII bytes. - // We do impose some restrictions on the zone, to avoid stupidity - // like newlines. - zone := strings.Index(host[:i], "%25") - if zone >= 0 { - host1, err := unescape(host[:zone], encodeHost) - if err != nil { - return "", err - } - host2, err := unescape(host[zone:i], encodeZone) - if err != nil { - return "", err - } - host3, err := unescape(host[i:], encodeHost) - if err != nil { - return "", err - } - return host1 + host2 + host3, nil - } - } else if i := strings.LastIndex(host, ":"); i != -1 { - colonPort := host[i:] - if !validOptionalPort(colonPort) { - return "", fmt.Errorf("invalid port %q after host", colonPort) - } - } - - var err error - if host, err = unescape(host, encodeHost); err != nil { - return "", err - } - return host, nil -} - -// setPath sets the Path and RawPath fields of the URL based on the provided -// escaped path p. It maintains the invariant that RawPath is only specified -// when it differs from the default encoding of the path. -// For example: -// - setPath("/foo/bar") will set Path="/foo/bar" and RawPath="" -// - setPath("/foo%2fbar") will set Path="/foo/bar" and RawPath="/foo%2fbar" -// setPath will return an error only if the provided path contains an invalid -// escaping. -// -// setPath should be an internal detail, -// but widely used packages access it using linkname. -// Notable members of the hall of shame include: -// - github.com/sagernet/sing -// -// Do not remove or change the type signature. -// See go.dev/issue/67401. -// -//go:linkname badSetPath net/url.(*URL).setPath -func (u *URL) setPath(p string) error { - path, err := unescape(p, encodePath) - if err != nil { - return err - } - u.Path = path - if escp := escape(path, encodePath); p == escp { - // Default encoding is fine. - u.RawPath = "" - } else { - u.RawPath = p - } - return nil -} - -// for linkname because we cannot linkname methods directly -func badSetPath(*URL, string) error - -// EscapedPath returns the escaped form of u.Path. -// In general there are multiple possible escaped forms of any path. -// EscapedPath returns u.RawPath when it is a valid escaping of u.Path. -// Otherwise EscapedPath ignores u.RawPath and computes an escaped -// form on its own. -// The [URL.String] and [URL.RequestURI] methods use EscapedPath to construct -// their results. -// In general, code should call EscapedPath instead of -// reading u.RawPath directly. -func (u *URL) EscapedPath() string { - if u.RawPath != "" && validEncoded(u.RawPath, encodePath) { - p, err := unescape(u.RawPath, encodePath) - if err == nil && p == u.Path { - return u.RawPath - } - } - if u.Path == "*" { - return "*" // don't escape (Issue 11202) - } - return escape(u.Path, encodePath) -} - -// validEncoded reports whether s is a valid encoded path or fragment, -// according to mode. -// It must not contain any bytes that require escaping during encoding. -func validEncoded(s string, mode encoding) bool { - for i := 0; i < len(s); i++ { - // RFC 3986, Appendix A. - // pchar = unreserved / pct-encoded / sub-delims / ":" / "@". - // shouldEscape is not quite compliant with the RFC, - // so we check the sub-delims ourselves and let - // shouldEscape handle the others. - switch s[i] { - case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=', ':', '@': - // ok - case '[', ']': - // ok - not specified in RFC 3986 but left alone by modern browsers - case '%': - // ok - percent encoded, will decode - default: - if shouldEscape(s[i], mode) { - return false - } - } - } - return true -} - -// setFragment is like setPath but for Fragment/RawFragment. -func (u *URL) setFragment(f string) error { - frag, err := unescape(f, encodeFragment) - if err != nil { - return err - } - u.Fragment = frag - if escf := escape(frag, encodeFragment); f == escf { - // Default encoding is fine. - u.RawFragment = "" - } else { - u.RawFragment = f - } - return nil -} - -// EscapedFragment returns the escaped form of u.Fragment. -// In general there are multiple possible escaped forms of any fragment. -// EscapedFragment returns u.RawFragment when it is a valid escaping of u.Fragment. -// Otherwise EscapedFragment ignores u.RawFragment and computes an escaped -// form on its own. -// The [URL.String] method uses EscapedFragment to construct its result. -// In general, code should call EscapedFragment instead of -// reading u.RawFragment directly. -func (u *URL) EscapedFragment() string { - if u.RawFragment != "" && validEncoded(u.RawFragment, encodeFragment) { - f, err := unescape(u.RawFragment, encodeFragment) - if err == nil && f == u.Fragment { - return u.RawFragment - } - } - return escape(u.Fragment, encodeFragment) -} - -// validOptionalPort reports whether port is either an empty string -// or matches /^:\d*$/ -func validOptionalPort(port string) bool { - if port == "" { - return true - } - if port[0] != ':' { - return false - } - for _, b := range port[1:] { - if b < '0' || b > '9' { - return false - } - } - return true -} - -// String reassembles the [URL] into a valid URL string. -// The general form of the result is one of: -// -// scheme:opaque?query#fragment -// scheme://userinfo@host/path?query#fragment -// -// If u.Opaque is non-empty, String uses the first form; -// otherwise it uses the second form. -// Any non-ASCII characters in host are escaped. -// To obtain the path, String uses u.EscapedPath(). -// -// In the second form, the following rules apply: -// - if u.Scheme is empty, scheme: is omitted. -// - if u.User is nil, userinfo@ is omitted. -// - if u.Host is empty, host/ is omitted. -// - if u.Scheme and u.Host are empty and u.User is nil, -// the entire scheme://userinfo@host/ is omitted. -// - if u.Host is non-empty and u.Path begins with a /, -// the form host/path does not add its own /. -// - if u.RawQuery is empty, ?query is omitted. -// - if u.Fragment is empty, #fragment is omitted. -func (u *URL) String() string { - var buf strings.Builder - - n := len(u.Scheme) - if u.Opaque != "" { - n += len(u.Opaque) - } else { - if !u.OmitHost && (u.Scheme != "" || u.Host != "" || u.User != nil) { - username := u.User.Username() - password, _ := u.User.Password() - n += len(username) + len(password) + len(u.Host) - } - n += len(u.Path) - } - n += len(u.RawQuery) + len(u.RawFragment) - n += len(":" + "//" + "//" + ":" + "@" + "/" + "./" + "?" + "#") - buf.Grow(n) - - if u.Scheme != "" { - buf.WriteString(u.Scheme) - buf.WriteByte(':') - } - if u.Opaque != "" { - buf.WriteString(u.Opaque) - } else { - if u.Scheme != "" || u.Host != "" || u.User != nil { - if u.OmitHost && u.Host == "" && u.User == nil { - // omit empty host - } else { - if u.Host != "" || u.Path != "" || u.User != nil { - buf.WriteString("//") - } - if ui := u.User; ui != nil { - buf.WriteString(ui.String()) - buf.WriteByte('@') - } - if h := u.Host; h != "" { - buf.WriteString(escape(h, encodeHost)) - } - } - } - path := u.EscapedPath() - if path != "" && path[0] != '/' && u.Host != "" { - buf.WriteByte('/') - } - if buf.Len() == 0 { - // RFC 3986 §4.2 - // A path segment that contains a colon character (e.g., "this:that") - // cannot be used as the first segment of a relative-path reference, as - // it would be mistaken for a scheme name. Such a segment must be - // preceded by a dot-segment (e.g., "./this:that") to make a relative- - // path reference. - if segment, _, _ := strings.Cut(path, "/"); strings.Contains(segment, ":") { - buf.WriteString("./") - } - } - buf.WriteString(path) - } - if u.ForceQuery || u.RawQuery != "" { - buf.WriteByte('?') - buf.WriteString(u.RawQuery) - } - if u.Fragment != "" { - buf.WriteByte('#') - buf.WriteString(u.EscapedFragment()) - } - return buf.String() -} - -// Redacted is like [URL.String] but replaces any password with "xxxxx". -// Only the password in u.User is redacted. -func (u *URL) Redacted() string { - if u == nil { - return "" - } - - ru := *u - if _, has := ru.User.Password(); has { - ru.User = UserPassword(ru.User.Username(), "xxxxx") - } - return ru.String() -} - -// Values maps a string key to a list of values. -// It is typically used for query parameters and form values. -// Unlike in the http.Header map, the keys in a Values map -// are case-sensitive. -type Values map[string][]string - -// Get gets the first value associated with the given key. -// If there are no values associated with the key, Get returns -// the empty string. To access multiple values, use the map -// directly. -func (v Values) Get(key string) string { - vs := v[key] - if len(vs) == 0 { - return "" - } - return vs[0] -} - -// Set sets the key to value. It replaces any existing -// values. -func (v Values) Set(key, value string) { - v[key] = []string{value} -} - -// Add adds the value to key. It appends to any existing -// values associated with key. -func (v Values) Add(key, value string) { - v[key] = append(v[key], value) -} - -// Del deletes the values associated with key. -func (v Values) Del(key string) { - delete(v, key) -} - -// Has checks whether a given key is set. -func (v Values) Has(key string) bool { - _, ok := v[key] - return ok -} - -// ParseQuery parses the URL-encoded query string and returns -// a map listing the values specified for each key. -// ParseQuery always returns a non-nil map containing all the -// valid query parameters found; err describes the first decoding error -// encountered, if any. -// -// Query is expected to be a list of key=value settings separated by ampersands. -// A setting without an equals sign is interpreted as a key set to an empty -// value. -// Settings containing a non-URL-encoded semicolon are considered invalid. -func ParseQuery(query string) (Values, error) { - m := make(Values) - err := parseQuery(m, query) - return m, err -} - -func parseQuery(m Values, query string) (err error) { - for query != "" { - var key string - key, query, _ = strings.Cut(query, "&") - if strings.Contains(key, ";") { - err = fmt.Errorf("invalid semicolon separator in query") - continue - } - if key == "" { - continue - } - key, value, _ := strings.Cut(key, "=") - key, err1 := QueryUnescape(key) - if err1 != nil { - if err == nil { - err = err1 - } - continue - } - value, err1 = QueryUnescape(value) - if err1 != nil { - if err == nil { - err = err1 - } - continue - } - m[key] = append(m[key], value) - } - return err -} - -// Encode encodes the values into “URL encoded” form -// ("bar=baz&foo=quux") sorted by key. -func (v Values) Encode() string { - if len(v) == 0 { - return "" - } - var buf strings.Builder - keys := make([]string, 0, len(v)) - for k := range v { - keys = append(keys, k) - } - slices.Sort(keys) - for _, k := range keys { - vs := v[k] - keyEscaped := QueryEscape(k) - for _, v := range vs { - if buf.Len() > 0 { - buf.WriteByte('&') - } - buf.WriteString(keyEscaped) - buf.WriteByte('=') - buf.WriteString(QueryEscape(v)) - } - } - return buf.String() -} - -// resolvePath applies special path segments from refs and applies -// them to base, per RFC 3986. -func resolvePath(base, ref string) string { - var full string - if ref == "" { - full = base - } else if ref[0] != '/' { - i := strings.LastIndex(base, "/") - full = base[:i+1] + ref - } else { - full = ref - } - if full == "" { - return "" - } - - var ( - elem string - dst strings.Builder - ) - first := true - remaining := full - // We want to return a leading '/', so write it now. - dst.WriteByte('/') - found := true - for found { - elem, remaining, found = strings.Cut(remaining, "/") - if elem == "." { - first = false - // drop - continue - } - - if elem == ".." { - // Ignore the leading '/' we already wrote. - str := dst.String()[1:] - index := strings.LastIndexByte(str, '/') - - dst.Reset() - dst.WriteByte('/') - if index == -1 { - first = true - } else { - dst.WriteString(str[:index]) - } - } else { - if !first { - dst.WriteByte('/') - } - dst.WriteString(elem) - first = false - } - } - - if elem == "." || elem == ".." { - dst.WriteByte('/') - } - - // We wrote an initial '/', but we don't want two. - r := dst.String() - if len(r) > 1 && r[1] == '/' { - r = r[1:] - } - return r -} - -// IsAbs reports whether the [URL] is absolute. -// Absolute means that it has a non-empty scheme. -func (u *URL) IsAbs() bool { - return u.Scheme != "" -} - -// Parse parses a [URL] in the context of the receiver. The provided URL -// may be relative or absolute. Parse returns nil, err on parse -// failure, otherwise its return value is the same as [URL.ResolveReference]. -func (u *URL) Parse(ref string) (*URL, error) { - refURL, err := Parse(ref) - if err != nil { - return nil, err - } - return u.ResolveReference(refURL), nil -} - -// ResolveReference resolves a URI reference to an absolute URI from -// an absolute base URI u, per RFC 3986 Section 5.2. The URI reference -// may be relative or absolute. ResolveReference always returns a new -// [URL] instance, even if the returned URL is identical to either the -// base or reference. If ref is an absolute URL, then ResolveReference -// ignores base and returns a copy of ref. -func (u *URL) ResolveReference(ref *URL) *URL { - url := *ref - if ref.Scheme == "" { - url.Scheme = u.Scheme - } - if ref.Scheme != "" || ref.Host != "" || ref.User != nil { - // The "absoluteURI" or "net_path" cases. - // We can ignore the error from setPath since we know we provided a - // validly-escaped path. - url.setPath(resolvePath(ref.EscapedPath(), "")) - return &url - } - if ref.Opaque != "" { - url.User = nil - url.Host = "" - url.Path = "" - return &url - } - if ref.Path == "" && !ref.ForceQuery && ref.RawQuery == "" { - url.RawQuery = u.RawQuery - if ref.Fragment == "" { - url.Fragment = u.Fragment - url.RawFragment = u.RawFragment - } - } - if ref.Path == "" && u.Opaque != "" { - url.Opaque = u.Opaque - url.User = nil - url.Host = "" - url.Path = "" - return &url - } - // The "abs_path" or "rel_path" cases. - url.Host = u.Host - url.User = u.User - url.setPath(resolvePath(u.EscapedPath(), ref.EscapedPath())) - return &url -} - -// Query parses RawQuery and returns the corresponding values. -// It silently discards malformed value pairs. -// To check errors use [ParseQuery]. -func (u *URL) Query() Values { - v, _ := ParseQuery(u.RawQuery) - return v -} - -// RequestURI returns the encoded path?query or opaque?query -// string that would be used in an HTTP request for u. -func (u *URL) RequestURI() string { - result := u.Opaque - if result == "" { - result = u.EscapedPath() - if result == "" { - result = "/" - } - } else { - if strings.HasPrefix(result, "//") { - result = u.Scheme + ":" + result - } - } - if u.ForceQuery || u.RawQuery != "" { - result += "?" + u.RawQuery - } - return result -} - -// Hostname returns u.Host, stripping any valid port number if present. -// -// If the result is enclosed in square brackets, as literal IPv6 addresses are, -// the square brackets are removed from the result. -func (u *URL) Hostname() string { - host, _ := splitHostPort(u.Host) - return host -} - -// Port returns the port part of u.Host, without the leading colon. -// -// If u.Host doesn't contain a valid numeric port, Port returns an empty string. -func (u *URL) Port() string { - _, port := splitHostPort(u.Host) - return port -} - -// splitHostPort separates host and port. If the port is not valid, it returns -// the entire input as host, and it doesn't check the validity of the host. -// Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric. -func splitHostPort(hostPort string) (host, port string) { - host = hostPort - - colon := strings.LastIndexByte(host, ':') - if colon != -1 && validOptionalPort(host[colon:]) { - host, port = host[:colon], host[colon+1:] - } - - if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { - host = host[1 : len(host)-1] - } - - return -} - -// Marshaling interface implementations. -// Would like to implement MarshalText/UnmarshalText but that will change the JSON representation of URLs. - -func (u *URL) MarshalBinary() (text []byte, err error) { - return u.AppendBinary(nil) -} - -func (u *URL) AppendBinary(b []byte) ([]byte, error) { - return append(b, u.String()...), nil -} - -func (u *URL) UnmarshalBinary(text []byte) error { - u1, err := Parse(string(text)) - if err != nil { - return err - } - *u = *u1 - return nil -} - -// JoinPath returns a new [URL] with the provided path elements joined to -// any existing path and the resulting path cleaned of any ./ or ../ elements. -// Any sequences of multiple / characters will be reduced to a single /. -func (u *URL) JoinPath(elem ...string) *URL { - elem = append([]string{u.EscapedPath()}, elem...) - var p string - if !strings.HasPrefix(elem[0], "/") { - // Return a relative path if u is relative, - // but ensure that it contains no ../ elements. - elem[0] = "/" + elem[0] - p = path.Join(elem...)[1:] - } else { - p = path.Join(elem...) - } - // path.Join will remove any trailing slashes. - // Preserve at least one. - if strings.HasSuffix(elem[len(elem)-1], "/") && !strings.HasSuffix(p, "/") { - p += "/" - } - url := *u - url.setPath(p) - return &url -} - -// validUserinfo reports whether s is a valid userinfo string per RFC 3986 -// Section 3.2.1: -// -// userinfo = *( unreserved / pct-encoded / sub-delims / ":" ) -// unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" -// sub-delims = "!" / "$" / "&" / "'" / "(" / ")" -// / "*" / "+" / "," / ";" / "=" -// -// It doesn't validate pct-encoded. The caller does that via func unescape. -func validUserinfo(s string) bool { - for _, r := range s { - if 'A' <= r && r <= 'Z' { - continue - } - if 'a' <= r && r <= 'z' { - continue - } - if '0' <= r && r <= '9' { - continue - } - switch r { - case '-', '.', '_', ':', '~', '!', '$', '&', '\'', - '(', ')', '*', '+', ',', ';', '=', '%', '@': - continue - default: - return false - } - } - return true -} - -// stringContainsCTLByte reports whether s contains any ASCII control character. -func stringContainsCTLByte(s string) bool { - for i := 0; i < len(s); i++ { - b := s[i] - if b < ' ' || b == 0x7f { - return true - } - } - return false -} - -// JoinPath returns a [URL] string with the provided path elements joined to -// the existing path of base and the resulting path cleaned of any ./ or ../ elements. -func JoinPath(base string, elem ...string) (result string, err error) { - url, err := Parse(base) - if err != nil { - return - } - result = url.JoinPath(elem...).String() - return -} From 42e5881977d526eb16dc6fbcab43907b0c563ec9 Mon Sep 17 00:00:00 2001 From: hackerchai Date: Fri, 16 Aug 2024 18:29:22 +0800 Subject: [PATCH 14/55] feat(x/net/http): Implement http server demo Signed-off-by: hackerchai --- x/net/http/_demo/http.go | 39 ++++++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/x/net/http/_demo/http.go b/x/net/http/_demo/http.go index 7fd6b37..4a043da 100644 --- a/x/net/http/_demo/http.go +++ b/x/net/http/_demo/http.go @@ -1,13 +1,30 @@ -func main() { - mux := http.NewServeMux() - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "Hello, %s!", r.URL) - }) - - server := http.NewServer(":8080") - server.Handler = mux - err := server.ListenAndServe() +package main + +import ( + "fmt" + "io" + + "github.com/goplus/llgo/x/net/http" +) + +func echoHandler(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) if err != nil { - fmt.Printf("Server error: %v\n", err) + http.Error(w, "Error reading request body", http.StatusInternalServerError) + return + } + defer r.Body.Close() + + w.Header().Set("Content-Type", "text/plain") + + w.Write(body) +} + +func main() { + http.HandleFunc("/echo", echoHandler) + + fmt.Println("Starting server on :8080") + if err := http.ListenAndServe(":8080", nil); err != nil { + panic(err) } -} \ No newline at end of file +} From 953f6182efc836588d183b4f0f82fcf38fc1942b Mon Sep 17 00:00:00 2001 From: hackerchai Date: Mon, 19 Aug 2024 18:21:16 +0800 Subject: [PATCH 15/55] fix(x/net/http): Fix demo and multiple fixes Signed-off-by: hackerchai --- x/net/bytealg.go | 19 +++ x/net/http/_demo/http.go | 2 +- x/net/http/header.go | 4 + x/net/http/server.go | 249 +++++++++++++++++++++++++++++++-------- x/net/http/servermux.go | 1 + x/net/ipsock.go | 65 ++++++++++ x/net/net.go | 20 ++++ 7 files changed, 310 insertions(+), 50 deletions(-) create mode 100644 x/net/bytealg.go create mode 100644 x/net/ipsock.go create mode 100644 x/net/net.go diff --git a/x/net/bytealg.go b/x/net/bytealg.go new file mode 100644 index 0000000..9c9ac68 --- /dev/null +++ b/x/net/bytealg.go @@ -0,0 +1,19 @@ +package net + +func LastIndexByteString(s string, c byte) int { + for i := len(s) - 1; i >= 0; i-- { + if s[i] == c { + return i + } + } + return -1 +} + +func IndexByteString(s string, c byte) int { + for i := 0; i < len(s); i++ { + if s[i] == c { + return i + } + } + return -1 +} \ No newline at end of file diff --git a/x/net/http/_demo/http.go b/x/net/http/_demo/http.go index 4a043da..d0b3c0f 100644 --- a/x/net/http/_demo/http.go +++ b/x/net/http/_demo/http.go @@ -24,7 +24,7 @@ func main() { http.HandleFunc("/echo", echoHandler) fmt.Println("Starting server on :8080") - if err := http.ListenAndServe(":8080", nil); err != nil { + if err := http.ListenAndServe("127.0.0.1:1234", nil); err != nil { panic(err) } } diff --git a/x/net/http/header.go b/x/net/http/header.go index 39130e4..5579972 100644 --- a/x/net/http/header.go +++ b/x/net/http/header.go @@ -15,4 +15,8 @@ func (h Header) Get(key string) string { return v[0] } return "" +} + +func (h Header) Del(key string) { + delete(h, key) } \ No newline at end of file diff --git a/x/net/http/server.go b/x/net/http/server.go index f3196c3..833b4e9 100644 --- a/x/net/http/server.go +++ b/x/net/http/server.go @@ -3,16 +3,18 @@ package http import ( "fmt" "os" + "strconv" "sync" "sync/atomic" "unsafe" "github.com/goplus/llgo/c" "github.com/goplus/llgo/c/libuv" - "github.com/goplus/llgo/c/net" + cnet "github.com/goplus/llgo/c/net" cos "github.com/goplus/llgo/c/os" "github.com/goplus/llgo/c/syscall" "github.com/goplus/llgo/rust/hyper" + "github.com/goplus/llgo/x/net" ) type Handler interface { @@ -48,10 +50,17 @@ type conn struct { Executor *hyper.Executor } +type serviceUserdata struct { + Server *Server + Conn *conn +} + func NewServer(addr string) *Server { + activeClients := make(map[*conn]struct{}) return &Server{ - Addr: addr, - Handler: DefaultServeMux, + Addr: addr, + Handler: DefaultServeMux, + activeConnections: activeClients, } } @@ -67,23 +76,34 @@ func (srv *Server) ListenAndServe() error { return fmt.Errorf("failed to init TCP: %v", err) } - var sockaddr net.SockaddrIn - if err := libuv.Ip4Addr(c.AllocaCStr(srv.Addr), 0, &sockaddr); err != 0 { + host, port, err := net.SplitHostPort(srv.Addr) + if err != nil { + return fmt.Errorf("invalid address %q: %v", srv.Addr, err) + } + + portNum, err := strconv.Atoi(port) + if err != nil { + return fmt.Errorf("invalid port number: %v", err) + } + + var sockaddr cnet.SockaddrIn + if err := libuv.Ip4Addr(c.AllocaCStr(host), c.Int(portNum), &sockaddr); err != 0 { return fmt.Errorf("failed to create IP address: %v", err) } - if err := srv.uvServer.Bind((*net.SockAddr)(unsafe.Pointer(&sockaddr)), 0); err != 0 { + if err := srv.uvServer.Bind((*cnet.SockAddr)(unsafe.Pointer(&sockaddr)), 0); err != 0 { return fmt.Errorf("failed to bind: %v", err) } // Set SO_REUSEADDR yes := c.Int(1) - result := net.SetSockOpt(srv.uvServer.GetIoWatcherFd(), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, unsafe.Pointer(&yes), c.Uint(unsafe.Sizeof(yes))) + result := cnet.SetSockOpt(srv.uvServer.GetIoWatcherFd(), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, unsafe.Pointer(&yes), c.Uint(unsafe.Sizeof(yes))) if result != 0 { return fmt.Errorf("failed to set SO_REUSEADDR: %v", result) } - - if err := (*libuv.Stream)(&srv.uvServer).Listen(128, srv.onNewConnection); err != 0 { + + (*libuv.Stream)(&srv.uvServer).Data = unsafe.Pointer(srv) + if err := (*libuv.Stream)(&srv.uvServer).Listen(128, onNewConnection); err != 0 { return fmt.Errorf("failed to listen: %v", err) } @@ -92,38 +112,49 @@ func (srv *Server) ListenAndServe() error { for { srv.uvLoop.Run(libuv.RUN_NOWAIT) - for conn := range srv.activeConnections { - task := conn.Executor.Poll() - for task != nil { - srv.handleTask(task) - task.Free() - task = conn.Executor.Poll() - } - } + // for conn := range srv.activeConnections { + // task := conn.Executor.Poll() + // for task != nil { + // srv.handleTask(task) + // task.Free() + // task = conn.Executor.Poll() + // } + // } } } -func (srv *Server) onNewConnection(serverStream *libuv.Stream, status c.Int) { +func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) { + DefaultServeMux.HandleFunc(pattern, handler) +} + +func onNewConnection(serverStream *libuv.Stream, status c.Int) { + fmt.Println("onNewConnection called") if status < 0 { fmt.Printf("New connection error: %s\n", libuv.Strerror(libuv.Errno(status))) return } - client := new(libuv.Tcp) - libuv.InitTcp(srv.uvLoop, client) + client := (*libuv.Tcp)(c.Malloc(unsafe.Sizeof(libuv.Tcp{}))) + libuv.InitTcp(libuv.DefaultLoop(), client) + srv := (*Server)(serverStream.Data) if serverStream.Accept((*libuv.Stream)(unsafe.Pointer(client))) == 0 { - conn := createConnData(srv.uvLoop, client) + fmt.Println("Accepted new connection") + conn := createConnData(libuv.DefaultLoop(), client) if conn == nil { fmt.Fprintf(os.Stderr, "Failed to create Conn\n") (*libuv.Handle)(unsafe.Pointer(client)).Close(onClose) return } + fmt.Println("Conn created") srv.trackConn(conn, true) + fmt.Println("Conn tracked") + + userdata := createServiceUserdata(srv, conn) io := createIo(conn) - service := hyper.ServiceNew(srv.serverCallback) - service.SetUserdata(unsafe.Pointer(conn), freeConnData) + service := hyper.ServiceNew(serverCallback) + service.SetUserdata(unsafe.Pointer(userdata), freeServiceUserdata) http1Opts := hyper.Http1ServerconnOptionsNew(conn.Executor) http2Opts := hyper.Http2ServerconnOptionsNew(conn.Executor) @@ -134,19 +165,20 @@ func (srv *Server) onNewConnection(serverStream *libuv.Stream, status c.Int) { http1Opts.Free() http2Opts.Free() } else { + fmt.Println("Client not accepted") (*libuv.Handle)(unsafe.Pointer(client)).Close(nil) } } -func (srv *Server) serverCallback(userdata unsafe.Pointer, hyperReq *hyper.Request, channel *hyper.ResponseChannel) { - conn := (*conn)(userdata) +func serverCallback(userdata unsafe.Pointer, hyperReq *hyper.Request, channel *hyper.ResponseChannel) { + userData := (*serviceUserdata)(userdata) if hyperReq == nil { fmt.Fprintf(os.Stderr, "Error: Received null request\n") return } - req, err := newRequest(conn, hyperReq) + req, err := newRequest(userData.Conn, hyperReq) if err != nil { fmt.Printf("Error creating request: %v\n", err) return @@ -154,27 +186,52 @@ func (srv *Server) serverCallback(userdata unsafe.Pointer, hyperReq *hyper.Reque res := newResponse(channel) - srv.Handler.ServeHTTP(res, req) + userData.Server.Handler.ServeHTTP(res, req) res.finalize() } func (srv *Server) handleTask(task *hyper.Task) { + taskUserdata := task.Userdata() switch task.Type() { - case hyper.TaskServerconn: + case hyper.TaskEmpty: fmt.Println("New server connection") - case hyper.TaskResponse: - fmt.Println("Response sent") + if taskUserdata != nil { + conn := (*conn)(taskUserdata) + if conn.IsClosing == 0 { + conn.IsClosing = 1 + if (*libuv.Handle)(unsafe.Pointer(&conn.PollHandle)).IsClosing() == 0 { + (*libuv.Handle)(unsafe.Pointer(&conn.PollHandle)).Close(nil) + } + if (*libuv.Handle)(unsafe.Pointer(&conn.Stream)).IsClosing() == 0 { + (*libuv.Handle)(unsafe.Pointer(&conn.Stream)).Close(closeConn) + } + } + } case hyper.TaskError: err := (*hyper.Error)(task.Value()) var errbuf [256]byte errlen := err.Print(&errbuf[0], unsafe.Sizeof(errbuf)) fmt.Printf("Task error: %.*s\n", errlen, (*c.Char)(unsafe.Pointer(&errbuf[0]))) err.Free() + + case hyper.TaskClientConn: + fmt.Fprintf(os.Stderr, "Unexpected HYPER_TASK_CLIENTCONN in server context\n") + + case hyper.TaskResponse: + fmt.Println("Response task received") + + case hyper.TaskBuf: + fmt.Println("Buffer task received") + + case hyper.TaskServerconn: + fmt.Println("Server connection task received: ready for new connection...") + default: + fmt.Fprintf(os.Stderr, "Unknown task type: %d\n", task.Type()) } } -func (s *Server) trackConn(c *conn, add bool) { +func (s *Server)trackConn(c *conn, add bool) { s.mu.Lock() defer s.mu.Unlock() if s.activeConnections == nil { @@ -187,17 +244,6 @@ func (s *Server) trackConn(c *conn, add bool) { } } -func (srv *Server) Close() error { - srv.inShutdown.Store(true) - srv.mu.Lock() - defer srv.mu.Unlock() - - for c := range srv.activeConnections { - delete(srv.activeConnections, c) - } - return nil -} - func createIo(conn *conn) *hyper.Io { io := hyper.NewIo() io.SetUserdata(unsafe.Pointer(conn), freeConnData) @@ -206,9 +252,19 @@ func createIo(conn *conn) *hyper.Io { return io } +func createServiceUserdata(srv *Server, conn *conn) *serviceUserdata { + userdata := (*serviceUserdata)(c.Calloc(1, unsafe.Sizeof(serviceUserdata{}))) + if userdata == nil { + fmt.Fprintf(os.Stderr, "Failed to allocate service_userdata\n") + } + userdata.Server = srv + userdata.Conn = conn + return userdata +} + func readCb(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uintptr) uintptr { conn := (*conn)(userdata) - ret := net.Recv(conn.Stream.GetIoWatcherFd(), unsafe.Pointer(buf), bufLen, 0) + ret := cnet.Recv(conn.Stream.GetIoWatcherFd(), unsafe.Pointer(buf), bufLen, 0) if ret >= 0 { return uintptr(ret) @@ -235,7 +291,7 @@ func readCb(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uintp func writeCb(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uintptr) uintptr { conn := (*conn)(userdata) - ret := net.Send(conn.Stream.GetIoWatcherFd(), unsafe.Pointer(buf), bufLen, 0) + ret := cnet.Send(conn.Stream.GetIoWatcherFd(), unsafe.Pointer(buf), bufLen, 0) if ret >= 0 { return uintptr(ret) @@ -265,7 +321,7 @@ func onClose(handle *libuv.Handle) { } func onPoll(handle *libuv.Poll, status c.Int, events c.Int) { - conn := (*conn)((*libuv.Handle)(unsafe.Pointer(handle)).GetData()) + conn := (*conn)(handle.Data) if status < 0 { fmt.Fprintf(os.Stderr, "Poll error: %s\n", libuv.Strerror(libuv.Errno(status))) @@ -284,6 +340,7 @@ func onPoll(handle *libuv.Poll, status c.Int, events c.Int) { } func updateConnRegistrations(conn *conn, create bool) bool { + fmt.Println("updateConnRegistrations called") events := c.Int(0) if conn.EventMask&c.Uint(libuv.READABLE) != 0 { events |= c.Int(libuv.READABLE) @@ -293,34 +350,42 @@ func updateConnRegistrations(conn *conn, create bool) bool { } r := conn.PollHandle.Start(events, onPoll) + fmt.Println("Poll handle started: %d", r) if r < 0 { fmt.Fprintf(os.Stderr, "uv_poll_start error: %s\n", libuv.Strerror(libuv.Errno(r))) return false } + fmt.Println("Poll handle started") return true } -func createConnData(loop *libuv.Loop, client *libuv.Tcp) *conn { +func createConnData(loop *libuv.Loop,client *libuv.Tcp) *conn { conn := (*conn)(c.Calloc(1, unsafe.Sizeof(conn{}))) if conn == nil { fmt.Fprintf(os.Stderr, "Failed to allocate conn_data\n") return nil } + fmt.Println("Conn data created") c.Memcpy(unsafe.Pointer(&conn.Stream), unsafe.Pointer(client), unsafe.Sizeof(libuv.Tcp{})) conn.IsClosing = 0 + fmt.Println("Conn data initialized") + r := libuv.PollInit(loop, conn.PollHandle, libuv.OsFd(client.GetIoWatcherFd())) if r < 0 { fmt.Fprintf(os.Stderr, "uv_poll_init error: %s\n", libuv.Strerror(libuv.Errno(r))) c.Free(unsafe.Pointer(conn)) return nil } + fmt.Println("Poll handle initialized") - (*libuv.Handle)(unsafe.Pointer(&conn.PollHandle)).Data = unsafe.Pointer(conn) - conn.Stream.Data = unsafe.Pointer(conn) + //(*libuv.Handle)(unsafe.Pointer(conn.PollHandle)).Data = unsafe.Pointer(conn) + conn.PollHandle.Data = unsafe.Pointer(conn) + //TODO(hackerchai): fix nil pointer error + //conn.Stream.Data = unsafe.Pointer(conn) if !updateConnRegistrations(conn, true) { - (*libuv.Handle)(unsafe.Pointer(&conn.PollHandle)).Close(nil) + (*libuv.Handle)(unsafe.Pointer(conn.PollHandle)).Close(nil) c.Free(unsafe.Pointer(conn)) return nil } @@ -337,6 +402,62 @@ func freeConnData(userdata c.Pointer) { } } +func closeConn(handle *libuv.Handle) { + conn := (*conn)(handle.GetData()) + if conn != nil { + if conn.ReadWaker != nil { + conn.ReadWaker.Free() + conn.ReadWaker = nil + } + if conn.WriteWaker != nil { + conn.WriteWaker.Free() + conn.WriteWaker = nil + } + if conn.ConnTask != nil { + conn.ConnTask.Free() + conn.ConnTask = nil + } + if conn.Executor != nil { + conn.Executor.Free() + conn.Executor = nil + } + c.Free(unsafe.Pointer(conn)) + } + c.Free(unsafe.Pointer(handle)) +} + +func freeServiceUserdata(userdata c.Pointer) { + castUserdata := (*serviceUserdata)(userdata) + if castUserdata != nil { + // Note: We don't free conn here because it's managed separately + freeConnData(unsafe.Pointer(castUserdata.Conn)) + c.Free(userdata) + } +} + +func closeWalkCb(handle *libuv.Handle, arg c.Pointer) { + if handle.IsClosing() == 0 { + handle.Close(nil) + } +} + +func (srv *Server) Close() error { + srv.inShutdown.Store(true) + srv.mu.Lock() + defer srv.mu.Unlock() + + for c := range srv.activeConnections { + delete(srv.activeConnections, c) + freeConnData(unsafe.Pointer(c)) + } + + srv.uvLoop.Walk(closeWalkCb, nil) + srv.uvLoop.Run(libuv.RUN_DEFAULT) + + srv.uvLoop.Close() + return nil +} + type HandlerFunc func(ResponseWriter, *Request) func (f HandlerFunc) ServeHTTP(w ResponseWriter, r *Request) { @@ -349,3 +470,33 @@ func NotFound(w ResponseWriter, r *Request) { w.WriteHeader(404) w.Write([]byte("404 page not found")) } + +// Error replies to the request with the specified error message and HTTP code. +// It does not otherwise end the request; the caller should ensure no further +// writes are done to w. +// The error message should be plain text. +// +// Error deletes the Content-Length header, +// sets Content-Type to “text/plain; charset=utf-8”, +// and sets X-Content-Type-Options to “nosniff”. +// This configures the header properly for the error message, +// in case the caller had set it up expecting a successful output. +func Error(w ResponseWriter, error string, code int) { + h := w.Header() + + // Delete the Content-Length header, which might be for some other content. + // Assuming the error string fits in the writer's buffer, we'll figure + // out the correct Content-Length for it later. + // + // We don't delete Content-Encoding, because some middleware sets + // Content-Encoding: gzip and wraps the ResponseWriter to compress on-the-fly. + // See https://go.dev/issue/66343. + h.Del("Content-Length") + + // There might be content type already set, but we reset it to + // text/plain for the error message. + h.Set("Content-Type", "text/plain; charset=utf-8") + h.Set("X-Content-Type-Options", "nosniff") + w.WriteHeader(code) + fmt.Fprintln(w, error) +} diff --git a/x/net/http/servermux.go b/x/net/http/servermux.go index a37b8a0..744d693 100644 --- a/x/net/http/servermux.go +++ b/x/net/http/servermux.go @@ -14,6 +14,7 @@ type muxEntry struct { pattern string } +// DefaultServeMux is the default [ServeMux] used by [Serve]. var DefaultServeMux = &ServeMux{m: make(map[string]muxEntry)} func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) { diff --git a/x/net/ipsock.go b/x/net/ipsock.go new file mode 100644 index 0000000..269d2b5 --- /dev/null +++ b/x/net/ipsock.go @@ -0,0 +1,65 @@ +package net + +// SplitHostPort splits a network address of the form "host:port", +// "host%zone:port", "[host]:port" or "[host%zone]:port" into host or +// host%zone and port. +// +// A literal IPv6 address in hostport must be enclosed in square +// brackets, as in "[::1]:80", "[::1%lo0]:80". +// +// See func Dial for a description of the hostport parameter, and host +// and port results. +func SplitHostPort(hostport string) (host, port string, err error) { + const ( + missingPort = "missing port in address" + tooManyColons = "too many colons in address" + ) + addrErr := func(addr, why string) (host, port string, err error) { + return "", "", &AddrError{Err: why, Addr: addr} + } + j, k := 0, 0 + + // The port starts after the last colon. + i := LastIndexByteString(hostport, ':') + if i < 0 { + return addrErr(hostport, missingPort) + } + + if hostport[0] == '[' { + // Expect the first ']' just before the last ':'. + end := IndexByteString(hostport, ']') + if end < 0 { + return addrErr(hostport, "missing ']' in address") + } + switch end + 1 { + case len(hostport): + // There can't be a ':' behind the ']' now. + return addrErr(hostport, missingPort) + case i: + // The expected result. + default: + // Either ']' isn't followed by a colon, or it is + // followed by a colon that is not the last one. + if hostport[end+1] == ':' { + return addrErr(hostport, tooManyColons) + } + return addrErr(hostport, missingPort) + } + host = hostport[1:end] + j, k = 1, end+1 // there can't be a '[' resp. ']' before these positions + } else { + host = hostport[:i] + if IndexByteString(host, ':') >= 0 { + return addrErr(hostport, tooManyColons) + } + } + if IndexByteString(hostport[j:], '[') >= 0 { + return addrErr(hostport, "unexpected '[' in address") + } + if IndexByteString(hostport[k:], ']') >= 0 { + return addrErr(hostport, "unexpected ']' in address") + } + + port = hostport[i+1:] + return host, port, nil +} \ No newline at end of file diff --git a/x/net/net.go b/x/net/net.go new file mode 100644 index 0000000..3cf53df --- /dev/null +++ b/x/net/net.go @@ -0,0 +1,20 @@ +package net + +type AddrError struct { + Err string + Addr string +} + +func (e *AddrError) Error() string { + if e == nil { + return "" + } + s := e.Err + if e.Addr != "" { + s = "address " + e.Addr + ": " + s + } + return s +} + +func (e *AddrError) Timeout() bool { return false } +func (e *AddrError) Temporary() bool { return false } \ No newline at end of file From 1bad20a6ef7df2776b12daefb9bbe2e32b94ce2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Mon, 19 Aug 2024 18:32:58 +0800 Subject: [PATCH 16/55] WIP(x/http/client): http.PostForm() & some function improvements --- x/http/_demo/postform/postform.go | 31 ++ x/http/_demo/timeout/timeout.go | 4 +- x/http/client.go | 12 +- x/http/request.go | 85 +++- x/http/transport.go | 644 +++++++++++++++++++++++++++--- x/http/util.go | 87 ++++ 6 files changed, 788 insertions(+), 75 deletions(-) create mode 100644 x/http/_demo/postform/postform.go diff --git a/x/http/_demo/postform/postform.go b/x/http/_demo/postform/postform.go new file mode 100644 index 0000000..5315ca9 --- /dev/null +++ b/x/http/_demo/postform/postform.go @@ -0,0 +1,31 @@ +package main + +import ( + "fmt" + "io" + "net/url" + + "github.com/goplus/llgo/x/http" +) + +func main() { + formData := url.Values{ + "name": {"John Doe"}, + "email": {"johndoe@example.com"}, + } + + resp, err := http.PostForm("http://httpbin.org/post", formData) + if err != nil { + fmt.Println(err) + return + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(string(body)) + defer resp.Body.Close() +} diff --git a/x/http/_demo/timeout/timeout.go b/x/http/_demo/timeout/timeout.go index 6eece04..42f8bf8 100644 --- a/x/http/_demo/timeout/timeout.go +++ b/x/http/_demo/timeout/timeout.go @@ -10,8 +10,8 @@ import ( func main() { client := &http.Client{ - //Timeout: time.Millisecond, // Set a small timeout to ensure it will time out - Timeout: time.Second * 5, + Timeout: time.Millisecond, // Set a small timeout to ensure it will time out + //Timeout: time.Second * 5, } req, err := http.NewRequest("GET", "https://www.baidu.com", nil) if err != nil { diff --git a/x/http/client.go b/x/http/client.go index 31362a9..4fc6e41 100644 --- a/x/http/client.go +++ b/x/http/client.go @@ -68,6 +68,14 @@ func (c *Client) Post(url, contentType string, body io.Reader) (resp *Response, return c.Do(req) } +func PostForm(url string, data url.Values) (resp *Response, err error) { + return DefaultClient.PostForm(url, data) +} + +func (c *Client) PostForm(url string, data url.Values) (resp *Response, err error) { + return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) +} + func (c *Client) Do(req *Request) (*Response, error) { return c.do(req) } @@ -474,7 +482,6 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi if deadline.IsZero() { return nop, alwaysFalse } - // TODO(spongehah) todo: map[string]github.com/goplus/llgo/x/http.RoundTripper //knownTransport := knownRoundTripperImpl(rt, req) oldCtx := req.Context() @@ -552,8 +559,7 @@ func timeBeforeContextDeadline(t time.Time, ctx context.Context) bool { return t.Before(d) } -/* -// knownRoundTripperImpl reports whether rt is a RoundTripper that's +/*// knownRoundTripperImpl reports whether rt is a RoundTripper that's // maintained by the Go team and known to implement the latest // optional semantics (notably contexts). The Request is used // to check whether this particular request is using an alternate protocol, diff --git a/x/http/request.go b/x/http/request.go index f6e6f16..38b74bb 100644 --- a/x/http/request.go +++ b/x/http/request.go @@ -45,7 +45,34 @@ type Request struct { var defaultChunkSize uintptr = 8192 -func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { +// NewRequest wraps NewRequestWithContext using context.Background. +func NewRequest(method, url string, body io.Reader) (*Request, error) { + return NewRequestWithContext(context.Background(), method, url, body) +} + +// NewRequestWithContext returns a new Request given a method, URL, and +// optional body. +// +// If the provided body is also an io.Closer, the returned +// Request.Body is set to body and will be closed by the Client +// methods Do, Post, and PostForm, and Transport.RoundTrip. +// +// NewRequestWithContext returns a Request suitable for use with +// Client.Do or Transport.RoundTrip. To create a request for use with +// testing a Server Handler, either use the NewRequest function in the +// net/http/httptest package, use ReadRequest, or manually update the +// Request fields. For an outgoing client request, the context +// controls the entire lifetime of a request and its response: +// obtaining a connection, sending the request, and reading the +// response headers and body. See the Request type's documentation for +// the difference between inbound and outbound request fields. +// +// If body is of type *bytes.Buffer, *bytes.Reader, or +// *strings.Reader, the returned request's ContentLength is set to its +// exact value (instead of -1), GetBody is populated (so 307 and 308 +// redirects can replay the body), and Body is set to NoBody if the +// ContentLength is 0. +func NewRequestWithContext(ctx context.Context, method, urlStr string, body io.Reader) (*Request, error) { if method == "" { // We document that "" means "GET" for Request.Method, and people have // relied on that from NewRequest, so keep that working. @@ -69,7 +96,7 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { // The host's colon:port should be normalized. See Issue 14836. u.Host = removeEmptyPort(u.Host) req := &Request{ - //ctx: ctx, + ctx: ctx, Method: method, URL: u, Proto: "HTTP/1.1", @@ -131,10 +158,49 @@ func printInformational(userdata c.Pointer, resp *hyper.Response) { fmt.Println("Informational (1xx): ", status) } +type postReq struct { + req *Request + buf []byte +} + func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { - req := (*Request)(userdata) - buffer := make([]byte, defaultChunkSize) - n, err := req.Body.Read(buffer) + req := (*postReq)(userdata) + n, err := req.req.Body.Read(req.buf) + if err != nil { + if err == io.EOF { + *chunk = nil + return hyper.PollReady + } + fmt.Println("error reading upload file: ", err) + return hyper.PollError + } + if n > 0 { + *chunk = hyper.CopyBuf(&req.buf[0], uintptr(n)) + return hyper.PollReady + } + if n == 0 { + *chunk = nil + return hyper.PollReady + } + + fmt.Printf("error reading request body: %s\n", c.GoString(c.Strerror(os.Errno))) + return hyper.PollError +} + +func setPostDataNoCopy(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { + type buf struct { + data *uint8 + len uintptr + Unused [16]byte + } + req := (*postReq)(userdata) + buffer := &buf{ + data: &req.buf[0], + len: uintptr(len(req.buf)), + } + + *chunk = (*hyper.Buf)(c.Pointer(buffer)) + n, err := req.req.Body.Read(req.buf) if err != nil { if err == io.EOF { *chunk = nil @@ -144,7 +210,6 @@ func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.In return hyper.PollError } if n > 0 { - *chunk = hyper.CopyBuf(&buffer[0], uintptr(n)) return hyper.PollReady } if n == 0 { @@ -152,7 +217,7 @@ func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.In return hyper.PollReady } - fmt.Printf("error reading upload file: %s\n", c.GoString(c.Strerror(os.Errno))) + fmt.Printf("error reading request body: %s\n", c.GoString(c.Strerror(os.Errno))) return hyper.PollError } @@ -180,7 +245,11 @@ func newHyperRequest(req *Request) (*hyper.Request, error) { hyperReq.OnInformational(printInformational, nil) hyperReqBody := hyper.NewBody() - hyperReqBody.SetUserdata(c.Pointer(req)) + reqData := &postReq{ + req: req, + buf: make([]byte, 3), + } + hyperReqBody.SetUserdata(c.Pointer(reqData)) hyperReqBody.SetDataFunc(setPostData) hyperReq.SetBody(hyperReqBody) } diff --git a/x/http/transport.go b/x/http/transport.go index f9bfa46..e54b357 100644 --- a/x/http/transport.go +++ b/x/http/transport.go @@ -1,9 +1,12 @@ package http import ( + "context" + "errors" "fmt" "io" "net/url" + "sync" "sync/atomic" "unsafe" @@ -26,9 +29,21 @@ type connData struct { } type Transport struct { - altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme + altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme + reqMu sync.Mutex + reqCanceler map[cancelKey]func(error) + //Proxy func(*Request) (*url.URL, error) + + // MaxConnsPerHost optionally limits the total number of + // connections per host, including connections in the dialing, + // active, and idle states. On limit violation, dials will block. + // + // Zero means no limit. + MaxConnsPerHost int } +var DefaultTransport RoundTripper = &Transport{} + // taskId The unique identifier of the next task polled from the executor type taskId c.Int @@ -43,15 +58,13 @@ const ( defaultHTTPPort = "80" ) -var DefaultTransport RoundTripper = &Transport{} - // persistConn wraps a connection, usually a persistent one // (but may be used for non-keep-alive requests as well) type persistConn struct { // alt optionally specifies the TLS NextProto RoundTripper. // This is used for HTTP/2 today and future protocols later. // If it's non-nil, the rest of the fields are unused. - //alt RoundTripper + alt RoundTripper //br *bufio.Reader // from conn //bw *bufio.Writer // to conn //nwrite int64 // bytes written @@ -94,47 +107,331 @@ type freeChan struct { freech chan struct{} } +// A cancelKey is the key of the reqCanceler map. +// We wrap the *Request in this type since we want to use the original request, +// not any transient one created by roundTrip. +type cancelKey struct { + req *Request +} + +// transportRequest is a wrapper around a *Request that adds +// optional extra headers to write and stores any error to return +// from roundTrip. +type transportRequest struct { + *Request // original request, not to be mutated + extra Header // extra headers to write, or nil + //trace *httptrace.ClientTrace // optional + cancelKey cancelKey + + mu sync.Mutex // guards err + err error // first setError value for mapRoundTripError to consider +} + +// useRegisteredProtocol reports whether an alternate protocol (as registered +// with Transport.RegisterProtocol) should be respected for this request. +func (t *Transport) useRegisteredProtocol(req *Request) bool { + if req.URL.Scheme == "https" && req.requiresHTTP1() { + // If this request requires HTTP/1, don't use the + // "https" alternate protocol, which is used by the + // HTTP/2 code to take over requests if there's an + // existing cached HTTP/2 connection. + return false + } + return true +} + +// alternateRoundTripper returns the alternate RoundTripper to use +// for this request if the Request's URL scheme requires one, +// or nil for the normal case of using the Transport. +func (t *Transport) alternateRoundTripper(req *Request) RoundTripper { + if !t.useRegisteredProtocol(req) { + return nil + } + altProto, _ := t.altProto.Load().(map[string]RoundTripper) + return altProto[req.URL.Scheme] +} + func (t *Transport) RoundTrip(req *Request) (*Response, error) { - pconn, err := t.getConn(req) - if err != nil { + //t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) + //ctx := req.Context() + //trace := httptrace.ContextClientTrace(ctx) + + if req.URL == nil { + req.closeBody() + return nil, errors.New("http: nil Request.URL") + } + if req.Header == nil { + req.closeBody() + return nil, errors.New("http: nil Request.Header") + } + scheme := req.URL.Scheme + isHTTP := scheme == "http" || scheme == "https" + if isHTTP { + for k, vv := range req.Header { + if !ValidHeaderFieldName(k) { + req.closeBody() + return nil, fmt.Errorf("net/http: invalid header field name %q", k) + } + for _, v := range vv { + if !ValidHeaderFieldValue(v) { + req.closeBody() + // Don't include the value in the error, because it may be sensitive. + return nil, fmt.Errorf("net/http: invalid header field value for %q", k) + } + } + } + } + + origReq := req + cancelKey := cancelKey{origReq} + req = setupRewindBody(req) + + if altRT := t.alternateRoundTripper(req); altRT != nil { + if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol { + return resp, err + } + var err error + req, err = rewindBody(req) + if err != nil { + return nil, err + } + } + if !isHTTP { + req.closeBody() + return nil, badStringError("unsupported protocol scheme", scheme) + } + if req.Method != "" && !validMethod(req.Method) { + req.closeBody() + return nil, fmt.Errorf("net/http: invalid method %q", req.Method) + } + if req.URL.Host == "" { + req.closeBody() + return nil, errors.New("http: no Host in request URL") + } + + for { + // TODO(spongehah) timeout: because of that ctx not initialized ( initialized in setRequestCancel() ) + //select { + //case <-ctx.Done(): + // req.closeBody() + // return nil, ctx.Err() + //default: + //} + + // treq gets modified by roundTrip, so we need to recreate for each retry. + //treq := &transportRequest{Request: req, trace: trace, cancelKey: cancelKey} + treq := &transportRequest{Request: req, cancelKey: cancelKey} + cm, err := t.connectMethodForRequest(treq) + if err != nil { + req.closeBody() + return nil, err + } + + // Get the cached or newly-created connection to either the + // host (for http or https), the http proxy, or the http proxy + // pre-CONNECTed to https server. In any case, we'll be ready + // to send it requests. + pconn, err := t.getConn(treq, cm) + if err != nil { + t.setReqCanceler(cancelKey, nil) + req.closeBody() + return nil, err + } + + var resp *Response + if pconn.alt != nil { + // HTTP/2 path. + t.setReqCanceler(cancelKey, nil) // not cancelable with CancelRequest + resp, err = pconn.alt.RoundTrip(req) + } else { + resp, err = pconn.roundTrip(treq) + } + if err == nil { + resp.Request = origReq + return resp, nil + } + + // Failed. Clean up and determine whether to retry. + // TODO(spongehah) Retry & ConnPool return nil, err } - var resp *Response - resp, err = pconn.roundTrip(req) - if err != nil { +} + +func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persistConn, err error) { + //req := treq.Request + //trace := treq.trace + //ctx := req.Context() + //if trace != nil && trace.GetConn != nil { + // trace.GetConn(cm.addr()) + //} + + w := &wantConn{ + cm: cm, + key: cm.key(), + //ctx: ctx, + ready: make(chan struct{}, 1), + beforeDial: testHookPrePendingDial, + afterDial: testHookPostPendingDial, + } + defer func() { + if err != nil { + w.cancel(t, err) + } + }() + + // TODO(spongehah) ConnPool + //// Queue for idle connection. + //if delivered := t.queueForIdleConn(w); delivered { + // pc := w.pc + // // Trace only for HTTP/1. + // // HTTP/2 calls trace.GotConn itself. + // if pc.alt == nil && trace != nil && trace.GotConn != nil { + // trace.GotConn(pc.gotIdleConnTrace(pc.idleAt)) + // } + // // set request canceler to some non-nil function so we + // // can detect whether it was cleared between now and when + // // we enter roundTrip + // t.setReqCanceler(treq.cancelKey, func(error) {}) + // return pc, nil + //} + + cancelc := make(chan error, 1) + t.setReqCanceler(treq.cancelKey, func(err error) { cancelc <- err }) + + // Queue for permission to dial. + t.queueForDial(w) + + // Wait for completion or cancellation. + select { + case <-w.ready: + // Trace success but only for HTTP/1. + // HTTP/2 calls trace.GotConn itself. + //if w.pc != nil && w.pc.alt == nil && trace != nil && trace.GotConn != nil { + // trace.GotConn(httptrace.GotConnInfo{Conn: w.pc.conn, Reused: w.pc.isReused()}) + //} + if w.err != nil { + // If the request has been canceled, that's probably + // what caused w.err; if so, prefer to return the + // cancellation error (see golang.org/issue/16049). + select { + //case <-req.Cancel: + // return nil, errRequestCanceledConn + //case <-req.Context().Done(): + // return nil, req.Context().Err() + case err := <-cancelc: + if err == errRequestCanceled { + err = errRequestCanceledConn + } + return nil, err + default: + // return below + } + } + return w.pc, w.err + //case <-req.Cancel: + // return nil, errRequestCanceledConn + //case <-req.Context().Done(): + // return nil, req.Context().Err() + case err := <-cancelc: + if err == errRequestCanceled { + err = errRequestCanceledConn + } return nil, err } - return resp, nil } -func (t *Transport) getConn(req *Request) (pconn *persistConn, err error) { - host := req.URL.Hostname() - port := req.URL.Port() +// queueForDial queues w to wait for permission to begin dialing. +// Once w receives permission to dial, it will do so in a separate goroutine. +func (t *Transport) queueForDial(w *wantConn) { + w.beforeDial() + + go t.dialConnFor(w) + // TODO(spongehah) MaxConnsPerHost + //if t.MaxConnsPerHost <= 0 { + // go t.dialConnFor(w) + // return + //} + + //t.connsPerHostMu.Lock() + //defer t.connsPerHostMu.Unlock() + // + //if n := t.connsPerHost[w.key]; n < t.MaxConnsPerHost { + // if t.connsPerHost == nil { + // t.connsPerHost = make(map[connectMethodKey]int) + // } + // t.connsPerHost[w.key] = n + 1 + // go t.dialConnFor(w) + // return + //} + // + //if t.connsPerHostWait == nil { + // t.connsPerHostWait = make(map[connectMethodKey]wantConnQueue) + //} + //q := t.connsPerHostWait[w.key] + //q.cleanFront() + //q.pushBack(w) + //t.connsPerHostWait[w.key] = q +} + +// dialConnFor dials on behalf of w and delivers the result to w. +// dialConnFor has received permission to dial w.cm and is counted in t.connCount[w.cm.key()]. +// If the dial is canceled or unsuccessful, dialConnFor decrements t.connCount[w.cm.key()]. +func (t *Transport) dialConnFor(w *wantConn) { + defer w.afterDial() + + pc, err := t.dialConn(w.ctx, w.cm) + w.tryDeliver(pc, err) + // TODO(spongehah) ConnPool + //delivered := w.tryDeliver(pc, err) + //if err == nil && (!delivered || pc.alt != nil) { + // // pconn was not passed to w, + // // or it is HTTP/2 and can be shared. + // // Add to the idle connection pool. + // t.putOrCloseIdleConn(pc) + //} + //if err != nil { + // t.decConnsPerHost(w.key) + //} +} + +func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *persistConn, err error) { + pconn = &persistConn{ + t: t, + reqch: make(chan requestAndChan, 1), + cancelch: make(chan freeChan, 1), + timeoutch: make(chan struct{}, 1), + //writech: make(chan writeRequest, 1), + //closech: make(chan struct{}), + } + + // TODO(spongehah) Proxy dialConn + + treq := cm.treq + host := treq.URL.Hostname() + port := treq.URL.Port() if port == "" { // Hyper only supports http port = defaultHTTPPort } loop := libuv.DefaultLoop() - //conn := (*ConnData)(c.Calloc(1, unsafe.Sizeof(ConnData{}))) conn := new(connData) + pconn.conn = conn if conn == nil { return nil, fmt.Errorf("Failed to allocate memory for conn_data\n") } // If timeout is set, start the timer - timeoutch := make(chan struct{}, 1) - if req.timeout > 0 { + if treq.timeout > 0 { libuv.InitTimer(loop, &conn.TimeoutTimer) ct := &connAndTimeoutChan{ conn: conn, - timeoutch: timeoutch, + timeoutch: pconn.timeoutch, } (*libuv.Handle)(c.Pointer(&conn.TimeoutTimer)).SetData(c.Pointer(ct)) - conn.TimeoutTimer.Start(onTimeout, uint64(req.timeout.Milliseconds()), 0) + conn.TimeoutTimer.Start(onTimeout, uint64(treq.timeout.Milliseconds()), 0) } libuv.InitTcp(loop, &conn.TcpHandle) - //conn.TcpHandle.Data = c.Pointer(conn) (*libuv.Handle)(c.Pointer(&conn.TcpHandle)).SetData(c.Pointer(conn)) var hints net.AddrInfo @@ -145,26 +442,16 @@ func (t *Transport) getConn(req *Request) (pconn *persistConn, err error) { var res *net.AddrInfo status := net.Getaddrinfo(c.AllocaCStr(host), c.AllocaCStr(port), &hints, &res) if status != 0 { - close(timeoutch) + close(pconn.timeoutch) return nil, fmt.Errorf("getaddrinfo error\n") } - //conn.ConnectReq.Data = c.Pointer(conn) (*libuv.Req)(c.Pointer(&conn.ConnectReq)).SetData(c.Pointer(conn)) status = libuv.TcpConnect(&conn.ConnectReq, &conn.TcpHandle, res.Addr, onConnect) if status != 0 { - close(timeoutch) + close(pconn.timeoutch) return nil, fmt.Errorf("connect error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(status)))) } - pconn = &persistConn{ - conn: conn, - t: t, - reqch: make(chan requestAndChan, 1), - cancelch: make(chan freeChan, 1), - timeoutch: timeoutch, - //writech: make(chan writeRequest, 1), - //closech: make(chan struct{}), - } net.Freeaddrinfo(res) @@ -174,18 +461,19 @@ func (t *Transport) getConn(req *Request) (pconn *persistConn, err error) { return pconn, nil } -func (pc *persistConn) roundTrip(req *Request) (*Response, error) { +func (pc *persistConn) roundTrip(req *transportRequest) (*Response, error) { + testHookEnterRoundTrip() resc := make(chan responseAndError, 1) pc.reqch <- requestAndChan{ - req: req, + req: req.Request, ch: resc, } // Determine whether timeout has occurred if pc.conn.IsCompleted == 1 { rc := <-pc.reqch // blocking // Free the resources - FreeResources(nil, nil, nil, nil, pc, rc) + freeResources(nil, nil, nil, nil, pc, rc) return nil, fmt.Errorf("request timeout\n") } select { @@ -203,7 +491,6 @@ func (pc *persistConn) roundTrip(req *Request) (*Response, error) { freech: freech, } <-freech - close(freech) return nil, fmt.Errorf("request timeout\n") } } @@ -237,9 +524,9 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { select { case fc := <-pc.cancelch: // Free the resources - FreeResources(nil, respBody, bodyWriter, exec, pc, rc) + freeResources(nil, respBody, bodyWriter, exec, pc, rc) alive = false - fc.freech <- struct{}{} + close(fc.freech) return default: task := exec.Poll() @@ -253,7 +540,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { if err != nil { rc.ch <- responseAndError{err: err} // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) + freeResources(task, respBody, bodyWriter, exec, pc, rc) return } @@ -265,7 +552,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { if err != nil { rc.ch <- responseAndError{err: err} // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) + freeResources(task, respBody, bodyWriter, exec, pc, rc) return } @@ -276,7 +563,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { if sendRes != hyper.OK { rc.ch <- responseAndError{err: fmt.Errorf("failed to send request")} // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) + freeResources(task, respBody, bodyWriter, exec, pc, rc) return } @@ -287,7 +574,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { if err != nil { rc.ch <- responseAndError{err: err} // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) + freeResources(task, respBody, bodyWriter, exec, pc, rc) return } @@ -299,7 +586,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { if err != nil { rc.ch <- responseAndError{err: err} // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) + freeResources(task, respBody, bodyWriter, exec, pc, rc) return } @@ -327,7 +614,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { if err != nil { rc.ch <- responseAndError{err: err} // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) + freeResources(task, respBody, bodyWriter, exec, pc, rc) return } @@ -338,14 +625,14 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { if bodyWriter == nil { rc.ch <- responseAndError{err: fmt.Errorf("ResponseBodyWriter is nil")} // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) + freeResources(task, respBody, bodyWriter, exec, pc, rc) return } _, err := bodyWriter.Write(bytes) // blocking if err != nil { rc.ch <- responseAndError{err: err} // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) + freeResources(task, respBody, bodyWriter, exec, pc, rc) return } buf.Free() @@ -363,12 +650,12 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { c.Printf(c.Str("unexpected task type\n")) rc.ch <- responseAndError{err: fmt.Errorf("unexpected task type\n")} // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) + freeResources(task, respBody, bodyWriter, exec, pc, rc) return } // Free the resources - FreeResources(task, respBody, bodyWriter, exec, pc, rc) + freeResources(task, respBody, bodyWriter, exec, pc, rc) alive = false case notSet: @@ -544,7 +831,7 @@ func checkTaskType(task *hyper.Task, curTaskId taskId) error { case sending: if task.Type() == hyper.TaskError { c.Printf(c.Str("handshake task error!\n")) - return Fail((*hyper.Error)(task.Value())) + return fail((*hyper.Error)(task.Value())) } if task.Type() != hyper.TaskClientConn { return fmt.Errorf("unexpected task type\n") @@ -553,7 +840,7 @@ func checkTaskType(task *hyper.Task, curTaskId taskId) error { case receiveResp: if task.Type() == hyper.TaskError { c.Printf(c.Str("send task error!\n")) - return Fail((*hyper.Error)(task.Value())) + return fail((*hyper.Error)(task.Value())) } if task.Type() != hyper.TaskResponse { c.Printf(c.Str("unexpected task type\n")) @@ -563,7 +850,7 @@ func checkTaskType(task *hyper.Task, curTaskId taskId) error { case receiveRespBody: if task.Type() == hyper.TaskError { c.Printf(c.Str("body error!\n")) - return Fail((*hyper.Error)(task.Value())) + return fail((*hyper.Error)(task.Value())) } return nil case notSet: @@ -571,8 +858,8 @@ func checkTaskType(task *hyper.Task, curTaskId taskId) error { return fmt.Errorf("unexpected TaskId\n") } -// Fail prints the error details and panics -func Fail(err *hyper.Error) error { +// fail prints the error details and panics +func fail(err *hyper.Error) error { if err != nil { c.Printf(c.Str("error code: %d\n"), err.Code()) // grab the error details @@ -588,8 +875,8 @@ func Fail(err *hyper.Error) error { return nil } -// FreeResources frees the resources -func FreeResources(task *hyper.Task, respBody *hyper.Body, bodyWriter *io.PipeWriter, exec *hyper.Executor, pc *persistConn, rc requestAndChan) { +// freeResources frees the resources +func freeResources(task *hyper.Task, respBody *hyper.Body, bodyWriter *io.PipeWriter, exec *hyper.Executor, pc *persistConn, rc requestAndChan) { // Cleaning up before exiting if task != nil { task.Free() @@ -604,13 +891,13 @@ func FreeResources(task *hyper.Task, respBody *hyper.Body, bodyWriter *io.PipeWr exec.Free() } (*libuv.Handle)(c.Pointer(&pc.conn.TcpHandle)).Close(nil) - FreeConnData(pc.conn) + freeConnData(pc.conn) - CloseChannels(rc, pc) + closeChannels(rc, pc) } -// CloseChannels closes the channels -func CloseChannels(rc requestAndChan, pc *persistConn) { +// closeChannels closes the channels +func closeChannels(rc requestAndChan, pc *persistConn) { // Closing the channel close(rc.ch) close(pc.reqch) @@ -618,8 +905,8 @@ func CloseChannels(rc requestAndChan, pc *persistConn) { close(pc.cancelch) } -// FreeConnData frees the connection data -func FreeConnData(conn *connData) { +// freeConnData frees the connection data +func freeConnData(conn *connData) { if conn.ReadWaker != nil { conn.ReadWaker.Free() conn.ReadWaker = nil @@ -643,9 +930,23 @@ func (e *httpError) Error() string { return e.err } func (e *httpError) Timeout() bool { return e.timeout } func (e *httpError) Temporary() bool { return true } +func nop() {} + +// ErrSkipAltProtocol is a sentinel error value defined by Transport.RegisterProtocol. +var ErrSkipAltProtocol = errors.New("net/http: skip alternate protocol") + +var errCannotRewind = errors.New("net/http: cannot rewind body after connection loss") + var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true} -func nop() {} +// errRequestCanceled is set to be identical to the one from h2 to facilitate +// testing. +var errRequestCanceled = http2errRequestCanceled + +// errRequestCanceled is a copy of net/http's errRequestCanceled because it's not +// exported. At least they'll be DeepEqual for h1-vs-h2 comparisons tests. +var http2errRequestCanceled = errors.New("net/http: request canceled") +var errRequestCanceledConn = errors.New("net/http: request canceled while waiting for connection") // TODO: unify? /*// alternateRoundTripper returns the alternate RoundTripper to use // for this request if the Request's URL scheme requires one, @@ -678,4 +979,223 @@ func idnaASCIIFromURL(url *url.URL) string { addr = v } return addr -} \ No newline at end of file +} + +type readTrackingBody struct { + io.ReadCloser + didRead bool + didClose bool +} + +func (r *readTrackingBody) Read(data []byte) (int, error) { + r.didRead = true + return r.ReadCloser.Read(data) +} + +func (r *readTrackingBody) Close() error { + r.didClose = true + return r.ReadCloser.Close() +} + +// testHooks. Always non-nil. +var ( + testHookEnterRoundTrip = nop + testHookWaitResLoop = nop + testHookRoundTripRetried = nop + testHookPrePendingDial = nop + testHookPostPendingDial = nop + + testHookMu sync.Locker = fakeLocker{} // guards following + testHookReadLoopBeforeNextRead = nop +) + +// fakeLocker is a sync.Locker which does nothing. It's used to guard +// test-only fields when not under test, to avoid runtime atomic +// overhead. +type fakeLocker struct{} + +func (fakeLocker) Lock() {} +func (fakeLocker) Unlock() {} + +// setupRewindBody returns a new request with a custom body wrapper +// that can report whether the body needs rewinding. +// This lets rewindBody avoid an error result when the request +// does not have GetBody but the body hasn't been read at all yet. +func setupRewindBody(req *Request) *Request { + if req.Body == nil || req.Body == NoBody { + return req + } + newReq := *req + newReq.Body = &readTrackingBody{ReadCloser: req.Body} + return &newReq +} + +// rewindBody returns a new request with the body rewound. +// It returns req unmodified if the body does not need rewinding. +// rewindBody takes care of closing req.Body when appropriate +// (in all cases except when rewindBody returns req unmodified). +func rewindBody(req *Request) (rewound *Request, err error) { + if req.Body == nil || req.Body == NoBody || (!req.Body.(*readTrackingBody).didRead && !req.Body.(*readTrackingBody).didClose) { + return req, nil // nothing to rewind + } + if !req.Body.(*readTrackingBody).didClose { + req.closeBody() + } + if req.GetBody == nil { + return nil, errCannotRewind + } + body, err := req.GetBody() + if err != nil { + return nil, err + } + newReq := *req + newReq.Body = &readTrackingBody{ReadCloser: body} + return &newReq, nil +} + +func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) { + t.reqMu.Lock() + defer t.reqMu.Unlock() + if t.reqCanceler == nil { + t.reqCanceler = make(map[cancelKey]func(error)) + } + if fn != nil { + t.reqCanceler[key] = fn + } else { + delete(t.reqCanceler, key) + } +} + +// connectMethod is the map key (in its String form) for keeping persistent +// TCP connections alive for subsequent HTTP requests. +// +// A connect method may be of the following types: +// +// connectMethod.key().String() Description +// ------------------------------ ------------------------- +// |http|foo.com http directly to server, no proxy +// |https|foo.com https directly to server, no proxy +// |https,h1|foo.com https directly to server w/o HTTP/2, no proxy +// http://proxy.com|https|foo.com http to proxy, then CONNECT to foo.com +// http://proxy.com|http http to proxy, http to anywhere after that +// socks5://proxy.com|http|foo.com socks5 to proxy, then http to foo.com +// socks5://proxy.com|https|foo.com socks5 to proxy, then https to foo.com +// https://proxy.com|https|foo.com https to proxy, then CONNECT to foo.com +// https://proxy.com|http https to proxy, http to anywhere after that +type connectMethod struct { + _ incomparable + proxyURL *url.URL // nil for no proxy, else full proxy URL + targetScheme string // "http" or "https" + // If proxyURL specifies an http or https proxy, and targetScheme is http (not https), + // then targetAddr is not included in the connect method key, because the socket can + // be reused for different targetAddr values. + targetAddr string + treq *transportRequest // optional + onlyH1 bool // whether to disable HTTP/2 and force HTTP/1 +} + +func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) { + cm.targetScheme = treq.URL.Scheme + // TODO(spongehah) canonicalAddr & Proxy + //cm.targetAddr = canonicalAddr(treq.URL) + //if t.Proxy != nil { + // cm.proxyURL, err = t.Proxy(treq.Request) + //} + cm.treq = treq + cm.onlyH1 = treq.requiresHTTP1() + return cm, err +} + +// connectMethodKey is the map key version of connectMethod, with a +// stringified proxy URL (or the empty string) instead of a pointer to +// a URL. +type connectMethodKey struct { + proxy, scheme, addr string + onlyH1 bool +} + +// A wantConn records state about a wanted connection +// (that is, an active call to getConn). +// The conn may be gotten by dialing or by finding an idle connection, +// or a cancellation may make the conn no longer wanted. +// These three options are racing against each other and use +// wantConn to coordinate and agree about the winning outcome. +type wantConn struct { + cm connectMethod + key connectMethodKey // cm.key() + ctx context.Context // context for dial + ready chan struct{} // closed when pc, err pair is delivered + + // hooks for testing to know when dials are done + // beforeDial is called in the getConn goroutine when the dial is queued. + // afterDial is called when the dial is completed or canceled. + beforeDial func() + afterDial func() + + mu sync.Mutex // protects pc, err, close(ready) + pc *persistConn + err error +} + +// waiting reports whether w is still waiting for an answer (connection or error). +func (w *wantConn) waiting() bool { + select { + case <-w.ready: + return false + default: + return true + } +} + +// tryDeliver attempts to deliver pc, err to w and reports whether it succeeded. +func (w *wantConn) tryDeliver(pc *persistConn, err error) bool { + w.mu.Lock() + defer w.mu.Unlock() + + if w.pc != nil || w.err != nil { + return false + } + + w.pc = pc + w.err = err + if w.pc == nil && w.err == nil { + panic("net/http: internal error: misuse of tryDeliver") + } + close(w.ready) + return true +} + +// cancel marks w as no longer wanting a result (for example, due to cancellation). +// If a connection has been delivered already, cancel returns it with t.putOrCloseIdleConn. +func (w *wantConn) cancel(t *Transport, err error) { + w.mu.Lock() + if w.pc == nil && w.err == nil { + close(w.ready) // catch misbehavior in future delivery + } + //pc := w.pc + w.pc = nil + w.err = err + w.mu.Unlock() + + // TODO(spongehah) ConnPool + //if pc != nil { + // t.putOrCloseIdleConn(pc) + //} +} + +func (cm *connectMethod) key() connectMethodKey { + proxyStr := "" + targetAddr := cm.targetAddr + if cm.proxyURL != nil { + proxyStr = cm.proxyURL.String() + if (cm.proxyURL.Scheme == "http" || cm.proxyURL.Scheme == "https") && cm.targetScheme == "http" { + targetAddr = "" + } + } + return connectMethodKey{ + proxy: proxyStr, + scheme: cm.targetScheme, + addr: targetAddr, + onlyH1: cm.onlyH1, + } +} diff --git a/x/http/util.go b/x/http/util.go index 674f481..e5d2d03 100644 --- a/x/http/util.go +++ b/x/http/util.go @@ -94,6 +94,93 @@ func IsTokenRune(r rune) bool { // httpguts.IsTokenRune return i < len(isTokenTable) && isTokenTable[i] } +// ValidHeaderFieldName reports whether v is a valid HTTP/1.x header name. +// HTTP/2 imposes the additional restriction that uppercase ASCII +// letters are not allowed. +// +// RFC 7230 says: +// +// header-field = field-name ":" OWS field-value OWS +// field-name = token +// token = 1*tchar +// tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / +// "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA +func ValidHeaderFieldName(v string) bool { // httpguts.ValidHeaderFieldName + if len(v) == 0 { + return false + } + for i := 0; i < len(v); i++ { + if !isTokenTable[v[i]] { + return false + } + } + return true +} + +// ValidHeaderFieldValue reports whether v is a valid "field-value" according to +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 : +// +// message-header = field-name ":" [ field-value ] +// field-value = *( field-content | LWS ) +// field-content = +// +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2 : +// +// TEXT = +// LWS = [CRLF] 1*( SP | HT ) +// CTL = +// +// RFC 7230 says: +// +// field-value = *( field-content / obs-fold ) +// obj-fold = N/A to http2, and deprecated +// field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ] +// field-vchar = VCHAR / obs-text +// obs-text = %x80-FF +// VCHAR = "any visible [USASCII] character" +// +// http2 further says: "Similarly, HTTP/2 allows header field values +// that are not valid. While most of the values that can be encoded +// will not alter header field parsing, carriage return (CR, ASCII +// 0xd), line feed (LF, ASCII 0xa), and the zero character (NUL, ASCII +// 0x0) might be exploited by an attacker if they are translated +// verbatim. Any request or response that contains a character not +// permitted in a header field value MUST be treated as malformed +// (Section 8.1.2.6). Valid characters are defined by the +// field-content ABNF rule in Section 3.2 of [RFC7230]." +// +// This function does not (yet?) properly handle the rejection of +// strings that begin or end with SP or HTAB. +func ValidHeaderFieldValue(v string) bool { // httpguts.ValidHeaderFieldValue + for i := 0; i < len(v); i++ { + b := v[i] + if isCTL(b) && !isLWS(b) { + return false + } + } + return true +} + +// isLWS reports whether b is linear white space, according +// to http://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2 +// +// LWS = [CRLF] 1*( SP | HT ) +func isLWS(b byte) bool { return b == ' ' || b == '\t' } // httpguts.isLWS + +// isCTL reports whether b is a control byte, according +// to http://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2 +// +// CTL = +func isCTL(b byte) bool { // httpguts.isCTL + const del = 0x7f // a CTL + return b < ' ' || b == del +} + // IsPrint returns whether s is ASCII and printable according to // https://tools.ietf.org/html/rfc20#section-4.2. func IsPrint(s string) bool { // ascii.IsPrint From d644d432ffe6ef80dffa221179f41aec73fc1ab3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Tue, 20 Aug 2024 10:23:08 +0800 Subject: [PATCH 17/55] refactor(x/http/client): Move the file directory --- x/http/_demo/upload/upload.go | 24 ------------- x/{ => net}/http/_demo/get/get.go | 2 +- x/{ => net}/http/_demo/headers/headers.go | 2 +- x/{ => net}/http/_demo/post/post.go | 2 +- x/{ => net}/http/_demo/postform/postform.go | 2 +- x/{ => net}/http/_demo/redirect/redirect.go | 2 +- .../http/_demo/server/redirectServer.go | 0 x/{ => net}/http/_demo/timeout/timeout.go | 2 +- x/{ => net}/http/_demo/upload/example.txt | 0 x/net/http/_demo/upload/upload.go | 35 +++++++++++++++++++ x/{ => net}/http/client.go | 0 x/{ => net}/http/clone.go | 0 x/{ => net}/http/cookie.go | 0 x/{ => net}/http/header.go | 0 x/{ => net}/http/http.go | 0 x/{ => net}/http/jar.go | 0 x/{ => net}/http/request.go | 0 x/{ => net}/http/response.go | 0 x/{ => net}/http/transfer.go | 0 x/{ => net}/http/transport.go | 0 x/{ => net}/http/util.go | 0 21 files changed, 41 insertions(+), 30 deletions(-) delete mode 100644 x/http/_demo/upload/upload.go rename x/{ => net}/http/_demo/get/get.go (89%) rename x/{ => net}/http/_demo/headers/headers.go (96%) rename x/{ => net}/http/_demo/post/post.go (91%) rename x/{ => net}/http/_demo/postform/postform.go (92%) rename x/{ => net}/http/_demo/redirect/redirect.go (90%) rename x/{ => net}/http/_demo/server/redirectServer.go (100%) rename x/{ => net}/http/_demo/timeout/timeout.go (92%) rename x/{ => net}/http/_demo/upload/example.txt (100%) create mode 100644 x/net/http/_demo/upload/upload.go rename x/{ => net}/http/client.go (100%) rename x/{ => net}/http/clone.go (100%) rename x/{ => net}/http/cookie.go (100%) rename x/{ => net}/http/header.go (100%) rename x/{ => net}/http/http.go (100%) rename x/{ => net}/http/jar.go (100%) rename x/{ => net}/http/request.go (100%) rename x/{ => net}/http/response.go (100%) rename x/{ => net}/http/transfer.go (100%) rename x/{ => net}/http/transport.go (100%) rename x/{ => net}/http/util.go (100%) diff --git a/x/http/_demo/upload/upload.go b/x/http/_demo/upload/upload.go deleted file mode 100644 index c6bb391..0000000 --- a/x/http/_demo/upload/upload.go +++ /dev/null @@ -1,24 +0,0 @@ -package main - -import ( - "fmt" - "io" - - "github.com/goplus/llgoexamples/x/http" -) - -func main() { - resp, err := http.Post("http://httpbin.org/post", "", nil) - if err != nil { - fmt.Println(err) - return - } - fmt.Println(resp.Status) - body, err := io.ReadAll(resp.Body) - if err != nil { - fmt.Println(err) - return - } - fmt.Println(string(body)) - defer resp.Body.Close() -} diff --git a/x/http/_demo/get/get.go b/x/net/http/_demo/get/get.go similarity index 89% rename from x/http/_demo/get/get.go rename to x/net/http/_demo/get/get.go index bff1bd1..79c18ba 100644 --- a/x/http/_demo/get/get.go +++ b/x/net/http/_demo/get/get.go @@ -4,7 +4,7 @@ import ( "fmt" "io" - "github.com/goplus/llgoexamples/x/http" + "github.com/goplus/llgoexamples/x/net/http" ) func main() { diff --git a/x/http/_demo/headers/headers.go b/x/net/http/_demo/headers/headers.go similarity index 96% rename from x/http/_demo/headers/headers.go rename to x/net/http/_demo/headers/headers.go index 2672a66..71d42b7 100644 --- a/x/http/_demo/headers/headers.go +++ b/x/net/http/_demo/headers/headers.go @@ -4,7 +4,7 @@ import ( "fmt" "io" - "github.com/goplus/llgoexamples/x/http" + "github.com/goplus/llgoexamples/x/net/http" ) func main() { diff --git a/x/http/_demo/post/post.go b/x/net/http/_demo/post/post.go similarity index 91% rename from x/http/_demo/post/post.go rename to x/net/http/_demo/post/post.go index 4958a8e..f169dfc 100644 --- a/x/http/_demo/post/post.go +++ b/x/net/http/_demo/post/post.go @@ -5,7 +5,7 @@ import ( "fmt" "io" - "github.com/goplus/llgoexamples/x/http" + "github.com/goplus/llgoexamples/x/net/http" ) func main() { diff --git a/x/http/_demo/postform/postform.go b/x/net/http/_demo/postform/postform.go similarity index 92% rename from x/http/_demo/postform/postform.go rename to x/net/http/_demo/postform/postform.go index 5315ca9..1636786 100644 --- a/x/http/_demo/postform/postform.go +++ b/x/net/http/_demo/postform/postform.go @@ -5,7 +5,7 @@ import ( "io" "net/url" - "github.com/goplus/llgo/x/http" + "github.com/goplus/llgo/x/net/http" ) func main() { diff --git a/x/http/_demo/redirect/redirect.go b/x/net/http/_demo/redirect/redirect.go similarity index 90% rename from x/http/_demo/redirect/redirect.go rename to x/net/http/_demo/redirect/redirect.go index 48465b7..e4fdb92 100644 --- a/x/http/_demo/redirect/redirect.go +++ b/x/net/http/_demo/redirect/redirect.go @@ -4,7 +4,7 @@ import ( "fmt" "io" - "github.com/goplus/llgoexamples/x/http" + "github.com/goplus/llgoexamples/x/net/http" ) func main() { diff --git a/x/http/_demo/server/redirectServer.go b/x/net/http/_demo/server/redirectServer.go similarity index 100% rename from x/http/_demo/server/redirectServer.go rename to x/net/http/_demo/server/redirectServer.go diff --git a/x/http/_demo/timeout/timeout.go b/x/net/http/_demo/timeout/timeout.go similarity index 92% rename from x/http/_demo/timeout/timeout.go rename to x/net/http/_demo/timeout/timeout.go index 42f8bf8..ddb2d25 100644 --- a/x/http/_demo/timeout/timeout.go +++ b/x/net/http/_demo/timeout/timeout.go @@ -5,7 +5,7 @@ import ( "io" "time" - "github.com/goplus/llgoexamples/x/http" + "github.com/goplus/llgoexamples/x/net/http" ) func main() { diff --git a/x/http/_demo/upload/example.txt b/x/net/http/_demo/upload/example.txt similarity index 100% rename from x/http/_demo/upload/example.txt rename to x/net/http/_demo/upload/example.txt diff --git a/x/net/http/_demo/upload/upload.go b/x/net/http/_demo/upload/upload.go new file mode 100644 index 0000000..3dab514 --- /dev/null +++ b/x/net/http/_demo/upload/upload.go @@ -0,0 +1,35 @@ +package main + +import ( + "fmt" + "io" + "os" + + "github.com/goplus/llgoexamples/x/net/http" +) + +func main() { + url := "http://httpbin.org/post" + filePath := "/Users/spongehah/go/src/llgo/x/http/_demo/upload/example.txt" // Replace with your file path + + file, err := os.Open(filePath) + if err != nil { + fmt.Println("Error opening file:", err) + return + } + defer file.Close() + + resp, err := http.Post(url, "application/octet-stream", file) + if err != nil { + fmt.Println(err) + return + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(string(respBody)) +} diff --git a/x/http/client.go b/x/net/http/client.go similarity index 100% rename from x/http/client.go rename to x/net/http/client.go diff --git a/x/http/clone.go b/x/net/http/clone.go similarity index 100% rename from x/http/clone.go rename to x/net/http/clone.go diff --git a/x/http/cookie.go b/x/net/http/cookie.go similarity index 100% rename from x/http/cookie.go rename to x/net/http/cookie.go diff --git a/x/http/header.go b/x/net/http/header.go similarity index 100% rename from x/http/header.go rename to x/net/http/header.go diff --git a/x/http/http.go b/x/net/http/http.go similarity index 100% rename from x/http/http.go rename to x/net/http/http.go diff --git a/x/http/jar.go b/x/net/http/jar.go similarity index 100% rename from x/http/jar.go rename to x/net/http/jar.go diff --git a/x/http/request.go b/x/net/http/request.go similarity index 100% rename from x/http/request.go rename to x/net/http/request.go diff --git a/x/http/response.go b/x/net/http/response.go similarity index 100% rename from x/http/response.go rename to x/net/http/response.go diff --git a/x/http/transfer.go b/x/net/http/transfer.go similarity index 100% rename from x/http/transfer.go rename to x/net/http/transfer.go diff --git a/x/http/transport.go b/x/net/http/transport.go similarity index 100% rename from x/http/transport.go rename to x/net/http/transport.go diff --git a/x/http/util.go b/x/net/http/util.go similarity index 100% rename from x/http/util.go rename to x/net/http/util.go From 6f115f12557853396bbdbca708945840970b1b00 Mon Sep 17 00:00:00 2001 From: hackerchai Date: Tue, 20 Aug 2024 17:32:43 +0800 Subject: [PATCH 18/55] fix(x/net/http): Fix request receive logic Signed-off-by: hackerchai --- x/net/http/_demo/http.go | 7 +- x/net/http/request.go | 50 ++++++++++--- x/net/http/response.go | 42 +++++++---- x/net/http/server.go | 154 +++++++++++++++++++++++++++++++-------- x/net/http/servermux.go | 2 + 5 files changed, 197 insertions(+), 58 deletions(-) diff --git a/x/net/http/_demo/http.go b/x/net/http/_demo/http.go index d0b3c0f..34240cd 100644 --- a/x/net/http/_demo/http.go +++ b/x/net/http/_demo/http.go @@ -13,7 +13,7 @@ func echoHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "Error reading request body", http.StatusInternalServerError) return } - defer r.Body.Close() + //defer r.Body.Close() w.Header().Set("Content-Type", "text/plain") @@ -23,8 +23,9 @@ func echoHandler(w http.ResponseWriter, r *http.Request) { func main() { http.HandleFunc("/echo", echoHandler) - fmt.Println("Starting server on :8080") - if err := http.ListenAndServe("127.0.0.1:1234", nil); err != nil { + fmt.Println("Starting server on :1234") + server := http.NewServer("127.0.0.1:1234") + if err := server.ListenAndServe(); err != nil { panic(err) } } diff --git a/x/net/http/request.go b/x/net/http/request.go index 0404815..946200f 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -28,24 +28,47 @@ type Request struct { timeout time.Duration } -func newRequest(conn *conn, hyperReq *hyper.Request) (*Request, error) { +func newRequest(ListenAddr string, conn *conn, hyperReq *hyper.Request) (*Request, error) { method := make([]byte, 32) - methodLen := uintptr(len(method)) + methodLen := unsafe.Sizeof(method) if err := hyperReq.Method(&method[0], &methodLen); err != hyper.OK { return nil, fmt.Errorf("failed to get method: %v", err) } methodStr := string(method[:methodLen]) + fmt.Printf("Method: %s\n", methodStr) var scheme, authority, pathAndQuery [1024]byte - schemeLen, authorityLen, pathAndQueryLen := uintptr(len(scheme)), uintptr(len(authority)), uintptr(len(pathAndQuery)) - if err := hyperReq.URIParts(&scheme[0], &schemeLen, &authority[0], &authorityLen, &pathAndQuery[0], &pathAndQueryLen); err != hyper.OK { - return nil, fmt.Errorf("failed to get URI parts: %v", err) + schemeLen, authorityLen, pathAndQueryLen := unsafe.Sizeof(scheme), unsafe.Sizeof(authority), unsafe.Sizeof(pathAndQuery) + uriResult := hyperReq.URIParts(&scheme[0], &schemeLen, &authority[0], &authorityLen, &pathAndQuery[0], &pathAndQueryLen); + if uriResult != hyper.OK { + return nil, fmt.Errorf("failed to get URI parts: %v", uriResult) } + var schemeStr, authorityStr, pathAndQueryStr string + if schemeLen == 0 { + schemeStr = "http" + } else { + schemeStr = string(scheme[:schemeLen]) + } + + if authorityLen == 0 { + authorityStr = ListenAddr + } else { + authorityStr = string(authority[:authorityLen]) + } + + if pathAndQueryLen == 0 { + return nil, fmt.Errorf("failed to get URI path and query: %v", uriResult) + } else { + pathAndQueryStr = string(pathAndQuery[:pathAndQueryLen]) + } + + var proto string var protoMajor, protoMinor int version := hyperReq.Version() + fmt.Printf("Version: %d\n", version) switch version { case hyper.HTTPVersion10: proto = "HTTP/1.0" @@ -67,26 +90,27 @@ func newRequest(conn *conn, hyperReq *hyper.Request) (*Request, error) { return nil, fmt.Errorf("unknown HTTP version: %d", version) } - urlStr := fmt.Sprintf("%s://%s%s", string(scheme[:schemeLen]), string(authority[:authorityLen]), string(pathAndQuery[:pathAndQueryLen])) + urlStr := fmt.Sprintf("%s://%s%s", schemeStr, authorityStr, pathAndQueryStr) + fmt.Printf("URL: %s\n", urlStr) url, err := url.Parse(urlStr) if err != nil { return nil, err } - req := &Request{ + req := Request{ Method: methodStr, URL: url, Proto: proto, ProtoMajor: protoMajor, ProtoMinor: protoMinor, Header: make(Header), - Host: string(authority[:authorityLen]), + Host: authorityStr, timeout: 0, } headers := hyperReq.Headers() if headers != nil { - headers.Foreach(addHeader, c.Pointer(req)) + headers.Foreach(addHeader, unsafe.Pointer(&req)) } else { return nil, fmt.Errorf("failed to get request headers") } @@ -94,10 +118,11 @@ func newRequest(conn *conn, hyperReq *hyper.Request) (*Request, error) { if methodStr == "POST" || methodStr == "PUT" || methodStr == "PATCH" { body := hyperReq.Body() if body != nil { - var bodyWriter *io.PipeWriter + bodyWriter := new(io.PipeWriter) req.Body, bodyWriter = io.Pipe() - task := body.Foreach(getBodyChunk, c.Pointer(bodyWriter), freeBodyWriter) + + task := body.Foreach(getBodyChunk, c.Pointer(&bodyWriter), freeBodyWriter) if task != nil { r := conn.Executor.Push(task) if r != hyper.OK { @@ -113,7 +138,7 @@ func newRequest(conn *conn, hyperReq *hyper.Request) (*Request, error) { } } - return req, nil + return &req, nil } func addHeader(data unsafe.Pointer, name *byte, nameLen uintptr, value *byte, valueLen uintptr) c.Int { @@ -132,6 +157,7 @@ func addHeader(data unsafe.Pointer, name *byte, nameLen uintptr, value *byte, va } func getBodyChunk(userdata c.Pointer, chunk *hyper.Buf) c.Int { + fmt.Printf("getBodyChunk called\n") writer := (*io.PipeWriter)(userdata) buf := chunk.Bytes() len := chunk.Len() diff --git a/x/net/http/response.go b/x/net/http/response.go index b3c110c..14273c6 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -15,6 +15,7 @@ type response struct { written bool body []byte channel *hyper.ResponseChannel + resp *hyper.Response } type body struct { @@ -27,10 +28,12 @@ var DefaultChunkSize uintptr = 8192 func newResponse(channel *hyper.ResponseChannel) *response { - return &response{ + fmt.Printf("newResponse called\n") + resp := response{ header: make(Header), channel: channel, } + return &resp } func (r *response) Header() Header { @@ -46,22 +49,24 @@ func (r *response) Write(data []byte) (int, error) { } func (r *response) WriteHeader(statusCode int) { + fmt.Printf("WriteHeader called\n") if r.written { return } r.written = true r.statusCode = statusCode - resp := hyper.NewResponse() - resp.SetStatus(uint16(statusCode)) + newResp := hyper.NewResponse() - headers := resp.Headers() + newResp.SetStatus(uint16(statusCode)) + + headers := newResp.Headers() for key, values := range r.header { valueLen := len(values) if valueLen > 1 { for _, value := range values { if headers.Add(&[]byte(key)[0], c.Strlen(c.AllocaCStr(key)), &[]byte(value)[0], c.Strlen(c.AllocaCStr(value))) != hyper.OK { - return + return } } } else if valueLen == 1 { @@ -72,34 +77,45 @@ func (r *response) WriteHeader(statusCode int) { return } } - - r.channel.Send(resp) + r.resp = newResp } func (r *response) finalize() error { + fmt.Printf("finalize called\n") if !r.written { r.WriteHeader(200) } - bodyData := &body{ + bodyData := body{ data: r.body, len: uintptr(len(r.body)), readLen: 0, } + fmt.Printf("bodyData constructed\n") body := hyper.NewBody() - body.SetUserdata(unsafe.Pointer(bodyData), nil) - + if body == nil { + return fmt.Errorf("failed to create body") + } body.SetDataFunc(setBodyDataFunc) + body.SetUserdata(unsafe.Pointer(&bodyData), nil) + fmt.Printf("bodyData userdata set\n") + + fmt.Printf("bodyData set\n") - resp := hyper.NewResponse() - resp.SetBody(body) + resBody := r.resp.SetBody(body) + if resBody != hyper.OK { + return fmt.Errorf("failed to set body") + } + fmt.Printf("body set\n") - r.channel.Send(resp) + r.channel.Send(r.resp) + fmt.Printf("response sent\n") return nil } func setBodyDataFunc(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { + fmt.Printf("setBodyDataFunc called\n") body := (*body)(userdata) if body.len > 0 { if body.len > DefaultChunkSize { diff --git a/x/net/http/server.go b/x/net/http/server.go index 833b4e9..909c8f6 100644 --- a/x/net/http/server.go +++ b/x/net/http/server.go @@ -51,8 +51,11 @@ type conn struct { } type serviceUserdata struct { + Host [128]c.Char + Port [8]c.Char Server *Server - Conn *conn + Conn *conn + ListenAddr string } func NewServer(addr string) *Server { @@ -71,6 +74,9 @@ func ListenAndServe(addr string, handler Handler) error { func (srv *Server) ListenAndServe() error { srv.uvLoop = libuv.DefaultLoop() + if srv.uvLoop == nil { + return fmt.Errorf("failed to get default loop") + } if err := libuv.InitTcp(srv.uvLoop, &srv.uvServer); err != 0 { return fmt.Errorf("failed to init TCP: %v", err) @@ -101,8 +107,9 @@ func (srv *Server) ListenAndServe() error { if result != 0 { return fmt.Errorf("failed to set SO_REUSEADDR: %v", result) } - - (*libuv.Stream)(&srv.uvServer).Data = unsafe.Pointer(srv) + + //(*libuv.Stream)(&srv.uvServer).Data = unsafe.Pointer(srv) + (*libuv.Handle)(unsafe.Pointer(&srv.uvServer)).SetData(unsafe.Pointer(srv)) if err := (*libuv.Stream)(&srv.uvServer).Listen(128, onNewConnection); err != 0 { return fmt.Errorf("failed to listen: %v", err) } @@ -110,17 +117,25 @@ func (srv *Server) ListenAndServe() error { fmt.Printf("Listening on %s\n", srv.Addr) for { - srv.uvLoop.Run(libuv.RUN_NOWAIT) + res := srv.uvLoop.Run(libuv.RUN_NOWAIT) + if res < 0 { + fmt.Fprintf(os.Stderr, "uv_loop_run error: %s\n", libuv.Strerror(libuv.Errno(res))) + break + } - // for conn := range srv.activeConnections { - // task := conn.Executor.Poll() - // for task != nil { - // srv.handleTask(task) - // task.Free() - // task = conn.Executor.Poll() - // } - // } + for conn := range srv.activeConnections { + fmt.Printf("Active connection found\n") + if conn.Executor != nil { + task := conn.Executor.Poll() + for task != nil { + srv.handleTask(task) + task.Free() + task = conn.Executor.Poll() + } + } + } } + return nil } func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) { @@ -134,24 +149,68 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { return } - client := (*libuv.Tcp)(c.Malloc(unsafe.Sizeof(libuv.Tcp{}))) - libuv.InitTcp(libuv.DefaultLoop(), client) srv := (*Server)(serverStream.Data) + if srv == nil { + fmt.Fprintf(os.Stderr, "Server is nil\n") + return + } + + client := (*libuv.Tcp)(c.Malloc(unsafe.Sizeof(libuv.Tcp{}))) + libuv.InitTcp(srv.uvLoop, client) if serverStream.Accept((*libuv.Stream)(unsafe.Pointer(client))) == 0 { fmt.Println("Accepted new connection") - conn := createConnData(libuv.DefaultLoop(), client) + userdata := createServiceUserdata() + userdata.Server = srv + if userdata == nil { + fmt.Fprintf(os.Stderr, "Failed to create service userdata\n") + (*libuv.Handle)(unsafe.Pointer(client)).Close(onClose) + freeServiceUserdata(unsafe.Pointer(userdata)) + return + } + fmt.Printf("ListenAddr: %s\n", srv.Addr) + userdata.ListenAddr = srv.Addr + + var addr cnet.SockaddrStorage + addrlen := c.Int(unsafe.Sizeof(addr)) + client.Getpeername((*cnet.SockAddr)(c.Pointer(&addr)), &addrlen) + + if addr.Family == cnet.AF_INET { + s := (*cnet.SockaddrIn)(unsafe.Pointer(&addr)) + libuv.Ip4Name(s, (*c.Char)(&userdata.Host[0]), unsafe.Sizeof(userdata.Host)) + c.Snprintf((*c.Char)(&userdata.Port[0]), unsafe.Sizeof(userdata.Port), c.Str("%d"), cnet.Ntohs(s.Port)) + } else if addr.Family == cnet.AF_INET6 { + s := (*cnet.SockaddrIn6)(unsafe.Pointer(&addr)) + libuv.Ip6Name(s, (*c.Char)(&userdata.Host[0]), unsafe.Sizeof(userdata.Host)) + c.Snprintf((*c.Char)(&userdata.Port[0]), unsafe.Sizeof(userdata.Port), c.Str("%d"), cnet.Ntohs(s.Port)) + } + + fmt.Printf("New incoming connection from (%s:%s)\n", c.GoString((*c.Char)(&userdata.Host[0])), + c.GoString((*c.Char)(&userdata.Port[0]))) + + conn := createConnData(srv.uvLoop, client) if conn == nil { fmt.Fprintf(os.Stderr, "Failed to create Conn\n") (*libuv.Handle)(unsafe.Pointer(client)).Close(onClose) + freeServiceUserdata(unsafe.Pointer(userdata)) + return + } + + executor := hyper.NewExecutor() + if executor == nil { + fmt.Fprintf(os.Stderr, "Failed to create Executor\n") + (*libuv.Handle)(unsafe.Pointer(client)).Close(onClose) + freeServiceUserdata(unsafe.Pointer(userdata)) return } + conn.Executor = executor + + userdata.Conn = conn + fmt.Println("Conn created") srv.trackConn(conn, true) fmt.Println("Conn tracked") - userdata := createServiceUserdata(srv, conn) - io := createIo(conn) service := hyper.ServiceNew(serverCallback) service.SetUserdata(unsafe.Pointer(userdata), freeServiceUserdata) @@ -178,13 +237,14 @@ func serverCallback(userdata unsafe.Pointer, hyperReq *hyper.Request, channel *h return } - req, err := newRequest(userData.Conn, hyperReq) + req, err := newRequest(userData.ListenAddr, userData.Conn, hyperReq) if err != nil { fmt.Printf("Error creating request: %v\n", err) return } res := newResponse(channel) + fmt.Printf("Response created\n") userData.Server.Handler.ServeHTTP(res, req) @@ -231,7 +291,7 @@ func (srv *Server) handleTask(task *hyper.Task) { } } -func (s *Server)trackConn(c *conn, add bool) { +func (s *Server) trackConn(c *conn, add bool) { s.mu.Lock() defer s.mu.Unlock() if s.activeConnections == nil { @@ -252,13 +312,11 @@ func createIo(conn *conn) *hyper.Io { return io } -func createServiceUserdata(srv *Server, conn *conn) *serviceUserdata { +func createServiceUserdata() *serviceUserdata { userdata := (*serviceUserdata)(c.Calloc(1, unsafe.Sizeof(serviceUserdata{}))) if userdata == nil { fmt.Fprintf(os.Stderr, "Failed to allocate service_userdata\n") } - userdata.Server = srv - userdata.Conn = conn return userdata } @@ -280,9 +338,11 @@ func readCb(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uintp if conn.EventMask&c.Uint(libuv.READABLE) == 0 { conn.EventMask |= c.Uint(libuv.READABLE) + fmt.Printf("ReadCb Event mask: %d\n", conn.EventMask) if !updateConnRegistrations(conn, false) { return hyper.IoError } + fmt.Printf("ReadCb updateConnRegistrations\n") } conn.ReadWaker = ctx.Waker() @@ -307,6 +367,7 @@ func writeCb(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uint if conn.EventMask&c.Uint(libuv.WRITABLE) == 0 { conn.EventMask |= c.Uint(libuv.WRITABLE) + fmt.Printf("WriteCb Event mask: %d\n", conn.EventMask) if !updateConnRegistrations(conn, false) { return hyper.IoError } @@ -321,7 +382,8 @@ func onClose(handle *libuv.Handle) { } func onPoll(handle *libuv.Poll, status c.Int, events c.Int) { - conn := (*conn)(handle.Data) + fmt.Printf("onPoll called\n") + conn := (*conn)((*libuv.Handle)(unsafe.Pointer(handle)).GetData()) if status < 0 { fmt.Fprintf(os.Stderr, "Poll error: %s\n", libuv.Strerror(libuv.Errno(status))) @@ -341,7 +403,17 @@ func onPoll(handle *libuv.Poll, status c.Int, events c.Int) { func updateConnRegistrations(conn *conn, create bool) bool { fmt.Println("updateConnRegistrations called") + if conn == nil || conn.PollHandle == nil { + fmt.Fprintf(os.Stderr, "Poll handle is nil\n") + return false + } + events := c.Int(0) + if conn.EventMask == 0 { + fmt.Println("No events to poll, skipping poll start.") + return true + } + fmt.Printf("Event mask: %d\n", conn.EventMask) if conn.EventMask&c.Uint(libuv.READABLE) != 0 { events |= c.Int(libuv.READABLE) } @@ -349,17 +421,22 @@ func updateConnRegistrations(conn *conn, create bool) bool { events |= c.Int(libuv.WRITABLE) } + fmt.Printf("Starting poll with events: %d\n", events) + if conn.PollHandle == nil { + fmt.Fprintf(os.Stderr, "Poll handle is nil\n") + return false + } r := conn.PollHandle.Start(events, onPoll) - fmt.Println("Poll handle started: %d", r) + //fmt.Println("Poll handle started: %d", r) if r < 0 { fmt.Fprintf(os.Stderr, "uv_poll_start error: %s\n", libuv.Strerror(libuv.Errno(r))) return false } - fmt.Println("Poll handle started") + fmt.Printf("Poll handle started: %d\n", r) return true } -func createConnData(loop *libuv.Loop,client *libuv.Tcp) *conn { +func createConnData(loop *libuv.Loop, client *libuv.Tcp) *conn { conn := (*conn)(c.Calloc(1, unsafe.Sizeof(conn{}))) if conn == nil { fmt.Fprintf(os.Stderr, "Failed to allocate conn_data\n") @@ -368,9 +445,25 @@ func createConnData(loop *libuv.Loop,client *libuv.Tcp) *conn { fmt.Println("Conn data created") c.Memcpy(unsafe.Pointer(&conn.Stream), unsafe.Pointer(client), unsafe.Sizeof(libuv.Tcp{})) conn.IsClosing = 0 + conn.EventMask = 0 fmt.Println("Conn data initialized") + conn.PollHandle = (*libuv.Poll)(c.Malloc(unsafe.Sizeof(libuv.Poll{}))) + if conn.PollHandle == nil { + fmt.Fprintf(os.Stderr, "Failed to allocate poll handle\n") + c.Free(unsafe.Pointer(conn)) + return nil + } + fmt.Println("Poll handle allocated") + + fmt.Printf("Io Watcher Fd: %d\n", client.GetIoWatcherFd()) + fd := client.GetIoWatcherFd() + if fd < 0 { + fmt.Fprintf(os.Stderr, "Invalid file descriptor\n") + c.Free(unsafe.Pointer(conn)) + return nil + } r := libuv.PollInit(loop, conn.PollHandle, libuv.OsFd(client.GetIoWatcherFd())) if r < 0 { fmt.Fprintf(os.Stderr, "uv_poll_init error: %s\n", libuv.Strerror(libuv.Errno(r))) @@ -379,10 +472,10 @@ func createConnData(loop *libuv.Loop,client *libuv.Tcp) *conn { } fmt.Println("Poll handle initialized") - //(*libuv.Handle)(unsafe.Pointer(conn.PollHandle)).Data = unsafe.Pointer(conn) - conn.PollHandle.Data = unsafe.Pointer(conn) - //TODO(hackerchai): fix nil pointer error - //conn.Stream.Data = unsafe.Pointer(conn) + (*libuv.Handle)(unsafe.Pointer(&conn.PollHandle)).SetData(unsafe.Pointer(conn)) + fmt.Println("Poll handle data set") + (*libuv.Handle)(unsafe.Pointer(&conn.Stream)).SetData(unsafe.Pointer(conn)) + fmt.Println("Stream data set") if !updateConnRegistrations(conn, true) { (*libuv.Handle)(unsafe.Pointer(conn.PollHandle)).Close(nil) @@ -461,6 +554,7 @@ func (srv *Server) Close() error { type HandlerFunc func(ResponseWriter, *Request) func (f HandlerFunc) ServeHTTP(w ResponseWriter, r *Request) { + fmt.Printf("ServeHTTP called\n") f(w, r) } diff --git a/x/net/http/servermux.go b/x/net/http/servermux.go index 744d693..ece33c6 100644 --- a/x/net/http/servermux.go +++ b/x/net/http/servermux.go @@ -1,6 +1,7 @@ package http import ( + "fmt" "sync" ) @@ -18,6 +19,7 @@ type muxEntry struct { var DefaultServeMux = &ServeMux{m: make(map[string]muxEntry)} func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) { + fmt.Printf("ServeHTTP called\n") h, _ := mux.Handler(r) h.ServeHTTP(w, r) } From 2944a9d8dc73633ea3afba7362969b75c7b3c651 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Tue, 20 Aug 2024 18:14:05 +0800 Subject: [PATCH 19/55] WIP(x/http/client): http proxy & 100-continue & KeepAlive & gzip & some code improvement --- .../_demo/maxConnsPerHost/maxConnsPerHost.go | 32 + x/net/http/_demo/postform/postform.go | 2 +- x/net/http/_demo/upload/upload.go | 12 +- x/net/http/request.go | 42 +- x/net/http/response.go | 7 + x/net/http/transfer.go | 25 - x/net/http/transport.go | 931 ++++++++++++++---- x/net/http/util.go | 25 + x/net/ipsock.go | 24 + 9 files changed, 894 insertions(+), 206 deletions(-) create mode 100644 x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go create mode 100644 x/net/ipsock.go diff --git a/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go b/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go new file mode 100644 index 0000000..63cedbc --- /dev/null +++ b/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go @@ -0,0 +1,32 @@ +package main + +import ( + "fmt" + "io" + + "github.com/goplus/llgoexamples/x/net/http" +) + +func main() { + client := &http.Client{ + //Transport: &http.Transport{ + // MaxConnsPerHost: 2, + //}, + } + req, err := http.NewRequest("GET", "https://www.baidu.com", nil) + resp, err := client.Do(req) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) + fmt.Println(resp.Proto) + resp.PrintHeaders() + body, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(string(body)) + defer resp.Body.Close() +} diff --git a/x/net/http/_demo/postform/postform.go b/x/net/http/_demo/postform/postform.go index 1636786..eae4d6e 100644 --- a/x/net/http/_demo/postform/postform.go +++ b/x/net/http/_demo/postform/postform.go @@ -5,7 +5,7 @@ import ( "io" "net/url" - "github.com/goplus/llgo/x/net/http" + "github.com/goplus/llgoexamples/x/net/http" ) func main() { diff --git a/x/net/http/_demo/upload/upload.go b/x/net/http/_demo/upload/upload.go index 3dab514..86b57e9 100644 --- a/x/net/http/_demo/upload/upload.go +++ b/x/net/http/_demo/upload/upload.go @@ -10,7 +10,7 @@ import ( func main() { url := "http://httpbin.org/post" - filePath := "/Users/spongehah/go/src/llgo/x/http/_demo/upload/example.txt" // Replace with your file path + filePath := "/Users/spongehah/go/src/llgo/x/net/http/_demo/upload/example.txt" // Replace with your file path file, err := os.Open(filePath) if err != nil { @@ -19,7 +19,15 @@ func main() { } defer file.Close() - resp, err := http.Post(url, "application/octet-stream", file) + client := &http.Client{} + req, err := http.NewRequest("POST", url, file) + if err != nil { + fmt.Println(err) + return + } + req.Header.Set("expect", "100-continue") + resp, err := client.Do(req) + if err != nil { fmt.Println(err) return diff --git a/x/net/http/request.go b/x/net/http/request.go index 38b74bb..b44b00b 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -3,6 +3,7 @@ package http import ( "bytes" "context" + "errors" "fmt" "io" "net/textproto" @@ -73,6 +74,16 @@ func NewRequest(method, url string, body io.Reader) (*Request, error) { // redirects can replay the body), and Body is set to NoBody if the // ContentLength is 0. func NewRequestWithContext(ctx context.Context, method, urlStr string, body io.Reader) (*Request, error) { + // TODO(spongehah) Hyper only supports http + isHttpPrefix := strings.HasPrefix(urlStr, "http://") + isHttpsPrefix := strings.HasPrefix(urlStr, "https://") + if !isHttpPrefix && !isHttpsPrefix { + urlStr = "http://" + urlStr + } + if isHttpsPrefix { + urlStr = "http://" + strings.TrimPrefix(urlStr, "https://") + } + if method == "" { // We document that "" means "GET" for Request.Method, and people have // relied on that from NewRequest, so keep that working. @@ -82,9 +93,9 @@ func NewRequestWithContext(ctx context.Context, method, urlStr string, body io.R if !validMethod(method) { return nil, fmt.Errorf("net/http: invalid method %q", method) } - //if ctx == nil { - // return nil, errors.New("net/http: nil Context") - //} + if ctx == nil { + return nil, errors.New("net/http: nil Context") + } u, err := url.Parse(urlStr) if err != nil { return nil, err @@ -241,8 +252,10 @@ func newHyperRequest(req *Request) (*hyper.Request, error) { } if method == "POST" && req.Body != nil { - req.Header.Set("expect", "100-continue") - hyperReq.OnInformational(printInformational, nil) + // 100-continue + if req.ProtoAtLeast(1, 1) && req.Body != nil && req.expectsContinue() { + hyperReq.OnInformational(printInformational, nil) + } hyperReqBody := hyper.NewBody() reqData := &postReq{ @@ -285,6 +298,17 @@ func (req *Request) setHeaders(hyperReq *hyper.Request) error { return nil } +func (r *Request) expectsContinue() bool { + return hasToken(r.Header.get("Expect"), "100-continue") +} + +func (r *Request) wantsClose() bool { + if r.Close { + return true + } + return hasToken(r.Header.get("Connection"), "close") +} + func (r *Request) closeBody() error { if r.Body == nil { return nil @@ -354,6 +378,14 @@ func (r *Request) Cookies() []*Cookie { return readCookies(r.Header, "") } +// ProtoAtLeast reports whether the HTTP protocol used +// in the request is at least major.minor. +func (r *Request) ProtoAtLeast(major, minor int) bool { + return r.ProtoMajor > major || + r.ProtoMajor == major && r.ProtoMinor >= minor +} + + // readCookies parses all "Cookie" values from the header h and // returns the successfully parsed Cookies. // diff --git a/x/net/http/response.go b/x/net/http/response.go index 174d2fc..32b5723 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -82,3 +82,10 @@ func fixPragmaCacheControl(header Header) { func (r *Response) Cookies() []*Cookie { return readSetCookies(r.Header) } + +// isProtocolSwitchHeader reports whether the request or response header +// is for a protocol switch. +func isProtocolSwitchHeader(h Header) bool { + return h.Get("Upgrade") != "" && + HeaderValuesContainsToken(h["Connection"], "Upgrade") +} diff --git a/x/net/http/transfer.go b/x/net/http/transfer.go index ac50296..0787270 100644 --- a/x/net/http/transfer.go +++ b/x/net/http/transfer.go @@ -388,31 +388,6 @@ func shouldClose(major, minor int, header Header, removeCloseHeader bool) bool { return hasClose } -// HeaderValuesContainsToken reports whether any string in values -// contains the provided token, ASCII case-insensitively. -func HeaderValuesContainsToken(values []string, token string) bool { - for _, v := range values { - if headerValueContainsToken(v, token) { - return true - } - } - return false -} - -// headerValueContainsToken reports whether v (assumed to be a -// 0#element, in the ABNF extension described in RFC 7230 section 7) -// contains token amongst its comma-separated tokens, ASCII -// case-insensitively. -func headerValueContainsToken(v string, token string) bool { - for comma := strings.IndexByte(v, ','); comma != -1; comma = strings.IndexByte(v, ',') { - if tokenEqual(trimOWS(v[:comma]), token) { - return true - } - v = v[comma+1:] - } - return tokenEqual(trimOWS(v), token) -} - // tokenEqual reports whether t1 and t2 are equal, ASCII case-insensitively. func tokenEqual(t1, t2 string) bool { if len(t1) != len(t2) { diff --git a/x/net/http/transport.go b/x/net/http/transport.go index e54b357..e5467e2 100644 --- a/x/net/http/transport.go +++ b/x/net/http/transport.go @@ -14,25 +14,36 @@ import ( "github.com/goplus/llgo/c/libuv" "github.com/goplus/llgo/c/net" "github.com/goplus/llgo/c/syscall" + xnet "github.com/goplus/llgoexamples/x/net" "github.com/goplus/llgoexamples/rust/hyper" ) -type connData struct { - TcpHandle libuv.Tcp - ConnectReq libuv.Connect - ReadBuf libuv.Buf - TimeoutTimer libuv.Timer - IsCompleted int - ReadBufFilled uintptr - ReadWaker *hyper.Waker - WriteWaker *hyper.Waker -} - type Transport struct { altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme reqMu sync.Mutex reqCanceler map[cancelKey]func(error) - //Proxy func(*Request) (*url.URL, error) + Proxy func(*Request) (*url.URL, error) + + connsPerHostMu sync.Mutex + connsPerHost map[connectMethodKey]int + connsPerHostWait map[connectMethodKey]wantConnQueue // waiting getConns + + // DisableKeepAlives, if true, disables HTTP keep-alives and + // will only use the connection to the server for a single + // HTTP request. + // + // This is unrelated to the similarly named TCP keep-alives. + DisableKeepAlives bool + + // DisableCompression, if true, prevents the Transport from + // requesting compression with an "Accept-Encoding: gzip" + // request header when the Request contains no existing + // Accept-Encoding value. If the Transport requests gzip on + // its own and gets a gzipped response, it's transparently + // decoded in the Response.Body. However, if the user + // explicitly requested gzip it is not automatically + // uncompressed. + DisableCompression bool // MaxConnsPerHost optionally limits the total number of // connections per host, including connections in the dialing, @@ -42,17 +53,15 @@ type Transport struct { MaxConnsPerHost int } -var DefaultTransport RoundTripper = &Transport{} - -// taskId The unique identifier of the next task polled from the executor -type taskId c.Int - -const ( - notSet taskId = iota - sending - receiveResp - receiveRespBody -) +// DefaultTransport is the default implementation of Transport and is +// used by DefaultClient. It establishes network connections as needed +// and caches them for reuse by subsequent calls. It uses HTTP proxies +// as directed by the environment variables HTTP_PROXY, HTTPS_PROXY +// and NO_PROXY (or the lowercase versions thereof). +var DefaultTransport RoundTripper = &Transport{ + //Proxy: ProxyFromEnvironment, + Proxy: nil, +} const ( defaultHTTPPort = "80" @@ -65,16 +74,34 @@ type persistConn struct { // This is used for HTTP/2 today and future protocols later. // If it's non-nil, the rest of the fields are unused. alt RoundTripper + //br *bufio.Reader // from conn //bw *bufio.Writer // to conn //nwrite int64 // bytes written //writech chan writeRequest // written by roundTrip; read by writeLoop //closech chan struct{} // closed when conn closed - conn *connData - t *Transport - reqch chan requestAndChan // written by roundTrip; read by readLoop + + t *Transport + cacheKey connectMethodKey + conn *connData + nwrite int64 // bytes written + reqch chan requestAndChan // written by roundTrip; read by readLoop + closech chan struct{} // closed when conn closed + writeLoopDone chan struct{} // closed when write loop ends + cancelch chan freeChan timeoutch chan struct{} + + isProxy bool + mu sync.Mutex // guards following fields + numExpectedResponses int + closed error // set non-nil when conn is closed, before closech is closed + canceledErr error // set non-nil if conn is canceled + broken bool // an error has happened on this connection; marked broken so it's not reused. + // mutateHeaderFunc is an optional func to modify extra + // headers on each outbound request before it's written. (the + // original Request given to RoundTrip is not modified) + mutateHeaderFunc func(Header) } // incomparable is a zero-width, non-comparable type. Adding it to a struct @@ -83,9 +110,17 @@ type persistConn struct { type incomparable [0]func() type requestAndChan struct { - _ incomparable - req *Request - ch chan responseAndError // unbuffered; always send in select on callerGone + _ incomparable + req *Request + cancelKey cancelKey + ch chan responseAndError // unbuffered; always send in select on callerGone + + // whether the Transport (as opposed to the user client code) + // added the Accept-Encoding gzip header. If the Transport + // set it, only then do we transparently decode the gzip. + addedGzip bool + + callerGone <-chan struct{} // closed when roundTrip caller has returned } // responseAndError is how the goroutine reading from an HTTP/1 server @@ -127,6 +162,13 @@ type transportRequest struct { err error // first setError value for mapRoundTripError to consider } +func (tr *transportRequest) extraHeaders() Header { + if tr.extra == nil { + tr.extra = make(Header) + } + return tr.extra +} + // useRegisteredProtocol reports whether an alternate protocol (as registered // with Transport.RegisterProtocol) should be respected for this request. func (t *Transport) useRegisteredProtocol(req *Request) bool { @@ -328,6 +370,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi } } return w.pc, w.err + // TODO(spongehah) cancel(t.getConn) //case <-req.Cancel: // return nil, errRequestCanceledConn //case <-req.Context().Done(): @@ -345,32 +388,30 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi func (t *Transport) queueForDial(w *wantConn) { w.beforeDial() - go t.dialConnFor(w) - // TODO(spongehah) MaxConnsPerHost - //if t.MaxConnsPerHost <= 0 { - // go t.dialConnFor(w) - // return - //} + if t.MaxConnsPerHost <= 0 { + go t.dialConnFor(w) + return + } - //t.connsPerHostMu.Lock() - //defer t.connsPerHostMu.Unlock() - // - //if n := t.connsPerHost[w.key]; n < t.MaxConnsPerHost { - // if t.connsPerHost == nil { - // t.connsPerHost = make(map[connectMethodKey]int) - // } - // t.connsPerHost[w.key] = n + 1 - // go t.dialConnFor(w) - // return - //} - // - //if t.connsPerHostWait == nil { - // t.connsPerHostWait = make(map[connectMethodKey]wantConnQueue) - //} - //q := t.connsPerHostWait[w.key] - //q.cleanFront() - //q.pushBack(w) - //t.connsPerHostWait[w.key] = q + t.connsPerHostMu.Lock() + defer t.connsPerHostMu.Unlock() + + if n := t.connsPerHost[w.key]; n < t.MaxConnsPerHost { + if t.connsPerHost == nil { + t.connsPerHost = make(map[connectMethodKey]int) + } + t.connsPerHost[w.key] = n + 1 + go t.dialConnFor(w) + return + } + + if t.connsPerHostWait == nil { + t.connsPerHostWait = make(map[connectMethodKey]wantConnQueue) + } + q := t.connsPerHostWait[w.key] + q.cleanFront() + q.pushBack(w) + t.connsPerHostWait[w.key] = q } // dialConnFor dials on behalf of w and delivers the result to w. @@ -383,39 +424,176 @@ func (t *Transport) dialConnFor(w *wantConn) { w.tryDeliver(pc, err) // TODO(spongehah) ConnPool //delivered := w.tryDeliver(pc, err) + // Handle undelivered or shareable connections //if err == nil && (!delivered || pc.alt != nil) { // // pconn was not passed to w, // // or it is HTTP/2 and can be shared. // // Add to the idle connection pool. // t.putOrCloseIdleConn(pc) //} + + // TODO(spongehah) decConnsPerHost + // If an error occurs during the dialing process, the connection count for that host is decreased. + // This ensures that the connection count remains accurate even in cases where the dial attempt fails. //if err != nil { // t.decConnsPerHost(w.key) //} } +// decConnsPerHost decrements the per-host connection count for key, +// which may in turn give a different waiting goroutine permission to dial. +//func (t *Transport) decConnsPerHost(key connectMethodKey) { +// if t.MaxConnsPerHost <= 0 { +// return +// } +// +// t.connsPerHostMu.Lock() +// defer t.connsPerHostMu.Unlock() +// n := t.connsPerHost[key] +// if n == 0 { +// // Shouldn't happen, but if it does, the counting is buggy and could +// // easily lead to a silent deadlock, so report the problem loudly. +// panic("net/http: internal error: connCount underflow") +// } +// +// // Can we hand this count to a goroutine still waiting to dial? +// // (Some goroutines on the wait list may have timed out or +// // gotten a connection another way. If they're all gone, +// // we don't want to kick off any spurious dial operations.) +// if q := t.connsPerHostWait[key]; q.len() > 0 { +// done := false +// for q.len() > 0 { +// w := q.popFront() +// if w.waiting() { +// go t.dialConnFor(w) +// done = true +// break +// } +// } +// if q.len() == 0 { +// delete(t.connsPerHostWait, key) +// } else { +// // q is a value (like a slice), so we have to store +// // the updated q back into the map. +// t.connsPerHostWait[key] = q +// } +// if done { +// return +// } +// } +// +// // Otherwise, decrement the recorded count. +// if n--; n == 0 { +// delete(t.connsPerHost, key) +// } else { +// t.connsPerHost[key] = n +// } +//} + func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *persistConn, err error) { pconn = &persistConn{ - t: t, - reqch: make(chan requestAndChan, 1), - cancelch: make(chan freeChan, 1), - timeoutch: make(chan struct{}, 1), + t: t, + cacheKey: cm.key(), + reqch: make(chan requestAndChan, 1), + cancelch: make(chan freeChan, 1), + timeoutch: make(chan struct{}, 1), + closech: make(chan struct{}, 1), + writeLoopDone: make(chan struct{}, 1), //writech: make(chan writeRequest, 1), //closech: make(chan struct{}), } - // TODO(spongehah) Proxy dialConn + //if cm.scheme() == "https" && t.hasCustomTLSDialer() { + // var err error + // pconn.conn, err = t.customDialTLS(ctx, "tcp", cm.addr()) + // if err != nil { + // return nil, wrapErr(err) + // } + // if tc, ok := pconn.conn.(*tls.Conn); ok { + // // Handshake here, in case DialTLS didn't. TLSNextProto below + // // depends on it for knowing the connection state. + // if trace != nil && trace.TLSHandshakeStart != nil { + // trace.TLSHandshakeStart() + // } + // if err := tc.HandshakeContext(ctx); err != nil { + // go pconn.conn.Close() + // if trace != nil && trace.TLSHandshakeDone != nil { + // trace.TLSHandshakeDone(tls.ConnectionState{}, err) + // } + // return nil, err + // } + // cs := tc.ConnectionState() + // if trace != nil && trace.TLSHandshakeDone != nil { + // trace.TLSHandshakeDone(cs, nil) + // } + // pconn.tlsState = &cs + // } + //} else { + //conn, err := t.dial(ctx, "tcp", cm.addr()) + conn, err := t.dial(ctx, pconn, cm) + if err != nil { + return nil, err + } + pconn.conn = conn + //if cm.scheme() == "https" { + // var firstTLSHost string + // if firstTLSHost, _, err = net.SplitHostPort(cm.addr()); err != nil { + // return nil, wrapErr(err) + // } + // if err = pconn.addTLS(ctx, firstTLSHost, trace); err != nil { + // return nil, wrapErr(err) + // } + //} + //} + + // TODO(spongehah) Proxy(https/sock5) + // Proxy setup. + switch { + case cm.proxyURL == nil: + // Do nothing. Not using a proxy. + // case cm.proxyURL.Scheme == "socks5": + case cm.targetScheme == "http": + pconn.isProxy = true + if pa := cm.proxyAuth(); pa != "" { + pconn.mutateHeaderFunc = func(h Header) { + h.Set("Proxy-Authorization", pa) + } + } + // case cm.targetScheme == "https": + } + //if cm.proxyURL != nil && cm.targetScheme == "https" { + // if err := pconn.addTLS(ctx, cm.tlsHost(), trace); err != nil { + // return nil, err + // } + //} + // + //if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" { + // if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; ok { + // alt := next(cm.targetAddr, pconn.conn.(*tls.Conn)) + // if e, ok := alt.(erringRoundTripper); ok { + // // pconn.conn was closed by next (http2configureTransports.upgradeFn). + // return nil, e.RoundTripErr() + // } + // return &persistConn{t: t, cacheKey: pconn.cacheKey, alt: alt}, nil + // } + //} + + if conn.IsCompleted != 1 { + go pconn.readWriteLoop(libuv.DefaultLoop()) + } + return pconn, nil +} + +func (t *Transport) dial(ctx context.Context, pconn *persistConn, cm connectMethod) (*connData, error) { treq := cm.treq host := treq.URL.Hostname() port := treq.URL.Port() if port == "" { - // Hyper only supports http port = defaultHTTPPort } loop := libuv.DefaultLoop() conn := new(connData) - pconn.conn = conn if conn == nil { return nil, fmt.Errorf("Failed to allocate memory for conn_data\n") } @@ -454,44 +632,136 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers } net.Freeaddrinfo(res) - - if pconn.conn.IsCompleted != 1 { - go pconn.readWriteLoop(loop) - } - return pconn, nil + return conn, nil } -func (pc *persistConn) roundTrip(req *transportRequest) (*Response, error) { +func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { testHookEnterRoundTrip() - resc := make(chan responseAndError, 1) + if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) { + // TODO(spongehah) ConnPool + //pc.t.putOrCloseIdleConn(pc) + return nil, errRequestCanceled + } + pc.mu.Lock() + pc.numExpectedResponses++ + headerFn := pc.mutateHeaderFunc + pc.mu.Unlock() - pc.reqch <- requestAndChan{ - req: req.Request, - ch: resc, + if headerFn != nil { + headerFn(req.extraHeaders()) } - // Determine whether timeout has occurred - if pc.conn.IsCompleted == 1 { - rc := <-pc.reqch // blocking - // Free the resources - freeResources(nil, nil, nil, nil, pc, rc) - return nil, fmt.Errorf("request timeout\n") + + // Ask for a compressed version if the caller didn't set their + // own value for Accept-Encoding. We only attempt to + // uncompress the gzip stream if we were the layer that + // requested it. + requestedGzip := false + if !pc.t.DisableCompression && + req.Header.Get("Accept-Encoding") == "" && + req.Header.Get("Range") == "" && + req.Method != "HEAD" { + // Request gzip only, not deflate. Deflate is ambiguous and + // not as universally supported anyway. + // See: https://zlib.net/zlib_faq.html#faq39 + // + // Note that we don't request this for HEAD requests, + // due to a bug in nginx: + // https://trac.nginx.org/nginx/ticket/358 + // https://golang.org/issue/5522 + // + // We don't request gzip if the request is for a range, since + // auto-decoding a portion of a gzipped document will just fail + // anyway. See https://golang.org/issue/8923 + requestedGzip = true + req.extraHeaders().Set("Accept-Encoding", "gzip") } - select { - case re := <-resc: - if (re.res == nil) == (re.err == nil) { - return nil, fmt.Errorf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil) + + // The 100-continue operation in Hyper is handled in the newHyperRequest function. + + // Keep-Alive + if pc.t.DisableKeepAlives && + !req.wantsClose() && + !isProtocolSwitchHeader(req.Header) { + req.extraHeaders().Set("Connection", "close") + } + + gone := make(chan struct{}) + defer close(gone) + + defer func() { + if err != nil { + pc.t.setReqCanceler(req.cancelKey, nil) } - if re.err != nil { - return nil, re.err + }() + + const debugRoundTrip = false // Debug switch provided for developers + + // Write the request concurrently with waiting for a response, + // in case the server decides to reply before reading our full + // request body. + + // In Hyper, the writeLoop() and readLoop() are combined together --> readWriteLoop(). + startBytesWritten := pc.nwrite + + // Send the request to readWriteLoop(). + resc := make(chan responseAndError, 1) + + pc.reqch <- requestAndChan{ + req: req.Request, + cancelKey: req.cancelKey, + ch: resc, + addedGzip: requestedGzip, + callerGone: gone, + } + + //var respHeaderTimer <-chan time.Time + //cancelChan := req.Request.Cancel + //ctxDoneChan := req.Context().Done() + pcClosed := pc.closech + canceled := false + + for { + testHookWaitResLoop() + + // Determine whether timeout has occurred + if pc.conn.IsCompleted == 1 { + rc := <-pc.reqch // blocking + // Free the resources + freeResources(nil, nil, nil, nil, pc, rc) + return nil, fmt.Errorf("request timeout\n") } - return re.res, nil - case <-pc.timeoutch: - freech := make(chan struct{}, 1) - pc.cancelch <- freeChan{ - freech: freech, + select { + //case err := <-writeErrCh: + case <-pcClosed: + pcClosed = nil + if canceled || pc.t.replaceReqCanceler(req.cancelKey, nil) { + if debugRoundTrip { + //req.logf("closech recv: %T %#v", pc.closed, pc.closed) + } + return nil, pc.mapRoundTripError(req, startBytesWritten, pc.closed) + } + //case <-respHeaderTimer: + case re := <-resc: + if (re.res == nil) == (re.err == nil) { + return nil, fmt.Errorf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil) + } + if debugRoundTrip { + //req.logf("resc recv: %p, %T/%#v", re.res, re.err, re.err) + } + if re.err != nil { + return nil, pc.mapRoundTripError(req, startBytesWritten, re.err) + } + return re.res, nil + // TODO(spongehah) cancel(pc.roundTrip) + //case <-cancelChan: + case <-pc.timeoutch: + freech := make(chan struct{}, 1) + pc.cancelch <- freeChan{ + freech: freech, + } + <-freech + return nil, fmt.Errorf("request timeout\n") } - <-freech - return nil, fmt.Errorf("request timeout\n") } } @@ -667,6 +937,22 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { //} } +type connData struct { + TcpHandle libuv.Tcp + ConnectReq libuv.Connect + ReadBuf libuv.Buf + TimeoutTimer libuv.Timer + IsCompleted int + ReadBufFilled uintptr + ReadWaker *hyper.Waker + WriteWaker *hyper.Waker +} + +func (conn *connData) Close() error { + freeConnData(conn) + return nil +} + // onConnect is the libuv callback for a successful connection func onConnect(req *libuv.Connect, status c.Int) { //conn := (*ConnData)(req.Data) @@ -819,6 +1105,16 @@ func newIoWithConnReadWrite(connData *connData) *hyper.Io { return hyperIo } +// taskId The unique identifier of the next task polled from the executor +type taskId c.Int + +const ( + notSet taskId = iota + sending + receiveResp + receiveRespBody +) + // setTaskId Set taskId to the task's userdata as a unique identifier func setTaskId(task *hyper.Task, userData taskId) { var data = userData @@ -921,24 +1217,30 @@ func freeConnData(conn *connData) { } } -type httpError struct { - err string - timeout bool -} - -func (e *httpError) Error() string { return e.err } -func (e *httpError) Timeout() bool { return e.timeout } -func (e *httpError) Temporary() bool { return true } +// ---------------------------------------------------------- -func nop() {} +// error values for debugging and testing, not seen by users. +var ( + errKeepAlivesDisabled = errors.New("http: putIdleConn: keep alives disabled") + errConnBroken = errors.New("http: putIdleConn: connection is in bad state") + errCloseIdle = errors.New("http: putIdleConn: CloseIdleConnections was called") + errTooManyIdle = errors.New("http: putIdleConn: too many idle connections") + errTooManyIdleHost = errors.New("http: putIdleConn: too many idle connections for host") + errCloseIdleConns = errors.New("http: CloseIdleConnections called") + errReadLoopExiting = errors.New("http: persistConn.readLoop exiting") + errIdleConnTimeout = errors.New("http: idle connection timeout") + + // errServerClosedIdle is not seen by users for idempotent requests, but may be + // seen by a user if the server shuts down an idle connection and sends its FIN + // in flight with already-written POST body bytes from the client. + // See https://github.com/golang/go/issues/19943#issuecomment-355607646 + errServerClosedIdle = errors.New("http: server closed idle connection") +) // ErrSkipAltProtocol is a sentinel error value defined by Transport.RegisterProtocol. var ErrSkipAltProtocol = errors.New("net/http: skip alternate protocol") - var errCannotRewind = errors.New("net/http: cannot rewind body after connection loss") -var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true} - // errRequestCanceled is set to be identical to the one from h2 to facilitate // testing. var errRequestCanceled = http2errRequestCanceled @@ -947,6 +1249,67 @@ var errRequestCanceled = http2errRequestCanceled // exported. At least they'll be DeepEqual for h1-vs-h2 comparisons tests. var http2errRequestCanceled = errors.New("net/http: request canceled") var errRequestCanceledConn = errors.New("net/http: request canceled while waiting for connection") // TODO: unify? +// errCallerOwnsConn is an internal sentinel error used when we hand +// off a writable response.Body to the caller. We use this to prevent +// closing a net.Conn that is now owned by the caller. +var errCallerOwnsConn = errors.New("read loop ending; caller owns writable underlying conn") + +type httpError struct { + err string + timeout bool +} + +func (e *httpError) Error() string { return e.err } +func (e *httpError) Timeout() bool { return e.timeout } +func (e *httpError) Temporary() bool { return true } + +// fakeLocker is a sync.Locker which does nothing. It's used to guard +// test-only fields when not under test, to avoid runtime atomic +// overhead. +type fakeLocker struct{} + +func (fakeLocker) Lock() {} +func (fakeLocker) Unlock() {} + +// nothingWrittenError wraps a write errors which ended up writing zero bytes. +type nothingWrittenError struct { + error +} + +func (nwe nothingWrittenError) Unwrap() error { + return nwe.error +} + +// transportReadFromServerError is used by Transport.readLoop when the +// 1 byte peek read fails and we're actually anticipating a response. +// Usually this is just due to the inherent keep-alive shut down race, +// where the server closed the connection at the same time the client +// wrote. The underlying err field is usually io.EOF or some +// ECONNRESET sort of thing which varies by platform. But it might be +// the user's custom net.Conn.Read error too, so we carry it along for +// them to return from Transport.RoundTrip. +type transportReadFromServerError struct { + err error +} + +func (e transportReadFromServerError) Unwrap() error { return e.err } +func (e transportReadFromServerError) Error() string { + return fmt.Sprintf("net/http: Transport failed to read from server: %v", e.err) +} + +func nop() {} + +// testHooks. Always non-nil. +var ( + testHookEnterRoundTrip = nop + testHookWaitResLoop = nop + testHookRoundTripRetried = nop + testHookPrePendingDial = nop + testHookPostPendingDial = nop + + testHookMu sync.Locker = fakeLocker{} // guards following + testHookReadLoopBeforeNextRead = nop +) /*// alternateRoundTripper returns the alternate RoundTripper to use // for this request if the Request's URL scheme requires one, @@ -973,12 +1336,160 @@ func (t *Transport) useRegisteredProtocol(req *Request) bool { } */ -func idnaASCIIFromURL(url *url.URL) string { - addr := url.Hostname() - if v, err := idnaASCII(addr); err == nil { - addr = v +func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) { + cm.targetScheme = treq.URL.Scheme + cm.targetAddr = canonicalAddr(treq.URL) + if t.Proxy != nil { + cm.proxyURL, err = t.Proxy(treq.Request) } - return addr + cm.treq = treq + cm.onlyH1 = treq.requiresHTTP1() + return cm, err +} + +func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) { + t.reqMu.Lock() + defer t.reqMu.Unlock() + if t.reqCanceler == nil { + t.reqCanceler = make(map[cancelKey]func(error)) + } + if fn != nil { + t.reqCanceler[key] = fn + } else { + delete(t.reqCanceler, key) + } +} + +// replaceReqCanceler replaces an existing cancel function. If there is no cancel function +// for the request, we don't set the function and return false. +// Since CancelRequest will clear the canceler, we can use the return value to detect if +// the request was canceled since the last setReqCancel call. +func (t *Transport) replaceReqCanceler(key cancelKey, fn func(error)) bool { + t.reqMu.Lock() + defer t.reqMu.Unlock() + _, ok := t.reqCanceler[key] + if !ok { + return false + } + if fn != nil { + t.reqCanceler[key] = fn + } else { + delete(t.reqCanceler, key) + } + return true +} + +func (pc *persistConn) cancelRequest(err error) { + pc.mu.Lock() + defer pc.mu.Unlock() + pc.canceledErr = err + pc.closeLocked(errRequestCanceled) +} + +// close closes the underlying TCP connection and closes +// the pc.closech channel. +// +// The provided err is only for testing and debugging; in normal +// circumstances it should never be seen by users. +func (pc *persistConn) close(err error) { + pc.mu.Lock() + defer pc.mu.Unlock() + pc.closeLocked(err) +} + +func (pc *persistConn) closeLocked(err error) { + if err == nil { + panic("nil error") + } + pc.broken = true + if pc.closed == nil { + pc.closed = err + // TODO(spongehah) decConnsPerHost + //pc.t.decConnsPerHost(pc.cacheKey) + // Close HTTP/1 (pc.alt == nil) connection. + // HTTP/2 closes its connection itself. + if pc.alt == nil { + if err != errCallerOwnsConn { + pc.conn.Close() + } + close(pc.closech) + } + } + pc.mutateHeaderFunc = nil +} + +// mapRoundTripError returns the appropriate error value for +// persistConn.roundTrip. +// +// The provided err is the first error that (*persistConn).roundTrip +// happened to receive from its select statement. +// +// The startBytesWritten value should be the value of pc.nwrite before the roundTrip +// started writing the request. +func (pc *persistConn) mapRoundTripError(req *transportRequest, startBytesWritten int64, err error) error { + if err == nil { + return nil + } + + // Wait for the writeLoop goroutine to terminate to avoid data + // races on callers who mutate the request on failure. + // + // When resc in pc.roundTrip and hence rc.ch receives a responseAndError + // with a non-nil error it implies that the persistConn is either closed + // or closing. Waiting on pc.writeLoopDone is hence safe as all callers + // close closech which in turn ensures writeLoop returns. + <-pc.writeLoopDone + + // If the request was canceled, that's better than network + // failures that were likely the result of tearing down the + // connection. + if cerr := pc.canceled(); cerr != nil { + return cerr + } + + // See if an error was set explicitly. + req.mu.Lock() + reqErr := req.err + req.mu.Unlock() + if reqErr != nil { + return reqErr + } + + if err == errServerClosedIdle { + // Don't decorate + return err + } + + if _, ok := err.(transportReadFromServerError); ok { + if pc.nwrite == startBytesWritten { + return nothingWrittenError{err} + } + // Don't decorate + return err + } + if pc.isBroken() { + if pc.nwrite == startBytesWritten { + return nothingWrittenError{err} + } + return fmt.Errorf("net/http: HTTP/1.x transport connection broken: %w", err) + } + return err +} + +// canceled returns non-nil if the connection was closed due to +// CancelRequest or due to context cancellation. +func (pc *persistConn) canceled() error { + pc.mu.Lock() + defer pc.mu.Unlock() + return pc.canceledErr +} + +// isBroken reports whether this connection is in a known broken state. +func (pc *persistConn) isBroken() bool { + pc.mu.Lock() + b := pc.closed != nil + pc.mu.Unlock() + return b } type readTrackingBody struct { @@ -997,26 +1508,6 @@ func (r *readTrackingBody) Close() error { return r.ReadCloser.Close() } -// testHooks. Always non-nil. -var ( - testHookEnterRoundTrip = nop - testHookWaitResLoop = nop - testHookRoundTripRetried = nop - testHookPrePendingDial = nop - testHookPostPendingDial = nop - - testHookMu sync.Locker = fakeLocker{} // guards following - testHookReadLoopBeforeNextRead = nop -) - -// fakeLocker is a sync.Locker which does nothing. It's used to guard -// test-only fields when not under test, to avoid runtime atomic -// overhead. -type fakeLocker struct{} - -func (fakeLocker) Lock() {} -func (fakeLocker) Unlock() {} - // setupRewindBody returns a new request with a custom body wrapper // that can report whether the body needs rewinding. // This lets rewindBody avoid an error result when the request @@ -1053,17 +1544,27 @@ func rewindBody(req *Request) (rewound *Request, err error) { return &newReq, nil } -func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) { - t.reqMu.Lock() - defer t.reqMu.Unlock() - if t.reqCanceler == nil { - t.reqCanceler = make(map[cancelKey]func(error)) +var portMap = map[string]string{ + "http": "80", + "https": "443", + "socks5": "1080", +} + +func idnaASCIIFromURL(url *url.URL) string { + addr := url.Hostname() + if v, err := idnaASCII(addr); err == nil { + addr = v } - if fn != nil { - t.reqCanceler[key] = fn - } else { - delete(t.reqCanceler, key) + return addr +} + +// canonicalAddr returns url.Host but always with a ":port" suffix. +func canonicalAddr(url *url.URL) string { + port := url.Port() + if port == "" { + port = portMap[url.Scheme] } + return xnet.JoinHostPort(idnaASCIIFromURL(url), port) } // connectMethod is the map key (in its String form) for keeping persistent @@ -1094,16 +1595,51 @@ type connectMethod struct { onlyH1 bool // whether to disable HTTP/2 and force HTTP/1 } -func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) { - cm.targetScheme = treq.URL.Scheme - // TODO(spongehah) canonicalAddr & Proxy - //cm.targetAddr = canonicalAddr(treq.URL) - //if t.Proxy != nil { - // cm.proxyURL, err = t.Proxy(treq.Request) - //} - cm.treq = treq - cm.onlyH1 = treq.requiresHTTP1() - return cm, err +func (cm *connectMethod) key() connectMethodKey { + proxyStr := "" + targetAddr := cm.targetAddr + if cm.proxyURL != nil { + proxyStr = cm.proxyURL.String() + if (cm.proxyURL.Scheme == "http" || cm.proxyURL.Scheme == "https") && cm.targetScheme == "http" { + targetAddr = "" + } + } + return connectMethodKey{ + proxy: proxyStr, + scheme: cm.targetScheme, + addr: targetAddr, + onlyH1: cm.onlyH1, + } +} + +// scheme returns the first hop scheme: http, https, or socks5 +func (cm *connectMethod) scheme() string { + if cm.proxyURL != nil { + return cm.proxyURL.Scheme + } + return cm.targetScheme +} + +// addr returns the first hop "host:port" to which we need to TCP connect. +func (cm *connectMethod) addr() string { + if cm.proxyURL != nil { + return canonicalAddr(cm.proxyURL) + } + return cm.targetAddr +} + +// proxyAuth returns the Proxy-Authorization header to set +// on requests, if applicable. +func (cm *connectMethod) proxyAuth() string { + if cm.proxyURL == nil { + return "" + } + if u := cm.proxyURL.User; u != nil { + username := u.Username() + password, _ := u.Password() + return "Basic " + basicAuth(username, password) + } + return "" } // connectMethodKey is the map key version of connectMethod, with a @@ -1137,6 +1673,24 @@ type wantConn struct { err error } +// cancel marks w as no longer wanting a result (for example, due to cancellation). +// If a connection has been delivered already, cancel returns it with t.putOrCloseIdleConn. +func (w *wantConn) cancel(t *Transport, err error) { + w.mu.Lock() + if w.pc == nil && w.err == nil { + close(w.ready) // catch misbehavior in future delivery + } + //pc := w.pc + w.pc = nil + w.err = err + w.mu.Unlock() + + // TODO(spongehah) ConnPool + //if pc != nil { + // t.putOrCloseIdleConn(pc) + //} +} + // waiting reports whether w is still waiting for an answer (connection or error). func (w *wantConn) waiting() bool { select { @@ -1165,37 +1719,68 @@ func (w *wantConn) tryDeliver(pc *persistConn, err error) bool { return true } -// cancel marks w as no longer wanting a result (for example, due to cancellation). -// If a connection has been delivered already, cancel returns it with t.putOrCloseIdleConn. -func (w *wantConn) cancel(t *Transport, err error) { - w.mu.Lock() - if w.pc == nil && w.err == nil { - close(w.ready) // catch misbehavior in future delivery - } - //pc := w.pc - w.pc = nil - w.err = err - w.mu.Unlock() +// A wantConnQueue is a queue of wantConns. +type wantConnQueue struct { + // This is a queue, not a deque. + // It is split into two stages - head[headPos:] and tail. + // popFront is trivial (headPos++) on the first stage, and + // pushBack is trivial (append) on the second stage. + // If the first stage is empty, popFront can swap the + // first and second stages to remedy the situation. + // + // This two-stage split is analogous to the use of two lists + // in Okasaki's purely functional queue but without the + // overhead of reversing the list when swapping stages. + head []*wantConn + headPos int + tail []*wantConn +} - // TODO(spongehah) ConnPool - //if pc != nil { - // t.putOrCloseIdleConn(pc) - //} +// len returns the number of items in the queue. +func (q *wantConnQueue) len() int { + return len(q.head) - q.headPos + len(q.tail) } -func (cm *connectMethod) key() connectMethodKey { - proxyStr := "" - targetAddr := cm.targetAddr - if cm.proxyURL != nil { - proxyStr = cm.proxyURL.String() - if (cm.proxyURL.Scheme == "http" || cm.proxyURL.Scheme == "https") && cm.targetScheme == "http" { - targetAddr = "" +// pushBack adds w to the back of the queue. +func (q *wantConnQueue) pushBack(w *wantConn) { + q.tail = append(q.tail, w) +} + +// popFront removes and returns the wantConn at the front of the queue. +func (q *wantConnQueue) popFront() *wantConn { + if q.headPos >= len(q.head) { + if len(q.tail) == 0 { + return nil } + // Pick up tail as new head, clear tail. + q.head, q.headPos, q.tail = q.tail, 0, q.head[:0] } - return connectMethodKey{ - proxy: proxyStr, - scheme: cm.targetScheme, - addr: targetAddr, - onlyH1: cm.onlyH1, + w := q.head[q.headPos] + q.head[q.headPos] = nil + q.headPos++ + return w +} + +// peekFront returns the wantConn at the front of the queue without removing it. +func (q *wantConnQueue) peekFront() *wantConn { + if q.headPos < len(q.head) { + return q.head[q.headPos] + } + if len(q.tail) > 0 { + return q.tail[0] + } + return nil +} + +// cleanFront pops any wantConns that are no longer waiting from the head of the +// queue, reporting whether any were popped. +func (q *wantConnQueue) cleanFront() (cleaned bool) { + for { + w := q.peekFront() + if w == nil || w.waiting() { + return cleaned + } + q.popFront() + cleaned = true } } diff --git a/x/net/http/util.go b/x/net/http/util.go index e5d2d03..f2efb70 100644 --- a/x/net/http/util.go +++ b/x/net/http/util.go @@ -181,6 +181,31 @@ func isCTL(b byte) bool { // httpguts.isCTL return b < ' ' || b == del } +// HeaderValuesContainsToken reports whether any string in values +// contains the provided token, ASCII case-insensitively. +func HeaderValuesContainsToken(values []string, token string) bool { // httpguts.HeaderValuesContainsToken + for _, v := range values { + if headerValueContainsToken(v, token) { + return true + } + } + return false +} + +// headerValueContainsToken reports whether v (assumed to be a +// 0#element, in the ABNF extension described in RFC 7230 section 7) +// contains token amongst its comma-separated tokens, ASCII +// case-insensitively. +func headerValueContainsToken(v string, token string) bool { // httpguts.headerValueContainsToken + for comma := strings.IndexByte(v, ','); comma != -1; comma = strings.IndexByte(v, ',') { + if tokenEqual(trimOWS(v[:comma]), token) { + return true + } + v = v[comma+1:] + } + return tokenEqual(trimOWS(v), token) +} + // IsPrint returns whether s is ASCII and printable according to // https://tools.ietf.org/html/rfc20#section-4.2. func IsPrint(s string) bool { // ascii.IsPrint diff --git a/x/net/ipsock.go b/x/net/ipsock.go new file mode 100644 index 0000000..855a864 --- /dev/null +++ b/x/net/ipsock.go @@ -0,0 +1,24 @@ +package net + +// JoinHostPort combines host and port into a network address of the +// form "host:port". If host contains a colon, as found in literal +// IPv6 addresses, then JoinHostPort returns "[host]:port". +// +// See func Dial for a description of the host and port parameters. +func JoinHostPort(host, port string) string { + // We assume that host is a literal IPv6 address if host has + // colons. + if IndexByteString(host, ':') >= 0 { + return "[" + host + "]:" + port + } + return host + ":" + port +} + +func IndexByteString(s string, c byte) int { + for i := 0; i < len(s); i++ { + if s[i] == c { + return i + } + } + return -1 +} From d05548e7adde869260c1a48fc2432cbdeffec0b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Wed, 21 Aug 2024 18:58:13 +0800 Subject: [PATCH 20/55] WIP(x/http/client): Optimize readWriteLoop and make some code adjustments --- .../_demo/maxConnsPerHost/maxConnsPerHost.go | 6 +- x/net/http/request.go | 6 +- x/net/http/response.go | 26 + x/net/http/transfer.go | 2 +- x/net/http/transport.go | 822 ++++++++++-------- 5 files changed, 504 insertions(+), 358 deletions(-) diff --git a/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go b/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go index 63cedbc..882bdc1 100644 --- a/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go +++ b/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go @@ -9,9 +9,9 @@ import ( func main() { client := &http.Client{ - //Transport: &http.Transport{ - // MaxConnsPerHost: 2, - //}, + Transport: &http.Transport{ + MaxConnsPerHost: 2, + }, } req, err := http.NewRequest("GET", "https://www.baidu.com", nil) resp, err := client.Do(req) diff --git a/x/net/http/request.go b/x/net/http/request.go index b44b00b..cb50936 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -385,7 +385,6 @@ func (r *Request) ProtoAtLeast(major, minor int) bool { r.ProtoMajor == major && r.ProtoMinor >= minor } - // readCookies parses all "Cookie" values from the header h and // returns the successfully parsed Cookies. // @@ -425,6 +424,11 @@ func readCookies(h Header, filter string) []*Cookie { return cookies } +// requestBodyReadError wraps an error from (*Request).write to indicate +// that the error came from a Read call on the Request.Body. +// This error type should not escape the net/http package to users. +type requestBodyReadError struct{ error } + func idnaASCII(v string) (string, error) { // TODO: Consider removing this check after verifying performance is okay. // Right now punycode verification, length checks, context checks, and the diff --git a/x/net/http/response.go b/x/net/http/response.go index 32b5723..d2a7dd5 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "strconv" + "unsafe" "github.com/goplus/llgo/c" "github.com/goplus/llgoexamples/rust/hyper" @@ -63,6 +64,19 @@ func readResponseLineAndHeader(resp *Response, hyperResp *hyper.Response) { headers.Foreach(appendToResponseHeader, c.Pointer(resp)) } +// appendToResponseBody BodyForeachCallback function: Process the response body +func appendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { + writer := (*io.PipeWriter)(userdata) + bufLen := chunk.Len() + bytes := unsafe.Slice(chunk.Bytes(), bufLen) + _, err := writer.Write(bytes) + if err != nil { + fmt.Println("Error writing to response body:", err) + return hyper.IterBreak + } + return hyper.IterContinue +} + // RFC 7234, section 5.4: Should treat // // Pragma: no-cache @@ -89,3 +103,15 @@ func isProtocolSwitchHeader(h Header) bool { return h.Get("Upgrade") != "" && HeaderValuesContainsToken(h["Connection"], "Upgrade") } + +// bodyIsWritable reports whether the Body supports writing. The +// Transport returns Writable bodies for 101 Switching Protocols +// responses. +// The Transport uses this method to determine whether a persistent +// connection is done being managed from its perspective. Once we +// return a writable response body to a user, the net/http package is +// done managing that connection. +func (r *Response) bodyIsWritable() bool { + _, ok := r.Body.(io.Writer) + return ok +} diff --git a/x/net/http/transfer.go b/x/net/http/transfer.go index 0787270..823ef7d 100644 --- a/x/net/http/transfer.go +++ b/x/net/http/transfer.go @@ -367,7 +367,7 @@ func bodyAllowedForStatus(status int) bool { return true } -// Determine whether to hang up after sending a request and body, or +// Determine whether to hang up after write a request and body, or // receiving a response and body // 'header' is the request headers. func shouldClose(major, minor int, header Header, removeCloseHeader bool) bool { diff --git a/x/net/http/transport.go b/x/net/http/transport.go index e5467e2..99003eb 100644 --- a/x/net/http/transport.go +++ b/x/net/http/transport.go @@ -14,10 +14,25 @@ import ( "github.com/goplus/llgo/c/libuv" "github.com/goplus/llgo/c/net" "github.com/goplus/llgo/c/syscall" - xnet "github.com/goplus/llgoexamples/x/net" + xnet "github.com/goplus/llgo/x/net" "github.com/goplus/llgoexamples/rust/hyper" ) +// DefaultTransport is the default implementation of Transport and is +// used by DefaultClient. It establishes network connections as needed +// and caches them for reuse by subsequent calls. It uses HTTP proxies +// as directed by the environment variables HTTP_PROXY, HTTPS_PROXY +// and NO_PROXY (or the lowercase versions thereof). +var DefaultTransport RoundTripper = &Transport{ + //Proxy: ProxyFromEnvironment, + Proxy: nil, +} + +// DefaultMaxIdleConnsPerHost is the default value of Transport's +// MaxIdleConnsPerHost. +const DefaultMaxIdleConnsPerHost = 2 +const defaultHTTPPort = "80" + type Transport struct { altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme reqMu sync.Mutex @@ -53,55 +68,11 @@ type Transport struct { MaxConnsPerHost int } -// DefaultTransport is the default implementation of Transport and is -// used by DefaultClient. It establishes network connections as needed -// and caches them for reuse by subsequent calls. It uses HTTP proxies -// as directed by the environment variables HTTP_PROXY, HTTPS_PROXY -// and NO_PROXY (or the lowercase versions thereof). -var DefaultTransport RoundTripper = &Transport{ - //Proxy: ProxyFromEnvironment, - Proxy: nil, -} - -const ( - defaultHTTPPort = "80" -) - -// persistConn wraps a connection, usually a persistent one -// (but may be used for non-keep-alive requests as well) -type persistConn struct { - // alt optionally specifies the TLS NextProto RoundTripper. - // This is used for HTTP/2 today and future protocols later. - // If it's non-nil, the rest of the fields are unused. - alt RoundTripper - - //br *bufio.Reader // from conn - //bw *bufio.Writer // to conn - //nwrite int64 // bytes written - //writech chan writeRequest // written by roundTrip; read by writeLoop - //closech chan struct{} // closed when conn closed - - t *Transport - cacheKey connectMethodKey - conn *connData - nwrite int64 // bytes written - reqch chan requestAndChan // written by roundTrip; read by readLoop - closech chan struct{} // closed when conn closed - writeLoopDone chan struct{} // closed when write loop ends - - cancelch chan freeChan - timeoutch chan struct{} - - isProxy bool - mu sync.Mutex // guards following fields - numExpectedResponses int - closed error // set non-nil when conn is closed, before closech is closed - canceledErr error // set non-nil if conn is canceled - broken bool // an error has happened on this connection; marked broken so it's not reused. - // mutateHeaderFunc is an optional func to modify extra - // headers on each outbound request before it's written. (the - // original Request given to RoundTrip is not modified) - mutateHeaderFunc func(Header) +// A cancelKey is the key of the reqCanceler map. +// We wrap the *Request in this type since we want to use the original request, +// not any transient one created by roundTrip. +type cancelKey struct { + req *Request } // incomparable is a zero-width, non-comparable type. Adding it to a struct @@ -123,6 +94,15 @@ type requestAndChan struct { callerGone <-chan struct{} // closed when roundTrip caller has returned } +// A writeRequest is sent by the caller's goroutine to the +// writeLoop's goroutine to write a request while the read loop +// concurrently waits on both the write response and the server's +// reply. +type writeRequest struct { + req *transportRequest + ch chan<- error +} + // responseAndError is how the goroutine reading from an HTTP/1 server // communicates with the goroutine doing the RoundTrip. type responseAndError struct { @@ -142,11 +122,56 @@ type freeChan struct { freech chan struct{} } -// A cancelKey is the key of the reqCanceler map. -// We wrap the *Request in this type since we want to use the original request, -// not any transient one created by roundTrip. -type cancelKey struct { - req *Request +type readTrackingBody struct { + io.ReadCloser + didRead bool + didClose bool +} + +func (r *readTrackingBody) Read(data []byte) (int, error) { + r.didRead = true + return r.ReadCloser.Read(data) +} + +func (r *readTrackingBody) Close() error { + r.didClose = true + return r.ReadCloser.Close() +} + +// setupRewindBody returns a new request with a custom body wrapper +// that can report whether the body needs rewinding. +// This lets rewindBody avoid an error result when the request +// does not have GetBody but the body hasn't been readRespLineAndHeader at all yet. +func setupRewindBody(req *Request) *Request { + if req.Body == nil || req.Body == NoBody { + return req + } + newReq := *req + newReq.Body = &readTrackingBody{ReadCloser: req.Body} + return &newReq +} + +// rewindBody returns a new request with the body rewound. +// It returns req unmodified if the body does not need rewinding. +// rewindBody takes care of closing req.Body when appropriate +// (in all cases except when rewindBody returns req unmodified). +func rewindBody(req *Request) (rewound *Request, err error) { + if req.Body == nil || req.Body == NoBody || (!req.Body.(*readTrackingBody).didRead && !req.Body.(*readTrackingBody).didClose) { + return req, nil // nothing to rewind + } + if !req.Body.(*readTrackingBody).didClose { + req.closeBody() + } + if req.GetBody == nil { + return nil, errCannotRewind + } + body, err := req.GetBody() + if err != nil { + return nil, err + } + newReq := *req + newReq.Body = &readTrackingBody{ReadCloser: body} + return &newReq, nil } // transportRequest is a wrapper around a *Request that adds @@ -169,16 +194,54 @@ func (tr *transportRequest) extraHeaders() Header { return tr.extra } -// useRegisteredProtocol reports whether an alternate protocol (as registered -// with Transport.RegisterProtocol) should be respected for this request. -func (t *Transport) useRegisteredProtocol(req *Request) bool { - if req.URL.Scheme == "https" && req.requiresHTTP1() { - // If this request requires HTTP/1, don't use the - // "https" alternate protocol, which is used by the - // HTTP/2 code to take over requests if there's an - // existing cached HTTP/2 connection. +func (tr *transportRequest) setError(err error) { + tr.mu.Lock() + if tr.err == nil { + tr.err = err + } + tr.mu.Unlock() +} + +func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) { + cm.targetScheme = treq.URL.Scheme + cm.targetAddr = canonicalAddr(treq.URL) + if t.Proxy != nil { + cm.proxyURL, err = t.Proxy(treq.Request) + } + cm.treq = treq + cm.onlyH1 = treq.requiresHTTP1() + return cm, err +} + +func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) { + t.reqMu.Lock() + defer t.reqMu.Unlock() + if t.reqCanceler == nil { + t.reqCanceler = make(map[cancelKey]func(error)) + } + if fn != nil { + t.reqCanceler[key] = fn + } else { + delete(t.reqCanceler, key) + } +} + +// replaceReqCanceler replaces an existing cancel function. If there is no cancel function +// for the request, we don't set the function and return false. +// Since CancelRequest will clear the canceler, we can use the return value to detect if +// the request was canceled since the last setReqCancel call. +func (t *Transport) replaceReqCanceler(key cancelKey, fn func(error)) bool { + t.reqMu.Lock() + defer t.reqMu.Unlock() + _, ok := t.reqCanceler[key] + if !ok { return false } + if fn != nil { + t.reqCanceler[key] = fn + } else { + delete(t.reqCanceler, key) + } return true } @@ -193,6 +256,21 @@ func (t *Transport) alternateRoundTripper(req *Request) RoundTripper { return altProto[req.URL.Scheme] } +// useRegisteredProtocol reports whether an alternate protocol (as registered +// with Transport.RegisterProtocol) should be respected for this request. +func (t *Transport) useRegisteredProtocol(req *Request) bool { + if req.URL.Scheme == "https" && req.requiresHTTP1() { + // If this request requires HTTP/1, don't use the + // "https" alternate protocol, which is used by the + // HTTP/2 code to take over requests if there's an + // existing cached HTTP/2 connection. + return false + } + return true +} + +// ---------------------------------------------------------- + func (t *Transport) RoundTrip(req *Request) (*Response, error) { //t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) //ctx := req.Context() @@ -432,69 +510,69 @@ func (t *Transport) dialConnFor(w *wantConn) { // t.putOrCloseIdleConn(pc) //} - // TODO(spongehah) decConnsPerHost // If an error occurs during the dialing process, the connection count for that host is decreased. // This ensures that the connection count remains accurate even in cases where the dial attempt fails. - //if err != nil { - // t.decConnsPerHost(w.key) - //} + if err != nil { + t.decConnsPerHost(w.key) + } } // decConnsPerHost decrements the per-host connection count for key, // which may in turn give a different waiting goroutine permission to dial. -//func (t *Transport) decConnsPerHost(key connectMethodKey) { -// if t.MaxConnsPerHost <= 0 { -// return -// } -// -// t.connsPerHostMu.Lock() -// defer t.connsPerHostMu.Unlock() -// n := t.connsPerHost[key] -// if n == 0 { -// // Shouldn't happen, but if it does, the counting is buggy and could -// // easily lead to a silent deadlock, so report the problem loudly. -// panic("net/http: internal error: connCount underflow") -// } -// -// // Can we hand this count to a goroutine still waiting to dial? -// // (Some goroutines on the wait list may have timed out or -// // gotten a connection another way. If they're all gone, -// // we don't want to kick off any spurious dial operations.) -// if q := t.connsPerHostWait[key]; q.len() > 0 { -// done := false -// for q.len() > 0 { -// w := q.popFront() -// if w.waiting() { -// go t.dialConnFor(w) -// done = true -// break -// } -// } -// if q.len() == 0 { -// delete(t.connsPerHostWait, key) -// } else { -// // q is a value (like a slice), so we have to store -// // the updated q back into the map. -// t.connsPerHostWait[key] = q -// } -// if done { -// return -// } -// } -// -// // Otherwise, decrement the recorded count. -// if n--; n == 0 { -// delete(t.connsPerHost, key) -// } else { -// t.connsPerHost[key] = n -// } -//} +func (t *Transport) decConnsPerHost(key connectMethodKey) { + if t.MaxConnsPerHost <= 0 { + return + } + + t.connsPerHostMu.Lock() + defer t.connsPerHostMu.Unlock() + n := t.connsPerHost[key] + if n == 0 { + // Shouldn't happen, but if it does, the counting is buggy and could + // easily lead to a silent deadlock, so report the problem loudly. + panic("net/http: internal error: connCount underflow") + } + + // Can we hand this count to a goroutine still waiting to dial? + // (Some goroutines on the wait list may have timed out or + // gotten a connection another way. If they're all gone, + // we don't want to kick off any spurious dial operations.) + if q := t.connsPerHostWait[key]; q.len() > 0 { + done := false + for q.len() > 0 { + w := q.popFront() + if w.waiting() { + go t.dialConnFor(w) + done = true + break + } + } + if q.len() == 0 { + delete(t.connsPerHostWait, key) + } else { + // q is a value (like a slice), so we have to store + // the updated q back into the map. + t.connsPerHostWait[key] = q + } + if done { + return + } + } + + // Otherwise, decrement the recorded count. + if n--; n == 0 { + delete(t.connsPerHost, key) + } else { + t.connsPerHost[key] = n + } +} func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *persistConn, err error) { pconn = &persistConn{ t: t, cacheKey: cm.key(), reqch: make(chan requestAndChan, 1), + writech: make(chan writeRequest, 1), cancelch: make(chan freeChan, 1), timeoutch: make(chan struct{}, 1), closech: make(chan struct{}, 1), @@ -535,6 +613,24 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers return nil, err } pconn.conn = conn + + // hyper specific + // Hookup the IO + hyperIo := newIoWithConnReadWrite(conn) + // We need an executor generally to poll futures + exec := hyper.NewExecutor() + // Prepare client options + opts := hyper.NewClientConnOptions() + opts.Exec(exec) + pconn.io = hyperIo + pconn.exec = exec + pconn.opts = opts + // send the handshake + handshakeTask := hyper.Handshake(hyperIo, opts) + setTaskId(handshakeTask, write) + // Let's wait for the handshake to finish... + exec.Push(handshakeTask) + //if cm.scheme() == "https" { // var firstTLSHost string // if firstTLSHost, _, err = net.SplitHostPort(cm.addr()); err != nil { @@ -702,10 +798,11 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // In Hyper, the writeLoop() and readLoop() are combined together --> readWriteLoop(). startBytesWritten := pc.nwrite + writeErrCh := make(chan error, 1) + pc.writech <- writeRequest{req, writeErrCh} // Send the request to readWriteLoop(). resc := make(chan responseAndError, 1) - pc.reqch <- requestAndChan{ req: req.Request, cancelKey: req.cancelKey, @@ -731,7 +828,22 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err return nil, fmt.Errorf("request timeout\n") } select { - //case err := <-writeErrCh: + case err := <-writeErrCh: + if debugRoundTrip { + //req.logf("writeErrCh resv: %T/%#v", err, err) + } + if err != nil { + pc.close(fmt.Errorf("write error: %w", err)) + return nil, pc.mapRoundTripError(req, startBytesWritten, err) + } + //if d := pc.t.ResponseHeaderTimeout; d > 0 { + // if debugRoundTrip { + // //req.logf("starting timer for %v", d) + // } + // timer := time.NewTimer(d) + // defer timer.Stop() // prevent leaks + // respHeaderTimer = timer.C + //} case <-pcClosed: pcClosed = nil if canceled || pc.t.replaceReqCanceler(req.cancelKey, nil) { @@ -768,49 +880,52 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // readWriteLoop handles the main I/O loop for a persistent connection. // It processes incoming requests, sends them to the server, and handles responses. func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { - // Hookup the IO - hyperIo := newIoWithConnReadWrite(pc.conn) - - // We need an executor generally to poll futures - exec := hyper.NewExecutor() - // Prepare client options - opts := hyper.NewClientConnOptions() - opts.Exec(exec) + defer close(pc.writeLoopDone) - handshakeTask := hyper.Handshake(hyperIo, opts) - setTaskId(handshakeTask, sending) - - // Let's wait for the handshake to finish... - exec.Push(handshakeTask) + const debugReadWriteLoop = true // Debug switch provided for developers // The polling state machine! - //for { // Poll all ready tasks and act on them... - rc := <-pc.reqch // blocking alive := true var bodyWriter *io.PipeWriter - var respBody *hyper.Body = nil for alive { select { case fc := <-pc.cancelch: + if debugReadWriteLoop { + println("cancelch") + } // Free the resources - freeResources(nil, respBody, bodyWriter, exec, pc, rc) + //freeResources(nil, respBody, bodyWriter, pc.exec, pc, rc) alive = false + pc.close(errors.New("timeout error")) close(fc.freech) return + case <-pc.closech: + if debugReadWriteLoop { + println("closech") + } + return default: - task := exec.Poll() + task := pc.exec.Poll() if task == nil { loop.Run(libuv.RUN_ONCE) continue } switch (taskId)(uintptr(task.Userdata())) { - case sending: - err := checkTaskType(task, sending) + case write: + if debugReadWriteLoop { + println("write") + } + wc := <-pc.writech // blocking + + startBytesWritten := pc.nwrite + + err := checkTaskType(task, write) if err != nil { - rc.ch <- responseAndError{err: err} + wc.ch <- err // Free the resources - freeResources(task, respBody, bodyWriter, exec, pc, rc) + //freeResources(task, respBody, bodyWriter, pc.exec, pc, rc) + pc.close(err) return } @@ -818,53 +933,110 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { task.Free() // Prepare the hyper.Request - hyperReq, err := newHyperRequest(rc.req) + hyperReq, err := newHyperRequest(wc.req.Request) + if bre, ok := err.(requestBodyReadError); ok { + err = bre.error + // Errors reading from the user's + // Request.Body are high priority. + // Set it here before sending on the + // channels below or calling + // pc.close() which tears down + // connections and causes other + // errors. + wc.req.setError(err) + } if err != nil { - rc.ch <- responseAndError{err: err} + if pc.nwrite == startBytesWritten { + err = nothingWrittenError{err} + } + wc.ch <- err // Free the resources - freeResources(task, respBody, bodyWriter, exec, pc, rc) + //freeResources(task, respBody, bodyWriter, pc.exec, pc, rc) + pc.close(err) return } // Send it! sendTask := client.Send(hyperReq) - setTaskId(sendTask, receiveResp) - sendRes := exec.Push(sendTask) + setTaskId(sendTask, readRespLineAndHeader) + sendRes := pc.exec.Push(sendTask) if sendRes != hyper.OK { - rc.ch <- responseAndError{err: fmt.Errorf("failed to send request")} + wc.ch <- err // Free the resources - freeResources(task, respBody, bodyWriter, exec, pc, rc) + //freeResources(task, respBody, bodyWriter, pc.exec, pc, rc) + pc.close(err) return } // For this example, no longer need the client client.Free() - case receiveResp: - err := checkTaskType(task, receiveResp) - if err != nil { - rc.ch <- responseAndError{err: err} - // Free the resources - freeResources(task, respBody, bodyWriter, exec, pc, rc) + case readRespLineAndHeader: + if debugReadWriteLoop { + println("readRespLineAndHeader") + } + rc := <-pc.reqch // blocking + + closeErr := errReadLoopExiting // default value, if not changed below + defer func() { + pc.close(closeErr) + // TODO(spongehah) ConnPool(readWriteLoop) + //pc.t.removeIdleConn(pc) + }() + + //tryPutIdleConn := func(trace *httptrace.ClientTrace) bool { + // if err := pc.t.tryPutIdleConn(pc); err != nil { + // closeErr = err + // if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { + // trace.PutIdleConn(err) + // } + // return false + // } + // if trace != nil && trace.PutIdleConn != nil { + // trace.PutIdleConn(nil) + // } + // return true + //} + + // Read this once, before loop starts. (to avoid races in tests) + testHookMu.Lock() + testHookReadLoopBeforeNextRead := testHookReadLoopBeforeNextRead + testHookMu.Unlock() + + pc.mu.Lock() + if pc.numExpectedResponses == 0 { + pc.closeLocked(errServerClosedIdle) + pc.mu.Unlock() return } + pc.mu.Unlock() + err := checkTaskType(task, readRespLineAndHeader) // Take the results hyperResp := (*hyper.Response)(task.Value()) task.Free() - resp, err := ReadResponse(hyperResp, rc.req) + var resp *Response + var respBody *hyper.Body + if err == nil { + resp, err = ReadResponse(hyperResp, rc.req) + respBody = hyperResp.Body() + resp.Body, bodyWriter = io.Pipe() + } else { + err = transportReadFromServerError{err} + closeErr = err + } + if err != nil { - rc.ch <- responseAndError{err: err} + select { + case rc.ch <- responseAndError{err: err}: + case <-rc.callerGone: + return + } // Free the resources - freeResources(task, respBody, bodyWriter, exec, pc, rc) + freeResources(task, respBody, bodyWriter, pc.exec, pc, rc) return } - respBody = hyperResp.Body() - resp.Body, bodyWriter = io.Pipe() - - rc.ch <- responseAndError{res: resp} - // Response has been returned, stop the timer pc.conn.IsCompleted = 1 // Stop the timer @@ -873,70 +1045,92 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { (*libuv.Handle)(c.Pointer(&pc.conn.TimeoutTimer)).Close(nil) } - dataTask := respBody.Data() - setTaskId(dataTask, receiveRespBody) - exec.Push(dataTask) + pc.mu.Lock() + pc.numExpectedResponses-- + pc.mu.Unlock() - // No longer need the response - hyperResp.Free() - case receiveRespBody: - err := checkTaskType(task, receiveRespBody) - if err != nil { - rc.ch <- responseAndError{err: err} - // Free the resources - freeResources(task, respBody, bodyWriter, exec, pc, rc) - return + bodyWritable := resp.bodyIsWritable() + hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0 + + if resp.Close || rc.req.Close || resp.StatusCode <= 199 || bodyWritable { + // Don't do keep-alive on error if either party requested a close + // or we get an unexpected informational (1xx) response. + // StatusCode 100 is already handled above. + alive = false } - if task.Type() == hyper.TaskBuf { - buf := (*hyper.Buf)(task.Value()) - bufLen := buf.Len() - bytes := unsafe.Slice((*byte)(buf.Bytes()), bufLen) - if bodyWriter == nil { - rc.ch <- responseAndError{err: fmt.Errorf("ResponseBodyWriter is nil")} - // Free the resources - freeResources(task, respBody, bodyWriter, exec, pc, rc) - return + if !hasBody || bodyWritable { + //replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) + pc.t.replaceReqCanceler(rc.cancelKey, nil) + + // TODO(spongehah) ConnPool(readWriteLoop) + //// Put the idle conn back into the pool before we send the response + //// so if they process it quickly and make another request, they'll + //// get this same conn. But we use the unbuffered channel 'rc' + //// to guarantee that persistConn.roundTrip got out of its select + //// potentially waiting for this persistConn to close. + //alive = alive && + // !pc.sawEOF && + // pc.wroteRequest() && + // replaced && tryPutIdleConn(trace) + + if bodyWritable { + closeErr = errCallerOwnsConn } - _, err := bodyWriter.Write(bytes) // blocking - if err != nil { - rc.ch <- responseAndError{err: err} - // Free the resources - freeResources(task, respBody, bodyWriter, exec, pc, rc) + + select { + case rc.ch <- responseAndError{res: resp}: + case <-rc.callerGone: return } - buf.Free() - task.Free() - dataTask := respBody.Data() - setTaskId(dataTask, receiveRespBody) - exec.Push(dataTask) + // Now that they've read from the unbuffered channel, they're safely + // out of the select that also waits on this goroutine to die, so + // we're allowed to exit now if needed (if alive is false) + testHookReadLoopBeforeNextRead() + continue + } + + bodyForeachTask := respBody.Foreach(appendToResponseBody, c.Pointer(bodyWriter)) + setTaskId(bodyForeachTask, readRespBody) + pc.exec.Push(bodyForeachTask) - break + rc.ch <- responseAndError{res: resp} + + // No longer need the response + hyperResp.Free() + case readRespBody: + // A background task of reading the response body is completed + if debugReadWriteLoop { + println("readRespBody") + } + err := checkTaskType(task, readRespBody) + if err != nil { + fmt.Println(err) + pc.close(err) + return } - // We are done with the response body if task.Type() != hyper.TaskEmpty { - c.Printf(c.Str("unexpected task type\n")) - rc.ch <- responseAndError{err: fmt.Errorf("unexpected task type\n")} - // Free the resources - freeResources(task, respBody, bodyWriter, exec, pc, rc) + err = errors.New("unexpected task type\n") + fmt.Println(err) + pc.close(err) return } - // Free the resources - freeResources(task, respBody, bodyWriter, exec, pc, rc) - - alive = false + // free the task + task.Free() + bodyWriter.Close() case notSet: // A background task for hyper_client completed... task.Free() } } } - //} } +// ---------------------------------------------------------- + type connData struct { TcpHandle libuv.Tcp ConnectReq libuv.Connect @@ -981,7 +1175,7 @@ func allocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { } // onRead is the libuv callback for reading from a socket -// This callback function is called when data is available to be read +// This callback function is called when data is available to be readRespLineAndHeader func onRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { // Get the connection data associated with the stream conn := (*connData)((*libuv.Handle)(c.Pointer(stream)).GetData()) @@ -1002,7 +1196,7 @@ func onRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { } } -// readCallBack read callback function for Hyper library +// readCallBack readRespLineAndHeader callback function for Hyper library func readCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { // Get the user data (connection data) conn := (*connData)(userdata) @@ -1096,7 +1290,7 @@ func onTimeout(handle *libuv.Timer) { (*libuv.Handle)(c.Pointer(&ct.conn.TimeoutTimer)).Close(nil) } -// newIoWithConnReadWrite creates a new IO with read and write callbacks +// newIoWithConnReadWrite creates a new IO with readRespLineAndHeader and write callbacks func newIoWithConnReadWrite(connData *connData) *hyper.Io { hyperIo := hyper.NewIo() hyperIo.SetUserdata(c.Pointer(connData)) @@ -1110,9 +1304,9 @@ type taskId c.Int const ( notSet taskId = iota - sending - receiveResp - receiveRespBody + write + readRespLineAndHeader + readRespBody ) // setTaskId Set taskId to the task's userdata as a unique identifier @@ -1124,7 +1318,7 @@ func setTaskId(task *hyper.Task, userData taskId) { // checkTaskType checks the task type func checkTaskType(task *hyper.Task, curTaskId taskId) error { switch curTaskId { - case sending: + case write: if task.Type() == hyper.TaskError { c.Printf(c.Str("handshake task error!\n")) return fail((*hyper.Error)(task.Value())) @@ -1133,7 +1327,7 @@ func checkTaskType(task *hyper.Task, curTaskId taskId) error { return fmt.Errorf("unexpected task type\n") } return nil - case receiveResp: + case readRespLineAndHeader: if task.Type() == hyper.TaskError { c.Printf(c.Str("send task error!\n")) return fail((*hyper.Error)(task.Value())) @@ -1143,7 +1337,7 @@ func checkTaskType(task *hyper.Task, curTaskId taskId) error { return fmt.Errorf("unexpected task type\n") } return nil - case receiveRespBody: + case readRespBody: if task.Type() == hyper.TaskError { c.Printf(c.Str("body error!\n")) return fail((*hyper.Error)(task.Value())) @@ -1281,7 +1475,7 @@ func (nwe nothingWrittenError) Unwrap() error { } // transportReadFromServerError is used by Transport.readLoop when the -// 1 byte peek read fails and we're actually anticipating a response. +// 1 byte peek readRespLineAndHeader fails and we're actually anticipating a response. // Usually this is just due to the inherent keep-alive shut down race, // where the server closed the connection at the same time the client // wrote. The underlying err field is usually io.EOF or some @@ -1311,72 +1505,70 @@ var ( testHookReadLoopBeforeNextRead = nop ) -/*// alternateRoundTripper returns the alternate RoundTripper to use -// for this request if the Request's URL scheme requires one, -// or nil for the normal case of using the Transport. -func (t *Transport) alternateRoundTripper(req *Request) RoundTripper { - if !t.useRegisteredProtocol(req) { - return nil - } - altProto, _ := t.altProto.Load().(map[string]RoundTripper) - return altProto[req.URL.Scheme] +var portMap = map[string]string{ + "http": "80", + "https": "443", + "socks5": "1080", } -// useRegisteredProtocol reports whether an alternate protocol (as registered -// with Transport.RegisterProtocol) should be respected for this request. -func (t *Transport) useRegisteredProtocol(req *Request) bool { - if req.URL.Scheme == "https" && req.requiresHTTP1() { - // If this request requires HTTP/1, don't use the - // "https" alternate protocol, which is used by the - // HTTP/2 code to take over requests if there's an - // existing cached HTTP/2 connection. - return false +func idnaASCIIFromURL(url *url.URL) string { + addr := url.Hostname() + if v, err := idnaASCII(addr); err == nil { + addr = v } - return true + return addr } -*/ -func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) { - cm.targetScheme = treq.URL.Scheme - cm.targetAddr = canonicalAddr(treq.URL) - if t.Proxy != nil { - cm.proxyURL, err = t.Proxy(treq.Request) +// canonicalAddr returns url.Host but always with a ":port" suffix. +func canonicalAddr(url *url.URL) string { + port := url.Port() + if port == "" { + port = portMap[url.Scheme] } - cm.treq = treq - cm.onlyH1 = treq.requiresHTTP1() - return cm, err + return xnet.JoinHostPort(idnaASCIIFromURL(url), port) } -func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) { - t.reqMu.Lock() - defer t.reqMu.Unlock() - if t.reqCanceler == nil { - t.reqCanceler = make(map[cancelKey]func(error)) - } - if fn != nil { - t.reqCanceler[key] = fn - } else { - delete(t.reqCanceler, key) - } -} +// persistConn wraps a connection, usually a persistent one +// (but may be used for non-keep-alive requests as well) +type persistConn struct { + // alt optionally specifies the TLS NextProto RoundTripper. + // This is used for HTTP/2 today and future protocols later. + // If it's non-nil, the rest of the fields are unused. + alt RoundTripper -// replaceReqCanceler replaces an existing cancel function. If there is no cancel function -// for the request, we don't set the function and return false. -// Since CancelRequest will clear the canceler, we can use the return value to detect if -// the request was canceled since the last setReqCancel call. -func (t *Transport) replaceReqCanceler(key cancelKey, fn func(error)) bool { - t.reqMu.Lock() - defer t.reqMu.Unlock() - _, ok := t.reqCanceler[key] - if !ok { - return false - } - if fn != nil { - t.reqCanceler[key] = fn - } else { - delete(t.reqCanceler, key) - } - return true + //br *bufio.Reader // from conn + //bw *bufio.Writer // to conn + //nwrite int64 // bytes written + //writech chan writeRequest // written by roundTrip; read by writeLoop + //closech chan struct{} // closed when conn closed + + t *Transport + cacheKey connectMethodKey + conn *connData + nwrite int64 // bytes written + reqch chan requestAndChan // written by roundTrip; readRespLineAndHeader by readWriteLoop + writech chan writeRequest // written by roundTrip; readRespLineAndHeader by writeLoop(Already merged into reqch) + closech chan struct{} // closed when conn closed + writeLoopDone chan struct{} // closed when write loop ends + + cancelch chan freeChan + timeoutch chan struct{} + + isProxy bool + mu sync.Mutex // guards following fields + numExpectedResponses int + closed error // set non-nil when conn is closed, before closech is closed + canceledErr error // set non-nil if conn is canceled + broken bool // an error has happened on this connection; marked broken so it's not reused. + // mutateHeaderFunc is an optional func to modify extra + // headers on each outbound request before it's written. (the + // original Request given to RoundTrip is not modified) + mutateHeaderFunc func(Header) + + // hyper specific + exec *hyper.Executor + opts *hyper.ClientConnOptions + io *hyper.Io } func (pc *persistConn) cancelRequest(err error) { @@ -1404,8 +1596,7 @@ func (pc *persistConn) closeLocked(err error) { pc.broken = true if pc.closed == nil { pc.closed = err - // TODO(spongehah) decConnsPerHost - //pc.t.decConnsPerHost(pc.cacheKey) + pc.t.decConnsPerHost(pc.cacheKey) // Close HTTP/1 (pc.alt == nil) connection. // HTTP/2 closes its connection itself. if pc.alt == nil { @@ -1492,81 +1683,6 @@ func (pc *persistConn) isBroken() bool { return b } -type readTrackingBody struct { - io.ReadCloser - didRead bool - didClose bool -} - -func (r *readTrackingBody) Read(data []byte) (int, error) { - r.didRead = true - return r.ReadCloser.Read(data) -} - -func (r *readTrackingBody) Close() error { - r.didClose = true - return r.ReadCloser.Close() -} - -// setupRewindBody returns a new request with a custom body wrapper -// that can report whether the body needs rewinding. -// This lets rewindBody avoid an error result when the request -// does not have GetBody but the body hasn't been read at all yet. -func setupRewindBody(req *Request) *Request { - if req.Body == nil || req.Body == NoBody { - return req - } - newReq := *req - newReq.Body = &readTrackingBody{ReadCloser: req.Body} - return &newReq -} - -// rewindBody returns a new request with the body rewound. -// It returns req unmodified if the body does not need rewinding. -// rewindBody takes care of closing req.Body when appropriate -// (in all cases except when rewindBody returns req unmodified). -func rewindBody(req *Request) (rewound *Request, err error) { - if req.Body == nil || req.Body == NoBody || (!req.Body.(*readTrackingBody).didRead && !req.Body.(*readTrackingBody).didClose) { - return req, nil // nothing to rewind - } - if !req.Body.(*readTrackingBody).didClose { - req.closeBody() - } - if req.GetBody == nil { - return nil, errCannotRewind - } - body, err := req.GetBody() - if err != nil { - return nil, err - } - newReq := *req - newReq.Body = &readTrackingBody{ReadCloser: body} - return &newReq, nil -} - -var portMap = map[string]string{ - "http": "80", - "https": "443", - "socks5": "1080", -} - -func idnaASCIIFromURL(url *url.URL) string { - addr := url.Hostname() - if v, err := idnaASCII(addr); err == nil { - addr = v - } - return addr -} - -// canonicalAddr returns url.Host but always with a ":port" suffix. -func canonicalAddr(url *url.URL) string { - port := url.Port() - if port == "" { - port = portMap[url.Scheme] - } - return xnet.JoinHostPort(idnaASCIIFromURL(url), port) -} - // connectMethod is the map key (in its String form) for keeping persistent // TCP connections alive for subsequent HTTP requests. // @@ -1595,6 +1711,14 @@ type connectMethod struct { onlyH1 bool // whether to disable HTTP/2 and force HTTP/1 } +// connectMethodKey is the map key version of connectMethod, with a +// stringified proxy URL (or the empty string) instead of a pointer to +// a URL. +type connectMethodKey struct { + proxy, scheme, addr string + onlyH1 bool +} + func (cm *connectMethod) key() connectMethodKey { proxyStr := "" targetAddr := cm.targetAddr @@ -1642,14 +1766,6 @@ func (cm *connectMethod) proxyAuth() string { return "" } -// connectMethodKey is the map key version of connectMethod, with a -// stringified proxy URL (or the empty string) instead of a pointer to -// a URL. -type connectMethodKey struct { - proxy, scheme, addr string - onlyH1 bool -} - // A wantConn records state about a wanted connection // (that is, an active call to getConn). // The conn may be gotten by dialing or by finding an idle connection, From c2eb82d09421d8d58ef154c18e9ae1357fcd832b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Thu, 22 Aug 2024 18:49:43 +0800 Subject: [PATCH 21/55] WIP(x/http/client): bodyEOFSignal packaging & optimized readLoop() & adjusted timeout logic --- x/net/http/client.go | 6 +- x/net/http/request.go | 4 +- x/net/http/response.go | 1 + x/net/http/transport.go | 384 ++++++++++++++++++++++++++++++---------- 4 files changed, 295 insertions(+), 100 deletions(-) diff --git a/x/net/http/client.go b/x/net/http/client.go index 4fc6e41..6bf72c9 100644 --- a/x/net/http/client.go +++ b/x/net/http/client.go @@ -307,7 +307,9 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, d } // TODO(spongehah) timeout + req.timeoutch = make(chan struct{}, 1) //stopTimer, didTimeout := setRequestCancel(req, rt, deadline) + sub := deadline.Sub(time.Now()) req.timeout = sub resp, err = rt.RoundTrip(req) @@ -504,7 +506,7 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi req.ctx, cancelCtx = context.WithDeadline(oldCtx, deadline) } - cancel := make(chan struct{}) + cancel := make(chan struct{}, 1) req.Cancel = cancel doCancel := func() { @@ -518,7 +520,7 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi } } - stopTimerCh := make(chan struct{}) + stopTimerCh := make(chan struct{}, 1) var once sync.Once stopTimer = func() { once.Do(func() { diff --git a/x/net/http/request.go b/x/net/http/request.go index cb50936..8a1fb88 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -38,7 +38,9 @@ type Request struct { RemoteAddr string RequestURI string //TLS *tls.ConnectionState - Cancel <-chan struct{} + Cancel <-chan struct{} + timeoutch chan struct{} //optional + Response *Response timeout time.Duration ctx context.Context diff --git a/x/net/http/response.go b/x/net/http/response.go index d2a7dd5..c647f20 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -21,6 +21,7 @@ type Response struct { ContentLength int64 TransferEncoding []string Close bool + Uncompressed bool //Trailer Header Request *Request } diff --git a/x/net/http/transport.go b/x/net/http/transport.go index 99003eb..52b7134 100644 --- a/x/net/http/transport.go +++ b/x/net/http/transport.go @@ -1,6 +1,7 @@ package http import ( + "compress/gzip" "context" "errors" "fmt" @@ -117,11 +118,6 @@ type connAndTimeoutChan struct { timeoutch chan struct{} } -type freeChan struct { - _ incomparable - freech chan struct{} -} - type readTrackingBody struct { io.ReadCloser didRead bool @@ -141,7 +137,7 @@ func (r *readTrackingBody) Close() error { // setupRewindBody returns a new request with a custom body wrapper // that can report whether the body needs rewinding. // This lets rewindBody avoid an error result when the request -// does not have GetBody but the body hasn't been readRespLineAndHeader at all yet. +// does not have GetBody but the body hasn't been read at all yet. func setupRewindBody(req *Request) *Request { if req.Body == nil || req.Body == NoBody { return req @@ -269,6 +265,32 @@ func (t *Transport) useRegisteredProtocol(req *Request) bool { return true } +// CancelRequest cancels an in-flight request by closing its connection. +// CancelRequest should only be called after RoundTrip has returned. +// +// Deprecated: Use Request.WithContext to create a request with a +// cancelable context instead. CancelRequest cannot cancel HTTP/2 +// requests. +func (t *Transport) CancelRequest(req *Request) { + t.cancelRequest(cancelKey{req}, errRequestCanceled) +} + +// Cancel an in-flight request, recording the error value. +// Returns whether the request was canceled. +func (t *Transport) cancelRequest(key cancelKey, err error) bool { + // This function must not return until the cancel func has completed. + // See: https://golang.org/issue/34658 + t.reqMu.Lock() + defer t.reqMu.Unlock() + cancel := t.reqCanceler[key] + delete(t.reqCanceler, key) + if cancel != nil { + cancel(err) + } + + return cancel != nil +} + // ---------------------------------------------------------- func (t *Transport) RoundTrip(req *Request) (*Response, error) { @@ -451,6 +473,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi // TODO(spongehah) cancel(t.getConn) //case <-req.Cancel: // return nil, errRequestCanceledConn + case <-treq.Request.timeoutch: + return nil, fmt.Errorf("request timeout\n") //case <-req.Context().Done(): // return nil, req.Context().Err() case err := <-cancelc: @@ -573,12 +597,8 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers cacheKey: cm.key(), reqch: make(chan requestAndChan, 1), writech: make(chan writeRequest, 1), - cancelch: make(chan freeChan, 1), - timeoutch: make(chan struct{}, 1), closech: make(chan struct{}, 1), writeLoopDone: make(chan struct{}, 1), - //writech: make(chan writeRequest, 1), - //closech: make(chan struct{}), } //if cm.scheme() == "https" && t.hasCustomTLSDialer() { @@ -675,9 +695,8 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers // } //} - if conn.IsCompleted != 1 { - go pconn.readWriteLoop(libuv.DefaultLoop()) - } + go pconn.readWriteLoop(libuv.DefaultLoop()) + return pconn, nil } @@ -699,7 +718,7 @@ func (t *Transport) dial(ctx context.Context, pconn *persistConn, cm connectMeth libuv.InitTimer(loop, &conn.TimeoutTimer) ct := &connAndTimeoutChan{ conn: conn, - timeoutch: pconn.timeoutch, + timeoutch: treq.Request.timeoutch, } (*libuv.Handle)(c.Pointer(&conn.TimeoutTimer)).SetData(c.Pointer(ct)) conn.TimeoutTimer.Start(onTimeout, uint64(treq.timeout.Milliseconds()), 0) @@ -716,14 +735,14 @@ func (t *Transport) dial(ctx context.Context, pconn *persistConn, cm connectMeth var res *net.AddrInfo status := net.Getaddrinfo(c.AllocaCStr(host), c.AllocaCStr(port), &hints, &res) if status != 0 { - close(pconn.timeoutch) + close(treq.Request.timeoutch) return nil, fmt.Errorf("getaddrinfo error\n") } (*libuv.Req)(c.Pointer(&conn.ConnectReq)).SetData(c.Pointer(conn)) status = libuv.TcpConnect(&conn.ConnectReq, &conn.TcpHandle, res.Addr, onConnect) if status != 0 { - close(pconn.timeoutch) + close(treq.Request.timeoutch) return nil, fmt.Errorf("connect error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(status)))) } @@ -781,7 +800,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err req.extraHeaders().Set("Connection", "close") } - gone := make(chan struct{}) + gone := make(chan struct{}, 1) defer close(gone) defer func() { @@ -799,7 +818,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // In Hyper, the writeLoop() and readLoop() are combined together --> readWriteLoop(). startBytesWritten := pc.nwrite writeErrCh := make(chan error, 1) - pc.writech <- writeRequest{req, writeErrCh} + pc.writech <- writeRequest{req: req, ch: writeErrCh} // Send the request to readWriteLoop(). resc := make(chan responseAndError, 1) @@ -820,13 +839,6 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err for { testHookWaitResLoop() - // Determine whether timeout has occurred - if pc.conn.IsCompleted == 1 { - rc := <-pc.reqch // blocking - // Free the resources - freeResources(nil, nil, nil, nil, pc, rc) - return nil, fmt.Errorf("request timeout\n") - } select { case err := <-writeErrCh: if debugRoundTrip { @@ -855,23 +867,21 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err //case <-respHeaderTimer: case re := <-resc: if (re.res == nil) == (re.err == nil) { + println(1) return nil, fmt.Errorf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil) } if debugRoundTrip { + println(2) //req.logf("resc recv: %p, %T/%#v", re.res, re.err, re.err) } if re.err != nil { + println(3) return nil, pc.mapRoundTripError(req, startBytesWritten, re.err) } return re.res, nil // TODO(spongehah) cancel(pc.roundTrip) //case <-cancelChan: - case <-pc.timeoutch: - freech := make(chan struct{}, 1) - pc.cancelch <- freeChan{ - freech: freech, - } - <-freech + case <-req.Request.timeoutch: return nil, fmt.Errorf("request timeout\n") } } @@ -884,22 +894,17 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { const debugReadWriteLoop = true // Debug switch provided for developers + if debugReadWriteLoop { + println("readWriteLoop start") + } + // The polling state machine! // Poll all ready tasks and act on them... alive := true var bodyWriter *io.PipeWriter + var rw readWaiter for alive { select { - case fc := <-pc.cancelch: - if debugReadWriteLoop { - println("cancelch") - } - // Free the resources - //freeResources(nil, respBody, bodyWriter, pc.exec, pc, rc) - alive = false - pc.close(errors.New("timeout error")) - close(fc.freech) - return case <-pc.closech: if debugReadWriteLoop { println("closech") @@ -911,18 +916,22 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { loop.Run(libuv.RUN_ONCE) continue } - switch (taskId)(uintptr(task.Userdata())) { + taskId := (taskId)(uintptr(task.Userdata())) + if debugReadWriteLoop { + println(taskId) + } + switch taskId { case write: if debugReadWriteLoop { println("write") } - wc := <-pc.writech // blocking + wr := <-pc.writech // blocking startBytesWritten := pc.nwrite err := checkTaskType(task, write) if err != nil { - wc.ch <- err + wr.ch <- err // Free the resources //freeResources(task, respBody, bodyWriter, pc.exec, pc, rc) pc.close(err) @@ -933,7 +942,16 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { task.Free() // Prepare the hyper.Request - hyperReq, err := newHyperRequest(wc.req.Request) + hyperReq, err := newHyperRequest(wr.req.Request) + if err == nil { + // Send it! + sendTask := client.Send(hyperReq) + setTaskId(sendTask, read) + sendRes := pc.exec.Push(sendTask) + if sendRes != hyper.OK { + err = errors.New("failed to send the request") + } + } if bre, ok := err.(requestBodyReadError); ok { err = bre.error // Errors reading from the user's @@ -943,25 +961,14 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // pc.close() which tears down // connections and causes other // errors. - wc.req.setError(err) + wr.req.setError(err) } if err != nil { if pc.nwrite == startBytesWritten { err = nothingWrittenError{err} } - wc.ch <- err - // Free the resources - //freeResources(task, respBody, bodyWriter, pc.exec, pc, rc) - pc.close(err) - return - } - - // Send it! - sendTask := client.Send(hyperReq) - setTaskId(sendTask, readRespLineAndHeader) - sendRes := pc.exec.Push(sendTask) - if sendRes != hyper.OK { - wc.ch <- err + //pc.writeErrCh <- err // to the body reader, which might recycle us + wr.ch <- err // to the roundTrip function // Free the resources //freeResources(task, respBody, bodyWriter, pc.exec, pc, rc) pc.close(err) @@ -970,10 +977,14 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // For this example, no longer need the client client.Free() - case readRespLineAndHeader: if debugReadWriteLoop { - println("readRespLineAndHeader") + println("write end") + } + case read: + if debugReadWriteLoop { + println("read") } + rc := <-pc.reqch // blocking closeErr := errReadLoopExiting // default value, if not changed below @@ -997,6 +1008,12 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // return true //} + // eofc is used to block caller goroutines reading from Response.Body + // at EOF until this goroutines has (potentially) added the connection + // back to the idle pool. + eofc := make(chan struct{}, 1) + defer close(eofc) // unblock reader on errors + // Read this once, before loop starts. (to avoid races in tests) testHookMu.Lock() testHookReadLoopBeforeNextRead := testHookReadLoopBeforeNextRead @@ -1010,7 +1027,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { } pc.mu.Unlock() - err := checkTaskType(task, readRespLineAndHeader) + err := checkTaskType(task, read) // Take the results hyperResp := (*hyper.Response)(task.Value()) task.Free() @@ -1038,8 +1055,6 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { } // Response has been returned, stop the timer - pc.conn.IsCompleted = 1 - // Stop the timer if rc.req.timeout > 0 { pc.conn.TimeoutTimer.Stop() (*libuv.Handle)(c.Pointer(&pc.conn.TimeoutTimer)).Close(nil) @@ -1091,24 +1106,71 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { continue } + waitForBodyRead := make(chan bool, 2) + body := &bodyEOFSignal{ + body: resp.Body, + earlyCloseFn: func() error { + waitForBodyRead <- false + <-eofc // will be closed by deferred call at the end of the function + return nil + }, + fn: func(err error) error { + isEOF := err == io.EOF + waitForBodyRead <- isEOF + if isEOF { + <-eofc // see comment above eofc declaration + } else if err != nil { + if cerr := pc.canceled(); cerr != nil { + return cerr + } + } + return err + }, + } + resp.Body = body + + // TODO(spongehah) gzip fail + if rc.addedGzip && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") { + resp.Body = &gzipReader{body: body} + resp.Header.Del("Content-Encoding") + resp.Header.Del("Content-Length") + resp.ContentLength = -1 + resp.Uncompressed = true + } + + rw.waitForBodyRead = waitForBodyRead + rw.rc = rc + rw.eofc = eofc bodyForeachTask := respBody.Foreach(appendToResponseBody, c.Pointer(bodyWriter)) - setTaskId(bodyForeachTask, readRespBody) + setTaskId(bodyForeachTask, readDone) pc.exec.Push(bodyForeachTask) + // TODO(spongehah) select blocking + //select { + //case rc.ch <- responseAndError{res: resp}: + //case <-rc.callerGone: + // return + //} rc.ch <- responseAndError{res: resp} // No longer need the response hyperResp.Free() - case readRespBody: + if debugReadWriteLoop { + println("read end") + } + + //pc.t.replaceReqCanceler(rc.cancelKey, nil) + //eofc <- struct{}{} + case readDone: // A background task of reading the response body is completed if debugReadWriteLoop { - println("readRespBody") + println("readDone") } - err := checkTaskType(task, readRespBody) + err := checkTaskType(task, readDone) if err != nil { fmt.Println(err) pc.close(err) - return + alive = false } if task.Type() != hyper.TaskEmpty { @@ -1121,6 +1183,39 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // free the task task.Free() bodyWriter.Close() + + // Before looping back to the top of this function and peeking on + // the bufio.Reader, wait for the caller goroutine to finish + // reading the response body. (or for cancellation or death) + rc := rw.rc + select { + //case bodyEOF := <-rw.waitForBodyRead: + case <-rw.waitForBodyRead: + //replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool + pc.t.replaceReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool + // TODO(spongehah) ConnPool(readWriteLoop) + //alive = alive && + // bodyEOF && + // !pc.sawEOF && + // pc.wroteRequest() && + // replaced && tryPutIdleConn(trace) + + rw.eofc <- struct{}{} + // TODO(spongehah) cancel(pc.readWriteLoop) + //case <-rc.req.Cancel: + // alive = false + // pc.t.CancelRequest(rc.req) + //case <-rc.req.Context().Done(): + // alive = false + // pc.t.cancelRequest(rc.cancelKey, rc.req.Context().Err()) + case <-pc.closech: + alive = false + } + + testHookReadLoopBeforeNextRead() + if debugReadWriteLoop { + println("readDone end") + } case notSet: // A background task for hyper_client completed... task.Free() @@ -1136,7 +1231,6 @@ type connData struct { ConnectReq libuv.Connect ReadBuf libuv.Buf TimeoutTimer libuv.Timer - IsCompleted int ReadBufFilled uintptr ReadWaker *hyper.Waker WriteWaker *hyper.Waker @@ -1175,12 +1269,10 @@ func allocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { } // onRead is the libuv callback for reading from a socket -// This callback function is called when data is available to be readRespLineAndHeader +// This callback function is called when data is available to be read func onRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { // Get the connection data associated with the stream conn := (*connData)((*libuv.Handle)(c.Pointer(stream)).GetData()) - //conn := (*ConnData)(stream.Data) - //conn := (*struct{ data *ConnData })(c.Pointer(stream)).data // If data was read (nread > 0) if nread > 0 { @@ -1196,7 +1288,7 @@ func onRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { } } -// readCallBack readRespLineAndHeader callback function for Hyper library +// readCallBack read callback function for Hyper library func readCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { // Get the user data (connection data) conn := (*connData)(userdata) @@ -1236,8 +1328,6 @@ func readCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uin func onWrite(req *libuv.Write, status c.Int) { // Get the connection data associated with the write request conn := (*connData)((*libuv.Req)(c.Pointer(req)).GetData()) - //conn := (*ConnData)(req.Data) - //conn := (*struct{ data *ConnData })(c.Pointer(req)).data // If there's a pending write waker if conn.WriteWaker != nil { @@ -1254,11 +1344,9 @@ func writeCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen ui conn := (*connData)(userdata) // Create a libuv buffer initBuf := libuv.InitBuf((*c.Char)(c.Pointer(buf)), c.Uint(bufLen)) - //req := (*libuv.Write)(c.Malloc(unsafe.Sizeof(libuv.Write{}))) req := &libuv.Write{} // Associate the connection data with the write request (*libuv.Req)(c.Pointer(req)).SetData(c.Pointer(conn)) - //req.Data = c.Pointer(conn) // Perform the asynchronous write operation ret := req.Write((*libuv.Stream)(c.Pointer(&conn.TcpHandle)), &initBuf, 1, onWrite) @@ -1282,15 +1370,12 @@ func writeCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen ui // onTimeout is the libuv callback for a timeout func onTimeout(handle *libuv.Timer) { ct := (*connAndTimeoutChan)((*libuv.Handle)(c.Pointer(handle)).GetData()) - if ct.conn.IsCompleted != 1 { - ct.conn.IsCompleted = 1 - ct.timeoutch <- struct{}{} - } + close(ct.timeoutch) // Close the timer (*libuv.Handle)(c.Pointer(&ct.conn.TimeoutTimer)).Close(nil) } -// newIoWithConnReadWrite creates a new IO with readRespLineAndHeader and write callbacks +// newIoWithConnReadWrite creates a new IO with read and write callbacks func newIoWithConnReadWrite(connData *connData) *hyper.Io { hyperIo := hyper.NewIo() hyperIo.SetUserdata(c.Pointer(connData)) @@ -1305,10 +1390,16 @@ type taskId c.Int const ( notSet taskId = iota write - readRespLineAndHeader - readRespBody + read + readDone ) +type readWaiter struct { + rc requestAndChan + waitForBodyRead chan bool + eofc chan struct{} +} + // setTaskId Set taskId to the task's userdata as a unique identifier func setTaskId(task *hyper.Task, userData taskId) { var data = userData @@ -1327,9 +1418,9 @@ func checkTaskType(task *hyper.Task, curTaskId taskId) error { return fmt.Errorf("unexpected task type\n") } return nil - case readRespLineAndHeader: + case read: if task.Type() == hyper.TaskError { - c.Printf(c.Str("send task error!\n")) + c.Printf(c.Str("write task error!\n")) return fail((*hyper.Error)(task.Value())) } if task.Type() != hyper.TaskResponse { @@ -1337,9 +1428,9 @@ func checkTaskType(task *hyper.Task, curTaskId taskId) error { return fmt.Errorf("unexpected task type\n") } return nil - case readRespBody: + case readDone: if task.Type() == hyper.TaskError { - c.Printf(c.Str("body error!\n")) + c.Printf(c.Str("read error!\n")) return fail((*hyper.Error)(task.Value())) } return nil @@ -1391,8 +1482,6 @@ func closeChannels(rc requestAndChan, pc *persistConn) { // Closing the channel close(rc.ch) close(pc.reqch) - close(pc.timeoutch) - close(pc.cancelch) } // freeConnData frees the connection data @@ -1475,7 +1564,7 @@ func (nwe nothingWrittenError) Unwrap() error { } // transportReadFromServerError is used by Transport.readLoop when the -// 1 byte peek readRespLineAndHeader fails and we're actually anticipating a response. +// 1 byte peek read fails and we're actually anticipating a response. // Usually this is just due to the inherent keep-alive shut down race, // where the server closed the connection at the same time the client // wrote. The underlying err field is usually io.EOF or some @@ -1546,14 +1635,11 @@ type persistConn struct { cacheKey connectMethodKey conn *connData nwrite int64 // bytes written - reqch chan requestAndChan // written by roundTrip; readRespLineAndHeader by readWriteLoop - writech chan writeRequest // written by roundTrip; readRespLineAndHeader by writeLoop(Already merged into reqch) + reqch chan requestAndChan // written by roundTrip; read by readWriteLoop + writech chan writeRequest // written by roundTrip; read by writeLoop(Already merged into reqch) closech chan struct{} // closed when conn closed writeLoopDone chan struct{} // closed when write loop ends - cancelch chan freeChan - timeoutch chan struct{} - isProxy bool mu sync.Mutex // guards following fields numExpectedResponses int @@ -1900,3 +1986,107 @@ func (q *wantConnQueue) cleanFront() (cleaned bool) { cleaned = true } } + +// bodyEOFSignal is used by the HTTP/1 transport when reading response +// bodies to make sure we see the end of a response body before +// proceeding and reading on the connection again. +// +// It wraps a ReadCloser but runs fn (if non-nil) at most +// once, right before its final (error-producing) Read or Close call +// returns. fn should return the new error to return from Read or Close. +// +// If earlyCloseFn is non-nil and Close is called before io.EOF is +// seen, earlyCloseFn is called instead of fn, and its return value is +// the return value from Close. +type bodyEOFSignal struct { + body io.ReadCloser + mu sync.Mutex // guards following 4 fields + closed bool // whether Close has been called + rerr error // sticky Read error + fn func(error) error // err will be nil on Read io.EOF + earlyCloseFn func() error // optional alt Close func used if io.EOF not seen +} + +var errReadOnClosedResBody = errors.New("http: read on closed response body") + +func (es *bodyEOFSignal) Read(p []byte) (n int, err error) { + es.mu.Lock() + closed, rerr := es.closed, es.rerr + es.mu.Unlock() + if closed { + return 0, errReadOnClosedResBody + } + if rerr != nil { + return 0, rerr + } + + n, err = es.body.Read(p) + if err != nil { + es.mu.Lock() + defer es.mu.Unlock() + if es.rerr == nil { + es.rerr = err + } + err = es.condfn(err) + } + return +} + +func (es *bodyEOFSignal) Close() error { + es.mu.Lock() + defer es.mu.Unlock() + if es.closed { + return nil + } + es.closed = true + if es.earlyCloseFn != nil && es.rerr != io.EOF { + return es.earlyCloseFn() + } + err := es.body.Close() + return es.condfn(err) +} + +// caller must hold es.mu. +func (es *bodyEOFSignal) condfn(err error) error { + if es.fn == nil { + return err + } + err = es.fn(err) + es.fn = nil + return err +} + +// gzipReader wraps a response body so it can lazily +// call gzip.NewReader on the first call to Read +type gzipReader struct { + _ incomparable + body *bodyEOFSignal // underlying HTTP/1 response body framing + zr *gzip.Reader // lazily-initialized gzip reader + zerr error // any error from gzip.NewReader; sticky +} + +func (gz *gzipReader) Read(p []byte) (n int, err error) { + if gz.zr == nil { + if gz.zerr == nil { + gz.zr, gz.zerr = gzip.NewReader(gz.body) + } + if gz.zerr != nil { + return 0, gz.zerr + } + } + + gz.body.mu.Lock() + if gz.body.closed { + err = errReadOnClosedResBody + } + gz.body.mu.Unlock() + + if err != nil { + return 0, err + } + return gz.zr.Read(p) +} + +func (gz *gzipReader) Close() error { + return gz.body.Close() +} From a68bc29fada2c1b6789a890769641191fa01622c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Fri, 23 Aug 2024 18:28:37 +0800 Subject: [PATCH 22/55] WIP(x/http/client): Use body to wrap readCloser & optimize req.write code --- x/net/http/client.go | 8 +- x/net/http/http.go | 11 ++ x/net/http/request.go | 386 ++++++++++++++++++++++++++-------------- x/net/http/response.go | 6 +- x/net/http/server.go | 12 ++ x/net/http/transfer.go | 334 ++++++++++++++++++++++++++-------- x/net/http/transport.go | 326 ++++++++++++++------------------- x/net/http/util.go | 98 ++++++++++ x/net/ipsock.go | 79 +++++++- x/net/net.go | 20 +++ x/net/parse.go | 12 ++ 11 files changed, 882 insertions(+), 410 deletions(-) create mode 100644 x/net/http/server.go create mode 100644 x/net/net.go create mode 100644 x/net/parse.go diff --git a/x/net/http/client.go b/x/net/http/client.go index 6bf72c9..bf1bfd4 100644 --- a/x/net/http/client.go +++ b/x/net/http/client.go @@ -139,9 +139,7 @@ func (c *Client) do(req *Request) (retres *Response, reterr error) { resp.closeBody() return nil, uerr(fmt.Errorf("failed to parse Location header %q: %v", loc, err)) } - // TODO(spongehah) redirect: Why use host := "" - //host := "" - host := u.Host + host := "" if req.Host != "" && req.Host != req.URL.Host { // If the caller specified a custom Host header and the @@ -239,7 +237,7 @@ func (c *Client) do(req *Request) (retres *Response, reterr error) { // didTimeout is non-nil only if err != nil. func (c *Client) send(req *Request, deadline time.Time) (resp *Response, didTimeout func() bool, err error) { - // TODO(spongehah) cookie + // TODO(spongehah) cookie(c.send) if c.Jar != nil { for _, cookie := range c.Jar.Cookies(req.URL) { req.AddCookie(cookie) @@ -306,7 +304,7 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, d forkReq() } - // TODO(spongehah) timeout + // TODO(spongehah) timeout(send) req.timeoutch = make(chan struct{}, 1) //stopTimer, didTimeout := setRequestCancel(req, rt, deadline) diff --git a/x/net/http/http.go b/x/net/http/http.go index f668906..2f5d5ab 100644 --- a/x/net/http/http.go +++ b/x/net/http/http.go @@ -13,6 +13,17 @@ func isNotToken(r rune) bool { return !IsTokenRune(r) } +// stringContainsCTLByte reports whether s contains any ASCII control character. +func stringContainsCTLByte(s string) bool { + for i := 0; i < len(s); i++ { + b := s[i] + if b < ' ' || b == 0x7f { + return true + } + } + return false +} + // removeEmptyPort strips the empty port in ":port" to "" // as mandated by RFC 3986 Section 6.2.3. func removeEmptyPort(host string) string { diff --git a/x/net/http/request.go b/x/net/http/request.go index 8a1fb88..be81f0e 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -46,7 +46,7 @@ type Request struct { ctx context.Context } -var defaultChunkSize uintptr = 8192 +const defaultChunkSize = 8192 // NewRequest wraps NewRequestWithContext using context.Background. func NewRequest(method, url string, body io.Reader) (*Request, error) { @@ -128,6 +128,7 @@ func NewRequestWithContext(ctx context.Context, method, urlStr string, body io.R r := bytes.NewReader(buf) return io.NopCloser(r), nil } + case *bytes.Reader: req.ContentLength = int64(v.Len()) snapshot := *v @@ -166,122 +167,36 @@ func NewRequestWithContext(ctx context.Context, method, urlStr string, body io.R return req, nil } -func printInformational(userdata c.Pointer, resp *hyper.Response) { - status := resp.Status() - fmt.Println("Informational (1xx): ", status) -} - -type postReq struct { - req *Request - buf []byte -} - -func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { - req := (*postReq)(userdata) - n, err := req.req.Body.Read(req.buf) - if err != nil { - if err == io.EOF { - *chunk = nil - return hyper.PollReady - } - fmt.Println("error reading upload file: ", err) - return hyper.PollError - } - if n > 0 { - *chunk = hyper.CopyBuf(&req.buf[0], uintptr(n)) - return hyper.PollReady - } - if n == 0 { - *chunk = nil - return hyper.PollReady - } - - fmt.Printf("error reading request body: %s\n", c.GoString(c.Strerror(os.Errno))) - return hyper.PollError -} - -func setPostDataNoCopy(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { - type buf struct { - data *uint8 - len uintptr - Unused [16]byte - } - req := (*postReq)(userdata) - buffer := &buf{ - data: &req.buf[0], - len: uintptr(len(req.buf)), - } - - *chunk = (*hyper.Buf)(c.Pointer(buffer)) - n, err := req.req.Body.Read(req.buf) - if err != nil { - if err == io.EOF { - *chunk = nil - return hyper.PollReady - } - fmt.Println("error reading upload file: ", err) - return hyper.PollError - } - if n > 0 { - return hyper.PollReady - } - if n == 0 { - *chunk = nil - return hyper.PollReady - } - - fmt.Printf("error reading request body: %s\n", c.GoString(c.Strerror(os.Errno))) - return hyper.PollError -} - -func newHyperRequest(req *Request) (*hyper.Request, error) { - host := req.Host - uri := req.URL.RequestURI() - method := req.Method - // Prepare the request - hyperReq := hyper.NewRequest() - // Set the request method and uri - if hyperReq.SetMethod(&[]byte(method)[0], c.Strlen(c.AllocaCStr(method))) != hyper.OK { - return nil, fmt.Errorf("error setting method %s\n", method) - } - if hyperReq.SetURI(&[]byte(uri)[0], c.Strlen(c.AllocaCStr(uri))) != hyper.OK { - return nil, fmt.Errorf("error setting uri %s\n", uri) - } - // Set the request headers - reqHeaders := hyperReq.Headers() - if reqHeaders.Set(&[]byte("Host")[0], c.Strlen(c.Str("Host")), &[]byte(host)[0], c.Strlen(c.AllocaCStr(host))) != hyper.OK { - return nil, fmt.Errorf("error setting header: Host: %s\n", host) - } - - if method == "POST" && req.Body != nil { - // 100-continue - if req.ProtoAtLeast(1, 1) && req.Body != nil && req.expectsContinue() { - hyperReq.OnInformational(printInformational, nil) - } - - hyperReqBody := hyper.NewBody() - reqData := &postReq{ - req: req, - buf: make([]byte, 3), - } - hyperReqBody.SetUserdata(c.Pointer(reqData)) - hyperReqBody.SetDataFunc(setPostData) - hyperReq.SetBody(hyperReqBody) - } - - // Add user-defined request headers to hyper.Request - err := req.setHeaders(hyperReq) - if err != nil { - return nil, err - } - - return hyperReq, nil -} +//func setPostDataNoCopy(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { +// req := (*postReq)(userdata) +// buf := req.hyperBuf.Bytes() +// len := req.hyperBuf.Len() +// n, err := req.req.Body.Read(unsafe.Slice(buf, len)) +// if err != nil { +// if err == io.EOF { +// *chunk = nil +// return hyper.PollReady +// } +// fmt.Println("error reading upload file: ", err) +// return hyper.PollError +// } +// if n > 0 { +// *chunk = req.hyperBuf +// return hyper.PollReady +// } +// if n == 0 { +// *chunk = nil +// return hyper.PollReady +// } +// +// fmt.Printf("error reading request body: %s\n", c.GoString(c.Strerror(os.Errno))) +// return hyper.PollError +//} // setHeaders sets the headers of the request -func (req *Request) setHeaders(hyperReq *hyper.Request) error { +func (r *Request) setHeaders(hyperReq *hyper.Request) error { headers := hyperReq.Headers() - for key, values := range req.Header { + for key, values := range r.Header { valueLen := len(values) if valueLen > 1 { for _, value := range values { @@ -318,23 +233,6 @@ func (r *Request) closeBody() error { return r.Body.Close() } -func validMethod(method string) bool { - /* - Method = "OPTIONS" ; Section 9.2 - | "GET" ; Section 9.3 - | "HEAD" ; Section 9.4 - | "POST" ; Section 9.5 - | "PUT" ; Section 9.6 - | "DELETE" ; Section 9.7 - | "TRACE" ; Section 9.8 - | "CONNECT" ; Section 9.9 - | extension-method - extension-method = token - token = 1* - */ - return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1 -} - // Context returns the request's context. To change the context, use // Clone or WithContext. // @@ -387,6 +285,215 @@ func (r *Request) ProtoAtLeast(major, minor int) bool { r.ProtoMajor == major && r.ProtoMinor >= minor } +// errMissingHost is returned by Write when there is no Host or URL present in +// the Request. +var errMissingHost = errors.New("http: Request.Write on Request with no Host or URL set") + +// extraHeaders may be nil +// waitForContinue may be nil +// always closes body +func (r *Request) write(usingProxy bool, extraHeader Header, client *hyper.ClientConn, exec *hyper.Executor) (err error) { + //trace := httptrace.ContextClientTrace(r.Context()) + //if trace != nil && trace.WroteRequest != nil { + // defer func() { + // trace.WroteRequest(httptrace.WroteRequestInfo{ + // Err: err, + // }) + // }() + //} + + //closed := false + //defer func() { + // if closed { + // return + // } + // if closeErr := r.closeBody(); closeErr != nil && err == nil { + // err = closeErr + // } + //}() + + // Prepare the hyper.Request + hyperReq, err := r.newHyperRequest(usingProxy, extraHeader) + if err != nil { + return err + } + // Send it! + sendTask := client.Send(hyperReq) + setTaskId(sendTask, read) + sendRes := exec.Push(sendTask) + if sendRes != hyper.OK { + err = errors.New("failed to send the request") + } + return err +} + +func (r *Request) newHyperRequest(usingProxy bool, extraHeader Header) (*hyper.Request, error) { + // Find the target host. Prefer the Host: header, but if that + // is not given, use the host from the request URL. + // + // Clean the host, in case it arrives with unexpected stuff in it. + host := r.Host + if host == "" { + if r.URL == nil { + return nil, errMissingHost + } + host = r.URL.Host + } + host, err := PunycodeHostPort(host) + if err != nil { + return nil, err + } + // Validate that the Host header is a valid header in general, + // but don't validate the host itself. This is sufficient to avoid + // header or request smuggling via the Host field. + // The server can (and will, if it's a net/http server) reject + // the request if it doesn't consider the host valid. + if !ValidHostHeader(host) { + // Historically, we would truncate the Host header after '/' or ' '. + // Some users have relied on this truncation to convert a network + // address such as Unix domain socket path into a valid, ignored + // Host header (see https://go.dev/issue/61431). + // + // We don't preserve the truncation, because sending an altered + // header field opens a smuggling vector. Instead, zero out the + // Host header entirely if it isn't valid. (An empty Host is valid; + // see RFC 9112 Section 3.2.) + // + // Return an error if we're sending to a proxy, since the proxy + // probably can't do anything useful with an empty Host header. + if !usingProxy { + host = "" + } else { + return nil, errors.New("http: invalid Host header") + } + } + + // According to RFC 6874, an HTTP client, proxy, or other + // intermediary must remove any IPv6 zone identifier attached + // to an outgoing URI. + host = removeZone(host) + + ruri := r.URL.RequestURI() + if usingProxy && r.URL.Scheme != "" && r.URL.Opaque == "" { + ruri = r.URL.Scheme + "://" + host + ruri + } else if r.Method == "CONNECT" && r.URL.Path == "" { + // CONNECT requests normally give just the host and port, not a full URL. + ruri = host + if r.URL.Opaque != "" { + ruri = r.URL.Opaque + } + } + if stringContainsCTLByte(ruri) { + return nil, errors.New("net/http: can't write control character in Request.URL") + } + + + + + // Prepare the hyper request + hyperReq := hyper.NewRequest() + + // Set the request line, default HTTP/1.1 + if hyperReq.SetMethod(&[]byte(r.Method)[0], c.Strlen(c.AllocaCStr(r.Method))) != hyper.OK { + return nil, fmt.Errorf("error setting method %s\n", r.Method) + } + if hyperReq.SetURI(&[]byte(ruri)[0], c.Strlen(c.AllocaCStr(ruri))) != hyper.OK { + return nil, fmt.Errorf("error setting uri %s\n", ruri) + } + if hyperReq.SetVersion(c.Int(hyper.HTTPVersion11)) != hyper.OK { + return nil, fmt.Errorf("error setting httpversion %s\n", "HTTP/1.1") + } + + // Set the request headers + reqHeaders := hyperReq.Headers() + if reqHeaders.Set(&[]byte("Host")[0], c.Strlen(c.Str("Host")), &[]byte(host)[0], c.Strlen(c.AllocaCStr(host))) != hyper.OK { + return nil, fmt.Errorf("error setting header: Host: %s\n", host) + } + err = r.setHeaders(hyperReq) + if err != nil { + return nil, err + } + + if r.Body != nil { + // 100-continue + if r.ProtoAtLeast(1, 1) && r.Body != nil && r.expectsContinue() { + hyperReq.OnInformational(printInformational, nil) + } + + hyperReqBody := hyper.NewBody() + //buf := make([]byte, 2) + //hyperBuf := hyper.CopyBuf(&buf[0], uintptr(2)) + reqData := &postReq{ + req: r, + buf: make([]byte, defaultChunkSize), + //hyperBuf: hyperBuf, + } + hyperReqBody.SetUserdata(c.Pointer(reqData)) + hyperReqBody.SetDataFunc(setPostData) + //hyperReqBody.SetDataFunc(setPostDataNoCopy) + hyperReq.SetBody(hyperReqBody) + } + + return hyperReq, nil +} + +func printInformational(userdata c.Pointer, resp *hyper.Response) { + status := resp.Status() + fmt.Println("Informational (1xx): ", status) +} + +type postReq struct { + req *Request + buf []byte + //hyperBuf *hyper.Buf +} + +func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { + req := (*postReq)(userdata) + n, err := req.req.Body.Read(req.buf) + if err != nil { + if err == io.EOF { + println("EOF") + *chunk = nil + req.req.Body.Close() + return hyper.PollReady + } + fmt.Println("error reading request body: ", err) + return hyper.PollError + } + if n > 0 { + *chunk = hyper.CopyBuf(&req.buf[0], uintptr(n)) + return hyper.PollReady + } + if n == 0 { + println("n == 0") + *chunk = nil + req.req.Body.Close() + return hyper.PollReady + } + + req.req.Body.Close() + fmt.Printf("error reading request body: %s\n", c.GoString(c.Strerror(os.Errno))) + return hyper.PollError +} + +func validMethod(method string) bool { + /* + Method = "OPTIONS" ; Section 9.2 + | "GET" ; Section 9.3 + | "HEAD" ; Section 9.4 + | "POST" ; Section 9.5 + | "PUT" ; Section 9.6 + | "DELETE" ; Section 9.7 + | "TRACE" ; Section 9.8 + | "CONNECT" ; Section 9.9 + | extension-method + extension-method = token + token = 1* + */ + return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1 +} + // readCookies parses all "Cookie" values from the header h and // returns the successfully parsed Cookies. // @@ -446,3 +553,20 @@ func idnaASCII(v string) (string, error) { } return idna.Lookup.ToASCII(v) } + +// removeZone removes IPv6 zone identifier from host. +// E.g., "[fe80::1%en0]:8080" to "[fe80::1]:8080" +func removeZone(host string) string { + if !strings.HasPrefix(host, "[") { + return host + } + i := strings.LastIndex(host, "]") + if i < 0 { + return host + } + j := strings.LastIndex(host[:i], "%") + if j < 0 { + return host + } + return host[:j] + host[i:] +} diff --git a/x/net/http/response.go b/x/net/http/response.go index c647f20..8151ac2 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -32,7 +32,7 @@ func (r *Response) closeBody() { } } -func ReadResponse(hyperResp *hyper.Response, req *Request) (*Response, error) { +func ReadResponse(r *io.PipeReader, req *Request, hyperResp *hyper.Response) (*Response, error) { resp := &Response{ Request: req, Header: make(Header), @@ -42,7 +42,7 @@ func ReadResponse(hyperResp *hyper.Response, req *Request) (*Response, error) { fixPragmaCacheControl(req.Header) - err := readTransfer(resp) + err := readTransfer(resp, r) if err != nil { return nil, err } @@ -54,9 +54,9 @@ func readResponseLineAndHeader(resp *Response, hyperResp *hyper.Response) { rp := hyperResp.ReasonPhrase() rpLen := hyperResp.ReasonPhraseLen() + // Parse the first line of the response. resp.Status = strconv.Itoa(int(hyperResp.Status())) + " " + c.GoString((*int8)(c.Pointer(rp)), rpLen) resp.StatusCode = int(hyperResp.Status()) - version := int(hyperResp.Version()) resp.ProtoMajor, resp.ProtoMinor = splitTwoDigitNumber(version) resp.Proto = fmt.Sprintf("HTTP/%d.%d", resp.ProtoMajor, resp.ProtoMinor) diff --git a/x/net/http/server.go b/x/net/http/server.go new file mode 100644 index 0000000..f38cbd0 --- /dev/null +++ b/x/net/http/server.go @@ -0,0 +1,12 @@ +package http + +// maxPostHandlerReadBytes is the max number of Request.Body bytes not +// consumed by a handler that the server will read from the client +// in order to keep a connection alive. If there are more bytes than +// this then the server to be paranoid instead sends a "Connection: +// close" response. +// +// This number is approximately what a typical machine's TCP buffer +// size is anyway. (if we have the bytes on the machine, we might as +// well read them) +const maxPostHandlerReadBytes = 256 << 10 diff --git a/x/net/http/transfer.go b/x/net/http/transfer.go index 823ef7d..cf96f84 100644 --- a/x/net/http/transfer.go +++ b/x/net/http/transfer.go @@ -1,11 +1,13 @@ package http import ( + "errors" "fmt" "io" "net/textproto" "strconv" "strings" + "sync" "unicode/utf8" ) @@ -24,13 +26,49 @@ type transferReader struct { Trailer Header } -// unsupportedTEError reports unsupported transfer-encodings. -type unsupportedTEError struct { - err string +// parseTransferEncoding sets t.Chunked based on the Transfer-Encoding header. +func (t *transferReader) parseTransferEncoding() error { + raw, present := t.Header["Transfer-Encoding"] + if !present { + return nil + } + delete(t.Header, "Transfer-Encoding") + + // Issue 12785; ignore Transfer-Encoding on HTTP/1.0 requests. + if !t.protoAtLeast(1, 1) { + return nil + } + + // Like nginx, we only support a single Transfer-Encoding header field, and + // only if set to "chunked". This is one of the most security sensitive + // surfaces in HTTP/1.1 due to the risk of request smuggling, so we keep it + // strict and simple. + if len(raw) != 1 { + return &unsupportedTEError{fmt.Sprintf("too many transfer encodings: %q", raw)} + } + if !EqualFold(raw[0], "chunked") { + return &unsupportedTEError{fmt.Sprintf("unsupported transfer encoding: %q", raw[0])} + } + + // RFC 7230 3.3.2 says "A sender MUST NOT send a Content-Length header field + // in any message that contains a Transfer-Encoding header field." + // + // but also: "If a message is received with both a Transfer-Encoding and a + // Content-Length header field, the Transfer-Encoding overrides the + // Content-Length. Such a message might indicate an attempt to perform + // request smuggling (Section 9.5) or response splitting (Section 9.4) and + // ought to be handled as an error. A sender MUST remove the received + // Content-Length field prior to forwarding such a message downstream." + // + // Reportedly, these appear in the wild. + delete(t.Header, "Content-Length") + + t.Chunked = true + return nil } -func (uste *unsupportedTEError) Error() string { - return uste.err +func (t *transferReader) protoAtLeast(m, n int) bool { + return t.ProtoMajor > m || (t.ProtoMajor == m && t.ProtoMinor >= n) } // NoBody is an io.ReadCloser with no bytes. Read always returns EOF @@ -45,7 +83,17 @@ func (noBody) Read([]byte) (int, error) { return 0, io.EOF } func (noBody) Close() error { return nil } func (noBody) WriteTo(io.Writer) (int64, error) { return 0, nil } -func readTransfer(msg any) (err error) { +// unsupportedTEError reports unsupported transfer-encodings. +type unsupportedTEError struct { + err string +} + +func (uste *unsupportedTEError) Error() string { + return uste.err +} + +// msg is *Request or *Response. +func readTransfer(msg any, r *io.PipeReader) (err error) { t := &transferReader{RequestMethod: "GET"} // Unify input @@ -74,6 +122,11 @@ func readTransfer(msg any) (err error) { panic("unexpected type") } + // Default to HTTP/1.1 + if t.ProtoMajor == 0 && t.ProtoMinor == 0 { + t.ProtoMajor, t.ProtoMinor = 1, 1 + } + // Transfer-Encoding: chunked, and overriding Content-Length. if err = t.parseTransferEncoding(); err != nil { return err @@ -93,6 +146,7 @@ func readTransfer(msg any) (err error) { t.ContentLength = realLength } + // TODO(spongehah) Trailer(readTransfer) // Trailer //t.Trailer, err = fixTrailer(t.Header, t.Chunked) @@ -109,40 +163,42 @@ func readTransfer(msg any) (err error) { // Prepare body reader. ContentLength < 0 means chunked encoding // or close connection when finished, since multipart is not supported yet - //switch { - //case t.Chunked: - // if isResponse && noResponseBodyExpected(t.RequestMethod) || !bodyAllowedForStatus(t.StatusCode) { - // t.Body = NoBody - // } else { - // t.Body = &body{src: internal.NewChunkedReader(r), hdr: msg, r: r, closing: t.Close} - // } - //case realLength == 0: - // t.Body = NoBody - //case realLength > 0: - // t.Body = &body{src: io.LimitReader(r, realLength), closing: t.Close} - //default: - // // realLength < 0, i.e. "Content-Length" not mentioned in header - // if t.Close { - // // Close semantics (i.e. HTTP/1.0) - // t.Body = &body{src: r, closing: t.Close} - // } else { - // // Persistent connection (i.e. HTTP/1.1) - // t.Body = NoBody - // } - //} + switch { + case t.Chunked: + if isResponse && noResponseBodyExpected(t.RequestMethod) || !bodyAllowedForStatus(t.StatusCode) { + t.Body = NoBody + } else { + // TODO(spongehah) ChunkReader(readTransfer) + //t.Body = &body{src: internal.NewChunkedReader(r), hdr: msg, r: r, closing: t.Close} + t.Body = &body{src: r, hdr: msg, r: r, closing: t.Close} + } + case realLength == 0: + t.Body = NoBody + case realLength > 0: + t.Body = &body{src: io.LimitReader(r, realLength), closing: t.Close} + default: + // realLength < 0, i.e. "Content-Length" not mentioned in header + if t.Close { + // Close semantics (i.e. HTTP/1.0) + t.Body = &body{src: r, closing: t.Close} + } else { + // Persistent connection (i.e. HTTP/1.1) + t.Body = NoBody + } + } // Unify output switch rr := msg.(type) { case *Request: - //rr.Body = t.Body - //rr.ContentLength = t.ContentLength - //if t.Chunked { - // rr.TransferEncoding = []string{"chunked"} - //} + rr.Body = t.Body + rr.ContentLength = t.ContentLength + if t.Chunked { + rr.TransferEncoding = []string{"chunked"} + } rr.Close = t.Close //rr.Trailer = t.Trailer case *Response: - //rr.Body = t.Body + rr.Body = t.Body rr.ContentLength = t.ContentLength if t.Chunked { rr.TransferEncoding = []string{"chunked"} @@ -154,51 +210,6 @@ func readTransfer(msg any) (err error) { return nil } -// parseTransferEncoding sets t.Chunked based on the Transfer-Encoding header. -func (t *transferReader) parseTransferEncoding() error { - raw, present := t.Header["Transfer-Encoding"] - if !present { - return nil - } - delete(t.Header, "Transfer-Encoding") - - // Issue 12785; ignore Transfer-Encoding on HTTP/1.0 requests. - if !t.protoAtLeast(1, 1) { - return nil - } - - // Like nginx, we only support a single Transfer-Encoding header field, and - // only if set to "chunked". This is one of the most security sensitive - // surfaces in HTTP/1.1 due to the risk of request smuggling, so we keep it - // strict and simple. - if len(raw) != 1 { - return &unsupportedTEError{fmt.Sprintf("too many transfer encodings: %q", raw)} - } - if !EqualFold(raw[0], "chunked") { - return &unsupportedTEError{fmt.Sprintf("unsupported transfer encoding: %q", raw[0])} - } - - // RFC 7230 3.3.2 says "A sender MUST NOT send a Content-Length header field - // in any message that contains a Transfer-Encoding header field." - // - // but also: "If a message is received with both a Transfer-Encoding and a - // Content-Length header field, the Transfer-Encoding overrides the - // Content-Length. Such a message might indicate an attempt to perform - // request smuggling (Section 9.5) or response splitting (Section 9.4) and - // ought to be handled as an error. A sender MUST remove the received - // Content-Length field prior to forwarding such a message downstream." - // - // Reportedly, these appear in the wild. - delete(t.Header, "Content-Length") - - t.Chunked = true - return nil -} - -func (t *transferReader) protoAtLeast(m, n int) bool { - return t.ProtoMajor > m || (t.ProtoMajor == m && t.ProtoMinor >= n) -} - // Determine the expected body length, using RFC 7230 Section 3.3. This // function is not a method, because ultimately it should be shared by // ReadResponse and ReadRequest. @@ -329,6 +340,173 @@ func fixTrailer(header Header, chunked bool) (Header, error) { return trailer, nil } +// body turns a Reader into a ReadCloser. +// Close ensures that the body has been fully read +// and then reads the trailer if necessary. +type body struct { + src io.Reader + hdr any // non-nil (Response or Request) value means read trailer + //r *bufio.Reader // underlying wire-format reader for the trailer + r io.Reader // underlying wire-format reader for the trailer + closing bool // is the connection to be closed after reading body? + doEarlyClose bool // whether Close should stop early + + mu sync.Mutex // guards following, and calls to Read and Close + sawEOF bool + closed bool + earlyClose bool // Close called and we didn't read to the end of src + onHitEOF func() // if non-nil, func to call when EOF is Read +} + +// ErrBodyReadAfterClose is returned when reading a Request or Response +// Body after the body has been closed. This typically happens when the body is +// read after an HTTP Handler calls WriteHeader or Write on its +// ResponseWriter. +var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed Body") + +func (b *body) Read(p []byte) (n int, err error) { + b.mu.Lock() + defer b.mu.Unlock() + if b.closed { + return 0, ErrBodyReadAfterClose + } + return b.readLocked(p) +} + +// Must hold b.mu. +func (b *body) readLocked(p []byte) (n int, err error) { + if b.sawEOF { + return 0, io.EOF + } + n, err = b.src.Read(p) + + if err == io.EOF { + b.sawEOF = true + // Chunked case. Read the trailer. + if b.hdr != nil { + // TODO(spongehah) Trailer(b.readLocked) + //if e := b.readTrailer(); e != nil { + // err = e + // // Something went wrong in the trailer, we must not allow any + // // further reads of any kind to succeed from body, nor any + // // subsequent requests on the server connection. See + // // golang.org/issue/12027 + // b.sawEOF = false + // b.closed = true + //} + b.hdr = nil + } else { + // If the server declared the Content-Length, our body is a LimitedReader + // and we need to check whether this EOF arrived early. + if lr, ok := b.src.(*io.LimitedReader); ok && lr.N > 0 { + err = io.ErrUnexpectedEOF + } + } + } + + // If we can return an EOF here along with the read data, do + // so. This is optional per the io.Reader contract, but doing + // so helps the HTTP transport code recycle its connection + // earlier (since it will see this EOF itself), even if the + // client doesn't do future reads or Close. + if err == nil && n > 0 { + if lr, ok := b.src.(*io.LimitedReader); ok && lr.N == 0 { + err = io.EOF + b.sawEOF = true + } + } + + if b.sawEOF && b.onHitEOF != nil { + b.onHitEOF() + } + + return n, err +} + +// unreadDataSizeLocked returns the number of bytes of unread input. +// It returns -1 if unknown. +// b.mu must be held. +func (b *body) unreadDataSizeLocked() int64 { + if lr, ok := b.src.(*io.LimitedReader); ok { + return lr.N + } + return -1 +} + +func (b *body) Close() error { + b.mu.Lock() + defer b.mu.Unlock() + if b.closed { + return nil + } + var err error + switch { + case b.sawEOF: + // Already saw EOF, so no need going to look for it. + case b.hdr == nil && b.closing: + // no trailer and closing the connection next. + // no point in reading to EOF. + case b.doEarlyClose: + // Read up to maxPostHandlerReadBytes bytes of the body, looking + // for EOF (and trailers), so we can re-use this connection. + if lr, ok := b.src.(*io.LimitedReader); ok && lr.N > maxPostHandlerReadBytes { + // There was a declared Content-Length, and we have more bytes remaining + // than our maxPostHandlerReadBytes tolerance. So, give up. + b.earlyClose = true + } else { + var n int64 + // Consume the body, or, which will also lead to us reading + // the trailer headers after the body, if present. + n, err = io.CopyN(io.Discard, bodyLocked{b}, maxPostHandlerReadBytes) + if err == io.EOF { + err = nil + } + if n == maxPostHandlerReadBytes { + b.earlyClose = true + } + } + default: + // Fully consume the body, which will also lead to us reading + // the trailer headers after the body, if present. + _, err = io.Copy(io.Discard, bodyLocked{b}) + } + b.closed = true + return err +} + +// bodyLocked is an io.Reader reading from a *body when its mutex is +// already held. +type bodyLocked struct { + b *body +} + +func (bl bodyLocked) Read(p []byte) (n int, err error) { + if bl.b.closed { + return 0, ErrBodyReadAfterClose + } + return bl.b.readLocked(p) +} + +func (b *body) didEarlyClose() bool { + b.mu.Lock() + defer b.mu.Unlock() + return b.earlyClose +} + +// bodyRemains reports whether future Read calls might +// yield data. +func (b *body) bodyRemains() bool { + b.mu.Lock() + defer b.mu.Unlock() + return !b.sawEOF +} + +func (b *body) registerOnHitEOF(fn func()) { + b.mu.Lock() + defer b.mu.Unlock() + b.onHitEOF = fn +} + // foreachHeaderElement splits v according to the "#rule" construction // in RFC 7230 section 7 and calls fn for each non-empty element. func foreachHeaderElement(v string, fn func(string)) { diff --git a/x/net/http/transport.go b/x/net/http/transport.go index 52b7134..fe3efc3 100644 --- a/x/net/http/transport.go +++ b/x/net/http/transport.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "log" "net/url" "sync" "sync/atomic" @@ -13,9 +14,9 @@ import ( "github.com/goplus/llgo/c" "github.com/goplus/llgo/c/libuv" - "github.com/goplus/llgo/c/net" + cnet "github.com/goplus/llgo/c/net" "github.com/goplus/llgo/c/syscall" - xnet "github.com/goplus/llgo/x/net" + "github.com/goplus/llgoexamples/x/net" "github.com/goplus/llgoexamples/rust/hyper" ) @@ -204,6 +205,7 @@ func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectM if t.Proxy != nil { cm.proxyURL, err = t.Proxy(treq.Request) } + // TODO(spongehah) cm.treq(connectMethod) cm.treq = treq cm.onlyH1 = treq.requiresHTTP1() return cm, err @@ -352,7 +354,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { } for { - // TODO(spongehah) timeout: because of that ctx not initialized ( initialized in setRequestCancel() ) + // TODO(spongehah) timeout(t.RoundTrip): because of that ctx not initialized ( initialized in setRequestCancel() ) //select { //case <-ctx.Done(): // req.closeBody() @@ -394,7 +396,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { } // Failed. Clean up and determine whether to retry. - // TODO(spongehah) Retry & ConnPool + // TODO(spongehah) Retry & ConnPool(t.RoundTrip) return nil, err } } @@ -421,7 +423,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi } }() - // TODO(spongehah) ConnPool + // TODO(spongehah) ConnPool(t.getConn) //// Queue for idle connection. //if delivered := t.queueForIdleConn(w); delivered { // pc := w.pc @@ -524,7 +526,7 @@ func (t *Transport) dialConnFor(w *wantConn) { pc, err := t.dialConn(w.ctx, w.cm) w.tryDeliver(pc, err) - // TODO(spongehah) ConnPool + // TODO(spongehah) ConnPool(t.dialConnFor) //delivered := w.tryDeliver(pc, err) // Handle undelivered or shareable connections //if err == nil && (!delivered || pc.alt != nil) { @@ -601,6 +603,15 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers writeLoopDone: make(chan struct{}, 1), } + //trace := httptrace.ContextClientTrace(ctx) + //wrapErr := func(err error) error { + // if cm.proxyURL != nil { + // // Return a typed error, per Issue 16997 + // return &net.OpError{Op: "proxyconnect", Net: "tcp", Err: err} + // } + // return err + //} + //if cm.scheme() == "https" && t.hasCustomTLSDialer() { // var err error // pconn.conn, err = t.customDialTLS(ctx, "tcp", cm.addr()) @@ -628,7 +639,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers // } //} else { //conn, err := t.dial(ctx, "tcp", cm.addr()) - conn, err := t.dial(ctx, pconn, cm) + conn, err := t.dial(ctx, cm) if err != nil { return nil, err } @@ -642,9 +653,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers // Prepare client options opts := hyper.NewClientConnOptions() opts.Exec(exec) - pconn.io = hyperIo pconn.exec = exec - pconn.opts = opts // send the handshake handshakeTask := hyper.Handshake(hyperIo, opts) setTaskId(handshakeTask, write) @@ -662,7 +671,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers //} //} - // TODO(spongehah) Proxy(https/sock5) + // TODO(spongehah) Proxy(https/sock5)(t.dialConn) // Proxy setup. switch { case cm.proxyURL == nil: @@ -700,7 +709,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers return pconn, nil } -func (t *Transport) dial(ctx context.Context, pconn *persistConn, cm connectMethod) (*connData, error) { +func (t *Transport) dial(ctx context.Context, cm connectMethod) (*connData, error) { treq := cm.treq host := treq.URL.Hostname() port := treq.URL.Port() @@ -724,16 +733,17 @@ func (t *Transport) dial(ctx context.Context, pconn *persistConn, cm connectMeth conn.TimeoutTimer.Start(onTimeout, uint64(treq.timeout.Milliseconds()), 0) } + libuv.InitTcp(loop, &conn.TcpHandle) libuv.InitTcp(loop, &conn.TcpHandle) (*libuv.Handle)(c.Pointer(&conn.TcpHandle)).SetData(c.Pointer(conn)) - var hints net.AddrInfo + var hints cnet.AddrInfo c.Memset(c.Pointer(&hints), 0, unsafe.Sizeof(hints)) hints.Family = syscall.AF_UNSPEC hints.SockType = syscall.SOCK_STREAM - var res *net.AddrInfo - status := net.Getaddrinfo(c.AllocaCStr(host), c.AllocaCStr(port), &hints, &res) + var res *cnet.AddrInfo + status := cnet.Getaddrinfo(c.AllocaCStr(host), c.AllocaCStr(port), &hints, &res) if status != 0 { close(treq.Request.timeoutch) return nil, fmt.Errorf("getaddrinfo error\n") @@ -746,14 +756,14 @@ func (t *Transport) dial(ctx context.Context, pconn *persistConn, cm connectMeth return nil, fmt.Errorf("connect error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(status)))) } - net.Freeaddrinfo(res) + cnet.Freeaddrinfo(res) return conn, nil } func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { testHookEnterRoundTrip() if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) { - // TODO(spongehah) ConnPool + // TODO(spongehah) ConnPool(pc.roundTrip) //pc.t.putOrCloseIdleConn(pc) return nil, errRequestCanceled } @@ -816,7 +826,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // request body. // In Hyper, the writeLoop() and readLoop() are combined together --> readWriteLoop(). - startBytesWritten := pc.nwrite + startBytesWritten := pc.conn.nwrite writeErrCh := make(chan error, 1) pc.writech <- writeRequest{req: req, ch: writeErrCh} @@ -890,8 +900,42 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // readWriteLoop handles the main I/O loop for a persistent connection. // It processes incoming requests, sends them to the server, and handles responses. func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { + // writeLoop related defer close(pc.writeLoopDone) + // readLoop related + closeErr := errReadLoopExiting // default value, if not changed below + defer func() { + pc.close(closeErr) + // TODO(spongehah) ConnPool(readWriteLoop) + //pc.t.removeIdleConn(pc) + }() + + //tryPutIdleConn := func(trace *httptrace.ClientTrace) bool { + // if err := pc.t.tryPutIdleConn(pc); err != nil { + // closeErr = err + // if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { + // trace.PutIdleConn(err) + // } + // return false + // } + // if trace != nil && trace.PutIdleConn != nil { + // trace.PutIdleConn(nil) + // } + // return true + //} + + // eofc is used to block caller goroutines reading from Response.Body + // at EOF until this goroutines has (potentially) added the connection + // back to the idle pool. + eofc := make(chan struct{}, 1) + defer close(eofc) // unblock reader on errors + + // Read this once, before loop starts. (to avoid races in tests) + testHookMu.Lock() + testHookReadLoopBeforeNextRead := testHookReadLoopBeforeNextRead + testHookMu.Unlock() + const debugReadWriteLoop = true // Debug switch provided for developers if debugReadWriteLoop { @@ -918,7 +962,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { } taskId := (taskId)(uintptr(task.Userdata())) if debugReadWriteLoop { - println(taskId) + println("taskId: ", taskId) } switch taskId { case write: @@ -927,31 +971,16 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { } wr := <-pc.writech // blocking - startBytesWritten := pc.nwrite - + startBytesWritten := pc.conn.nwrite err := checkTaskType(task, write) - if err != nil { - wr.ch <- err - // Free the resources - //freeResources(task, respBody, bodyWriter, pc.exec, pc, rc) - pc.close(err) - return - } - client := (*hyper.ClientConn)(task.Value()) task.Free() - - // Prepare the hyper.Request - hyperReq, err := newHyperRequest(wr.req.Request) if err == nil { - // Send it! - sendTask := client.Send(hyperReq) - setTaskId(sendTask, read) - sendRes := pc.exec.Push(sendTask) - if sendRes != hyper.OK { - err = errors.New("failed to send the request") - } + // TODO(spongehah) Proxy(writeLoop) + err = wr.req.Request.write(pc.isProxy, wr.req.extra, client, pc.exec) } + // For this request, no longer need the client + client.Free() if bre, ok := err.(requestBodyReadError); ok { err = bre.error // Errors reading from the user's @@ -964,19 +993,15 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { wr.req.setError(err) } if err != nil { - if pc.nwrite == startBytesWritten { + if pc.conn.nwrite == startBytesWritten { err = nothingWrittenError{err} } //pc.writeErrCh <- err // to the body reader, which might recycle us wr.ch <- err // to the roundTrip function - // Free the resources - //freeResources(task, respBody, bodyWriter, pc.exec, pc, rc) pc.close(err) return } - // For this example, no longer need the client - client.Free() if debugReadWriteLoop { println("write end") } @@ -985,39 +1010,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { println("read") } - rc := <-pc.reqch // blocking - - closeErr := errReadLoopExiting // default value, if not changed below - defer func() { - pc.close(closeErr) - // TODO(spongehah) ConnPool(readWriteLoop) - //pc.t.removeIdleConn(pc) - }() - - //tryPutIdleConn := func(trace *httptrace.ClientTrace) bool { - // if err := pc.t.tryPutIdleConn(pc); err != nil { - // closeErr = err - // if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { - // trace.PutIdleConn(err) - // } - // return false - // } - // if trace != nil && trace.PutIdleConn != nil { - // trace.PutIdleConn(nil) - // } - // return true - //} - - // eofc is used to block caller goroutines reading from Response.Body - // at EOF until this goroutines has (potentially) added the connection - // back to the idle pool. - eofc := make(chan struct{}, 1) - defer close(eofc) // unblock reader on errors - - // Read this once, before loop starts. (to avoid races in tests) - testHookMu.Lock() - testHookReadLoopBeforeNextRead := testHookReadLoopBeforeNextRead - testHookMu.Unlock() + err := checkTaskType(task, read) pc.mu.Lock() if pc.numExpectedResponses == 0 { @@ -1027,7 +1020,9 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { } pc.mu.Unlock() - err := checkTaskType(task, read) + rc := <-pc.reqch // blocking + //trace := httptrace.ContextClientTrace(rc.req.Context()) + // Take the results hyperResp := (*hyper.Response)(task.Value()) task.Free() @@ -1035,22 +1030,24 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { var resp *Response var respBody *hyper.Body if err == nil { - resp, err = ReadResponse(hyperResp, rc.req) + var pr *io.PipeReader + pr, bodyWriter = io.Pipe() + resp, err = ReadResponse(pr, rc.req, hyperResp) respBody = hyperResp.Body() - resp.Body, bodyWriter = io.Pipe() } else { err = transportReadFromServerError{err} closeErr = err } + // No longer need the response + hyperResp.Free() + if err != nil { select { case rc.ch <- responseAndError{err: err}: case <-rc.callerGone: return } - // Free the resources - freeResources(task, respBody, bodyWriter, pc.exec, pc, rc) return } @@ -1129,7 +1126,7 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { } resp.Body = body - // TODO(spongehah) gzip fail + // TODO(spongehah) gzip fail(readWriteLoop) if rc.addedGzip && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") { resp.Body = &gzipReader{body: body} resp.Header.Del("Content-Encoding") @@ -1140,12 +1137,11 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { rw.waitForBodyRead = waitForBodyRead rw.rc = rc - rw.eofc = eofc bodyForeachTask := respBody.Foreach(appendToResponseBody, c.Pointer(bodyWriter)) setTaskId(bodyForeachTask, readDone) pc.exec.Push(bodyForeachTask) - // TODO(spongehah) select blocking + // TODO(spongehah) select blocking(readWriteLoop) //select { //case rc.ch <- responseAndError{res: resp}: //case <-rc.callerGone: @@ -1153,46 +1149,31 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { //} rc.ch <- responseAndError{res: resp} - // No longer need the response - hyperResp.Free() if debugReadWriteLoop { println("read end") } - - //pc.t.replaceReqCanceler(rc.cancelKey, nil) - //eofc <- struct{}{} case readDone: // A background task of reading the response body is completed if debugReadWriteLoop { println("readDone") } - err := checkTaskType(task, readDone) - if err != nil { - fmt.Println(err) - pc.close(err) - alive = false - } - - if task.Type() != hyper.TaskEmpty { - err = errors.New("unexpected task type\n") - fmt.Println(err) - pc.close(err) - return + if bodyWriter != nil { + bodyWriter.Close() } + checkTaskType(task, readDone) + hyperBodyEOF := task.Type() == hyper.TaskEmpty // free the task task.Free() - bodyWriter.Close() // Before looping back to the top of this function and peeking on // the bufio.Reader, wait for the caller goroutine to finish // reading the response body. (or for cancellation or death) - rc := rw.rc select { - //case bodyEOF := <-rw.waitForBodyRead: - case <-rw.waitForBodyRead: + case bodyEOF := <-rw.waitForBodyRead: + bodyEOF = bodyEOF && hyperBodyEOF //replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool - pc.t.replaceReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool + pc.t.replaceReqCanceler(rw.rc.cancelKey, nil) // before pc might return to idle pool // TODO(spongehah) ConnPool(readWriteLoop) //alive = alive && // bodyEOF && @@ -1200,14 +1181,14 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { // pc.wroteRequest() && // replaced && tryPutIdleConn(trace) - rw.eofc <- struct{}{} + eofc <- struct{}{} // TODO(spongehah) cancel(pc.readWriteLoop) - //case <-rc.req.Cancel: + //case <-rw.rc.req.Cancel: // alive = false - // pc.t.CancelRequest(rc.req) - //case <-rc.req.Context().Done(): + // pc.t.CancelRequest(rw.rc.req) + //case <-rw.rc.req.Context().Done(): // alive = false - // pc.t.cancelRequest(rc.cancelKey, rc.req.Context().Err()) + // pc.t.cancelRequest(rw.rc.cancelKey, rw.rc.req.Context().Err()) case <-pc.closech: alive = false } @@ -1232,12 +1213,25 @@ type connData struct { ReadBuf libuv.Buf TimeoutTimer libuv.Timer ReadBufFilled uintptr + nwrite int64 // bytes written(Replaced from persistConn's nwrite) ReadWaker *hyper.Waker WriteWaker *hyper.Waker } func (conn *connData) Close() error { - freeConnData(conn) + if conn.ReadWaker != nil { + conn.ReadWaker.Free() + conn.ReadWaker = nil + } + if conn.WriteWaker != nil { + conn.WriteWaker.Free() + conn.WriteWaker = nil + } + if conn.ReadBuf.Base != nil { + c.Free(c.Pointer(conn.ReadBuf.Base)) + conn.ReadBuf.Base = nil + } + (*libuv.Handle)(c.Pointer(&conn.TcpHandle)).Close(nil) return nil } @@ -1352,6 +1346,7 @@ func writeCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen ui ret := req.Write((*libuv.Stream)(c.Pointer(&conn.TcpHandle)), &initBuf, 1, onWrite) // If the write operation was successfully initiated if ret >= 0 { + conn.nwrite += int64(bufLen) // Return the number of bytes to be written return bufLen } @@ -1397,7 +1392,6 @@ const ( type readWaiter struct { rc requestAndChan waitForBodyRead chan bool - eofc chan struct{} } // setTaskId Set taskId to the task's userdata as a unique identifier @@ -1411,95 +1405,51 @@ func checkTaskType(task *hyper.Task, curTaskId taskId) error { switch curTaskId { case write: if task.Type() == hyper.TaskError { - c.Printf(c.Str("handshake task error!\n")) + log.Printf("[readWriteLoop::write]handshake task error!\n") return fail((*hyper.Error)(task.Value())) } if task.Type() != hyper.TaskClientConn { - return fmt.Errorf("unexpected task type\n") + return fmt.Errorf("[readWriteLoop::write]unexpected task type\n") } return nil case read: if task.Type() == hyper.TaskError { - c.Printf(c.Str("write task error!\n")) + log.Printf("[readWriteLoop::read]write task error!\n") return fail((*hyper.Error)(task.Value())) } if task.Type() != hyper.TaskResponse { - c.Printf(c.Str("unexpected task type\n")) - return fmt.Errorf("unexpected task type\n") + c.Printf(c.Str("[readWriteLoop::read]unexpected task type\n")) + return errors.New("[readWriteLoop::read]unexpected task type\n") } return nil case readDone: if task.Type() == hyper.TaskError { - c.Printf(c.Str("read error!\n")) + log.Printf("[readWriteLoop::readDone]read response body error!\n") return fail((*hyper.Error)(task.Value())) } return nil case notSet: } - return fmt.Errorf("unexpected TaskId\n") + return errors.New("[readWriteLoop]unexpected task type\n") } // fail prints the error details and panics func fail(err *hyper.Error) error { if err != nil { - c.Printf(c.Str("error code: %d\n"), err.Code()) + c.Printf(c.Str("[readWriteLoop]error code: %d\n"), err.Code()) // grab the error details var errBuf [256]c.Char errLen := err.Print((*uint8)(c.Pointer(&errBuf[:][0])), uintptr(len(errBuf))) - c.Printf(c.Str("details: %.*s\n"), c.Int(errLen), c.Pointer(&errBuf[:][0])) + c.Printf(c.Str("[readWriteLoop]details: %.*s\n"), c.Int(errLen), c.Pointer(&errBuf[:][0])) // clean up the error err.Free() - return fmt.Errorf("hyper request error, error code: %d\n", int(err.Code())) + return fmt.Errorf("[readWriteLoop]hyper request error, error code: %d\n", int(err.Code())) } return nil } -// freeResources frees the resources -func freeResources(task *hyper.Task, respBody *hyper.Body, bodyWriter *io.PipeWriter, exec *hyper.Executor, pc *persistConn, rc requestAndChan) { - // Cleaning up before exiting - if task != nil { - task.Free() - } - if respBody != nil { - respBody.Free() - } - if bodyWriter != nil { - bodyWriter.Close() - } - if exec != nil { - exec.Free() - } - (*libuv.Handle)(c.Pointer(&pc.conn.TcpHandle)).Close(nil) - freeConnData(pc.conn) - - closeChannels(rc, pc) -} - -// closeChannels closes the channels -func closeChannels(rc requestAndChan, pc *persistConn) { - // Closing the channel - close(rc.ch) - close(pc.reqch) -} - -// freeConnData frees the connection data -func freeConnData(conn *connData) { - if conn.ReadWaker != nil { - conn.ReadWaker.Free() - conn.ReadWaker = nil - } - if conn.WriteWaker != nil { - conn.WriteWaker.Free() - conn.WriteWaker = nil - } - if conn.ReadBuf.Base != nil { - c.Free(c.Pointer(conn.ReadBuf.Base)) - conn.ReadBuf.Base = nil - } -} - // ---------------------------------------------------------- // error values for debugging and testing, not seen by users. @@ -1614,7 +1564,7 @@ func canonicalAddr(url *url.URL) string { if port == "" { port = portMap[url.Scheme] } - return xnet.JoinHostPort(idnaASCIIFromURL(url), port) + return net.JoinHostPort(idnaASCIIFromURL(url), port) } // persistConn wraps a connection, usually a persistent one @@ -1625,22 +1575,18 @@ type persistConn struct { // If it's non-nil, the rest of the fields are unused. alt RoundTripper - //br *bufio.Reader // from conn - //bw *bufio.Writer // to conn - //nwrite int64 // bytes written - //writech chan writeRequest // written by roundTrip; read by writeLoop - //closech chan struct{} // closed when conn closed - - t *Transport - cacheKey connectMethodKey - conn *connData - nwrite int64 // bytes written - reqch chan requestAndChan // written by roundTrip; read by readWriteLoop - writech chan writeRequest // written by roundTrip; read by writeLoop(Already merged into reqch) - closech chan struct{} // closed when conn closed - writeLoopDone chan struct{} // closed when write loop ends - - isProxy bool + t *Transport + cacheKey connectMethodKey + conn *connData + //tlsState *tls.ConnectionState + //nwrite int64 // bytes written(Replaced by connData.nwrite) + reqch chan requestAndChan // written by roundTrip; read by readWriteLoop + writech chan writeRequest // written by roundTrip; read by readWriteLoop + closech chan struct{} // closed when conn closed + isProxy bool + + writeLoopDone chan struct{} // closed when readWriteLoop ends + mu sync.Mutex // guards following fields numExpectedResponses int closed error // set non-nil when conn is closed, before closech is closed @@ -1653,8 +1599,6 @@ type persistConn struct { // hyper specific exec *hyper.Executor - opts *hyper.ClientConnOptions - io *hyper.Io } func (pc *persistConn) cancelRequest(err error) { @@ -1693,6 +1637,10 @@ func (pc *persistConn) closeLocked(err error) { } } pc.mutateHeaderFunc = nil + // hyper related + if pc.exec != nil { + pc.exec.Free() + } } // mapRoundTripError returns the appropriate error value for @@ -1738,14 +1686,14 @@ func (pc *persistConn) mapRoundTripError(req *transportRequest, startBytesWritte } if _, ok := err.(transportReadFromServerError); ok { - if pc.nwrite == startBytesWritten { + if pc.conn.nwrite == startBytesWritten { return nothingWrittenError{err} } // Don't decorate return err } if pc.isBroken() { - if pc.nwrite == startBytesWritten { + if pc.conn.nwrite == startBytesWritten { return nothingWrittenError{err} } return fmt.Errorf("net/http: HTTP/1.x transport connection broken: %w", err) @@ -1887,7 +1835,7 @@ func (w *wantConn) cancel(t *Transport, err error) { w.err = err w.mu.Unlock() - // TODO(spongehah) ConnPool + // TODO(spongehah) ConnPool(w.cancel) //if pc != nil { // t.putOrCloseIdleConn(pc) //} diff --git a/x/net/http/util.go b/x/net/http/util.go index f2efb70..bfd9fc3 100644 --- a/x/net/http/util.go +++ b/x/net/http/util.go @@ -3,6 +3,11 @@ package http import ( "strings" "unicode" + "unicode/utf8" + + "golang.org/x/net/idna" + + "github.com/goplus/llgoexamples/x/net" ) /** @@ -206,6 +211,99 @@ func headerValueContainsToken(v string, token string) bool { // httpguts.headerV return tokenEqual(trimOWS(v), token) } +// PunycodeHostPort returns the IDNA Punycode version +// of the provided "host" or "host:port" string. +func PunycodeHostPort(v string) (string, error) { // httpguts.PunycodeHostPort + if isASCII(v) { + return v, nil + } + + host, port, err := net.SplitHostPort(v) + if err != nil { + // The input 'v' argument was just a "host" argument, + // without a port. This error should not be returned + // to the caller. + host = v + port = "" + } + host, err = idna.ToASCII(host) + if err != nil { + // Non-UTF-8? Not representable in Punycode, in any + // case. + return "", err + } + if port == "" { + return host, nil + } + return net.JoinHostPort(host, port), nil +} + +func isASCII(s string) bool { // httpguts.isASCII + for i := 0; i < len(s); i++ { + if s[i] >= utf8.RuneSelf { + return false + } + } + return true +} + +// ValidHostHeader reports whether h is a valid host header. +func ValidHostHeader(h string) bool { // httpguts.ValidHostHeader + // The latest spec is actually this: + // + // http://tools.ietf.org/html/rfc7230#section-5.4 + // Host = uri-host [ ":" port ] + // + // Where uri-host is: + // http://tools.ietf.org/html/rfc3986#section-3.2.2 + // + // But we're going to be much more lenient for now and just + // search for any byte that's not a valid byte in any of those + // expressions. + for i := 0; i < len(h); i++ { + if !validHostByte[h[i]] { + return false + } + } + return true +} + +// See the validHostHeader comment. +var validHostByte = [256]bool{ // httpguts.validHostByte + '0': true, '1': true, '2': true, '3': true, '4': true, '5': true, '6': true, '7': true, + '8': true, '9': true, + + 'a': true, 'b': true, 'c': true, 'd': true, 'e': true, 'f': true, 'g': true, 'h': true, + 'i': true, 'j': true, 'k': true, 'l': true, 'm': true, 'n': true, 'o': true, 'p': true, + 'q': true, 'r': true, 's': true, 't': true, 'u': true, 'v': true, 'w': true, 'x': true, + 'y': true, 'z': true, + + 'A': true, 'B': true, 'C': true, 'D': true, 'E': true, 'F': true, 'G': true, 'H': true, + 'I': true, 'J': true, 'K': true, 'L': true, 'M': true, 'N': true, 'O': true, 'P': true, + 'Q': true, 'R': true, 'S': true, 'T': true, 'U': true, 'V': true, 'W': true, 'X': true, + 'Y': true, 'Z': true, + + '!': true, // sub-delims + '$': true, // sub-delims + '%': true, // pct-encoded (and used in IPv6 zones) + '&': true, // sub-delims + '(': true, // sub-delims + ')': true, // sub-delims + '*': true, // sub-delims + '+': true, // sub-delims + ',': true, // sub-delims + '-': true, // unreserved + '.': true, // unreserved + ':': true, // IPv6address + Host expression's optional port + ';': true, // sub-delims + '=': true, // sub-delims + '[': true, + '\'': true, // sub-delims + ']': true, + '_': true, // unreserved + '~': true, // unreserved +} + // IsPrint returns whether s is ASCII and printable according to // https://tools.ietf.org/html/rfc20#section-4.2. func IsPrint(s string) bool { // ascii.IsPrint diff --git a/x/net/ipsock.go b/x/net/ipsock.go index 855a864..55e1b45 100644 --- a/x/net/ipsock.go +++ b/x/net/ipsock.go @@ -1,5 +1,11 @@ package net +import ( + "unsafe" + + "github.com/goplus/llgo/c" +) + // JoinHostPort combines host and port into a network address of the // form "host:port". If host contains a colon, as found in literal // IPv6 addresses, then JoinHostPort returns "[host]:port". @@ -8,17 +14,82 @@ package net func JoinHostPort(host, port string) string { // We assume that host is a literal IPv6 address if host has // colons. + if IndexByteString(host, ':') >= 0 { return "[" + host + "]:" + port } return host + ":" + port } -func IndexByteString(s string, c byte) int { - for i := 0; i < len(s); i++ { - if s[i] == c { - return i +// SplitHostPort splits a network address of the form "host:port", +// "host%zone:port", "[host]:port" or "[host%zone]:port" into host or +// host%zone and port. +// +// A literal IPv6 address in hostport must be enclosed in square +// brackets, as in "[::1]:80", "[::1%lo0]:80". +// +// See func Dial for a description of the hostport parameter, and host +// and port results. +func SplitHostPort(hostport string) (host, port string, err error) { + const ( + missingPort = "missing port in address" + tooManyColons = "too many colons in address" + ) + addrErr := func(addr, why string) (host, port string, err error) { + return "", "", &AddrError{Err: why, Addr: addr} + } + j, k := 0, 0 + + // The port starts after the last colon. + i := last(hostport, ':') + if i < 0 { + return addrErr(hostport, missingPort) + } + + if hostport[0] == '[' { + // Expect the first ']' just before the last ':'. + end := IndexByteString(hostport, ']') + if end < 0 { + return addrErr(hostport, "missing ']' in address") } + switch end + 1 { + case len(hostport): + // There can't be a ':' behind the ']' now. + return addrErr(hostport, missingPort) + case i: + // The expected result. + default: + // Either ']' isn't followed by a colon, or it is + // followed by a colon that is not the last one. + if hostport[end+1] == ':' { + return addrErr(hostport, tooManyColons) + } + return addrErr(hostport, missingPort) + } + host = hostport[1:end] + j, k = 1, end+1 // there can't be a '[' resp. ']' before these positions + } else { + host = hostport[:i] + if IndexByteString(host, ':') >= 0 { + return addrErr(hostport, tooManyColons) + } + } + if IndexByteString(hostport[j:], '[') >= 0 { + return addrErr(hostport, "unexpected '[' in address") + } + if IndexByteString(hostport[k:], ']') >= 0 { + return addrErr(hostport, "unexpected ']' in address") + } + + port = hostport[i+1:] + return host, port, nil +} + +func IndexByteString(s string, ch byte) int { // bytealg.IndexByteString + ptr := unsafe.Pointer(unsafe.StringData(s)) + ret := c.Memchr(ptr, c.Int(ch), uintptr(len(s))) + if ret != nil { + return int(uintptr(ret) - uintptr(ptr)) } return -1 } diff --git a/x/net/net.go b/x/net/net.go new file mode 100644 index 0000000..3267d90 --- /dev/null +++ b/x/net/net.go @@ -0,0 +1,20 @@ +package net + +type AddrError struct { + Err string + Addr string +} + +func (e *AddrError) Error() string { + if e == nil { + return "" + } + s := e.Err + if e.Addr != "" { + s = "address " + e.Addr + ": " + s + } + return s +} + +func (e *AddrError) Timeout() bool { return false } +func (e *AddrError) Temporary() bool { return false } diff --git a/x/net/parse.go b/x/net/parse.go new file mode 100644 index 0000000..c110fcf --- /dev/null +++ b/x/net/parse.go @@ -0,0 +1,12 @@ +package net + +// Index of rightmost occurrence of b in s. +func last(s string, b byte) int { + i := len(s) + for i--; i >= 0; i-- { + if s[i] == b { + break + } + } + return i +} From cebff180c6ce1c390da2bb17ed0ad3372cfbf18d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Mon, 26 Aug 2024 17:46:17 +0800 Subject: [PATCH 23/55] WIP(x/http/client): Categorize and write req.headers & perform unwrapBody operation on req.Reader --- x/net/http/_demo/get/get.go | 1 - x/net/http/_demo/headers/headers.go | 3 +- x/net/http/_demo/upload/upload.go | 5 +- x/net/http/header.go | 91 +++++++++++++++ x/net/http/request.go | 172 ++++++++++------------------ x/net/http/server.go | 7 ++ x/net/http/transfer.go | 110 ++++++++++++++++++ x/net/http/transport.go | 59 +++++----- 8 files changed, 302 insertions(+), 146 deletions(-) diff --git a/x/net/http/_demo/get/get.go b/x/net/http/_demo/get/get.go index 79c18ba..6bc5b06 100644 --- a/x/net/http/_demo/get/get.go +++ b/x/net/http/_demo/get/get.go @@ -14,7 +14,6 @@ func main() { return } fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) - fmt.Println(resp.Proto) resp.PrintHeaders() body, err := io.ReadAll(resp.Body) if err != nil { diff --git a/x/net/http/_demo/headers/headers.go b/x/net/http/_demo/headers/headers.go index 71d42b7..aa2e5d6 100644 --- a/x/net/http/_demo/headers/headers.go +++ b/x/net/http/_demo/headers/headers.go @@ -16,7 +16,7 @@ func main() { } //req.Header.Set("accept", "*/*") - req.Header.Set("accept-encoding", "identity") + req.Header.Set("accept-encoding", "gzip") //req.Header.Set("cache-control", "no-cache") //req.Header.Set("pragma", "no-cache") //req.Header.Set("priority", "u=0, i") @@ -36,6 +36,7 @@ func main() { println(err.Error()) return } + fmt.Println(resp.Status) resp.PrintHeaders() body, err := io.ReadAll(resp.Body) if err != nil { diff --git a/x/net/http/_demo/upload/upload.go b/x/net/http/_demo/upload/upload.go index 86b57e9..fe7256b 100644 --- a/x/net/http/_demo/upload/upload.go +++ b/x/net/http/_demo/upload/upload.go @@ -10,7 +10,9 @@ import ( func main() { url := "http://httpbin.org/post" + //url := "http://localhost:8080" filePath := "/Users/spongehah/go/src/llgo/x/net/http/_demo/upload/example.txt" // Replace with your file path + //filePath := "/Users/spongehah/Downloads/xiaoshuo.txt" // Replace with your file path file, err := os.Open(filePath) if err != nil { @@ -33,7 +35,8 @@ func main() { return } defer resp.Body.Close() - + fmt.Println("Status:", resp.Status) + resp.PrintHeaders() respBody, err := io.ReadAll(resp.Body) if err != nil { fmt.Println(err) diff --git a/x/net/http/header.go b/x/net/http/header.go index 6515c48..0d1e2cc 100644 --- a/x/net/http/header.go +++ b/x/net/http/header.go @@ -3,6 +3,9 @@ package http import ( "fmt" "net/textproto" + "sort" + "strings" + "sync" "github.com/goplus/llgo/c" "github.com/goplus/llgoexamples/rust/hyper" @@ -108,6 +111,94 @@ func (h Header) Clone() Header { return h2 } +var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ") + +type keyValues struct { + key string + values []string +} + +// A headerSorter implements sort.Interface by sorting a []keyValues +// by key. It's used as a pointer, so it can fit in a sort.Interface +// interface value without allocation. +type headerSorter struct { + kvs []keyValues +} + +func (s *headerSorter) Len() int { return len(s.kvs) } +func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] } +func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key } + +var headerSorterPool = sync.Pool{ + New: func() any { return new(headerSorter) }, +} + +// sortedKeyValues returns h's keys sorted in the returned kvs +// slice. The headerSorter used to sort is also returned, for possible +// return to headerSorterCache. +func (h Header) sortedKeyValues(exclude map[string]bool) (kvs []keyValues, hs *headerSorter) { + hs = headerSorterPool.Get().(*headerSorter) + if cap(hs.kvs) < len(h) { + hs.kvs = make([]keyValues, 0, len(h)) + } + kvs = hs.kvs[:0] + for k, vv := range h { + if !exclude[k] { + kvs = append(kvs, keyValues{k, vv}) + } + } + hs.kvs = kvs + sort.Sort(hs) + return kvs, hs +} + +// Write writes a header in wire format. +func (h Header) Write(reqHeaders *hyper.Headers) error { + return h.write(reqHeaders) +} + +func (h Header) write(reqHeaders *hyper.Headers) error { + return h.writeSubset(reqHeaders, nil) +} + +// WriteSubset writes a header in wire format. +// If exclude is not nil, keys where exclude[key] == true are not written. +// Keys are not canonicalized before checking the exclude map. +func (h Header) WriteSubset(reqHeaders *hyper.Headers, exclude map[string]bool) error { + return h.writeSubset(reqHeaders, exclude) +} + +func (h Header) writeSubset(reqHeaders *hyper.Headers, exclude map[string]bool) error { + kvs, sorter := h.sortedKeyValues(exclude) + for _, kv := range kvs { + if !ValidHeaderFieldName(kv.key) { + // This could be an error. In the common case of + // writing response headers, however, we have no good + // way to provide the error back to the server + // handler, so just drop invalid headers instead. + continue + } + for _, v := range kv.values { + v = headerNewlineToSpace.Replace(v) + v = textproto.TrimString(v) + if reqHeaders.Add(&[]byte(kv.key)[0], c.Strlen(c.AllocaCStr(kv.key)), &[]byte(v)[0], c.Strlen(c.AllocaCStr(v))) != hyper.OK { + headerSorterPool.Put(sorter) + return fmt.Errorf("error adding header %s: %s\n", kv.key, v) + } + //if trace != nil && trace.WroteHeaderField != nil { + // formattedVals = append(formattedVals, v) + //} + } + //if trace != nil && trace.WroteHeaderField != nil { + // trace.WroteHeaderField(kv.key, formattedVals) + // formattedVals = nil + //} + } + + headerSorterPool.Put(sorter) + return nil +} + // hasToken reports whether token appears with v, ASCII // case-insensitive, with space or comma boundaries. // token must be all lowercase. diff --git a/x/net/http/request.go b/x/net/http/request.go index be81f0e..6d74296 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -14,7 +14,6 @@ import ( "golang.org/x/net/idna" "github.com/goplus/llgo/c" - "github.com/goplus/llgo/c/os" "github.com/goplus/llgoexamples/rust/hyper" ) @@ -167,54 +166,6 @@ func NewRequestWithContext(ctx context.Context, method, urlStr string, body io.R return req, nil } -//func setPostDataNoCopy(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { -// req := (*postReq)(userdata) -// buf := req.hyperBuf.Bytes() -// len := req.hyperBuf.Len() -// n, err := req.req.Body.Read(unsafe.Slice(buf, len)) -// if err != nil { -// if err == io.EOF { -// *chunk = nil -// return hyper.PollReady -// } -// fmt.Println("error reading upload file: ", err) -// return hyper.PollError -// } -// if n > 0 { -// *chunk = req.hyperBuf -// return hyper.PollReady -// } -// if n == 0 { -// *chunk = nil -// return hyper.PollReady -// } -// -// fmt.Printf("error reading request body: %s\n", c.GoString(c.Strerror(os.Errno))) -// return hyper.PollError -//} - -// setHeaders sets the headers of the request -func (r *Request) setHeaders(hyperReq *hyper.Request) error { - headers := hyperReq.Headers() - for key, values := range r.Header { - valueLen := len(values) - if valueLen > 1 { - for _, value := range values { - if headers.Add(&[]byte(key)[0], c.Strlen(c.AllocaCStr(key)), &[]byte(value)[0], c.Strlen(c.AllocaCStr(value))) != hyper.OK { - return fmt.Errorf("error adding header %s: %s\n", key, value) - } - } - } else if valueLen == 1 { - if headers.Set(&[]byte(key)[0], c.Strlen(c.AllocaCStr(key)), &[]byte(values[0])[0], c.Strlen(c.AllocaCStr(values[0]))) != hyper.OK { - return fmt.Errorf("error setting header %s: %s\n", key, values[0]) - } - } else { - return fmt.Errorf("error setting header %s: empty value\n", key) - } - } - return nil -} - func (r *Request) expectsContinue() bool { return hasToken(r.Header.get("Expect"), "100-continue") } @@ -289,6 +240,21 @@ func (r *Request) ProtoAtLeast(major, minor int) bool { // the Request. var errMissingHost = errors.New("http: Request.Write on Request with no Host or URL set") +// NOTE: This is not intended to reflect the actual Go version being used. +// It was changed at the time of Go 1.1 release because the former User-Agent +// had ended up blocked by some intrusion detection systems. +// See https://codereview.appspot.com/7532043. +const defaultUserAgent = "Go-http-client/1.1" + +// Headers that Request.Write handles itself and should be skipped. +var reqWriteExcludeHeader = map[string]bool{ + "Host": true, // not in Header map anyway + "User-Agent": true, + "Content-Length": true, + "Transfer-Encoding": true, + "Trailer": true, +} + // extraHeaders may be nil // waitForContinue may be nil // always closes body @@ -302,16 +268,6 @@ func (r *Request) write(usingProxy bool, extraHeader Header, client *hyper.Clien // }() //} - //closed := false - //defer func() { - // if closed { - // return - // } - // if closeErr := r.closeBody(); closeErr != nil && err == nil { - // err = closeErr - // } - //}() - // Prepare the hyper.Request hyperReq, err := r.newHyperRequest(usingProxy, extraHeader) if err != nil { @@ -387,9 +343,6 @@ func (r *Request) newHyperRequest(usingProxy bool, extraHeader Header) (*hyper.R return nil, errors.New("net/http: can't write control character in Request.URL") } - - - // Prepare the hyper request hyperReq := hyper.NewRequest() @@ -409,29 +362,55 @@ func (r *Request) newHyperRequest(usingProxy bool, extraHeader Header) (*hyper.R if reqHeaders.Set(&[]byte("Host")[0], c.Strlen(c.Str("Host")), &[]byte(host)[0], c.Strlen(c.AllocaCStr(host))) != hyper.OK { return nil, fmt.Errorf("error setting header: Host: %s\n", host) } - err = r.setHeaders(hyperReq) + + // Use the defaultUserAgent unless the Header contains one, which + // may be blank to not send the header. + userAgent := defaultUserAgent + if r.Header.has("User-Agent") { + userAgent = r.Header.Get("User-Agent") + } + if userAgent != "" { + if reqHeaders.Set(&[]byte("User-Agent")[0], c.Strlen(c.Str("User-Agent")), &[]byte(userAgent)[0], c.Strlen(c.AllocaCStr(userAgent))) != hyper.OK { + return nil, fmt.Errorf("error setting header: User-Agent: %s\n", userAgent) + } + } + + // Process Body,ContentLength,Close,Trailer + //tw, err := newTransferWriter(r) + //if err != nil { + // return err + //} + //err = tw.writeHeader(w, trace) + err = r.writeHeader(reqHeaders) if err != nil { return nil, err } - if r.Body != nil { - // 100-continue - if r.ProtoAtLeast(1, 1) && r.Body != nil && r.expectsContinue() { - hyperReq.OnInformational(printInformational, nil) - } + err = r.Header.writeSubset(reqHeaders, reqWriteExcludeHeader) + if err != nil { + return nil, err + } - hyperReqBody := hyper.NewBody() - //buf := make([]byte, 2) - //hyperBuf := hyper.CopyBuf(&buf[0], uintptr(2)) - reqData := &postReq{ - req: r, - buf: make([]byte, defaultChunkSize), - //hyperBuf: hyperBuf, + if extraHeader != nil { + err = extraHeader.write(reqHeaders) + if err != nil { + return nil, err } - hyperReqBody.SetUserdata(c.Pointer(reqData)) - hyperReqBody.SetDataFunc(setPostData) - //hyperReqBody.SetDataFunc(setPostDataNoCopy) - hyperReq.SetBody(hyperReqBody) + } + + //if trace != nil && trace.WroteHeaders != nil { + // trace.WroteHeaders() + //} + + // Wait for 100-continue if expected. + if r.ProtoAtLeast(1, 1) && r.Body != nil && r.expectsContinue() { + hyperReq.OnInformational(printInformational, nil) + } + + // Write body and trailer + err = r.writeBody(hyperReq) + if err != nil { + return nil, err } return hyperReq, nil @@ -442,41 +421,6 @@ func printInformational(userdata c.Pointer, resp *hyper.Response) { fmt.Println("Informational (1xx): ", status) } -type postReq struct { - req *Request - buf []byte - //hyperBuf *hyper.Buf -} - -func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { - req := (*postReq)(userdata) - n, err := req.req.Body.Read(req.buf) - if err != nil { - if err == io.EOF { - println("EOF") - *chunk = nil - req.req.Body.Close() - return hyper.PollReady - } - fmt.Println("error reading request body: ", err) - return hyper.PollError - } - if n > 0 { - *chunk = hyper.CopyBuf(&req.buf[0], uintptr(n)) - return hyper.PollReady - } - if n == 0 { - println("n == 0") - *chunk = nil - req.req.Body.Close() - return hyper.PollReady - } - - req.req.Body.Close() - fmt.Printf("error reading request body: %s\n", c.GoString(c.Strerror(os.Errno))) - return hyper.PollError -} - func validMethod(method string) bool { /* Method = "OPTIONS" ; Section 9.2 diff --git a/x/net/http/server.go b/x/net/http/server.go index f38cbd0..5c4c58d 100644 --- a/x/net/http/server.go +++ b/x/net/http/server.go @@ -10,3 +10,10 @@ package http // size is anyway. (if we have the bytes on the machine, we might as // well read them) const maxPostHandlerReadBytes = 256 << 10 + +type readResult struct { + _ incomparable + n int + err error + b byte // byte read, if n == 1 +} diff --git a/x/net/http/transfer.go b/x/net/http/transfer.go index cf96f84..b0a52fb 100644 --- a/x/net/http/transfer.go +++ b/x/net/http/transfer.go @@ -5,10 +5,15 @@ import ( "fmt" "io" "net/textproto" + "reflect" "strconv" "strings" "sync" "unicode/utf8" + + "github.com/goplus/llgo/c" + "github.com/goplus/llgo/c/os" + "github.com/goplus/llgoexamples/rust/hyper" ) type transferReader struct { @@ -610,3 +615,108 @@ func lowerASCII(b byte) byte { // isOWS reports whether b is an optional whitespace byte, as defined // by RFC 7230 section 3.2.3. func isOWS(b byte) bool { return b == ' ' || b == '\t' } + +// writeHeader Write Content-Length and/or Transfer-Encoding and/or Trailer header +func (r *Request) writeHeader(reqHeaders *hyper.Headers) error { + if r.Close && !hasToken(r.Header.get("Connection"), "close") { + if reqHeaders.Set(&[]byte("Connection")[0], c.Strlen(c.Str("Connection")), &[]byte("close")[0], c.Strlen(c.Str("close"))) != hyper.OK { + return fmt.Errorf("error setting header: Connection: %s\n", "close") + } + } + + // 'Content-Length' and 'Transfer-Encoding:chunked' are already handled by hyper + + // Write Trailer header + // TODO(spongehah) Trailer(writeHeader) + + return nil +} + +var nopCloserType = reflect.TypeOf(io.NopCloser(nil)) +var nopCloserWriterToType = reflect.TypeOf(io.NopCloser(struct { + io.Reader + io.WriterTo +}{})) + +// unwrapNopCloser return the underlying reader and true if r is a NopCloser +// else it return false. +func unwrapNopCloser(r io.Reader) (underlyingReader io.Reader, isNopCloser bool) { + switch reflect.TypeOf(r) { + case nopCloserType, nopCloserWriterToType: + return reflect.ValueOf(r).Field(0).Interface().(io.Reader), true + default: + return nil, false + } +} + +// unwrapBody unwraps the body's inner reader if it's a +// nopCloser. This is to ensure that body writes sourced from local +// files (*os.File types) are properly optimized. +// +// This function is only intended for use in writeBody. +func (req *Request) unwrapBody() io.Reader { + if r, ok := unwrapNopCloser(req.Body); ok { + return r + } + if r, ok := req.Body.(*readTrackingBody); ok { + r.didRead = true + return r.ReadCloser + } + return req.Body +} + +func (r *Request) writeBody(hyperReq *hyper.Request) error { + if r.Body != nil { + var body = r.unwrapBody() + hyperReqBody := hyper.NewBody() + buf := make([]byte, defaultChunkSize) + //hyperBuf := hyper.CopyBuf(&buf[0], uintptr(defaultChunkSize)) + reqData := &bodyReq{ + body: body, + buf: buf, + //hyperBuf: hyperBuf, + closeBody: r.closeBody, + } + hyperReqBody.SetUserdata(c.Pointer(reqData)) + hyperReqBody.SetDataFunc(setPostData) + hyperReq.SetBody(hyperReqBody) + } + return nil +} + +type bodyReq struct { + body io.Reader + buf []byte + //hyperBuf *hyper.Buf + closeBody func() error +} + +func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { + req := (*bodyReq)(userdata) + n, err := req.body.Read(req.buf) + //buf := req.hyperBuf.Bytes() + //bufLen := req.hyperBuf.Len() + //n, err := req.body.Read(unsafe.Slice(buf, bufLen)) + if err != nil { + if err == io.EOF { + *chunk = nil + req.closeBody() + return hyper.PollReady + } + fmt.Println("error reading request body: ", err) + return hyper.PollError + } + if n > 0 { + *chunk = hyper.CopyBuf(&req.buf[0], uintptr(n)) + //*chunk = req.hyperBuf + return hyper.PollReady + } + if n == 0 { + *chunk = nil + req.closeBody() + return hyper.PollReady + } + req.closeBody() + fmt.Printf("error reading request body: %s\n", c.GoString(c.Strerror(os.Errno))) + return hyper.PollError +} diff --git a/x/net/http/transport.go b/x/net/http/transport.go index fe3efc3..5d7d48a 100644 --- a/x/net/http/transport.go +++ b/x/net/http/transport.go @@ -16,8 +16,8 @@ import ( "github.com/goplus/llgo/c/libuv" cnet "github.com/goplus/llgo/c/net" "github.com/goplus/llgo/c/syscall" - "github.com/goplus/llgoexamples/x/net" "github.com/goplus/llgoexamples/rust/hyper" + "github.com/goplus/llgoexamples/x/net" ) // DefaultTransport is the default implementation of Transport and is @@ -733,7 +733,6 @@ func (t *Transport) dial(ctx context.Context, cm connectMethod) (*connData, erro conn.TimeoutTimer.Start(onTimeout, uint64(treq.timeout.Milliseconds()), 0) } - libuv.InitTcp(loop, &conn.TcpHandle) libuv.InitTcp(loop, &conn.TcpHandle) (*libuv.Handle)(c.Pointer(&conn.TcpHandle)).SetData(c.Pointer(conn)) @@ -781,25 +780,26 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // uncompress the gzip stream if we were the layer that // requested it. requestedGzip := false - if !pc.t.DisableCompression && - req.Header.Get("Accept-Encoding") == "" && - req.Header.Get("Range") == "" && - req.Method != "HEAD" { - // Request gzip only, not deflate. Deflate is ambiguous and - // not as universally supported anyway. - // See: https://zlib.net/zlib_faq.html#faq39 - // - // Note that we don't request this for HEAD requests, - // due to a bug in nginx: - // https://trac.nginx.org/nginx/ticket/358 - // https://golang.org/issue/5522 - // - // We don't request gzip if the request is for a range, since - // auto-decoding a portion of a gzipped document will just fail - // anyway. See https://golang.org/issue/8923 - requestedGzip = true - req.extraHeaders().Set("Accept-Encoding", "gzip") - } + // TODO(spongehah) gzip(pc.roundTrip) + //if !pc.t.DisableCompression && + // req.Header.Get("Accept-Encoding") == "" && + // req.Header.Get("Range") == "" && + // req.Method != "HEAD" { + // // Request gzip only, not deflate. Deflate is ambiguous and + // // not as universally supported anyway. + // // See: https://zlib.net/zlib_faq.html#faq39 + // // + // // Note that we don't request this for HEAD requests, + // // due to a bug in nginx: + // // https://trac.nginx.org/nginx/ticket/358 + // // https://golang.org/issue/5522 + // // + // // We don't request gzip if the request is for a range, since + // // auto-decoding a portion of a gzipped document will just fail + // // anyway. See https://golang.org/issue/8923 + // requestedGzip = true + // req.extraHeaders().Set("Accept-Encoding", "gzip") + //} // The 100-continue operation in Hyper is handled in the newHyperRequest function. @@ -1126,14 +1126,15 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { } resp.Body = body - // TODO(spongehah) gzip fail(readWriteLoop) - if rc.addedGzip && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") { - resp.Body = &gzipReader{body: body} - resp.Header.Del("Content-Encoding") - resp.Header.Del("Content-Length") - resp.ContentLength = -1 - resp.Uncompressed = true - } + // TODO(spongehah) gzip(pc.readWriteLoop) + //if rc.addedGzip && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") { + // println("gzip reader") + // resp.Body = &gzipReader{body: body} + // resp.Header.Del("Content-Encoding") + // resp.Header.Del("Content-Length") + // resp.ContentLength = -1 + // resp.Uncompressed = true + //} rw.waitForBodyRead = waitForBodyRead rw.rc = rc From c4d73157830519230bb04a52934fab473b241094 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Thu, 29 Aug 2024 11:16:42 +0800 Subject: [PATCH 24/55] WIP(x/net/http/client): Extract outwards libuv.Loop and timeout logic --- go.mod | 2 +- go.sum | 4 +- x/net/http/_demo/timeout/timeout.go | 4 +- x/net/http/client.go | 50 +- x/net/http/request.go | 17 +- x/net/http/response.go | 1 + x/net/http/transfer.go | 15 +- x/net/http/transport.go | 989 ++++++++++++++++------------ x/net/http/util.go | 2 +- 9 files changed, 639 insertions(+), 445 deletions(-) diff --git a/go.mod b/go.mod index f961f75..e893515 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/goplus/llgoexamples go 1.20 require ( - github.com/goplus/llgo v0.9.7-0.20240816085229-53d2d080f4c4 + github.com/goplus/llgo v0.9.7-0.20240830010153-2434fd778f0b golang.org/x/net v0.28.0 ) diff --git a/go.sum b/go.sum index 4c64063..5d7faad 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/goplus/llgo v0.9.7-0.20240816085229-53d2d080f4c4 h1:fqqbWhWaoseSplLJF8OTkNGl4Kruqm1wQWT/Yooq6E4= -github.com/goplus/llgo v0.9.7-0.20240816085229-53d2d080f4c4/go.mod h1:5Fs+08NslqofJ7xtOiIXugkurYOoQvY02ZkFNWA1uEI= +github.com/goplus/llgo v0.9.7-0.20240830010153-2434fd778f0b h1:iC0vVA8F2DNJ9wVyHI9fP9U0nM+si3LSQJ1TtGftXyo= +github.com/goplus/llgo v0.9.7-0.20240830010153-2434fd778f0b/go.mod h1:5Fs+08NslqofJ7xtOiIXugkurYOoQvY02ZkFNWA1uEI= golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= diff --git a/x/net/http/_demo/timeout/timeout.go b/x/net/http/_demo/timeout/timeout.go index ddb2d25..62b2c9d 100644 --- a/x/net/http/_demo/timeout/timeout.go +++ b/x/net/http/_demo/timeout/timeout.go @@ -10,8 +10,8 @@ import ( func main() { client := &http.Client{ - Timeout: time.Millisecond, // Set a small timeout to ensure it will time out - //Timeout: time.Second * 5, + //Timeout: time.Millisecond, // Set a small timeout to ensure it will time out + Timeout: time.Second * 5, } req, err := http.NewRequest("GET", "https://www.baidu.com", nil) if err != nil { diff --git a/x/net/http/client.go b/x/net/http/client.go index bf1bfd4..d56f5f2 100644 --- a/x/net/http/client.go +++ b/x/net/http/client.go @@ -8,6 +8,7 @@ import ( "io" "log" "net/url" + "reflect" "sort" "strings" "sync" @@ -158,6 +159,9 @@ func (c *Client) do(req *Request) (retres *Response, reterr error) { Host: host, Cancel: ireq.Cancel, ctx: ireq.ctx, + + timer: ireq.timer, + timeoutch: ireq.timeoutch, } if includeBody && ireq.GetBody != nil { req.Body, err = ireq.GetBody() @@ -305,11 +309,14 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, d } // TODO(spongehah) timeout(send) - req.timeoutch = make(chan struct{}, 1) + req.deadline = deadline + if deadline.IsZero() { + didTimeout = alwaysFalse + } else { + didTimeout = func() bool { return req.timer.GetDueIn() == 0 } + } //stopTimer, didTimeout := setRequestCancel(req, rt, deadline) - sub := deadline.Sub(time.Now()) - req.timeout = sub resp, err = rt.RoundTrip(req) if err != nil { //stopTimer() @@ -469,6 +476,34 @@ func (b *cancelTimerBody) Close() error { return err } +// knownRoundTripperImpl reports whether rt is a RoundTripper that's +// maintained by the Go team and known to implement the latest +// optional semantics (notably contexts). The Request is used +// to check whether this particular request is using an alternate protocol, +// in which case we need to check the RoundTripper for that protocol. +func knownRoundTripperImpl(rt RoundTripper, req *Request) bool { + switch t := rt.(type) { + case *Transport: + if altRT := t.alternateRoundTripper(req); altRT != nil { + return knownRoundTripperImpl(altRT, req) + } + return true + // TODO(spongehah) + //case *http2Transport, http2noDialH2RoundTripper: + // return true + } + // There's a very minor chance of a false positive with this. + // Instead of detecting our golang.org/x/net/http2.Transport, + // it might detect a Transport type in a different http2 + // package. But I know of none, and the only problem would be + // some temporarily leaked goroutines if the transport didn't + // support contexts. So this is a good enough heuristic: + if reflect.TypeOf(rt).String() == "*http2.Transport" { + return true + } + return false +} + // setRequestCancel sets req.Cancel and adds a deadline context to req // if deadline is non-zero. The RoundTripper's type is used to // determine whether the legacy CancelRequest behavior should be used. @@ -482,11 +517,10 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi if deadline.IsZero() { return nop, alwaysFalse } - //knownTransport := knownRoundTripperImpl(rt, req) + knownTransport := knownRoundTripperImpl(rt, req) oldCtx := req.Context() - //if req.Cancel == nil && knownTransport { - if req.Cancel == nil { + if req.Cancel == nil && knownTransport { // If they already had a Request.Context that's // expiring sooner, do nothing: if !timeBeforeContextDeadline(deadline, oldCtx) { @@ -504,7 +538,7 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi req.ctx, cancelCtx = context.WithDeadline(oldCtx, deadline) } - cancel := make(chan struct{}, 1) + cancel := make(chan struct{}) req.Cancel = cancel doCancel := func() { @@ -518,7 +552,7 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi } } - stopTimerCh := make(chan struct{}, 1) + stopTimerCh := make(chan struct{}) var once sync.Once stopTimer = func() { once.Do(func() { diff --git a/x/net/http/request.go b/x/net/http/request.go index 6d74296..658b033 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -11,6 +11,7 @@ import ( "strings" "time" + "github.com/goplus/llgo/c/libuv" "golang.org/x/net/idna" "github.com/goplus/llgo/c" @@ -37,12 +38,14 @@ type Request struct { RemoteAddr string RequestURI string //TLS *tls.ConnectionState - Cancel <-chan struct{} - timeoutch chan struct{} //optional + Cancel <-chan struct{} Response *Response - timeout time.Duration ctx context.Context + + deadline time.Time + timeoutch chan struct{} //tmp timeout + timer *libuv.Timer } const defaultChunkSize = 8192 @@ -117,6 +120,7 @@ func NewRequestWithContext(ctx context.Context, method, urlStr string, body io.R Header: make(Header), Body: rc, Host: u.Host, + timer: nil, } if body != nil { switch v := body.(type) { @@ -258,7 +262,7 @@ var reqWriteExcludeHeader = map[string]bool{ // extraHeaders may be nil // waitForContinue may be nil // always closes body -func (r *Request) write(usingProxy bool, extraHeader Header, client *hyper.ClientConn, exec *hyper.Executor) (err error) { +func (r *Request) write(client *hyper.ClientConn, taskData *taskData, exec *hyper.Executor) (err error) { //trace := httptrace.ContextClientTrace(r.Context()) //if trace != nil && trace.WroteRequest != nil { // defer func() { @@ -269,13 +273,14 @@ func (r *Request) write(usingProxy bool, extraHeader Header, client *hyper.Clien //} // Prepare the hyper.Request - hyperReq, err := r.newHyperRequest(usingProxy, extraHeader) + hyperReq, err := r.newHyperRequest(taskData.pc.isProxy, taskData.req.extra) if err != nil { return err } // Send it! sendTask := client.Send(hyperReq) - setTaskId(sendTask, read) + taskData.taskId = read + sendTask.SetUserdata(c.Pointer(taskData)) sendRes := exec.Push(sendTask) if sendRes != hyper.OK { err = errors.New("failed to send the request") diff --git a/x/net/http/response.go b/x/net/http/response.go index 8151ac2..6ff5b3d 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -73,6 +73,7 @@ func appendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { _, err := writer.Write(bytes) if err != nil { fmt.Println("Error writing to response body:", err) + writer.Close() return hyper.IterBreak } return hyper.IterContinue diff --git a/x/net/http/transfer.go b/x/net/http/transfer.go index b0a52fb..103200c 100644 --- a/x/net/http/transfer.go +++ b/x/net/http/transfer.go @@ -670,11 +670,9 @@ func (r *Request) writeBody(hyperReq *hyper.Request) error { var body = r.unwrapBody() hyperReqBody := hyper.NewBody() buf := make([]byte, defaultChunkSize) - //hyperBuf := hyper.CopyBuf(&buf[0], uintptr(defaultChunkSize)) reqData := &bodyReq{ - body: body, - buf: buf, - //hyperBuf: hyperBuf, + body: body, + buf: buf, closeBody: r.closeBody, } hyperReqBody.SetUserdata(c.Pointer(reqData)) @@ -685,18 +683,14 @@ func (r *Request) writeBody(hyperReq *hyper.Request) error { } type bodyReq struct { - body io.Reader - buf []byte - //hyperBuf *hyper.Buf + body io.Reader + buf []byte closeBody func() error } func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { req := (*bodyReq)(userdata) n, err := req.body.Read(req.buf) - //buf := req.hyperBuf.Bytes() - //bufLen := req.hyperBuf.Len() - //n, err := req.body.Read(unsafe.Slice(buf, bufLen)) if err != nil { if err == io.EOF { *chunk = nil @@ -708,7 +702,6 @@ func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.In } if n > 0 { *chunk = hyper.CopyBuf(&req.buf[0], uintptr(n)) - //*chunk = req.hyperBuf return hyper.PollReady } if n == 0 { diff --git a/x/net/http/transport.go b/x/net/http/transport.go index 5d7d48a..062cbe9 100644 --- a/x/net/http/transport.go +++ b/x/net/http/transport.go @@ -10,14 +10,15 @@ import ( "net/url" "sync" "sync/atomic" + "time" "unsafe" "github.com/goplus/llgo/c" "github.com/goplus/llgo/c/libuv" cnet "github.com/goplus/llgo/c/net" "github.com/goplus/llgo/c/syscall" + "github.com/goplus/llgo/x/net" "github.com/goplus/llgoexamples/rust/hyper" - "github.com/goplus/llgoexamples/x/net" ) // DefaultTransport is the default implementation of Transport and is @@ -33,7 +34,7 @@ var DefaultTransport RoundTripper = &Transport{ // DefaultMaxIdleConnsPerHost is the default value of Transport's // MaxIdleConnsPerHost. const DefaultMaxIdleConnsPerHost = 2 -const defaultHTTPPort = "80" +const debugSwitch = true type Transport struct { altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme @@ -68,6 +69,11 @@ type Transport struct { // // Zero means no limit. MaxConnsPerHost int + + // libuv and hyper related + loopInitOnce sync.Once + loop *libuv.Loop + exec *hyper.Executor } // A cancelKey is the key of the reqCanceler map. @@ -82,29 +88,6 @@ type cancelKey struct { // any size (as long as it's first). type incomparable [0]func() -type requestAndChan struct { - _ incomparable - req *Request - cancelKey cancelKey - ch chan responseAndError // unbuffered; always send in select on callerGone - - // whether the Transport (as opposed to the user client code) - // added the Accept-Encoding gzip header. If the Transport - // set it, only then do we transparently decode the gzip. - addedGzip bool - - callerGone <-chan struct{} // closed when roundTrip caller has returned -} - -// A writeRequest is sent by the caller's goroutine to the -// writeLoop's goroutine to write a request while the read loop -// concurrently waits on both the write response and the server's -// reply. -type writeRequest struct { - req *transportRequest - ch chan<- error -} - // responseAndError is how the goroutine reading from an HTTP/1 server // communicates with the goroutine doing the RoundTrip. type responseAndError struct { @@ -113,10 +96,9 @@ type responseAndError struct { err error } -type connAndTimeoutChan struct { - _ incomparable - conn *connData +type timeoutData struct { timeoutch chan struct{} + taskData *taskData } type readTrackingBody struct { @@ -205,8 +187,6 @@ func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectM if t.Proxy != nil { cm.proxyURL, err = t.Proxy(treq.Request) } - // TODO(spongehah) cm.treq(connectMethod) - cm.treq = treq cm.onlyH1 = treq.requiresHTTP1() return cm, err } @@ -293,9 +273,110 @@ func (t *Transport) cancelRequest(key cancelKey, err error) bool { return cancel != nil } +func (t *Transport) close(err error) { + t.reqMu.Lock() + defer t.reqMu.Unlock() + t.closeLocked(err) +} + +func (t *Transport) closeLocked(err error) { + if err != nil { + fmt.Println(err) + } + if t.loop != nil { + t.loop.Close() + } + if t.exec != nil { + t.exec.Free() + } +} + +func getMilliseconds(deadline time.Time) uint64 { + microseconds := deadline.Sub(time.Now()).Microseconds() + milliseconds := microseconds / 1e3 + if microseconds%1e3 != 0 { + milliseconds += 1 + } + return uint64(milliseconds) +} + // ---------------------------------------------------------- func (t *Transport) RoundTrip(req *Request) (*Response, error) { + if debugSwitch { + println("RoundTrip start") + defer println("RoundTrip end") + } + t.loopInitOnce.Do(func() { + t.loop = libuv.LoopNew() + t.exec = hyper.NewExecutor() + + //idle := &libuv.Idle{} + //libuv.InitIdle(t.loop, idle) + //(*libuv.Handle)(c.Pointer(idle)).SetData(c.Pointer(t)) + //idle.Start(readWriteLoop) + + checker := &libuv.Check{} + libuv.InitCheck(t.loop, checker) + (*libuv.Handle)(c.Pointer(checker)).SetData(c.Pointer(t)) + checker.Start(readWriteLoop) + + go t.loop.Run(libuv.RUN_DEFAULT) + }) + + // If timeout is set, start the timer + var didTimeout func() bool + var stopTimer func() + // Only the first request will initialize the timer + if req.timer == nil && !req.deadline.IsZero() { + req.timer = &libuv.Timer{} + req.timeoutch = make(chan struct{}, 1) + libuv.InitTimer(t.loop, req.timer) + ch := &timeoutData{ + timeoutch: req.timeoutch, + taskData: nil, + } + (*libuv.Handle)(c.Pointer(req.timer)).SetData(c.Pointer(ch)) + + req.timer.Start(onTimeout, getMilliseconds(req.deadline), 0) + if debugSwitch { + println("timer start") + } + didTimeout = func() bool { return req.timer.GetDueIn() == 0 } + stopTimer = func() { + close(req.timeoutch) + req.timer.Stop() + (*libuv.Handle)(c.Pointer(req.timer)).Close(nil) + if debugSwitch { + println("timer close") + } + } + } else { + didTimeout = alwaysFalse + stopTimer = nop + } + + resp, err := t.doRoundTrip(req) + if err != nil { + stopTimer() + return nil, err + } + + if !req.deadline.IsZero() { + resp.Body = &cancelTimerBody{ + stop: stopTimer, + rc: resp.Body, + reqDidTimeout: didTimeout, + } + } + return resp, nil +} + +func (t *Transport) doRoundTrip(req *Request) (*Response, error) { + if debugSwitch { + println("doRoundTrip start") + defer println("doRoundTrip end") + } //t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) //ctx := req.Context() //trace := httptrace.ContextClientTrace(ctx) @@ -354,13 +435,19 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { } for { - // TODO(spongehah) timeout(t.RoundTrip): because of that ctx not initialized ( initialized in setRequestCancel() ) + // TODO(spongehah) timeout(t.doRoundTrip) //select { //case <-ctx.Done(): // req.closeBody() // return nil, ctx.Err() //default: //} + select { + case <-req.timeoutch: + req.closeBody() + return nil, errors.New("request timeout!") + default: + } // treq gets modified by roundTrip, so we need to recreate for each retry. //treq := &transportRequest{Request: req, trace: trace, cancelKey: cancelKey} @@ -376,6 +463,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { // pre-CONNECTed to https server. In any case, we'll be ready // to send it requests. pconn, err := t.getConn(treq, cm) + if err != nil { t.setReqCanceler(cancelKey, nil) req.closeBody() @@ -390,19 +478,24 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { } else { resp, err = pconn.roundTrip(treq) } + if err == nil { resp.Request = origReq return resp, nil } // Failed. Clean up and determine whether to retry. - // TODO(spongehah) Retry & ConnPool(t.RoundTrip) + // TODO(spongehah) Retry & ConnPool(t.doRoundTrip) return nil, err } } func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persistConn, err error) { - //req := treq.Request + if debugSwitch { + println("getConn start") + defer println("getConn end") + } + req := treq.Request //trace := treq.trace //ctx := req.Context() //if trace != nil && trace.GetConn != nil { @@ -413,6 +506,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi cm: cm, key: cm.key(), //ctx: ctx, + timeoutch: treq.timeoutch, ready: make(chan struct{}, 1), beforeDial: testHookPrePendingDial, afterDial: testHookPostPendingDial, @@ -458,10 +552,13 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi // what caused w.err; if so, prefer to return the // cancellation error (see golang.org/issue/16049). select { + // TODO(spongehah) cancel(t.getConn) //case <-req.Cancel: // return nil, errRequestCanceledConn //case <-req.Context().Done(): // return nil, req.Context().Err() + case <-req.timeoutch: + return nil, errors.New("timeout: req.Context().Err()") case err := <-cancelc: if err == errRequestCanceled { err = errRequestCanceledConn @@ -475,10 +572,13 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi // TODO(spongehah) cancel(t.getConn) //case <-req.Cancel: // return nil, errRequestCanceledConn - case <-treq.Request.timeoutch: - return nil, fmt.Errorf("request timeout\n") //case <-req.Context().Done(): - // return nil, req.Context().Err() + // return nil, + case <-req.timeoutch: + if debugSwitch { + println("getConn: timeoutch") + } + return nil, errors.New("timeout: req.Context().Err()\n") case err := <-cancelc: if err == errRequestCanceled { err = errRequestCanceledConn @@ -490,6 +590,10 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi // queueForDial queues w to wait for permission to begin dialing. // Once w receives permission to dial, it will do so in a separate goroutine. func (t *Transport) queueForDial(w *wantConn) { + if debugSwitch { + println("queueForDial start") + defer println("queueForDial end") + } w.beforeDial() if t.MaxConnsPerHost <= 0 { @@ -522,9 +626,13 @@ func (t *Transport) queueForDial(w *wantConn) { // dialConnFor has received permission to dial w.cm and is counted in t.connCount[w.cm.key()]. // If the dial is canceled or unsuccessful, dialConnFor decrements t.connCount[w.cm.key()]. func (t *Transport) dialConnFor(w *wantConn) { + if debugSwitch { + println("dialConnFor start") + defer println("dialConnFor end") + } defer w.afterDial() - pc, err := t.dialConn(w.ctx, w.cm) + pc, err := t.dialConn(w.timeoutch, w.cm) w.tryDeliver(pc, err) // TODO(spongehah) ConnPool(t.dialConnFor) //delivered := w.tryDeliver(pc, err) @@ -593,12 +701,19 @@ func (t *Transport) decConnsPerHost(key connectMethodKey) { } } -func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *persistConn, err error) { +func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn *persistConn, err error) { + if debugSwitch { + println("dialConn start") + defer println("dialConn end") + } + select { + case <-timeoutch: + return + default: + } pconn = &persistConn{ t: t, cacheKey: cm.key(), - reqch: make(chan requestAndChan, 1), - writech: make(chan writeRequest, 1), closech: make(chan struct{}, 1), writeLoopDone: make(chan struct{}, 1), } @@ -611,7 +726,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers // } // return err //} - + // //if cm.scheme() == "https" && t.hasCustomTLSDialer() { // var err error // pconn.conn, err = t.customDialTLS(ctx, "tcp", cm.addr()) @@ -639,27 +754,12 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers // } //} else { //conn, err := t.dial(ctx, "tcp", cm.addr()) - conn, err := t.dial(ctx, cm) + conn, err := t.dial(timeoutch, cm.addr()) if err != nil { return nil, err } pconn.conn = conn - // hyper specific - // Hookup the IO - hyperIo := newIoWithConnReadWrite(conn) - // We need an executor generally to poll futures - exec := hyper.NewExecutor() - // Prepare client options - opts := hyper.NewClientConnOptions() - opts.Exec(exec) - pconn.exec = exec - // send the handshake - handshakeTask := hyper.Handshake(hyperIo, opts) - setTaskId(handshakeTask, write) - // Let's wait for the handshake to finish... - exec.Push(handshakeTask) - //if cm.scheme() == "https" { // var firstTLSHost string // if firstTLSHost, _, err = net.SplitHostPort(cm.addr()); err != nil { @@ -670,7 +770,12 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers // } //} //} - + select { + case <-timeoutch: + conn.Close() + return + default: + } // TODO(spongehah) Proxy(https/sock5)(t.dialConn) // Proxy setup. switch { @@ -704,36 +809,28 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers // } //} - go pconn.readWriteLoop(libuv.DefaultLoop()) - + select { + case <-timeoutch: + conn.Close() + return + default: + } return pconn, nil } -func (t *Transport) dial(ctx context.Context, cm connectMethod) (*connData, error) { - treq := cm.treq - host := treq.URL.Hostname() - port := treq.URL.Port() - if port == "" { - port = defaultHTTPPort +func (t *Transport) dial(timeoutch chan struct{}, addr string) (*connData, error) { + if debugSwitch { + println("dial start") + defer println("dial end") } - loop := libuv.DefaultLoop() - conn := new(connData) - if conn == nil { - return nil, fmt.Errorf("Failed to allocate memory for conn_data\n") + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err } - // If timeout is set, start the timer - if treq.timeout > 0 { - libuv.InitTimer(loop, &conn.TimeoutTimer) - ct := &connAndTimeoutChan{ - conn: conn, - timeoutch: treq.Request.timeoutch, - } - (*libuv.Handle)(c.Pointer(&conn.TimeoutTimer)).SetData(c.Pointer(ct)) - conn.TimeoutTimer.Start(onTimeout, uint64(treq.timeout.Milliseconds()), 0) - } + conn := new(connData) - libuv.InitTcp(loop, &conn.TcpHandle) + libuv.InitTcp(t.loop, &conn.TcpHandle) (*libuv.Handle)(c.Pointer(&conn.TcpHandle)).SetData(c.Pointer(conn)) var hints cnet.AddrInfo @@ -744,14 +841,12 @@ func (t *Transport) dial(ctx context.Context, cm connectMethod) (*connData, erro var res *cnet.AddrInfo status := cnet.Getaddrinfo(c.AllocaCStr(host), c.AllocaCStr(port), &hints, &res) if status != 0 { - close(treq.Request.timeoutch) return nil, fmt.Errorf("getaddrinfo error\n") } (*libuv.Req)(c.Pointer(&conn.ConnectReq)).SetData(c.Pointer(conn)) status = libuv.TcpConnect(&conn.ConnectReq, &conn.TcpHandle, res.Addr, onConnect) if status != 0 { - close(treq.Request.timeoutch) return nil, fmt.Errorf("connect error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(status)))) } @@ -760,6 +855,10 @@ func (t *Transport) dial(ctx context.Context, cm connectMethod) (*connData, erro } func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { + if debugSwitch { + println("roundTrip start") + defer println("roundTrip end") + } testHookEnterRoundTrip() if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) { // TODO(spongehah) ConnPool(pc.roundTrip) @@ -819,43 +918,61 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err } }() - const debugRoundTrip = false // Debug switch provided for developers - // Write the request concurrently with waiting for a response, // in case the server decides to reply before reading our full // request body. - - // In Hyper, the writeLoop() and readLoop() are combined together --> readWriteLoop(). startBytesWritten := pc.conn.nwrite writeErrCh := make(chan error, 1) - pc.writech <- writeRequest{req: req, ch: writeErrCh} - - // Send the request to readWriteLoop(). resc := make(chan responseAndError, 1) - pc.reqch <- requestAndChan{ - req: req.Request, - cancelKey: req.cancelKey, - ch: resc, + + // Hookup the IO + hyperIo := newIoWithConnReadWrite(pc.conn) + // We need an executor generally to poll futures + // Prepare client options + opts := hyper.NewClientConnOptions() + opts.Exec(pc.t.exec) + // send the handshake + handshakeTask := hyper.Handshake(hyperIo, opts) + taskData := &taskData{ + taskId: write, + req: req, + pc: pc, addedGzip: requestedGzip, + writeErrCh: writeErrCh, callerGone: gone, + resc: resc, } + handshakeTask.SetUserdata(c.Pointer(taskData)) + // Send the request to readWriteLoop(). + // Let's wait for the handshake to finish... + + pc.t.exec.Push(handshakeTask) + async := &libuv.Async{} + pc.t.loop.Async(async, asyncCb) + async.Send() //var respHeaderTimer <-chan time.Time //cancelChan := req.Request.Cancel //ctxDoneChan := req.Context().Done() + timeoutch := req.timeoutch pcClosed := pc.closech canceled := false for { testHookWaitResLoop() - + if debugSwitch { + println("roundTrip for") + } select { case err := <-writeErrCh: - if debugRoundTrip { - //req.logf("writeErrCh resv: %T/%#v", err, err) + if debugSwitch { + println("roundTrip: writeErrch") } if err != nil { pc.close(fmt.Errorf("write error: %w", err)) + if pc.conn.nwrite == startBytesWritten { + err = nothingWrittenError{err} + } return nil, pc.mapRoundTripError(req, startBytesWritten, err) } //if d := pc.t.ResponseHeaderTimeout; d > 0 { @@ -867,69 +984,53 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // respHeaderTimer = timer.C //} case <-pcClosed: + if debugSwitch { + println("roundTrip: pcClosed") + } pcClosed = nil if canceled || pc.t.replaceReqCanceler(req.cancelKey, nil) { - if debugRoundTrip { - //req.logf("closech recv: %T %#v", pc.closed, pc.closed) - } return nil, pc.mapRoundTripError(req, startBytesWritten, pc.closed) } //case <-respHeaderTimer: case re := <-resc: + if debugSwitch { + println("roundTrip: resc") + } if (re.res == nil) == (re.err == nil) { - println(1) return nil, fmt.Errorf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil) } - if debugRoundTrip { - println(2) - //req.logf("resc recv: %p, %T/%#v", re.res, re.err, re.err) - } if re.err != nil { - println(3) return nil, pc.mapRoundTripError(req, startBytesWritten, re.err) } return re.res, nil // TODO(spongehah) cancel(pc.roundTrip) //case <-cancelChan: - case <-req.Request.timeoutch: - return nil, fmt.Errorf("request timeout\n") + // canceled = pc.t.cancelRequest(req.cancelKey, errRequestCanceled) + // cancelChan = nil + //case <-ctxDoneChan: + // canceled = pc.t.cancelRequest(req.cancelKey, req.Context().Err()) + // cancelChan = nil + // ctxDoneChan = nil + case <-timeoutch: + if debugSwitch { + println("roundTrip: timeoutch") + } + canceled = pc.t.cancelRequest(req.cancelKey, errors.New("timeout: req.Context().Err()")) + timeoutch = nil + return nil, errors.New("request timeout") } } } +func asyncCb(async *libuv.Async) { + println("async called") +} + // readWriteLoop handles the main I/O loop for a persistent connection. // It processes incoming requests, sends them to the server, and handles responses. -func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { - // writeLoop related - defer close(pc.writeLoopDone) - - // readLoop related - closeErr := errReadLoopExiting // default value, if not changed below - defer func() { - pc.close(closeErr) - // TODO(spongehah) ConnPool(readWriteLoop) - //pc.t.removeIdleConn(pc) - }() - - //tryPutIdleConn := func(trace *httptrace.ClientTrace) bool { - // if err := pc.t.tryPutIdleConn(pc); err != nil { - // closeErr = err - // if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { - // trace.PutIdleConn(err) - // } - // return false - // } - // if trace != nil && trace.PutIdleConn != nil { - // trace.PutIdleConn(nil) - // } - // return true - //} - - // eofc is used to block caller goroutines reading from Response.Body - // at EOF until this goroutines has (potentially) added the connection - // back to the idle pool. - eofc := make(chan struct{}, 1) - defer close(eofc) // unblock reader on errors +func readWriteLoop(idle *libuv.Check) { + println("polling") + t := (*Transport)((*libuv.Handle)(c.Pointer(idle)).GetData()) // Read this once, before loop starts. (to avoid races in tests) testHookMu.Lock() @@ -938,281 +1039,334 @@ func (pc *persistConn) readWriteLoop(loop *libuv.Loop) { const debugReadWriteLoop = true // Debug switch provided for developers - if debugReadWriteLoop { - println("readWriteLoop start") - } - // The polling state machine! // Poll all ready tasks and act on them... - alive := true - var bodyWriter *io.PipeWriter - var rw readWaiter - for alive { - select { - case <-pc.closech: - if debugReadWriteLoop { - println("closech") - } + for { + task := t.exec.Poll() + if task == nil { return - default: - task := pc.exec.Poll() - if task == nil { - loop.Run(libuv.RUN_ONCE) - continue - } - taskId := (taskId)(uintptr(task.Userdata())) + } + taskData := (*taskData)(task.Userdata()) + var taskId taskId + if taskData != nil { + taskId = taskData.taskId + } else { + taskId = notSet + } + if debugReadWriteLoop { + println("taskId: ", taskId) + } + switch taskId { + case write: if debugReadWriteLoop { - println("taskId: ", taskId) + println("write") } - switch taskId { - case write: - if debugReadWriteLoop { - println("write") - } - wr := <-pc.writech // blocking - startBytesWritten := pc.conn.nwrite - err := checkTaskType(task, write) - client := (*hyper.ClientConn)(task.Value()) + select { + case <-taskData.pc.closech: task.Free() - if err == nil { - // TODO(spongehah) Proxy(writeLoop) - err = wr.req.Request.write(pc.isProxy, wr.req.extra, client, pc.exec) - } - // For this request, no longer need the client - client.Free() - if bre, ok := err.(requestBodyReadError); ok { - err = bre.error - // Errors reading from the user's - // Request.Body are high priority. - // Set it here before sending on the - // channels below or calling - // pc.close() which tears down - // connections and causes other - // errors. - wr.req.setError(err) - } - if err != nil { - if pc.conn.nwrite == startBytesWritten { - err = nothingWrittenError{err} - } - //pc.writeErrCh <- err // to the body reader, which might recycle us - wr.ch <- err // to the roundTrip function - pc.close(err) - return - } + continue + default: + } - if debugReadWriteLoop { - println("write end") - } - case read: - if debugReadWriteLoop { - println("read") - } + err := checkTaskType(task, write) + client := (*hyper.ClientConn)(task.Value()) + task.Free() - err := checkTaskType(task, read) + if err == nil { + // TODO(spongehah) Proxy(writeLoop) + err = taskData.req.Request.write(client, taskData, t.exec) + } + // For this request, no longer need the client + client.Free() + if bre, ok := err.(requestBodyReadError); ok { + err = bre.error + // Errors reading from the user's + // Request.Body are high priority. + // Set it here before sending on the + // channels below or calling + // pc.close() which tears down + // connections and causes other + // errors. + taskData.req.setError(err) + } + if err != nil { + //pc.writeErrCh <- err // to the body reader, which might recycle us + taskData.writeErrCh <- err // to the roundTrip function + taskData.pc.close(err) + continue + } - pc.mu.Lock() - if pc.numExpectedResponses == 0 { - pc.closeLocked(errServerClosedIdle) - pc.mu.Unlock() - return - } - pc.mu.Unlock() + if debugReadWriteLoop { + println("write end") + } + case read: + if debugReadWriteLoop { + println("read") + } - rc := <-pc.reqch // blocking - //trace := httptrace.ContextClientTrace(rc.req.Context()) + if taskData.pc.closeErr == nil { + taskData.pc.closeErr = errReadLoopExiting + } + // TODO(spongehah) ConnPool(readWriteLoop) + //if taskData.pc.tryPutIdleConn == nil { + // //taskData.pc.tryPutIdleConn := func(trace *httptrace.ClientTrace) bool { + // // if err := pc.t.tryPutIdleConn(pc); err != nil { + // // closeErr = err + // // if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { + // // trace.PutIdleConn(err) + // // } + // // return false + // // } + // // if trace != nil && trace.PutIdleConn != nil { + // // trace.PutIdleConn(nil) + // // } + // // return true + // //} + //} - // Take the results - hyperResp := (*hyper.Response)(task.Value()) - task.Free() + err := checkTaskType(task, read) - var resp *Response - var respBody *hyper.Body - if err == nil { - var pr *io.PipeReader - pr, bodyWriter = io.Pipe() - resp, err = ReadResponse(pr, rc.req, hyperResp) - respBody = hyperResp.Body() - } else { - err = transportReadFromServerError{err} - closeErr = err - } + taskData.pc.mu.Lock() + if taskData.pc.numExpectedResponses == 0 { + taskData.pc.closeLocked(errServerClosedIdle) + taskData.pc.mu.Unlock() - // No longer need the response - hyperResp.Free() + // defer + taskData.pc.close(taskData.pc.closeErr) + // TODO(spongehah) ConnPool(readWriteLoop) + //t.removeIdleConn(pc) + continue + } + taskData.pc.mu.Unlock() + + //trace := httptrace.ContextClientTrace(rc.req.Context()) + + // Take the results + hyperResp := (*hyper.Response)(task.Value()) + task.Free() + + var resp *Response + var respBody *hyper.Body + if err == nil { + var pr *io.PipeReader + pr, taskData.bodyWriter = io.Pipe() + resp, err = ReadResponse(pr, taskData.req.Request, hyperResp) + respBody = hyperResp.Body() + } else { + err = transportReadFromServerError{err} + taskData.pc.closeErr = err + } - if err != nil { - select { - case rc.ch <- responseAndError{err: err}: - case <-rc.callerGone: - return - } - return - } + // No longer need the response + hyperResp.Free() - // Response has been returned, stop the timer - if rc.req.timeout > 0 { - pc.conn.TimeoutTimer.Stop() - (*libuv.Handle)(c.Pointer(&pc.conn.TimeoutTimer)).Close(nil) + if err != nil { + select { + case taskData.resc <- responseAndError{err: err}: + case <-taskData.callerGone: + // defer + taskData.pc.close(taskData.pc.closeErr) + // TODO(spongehah) ConnPool(readWriteLoop) + //t.removeIdleConn(pc) + continue } + // defer + taskData.pc.close(taskData.pc.closeErr) + // TODO(spongehah) ConnPool(readWriteLoop) + //t.removeIdleConn(pc) + continue + } - pc.mu.Lock() - pc.numExpectedResponses-- - pc.mu.Unlock() + taskData.pc.mu.Lock() + taskData.pc.numExpectedResponses-- + taskData.pc.mu.Unlock() - bodyWritable := resp.bodyIsWritable() - hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0 + bodyWritable := resp.bodyIsWritable() + hasBody := taskData.req.Method != "HEAD" && resp.ContentLength != 0 - if resp.Close || rc.req.Close || resp.StatusCode <= 199 || bodyWritable { - // Don't do keep-alive on error if either party requested a close - // or we get an unexpected informational (1xx) response. - // StatusCode 100 is already handled above. - alive = false - } + if resp.Close || taskData.req.Close || resp.StatusCode <= 199 || bodyWritable { + // Don't do keep-alive on error if either party requested a close + // or we get an unexpected informational (1xx) response. + // StatusCode 100 is already handled above. + taskData.pc.alive = false + } - if !hasBody || bodyWritable { - //replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) - pc.t.replaceReqCanceler(rc.cancelKey, nil) + if !hasBody || bodyWritable { + //replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) + t.replaceReqCanceler(taskData.req.cancelKey, nil) + + // TODO(spongehah) ConnPool(readWriteLoop) + //// Put the idle conn back into the pool before we send the response + //// so if they process it quickly and make another request, they'll + //// get this same conn. But we use the unbuffered channel 'rc' + //// to guarantee that persistConn.roundTrip got out of its select + //// potentially waiting for this persistConn to close. + //taskData.pc.alive = taskData.pc.alive && + // !pc.sawEOF && + // pc.wroteRequest() && + // replaced && tryPutIdleConn(trace) + + if bodyWritable { + taskData.pc.closeErr = errCallerOwnsConn + } + select { + case taskData.resc <- responseAndError{res: resp}: + case <-taskData.callerGone: + // defer + taskData.pc.close(taskData.pc.closeErr) // TODO(spongehah) ConnPool(readWriteLoop) - //// Put the idle conn back into the pool before we send the response - //// so if they process it quickly and make another request, they'll - //// get this same conn. But we use the unbuffered channel 'rc' - //// to guarantee that persistConn.roundTrip got out of its select - //// potentially waiting for this persistConn to close. - //alive = alive && - // !pc.sawEOF && - // pc.wroteRequest() && - // replaced && tryPutIdleConn(trace) - - if bodyWritable { - closeErr = errCallerOwnsConn - } - - select { - case rc.ch <- responseAndError{res: resp}: - case <-rc.callerGone: - return - } - - // Now that they've read from the unbuffered channel, they're safely - // out of the select that also waits on this goroutine to die, so - // we're allowed to exit now if needed (if alive is false) - testHookReadLoopBeforeNextRead() + //t.removeIdleConn(pc) continue } + // Now that they've read from the unbuffered channel, they're safely + // out of the select that also waits on this goroutine to die, so + // we're allowed to exit now if needed (if alive is false) + testHookReadLoopBeforeNextRead() + if taskData.pc.alive == false { + // defer + taskData.pc.close(taskData.pc.closeErr) + // TODO(spongehah) ConnPool(readWriteLoop) + //t.removeIdleConn(pc) + } + continue + } - waitForBodyRead := make(chan bool, 2) - body := &bodyEOFSignal{ - body: resp.Body, - earlyCloseFn: func() error { - waitForBodyRead <- false - <-eofc // will be closed by deferred call at the end of the function - return nil - }, - fn: func(err error) error { - isEOF := err == io.EOF - waitForBodyRead <- isEOF - if isEOF { - <-eofc // see comment above eofc declaration - } else if err != nil { - if cerr := pc.canceled(); cerr != nil { - return cerr - } + body := &bodyEOFSignal{ + body: resp.Body, + earlyCloseFn: func() error { + taskData.bodyWriter.Close() + return nil + }, + fn: func(err error) error { + isEOF := err == io.EOF + if !isEOF { + if cerr := taskData.pc.canceled(); cerr != nil { + return cerr } - return err - }, - } - resp.Body = body - - // TODO(spongehah) gzip(pc.readWriteLoop) - //if rc.addedGzip && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") { - // println("gzip reader") - // resp.Body = &gzipReader{body: body} - // resp.Header.Del("Content-Encoding") - // resp.Header.Del("Content-Length") - // resp.ContentLength = -1 - // resp.Uncompressed = true - //} - - rw.waitForBodyRead = waitForBodyRead - rw.rc = rc - bodyForeachTask := respBody.Foreach(appendToResponseBody, c.Pointer(bodyWriter)) - setTaskId(bodyForeachTask, readDone) - pc.exec.Push(bodyForeachTask) - - // TODO(spongehah) select blocking(readWriteLoop) - //select { - //case rc.ch <- responseAndError{res: resp}: - //case <-rc.callerGone: - // return - //} - rc.ch <- responseAndError{res: resp} - - if debugReadWriteLoop { - println("read end") - } - case readDone: - // A background task of reading the response body is completed - if debugReadWriteLoop { - println("readDone") - } - if bodyWriter != nil { - bodyWriter.Close() - } - checkTaskType(task, readDone) + } + return err + }, + } + resp.Body = body + + // TODO(spongehah) gzip(pc.readWriteLoop) + //if taskData.addedGzip && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") { + // println("gzip reader") + // resp.Body = &gzipReader{body: body} + // resp.Header.Del("Content-Encoding") + // resp.Header.Del("Content-Length") + // resp.ContentLength = -1 + // resp.Uncompressed = true + //} - hyperBodyEOF := task.Type() == hyper.TaskEmpty - // free the task - task.Free() + bodyForeachTask := respBody.Foreach(appendToResponseBody, c.Pointer(taskData.bodyWriter)) + taskData.taskId = readDone + bodyForeachTask.SetUserdata(c.Pointer(taskData)) + t.exec.Push(bodyForeachTask) + (*timeoutData)((*libuv.Handle)(c.Pointer(taskData.req.timer)).GetData()).taskData = taskData + + // TODO(spongehah) select blocking(readWriteLoop) + //select { + //case taskData.resc <- responseAndError{res: resp}: + //case <-taskData.callerGone: + // // defer + // taskData.pc.close(taskData.pc.closeErr) + // // TODO(spongehah) ConnPool(readWriteLoop) + // //t.removeIdleConn(pc) + // continue + //} + select { + case <-taskData.callerGone: + // defer + taskData.pc.close(taskData.pc.closeErr) + // TODO(spongehah) ConnPool(readWriteLoop) + //t.removeIdleConn(pc) + continue + default: + } + taskData.resc <- responseAndError{res: resp} - // Before looping back to the top of this function and peeking on - // the bufio.Reader, wait for the caller goroutine to finish - // reading the response body. (or for cancellation or death) - select { - case bodyEOF := <-rw.waitForBodyRead: - bodyEOF = bodyEOF && hyperBodyEOF - //replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool - pc.t.replaceReqCanceler(rw.rc.cancelKey, nil) // before pc might return to idle pool - // TODO(spongehah) ConnPool(readWriteLoop) - //alive = alive && - // bodyEOF && - // !pc.sawEOF && - // pc.wroteRequest() && - // replaced && tryPutIdleConn(trace) - - eofc <- struct{}{} - // TODO(spongehah) cancel(pc.readWriteLoop) - //case <-rw.rc.req.Cancel: - // alive = false - // pc.t.CancelRequest(rw.rc.req) - //case <-rw.rc.req.Context().Done(): - // alive = false - // pc.t.cancelRequest(rw.rc.cancelKey, rw.rc.req.Context().Err()) - case <-pc.closech: - alive = false - } + if debugReadWriteLoop { + println("read end") + } + case readDone: + // A background task of reading the response body is completed + if debugReadWriteLoop { + println("readDone") + } + if taskData.bodyWriter != nil { + taskData.bodyWriter.Close() + } + checkTaskType(task, readDone) + + //bodyEOF := task.Type() == hyper.TaskEmpty + // free the task + task.Free() + + t.replaceReqCanceler(taskData.req.cancelKey, nil) // before pc might return to idle pool + // TODO(spongehah) ConnPool(readWriteLoop) + //taskData.pc.alive = taskData.pc.alive && + // bodyEOF && + // !pc.sawEOF && + // pc.wroteRequest() && + // replaced && tryPutIdleConn(trace) + + // TODO(spongehah) cancel(pc.readWriteLoop) + //case <-rw.rc.req.Cancel: + // taskData.pc.alive = false + // pc.t.CancelRequest(rw.rc.req) + //case <-rw.rc.req.Context().Done(): + // taskData.pc.alive = false + // pc.t.cancelRequest(rw.rc.cancelKey, rw.rc.req.Context().Err()) + //case <-taskData.pc.closech: + // taskData.pc.alive = false + //} - testHookReadLoopBeforeNextRead() - if debugReadWriteLoop { - println("readDone end") - } - case notSet: - // A background task for hyper_client completed... - task.Free() + select { + case <-taskData.req.timeoutch: + continue + case <-taskData.pc.closech: + taskData.pc.alive = false + default: + } + + if taskData.pc.alive == false { + // defer + taskData.pc.close(taskData.pc.closeErr) + // TODO(spongehah) ConnPool(readWriteLoop) + //t.removeIdleConn(pc) + } + + testHookReadLoopBeforeNextRead() + if debugReadWriteLoop { + println("readDone end") } + case notSet: + // A background task for hyper_client completed... + task.Free() } } } // ---------------------------------------------------------- +type taskData struct { + taskId taskId + bodyWriter *io.PipeWriter + req *transportRequest + pc *persistConn + addedGzip bool + writeErrCh chan error + callerGone chan struct{} + resc chan responseAndError +} + type connData struct { TcpHandle libuv.Tcp ConnectReq libuv.Connect ReadBuf libuv.Buf - TimeoutTimer libuv.Timer ReadBufFilled uintptr nwrite int64 // bytes written(Replaced from persistConn's nwrite) ReadWaker *hyper.Waker @@ -1238,8 +1392,10 @@ func (conn *connData) Close() error { // onConnect is the libuv callback for a successful connection func onConnect(req *libuv.Connect, status c.Int) { - //conn := (*ConnData)(req.Data) - //conn := (*struct{ data *ConnData })(c.Pointer(req)).data + if debugSwitch { + println("connect start") + defer println("connect end") + } conn := (*connData)((*libuv.Req)(c.Pointer(req)).GetData()) if status < 0 { @@ -1364,11 +1520,25 @@ func writeCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen ui } // onTimeout is the libuv callback for a timeout -func onTimeout(handle *libuv.Timer) { - ct := (*connAndTimeoutChan)((*libuv.Handle)(c.Pointer(handle)).GetData()) - close(ct.timeoutch) - // Close the timer - (*libuv.Handle)(c.Pointer(&ct.conn.TimeoutTimer)).Close(nil) +func onTimeout(timer *libuv.Timer) { + if debugSwitch { + println("onTimeout start") + defer println("onTimeout end") + } + data := (*timeoutData)((*libuv.Handle)(c.Pointer(timer)).GetData()) + close(data.timeoutch) + timer.Stop() + + taskData := data.taskData + if taskData != nil { + pc := taskData.pc + pc.alive = false + pc.t.cancelRequest(taskData.req.cancelKey, errors.New("timeout: req.Context().Err()")) + // defer + pc.close(pc.closeErr) + // TODO(spongehah) ConnPool(onTimeout) + //t.removeIdleConn(pc) + } } // newIoWithConnReadWrite creates a new IO with read and write callbacks @@ -1390,17 +1560,6 @@ const ( readDone ) -type readWaiter struct { - rc requestAndChan - waitForBodyRead chan bool -} - -// setTaskId Set taskId to the task's userdata as a unique identifier -func setTaskId(task *hyper.Task, userData taskId) { - var data = userData - task.SetUserdata(unsafe.Pointer(uintptr(data))) -} - // checkTaskType checks the task type func checkTaskType(task *hyper.Task, curTaskId taskId) error { switch curTaskId { @@ -1455,14 +1614,15 @@ func fail(err *hyper.Error) error { // error values for debugging and testing, not seen by users. var ( - errKeepAlivesDisabled = errors.New("http: putIdleConn: keep alives disabled") - errConnBroken = errors.New("http: putIdleConn: connection is in bad state") - errCloseIdle = errors.New("http: putIdleConn: CloseIdleConnections was called") - errTooManyIdle = errors.New("http: putIdleConn: too many idle connections") - errTooManyIdleHost = errors.New("http: putIdleConn: too many idle connections for host") - errCloseIdleConns = errors.New("http: CloseIdleConnections called") - errReadLoopExiting = errors.New("http: persistConn.readLoop exiting") - errIdleConnTimeout = errors.New("http: idle connection timeout") + errKeepAlivesDisabled = errors.New("http: putIdleConn: keep alives disabled") + errConnBroken = errors.New("http: putIdleConn: connection is in bad state") + errCloseIdle = errors.New("http: putIdleConn: CloseIdleConnections was called") + errTooManyIdle = errors.New("http: putIdleConn: too many idle connections") + errTooManyIdleHost = errors.New("http: putIdleConn: too many idle connections for host") + errCloseIdleConns = errors.New("http: CloseIdleConnections called") + errReadLoopExiting = errors.New("http: Transport.readWriteLoop.read exiting") + errReadWriteLoopExiting = errors.New("http: Transport.readWriteLoop exiting") + errIdleConnTimeout = errors.New("http: idle connection timeout") // errServerClosedIdle is not seen by users for idempotent requests, but may be // seen by a user if the server shuts down an idle connection and sends its FIN @@ -1581,9 +1741,7 @@ type persistConn struct { conn *connData //tlsState *tls.ConnectionState //nwrite int64 // bytes written(Replaced by connData.nwrite) - reqch chan requestAndChan // written by roundTrip; read by readWriteLoop - writech chan writeRequest // written by roundTrip; read by readWriteLoop - closech chan struct{} // closed when conn closed + closech chan struct{} // closed when conn closed isProxy bool writeLoopDone chan struct{} // closed when readWriteLoop ends @@ -1598,8 +1756,9 @@ type persistConn struct { // original Request given to RoundTrip is not modified) mutateHeaderFunc func(Header) - // hyper specific - exec *hyper.Executor + // other + alive bool // Replace the alive in readLoop + closeErr error // Replace the closeErr in readLoop } func (pc *persistConn) cancelRequest(err error) { @@ -1635,13 +1794,10 @@ func (pc *persistConn) closeLocked(err error) { pc.conn.Close() } close(pc.closech) + close(pc.writeLoopDone) } } pc.mutateHeaderFunc = nil - // hyper related - if pc.exec != nil { - pc.exec.Free() - } } // mapRoundTripError returns the appropriate error value for @@ -1742,8 +1898,7 @@ type connectMethod struct { // then targetAddr is not included in the connect method key, because the socket can // be reused for different targetAddr values. targetAddr string - treq *transportRequest // optional - onlyH1 bool // whether to disable HTTP/2 and force HTTP/1 + onlyH1 bool // whether to disable HTTP/2 and force HTTP/1 } // connectMethodKey is the map key version of connectMethod, with a @@ -1808,10 +1963,11 @@ func (cm *connectMethod) proxyAuth() string { // These three options are racing against each other and use // wantConn to coordinate and agree about the winning outcome. type wantConn struct { - cm connectMethod - key connectMethodKey // cm.key() - ctx context.Context // context for dial - ready chan struct{} // closed when pc, err pair is delivered + cm connectMethod + key connectMethodKey // cm.key() + ctx context.Context // context for dial + timeoutch chan struct{} // tmp timeout to replace ctx + ready chan struct{} // closed when pc, err pair is delivered // hooks for testing to know when dials are done // beforeDial is called in the getConn goroutine when the dial is queued. @@ -1866,6 +2022,11 @@ func (w *wantConn) tryDeliver(pc *persistConn, err error) bool { if w.pc == nil && w.err == nil { panic("net/http: internal error: misuse of tryDeliver") } + select { + case <-w.timeoutch: + pc.close(errors.New("request timeout: dialConn timeout")) + default: + } close(w.ready) return true } diff --git a/x/net/http/util.go b/x/net/http/util.go index bfd9fc3..bec22a8 100644 --- a/x/net/http/util.go +++ b/x/net/http/util.go @@ -7,7 +7,7 @@ import ( "golang.org/x/net/idna" - "github.com/goplus/llgoexamples/x/net" + "github.com/goplus/llgo/x/net" ) /** From 09cd8fe8873df5894a0e0e7285af3c280933e1c8 Mon Sep 17 00:00:00 2001 From: hackerchai Date: Tue, 3 Sep 2024 19:27:10 +0800 Subject: [PATCH 25/55] refactor(x/net/http/demo): Update http demo Signed-off-by: hackerchai --- x/net/http/_demo/http.go | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/x/net/http/_demo/http.go b/x/net/http/_demo/http.go index 34240cd..36920d0 100644 --- a/x/net/http/_demo/http.go +++ b/x/net/http/_demo/http.go @@ -2,22 +2,26 @@ package main import ( "fmt" - "io" + //"io" "github.com/goplus/llgo/x/net/http" ) func echoHandler(w http.ResponseWriter, r *http.Request) { - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, "Error reading request body", http.StatusInternalServerError) - return - } - //defer r.Body.Close() + fmt.Printf("echoHandler called\n") + //TODO: read body and echo + // body, err := io.ReadAll(r.Body) + // if err != nil { + // http.Error(w, "Error reading request body", http.StatusInternalServerError) + // return + // } + // defer r.Body.Close() + // fmt.Printf("body: %s\n", string(body)) + //w.Header().Set("Content-Type", "text/plain") + //w.Write(body) w.Header().Set("Content-Type", "text/plain") - - w.Write(body) + w.Write([]byte("echoHandler called\n")) } func main() { From 15e84fb7bb8494ea3fd2c934a71edba61f150022 Mon Sep 17 00:00:00 2001 From: hackerchai Date: Tue, 3 Sep 2024 19:27:57 +0800 Subject: [PATCH 26/55] refactor(x/net/http): Update new server logic Signed-off-by: hackerchai --- x/net/http/server.go | 350 +++++++++++++++++++--------------------- x/net/http/servermux.go | 4 +- 2 files changed, 173 insertions(+), 181 deletions(-) diff --git a/x/net/http/server.go b/x/net/http/server.go index 909c8f6..c0578fd 100644 --- a/x/net/http/server.go +++ b/x/net/http/server.go @@ -1,6 +1,7 @@ package http import ( + "errors" "fmt" "os" "strconv" @@ -31,30 +32,33 @@ type Server struct { Addr string Handler Handler - uvLoop *libuv.Loop - uvServer libuv.Tcp - inShutdown atomic.Bool + uvLoop *libuv.Loop + uvServer libuv.Tcp + inShutdown atomic.Bool + http1Opts *hyper.Http1ServerconnOptions + http2Opts *hyper.Http2ServerconnOptions + checkHandle libuv.Check mu sync.Mutex activeConnections map[*conn]struct{} } type conn struct { - Stream *libuv.Tcp - PollHandle *libuv.Poll - EventMask c.Uint - ReadWaker *hyper.Waker - WriteWaker *hyper.Waker - ConnTask *hyper.Task - IsClosing c.Int - Executor *hyper.Executor + Stream libuv.Tcp + PollHandle libuv.Poll + EventMask c.Uint + ReadWaker *hyper.Waker + WriteWaker *hyper.Waker + IsClosing atomic.Bool + ClosedHandles int32 + Executor *hyper.Executor } type serviceUserdata struct { - Host [128]c.Char - Port [8]c.Char - Server *Server - Conn *conn + Host [128]c.Char + Port [8]c.Char + Conn *conn + Server *Server ListenAddr string } @@ -67,6 +71,10 @@ func NewServer(addr string) *Server { } } +// ErrServerClosed is returned by the [Server.Serve], [ServeTLS], [ListenAndServe], +// and [ListenAndServeTLS] methods after a call to [Server.Shutdown] or [Server.Close]. +var ErrServerClosed = errors.New("http: Server closed") + func ListenAndServe(addr string, handler Handler) error { server := &Server{Addr: addr, Handler: handler} return server.ListenAndServe() @@ -78,8 +86,8 @@ func (srv *Server) ListenAndServe() error { return fmt.Errorf("failed to get default loop") } - if err := libuv.InitTcp(srv.uvLoop, &srv.uvServer); err != 0 { - return fmt.Errorf("failed to init TCP: %v", err) + if r := libuv.InitTcp(srv.uvLoop, &srv.uvServer); r != 0 { + return fmt.Errorf("failed to init TCP: %v", libuv.Strerror(libuv.Errno(r))) } host, port, err := net.SplitHostPort(srv.Addr) @@ -93,12 +101,12 @@ func (srv *Server) ListenAndServe() error { } var sockaddr cnet.SockaddrIn - if err := libuv.Ip4Addr(c.AllocaCStr(host), c.Int(portNum), &sockaddr); err != 0 { - return fmt.Errorf("failed to create IP address: %v", err) + if r := libuv.Ip4Addr(c.AllocaCStr(host), c.Int(portNum), &sockaddr); r != 0 { + return fmt.Errorf("failed to create IP address: %v", libuv.Strerror(libuv.Errno(r))) } - if err := srv.uvServer.Bind((*cnet.SockAddr)(unsafe.Pointer(&sockaddr)), 0); err != 0 { - return fmt.Errorf("failed to bind: %v", err) + if r := srv.uvServer.Bind((*cnet.SockAddr)(unsafe.Pointer(&sockaddr)), 0); r != 0 { + return fmt.Errorf("failed to bind: %v", libuv.Strerror(libuv.Errno(r))) } // Set SO_REUSEADDR @@ -114,6 +122,18 @@ func (srv *Server) ListenAndServe() error { return fmt.Errorf("failed to listen: %v", err) } + if r := libuv.InitCheck(srv.uvLoop, &srv.checkHandle); r != 0 { + fmt.Fprintf(os.Stderr, "Failed to initialize check handler: %d\n", r) + os.Exit(1) + } + + (*libuv.Handle)(unsafe.Pointer(&srv.checkHandle)).SetData(unsafe.Pointer(srv)) + + if r := srv.checkHandle.Start(onCheck); r != 0 { + fmt.Fprintf(os.Stderr, "Failed to start check handler: %d\n", r) + os.Exit(1) + } + fmt.Printf("Listening on %s\n", srv.Addr) for { @@ -123,16 +143,8 @@ func (srv *Server) ListenAndServe() error { break } - for conn := range srv.activeConnections { - fmt.Printf("Active connection found\n") - if conn.Executor != nil { - task := conn.Executor.Poll() - for task != nil { - srv.handleTask(task) - task.Free() - task = conn.Executor.Poll() - } - } + if srv.shuttingDown() { + break } } return nil @@ -155,17 +167,37 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { return } - client := (*libuv.Tcp)(c.Malloc(unsafe.Sizeof(libuv.Tcp{}))) - libuv.InitTcp(srv.uvLoop, client) + conn, err := createConnData() + if err != nil { + fmt.Fprintf(os.Stderr, "Failed to create Conn: %v\n", err) + return + } + + libuv.InitTcp(srv.uvLoop, &conn.Stream) + conn.Stream.Data = unsafe.Pointer(conn) - if serverStream.Accept((*libuv.Stream)(unsafe.Pointer(client))) == 0 { + if serverStream.Accept((*libuv.Stream)(unsafe.Pointer(&conn.Stream))) == 0 { fmt.Println("Accepted new connection") + r := libuv.PollInit(srv.uvLoop, &conn.PollHandle, libuv.OsFd(conn.Stream.GetIoWatcherFd())) + if r < 0 { + fmt.Fprintf(os.Stderr, "uv_poll_init error: %s\n", libuv.Strerror(libuv.Errno(r))) + (*libuv.Handle)(unsafe.Pointer(&conn.Stream)).Close(nil) + return + } + + (*libuv.Handle)(unsafe.Pointer(&conn.PollHandle)).Data = unsafe.Pointer(conn) + + if !updateConnRegistrations(conn, true) { + (*libuv.Handle)(unsafe.Pointer(&conn.PollHandle)).Close(nil) + (*libuv.Handle)(unsafe.Pointer(&conn.Stream)).Close(nil) + return + } + userdata := createServiceUserdata() userdata.Server = srv if userdata == nil { fmt.Fprintf(os.Stderr, "Failed to create service userdata\n") - (*libuv.Handle)(unsafe.Pointer(client)).Close(onClose) - freeServiceUserdata(unsafe.Pointer(userdata)) + (*libuv.Handle)(unsafe.Pointer(&conn.Stream)).Close(nil) return } fmt.Printf("ListenAddr: %s\n", srv.Addr) @@ -173,7 +205,7 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { var addr cnet.SockaddrStorage addrlen := c.Int(unsafe.Sizeof(addr)) - client.Getpeername((*cnet.SockAddr)(c.Pointer(&addr)), &addrlen) + conn.Stream.Getpeername((*cnet.SockAddr)(c.Pointer(&addr)), &addrlen) if addr.Family == cnet.AF_INET { s := (*cnet.SockaddrIn)(unsafe.Pointer(&addr)) @@ -185,50 +217,81 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { c.Snprintf((*c.Char)(&userdata.Port[0]), unsafe.Sizeof(userdata.Port), c.Str("%d"), cnet.Ntohs(s.Port)) } - fmt.Printf("New incoming connection from (%s:%s)\n", c.GoString((*c.Char)(&userdata.Host[0])), - c.GoString((*c.Char)(&userdata.Port[0]))) - - conn := createConnData(srv.uvLoop, client) - if conn == nil { - fmt.Fprintf(os.Stderr, "Failed to create Conn\n") - (*libuv.Handle)(unsafe.Pointer(client)).Close(onClose) - freeServiceUserdata(unsafe.Pointer(userdata)) - return - } - executor := hyper.NewExecutor() if executor == nil { fmt.Fprintf(os.Stderr, "Failed to create Executor\n") - (*libuv.Handle)(unsafe.Pointer(client)).Close(onClose) - freeServiceUserdata(unsafe.Pointer(userdata)) + (*libuv.Handle)(unsafe.Pointer(&conn.PollHandle)).Close(nil) + (*libuv.Handle)(unsafe.Pointer(&conn.Stream)).Close(nil) return } conn.Executor = executor - userdata.Conn = conn - fmt.Println("Conn created") srv.trackConn(conn, true) fmt.Println("Conn tracked") + userdata.Conn = conn + io := createIo(conn) service := hyper.ServiceNew(serverCallback) - service.SetUserdata(unsafe.Pointer(userdata), freeServiceUserdata) + service.SetUserdata(unsafe.Pointer(userdata), nil) http1Opts := hyper.Http1ServerconnOptionsNew(conn.Executor) + if http1Opts == nil { + fmt.Fprintf(os.Stderr, "Failed to create http1_opts\n") + os.Exit(1) + } + result := http1Opts.HeaderReadTimeout(5 * 1000) + if result != hyper.OK { + fmt.Fprintf(os.Stderr, "Failed to set header read timeout for http1_opts\n") + os.Exit(1) + } + srv.http1Opts = http1Opts + http2Opts := hyper.Http2ServerconnOptionsNew(conn.Executor) + if http2Opts == nil { + fmt.Fprintf(os.Stderr, "Failed to create http2_opts\n") + os.Exit(1) + } + result = http2Opts.KeepAliveInterval(5) + if result != hyper.OK { + fmt.Fprintf(os.Stderr, "Failed to set keep alive interval for http2_opts\n") + os.Exit(1) + } + result = http2Opts.KeepAliveTimeout(5) + if result != hyper.OK { + fmt.Fprintf(os.Stderr, "Failed to set keep alive timeout for http2_opts\n") + os.Exit(1) + } + srv.http2Opts = http2Opts serverconn := hyper.ServeHttpXConnection(http1Opts, http2Opts, io, service) conn.Executor.Push(serverconn) - - http1Opts.Free() - http2Opts.Free() } else { fmt.Println("Client not accepted") - (*libuv.Handle)(unsafe.Pointer(client)).Close(nil) + (*libuv.Handle)(unsafe.Pointer(&conn.PollHandle)).Close(nil) + (*libuv.Handle)(unsafe.Pointer(&conn.Stream)).Close(nil) } } +func onCheck(handle *libuv.Check) { + srv := (*Server)((*libuv.Handle)(unsafe.Pointer(handle)).GetData()) + for conn := range srv.activeConnections { + if conn.Executor != nil { + task := conn.Executor.Poll() + for task != nil { + srv.handleTask(task) + task = conn.Executor.Poll() + } + } + } + + if srv.shuttingDown() { + fmt.Println("Shutdown initiated, cleaning up...") + handle.Stop() + } +} + func serverCallback(userdata unsafe.Pointer, hyperReq *hyper.Request, channel *hyper.ResponseChannel) { userData := (*serviceUserdata)(userdata) @@ -252,42 +315,25 @@ func serverCallback(userdata unsafe.Pointer, hyperReq *hyper.Request, channel *h } func (srv *Server) handleTask(task *hyper.Task) { - taskUserdata := task.Userdata() - switch task.Type() { - case hyper.TaskEmpty: - fmt.Println("New server connection") - if taskUserdata != nil { - conn := (*conn)(taskUserdata) - if conn.IsClosing == 0 { - conn.IsClosing = 1 - if (*libuv.Handle)(unsafe.Pointer(&conn.PollHandle)).IsClosing() == 0 { - (*libuv.Handle)(unsafe.Pointer(&conn.PollHandle)).Close(nil) - } - if (*libuv.Handle)(unsafe.Pointer(&conn.Stream)).IsClosing() == 0 { - (*libuv.Handle)(unsafe.Pointer(&conn.Stream)).Close(closeConn) - } - } - } - case hyper.TaskError: + taskType := task.Type() + if taskType == hyper.TaskError { + fmt.Println("hyper task failed with error!") + err := (*hyper.Error)(task.Value()) + fmt.Printf("error code: %d\n", err.Code()) + var errbuf [256]byte errlen := err.Print(&errbuf[0], unsafe.Sizeof(errbuf)) - fmt.Printf("Task error: %.*s\n", errlen, (*c.Char)(unsafe.Pointer(&errbuf[0]))) - err.Free() - - case hyper.TaskClientConn: - fmt.Fprintf(os.Stderr, "Unexpected HYPER_TASK_CLIENTCONN in server context\n") + fmt.Printf("details: %s\n", errbuf[:errlen]) - case hyper.TaskResponse: - fmt.Println("Response task received") - - case hyper.TaskBuf: - fmt.Println("Buffer task received") - - case hyper.TaskServerconn: - fmt.Println("Server connection task received: ready for new connection...") - default: - fmt.Fprintf(os.Stderr, "Unknown task type: %d\n", task.Type()) + err.Free() + task.Free() + } else if taskType == hyper.TaskEmpty { + fmt.Println("internal hyper task complete") + task.Free() + } else if taskType == hyper.TaskServerconn { + fmt.Println("server connection task complete") + task.Free() } } @@ -313,11 +359,11 @@ func createIo(conn *conn) *hyper.Io { } func createServiceUserdata() *serviceUserdata { - userdata := (*serviceUserdata)(c.Calloc(1, unsafe.Sizeof(serviceUserdata{}))) - if userdata == nil { - fmt.Fprintf(os.Stderr, "Failed to allocate service_userdata\n") - } - return userdata + userdata := &serviceUserdata{} + if userdata == nil { + fmt.Fprintf(os.Stderr, "Failed to allocate service_userdata\n") + } + return userdata } func readCb(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uintptr) uintptr { @@ -377,9 +423,6 @@ func writeCb(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uint return hyper.IoPending } -func onClose(handle *libuv.Handle) { - c.Free(unsafe.Pointer(handle)) -} func onPoll(handle *libuv.Poll, status c.Int, events c.Int) { fmt.Printf("onPoll called\n") @@ -403,10 +446,6 @@ func onPoll(handle *libuv.Poll, status c.Int, events c.Int) { func updateConnRegistrations(conn *conn, create bool) bool { fmt.Println("updateConnRegistrations called") - if conn == nil || conn.PollHandle == nil { - fmt.Fprintf(os.Stderr, "Poll handle is nil\n") - return false - } events := c.Int(0) if conn.EventMask == 0 { @@ -422,10 +461,6 @@ func updateConnRegistrations(conn *conn, create bool) bool { } fmt.Printf("Starting poll with events: %d\n", events) - if conn.PollHandle == nil { - fmt.Fprintf(os.Stderr, "Poll handle is nil\n") - return false - } r := conn.PollHandle.Start(events, onPoll) //fmt.Println("Poll handle started: %d", r) if r < 0 { @@ -436,68 +471,21 @@ func updateConnRegistrations(conn *conn, create bool) bool { return true } -func createConnData(loop *libuv.Loop, client *libuv.Tcp) *conn { - conn := (*conn)(c.Calloc(1, unsafe.Sizeof(conn{}))) +func createConnData() (*conn, error) { + conn := &conn{} if conn == nil { - fmt.Fprintf(os.Stderr, "Failed to allocate conn_data\n") - return nil - } - fmt.Println("Conn data created") - c.Memcpy(unsafe.Pointer(&conn.Stream), unsafe.Pointer(client), unsafe.Sizeof(libuv.Tcp{})) - conn.IsClosing = 0 - conn.EventMask = 0 - - fmt.Println("Conn data initialized") - - conn.PollHandle = (*libuv.Poll)(c.Malloc(unsafe.Sizeof(libuv.Poll{}))) - if conn.PollHandle == nil { - fmt.Fprintf(os.Stderr, "Failed to allocate poll handle\n") - c.Free(unsafe.Pointer(conn)) - return nil - } - fmt.Println("Poll handle allocated") - - fmt.Printf("Io Watcher Fd: %d\n", client.GetIoWatcherFd()) - fd := client.GetIoWatcherFd() - if fd < 0 { - fmt.Fprintf(os.Stderr, "Invalid file descriptor\n") - c.Free(unsafe.Pointer(conn)) - return nil - } - r := libuv.PollInit(loop, conn.PollHandle, libuv.OsFd(client.GetIoWatcherFd())) - if r < 0 { - fmt.Fprintf(os.Stderr, "uv_poll_init error: %s\n", libuv.Strerror(libuv.Errno(r))) - c.Free(unsafe.Pointer(conn)) - return nil - } - fmt.Println("Poll handle initialized") - - (*libuv.Handle)(unsafe.Pointer(&conn.PollHandle)).SetData(unsafe.Pointer(conn)) - fmt.Println("Poll handle data set") - (*libuv.Handle)(unsafe.Pointer(&conn.Stream)).SetData(unsafe.Pointer(conn)) - fmt.Println("Stream data set") - - if !updateConnRegistrations(conn, true) { - (*libuv.Handle)(unsafe.Pointer(conn.PollHandle)).Close(nil) - c.Free(unsafe.Pointer(conn)) - return nil + return nil, fmt.Errorf("failed to allocate conn_data") } + conn.IsClosing.Store(false) + conn.ClosedHandles = 0 - return conn + return conn, nil } func freeConnData(userdata c.Pointer) { conn := (*conn)(userdata) - if conn != nil && conn.IsClosing == 0 { - conn.IsClosing = 1 - // We don't immediately close the connection here. - // Instead, we'll let the main loop handle the closure when appropriate. - } -} - -func closeConn(handle *libuv.Handle) { - conn := (*conn)(handle.GetData()) - if conn != nil { + if conn != nil && !conn.IsClosing.Swap(true){ + fmt.Printf("Closing connection...\n") if conn.ReadWaker != nil { conn.ReadWaker.Free() conn.ReadWaker = nil @@ -506,28 +494,18 @@ func closeConn(handle *libuv.Handle) { conn.WriteWaker.Free() conn.WriteWaker = nil } - if conn.ConnTask != nil { - conn.ConnTask.Free() - conn.ConnTask = nil - } - if conn.Executor != nil { - conn.Executor.Free() - conn.Executor = nil - } - c.Free(unsafe.Pointer(conn)) - } - c.Free(unsafe.Pointer(handle)) -} -func freeServiceUserdata(userdata c.Pointer) { - castUserdata := (*serviceUserdata)(userdata) - if castUserdata != nil { - // Note: We don't free conn here because it's managed separately - freeConnData(unsafe.Pointer(castUserdata.Conn)) - c.Free(userdata) + if (*libuv.Handle)(unsafe.Pointer(&conn.PollHandle)).IsClosing() == 0 { + (*libuv.Handle)(unsafe.Pointer(&conn.PollHandle)).Close(nil) + } + + if (*libuv.Handle)(unsafe.Pointer(&conn.Stream)).IsClosing() == 0 { + (*libuv.Handle)(unsafe.Pointer(&conn.Stream)).Close(nil) + } } } + func closeWalkCb(handle *libuv.Handle, arg c.Pointer) { if handle.IsClosing() == 0 { handle.Close(nil) @@ -540,21 +518,33 @@ func (srv *Server) Close() error { defer srv.mu.Unlock() for c := range srv.activeConnections { + if c.Executor != nil { + c.Executor.Free() + } delete(srv.activeConnections, c) - freeConnData(unsafe.Pointer(c)) } srv.uvLoop.Walk(closeWalkCb, nil) srv.uvLoop.Run(libuv.RUN_DEFAULT) srv.uvLoop.Close() + + if srv.http1Opts != nil { + srv.http1Opts.Free() + } + if srv.http2Opts != nil { + srv.http2Opts.Free() + } return nil } +func (s *Server) shuttingDown() bool { + return s.inShutdown.Load() +} + type HandlerFunc func(ResponseWriter, *Request) func (f HandlerFunc) ServeHTTP(w ResponseWriter, r *Request) { - fmt.Printf("ServeHTTP called\n") f(w, r) } diff --git a/x/net/http/servermux.go b/x/net/http/servermux.go index ece33c6..21d9b20 100644 --- a/x/net/http/servermux.go +++ b/x/net/http/servermux.go @@ -20,7 +20,8 @@ var DefaultServeMux = &ServeMux{m: make(map[string]muxEntry)} func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) { fmt.Printf("ServeHTTP called\n") - h, _ := mux.Handler(r) + h, pattern := mux.Handler(r) + fmt.Printf("Handler found for pattern: %s\n", pattern) h.ServeHTTP(w, r) } @@ -40,6 +41,7 @@ func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Re } func (mux *ServeMux) Handle(pattern string, handler Handler) { + fmt.Printf("Handle called with pattern: %s\n", pattern) mux.mu.Lock() defer mux.mu.Unlock() From 6b6cfdd6d96a47211ad8a5cdfd49ea4174fa135d Mon Sep 17 00:00:00 2001 From: hackerchai Date: Tue, 3 Sep 2024 19:28:41 +0800 Subject: [PATCH 27/55] refactor(x/net/http): Rewrite request implementation use host header Signed-off-by: hackerchai --- x/net/http/request.go | 71 +++++++++++++++++++++++++------------------ 1 file changed, 41 insertions(+), 30 deletions(-) diff --git a/x/net/http/request.go b/x/net/http/request.go index 946200f..b991420 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -29,6 +29,36 @@ type Request struct { } func newRequest(ListenAddr string, conn *conn, hyperReq *hyper.Request) (*Request, error) { + req := Request{ + Header: make(Header), + timeout: 0, + } + + headers := hyperReq.Headers() + if headers != nil { + headers.Foreach(addHeader, unsafe.Pointer(&req)) + } else { + return nil, fmt.Errorf("failed to get request headers") + } + + fmt.Printf("Headers:\n") + for key, values := range req.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } + + var host string + for key, values := range req.Header { + if strings.EqualFold(key, "Host") { + if len(values) > 0 { + host = values[0] + break + } + } + + } + method := make([]byte, 32) methodLen := unsafe.Sizeof(method) if err := hyperReq.Method(&method[0], &methodLen); err != hyper.OK { @@ -53,7 +83,7 @@ func newRequest(ListenAddr string, conn *conn, hyperReq *hyper.Request) (*Reques } if authorityLen == 0 { - authorityStr = ListenAddr + authorityStr = host } else { authorityStr = string(authority[:authorityLen]) } @@ -63,7 +93,8 @@ func newRequest(ListenAddr string, conn *conn, hyperReq *hyper.Request) (*Reques } else { pathAndQueryStr = string(pathAndQuery[:pathAndQueryLen]) } - + req.Host = authorityStr + req.Method = methodStr var proto string var protoMajor, protoMinor int @@ -89,43 +120,28 @@ func newRequest(ListenAddr string, conn *conn, hyperReq *hyper.Request) (*Reques default: return nil, fmt.Errorf("unknown HTTP version: %d", version) } + req.Proto = proto + req.ProtoMajor = protoMajor + req.ProtoMinor = protoMinor - urlStr := fmt.Sprintf("%s://%s%s", schemeStr, authorityStr, pathAndQueryStr) + urlStr := fmt.Sprintf("%s://%s%s", schemeStr, host, pathAndQueryStr) fmt.Printf("URL: %s\n", urlStr) url, err := url.Parse(urlStr) if err != nil { return nil, err } - - req := Request{ - Method: methodStr, - URL: url, - Proto: proto, - ProtoMajor: protoMajor, - ProtoMinor: protoMinor, - Header: make(Header), - Host: authorityStr, - timeout: 0, - } - - headers := hyperReq.Headers() - if headers != nil { - headers.Foreach(addHeader, unsafe.Pointer(&req)) - } else { - return nil, fmt.Errorf("failed to get request headers") - } + req.URL = url if methodStr == "POST" || methodStr == "PUT" || methodStr == "PATCH" { body := hyperReq.Body() if body != nil { - bodyWriter := new(io.PipeWriter) + var bodyWriter *io.PipeWriter req.Body, bodyWriter = io.Pipe() - - - task := body.Foreach(getBodyChunk, c.Pointer(&bodyWriter), freeBodyWriter) + task := body.Foreach(getBodyChunk, c.Pointer(&bodyWriter), nil) if task != nil { r := conn.Executor.Push(task) if r != hyper.OK { + fmt.Printf("failed to push body foreach task: %d\n", r) task.Free() return nil, fmt.Errorf("failed to push body foreach task: %v", r) } @@ -165,8 +181,3 @@ func getBodyChunk(userdata c.Pointer, chunk *hyper.Buf) c.Int { return hyper.IterContinue } - -func freeBodyWriter(userdata c.Pointer) { - writer := (*io.PipeWriter)(userdata) - writer.Close() -} From 0d8cc277f2688593442b82f3f9f317bd4b9271b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E8=8B=B1=E6=9D=B0?= <2635879218@qq.com> Date: Wed, 4 Sep 2024 18:09:04 +0800 Subject: [PATCH 28/55] WIP(x/net/http/client): Implement IdleConnPool --- x/net/http/_demo/get/get.go | 2 +- x/net/http/_demo/headers/headers.go | 2 +- .../_demo/maxConnsPerHost/maxConnsPerHost.go | 2 +- x/net/http/_demo/post/post.go | 2 +- x/net/http/_demo/postform/postform.go | 2 - x/net/http/_demo/redirect/redirect.go | 2 +- x/net/http/_demo/reuseConn/reuseConn.go | 42 + x/net/http/_demo/timeout/timeout.go | 12 +- x/net/http/client.go | 8 +- x/net/http/request.go | 73 +- x/net/http/transport.go | 974 +++++++++++++----- x/net/http/util.go | 2 +- 12 files changed, 832 insertions(+), 291 deletions(-) create mode 100644 x/net/http/_demo/reuseConn/reuseConn.go diff --git a/x/net/http/_demo/get/get.go b/x/net/http/_demo/get/get.go index 6bc5b06..6e91bd4 100644 --- a/x/net/http/_demo/get/get.go +++ b/x/net/http/_demo/get/get.go @@ -13,6 +13,7 @@ func main() { fmt.Println(err) return } + defer resp.Body.Close() fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) resp.PrintHeaders() body, err := io.ReadAll(resp.Body) @@ -21,5 +22,4 @@ func main() { return } fmt.Println(string(body)) - defer resp.Body.Close() } diff --git a/x/net/http/_demo/headers/headers.go b/x/net/http/_demo/headers/headers.go index aa2e5d6..5538923 100644 --- a/x/net/http/_demo/headers/headers.go +++ b/x/net/http/_demo/headers/headers.go @@ -36,6 +36,7 @@ func main() { println(err.Error()) return } + defer resp.Body.Close() fmt.Println(resp.Status) resp.PrintHeaders() body, err := io.ReadAll(resp.Body) @@ -44,5 +45,4 @@ func main() { return } fmt.Println(string(body)) - defer resp.Body.Close() } diff --git a/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go b/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go index 882bdc1..5662251 100644 --- a/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go +++ b/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go @@ -19,6 +19,7 @@ func main() { fmt.Println(err) return } + defer resp.Body.Close() fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) fmt.Println(resp.Proto) resp.PrintHeaders() @@ -28,5 +29,4 @@ func main() { return } fmt.Println(string(body)) - defer resp.Body.Close() } diff --git a/x/net/http/_demo/post/post.go b/x/net/http/_demo/post/post.go index f169dfc..fd756b3 100644 --- a/x/net/http/_demo/post/post.go +++ b/x/net/http/_demo/post/post.go @@ -15,6 +15,7 @@ func main() { fmt.Println(err) return } + defer resp.Body.Close() fmt.Println(resp.Status) body, err := io.ReadAll(resp.Body) if err != nil { @@ -22,5 +23,4 @@ func main() { return } fmt.Println(string(body)) - defer resp.Body.Close() } diff --git a/x/net/http/_demo/postform/postform.go b/x/net/http/_demo/postform/postform.go index eae4d6e..232c15d 100644 --- a/x/net/http/_demo/postform/postform.go +++ b/x/net/http/_demo/postform/postform.go @@ -20,12 +20,10 @@ func main() { return } defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) if err != nil { fmt.Println(err) return } fmt.Println(string(body)) - defer resp.Body.Close() } diff --git a/x/net/http/_demo/redirect/redirect.go b/x/net/http/_demo/redirect/redirect.go index e4fdb92..f189255 100644 --- a/x/net/http/_demo/redirect/redirect.go +++ b/x/net/http/_demo/redirect/redirect.go @@ -13,6 +13,7 @@ func main() { fmt.Println(err) return } + defer resp.Body.Close() fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) fmt.Println(resp.Proto) resp.PrintHeaders() @@ -22,5 +23,4 @@ func main() { return } fmt.Println(string(body)) - defer resp.Body.Close() } diff --git a/x/net/http/_demo/reuseConn/reuseConn.go b/x/net/http/_demo/reuseConn/reuseConn.go new file mode 100644 index 0000000..bb460ce --- /dev/null +++ b/x/net/http/_demo/reuseConn/reuseConn.go @@ -0,0 +1,42 @@ +package main + +import ( + "fmt" + "io" + + "github.com/goplus/llgoexamples/x/net/http" +) + +func main() { + // Send request first time + resp, err := http.Get("https://www.baidu.com") + if err != nil { + fmt.Println(err) + return + } + fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) + resp.PrintHeaders() + body, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(string(body)) + resp.Body.Close() + + // Send request second time + resp, err = http.Get("https://www.baidu.com") + if err != nil { + fmt.Println(err) + return + } + fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) + resp.PrintHeaders() + body, err = io.ReadAll(resp.Body) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(string(body)) + resp.Body.Close() +} diff --git a/x/net/http/_demo/timeout/timeout.go b/x/net/http/_demo/timeout/timeout.go index 62b2c9d..a6930b1 100644 --- a/x/net/http/_demo/timeout/timeout.go +++ b/x/net/http/_demo/timeout/timeout.go @@ -10,24 +10,24 @@ import ( func main() { client := &http.Client{ - //Timeout: time.Millisecond, // Set a small timeout to ensure it will time out - Timeout: time.Second * 5, + Timeout: time.Millisecond, // Set a small timeout to ensure it will time out + //Timeout: time.Second, } req, err := http.NewRequest("GET", "https://www.baidu.com", nil) if err != nil { - fmt.Println(err.Error()) + fmt.Println(err) return } resp, err := client.Do(req) if err != nil { - fmt.Println(err.Error()) + fmt.Println(err) return } + defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - fmt.Println(err.Error()) + fmt.Println(err) return } println(string(body)) - defer resp.Body.Close() } diff --git a/x/net/http/client.go b/x/net/http/client.go index d56f5f2..002397a 100644 --- a/x/net/http/client.go +++ b/x/net/http/client.go @@ -241,7 +241,6 @@ func (c *Client) do(req *Request) (retres *Response, reterr error) { // didTimeout is non-nil only if err != nil. func (c *Client) send(req *Request, deadline time.Time) (resp *Response, didTimeout func() bool, err error) { - // TODO(spongehah) cookie(c.send) if c.Jar != nil { for _, cookie := range c.Jar.Cookies(req.URL) { req.AddCookie(cookie) @@ -309,13 +308,16 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, d } // TODO(spongehah) timeout(send) + //stopTimer, didTimeout := setRequestCancel(req, rt, deadline) + req.timeoutch = make(chan struct{}, 1) req.deadline = deadline + req.ctx.Done() if deadline.IsZero() { didTimeout = alwaysFalse + defer close(req.timeoutch) } else { didTimeout = func() bool { return req.timer.GetDueIn() == 0 } } - //stopTimer, didTimeout := setRequestCancel(req, rt, deadline) resp, err = rt.RoundTrip(req) if err != nil { @@ -488,7 +490,7 @@ func knownRoundTripperImpl(rt RoundTripper, req *Request) bool { return knownRoundTripperImpl(altRT, req) } return true - // TODO(spongehah) + // TODO(spongehah) http2 //case *http2Transport, http2noDialH2RoundTripper: // return true } diff --git a/x/net/http/request.go b/x/net/http/request.go index 658b033..c5146ed 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -50,6 +50,30 @@ type Request struct { const defaultChunkSize = 8192 +// NOTE: This is not intended to reflect the actual Go version being used. +// It was changed at the time of Go 1.1 release because the former User-Agent +// had ended up blocked by some intrusion detection systems. +// See https://codereview.appspot.com/7532043. +const defaultUserAgent = "Go-http-client/1.1" + +// errMissingHost is returned by Write when there is no Host or URL present in +// the Request. +var errMissingHost = errors.New("http: Request.Write on Request with no Host or URL set") + +// Headers that Request.Write handles itself and should be skipped. +var reqWriteExcludeHeader = map[string]bool{ + "Host": true, // not in Header map anyway + "User-Agent": true, + "Content-Length": true, + "Transfer-Encoding": true, + "Trailer": true, +} + +// requestBodyReadError wraps an error from (*Request).write to indicate +// that the error came from a Read call on the Request.Body. +// This error type should not escape the net/http package to users. +type requestBodyReadError struct{ error } + // NewRequest wraps NewRequestWithContext using context.Background. func NewRequest(method, url string, body io.Reader) (*Request, error) { return NewRequestWithContext(context.Background(), method, url, body) @@ -188,6 +212,22 @@ func (r *Request) closeBody() error { return r.Body.Close() } +func (r *Request) isReplayable() bool { + if r.Body == nil || r.Body == NoBody || r.GetBody != nil { + switch valueOrDefault(r.Method, "GET") { + case "GET", "HEAD", "OPTIONS", "TRACE": + return true + } + // The Idempotency-Key, while non-standard, is widely used to + // mean a POST or other request is idempotent. See + // https://golang.org/issue/19943#issuecomment-421092421 + if r.Header.has("Idempotency-Key") || r.Header.has("X-Idempotency-Key") { + return true + } + } + return false +} + // Context returns the request's context. To change the context, use // Clone or WithContext. // @@ -240,25 +280,6 @@ func (r *Request) ProtoAtLeast(major, minor int) bool { r.ProtoMajor == major && r.ProtoMinor >= minor } -// errMissingHost is returned by Write when there is no Host or URL present in -// the Request. -var errMissingHost = errors.New("http: Request.Write on Request with no Host or URL set") - -// NOTE: This is not intended to reflect the actual Go version being used. -// It was changed at the time of Go 1.1 release because the former User-Agent -// had ended up blocked by some intrusion detection systems. -// See https://codereview.appspot.com/7532043. -const defaultUserAgent = "Go-http-client/1.1" - -// Headers that Request.Write handles itself and should be skipped. -var reqWriteExcludeHeader = map[string]bool{ - "Host": true, // not in Header map anyway - "User-Agent": true, - "Content-Length": true, - "Transfer-Encoding": true, - "Trailer": true, -} - // extraHeaders may be nil // waitForContinue may be nil // always closes body @@ -279,7 +300,6 @@ func (r *Request) write(client *hyper.ClientConn, taskData *taskData, exec *hype } // Send it! sendTask := client.Send(hyperReq) - taskData.taskId = read sendTask.SetUserdata(c.Pointer(taskData)) sendRes := exec.Push(sendTask) if sendRes != hyper.OK { @@ -482,11 +502,6 @@ func readCookies(h Header, filter string) []*Cookie { return cookies } -// requestBodyReadError wraps an error from (*Request).write to indicate -// that the error came from a Read call on the Request.Body. -// This error type should not escape the net/http package to users. -type requestBodyReadError struct{ error } - func idnaASCII(v string) (string, error) { // TODO: Consider removing this check after verifying performance is okay. // Right now punycode verification, length checks, context checks, and the @@ -519,3 +534,11 @@ func removeZone(host string) string { } return host[:j] + host[i:] } + +// Return value if nonempty, def otherwise. +func valueOrDefault(value, def string) string { + if value != "" { + return value + } + return def +} diff --git a/x/net/http/transport.go b/x/net/http/transport.go index 062cbe9..44d721d 100644 --- a/x/net/http/transport.go +++ b/x/net/http/transport.go @@ -2,6 +2,7 @@ package http import ( "compress/gzip" + "container/list" "context" "errors" "fmt" @@ -17,8 +18,8 @@ import ( "github.com/goplus/llgo/c/libuv" cnet "github.com/goplus/llgo/c/net" "github.com/goplus/llgo/c/syscall" - "github.com/goplus/llgo/x/net" "github.com/goplus/llgoexamples/rust/hyper" + "github.com/goplus/llgoexamples/x/net" ) // DefaultTransport is the default implementation of Transport and is @@ -28,7 +29,9 @@ import ( // and NO_PROXY (or the lowercase versions thereof). var DefaultTransport RoundTripper = &Transport{ //Proxy: ProxyFromEnvironment, - Proxy: nil, + Proxy: nil, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, } // DefaultMaxIdleConnsPerHost is the default value of Transport's @@ -37,6 +40,12 @@ const DefaultMaxIdleConnsPerHost = 2 const debugSwitch = true type Transport struct { + idleMu sync.Mutex + closeIdle bool // user has requested to close all idle conns + idleConn map[connectMethodKey][]*persistConn // most recently used at end + idleConnWait map[connectMethodKey]wantConnQueue // waiting getConns + idleLRU connLRU + altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme reqMu sync.Mutex reqCanceler map[cancelKey]func(error) @@ -63,6 +72,15 @@ type Transport struct { // uncompressed. DisableCompression bool + // MaxIdleConns controls the maximum number of idle (keep-alive) + // connections across all hosts. Zero means no limit. + MaxIdleConns int + + // MaxIdleConnsPerHost, if non-zero, controls the maximum idle + // (keep-alive) connections to keep per-host. If zero, + // DefaultMaxIdleConnsPerHost is used. + MaxIdleConnsPerHost int + // MaxConnsPerHost optionally limits the total number of // connections per host, including connections in the dialing, // active, and idle states. On limit violation, dials will block. @@ -70,9 +88,16 @@ type Transport struct { // Zero means no limit. MaxConnsPerHost int + // IdleConnTimeout is the maximum amount of time an idle + // (keep-alive) connection will remain idle before closing + // itself. + // Zero means no limit. + IdleConnTimeout time.Duration + // libuv and hyper related loopInitOnce sync.Once loop *libuv.Loop + async *libuv.Async exec *hyper.Executor } @@ -181,14 +206,258 @@ func (tr *transportRequest) setError(err error) { tr.mu.Unlock() } -func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) { - cm.targetScheme = treq.URL.Scheme - cm.targetAddr = canonicalAddr(treq.URL) - if t.Proxy != nil { - cm.proxyURL, err = t.Proxy(treq.Request) +func (t *Transport) putOrCloseIdleConn(pconn *persistConn) { + if err := t.tryPutIdleConn(pconn); err != nil { + pconn.close(err) + } +} + +func (t *Transport) maxIdleConnsPerHost() int { + if v := t.MaxIdleConnsPerHost; v != 0 { + return v + } + return DefaultMaxIdleConnsPerHost +} + +// tryPutIdleConn adds pconn to the list of idle persistent connections awaiting +// a new request. +// If pconn is no longer needed or not in a good state, tryPutIdleConn returns +// an error explaining why it wasn't registered. +// tryPutIdleConn does not close pconn. Use putOrCloseIdleConn instead for that. +func (t *Transport) tryPutIdleConn(pconn *persistConn) error { + if t.DisableKeepAlives || t.MaxIdleConnsPerHost < 0 { + return errKeepAlivesDisabled + } + if pconn.isBroken() { + return errConnBroken + } + pconn.markReused() + + t.idleMu.Lock() + defer t.idleMu.Unlock() + + // HTTP/2 (pconn.alt != nil) connections do not come out of the idle list, + // because multiple goroutines can use them simultaneously. + // If this is an HTTP/2 connection being “returned,” we're done. + if pconn.alt != nil && t.idleLRU.m[pconn] != nil { + return nil + } + + // Deliver pconn to goroutine waiting for idle connection, if any. + // (They may be actively dialing, but this conn is ready first. + // Chrome calls this socket late binding. + // See https://www.chromium.org/developers/design-documents/network-stack#TOC-Connection-Management.) + key := pconn.cacheKey + if q, ok := t.idleConnWait[key]; ok { + done := false + if pconn.alt == nil { + // HTTP/1. + // Loop over the waiting list until we find a w that isn't done already, and hand it pconn. + for q.len() > 0 { + w := q.popFront() + if w.tryDeliver(pconn, nil) { + done = true + break + } + } + } else { + // HTTP/2. + // Can hand the same pconn to everyone in the waiting list, + // and we still won't be done: we want to put it in the idle + // list unconditionally, for any future clients too. + for q.len() > 0 { + w := q.popFront() + w.tryDeliver(pconn, nil) + } + } + if q.len() == 0 { + delete(t.idleConnWait, key) + } else { + t.idleConnWait[key] = q + } + if done { + return nil + } + } + + if t.closeIdle { + return errCloseIdle + } + if t.idleConn == nil { + t.idleConn = make(map[connectMethodKey][]*persistConn) + } + idles := t.idleConn[key] + if len(idles) >= t.maxIdleConnsPerHost() { + return errTooManyIdleHost + } + for _, exist := range idles { + if exist == pconn { + log.Fatalf("dup idle pconn %p in freelist", pconn) + } + } + t.idleConn[key] = append(idles, pconn) + t.idleLRU.add(pconn) + if t.MaxIdleConns != 0 && t.idleLRU.len() > t.MaxIdleConns { + oldest := t.idleLRU.removeOldest() + oldest.close(errTooManyIdle) + t.removeIdleConnLocked(oldest) + } + + // Set idle timer, but only for HTTP/1 (pconn.alt == nil). + // The HTTP/2 implementation manages the idle timer itself + // (see idleConnTimeout in h2_bundle.go). + idleConnTimeout := uint64(t.IdleConnTimeout.Milliseconds()) + if t.IdleConnTimeout > 0 && pconn.alt == nil { + if pconn.idleTimer != nil { + pconn.idleTimer.Start(onIdleConnTimeout, idleConnTimeout, 0) + } else { + pconn.idleTimer = &libuv.Timer{} + libuv.InitTimer(t.loop, pconn.idleTimer) + (*libuv.Handle)(c.Pointer(pconn.idleTimer)).SetData(c.Pointer(pconn)) + pconn.idleTimer.Start(onIdleConnTimeout, idleConnTimeout, 0) + } + } + pconn.idleAt = time.Now() + return nil +} + +func onIdleConnTimeout(timer *libuv.Timer) { + pconn := (*persistConn)((*libuv.Handle)(c.Pointer(timer)).GetData()) + isClose := pconn.closeConnIfStillIdle() + if isClose { + timer.Stop() + } else { + timer.Start(onIdleConnTimeout, 0, 0) } - cm.onlyH1 = treq.requiresHTTP1() - return cm, err +} + +// queueForIdleConn queues w to receive the next idle connection for w.cm. +// As an optimization hint to the caller, queueForIdleConn reports whether +// it successfully delivered an already-idle connection. +func (t *Transport) queueForIdleConn(w *wantConn) (delivered bool) { + if t.DisableKeepAlives { + return false + } + + t.idleMu.Lock() + defer t.idleMu.Unlock() + + // Stop closing connections that become idle - we might want one. + // (That is, undo the effect of t.CloseIdleConnections.) + t.closeIdle = false + + if w == nil { + // Happens in test hook. + return false + } + + // If IdleConnTimeout is set, calculate the oldest + // persistConn.idleAt time we're willing to use a cached idle + // conn. + var oldTime time.Time + if t.IdleConnTimeout > 0 { + oldTime = time.Now().Add(-t.IdleConnTimeout) + } + // Look for most recently-used idle connection. + if list, ok := t.idleConn[w.key]; ok { + stop := false + delivered := false + for len(list) > 0 && !stop { + pconn := list[len(list)-1] + + // See whether this connection has been idle too long, considering + // only the wall time (the Round(0)), in case this is a laptop or VM + // coming out of suspend with previously cached idle connections. + tooOld := !oldTime.IsZero() && pconn.idleAt.Round(0).Before(oldTime) + if tooOld { + // Async cleanup. Launch in its own goroutine (as if a + // time.AfterFunc called it); it acquires idleMu, which we're + // holding, and does a synchronous net.Conn.Close. + pconn.closeConnIfStillIdleLocked() + } + if pconn.isBroken() || tooOld { + // If either persistConn.readLoop has marked the connection + // broken, but Transport.removeIdleConn has not yet removed it + // from the idle list, or if this persistConn is too old (it was + // idle too long), then ignore it and look for another. In both + // cases it's already in the process of being closed. + list = list[:len(list)-1] + continue + } + delivered = w.tryDeliver(pconn, nil) + if delivered { + if pconn.alt != nil { + // HTTP/2: multiple clients can share pconn. + // Leave it in the list. + } else { + // HTTP/1: only one client can use pconn. + // Remove it from the list. + t.idleLRU.remove(pconn) + list = list[:len(list)-1] + } + } + stop = true + } + if len(list) > 0 { + t.idleConn[w.key] = list + } else { + delete(t.idleConn, w.key) + } + if stop { + return delivered + } + } + + // Register to receive next connection that becomes idle. + if t.idleConnWait == nil { + t.idleConnWait = make(map[connectMethodKey]wantConnQueue) + } + q := t.idleConnWait[w.key] + q.cleanFront() + q.pushBack(w) + t.idleConnWait[w.key] = q + return false +} + +// removeIdleConn marks pconn as dead. +func (t *Transport) removeIdleConn(pconn *persistConn) bool { + t.idleMu.Lock() + defer t.idleMu.Unlock() + return t.removeIdleConnLocked(pconn) +} + +// t.idleMu must be held. +func (t *Transport) removeIdleConnLocked(pconn *persistConn) bool { + if pconn.idleTimer != nil { + pconn.idleTimer.Stop() + (*libuv.Handle)(c.Pointer(pconn.idleTimer)).Close(nil) + } + t.idleLRU.remove(pconn) + key := pconn.cacheKey + pconns := t.idleConn[key] + var removed bool + switch len(pconns) { + case 0: + // Nothing + case 1: + if pconns[0] == pconn { + delete(t.idleConn, key) + removed = true + } + default: + for i, v := range pconns { + if v != pconn { + continue + } + // Slide down, keeping most recently-used + // conns at the end. + copy(pconns[i:], pconns[i+1:]) + t.idleConn[key] = pconns[:len(pconns)-1] + removed = true + break + } + } + return removed } func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) { @@ -223,6 +492,16 @@ func (t *Transport) replaceReqCanceler(key cancelKey, fn func(error)) bool { return true } +func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) { + cm.targetScheme = treq.URL.Scheme + cm.targetAddr = canonicalAddr(treq.URL) + if t.Proxy != nil { + cm.proxyURL, err = t.Proxy(treq.Request) + } + cm.onlyH1 = treq.requiresHTTP1() + return cm, err +} + // alternateRoundTripper returns the alternate RoundTripper to use // for this request if the Request's URL scheme requires one, // or nil for the normal case of using the Transport. @@ -286,6 +565,9 @@ func (t *Transport) closeLocked(err error) { if t.loop != nil { t.loop.Close() } + if t.async != nil { + t.async.Close(nil) + } if t.exec != nil { t.exec.Free() } @@ -308,13 +590,12 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { defer println("RoundTrip end") } t.loopInitOnce.Do(func() { + println("init loop") t.loop = libuv.LoopNew() + t.async = &libuv.Async{} t.exec = hyper.NewExecutor() - //idle := &libuv.Idle{} - //libuv.InitIdle(t.loop, idle) - //(*libuv.Handle)(c.Pointer(idle)).SetData(c.Pointer(t)) - //idle.Start(readWriteLoop) + t.loop.Async(t.async, nil) checker := &libuv.Check{} libuv.InitCheck(t.loop, checker) @@ -330,7 +611,6 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { // Only the first request will initialize the timer if req.timer == nil && !req.deadline.IsZero() { req.timer = &libuv.Timer{} - req.timeoutch = make(chan struct{}, 1) libuv.InitTimer(t.loop, req.timer) ch := &timeoutData{ timeoutch: req.timeoutch, @@ -473,9 +753,10 @@ func (t *Transport) doRoundTrip(req *Request) (*Response, error) { var resp *Response if pconn.alt != nil { // HTTP/2 path. - t.setReqCanceler(cancelKey, nil) // not cancelable with CancelRequest + t.setReqCanceler(cancelKey, nil) // HTTP/2 not cancelable with CancelRequest resp, err = pconn.alt.RoundTrip(req) } else { + // HTTP/1.X path. resp, err = pconn.roundTrip(treq) } @@ -485,8 +766,35 @@ func (t *Transport) doRoundTrip(req *Request) (*Response, error) { } // Failed. Clean up and determine whether to retry. - // TODO(spongehah) Retry & ConnPool(t.doRoundTrip) - return nil, err + // TODO(spongehah) ConnPool(t.doRoundTrip) + if http2isNoCachedConnError(err) { + if t.removeIdleConn(pconn) { + t.decConnsPerHost(pconn.cacheKey) + } + } else if !pconn.shouldRetryRequest(req, err) { + // Issue 16465: return underlying net.Conn.Read error from peek, + // as we've historically done. + if e, ok := err.(nothingWrittenError); ok { + err = e.error + } + if e, ok := err.(transportReadFromServerError); ok { + err = e.err + } + if b, ok := req.Body.(*readTrackingBody); ok && !b.didClose { + // Issue 49621: Close the request body if pconn.roundTrip + // didn't do so already. This can happen if the pconn + // write loop exits without reading the write request. + req.closeBody() + } + return nil, err + } + testHookRoundTripRetried() + + // Rewind the body if we're able to. + req, err = rewindBody(req) + if err != nil { + return nil, err + } } } @@ -507,7 +815,6 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi key: cm.key(), //ctx: ctx, timeoutch: treq.timeoutch, - ready: make(chan struct{}, 1), beforeDial: testHookPrePendingDial, afterDial: testHookPostPendingDial, } @@ -518,20 +825,21 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi }() // TODO(spongehah) ConnPool(t.getConn) - //// Queue for idle connection. - //if delivered := t.queueForIdleConn(w); delivered { - // pc := w.pc - // // Trace only for HTTP/1. - // // HTTP/2 calls trace.GotConn itself. - // if pc.alt == nil && trace != nil && trace.GotConn != nil { - // trace.GotConn(pc.gotIdleConnTrace(pc.idleAt)) - // } - // // set request canceler to some non-nil function so we - // // can detect whether it was cleared between now and when - // // we enter roundTrip - // t.setReqCanceler(treq.cancelKey, func(error) {}) - // return pc, nil - //} + // Queue for idle connection. + if delivered := t.queueForIdleConn(w); delivered { + pc := w.pc + // Trace only for HTTP/1. + // HTTP/2 calls trace.GotConn itself. + // TODO(spongehah) trace(t.getConn) + //if pc.alt == nil && trace != nil && trace.GotConn != nil { + // trace.GotConn(pc.gotIdleConnTrace(pc.idleAt)) + //} + // set request canceler to some non-nil function so we + // can detect whether it was cleared between now and when + // we enter roundTrip + t.setReqCanceler(treq.cancelKey, func(error) {}) + return pc, nil + } cancelc := make(chan error, 1) t.setReqCanceler(treq.cancelKey, func(err error) { cancelc <- err }) @@ -539,52 +847,36 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi // Queue for permission to dial. t.queueForDial(w) - // Wait for completion or cancellation. - select { - case <-w.ready: - // Trace success but only for HTTP/1. - // HTTP/2 calls trace.GotConn itself. - //if w.pc != nil && w.pc.alt == nil && trace != nil && trace.GotConn != nil { - // trace.GotConn(httptrace.GotConnInfo{Conn: w.pc.conn, Reused: w.pc.isReused()}) - //} - if w.err != nil { - // If the request has been canceled, that's probably - // what caused w.err; if so, prefer to return the - // cancellation error (see golang.org/issue/16049). - select { - // TODO(spongehah) cancel(t.getConn) - //case <-req.Cancel: - // return nil, errRequestCanceledConn - //case <-req.Context().Done(): - // return nil, req.Context().Err() - case <-req.timeoutch: - return nil, errors.New("timeout: req.Context().Err()") - case err := <-cancelc: - if err == errRequestCanceled { - err = errRequestCanceledConn - } - return nil, err - default: - // return below + // Trace success but only for HTTP/1. + // HTTP/2 calls trace.GotConn itself. + //if w.pc != nil && w.pc.alt == nil && trace != nil && trace.GotConn != nil { + // trace.GotConn(httptrace.GotConnInfo{Conn: w.pc.conn, Reused: w.pc.isReused()}) + //} + if w.err != nil { + // If the request has been canceled, that's probably + // what caused w.err; if so, prefer to return the + // cancellation error (see golang.org/issue/16049). + select { + // TODO(spongehah) timeout(t.getConn) + //case <-req.Cancel: + // return nil, errRequestCanceledConn + //case <-req.Context().Done(): + // return nil, req.Context().Err() + case <-req.timeoutch: + if debugSwitch { + println("getConn: timeoutch") } + return nil, errors.New("timeout: req.Context().Err()") + case err := <-cancelc: + if err == errRequestCanceled { + err = errRequestCanceledConn + } + return nil, err + default: + // return below } - return w.pc, w.err - // TODO(spongehah) cancel(t.getConn) - //case <-req.Cancel: - // return nil, errRequestCanceledConn - //case <-req.Context().Done(): - // return nil, - case <-req.timeoutch: - if debugSwitch { - println("getConn: timeoutch") - } - return nil, errors.New("timeout: req.Context().Err()\n") - case err := <-cancelc: - if err == errRequestCanceled { - err = errRequestCanceledConn - } - return nil, err } + return w.pc, w.err } // queueForDial queues w to wait for permission to begin dialing. @@ -597,7 +889,7 @@ func (t *Transport) queueForDial(w *wantConn) { w.beforeDial() if t.MaxConnsPerHost <= 0 { - go t.dialConnFor(w) + t.dialConnFor(w) return } @@ -609,7 +901,7 @@ func (t *Transport) queueForDial(w *wantConn) { t.connsPerHost = make(map[connectMethodKey]int) } t.connsPerHost[w.key] = n + 1 - go t.dialConnFor(w) + t.dialConnFor(w) return } @@ -633,17 +925,16 @@ func (t *Transport) dialConnFor(w *wantConn) { defer w.afterDial() pc, err := t.dialConn(w.timeoutch, w.cm) - w.tryDeliver(pc, err) // TODO(spongehah) ConnPool(t.dialConnFor) - //delivered := w.tryDeliver(pc, err) - // Handle undelivered or shareable connections - //if err == nil && (!delivered || pc.alt != nil) { - // // pconn was not passed to w, - // // or it is HTTP/2 and can be shared. - // // Add to the idle connection pool. - // t.putOrCloseIdleConn(pc) - //} - + delivered := w.tryDeliver(pc, err) + // If the connection was successfully established but was not passed to w, + // or is a shareable HTTP/2 connection + if err == nil && (!delivered || pc.alt != nil) { + // pconn was not passed to w, + // or it is HTTP/2 and can be shared. + // Add to the idle connection pool. + t.putOrCloseIdleConn(pc) + } // If an error occurs during the dialing process, the connection count for that host is decreased. // This ensures that the connection count remains accurate even in cases where the dial attempt fails. if err != nil { @@ -676,7 +967,7 @@ func (t *Transport) decConnsPerHost(key connectMethodKey) { for q.len() > 0 { w := q.popFront() if w.waiting() { - go t.dialConnFor(w) + t.dialConnFor(w) done = true break } @@ -708,6 +999,7 @@ func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn * } select { case <-timeoutch: + err = errors.New("[t.dialConn] request timeout") return default: } @@ -716,6 +1008,7 @@ func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn * cacheKey: cm.key(), closech: make(chan struct{}, 1), writeLoopDone: make(chan struct{}, 1), + alive: true, } //trace := httptrace.ContextClientTrace(ctx) @@ -754,7 +1047,7 @@ func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn * // } //} else { //conn, err := t.dial(ctx, "tcp", cm.addr()) - conn, err := t.dial(timeoutch, cm.addr()) + conn, err := t.dial(cm.addr()) if err != nil { return nil, err } @@ -811,14 +1104,15 @@ func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn * select { case <-timeoutch: - conn.Close() - return + err = errors.New("[t.dialConn] request timeout") + pconn.close(err) + return nil, err default: } return pconn, nil } -func (t *Transport) dial(timeoutch chan struct{}, addr string) (*connData, error) { +func (t *Transport) dial(addr string) (*connData, error) { if debugSwitch { println("dial start") defer println("dial end") @@ -862,7 +1156,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err testHookEnterRoundTrip() if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) { // TODO(spongehah) ConnPool(pc.roundTrip) - //pc.t.putOrCloseIdleConn(pc) + pc.t.putOrCloseIdleConn(pc) return nil, errRequestCanceled } pc.mu.Lock() @@ -925,16 +1219,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err writeErrCh := make(chan error, 1) resc := make(chan responseAndError, 1) - // Hookup the IO - hyperIo := newIoWithConnReadWrite(pc.conn) - // We need an executor generally to poll futures - // Prepare client options - opts := hyper.NewClientConnOptions() - opts.Exec(pc.t.exec) - // send the handshake - handshakeTask := hyper.Handshake(hyperIo, opts) taskData := &taskData{ - taskId: write, req: req, pc: pc, addedGzip: requestedGzip, @@ -942,14 +1227,32 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err callerGone: gone, resc: resc, } - handshakeTask.SetUserdata(c.Pointer(taskData)) - // Send the request to readWriteLoop(). - // Let's wait for the handshake to finish... - pc.t.exec.Push(handshakeTask) - async := &libuv.Async{} - pc.t.loop.Async(async, asyncCb) - async.Send() + if pc.client == nil && !pc.isReused() { + println("first") + // Hookup the IO + hyperIo := newIoWithConnReadWrite(pc.conn) + // We need an executor generally to poll futures + // Prepare client options + opts := hyper.NewClientConnOptions() + opts.Exec(pc.t.exec) + // send the handshake + handshakeTask := hyper.Handshake(hyperIo, opts) + taskData.taskId = handshake + handshakeTask.SetUserdata(c.Pointer(taskData)) + // Send the request to readWriteLoop(). + pc.t.exec.Push(handshakeTask) + } else { + println("second") + taskData.taskId = read + err = req.write(pc.client, taskData, pc.t.exec) + if err != nil { + writeErrCh <- err + } + } + + // Wake up libuv. Loop + pc.t.async.Send() //var respHeaderTimer <-chan time.Time //cancelChan := req.Request.Cancel @@ -1003,7 +1306,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err return nil, pc.mapRoundTripError(req, startBytesWritten, re.err) } return re.res, nil - // TODO(spongehah) cancel(pc.roundTrip) + // TODO(spongehah) timeout(pc.roundTrip) //case <-cancelChan: // canceled = pc.t.cancelRequest(req.cancelKey, errRequestCanceled) // cancelChan = nil @@ -1022,20 +1325,15 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err } } -func asyncCb(async *libuv.Async) { - println("async called") -} - // readWriteLoop handles the main I/O loop for a persistent connection. // It processes incoming requests, sends them to the server, and handles responses. -func readWriteLoop(idle *libuv.Check) { - println("polling") - t := (*Transport)((*libuv.Handle)(c.Pointer(idle)).GetData()) +func readWriteLoop(checker *libuv.Check) { + t := (*Transport)((*libuv.Handle)(c.Pointer(checker)).GetData()) // Read this once, before loop starts. (to avoid races in tests) - testHookMu.Lock() - testHookReadLoopBeforeNextRead := testHookReadLoopBeforeNextRead - testHookMu.Unlock() + //testHookMu.Lock() + //testHookReadLoopBeforeNextRead := testHookReadLoopBeforeNextRead + //testHookMu.Unlock() const debugReadWriteLoop = true // Debug switch provided for developers @@ -1043,6 +1341,9 @@ func readWriteLoop(idle *libuv.Check) { // Poll all ready tasks and act on them... for { task := t.exec.Poll() + if debugSwitch { + println("polling") + } if task == nil { return } @@ -1057,28 +1358,51 @@ func readWriteLoop(idle *libuv.Check) { println("taskId: ", taskId) } switch taskId { - case write: + case handshake: if debugReadWriteLoop { println("write") } + err := checkTaskType(task, handshake) + if err != nil { + taskData.writeErrCh <- err + task.Free() + continue + } + + pc := taskData.pc select { - case <-taskData.pc.closech: + case <-pc.closech: task.Free() continue default: } - err := checkTaskType(task, write) - client := (*hyper.ClientConn)(task.Value()) + pc.client = (*hyper.ClientConn)(task.Value()) task.Free() - if err == nil { - // TODO(spongehah) Proxy(writeLoop) - err = taskData.req.Request.write(client, taskData, t.exec) + // TODO(spongehah) Proxy(writeLoop) + taskData.taskId = read + err = taskData.req.Request.write(pc.client, taskData, t.exec) + + if err != nil { + //pc.writeErrCh <- err // to the body reader, which might recycle us + taskData.writeErrCh <- err // to the roundTrip function + pc.close(err) + continue + } + + if debugReadWriteLoop { + println("write end") + } + case read: + if debugReadWriteLoop { + println("read") } - // For this request, no longer need the client - client.Free() + + pc := taskData.pc + + err := checkTaskType(task, read) if bre, ok := err.(requestBodyReadError); ok { err = bre.error // Errors reading from the user's @@ -1093,59 +1417,48 @@ func readWriteLoop(idle *libuv.Check) { if err != nil { //pc.writeErrCh <- err // to the body reader, which might recycle us taskData.writeErrCh <- err // to the roundTrip function - taskData.pc.close(err) + pc.close(err) continue } - if debugReadWriteLoop { - println("write end") - } - case read: - if debugReadWriteLoop { - println("read") - } - - if taskData.pc.closeErr == nil { - taskData.pc.closeErr = errReadLoopExiting + if pc.closeErr == nil { + pc.closeErr = errReadLoopExiting } // TODO(spongehah) ConnPool(readWriteLoop) - //if taskData.pc.tryPutIdleConn == nil { - // //taskData.pc.tryPutIdleConn := func(trace *httptrace.ClientTrace) bool { - // // if err := pc.t.tryPutIdleConn(pc); err != nil { - // // closeErr = err - // // if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { - // // trace.PutIdleConn(err) - // // } - // // return false - // // } - // // if trace != nil && trace.PutIdleConn != nil { - // // trace.PutIdleConn(nil) - // // } - // // return true - // //} - //} + if pc.tryPutIdleConn == nil { + pc.tryPutIdleConn = func() bool { + if err := pc.t.tryPutIdleConn(pc); err != nil { + pc.closeErr = err + // TODO(spongehah) trace(readWriteLoop) + //if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { + // trace.PutIdleConn(err) + //} + return false + } + //if trace != nil && trace.PutIdleConn != nil { + // trace.PutIdleConn(nil) + //} + return true + } + } - err := checkTaskType(task, read) + // Take the results + hyperResp := (*hyper.Response)(task.Value()) + task.Free() - taskData.pc.mu.Lock() - if taskData.pc.numExpectedResponses == 0 { - taskData.pc.closeLocked(errServerClosedIdle) - taskData.pc.mu.Unlock() + pc.mu.Lock() + if pc.numExpectedResponses == 0 { + pc.readLoopPeekFailLocked(hyperResp, err) + pc.mu.Unlock() // defer - taskData.pc.close(taskData.pc.closeErr) - // TODO(spongehah) ConnPool(readWriteLoop) - //t.removeIdleConn(pc) + readLoopDefer(pc, t) continue } - taskData.pc.mu.Unlock() + pc.mu.Unlock() //trace := httptrace.ContextClientTrace(rc.req.Context()) - // Take the results - hyperResp := (*hyper.Response)(task.Value()) - task.Free() - var resp *Response var respBody *hyper.Body if err == nil { @@ -1155,7 +1468,7 @@ func readWriteLoop(idle *libuv.Check) { respBody = hyperResp.Body() } else { err = transportReadFromServerError{err} - taskData.pc.closeErr = err + pc.closeErr = err } // No longer need the response @@ -1166,21 +1479,17 @@ func readWriteLoop(idle *libuv.Check) { case taskData.resc <- responseAndError{err: err}: case <-taskData.callerGone: // defer - taskData.pc.close(taskData.pc.closeErr) - // TODO(spongehah) ConnPool(readWriteLoop) - //t.removeIdleConn(pc) + readLoopDefer(pc, t) continue } // defer - taskData.pc.close(taskData.pc.closeErr) - // TODO(spongehah) ConnPool(readWriteLoop) - //t.removeIdleConn(pc) + readLoopDefer(pc, t) continue } - taskData.pc.mu.Lock() - taskData.pc.numExpectedResponses-- - taskData.pc.mu.Unlock() + pc.mu.Lock() + pc.numExpectedResponses-- + pc.mu.Unlock() bodyWritable := resp.bodyIsWritable() hasBody := taskData.req.Method != "HEAD" && resp.ContentLength != 0 @@ -1189,46 +1498,43 @@ func readWriteLoop(idle *libuv.Check) { // Don't do keep-alive on error if either party requested a close // or we get an unexpected informational (1xx) response. // StatusCode 100 is already handled above. - taskData.pc.alive = false + pc.alive = false } if !hasBody || bodyWritable { - //replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) - t.replaceReqCanceler(taskData.req.cancelKey, nil) + replaced := pc.t.replaceReqCanceler(taskData.req.cancelKey, nil) // TODO(spongehah) ConnPool(readWriteLoop) - //// Put the idle conn back into the pool before we send the response - //// so if they process it quickly and make another request, they'll - //// get this same conn. But we use the unbuffered channel 'rc' - //// to guarantee that persistConn.roundTrip got out of its select - //// potentially waiting for this persistConn to close. - //taskData.pc.alive = taskData.pc.alive && + // Put the idle conn back into the pool before we send the response + // so if they process it quickly and make another request, they'll + // get this same conn. But we use the unbuffered channel 'rc' + // to guarantee that persistConn.roundTrip got out of its select + // potentially waiting for this persistConn to close. + pc.alive = pc.alive && + replaced && pc.tryPutIdleConn() + //pc.alive = pc.alive && // !pc.sawEOF && // pc.wroteRequest() && - // replaced && tryPutIdleConn(trace) + // replaced && pc.tryPutIdleConn() if bodyWritable { - taskData.pc.closeErr = errCallerOwnsConn + pc.closeErr = errCallerOwnsConn } select { case taskData.resc <- responseAndError{res: resp}: case <-taskData.callerGone: // defer - taskData.pc.close(taskData.pc.closeErr) - // TODO(spongehah) ConnPool(readWriteLoop) - //t.removeIdleConn(pc) + readLoopDefer(pc, t) continue } // Now that they've read from the unbuffered channel, they're safely // out of the select that also waits on this goroutine to die, so // we're allowed to exit now if needed (if alive is false) - testHookReadLoopBeforeNextRead() - if taskData.pc.alive == false { + //testHookReadLoopBeforeNextRead() + if pc.alive == false { // defer - taskData.pc.close(taskData.pc.closeErr) - // TODO(spongehah) ConnPool(readWriteLoop) - //t.removeIdleConn(pc) + readLoopDefer(pc, t) } continue } @@ -1242,7 +1548,7 @@ func readWriteLoop(idle *libuv.Check) { fn: func(err error) error { isEOF := err == io.EOF if !isEOF { - if cerr := taskData.pc.canceled(); cerr != nil { + if cerr := pc.canceled(); cerr != nil { return cerr } } @@ -1265,24 +1571,22 @@ func readWriteLoop(idle *libuv.Check) { taskData.taskId = readDone bodyForeachTask.SetUserdata(c.Pointer(taskData)) t.exec.Push(bodyForeachTask) - (*timeoutData)((*libuv.Handle)(c.Pointer(taskData.req.timer)).GetData()).taskData = taskData + if taskData.req.timer != nil { + (*timeoutData)((*libuv.Handle)(c.Pointer(taskData.req.timer)).GetData()).taskData = taskData + } // TODO(spongehah) select blocking(readWriteLoop) //select { //case taskData.resc <- responseAndError{res: resp}: //case <-taskData.callerGone: // // defer - // taskData.pc.close(taskData.pc.closeErr) - // // TODO(spongehah) ConnPool(readWriteLoop) - // //t.removeIdleConn(pc) + // readLoopDefer(pc, t) // continue //} select { case <-taskData.callerGone: // defer - taskData.pc.close(taskData.pc.closeErr) - // TODO(spongehah) ConnPool(readWriteLoop) - //t.removeIdleConn(pc) + readLoopDefer(pc, t) continue default: } @@ -1301,45 +1605,48 @@ func readWriteLoop(idle *libuv.Check) { } checkTaskType(task, readDone) - //bodyEOF := task.Type() == hyper.TaskEmpty + bodyEOF := task.Type() == hyper.TaskEmpty // free the task task.Free() - t.replaceReqCanceler(taskData.req.cancelKey, nil) // before pc might return to idle pool + pc := taskData.pc + + replaced := t.replaceReqCanceler(taskData.req.cancelKey, nil) // before pc might return to idle pool // TODO(spongehah) ConnPool(readWriteLoop) - //taskData.pc.alive = taskData.pc.alive && + pc.alive = pc.alive && + bodyEOF && + replaced && pc.tryPutIdleConn() + //pc.alive = pc.alive && // bodyEOF && // !pc.sawEOF && // pc.wroteRequest() && // replaced && tryPutIdleConn(trace) - // TODO(spongehah) cancel(pc.readWriteLoop) + // TODO(spongehah) timeout(t.readWriteLoop) //case <-rw.rc.req.Cancel: - // taskData.pc.alive = false + // pc.alive = false // pc.t.CancelRequest(rw.rc.req) //case <-rw.rc.req.Context().Done(): - // taskData.pc.alive = false + // pc.alive = false // pc.t.cancelRequest(rw.rc.cancelKey, rw.rc.req.Context().Err()) - //case <-taskData.pc.closech: - // taskData.pc.alive = false + //case <-pc.closech: + // pc.alive = false //} - select { - case <-taskData.req.timeoutch: - continue - case <-taskData.pc.closech: - taskData.pc.alive = false - default: - } + //select { + //case <-taskData.req.timeoutch: + // continue + //case <-pc.closech: + // pc.alive = false + //default: + //} - if taskData.pc.alive == false { + if pc.alive == false { // defer - taskData.pc.close(taskData.pc.closeErr) - // TODO(spongehah) ConnPool(readWriteLoop) - //t.removeIdleConn(pc) + readLoopDefer(pc, t) } - testHookReadLoopBeforeNextRead() + //testHookReadLoopBeforeNextRead() if debugReadWriteLoop { println("readDone end") } @@ -1350,6 +1657,12 @@ func readWriteLoop(idle *libuv.Check) { } } +func readLoopDefer(pc *persistConn, t *Transport) { + pc.close(pc.closeErr) + // TODO(spongehah) ConnPool(readLoopDefer) + t.removeIdleConn(pc) +} + // ---------------------------------------------------------- type taskData struct { @@ -1374,6 +1687,9 @@ type connData struct { } func (conn *connData) Close() error { + if conn == nil { + return nil + } if conn.ReadWaker != nil { conn.ReadWaker.Free() conn.ReadWaker = nil @@ -1535,9 +1851,7 @@ func onTimeout(timer *libuv.Timer) { pc.alive = false pc.t.cancelRequest(taskData.req.cancelKey, errors.New("timeout: req.Context().Err()")) // defer - pc.close(pc.closeErr) - // TODO(spongehah) ConnPool(onTimeout) - //t.removeIdleConn(pc) + readLoopDefer(pc, pc.t) } } @@ -1555,7 +1869,7 @@ type taskId c.Int const ( notSet taskId = iota - write + handshake read readDone ) @@ -1563,13 +1877,13 @@ const ( // checkTaskType checks the task type func checkTaskType(task *hyper.Task, curTaskId taskId) error { switch curTaskId { - case write: + case handshake: if task.Type() == hyper.TaskError { - log.Printf("[readWriteLoop::write]handshake task error!\n") + log.Printf("[readWriteLoop::handshake]handshake task error!\n") return fail((*hyper.Error)(task.Value())) } if task.Type() != hyper.TaskClientConn { - return fmt.Errorf("[readWriteLoop::write]unexpected task type\n") + return fmt.Errorf("[readWriteLoop::handshake]unexpected task type\n") } return nil case read: @@ -1746,6 +2060,10 @@ type persistConn struct { writeLoopDone chan struct{} // closed when readWriteLoop ends + // Both guarded by Transport.idleMu: + idleAt time.Time // time it last become idle + idleTimer *libuv.Timer // holding an onIdleConnTimeout to close it + mu sync.Mutex // guards following fields numExpectedResponses int closed error // set non-nil when conn is closed, before closech is closed @@ -1754,11 +2072,14 @@ type persistConn struct { // mutateHeaderFunc is an optional func to modify extra // headers on each outbound request before it's written. (the // original Request given to RoundTrip is not modified) + reused bool // whether conn has had successful request/response and is being reused. mutateHeaderFunc func(Header) // other - alive bool // Replace the alive in readLoop - closeErr error // Replace the closeErr in readLoop + alive bool // Replace the alive in readLoop + closeErr error // Replace the closeErr in readLoop + tryPutIdleConn func() bool // Replace the tryPutIdleConn in readLoop + client *hyper.ClientConn } func (pc *persistConn) cancelRequest(err error) { @@ -1779,7 +2100,18 @@ func (pc *persistConn) close(err error) { pc.closeLocked(err) } +// markReused marks this connection as having been successfully used for a +// request and response. +func (pc *persistConn) markReused() { + pc.mu.Lock() + pc.reused = true + pc.mu.Unlock() +} + func (pc *persistConn) closeLocked(err error) { + if debugSwitch { + println("pc closed") + } if err == nil { panic("nil error") } @@ -1795,6 +2127,7 @@ func (pc *persistConn) closeLocked(err error) { } close(pc.closech) close(pc.writeLoopDone) + pc.client.Free() } } pc.mutateHeaderFunc = nil @@ -1866,6 +2199,14 @@ func (pc *persistConn) canceled() error { return pc.canceledErr } +// isReused reports whether this connection has been used before. +func (pc *persistConn) isReused() bool { + pc.mu.Lock() + r := pc.reused + pc.mu.Unlock() + return r +} + // isBroken reports whether this connection is in a known broken state. func (pc *persistConn) isBroken() bool { pc.mu.Lock() @@ -1874,6 +2215,107 @@ func (pc *persistConn) isBroken() bool { return b } +// shouldRetryRequest reports whether we should retry sending a failed +// HTTP request on a new connection. The non-nil input error is the +// error from roundTrip. +func (pc *persistConn) shouldRetryRequest(req *Request, err error) bool { + if http2isNoCachedConnError(err) { + // Issue 16582: if the user started a bunch of + // requests at once, they can all pick the same conn + // and violate the server's max concurrent streams. + // Instead, match the HTTP/1 behavior for now and dial + // again to get a new TCP connection, rather than failing + // this request. + return true + } + if err == errMissingHost { + // User error. + return false + } + if !pc.isReused() { + // This was a fresh connection. There's no reason the server + // should've hung up on us. + // + // Also, if we retried now, we could loop forever + // creating new connections and retrying if the server + // is just hanging up on us because it doesn't like + // our request (as opposed to sending an error). + return false + } + if _, ok := err.(nothingWrittenError); ok { + // We never wrote anything, so it's safe to retry, if there's no body or we + // can "rewind" the body with GetBody. + return req.outgoingLength() == 0 || req.GetBody != nil + } + if !req.isReplayable() { + // Don't retry non-idempotent requests. + return false + } + if _, ok := err.(transportReadFromServerError); ok { + // We got some non-EOF net.Conn.Read failure reading + // the 1st response byte from the server. + return true + } + if err == errServerClosedIdle { + // The server replied with io.EOF while we were trying to + // read the response. Probably an unfortunately keep-alive + // timeout, just as the client was writing a request. + return true + } + return false // conservatively +} + +// closeConnIfStillIdle closes the connection if it's still sitting idle. +// This is what's called by the persistConn's idleTimer, and is run in its +// own goroutine. +func (pc *persistConn) closeConnIfStillIdle() bool { + t := pc.t + isLock := t.idleMu.TryLock() + if isLock { + defer t.idleMu.Unlock() + pc.closeConnIfStillIdleLocked() + return true + } + return false +} + +func (pc *persistConn) closeConnIfStillIdleLocked() { + t := pc.t + if _, ok := t.idleLRU.m[pc]; !ok { + // Not idle. + return + } + t.removeIdleConnLocked(pc) + pc.close(errIdleConnTimeout) +} + +func (pc *persistConn) readLoopPeekFailLocked(resp *hyper.Response, err error) { + if pc.closed != nil { + return + } + if is408Message(resp) { + pc.closeLocked(errServerClosedIdle) + return + } + pc.closeLocked(fmt.Errorf("readLoopPeekFailLocked: %w", err)) +} + +func is408Message(resp *hyper.Response) bool { + httpVersion := int(resp.Version()) + if httpVersion != 10 && httpVersion != 11 { + return false + } + return resp.Status() == 408 +} + +// isNoCachedConnError reports whether err is of type noCachedConnError +// or its equivalent renamed type in net/http2's h2_bundle.go. Both types +// may coexist in the same running program. +func http2isNoCachedConnError(err error) bool { // h2_bundle.go + _, ok := err.(interface{ IsHTTP2NoCachedConnError() }) + return ok +} + // connectMethod is the map key (in its String form) for keeping persistent // TCP connections alive for subsequent HTTP requests. // @@ -1967,7 +2409,8 @@ type wantConn struct { key connectMethodKey // cm.key() ctx context.Context // context for dial timeoutch chan struct{} // tmp timeout to replace ctx - ready chan struct{} // closed when pc, err pair is delivered + ready bool + //ready chan struct{} // closed when pc, err pair is delivered // hooks for testing to know when dials are done // beforeDial is called in the getConn goroutine when the dial is queued. @@ -1985,25 +2428,24 @@ type wantConn struct { func (w *wantConn) cancel(t *Transport, err error) { w.mu.Lock() if w.pc == nil && w.err == nil { - close(w.ready) // catch misbehavior in future delivery + w.ready = true // catch misbehavior in future delivery } - //pc := w.pc + pc := w.pc w.pc = nil w.err = err w.mu.Unlock() // TODO(spongehah) ConnPool(w.cancel) - //if pc != nil { - // t.putOrCloseIdleConn(pc) - //} + if pc != nil { + t.putOrCloseIdleConn(pc) + } } // waiting reports whether w is still waiting for an answer (connection or error). func (w *wantConn) waiting() bool { - select { - case <-w.ready: + if w.ready { return false - default: + } else { return true } } @@ -2022,12 +2464,7 @@ func (w *wantConn) tryDeliver(pc *persistConn, err error) bool { if w.pc == nil && w.err == nil { panic("net/http: internal error: misuse of tryDeliver") } - select { - case <-w.timeoutch: - pc.close(errors.New("request timeout: dialConn timeout")) - default: - } - close(w.ready) + w.ready = true return true } @@ -2200,3 +2637,42 @@ func (gz *gzipReader) Read(p []byte) (n int, err error) { func (gz *gzipReader) Close() error { return gz.body.Close() } + +type connLRU struct { + ll *list.List // list.Element.Value type is of *persistConn + m map[*persistConn]*list.Element +} + +// add adds pc to the head of the linked list. +func (cl *connLRU) add(pc *persistConn) { + if cl.ll == nil { + cl.ll = list.New() + cl.m = make(map[*persistConn]*list.Element) + } + ele := cl.ll.PushFront(pc) + if _, ok := cl.m[pc]; ok { + panic("persistConn was already in LRU") + } + cl.m[pc] = ele +} + +func (cl *connLRU) removeOldest() *persistConn { + ele := cl.ll.Back() + pc := ele.Value.(*persistConn) + cl.ll.Remove(ele) + delete(cl.m, pc) + return pc +} + +// remove removes pc from cl. +func (cl *connLRU) remove(pc *persistConn) { + if ele, ok := cl.m[pc]; ok { + cl.ll.Remove(ele) + delete(cl.m, pc) + } +} + +// len returns the number of items in the cache. +func (cl *connLRU) len() int { + return len(cl.m) +} diff --git a/x/net/http/util.go b/x/net/http/util.go index bec22a8..bfd9fc3 100644 --- a/x/net/http/util.go +++ b/x/net/http/util.go @@ -7,7 +7,7 @@ import ( "golang.org/x/net/idna" - "github.com/goplus/llgo/x/net" + "github.com/goplus/llgoexamples/x/net" ) /** From 16054aefa2980378ac13b302f2532e91501e1653 Mon Sep 17 00:00:00 2001 From: hackerchai Date: Wed, 4 Sep 2024 18:27:43 +0800 Subject: [PATCH 29/55] refactor(x/net/http): Rewrite request logic Signed-off-by: hackerchai --- x/net/http/_demo/http.go | 29 ++++- x/net/http/request.go | 81 ++++++------ x/net/http/response.go | 21 +++- x/net/http/server.go | 257 ++++++++++++++++++++++----------------- 4 files changed, 231 insertions(+), 157 deletions(-) diff --git a/x/net/http/_demo/http.go b/x/net/http/_demo/http.go index 36920d0..2b20f0b 100644 --- a/x/net/http/_demo/http.go +++ b/x/net/http/_demo/http.go @@ -9,19 +9,38 @@ import ( func echoHandler(w http.ResponseWriter, r *http.Request) { fmt.Printf("echoHandler called\n") - //TODO: read body and echo + //TODO(hackerchai): read body and echo + // fmt.Printf("> %s %s HTTP/%d.%d\n", r.Method, r.RequestURI, r.ProtoMajor, r.ProtoMinor) + // for key, values := range r.Header { + // for _, value := range values { + // fmt.Printf("> %s: %s\n", key, value) + // } + // } + // fmt.Printf("URL: %s\n", r.URL.String()) + // // fmt.Println("ContentLength: %d", r.ContentLength) + // // fmt.Println("TransferEncoding: %s", r.TransferEncoding) + // //TODO: read body and echo // body, err := io.ReadAll(r.Body) + // println("body read") + // if err != nil { // http.Error(w, "Error reading request body", http.StatusInternalServerError) // return // } // defer r.Body.Close() - // fmt.Printf("body: %s\n", string(body)) - //w.Header().Set("Content-Type", "text/plain") - //w.Write(body) + // fmt.Printf("body read") + // w.Header().Set("Content-Type", "text/plain") + // w.Write(body) + fmt.Printf("> %s %s HTTP/%d.%d\n", r.Method, r.RequestURI, r.ProtoMajor, r.ProtoMinor) + for key, values := range r.Header { + for _, value := range values { + fmt.Printf("> %s: %s\n", key, value) + } + } + fmt.Printf("URL: %s\n", r.URL.String()) w.Header().Set("Content-Type", "text/plain") - w.Write([]byte("echoHandler called\n")) + w.Write([]byte("hello world\n")) } func main() { diff --git a/x/net/http/request.go b/x/net/http/request.go index b991420..9d458c3 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -3,6 +3,8 @@ package http import ( "fmt" "io" + + //"mime/multipart" "net/url" "strings" "time" @@ -25,14 +27,22 @@ type Request struct { TransferEncoding []string Close bool Host string - timeout time.Duration + // Form url.Values + // PostForm url.Values + // MultipartForm *multipart.Form + RemoteAddr string + RequestURI string + timeout time.Duration } -func newRequest(ListenAddr string, conn *conn, hyperReq *hyper.Request) (*Request, error) { +func (conn *conn) readRequest(hyperReq *hyper.Request) (*Request, error) { + println("readRequest called") req := Request{ - Header: make(Header), - timeout: 0, + Header: make(Header), + timeout: 0, + Body: nil, } + req.RemoteAddr = conn.remoteAddr headers := hyperReq.Headers() if headers != nil { @@ -41,13 +51,6 @@ func newRequest(ListenAddr string, conn *conn, hyperReq *hyper.Request) (*Reques return nil, fmt.Errorf("failed to get request headers") } - fmt.Printf("Headers:\n") - for key, values := range req.Header { - for _, value := range values { - fmt.Printf("%s: %s\n", key, value) - } - } - var host string for key, values := range req.Header { if strings.EqualFold(key, "Host") { @@ -66,11 +69,10 @@ func newRequest(ListenAddr string, conn *conn, hyperReq *hyper.Request) (*Reques } methodStr := string(method[:methodLen]) - fmt.Printf("Method: %s\n", methodStr) var scheme, authority, pathAndQuery [1024]byte schemeLen, authorityLen, pathAndQueryLen := unsafe.Sizeof(scheme), unsafe.Sizeof(authority), unsafe.Sizeof(pathAndQuery) - uriResult := hyperReq.URIParts(&scheme[0], &schemeLen, &authority[0], &authorityLen, &pathAndQuery[0], &pathAndQueryLen); + uriResult := hyperReq.URIParts(&scheme[0], &schemeLen, &authority[0], &authorityLen, &pathAndQuery[0], &pathAndQueryLen) if uriResult != hyper.OK { return nil, fmt.Errorf("failed to get URI parts: %v", uriResult) } @@ -95,11 +97,11 @@ func newRequest(ListenAddr string, conn *conn, hyperReq *hyper.Request) (*Reques } req.Host = authorityStr req.Method = methodStr + req.RequestURI = pathAndQueryStr var proto string var protoMajor, protoMinor int version := hyperReq.Version() - fmt.Printf("Version: %d\n", version) switch version { case hyper.HTTPVersion10: proto = "HTTP/1.0" @@ -122,38 +124,36 @@ func newRequest(ListenAddr string, conn *conn, hyperReq *hyper.Request) (*Reques } req.Proto = proto req.ProtoMajor = protoMajor - req.ProtoMinor = protoMinor + req.ProtoMinor = protoMinor - urlStr := fmt.Sprintf("%s://%s%s", schemeStr, host, pathAndQueryStr) - fmt.Printf("URL: %s\n", urlStr) + urlStr := fmt.Sprintf("%s://%s%s", schemeStr, authorityStr, pathAndQueryStr) url, err := url.Parse(urlStr) if err != nil { return nil, err } req.URL = url - if methodStr == "POST" || methodStr == "PUT" || methodStr == "PATCH" { - body := hyperReq.Body() - if body != nil { - var bodyWriter *io.PipeWriter - req.Body, bodyWriter = io.Pipe() - task := body.Foreach(getBodyChunk, c.Pointer(&bodyWriter), nil) - if task != nil { - r := conn.Executor.Push(task) - if r != hyper.OK { - fmt.Printf("failed to push body foreach task: %d\n", r) - task.Free() - return nil, fmt.Errorf("failed to push body foreach task: %v", r) - } - } else { - return nil, fmt.Errorf("failed to create body foreach task") + body := hyperReq.Body() + if body != nil { + req.Body, conn.bodyWriter = io.Pipe() + task := body.Foreach(getBodyChunk, c.Pointer(conn.bodyWriter), nil) + if task != nil { + r := conn.executor.Push(task) + if r != hyper.OK { + fmt.Printf("failed to push body foreach task: %d\n", r) + task.Free() + return nil, fmt.Errorf("failed to push body foreach task: %v", r) } - } else { - return nil, fmt.Errorf("failed to get request body") + return nil, fmt.Errorf("failed to create body foreach task") } + + } else { + return nil, fmt.Errorf("failed to get request body") } + defer hyperReq.Free() + return &req, nil } @@ -177,7 +177,18 @@ func getBodyChunk(userdata c.Pointer, chunk *hyper.Buf) c.Int { writer := (*io.PipeWriter)(userdata) buf := chunk.Bytes() len := chunk.Len() - writer.Write(unsafe.Slice(buf, len)) + bytes := unsafe.Slice(buf, len) + //debug + fmt.Printf("Writing %d bytes to response body\n", len) + fmt.Printf("Body chunk: %s\n", string(bytes)) + + _, err := writer.Write(bytes) + fmt.Printf("Body chunk written\n") + if err != nil { + fmt.Println("Error writing to response body:", err) + writer.Close() + return hyper.IterBreak + } return hyper.IterContinue } diff --git a/x/net/http/response.go b/x/net/http/response.go index 14273c6..20e6f01 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -26,7 +26,6 @@ type body struct { var DefaultChunkSize uintptr = 8192 - func newResponse(channel *hyper.ResponseChannel) *response { fmt.Printf("newResponse called\n") resp := response{ @@ -77,6 +76,15 @@ func (r *response) WriteHeader(statusCode int) { return } } + + //debug + fmt.Printf("< HTTP/1.1 %d\n", statusCode) + for key, values := range r.header { + for _, value := range values { + fmt.Printf("< %s: %s\n", key, value) + } + } + r.resp = newResp } @@ -87,8 +95,8 @@ func (r *response) finalize() error { } bodyData := body{ - data: r.body, - len: uintptr(len(r.body)), + data: r.body, + len: uintptr(len(r.body)), readLen: 0, } fmt.Printf("bodyData constructed\n") @@ -117,7 +125,12 @@ func (r *response) finalize() error { func setBodyDataFunc(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { fmt.Printf("setBodyDataFunc called\n") body := (*body)(userdata) + if body.len > 0 { + //debug + fmt.Println("<") + fmt.Printf("%s", string(body.data)) + if body.len > DefaultChunkSize { *chunk = hyper.CopyBuf(&body.data[body.readLen], DefaultChunkSize) body.readLen += DefaultChunkSize @@ -136,4 +149,4 @@ func setBodyDataFunc(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) fmt.Printf("error setting body data: %s\n", c.GoString(c.Strerror(os.Errno))) return hyper.PollError -} \ No newline at end of file +} diff --git a/x/net/http/server.go b/x/net/http/server.go index c0578fd..d006a95 100644 --- a/x/net/http/server.go +++ b/x/net/http/server.go @@ -3,6 +3,7 @@ package http import ( "errors" "fmt" + "io" "os" "strconv" "sync" @@ -32,34 +33,36 @@ type Server struct { Addr string Handler Handler - uvLoop *libuv.Loop - uvServer libuv.Tcp - inShutdown atomic.Bool - http1Opts *hyper.Http1ServerconnOptions - http2Opts *hyper.Http2ServerconnOptions - checkHandle libuv.Check + uvLoop *libuv.Loop + uvServer libuv.Tcp + inShutdown atomic.Bool + http1Opts *hyper.Http1ServerconnOptions + http2Opts *hyper.Http2ServerconnOptions + checkHandle libuv.Check + idleHandle libuv.Idle mu sync.Mutex activeConnections map[*conn]struct{} } type conn struct { - Stream libuv.Tcp - PollHandle libuv.Poll - EventMask c.Uint - ReadWaker *hyper.Waker - WriteWaker *hyper.Waker - IsClosing atomic.Bool - ClosedHandles int32 - Executor *hyper.Executor + stream libuv.Tcp + pollHandle libuv.Poll + eventMask c.Uint + readWaker *hyper.Waker + writeWaker *hyper.Waker + isClosing atomic.Bool + closedHandles int32 + executor *hyper.Executor + remoteAddr string + bodyWriter *io.PipeWriter } type serviceUserdata struct { - Host [128]c.Char - Port [8]c.Char - Conn *conn - Server *Server - ListenAddr string + host [128]c.Char + port [8]c.Char + conn *conn + server *Server } func NewServer(addr string) *Server { @@ -134,6 +137,18 @@ func (srv *Server) ListenAndServe() error { os.Exit(1) } + // if r := libuv.InitIdle(srv.uvLoop, &srv.idleHandle); r != 0 { + // fmt.Fprintf(os.Stderr, "Failed to initialize idle handler: %d\n", r) + // os.Exit(1) + // } + + // (*libuv.Handle)(unsafe.Pointer(&srv.idleHandle)).SetData(unsafe.Pointer(srv)) + + // if r := srv.idleHandle.Start(onIdle); r != 0 { + // fmt.Fprintf(os.Stderr, "Failed to start idle handler: %d\n", r) + // os.Exit(1) + // } + fmt.Printf("Listening on %s\n", srv.Addr) for { @@ -173,70 +188,70 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { return } - libuv.InitTcp(srv.uvLoop, &conn.Stream) - conn.Stream.Data = unsafe.Pointer(conn) + libuv.InitTcp(srv.uvLoop, &conn.stream) + conn.stream.Data = unsafe.Pointer(conn) - if serverStream.Accept((*libuv.Stream)(unsafe.Pointer(&conn.Stream))) == 0 { + if serverStream.Accept((*libuv.Stream)(unsafe.Pointer(&conn.stream))) == 0 { fmt.Println("Accepted new connection") - r := libuv.PollInit(srv.uvLoop, &conn.PollHandle, libuv.OsFd(conn.Stream.GetIoWatcherFd())) + r := libuv.PollInit(srv.uvLoop, &conn.pollHandle, libuv.OsFd(conn.stream.GetIoWatcherFd())) if r < 0 { fmt.Fprintf(os.Stderr, "uv_poll_init error: %s\n", libuv.Strerror(libuv.Errno(r))) - (*libuv.Handle)(unsafe.Pointer(&conn.Stream)).Close(nil) + (*libuv.Handle)(unsafe.Pointer(&conn.stream)).Close(nil) return } - - (*libuv.Handle)(unsafe.Pointer(&conn.PollHandle)).Data = unsafe.Pointer(conn) - + + (*libuv.Handle)(unsafe.Pointer(&conn.pollHandle)).Data = unsafe.Pointer(conn) + if !updateConnRegistrations(conn, true) { - (*libuv.Handle)(unsafe.Pointer(&conn.PollHandle)).Close(nil) - (*libuv.Handle)(unsafe.Pointer(&conn.Stream)).Close(nil) + (*libuv.Handle)(unsafe.Pointer(&conn.pollHandle)).Close(nil) + (*libuv.Handle)(unsafe.Pointer(&conn.stream)).Close(nil) return } userdata := createServiceUserdata() - userdata.Server = srv + userdata.server = srv if userdata == nil { fmt.Fprintf(os.Stderr, "Failed to create service userdata\n") - (*libuv.Handle)(unsafe.Pointer(&conn.Stream)).Close(nil) + (*libuv.Handle)(unsafe.Pointer(&conn.stream)).Close(nil) return } - fmt.Printf("ListenAddr: %s\n", srv.Addr) - userdata.ListenAddr = srv.Addr var addr cnet.SockaddrStorage addrlen := c.Int(unsafe.Sizeof(addr)) - conn.Stream.Getpeername((*cnet.SockAddr)(c.Pointer(&addr)), &addrlen) + conn.stream.Getpeername((*cnet.SockAddr)(c.Pointer(&addr)), &addrlen) if addr.Family == cnet.AF_INET { s := (*cnet.SockaddrIn)(unsafe.Pointer(&addr)) - libuv.Ip4Name(s, (*c.Char)(&userdata.Host[0]), unsafe.Sizeof(userdata.Host)) - c.Snprintf((*c.Char)(&userdata.Port[0]), unsafe.Sizeof(userdata.Port), c.Str("%d"), cnet.Ntohs(s.Port)) + libuv.Ip4Name(s, (*c.Char)(&userdata.host[0]), unsafe.Sizeof(userdata.host)) + c.Snprintf((*c.Char)(&userdata.port[0]), unsafe.Sizeof(userdata.port), c.Str("%d"), cnet.Ntohs(s.Port)) } else if addr.Family == cnet.AF_INET6 { s := (*cnet.SockaddrIn6)(unsafe.Pointer(&addr)) - libuv.Ip6Name(s, (*c.Char)(&userdata.Host[0]), unsafe.Sizeof(userdata.Host)) - c.Snprintf((*c.Char)(&userdata.Port[0]), unsafe.Sizeof(userdata.Port), c.Str("%d"), cnet.Ntohs(s.Port)) + libuv.Ip6Name(s, (*c.Char)(&userdata.host[0]), unsafe.Sizeof(userdata.host)) + c.Snprintf((*c.Char)(&userdata.port[0]), unsafe.Sizeof(userdata.port), c.Str("%d"), cnet.Ntohs(s.Port)) } + conn.remoteAddr = c.GoString((*c.Char)(&userdata.host[0])) + ":" + c.GoString((*c.Char)(&userdata.port[0])) + executor := hyper.NewExecutor() if executor == nil { fmt.Fprintf(os.Stderr, "Failed to create Executor\n") - (*libuv.Handle)(unsafe.Pointer(&conn.PollHandle)).Close(nil) - (*libuv.Handle)(unsafe.Pointer(&conn.Stream)).Close(nil) + (*libuv.Handle)(unsafe.Pointer(&conn.pollHandle)).Close(nil) + (*libuv.Handle)(unsafe.Pointer(&conn.stream)).Close(nil) return } - conn.Executor = executor + conn.executor = executor fmt.Println("Conn created") srv.trackConn(conn, true) fmt.Println("Conn tracked") - userdata.Conn = conn + userdata.conn = conn io := createIo(conn) service := hyper.ServiceNew(serverCallback) service.SetUserdata(unsafe.Pointer(userdata), nil) - http1Opts := hyper.Http1ServerconnOptionsNew(conn.Executor) + http1Opts := hyper.Http1ServerconnOptionsNew(conn.executor) if http1Opts == nil { fmt.Fprintf(os.Stderr, "Failed to create http1_opts\n") os.Exit(1) @@ -247,8 +262,8 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { os.Exit(1) } srv.http1Opts = http1Opts - - http2Opts := hyper.Http2ServerconnOptionsNew(conn.Executor) + + http2Opts := hyper.Http2ServerconnOptionsNew(conn.executor) if http2Opts == nil { fmt.Fprintf(os.Stderr, "Failed to create http2_opts\n") os.Exit(1) @@ -266,30 +281,50 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { srv.http2Opts = http2Opts serverconn := hyper.ServeHttpXConnection(http1Opts, http2Opts, io, service) - conn.Executor.Push(serverconn) + conn.executor.Push(serverconn) } else { fmt.Println("Client not accepted") - (*libuv.Handle)(unsafe.Pointer(&conn.PollHandle)).Close(nil) - (*libuv.Handle)(unsafe.Pointer(&conn.Stream)).Close(nil) + (*libuv.Handle)(unsafe.Pointer(&conn.pollHandle)).Close(nil) + (*libuv.Handle)(unsafe.Pointer(&conn.stream)).Close(nil) } } func onCheck(handle *libuv.Check) { + //fmt.Println("onCheck called") srv := (*Server)((*libuv.Handle)(unsafe.Pointer(handle)).GetData()) for conn := range srv.activeConnections { - if conn.Executor != nil { - task := conn.Executor.Poll() + if conn.executor != nil { + task := conn.executor.Poll() for task != nil { srv.handleTask(task) - task = conn.Executor.Poll() + task = conn.executor.Poll() } } } - if srv.shuttingDown() { - fmt.Println("Shutdown initiated, cleaning up...") - handle.Stop() - } + if srv.shuttingDown() { + fmt.Println("Shutdown initiated, cleaning up...") + handle.Stop() + } +} + +func onIdle(handle *libuv.Idle) { + //fmt.Println("onIdle called") + srv := (*Server)((*libuv.Handle)(unsafe.Pointer(handle)).GetData()) + for conn := range srv.activeConnections { + if conn.executor != nil { + task := conn.executor.Poll() + for task != nil { + srv.handleTask(task) + task = conn.executor.Poll() + } + } + } + + if srv.shuttingDown() { + fmt.Println("Shutdown initiated, cleaning up...") + handle.Stop() + } } func serverCallback(userdata unsafe.Pointer, hyperReq *hyper.Request, channel *hyper.ResponseChannel) { @@ -300,7 +335,7 @@ func serverCallback(userdata unsafe.Pointer, hyperReq *hyper.Request, channel *h return } - req, err := newRequest(userData.ListenAddr, userData.Conn, hyperReq) + req, err := userData.conn.readRequest(hyperReq) if err != nil { fmt.Printf("Error creating request: %v\n", err) return @@ -309,7 +344,7 @@ func serverCallback(userdata unsafe.Pointer, hyperReq *hyper.Request, channel *h res := newResponse(channel) fmt.Printf("Response created\n") - userData.Server.Handler.ServeHTTP(res, req) + userData.server.Handler.ServeHTTP(res, req) res.finalize() } @@ -321,7 +356,7 @@ func (srv *Server) handleTask(task *hyper.Task) { err := (*hyper.Error)(task.Value()) fmt.Printf("error code: %d\n", err.Code()) - + var errbuf [256]byte errlen := err.Print(&errbuf[0], unsafe.Sizeof(errbuf)) fmt.Printf("details: %s\n", errbuf[:errlen]) @@ -360,15 +395,15 @@ func createIo(conn *conn) *hyper.Io { func createServiceUserdata() *serviceUserdata { userdata := &serviceUserdata{} - if userdata == nil { - fmt.Fprintf(os.Stderr, "Failed to allocate service_userdata\n") - } - return userdata + if userdata == nil { + fmt.Fprintf(os.Stderr, "Failed to allocate service_userdata\n") + } + return userdata } func readCb(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uintptr) uintptr { conn := (*conn)(userdata) - ret := cnet.Recv(conn.Stream.GetIoWatcherFd(), unsafe.Pointer(buf), bufLen, 0) + ret := cnet.Recv(conn.stream.GetIoWatcherFd(), unsafe.Pointer(buf), bufLen, 0) if ret >= 0 { return uintptr(ret) @@ -378,26 +413,26 @@ func readCb(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uintp return hyper.IoError } - if conn.ReadWaker != nil { - conn.ReadWaker.Free() + if conn.readWaker != nil { + conn.readWaker.Free() } - if conn.EventMask&c.Uint(libuv.READABLE) == 0 { - conn.EventMask |= c.Uint(libuv.READABLE) - fmt.Printf("ReadCb Event mask: %d\n", conn.EventMask) + if conn.eventMask&c.Uint(libuv.READABLE) == 0 { + conn.eventMask |= c.Uint(libuv.READABLE) + fmt.Printf("ReadCb Event mask: %d\n", conn.eventMask) if !updateConnRegistrations(conn, false) { return hyper.IoError } fmt.Printf("ReadCb updateConnRegistrations\n") } - conn.ReadWaker = ctx.Waker() + conn.readWaker = ctx.Waker() return hyper.IoPending } func writeCb(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uintptr) uintptr { conn := (*conn)(userdata) - ret := cnet.Send(conn.Stream.GetIoWatcherFd(), unsafe.Pointer(buf), bufLen, 0) + ret := cnet.Send(conn.stream.GetIoWatcherFd(), unsafe.Pointer(buf), bufLen, 0) if ret >= 0 { return uintptr(ret) @@ -407,23 +442,22 @@ func writeCb(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uint return hyper.IoError } - if conn.WriteWaker != nil { - conn.WriteWaker.Free() + if conn.writeWaker != nil { + conn.writeWaker.Free() } - if conn.EventMask&c.Uint(libuv.WRITABLE) == 0 { - conn.EventMask |= c.Uint(libuv.WRITABLE) - fmt.Printf("WriteCb Event mask: %d\n", conn.EventMask) + if conn.eventMask&c.Uint(libuv.WRITABLE) == 0 { + conn.eventMask |= c.Uint(libuv.WRITABLE) + fmt.Printf("WriteCb Event mask: %d\n", conn.eventMask) if !updateConnRegistrations(conn, false) { return hyper.IoError } } - conn.WriteWaker = ctx.Waker() + conn.writeWaker = ctx.Waker() return hyper.IoPending } - func onPoll(handle *libuv.Poll, status c.Int, events c.Int) { fmt.Printf("onPoll called\n") conn := (*conn)((*libuv.Handle)(unsafe.Pointer(handle)).GetData()) @@ -433,14 +467,14 @@ func onPoll(handle *libuv.Poll, status c.Int, events c.Int) { return } - if events&c.Int(libuv.READABLE) != 0 && conn.ReadWaker != nil { - conn.ReadWaker.Wake() - conn.ReadWaker = nil + if events&c.Int(libuv.READABLE) != 0 && conn.readWaker != nil { + conn.readWaker.Wake() + conn.readWaker = nil } - if events&c.Int(libuv.WRITABLE) != 0 && conn.WriteWaker != nil { - conn.WriteWaker.Wake() - conn.WriteWaker = nil + if events&c.Int(libuv.WRITABLE) != 0 && conn.writeWaker != nil { + conn.writeWaker.Wake() + conn.writeWaker = nil } } @@ -448,26 +482,24 @@ func updateConnRegistrations(conn *conn, create bool) bool { fmt.Println("updateConnRegistrations called") events := c.Int(0) - if conn.EventMask == 0 { + if conn.eventMask == 0 { fmt.Println("No events to poll, skipping poll start.") return true } - fmt.Printf("Event mask: %d\n", conn.EventMask) - if conn.EventMask&c.Uint(libuv.READABLE) != 0 { + fmt.Printf("Event mask: %d\n", conn.eventMask) + if conn.eventMask&c.Uint(libuv.READABLE) != 0 { events |= c.Int(libuv.READABLE) } - if conn.EventMask&c.Uint(libuv.WRITABLE) != 0 { + if conn.eventMask&c.Uint(libuv.WRITABLE) != 0 { events |= c.Int(libuv.WRITABLE) } fmt.Printf("Starting poll with events: %d\n", events) - r := conn.PollHandle.Start(events, onPoll) - //fmt.Println("Poll handle started: %d", r) + r := conn.pollHandle.Start(events, onPoll) if r < 0 { fmt.Fprintf(os.Stderr, "uv_poll_start error: %s\n", libuv.Strerror(libuv.Errno(r))) return false } - fmt.Printf("Poll handle started: %d\n", r) return true } @@ -476,36 +508,35 @@ func createConnData() (*conn, error) { if conn == nil { return nil, fmt.Errorf("failed to allocate conn_data") } - conn.IsClosing.Store(false) - conn.ClosedHandles = 0 + conn.isClosing.Store(false) + conn.closedHandles = 0 return conn, nil } func freeConnData(userdata c.Pointer) { conn := (*conn)(userdata) - if conn != nil && !conn.IsClosing.Swap(true){ + if conn != nil && !conn.isClosing.Swap(true) { fmt.Printf("Closing connection...\n") - if conn.ReadWaker != nil { - conn.ReadWaker.Free() - conn.ReadWaker = nil + if conn.readWaker != nil { + conn.readWaker.Free() + conn.readWaker = nil } - if conn.WriteWaker != nil { - conn.WriteWaker.Free() - conn.WriteWaker = nil + if conn.writeWaker != nil { + conn.writeWaker.Free() + conn.writeWaker = nil } - if (*libuv.Handle)(unsafe.Pointer(&conn.PollHandle)).IsClosing() == 0 { - (*libuv.Handle)(unsafe.Pointer(&conn.PollHandle)).Close(nil) - } + if (*libuv.Handle)(unsafe.Pointer(&conn.pollHandle)).IsClosing() == 0 { + (*libuv.Handle)(unsafe.Pointer(&conn.pollHandle)).Close(nil) + } - if (*libuv.Handle)(unsafe.Pointer(&conn.Stream)).IsClosing() == 0 { - (*libuv.Handle)(unsafe.Pointer(&conn.Stream)).Close(nil) - } + if (*libuv.Handle)(unsafe.Pointer(&conn.stream)).IsClosing() == 0 { + (*libuv.Handle)(unsafe.Pointer(&conn.stream)).Close(nil) + } } } - func closeWalkCb(handle *libuv.Handle, arg c.Pointer) { if handle.IsClosing() == 0 { handle.Close(nil) @@ -518,8 +549,8 @@ func (srv *Server) Close() error { defer srv.mu.Unlock() for c := range srv.activeConnections { - if c.Executor != nil { - c.Executor.Free() + if c.executor != nil { + c.executor.Free() } delete(srv.activeConnections, c) } @@ -530,11 +561,11 @@ func (srv *Server) Close() error { srv.uvLoop.Close() if srv.http1Opts != nil { - srv.http1Opts.Free() - } - if srv.http2Opts != nil { - srv.http2Opts.Free() - } + srv.http1Opts.Free() + } + if srv.http2Opts != nil { + srv.http2Opts.Free() + } return nil } From f03b73f27ee6efc3354719171212e6ccf1fa9c7c Mon Sep 17 00:00:00 2001 From: hackerchai Date: Thu, 5 Sep 2024 18:06:50 +0800 Subject: [PATCH 30/55] refactor(x/net/ht tp/demo): Use echo read demo Signed-off-by: hackerchai --- x/net/http/_demo/http.go | 49 ++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/x/net/http/_demo/http.go b/x/net/http/_demo/http.go index 2b20f0b..31a3d14 100644 --- a/x/net/http/_demo/http.go +++ b/x/net/http/_demo/http.go @@ -2,36 +2,13 @@ package main import ( "fmt" - //"io" + "io" "github.com/goplus/llgo/x/net/http" ) func echoHandler(w http.ResponseWriter, r *http.Request) { fmt.Printf("echoHandler called\n") - //TODO(hackerchai): read body and echo - // fmt.Printf("> %s %s HTTP/%d.%d\n", r.Method, r.RequestURI, r.ProtoMajor, r.ProtoMinor) - // for key, values := range r.Header { - // for _, value := range values { - // fmt.Printf("> %s: %s\n", key, value) - // } - // } - // fmt.Printf("URL: %s\n", r.URL.String()) - // // fmt.Println("ContentLength: %d", r.ContentLength) - // // fmt.Println("TransferEncoding: %s", r.TransferEncoding) - // //TODO: read body and echo - // body, err := io.ReadAll(r.Body) - // println("body read") - - // if err != nil { - // http.Error(w, "Error reading request body", http.StatusInternalServerError) - // return - // } - // defer r.Body.Close() - // fmt.Printf("body read") - // w.Header().Set("Content-Type", "text/plain") - // w.Write(body) - fmt.Printf("> %s %s HTTP/%d.%d\n", r.Method, r.RequestURI, r.ProtoMajor, r.ProtoMinor) for key, values := range r.Header { for _, value := range values { @@ -39,8 +16,30 @@ func echoHandler(w http.ResponseWriter, r *http.Request) { } } fmt.Printf("URL: %s\n", r.URL.String()) + fmt.Printf("RemoteAddr: %s\n", r.RemoteAddr) + // fmt.Println("ContentLength: %d", r.ContentLength) + // fmt.Println("TransferEncoding: %s", r.TransferEncoding) + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Error reading request body", http.StatusInternalServerError) + return + } + + defer r.Body.Close() + fmt.Printf("body read") w.Header().Set("Content-Type", "text/plain") - w.Write([]byte("hello world\n")) + w.Write(body) + + // fmt.Printf("> %s %s HTTP/%d.%d\n", r.Method, r.RequestURI, r.ProtoMajor, r.ProtoMinor) + // for key, values := range r.Header { + // for _, value := range values { + // fmt.Printf("> %s: %s\n", key, value) + // } + // } + // fmt.Printf("URL: %s\n", r.URL.String()) + // fmt.Printf("RemoteAddr: %s\n", r.RemoteAddr) + // w.Header().Set("Content-Type", "text/plain") + // w.Write([]byte("hello world\n")) } func main() { From 3f94d3f19e32f2f9179c3c0c0c4e241dfa8bce98 Mon Sep 17 00:00:00 2001 From: hackerchai Date: Thu, 5 Sep 2024 18:07:50 +0800 Subject: [PATCH 31/55] refactor(x/net/http): Temporarily fix pipe write stuck Signed-off-by: hackerchai --- x/net/http/request.go | 32 +++++-- x/net/http/response.go | 35 ++++++- x/net/http/server.go | 198 ++++++++++++++++++++++++++-------------- x/net/http/servermux.go | 3 +- 4 files changed, 185 insertions(+), 83 deletions(-) diff --git a/x/net/http/request.go b/x/net/http/request.go index 9d458c3..26187ff 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -137,6 +137,12 @@ func (conn *conn) readRequest(hyperReq *hyper.Request) (*Request, error) { if body != nil { req.Body, conn.bodyWriter = io.Pipe() task := body.Foreach(getBodyChunk, c.Pointer(conn.bodyWriter), nil) + taskData := taskData { + body: nil, + conn: conn, + hyperTaskID: taskGetBody, + } + task.SetUserdata(c.Pointer(&taskData), nil) if task != nil { r := conn.executor.Push(task) if r != hyper.OK { @@ -175,6 +181,10 @@ func addHeader(data unsafe.Pointer, name *byte, nameLen uintptr, value *byte, va func getBodyChunk(userdata c.Pointer, chunk *hyper.Buf) c.Int { fmt.Printf("getBodyChunk called\n") writer := (*io.PipeWriter)(userdata) + if writer == nil { + fmt.Printf("writer is nil\n") + return hyper.IterBreak + } buf := chunk.Bytes() len := chunk.Len() bytes := unsafe.Slice(buf, len) @@ -182,13 +192,21 @@ func getBodyChunk(userdata c.Pointer, chunk *hyper.Buf) c.Int { fmt.Printf("Writing %d bytes to response body\n", len) fmt.Printf("Body chunk: %s\n", string(bytes)) - _, err := writer.Write(bytes) - fmt.Printf("Body chunk written\n") - if err != nil { - fmt.Println("Error writing to response body:", err) - writer.Close() - return hyper.IterBreak - } + go func() { + count, err := writer.Write(bytes) + fmt.Printf("Body chunk written: %d bytes\n", count) + if err != nil { + fmt.Println("Error writing to response body:", err) + writer.Close() + } + }() + // count, err := writer.Write(bytes) + // fmt.Printf("Body chunk written: %d bytes\n", count) + // if err != nil { + // fmt.Println("Error writing to response body:", err) + // writer.Close() + // return hyper.IterBreak + // } return hyper.IterContinue } diff --git a/x/net/http/response.go b/x/net/http/response.go index 20e6f01..29c065d 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -16,6 +16,7 @@ type response struct { body []byte channel *hyper.ResponseChannel resp *hyper.Response + request *Request } type body struct { @@ -24,13 +25,29 @@ type body struct { readLen uintptr } +type taskData struct { + body *body + conn *conn + hyperTaskID +} + +type hyperTaskID int + +const ( + taskSetBody hyperTaskID = iota + taskGetBody +) + + + var DefaultChunkSize uintptr = 8192 -func newResponse(channel *hyper.ResponseChannel) *response { +func newResponse(request *Request, channel *hyper.ResponseChannel) *response { fmt.Printf("newResponse called\n") resp := response{ header: make(Header), channel: channel, + request: request, } return &resp } @@ -90,6 +107,12 @@ func (r *response) WriteHeader(statusCode int) { func (r *response) finalize() error { fmt.Printf("finalize called\n") + err := r.request.Body.Close() + if err != nil { + return err + } + fmt.Printf("request body closed\n") + if !r.written { r.WriteHeader(200) } @@ -105,8 +128,13 @@ func (r *response) finalize() error { if body == nil { return fmt.Errorf("failed to create body") } + taskData := taskData{ + body: &bodyData, + conn: nil, + hyperTaskID: taskSetBody, + } body.SetDataFunc(setBodyDataFunc) - body.SetUserdata(unsafe.Pointer(&bodyData), nil) + body.SetUserdata(unsafe.Pointer(&taskData), nil) fmt.Printf("bodyData userdata set\n") fmt.Printf("bodyData set\n") @@ -124,12 +152,13 @@ func (r *response) finalize() error { func setBodyDataFunc(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { fmt.Printf("setBodyDataFunc called\n") - body := (*body)(userdata) + body := (*taskData)(userdata).body if body.len > 0 { //debug fmt.Println("<") fmt.Printf("%s", string(body.data)) + fmt.Println("") if body.len > DefaultChunkSize { *chunk = hyper.CopyBuf(&body.data[body.readLen], DefaultChunkSize) diff --git a/x/net/http/server.go b/x/net/http/server.go index d006a95..c6aee65 100644 --- a/x/net/http/server.go +++ b/x/net/http/server.go @@ -33,13 +33,11 @@ type Server struct { Addr string Handler Handler - uvLoop *libuv.Loop - uvServer libuv.Tcp - inShutdown atomic.Bool - http1Opts *hyper.Http1ServerconnOptions - http2Opts *hyper.Http2ServerconnOptions - checkHandle libuv.Check - idleHandle libuv.Idle + uvLoop *libuv.Loop + uvServer libuv.Tcp + inShutdown atomic.Bool + //checkHandle libuv.Check + idleHandle libuv.Idle mu sync.Mutex activeConnections map[*conn]struct{} @@ -51,6 +49,8 @@ type conn struct { eventMask c.Uint readWaker *hyper.Waker writeWaker *hyper.Waker + http1Opts *hyper.Http1ServerconnOptions + http2Opts *hyper.Http2ServerconnOptions isClosing atomic.Bool closedHandles int32 executor *hyper.Executor @@ -119,49 +119,43 @@ func (srv *Server) ListenAndServe() error { return fmt.Errorf("failed to set SO_REUSEADDR: %v", result) } - //(*libuv.Stream)(&srv.uvServer).Data = unsafe.Pointer(srv) - (*libuv.Handle)(unsafe.Pointer(&srv.uvServer)).SetData(unsafe.Pointer(srv)) + srv.uvServer.Data = unsafe.Pointer(srv) if err := (*libuv.Stream)(&srv.uvServer).Listen(128, onNewConnection); err != 0 { return fmt.Errorf("failed to listen: %v", err) } - if r := libuv.InitCheck(srv.uvLoop, &srv.checkHandle); r != 0 { - fmt.Fprintf(os.Stderr, "Failed to initialize check handler: %d\n", r) - os.Exit(1) - } - - (*libuv.Handle)(unsafe.Pointer(&srv.checkHandle)).SetData(unsafe.Pointer(srv)) - - if r := srv.checkHandle.Start(onCheck); r != 0 { - fmt.Fprintf(os.Stderr, "Failed to start check handler: %d\n", r) - os.Exit(1) - } - - // if r := libuv.InitIdle(srv.uvLoop, &srv.idleHandle); r != 0 { - // fmt.Fprintf(os.Stderr, "Failed to initialize idle handler: %d\n", r) + // if r := libuv.InitCheck(srv.uvLoop, &srv.checkHandle); r != 0 { + // fmt.Fprintf(os.Stderr, "Failed to initialize check handler: %d\n", r) // os.Exit(1) // } - // (*libuv.Handle)(unsafe.Pointer(&srv.idleHandle)).SetData(unsafe.Pointer(srv)) + // (*libuv.Handle)(unsafe.Pointer(&srv.checkHandle)).SetData(unsafe.Pointer(srv)) - // if r := srv.idleHandle.Start(onIdle); r != 0 { - // fmt.Fprintf(os.Stderr, "Failed to start idle handler: %d\n", r) + // if r := srv.checkHandle.Start(onCheck); r != 0 { + // fmt.Fprintf(os.Stderr, "Failed to start check handler: %d\n", r) // os.Exit(1) // } - fmt.Printf("Listening on %s\n", srv.Addr) + if r := libuv.InitIdle(srv.uvLoop, &srv.idleHandle); r != 0 { + fmt.Fprintf(os.Stderr, "Failed to initialize idle handler: %d\n", r) + os.Exit(1) + } - for { - res := srv.uvLoop.Run(libuv.RUN_NOWAIT) - if res < 0 { - fmt.Fprintf(os.Stderr, "uv_loop_run error: %s\n", libuv.Strerror(libuv.Errno(res))) - break - } + (*libuv.Handle)(unsafe.Pointer(&srv.idleHandle)).SetData(unsafe.Pointer(srv)) - if srv.shuttingDown() { - break - } + if r := srv.idleHandle.Start(onIdle); r != 0 { + fmt.Fprintf(os.Stderr, "Failed to start idle handler: %d\n", r) + os.Exit(1) } + + fmt.Printf("Listening on %s\n", srv.Addr) + + res := srv.uvLoop.Run(libuv.RUN_DEFAULT) + if res != 0 { + fmt.Fprintf(os.Stderr, "Error in event loop: %v\n", res) + os.Exit(1) + } + return nil } @@ -250,7 +244,6 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { io := createIo(conn) service := hyper.ServiceNew(serverCallback) service.SetUserdata(unsafe.Pointer(userdata), nil) - http1Opts := hyper.Http1ServerconnOptionsNew(conn.executor) if http1Opts == nil { fmt.Fprintf(os.Stderr, "Failed to create http1_opts\n") @@ -261,7 +254,7 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { fmt.Fprintf(os.Stderr, "Failed to set header read timeout for http1_opts\n") os.Exit(1) } - srv.http1Opts = http1Opts + conn.http1Opts = http1Opts http2Opts := hyper.Http2ServerconnOptionsNew(conn.executor) if http2Opts == nil { @@ -278,7 +271,7 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { fmt.Fprintf(os.Stderr, "Failed to set keep alive timeout for http2_opts\n") os.Exit(1) } - srv.http2Opts = http2Opts + conn.http2Opts = http2Opts serverconn := hyper.ServeHttpXConnection(http1Opts, http2Opts, io, service) conn.executor.Push(serverconn) @@ -289,29 +282,30 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { } } -func onCheck(handle *libuv.Check) { - //fmt.Println("onCheck called") - srv := (*Server)((*libuv.Handle)(unsafe.Pointer(handle)).GetData()) - for conn := range srv.activeConnections { - if conn.executor != nil { - task := conn.executor.Poll() - for task != nil { - srv.handleTask(task) - task = conn.executor.Poll() - } - } - } - - if srv.shuttingDown() { - fmt.Println("Shutdown initiated, cleaning up...") - handle.Stop() - } -} +// func onCheck(handle *libuv.Check) { +// //fmt.Println("onCheck called") +// srv := (*Server)((*libuv.Handle)(unsafe.Pointer(handle)).GetData()) +// for conn := range srv.activeConnections { +// if conn.executor != nil { +// task := conn.executor.Poll() +// for task != nil { +// srv.handleTask(task) +// task = conn.executor.Poll() +// } +// } +// } + +// if srv.shuttingDown() { +// fmt.Println("Shutdown initiated, cleaning up...") +// handle.Stop() +// } +// } func onIdle(handle *libuv.Idle) { //fmt.Println("onIdle called") srv := (*Server)((*libuv.Handle)(unsafe.Pointer(handle)).GetData()) for conn := range srv.activeConnections { + //fmt.Println("onIdle conn called") if conn.executor != nil { task := conn.executor.Poll() for task != nil { @@ -341,16 +335,35 @@ func serverCallback(userdata unsafe.Pointer, hyperReq *hyper.Request, channel *h return } - res := newResponse(channel) + res := newResponse(req, channel) fmt.Printf("Response created\n") - userData.server.Handler.ServeHTTP(res, req) + go func() { + userData.server.Handler.ServeHTTP(res, req) + res.finalize() + }() - res.finalize() + // userData.server.Handler.ServeHTTP(res, req) + + // res.finalize() } func (srv *Server) handleTask(task *hyper.Task) { taskType := task.Type() + taskData := (*taskData)(task.Userdata()) + fmt.Println("handleTask called") + if taskData != nil { + if taskData.hyperTaskID == taskGetBody { + fmt.Println("taskGetBody called") + if taskData.conn != nil && taskData.conn.bodyWriter != nil { + fmt.Println("taskGetBody calling Close") + taskData.conn.bodyWriter.Close() + } + } else if taskData.hyperTaskID == taskSetBody { + fmt.Println("taskSetBody called") + } + } + if taskType == hyper.TaskError { fmt.Println("hyper task failed with error!") @@ -527,6 +540,20 @@ func freeConnData(userdata c.Pointer) { conn.writeWaker = nil } + if conn.executor != nil { + conn.executor.Free() + conn.executor = nil + } + + if conn.http1Opts != nil { + conn.http1Opts.Free() + conn.http1Opts = nil + } + if conn.http2Opts != nil { + conn.http2Opts.Free() + conn.http2Opts = nil + } + if (*libuv.Handle)(unsafe.Pointer(&conn.pollHandle)).IsClosing() == 0 { (*libuv.Handle)(unsafe.Pointer(&conn.pollHandle)).Close(nil) } @@ -549,23 +576,16 @@ func (srv *Server) Close() error { defer srv.mu.Unlock() for c := range srv.activeConnections { - if c.executor != nil { - c.executor.Free() - } + c.Close() + delete(srv.activeConnections, c) } srv.uvLoop.Walk(closeWalkCb, nil) - srv.uvLoop.Run(libuv.RUN_DEFAULT) + srv.uvLoop.Run(libuv.RUN_ONCE) + (*libuv.Handle)(unsafe.Pointer(&srv.uvServer)).Close(nil) srv.uvLoop.Close() - - if srv.http1Opts != nil { - srv.http1Opts.Free() - } - if srv.http2Opts != nil { - srv.http2Opts.Free() - } return nil } @@ -573,6 +593,42 @@ func (s *Server) shuttingDown() bool { return s.inShutdown.Load() } +func (c *conn) shuttingDown() bool { + return c.isClosing.Load() +} + +func (c *conn) Close() { + c.isClosing.Store(true) + if c.shuttingDown() { + return + } + + if c.readWaker != nil { + c.readWaker.Free() + c.readWaker = nil + } + if c.writeWaker != nil { + c.writeWaker.Free() + c.writeWaker = nil + } + + if c.executor != nil { + c.executor.Free() + c.executor = nil + } + if c.http1Opts != nil { + c.http1Opts.Free() + c.http1Opts = nil + } + if c.http2Opts != nil { + c.http2Opts.Free() + c.http2Opts = nil + } + + (*libuv.Handle)(unsafe.Pointer(&c.pollHandle)).Close(nil) + (*libuv.Handle)(unsafe.Pointer(&c.stream)).Close(nil) +} + type HandlerFunc func(ResponseWriter, *Request) func (f HandlerFunc) ServeHTTP(w ResponseWriter, r *Request) { diff --git a/x/net/http/servermux.go b/x/net/http/servermux.go index 21d9b20..6da8bce 100644 --- a/x/net/http/servermux.go +++ b/x/net/http/servermux.go @@ -41,7 +41,6 @@ func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Re } func (mux *ServeMux) Handle(pattern string, handler Handler) { - fmt.Printf("Handle called with pattern: %s\n", pattern) mux.mu.Lock() defer mux.mu.Unlock() @@ -56,4 +55,4 @@ func (mux *ServeMux) Handle(pattern string, handler Handler) { } mux.m[pattern] = muxEntry{h: handler, pattern: pattern} -} \ No newline at end of file +} From 757e3a102f134e6edcc2e57efe8b72ef21b20857 Mon Sep 17 00:00:00 2001 From: hackerchai Date: Fri, 6 Sep 2024 18:23:29 +0800 Subject: [PATCH 32/55] refactor(x/net/http): Rewrite getBodyData logic using hyper_body_data Signed-off-by: hackerchai --- x/net/http/_demo/http.go | 13 +---- x/net/http/request.go | 7 ++- x/net/http/response.go | 32 ++++++------ x/net/http/server.go | 102 +++++++++++++++++++++++++++++---------- 4 files changed, 99 insertions(+), 55 deletions(-) diff --git a/x/net/http/_demo/http.go b/x/net/http/_demo/http.go index 31a3d14..2e47992 100644 --- a/x/net/http/_demo/http.go +++ b/x/net/http/_demo/http.go @@ -26,20 +26,9 @@ func echoHandler(w http.ResponseWriter, r *http.Request) { } defer r.Body.Close() - fmt.Printf("body read") + fmt.Println("body read") w.Header().Set("Content-Type", "text/plain") w.Write(body) - - // fmt.Printf("> %s %s HTTP/%d.%d\n", r.Method, r.RequestURI, r.ProtoMajor, r.ProtoMinor) - // for key, values := range r.Header { - // for _, value := range values { - // fmt.Printf("> %s: %s\n", key, value) - // } - // } - // fmt.Printf("URL: %s\n", r.URL.String()) - // fmt.Printf("RemoteAddr: %s\n", r.RemoteAddr) - // w.Header().Set("Content-Type", "text/plain") - // w.Write([]byte("hello world\n")) } func main() { diff --git a/x/net/http/request.go b/x/net/http/request.go index 26187ff..ba97c85 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -136,11 +136,14 @@ func (conn *conn) readRequest(hyperReq *hyper.Request) (*Request, error) { body := hyperReq.Body() if body != nil { req.Body, conn.bodyWriter = io.Pipe() - task := body.Foreach(getBodyChunk, c.Pointer(conn.bodyWriter), nil) + //task := body.Foreach(getBodyChunk, c.Pointer(conn.bodyWriter), nil) + task := body.Data() + taskID := taskGetBody taskData := taskData { + hyperBody: body, body: nil, conn: conn, - hyperTaskID: taskGetBody, + hyperTaskID: taskID, } task.SetUserdata(c.Pointer(&taskData), nil) if task != nil { diff --git a/x/net/http/response.go b/x/net/http/response.go index 29c065d..4d150ba 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -26,6 +26,7 @@ type body struct { } type taskData struct { + hyperBody *hyper.Body body *body conn *conn hyperTaskID @@ -48,6 +49,7 @@ func newResponse(request *Request, channel *hyper.ResponseChannel) *response { header: make(Header), channel: channel, request: request, + resp: hyper.NewResponse(), } return &resp } @@ -72,11 +74,9 @@ func (r *response) WriteHeader(statusCode int) { r.written = true r.statusCode = statusCode - newResp := hyper.NewResponse() + r.resp.SetStatus(uint16(statusCode)) - newResp.SetStatus(uint16(statusCode)) - - headers := newResp.Headers() + headers := r.resp.Headers() for key, values := range r.header { valueLen := len(values) if valueLen > 1 { @@ -95,14 +95,12 @@ func (r *response) WriteHeader(statusCode int) { } //debug - fmt.Printf("< HTTP/1.1 %d\n", statusCode) - for key, values := range r.header { - for _, value := range values { - fmt.Printf("< %s: %s\n", key, value) - } - } - - r.resp = newResp + // fmt.Printf("< HTTP/1.1 %d\n", statusCode) + // for key, values := range r.header { + // for _, value := range values { + // fmt.Printf("< %s: %s\n", key, value) + // } + // } } func (r *response) finalize() error { @@ -128,13 +126,14 @@ func (r *response) finalize() error { if body == nil { return fmt.Errorf("failed to create body") } - taskData := taskData{ + taskData := &taskData{ + hyperBody: nil, body: &bodyData, conn: nil, hyperTaskID: taskSetBody, } body.SetDataFunc(setBodyDataFunc) - body.SetUserdata(unsafe.Pointer(&taskData), nil) + body.SetUserdata(unsafe.Pointer(taskData), nil) fmt.Printf("bodyData userdata set\n") fmt.Printf("bodyData set\n") @@ -146,13 +145,14 @@ func (r *response) finalize() error { fmt.Printf("body set\n") r.channel.Send(r.resp) - fmt.Printf("response sent\n") + fmt.Println("response sent") return nil } func setBodyDataFunc(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { fmt.Printf("setBodyDataFunc called\n") - body := (*taskData)(userdata).body + taskData := (*taskData)(userdata) + body := taskData.body if body.len > 0 { //debug diff --git a/x/net/http/server.go b/x/net/http/server.go index c6aee65..cf1df59 100644 --- a/x/net/http/server.go +++ b/x/net/http/server.go @@ -350,37 +350,89 @@ func serverCallback(userdata unsafe.Pointer, hyperReq *hyper.Request, channel *h func (srv *Server) handleTask(task *hyper.Task) { taskType := task.Type() - taskData := (*taskData)(task.Userdata()) - fmt.Println("handleTask called") - if taskData != nil { - if taskData.hyperTaskID == taskGetBody { - fmt.Println("taskGetBody called") - if taskData.conn != nil && taskData.conn.bodyWriter != nil { - fmt.Println("taskGetBody calling Close") - taskData.conn.bodyWriter.Close() - } - } else if taskData.hyperTaskID == taskSetBody { - fmt.Println("taskSetBody called") - } + fmt.Printf("taskType: %d\n", taskType) + payload := (*taskData)(task.Userdata()) + if payload == nil { + fmt.Println("taskData is nil") + task.Free() + return } - if taskType == hyper.TaskError { - fmt.Println("hyper task failed with error!") + taskID := payload.hyperTaskID - err := (*hyper.Error)(task.Value()) - fmt.Printf("error code: %d\n", err.Code()) + fmt.Printf("taskID: %d\n", taskID) - var errbuf [256]byte - errlen := err.Print(&errbuf[0], unsafe.Sizeof(errbuf)) - fmt.Printf("details: %s\n", errbuf[:errlen]) + if taskID == taskGetBody { + fmt.Println("taskGetBody called") + if taskType == hyper.TaskError { + fmt.Println("taskGetBody error") + err := (*hyper.Error)(task.Value()) + fmt.Printf("error code: %d\n", err.Code()) - err.Free() - task.Free() - } else if taskType == hyper.TaskEmpty { - fmt.Println("internal hyper task complete") + var errbuf [256]byte + errlen := err.Print(&errbuf[0], unsafe.Sizeof(errbuf)) + fmt.Printf("details: %s\n", errbuf[:errlen]) + err.Free() + task.Free() + } + + if taskType == hyper.TaskBuf { + fmt.Println("taskGetBody write buf") + buf := (*hyper.Buf)(task.Value()) + bytes := unsafe.Slice(buf.Bytes(), buf.Len()) + payload.conn.bodyWriter.Write(bytes) + fmt.Println("taskGetBody wrote to bodyWriter") + buf.Free() + task.Free() + fmt.Println("taskGetBody free task") + + fmt.Println("taskGetBody get body task") + getBodyTask := payload.hyperBody.Data() + getBodyTask.SetUserdata(c.Pointer(payload), nil) + if getBodyTask != nil { + fmt.Println("taskGetBody push get body task") + r := payload.conn.executor.Push(getBodyTask) + fmt.Printf("taskGetBody push get body task: %d\n", r) + if r != hyper.OK { + fmt.Printf("failed to push get body task: %d\n", r) + getBodyTask.Free() + } + } + } + + if taskType == hyper.TaskEmpty { + fmt.Println("taskGetBody close bodyWriter") + payload.conn.bodyWriter.Close() + fmt.Println("taskGetBody free task") + task.Free() + } + } else if taskID == taskSetBody { + fmt.Println("taskSetBody called") + if taskType == hyper.TaskError { + fmt.Println("taskSetBody error") + err := (*hyper.Error)(task.Value()) + fmt.Printf("error code: %d\n", err.Code()) + + var errbuf [256]byte + errlen := err.Print(&errbuf[0], unsafe.Sizeof(errbuf)) + fmt.Printf("details: %s\n", errbuf[:errlen]) + err.Free() + task.Free() + } + + if taskType == hyper.TaskEmpty { + fmt.Println("taskSetBody free task") + task.Free() + } + } + + if taskType == hyper.TaskEmpty { + fmt.Println("taskEmpty called") task.Free() - } else if taskType == hyper.TaskServerconn { - fmt.Println("server connection task complete") + } + + if taskType == hyper.TaskServerconn { + fmt.Println("taskServerconn called") task.Free() } } From 650cb6cb41a6578a4d266180c05776f447f00733 Mon Sep 17 00:00:00 2001 From: hackerchai Date: Mon, 9 Sep 2024 18:02:40 +0800 Subject: [PATCH 33/55] feat(x/net/http): Implement responseStream reader & writer Signed-off-by: hackerchai --- x/net/http/response_stream/response_stream.go | 178 ++++++++++++++++++ 1 file changed, 178 insertions(+) create mode 100644 x/net/http/response_stream/response_stream.go diff --git a/x/net/http/response_stream/response_stream.go b/x/net/http/response_stream/response_stream.go new file mode 100644 index 0000000..bb0c786 --- /dev/null +++ b/x/net/http/response_stream/response_stream.go @@ -0,0 +1,178 @@ +package response_stream + +import ( + "errors" + "io" + "sync" +) + +type onceError struct { + sync.Mutex + err error +} + +func (a *onceError) Store(err error) { + a.Lock() + defer a.Unlock() + if a.err != nil { + return + } + a.err = err +} + +func (a *onceError) Load() error { + a.Lock() + defer a.Unlock() + return a.err +} + +var ErrClosedResponseStream = errors.New("response stream: read/write on closed stream") + +type responseStream struct { + wrMu sync.Mutex + wrCh chan []byte + rdCh chan struct{} + rdRemain []byte + done chan struct{} + once sync.Once + rerr onceError + werr onceError +} + +func (rs *responseStream) write(p []byte) (n int, err error) { + select { + case <-rs.done: + return 0, rs.writeCloseError() + default: + rs.wrMu.Lock() + defer rs.wrMu.Unlock() + } + + rs.wrCh <- p + return len(p), nil +} + +func (rs *responseStream) read(p []byte) (n int, err error) { + if len(rs.rdRemain) > 0 { + n = copy(p, rs.rdRemain) + rs.rdRemain = rs.rdRemain[n:] + if len(rs.rdRemain) == 0 { + select { + case rs.rdCh <- struct{}{}: + default: + } + } + return n, nil + } + + select { + case <-rs.done: + return 0, rs.rerr.Load() + case data := <-rs.wrCh: + n = copy(p, data) + if n < len(data) { + rs.rdRemain = data[n:] + } else { + select { + case rs.rdCh <- struct{}{}: + default: + } + } + return n, nil + } +} + +func (rs *responseStream) closeWrite(err error) error { + if err == nil { + err = io.EOF + } + rs.werr.Store(err) + rs.once.Do(func() { + close(rs.done) + }) + return nil +} + +func (rs *responseStream) closeRead(err error) error { + if err == nil { + err = ErrClosedResponseStream + } + rs.rerr.Store(err) + rs.once.Do(func() { + close(rs.done) + }) + return nil +} + +func (rs *responseStream) Close() error { + rs.once.Do(func() { + close(rs.done) + }) + return rs.readCloseError() +} + +func (rs *responseStream) readCloseError() error { + werr := rs.werr.Load() + if rerr := rs.rerr.Load(); werr == nil && rerr != nil { + return rerr + } + return ErrClosedResponseStream +} + +func (rs *responseStream) writeCloseError() error { + werr := rs.werr.Load() + if werr != nil { + return werr + } + return ErrClosedResponseStream +} + + +type ResponseStreamReader struct { + responseStream +} + +func (rs *ResponseStreamReader) Read(data []byte) (n int, err error) { + return rs.responseStream.read(data) +} + +func (rs *ResponseStreamReader) Close() error { + return rs.closeWithError(nil) +} + +func (rs *ResponseStreamReader) closeWithError(err error) error { + return rs.responseStream.closeRead(err) +} + +type ResponseStreamWriter struct { + r ResponseStreamReader +} + +func (rs *ResponseStreamWriter) Write(data []byte) (n int, err error) { + return rs.r.responseStream.write(data) +} + +func (rs *ResponseStreamWriter) Close() error { + return rs.r.closeWithError(nil) +} + +func (rs *ResponseStreamWriter) closeWithError(err error) error { + return rs.r.responseStream.closeWrite(err) +} + +func (rs *ResponseStreamWriter) GetRdCh() <-chan struct{} { + return rs.r.rdCh +} + +func NewResponseStream() (*ResponseStreamReader, *ResponseStreamWriter) { + rsw := &ResponseStreamWriter{ + r: ResponseStreamReader{ + responseStream: responseStream{ + wrCh: make(chan []byte, 1), + rdCh: make(chan struct{}, 1), + done: make(chan struct{}), + }, + }, + } + return &rsw.r, rsw +} \ No newline at end of file From 34ace8fc8239ce01f9511493aaaa417ebf679dd8 Mon Sep 17 00:00:00 2001 From: hackerchai Date: Mon, 9 Sep 2024 18:03:24 +0800 Subject: [PATCH 34/55] feat(x/net/http): Implement RequestBody logic Signed-off-by: hackerchai --- x/net/http/request.go | 63 ++++++++++++++++------------------ x/net/http/server.go | 79 ++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 104 insertions(+), 38 deletions(-) diff --git a/x/net/http/request.go b/x/net/http/request.go index ba97c85..58c1c82 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -35,7 +35,7 @@ type Request struct { timeout time.Duration } -func (conn *conn) readRequest(hyperReq *hyper.Request) (*Request, error) { +func (conn *conn) readRequest(server *Server, hyperReq *hyper.Request) (*Request, error) { println("readRequest called") req := Request{ Header: make(Header), @@ -135,7 +135,9 @@ func (conn *conn) readRequest(hyperReq *hyper.Request) (*Request, error) { body := hyperReq.Body() if body != nil { - req.Body, conn.bodyWriter = io.Pipe() + requestBody := newRequestBody() + conn.requestBody = requestBody + req.Body = requestBody //task := body.Foreach(getBodyChunk, c.Pointer(conn.bodyWriter), nil) task := body.Data() taskID := taskGetBody @@ -181,35 +183,28 @@ func addHeader(data unsafe.Pointer, name *byte, nameLen uintptr, value *byte, va return hyper.IterContinue } -func getBodyChunk(userdata c.Pointer, chunk *hyper.Buf) c.Int { - fmt.Printf("getBodyChunk called\n") - writer := (*io.PipeWriter)(userdata) - if writer == nil { - fmt.Printf("writer is nil\n") - return hyper.IterBreak - } - buf := chunk.Bytes() - len := chunk.Len() - bytes := unsafe.Slice(buf, len) - //debug - fmt.Printf("Writing %d bytes to response body\n", len) - fmt.Printf("Body chunk: %s\n", string(bytes)) - - go func() { - count, err := writer.Write(bytes) - fmt.Printf("Body chunk written: %d bytes\n", count) - if err != nil { - fmt.Println("Error writing to response body:", err) - writer.Close() - } - }() - // count, err := writer.Write(bytes) - // fmt.Printf("Body chunk written: %d bytes\n", count) - // if err != nil { - // fmt.Println("Error writing to response body:", err) - // writer.Close() - // return hyper.IterBreak - // } - - return hyper.IterContinue -} +// func getBodyChunk(userdata c.Pointer, chunk *hyper.Buf) c.Int { +// fmt.Printf("getBodyChunk called\n") +// writer := (*io.PipeWriter)(userdata) +// if writer == nil { +// fmt.Printf("writer is nil\n") +// return hyper.IterBreak +// } +// buf := chunk.Bytes() +// len := chunk.Len() +// bytes := unsafe.Slice(buf, len) +// //debug +// fmt.Printf("Writing %d bytes to response body\n", len) +// fmt.Printf("Body chunk: %s\n", string(bytes)) + +// go func() { +// count, err := writer.Write(bytes) +// fmt.Printf("Body chunk written: %d bytes\n", count) +// if err != nil { +// fmt.Println("Error writing to response body:", err) +// writer.Close() +// } +// }() + +// return hyper.IterContinue +// } diff --git a/x/net/http/server.go b/x/net/http/server.go index cf1df59..23ccdd1 100644 --- a/x/net/http/server.go +++ b/x/net/http/server.go @@ -3,7 +3,6 @@ package http import ( "errors" "fmt" - "io" "os" "strconv" "sync" @@ -17,6 +16,7 @@ import ( "github.com/goplus/llgo/c/syscall" "github.com/goplus/llgo/rust/hyper" "github.com/goplus/llgo/x/net" + "github.com/goplus/llgo/x/net/http/response_stream" ) type Handler interface { @@ -41,6 +41,9 @@ type Server struct { mu sync.Mutex activeConnections map[*conn]struct{} + + channels []<-chan struct{} + channelMutex sync.RWMutex } type conn struct { @@ -55,7 +58,35 @@ type conn struct { closedHandles int32 executor *hyper.Executor remoteAddr string - bodyWriter *io.PipeWriter + requestBody *requestBody +} + +type requestBody struct { + chunk []byte + readCh chan []byte + readToReadCh chan struct{} +} + +func newRequestBody() *requestBody { + return &requestBody{ + readCh: make(chan []byte, 1), + readToReadCh: make(chan struct{}, 1), + } +} + +func (rb *requestBody) Read(p []byte) (n int, err error) { + if len(rb.chunk) == 0 { + n = copy(p, rb.chunk) + rb.chunk = rb.chunk[n:] + if len(rb.chunk) > 0 { + return + } + } + rb.readToReadCh <- struct{}{} + rb.chunk = <-rb.readCh + n = copy(p, rb.chunk) + rb.chunk = rb.chunk[n:] + return } type serviceUserdata struct { @@ -308,6 +339,35 @@ func onIdle(handle *libuv.Idle) { //fmt.Println("onIdle conn called") if conn.executor != nil { task := conn.executor.Poll() + select { + case <-conn.requestBody.readToReadCh: + fmt.Println("readToReadCh signaled") + payload := (*taskData)(task.Userdata()) + if payload == nil { + fmt.Println("taskData is nil") + task.Free() + return + } + + taskID := payload.hyperTaskID + + fmt.Printf("taskID: %d\n", taskID) + + fmt.Println("taskGetBody get body task") + getBodyTask := payload.hyperBody.Data() + getBodyTask.SetUserdata(c.Pointer(payload), nil) + if getBodyTask != nil { + fmt.Println("taskGetBody push get body task") + r := payload.conn.executor.Push(getBodyTask) + fmt.Printf("taskGetBody push get body task: %d\n", r) + if r != hyper.OK { + fmt.Printf("failed to push get body task: %d\n", r) + getBodyTask.Free() + } + } + default: + fmt.Println("readToReadCh not signaled") + } for task != nil { srv.handleTask(task) task = conn.executor.Poll() @@ -315,6 +375,16 @@ func onIdle(handle *libuv.Idle) { } } + srv.channelMutex.RLock() + for _, ch := range srv.channels { + select { + case <-ch: + fmt.Printf("Received signal from channel\n") + default: + } + } + srv.channelMutex.RUnlock() + if srv.shuttingDown() { fmt.Println("Shutdown initiated, cleaning up...") handle.Stop() @@ -329,7 +399,7 @@ func serverCallback(userdata unsafe.Pointer, hyperReq *hyper.Request, channel *h return } - req, err := userData.conn.readRequest(hyperReq) + req, err := userData.conn.readRequest(userData.server, hyperReq) if err != nil { fmt.Printf("Error creating request: %v\n", err) return @@ -349,6 +419,7 @@ func serverCallback(userdata unsafe.Pointer, hyperReq *hyper.Request, channel *h } func (srv *Server) handleTask(task *hyper.Task) { + taskType := task.Type() fmt.Printf("taskType: %d\n", taskType) payload := (*taskData)(task.Userdata()) @@ -380,7 +451,7 @@ func (srv *Server) handleTask(task *hyper.Task) { fmt.Println("taskGetBody write buf") buf := (*hyper.Buf)(task.Value()) bytes := unsafe.Slice(buf.Bytes(), buf.Len()) - payload.conn.bodyWriter.Write(bytes) + payload.conn.requestBody.readCh <- bytes fmt.Println("taskGetBody wrote to bodyWriter") buf.Free() task.Free() From 799e91aeb38273ffbdffaa7c1a3ac0a055402a7d Mon Sep 17 00:00:00 2001 From: hackerchai Date: Tue, 10 Sep 2024 18:45:13 +0800 Subject: [PATCH 35/55] refactor(x/net/http/demo): Update demo Signed-off-by: hackerchai --- x/net/http/_demo/http.go | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/x/net/http/_demo/http.go b/x/net/http/_demo/http.go index 2e47992..35bf001 100644 --- a/x/net/http/_demo/http.go +++ b/x/net/http/_demo/http.go @@ -19,13 +19,29 @@ func echoHandler(w http.ResponseWriter, r *http.Request) { fmt.Printf("RemoteAddr: %s\n", r.RemoteAddr) // fmt.Println("ContentLength: %d", r.ContentLength) // fmt.Println("TransferEncoding: %s", r.TransferEncoding) + body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, "Error reading request body", http.StatusInternalServerError) return } + + // var body []byte + // buffer := make([]byte, 1024) + // for { + // n, err := r.Body.Read(buffer) + // if err != nil && err != io.EOF { + // http.Error(w, "Error reading request body", http.StatusInternalServerError) + // return + // } + // body = append(body, buffer[:n]...) + // if err == io.EOF { + // break + // } + // } + - defer r.Body.Close() + r.Body.Close() fmt.Println("body read") w.Header().Set("Content-Type", "text/plain") w.Write(body) From 94af890c6c3b296e0a5a7b5435c015c06eabff25 Mon Sep 17 00:00:00 2001 From: hackerchai Date: Tue, 10 Sep 2024 18:46:10 +0800 Subject: [PATCH 36/55] refactor(x/net/http/demo): Remove response_stream Signed-off-by: hackerchai --- x/net/http/response_stream/response_stream.go | 178 ------------------ 1 file changed, 178 deletions(-) delete mode 100644 x/net/http/response_stream/response_stream.go diff --git a/x/net/http/response_stream/response_stream.go b/x/net/http/response_stream/response_stream.go deleted file mode 100644 index bb0c786..0000000 --- a/x/net/http/response_stream/response_stream.go +++ /dev/null @@ -1,178 +0,0 @@ -package response_stream - -import ( - "errors" - "io" - "sync" -) - -type onceError struct { - sync.Mutex - err error -} - -func (a *onceError) Store(err error) { - a.Lock() - defer a.Unlock() - if a.err != nil { - return - } - a.err = err -} - -func (a *onceError) Load() error { - a.Lock() - defer a.Unlock() - return a.err -} - -var ErrClosedResponseStream = errors.New("response stream: read/write on closed stream") - -type responseStream struct { - wrMu sync.Mutex - wrCh chan []byte - rdCh chan struct{} - rdRemain []byte - done chan struct{} - once sync.Once - rerr onceError - werr onceError -} - -func (rs *responseStream) write(p []byte) (n int, err error) { - select { - case <-rs.done: - return 0, rs.writeCloseError() - default: - rs.wrMu.Lock() - defer rs.wrMu.Unlock() - } - - rs.wrCh <- p - return len(p), nil -} - -func (rs *responseStream) read(p []byte) (n int, err error) { - if len(rs.rdRemain) > 0 { - n = copy(p, rs.rdRemain) - rs.rdRemain = rs.rdRemain[n:] - if len(rs.rdRemain) == 0 { - select { - case rs.rdCh <- struct{}{}: - default: - } - } - return n, nil - } - - select { - case <-rs.done: - return 0, rs.rerr.Load() - case data := <-rs.wrCh: - n = copy(p, data) - if n < len(data) { - rs.rdRemain = data[n:] - } else { - select { - case rs.rdCh <- struct{}{}: - default: - } - } - return n, nil - } -} - -func (rs *responseStream) closeWrite(err error) error { - if err == nil { - err = io.EOF - } - rs.werr.Store(err) - rs.once.Do(func() { - close(rs.done) - }) - return nil -} - -func (rs *responseStream) closeRead(err error) error { - if err == nil { - err = ErrClosedResponseStream - } - rs.rerr.Store(err) - rs.once.Do(func() { - close(rs.done) - }) - return nil -} - -func (rs *responseStream) Close() error { - rs.once.Do(func() { - close(rs.done) - }) - return rs.readCloseError() -} - -func (rs *responseStream) readCloseError() error { - werr := rs.werr.Load() - if rerr := rs.rerr.Load(); werr == nil && rerr != nil { - return rerr - } - return ErrClosedResponseStream -} - -func (rs *responseStream) writeCloseError() error { - werr := rs.werr.Load() - if werr != nil { - return werr - } - return ErrClosedResponseStream -} - - -type ResponseStreamReader struct { - responseStream -} - -func (rs *ResponseStreamReader) Read(data []byte) (n int, err error) { - return rs.responseStream.read(data) -} - -func (rs *ResponseStreamReader) Close() error { - return rs.closeWithError(nil) -} - -func (rs *ResponseStreamReader) closeWithError(err error) error { - return rs.responseStream.closeRead(err) -} - -type ResponseStreamWriter struct { - r ResponseStreamReader -} - -func (rs *ResponseStreamWriter) Write(data []byte) (n int, err error) { - return rs.r.responseStream.write(data) -} - -func (rs *ResponseStreamWriter) Close() error { - return rs.r.closeWithError(nil) -} - -func (rs *ResponseStreamWriter) closeWithError(err error) error { - return rs.r.responseStream.closeWrite(err) -} - -func (rs *ResponseStreamWriter) GetRdCh() <-chan struct{} { - return rs.r.rdCh -} - -func NewResponseStream() (*ResponseStreamReader, *ResponseStreamWriter) { - rsw := &ResponseStreamWriter{ - r: ResponseStreamReader{ - responseStream: responseStream{ - wrCh: make(chan []byte, 1), - rdCh: make(chan struct{}, 1), - done: make(chan struct{}), - }, - }, - } - return &rsw.r, rsw -} \ No newline at end of file From 75c69d5622977e249ab0e5fee9f7b99dea4cbe9c Mon Sep 17 00:00:00 2001 From: hackerchai Date: Tue, 10 Sep 2024 18:46:46 +0800 Subject: [PATCH 37/55] refactor(x/net/http): Implement requestBody logic Signed-off-by: hackerchai --- x/net/http/request.go | 13 +- x/net/http/request_body.go | 122 ++++++++++++++++++ x/net/http/response.go | 16 +-- x/net/http/server.go | 255 ++++++++++++++++--------------------- 4 files changed, 249 insertions(+), 157 deletions(-) create mode 100644 x/net/http/request_body.go diff --git a/x/net/http/request.go b/x/net/http/request.go index 58c1c82..3179e55 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -35,7 +35,7 @@ type Request struct { timeout time.Duration } -func (conn *conn) readRequest(server *Server, hyperReq *hyper.Request) (*Request, error) { +func (conn *conn) readRequest(hyperReq *hyper.Request) (*Request, error) { println("readRequest called") req := Request{ Header: make(Header), @@ -135,16 +135,17 @@ func (conn *conn) readRequest(server *Server, hyperReq *hyper.Request) (*Request body := hyperReq.Body() if body != nil { + fmt.Println("Body is not nil!!!!!!!") requestBody := newRequestBody() conn.requestBody = requestBody req.Body = requestBody //task := body.Foreach(getBodyChunk, c.Pointer(conn.bodyWriter), nil) task := body.Data() taskID := taskGetBody - taskData := taskData { - hyperBody: body, - body: nil, - conn: conn, + taskData := taskData{ + hyperBody: body, + body: nil, + conn: conn, hyperTaskID: taskID, } task.SetUserdata(c.Pointer(&taskData), nil) @@ -163,7 +164,7 @@ func (conn *conn) readRequest(server *Server, hyperReq *hyper.Request) (*Request return nil, fmt.Errorf("failed to get request body") } - defer hyperReq.Free() + hyperReq.Free() return &req, nil } diff --git a/x/net/http/request_body.go b/x/net/http/request_body.go new file mode 100644 index 0000000..4a85ccd --- /dev/null +++ b/x/net/http/request_body.go @@ -0,0 +1,122 @@ +package http + +import ( + "errors" + "fmt" + "io" + "sync" +) + +type onceError struct { + sync.Mutex + err error +} + +func (a *onceError) Store(err error) { + a.Lock() + defer a.Unlock() + if a.err != nil { + return + } + a.err = err +} + +func (a *onceError) Load() error { + a.Lock() + defer a.Unlock() + return a.err +} + +type requestBody struct { + chunk []byte + readCh chan []byte + readyCh chan struct{} + + once sync.Once + done chan struct{} + + rerr onceError +} + +var ( + ErrClosedRequestBody = errors.New("request body: read/write on closed body") +) + +func newRequestBody() *requestBody { + return &requestBody{ + readCh: make(chan []byte, 1), + readyCh: make(chan struct{}, 1), + done: make(chan struct{}), + } +} + +func (rb *requestBody) Read(p []byte) (n int, err error) { + fmt.Println("RequestBody Read called") + select { + case <-rb.done: + fmt.Println("Read done") + return 0, rb.rerr.Load() + default: + } + + if len(rb.chunk) > 0 { + fmt.Println("Read remaining chunk") + n = copy(p, rb.chunk) + rb.chunk = rb.chunk[n:] + if len(rb.chunk) > 0 { + return + } + } + + fmt.Println("readyCh waiting") + select { + case rb.readyCh <- struct{}{}: + fmt.Println("readyCh signaled") + default: + fmt.Println("readyCh skipped (channel full)") + } + + select { + case rb.chunk = <-rb.readCh: + fmt.Printf("Read chunk received: %s\n", string(rb.chunk)) + case <-rb.done: + return 0, rb.rerr.Load() + default: + if len(rb.chunk) == 0 { + fmt.Println("Read ended") + return 0, io.EOF + } + } + fmt.Printf("Read chunk received: %s\n", string(rb.chunk)) + if len(rb.chunk) == 0 { + fmt.Println("Read ended") + return 0, io.EOF + } + n = copy(p, rb.chunk) + rb.chunk = rb.chunk[n:] + fmt.Printf("Read chunk copied: %d bytes\n", n) + return +} + +func (rb *requestBody) readCloseError() error { + if rerr := rb.rerr.Load(); rerr != nil { + return rerr + } + return ErrClosedRequestBody +} + +func (rb *requestBody) closeRead(err error) error { + fmt.Println("closeRead called") + if err == nil { + err = ErrClosedRequestBody + } + rb.rerr.Store(err) + rb.once.Do(func() { + close(rb.done) + }) + return nil +} + +func (rb *requestBody) Close() error { + return rb.closeRead(nil) +} diff --git a/x/net/http/response.go b/x/net/http/response.go index 4d150ba..0f3cb49 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -29,7 +29,7 @@ type taskData struct { hyperBody *hyper.Body body *body conn *conn - hyperTaskID + hyperTaskID hyperTaskID } type hyperTaskID int @@ -39,8 +39,6 @@ const ( taskGetBody ) - - var DefaultChunkSize uintptr = 8192 func newResponse(request *Request, channel *hyper.ResponseChannel) *response { @@ -95,12 +93,12 @@ func (r *response) WriteHeader(statusCode int) { } //debug - // fmt.Printf("< HTTP/1.1 %d\n", statusCode) - // for key, values := range r.header { - // for _, value := range values { - // fmt.Printf("< %s: %s\n", key, value) - // } - // } + fmt.Printf("< HTTP/1.1 %d\n", statusCode) + for key, values := range r.header { + for _, value := range values { + fmt.Printf("< %s: %s\n", key, value) + } + } } func (r *response) finalize() error { diff --git a/x/net/http/server.go b/x/net/http/server.go index 23ccdd1..0455f7d 100644 --- a/x/net/http/server.go +++ b/x/net/http/server.go @@ -16,7 +16,6 @@ import ( "github.com/goplus/llgo/c/syscall" "github.com/goplus/llgo/rust/hyper" "github.com/goplus/llgo/x/net" - "github.com/goplus/llgo/x/net/http/response_stream" ) type Handler interface { @@ -41,9 +40,6 @@ type Server struct { mu sync.Mutex activeConnections map[*conn]struct{} - - channels []<-chan struct{} - channelMutex sync.RWMutex } type conn struct { @@ -61,34 +57,6 @@ type conn struct { requestBody *requestBody } -type requestBody struct { - chunk []byte - readCh chan []byte - readToReadCh chan struct{} -} - -func newRequestBody() *requestBody { - return &requestBody{ - readCh: make(chan []byte, 1), - readToReadCh: make(chan struct{}, 1), - } -} - -func (rb *requestBody) Read(p []byte) (n int, err error) { - if len(rb.chunk) == 0 { - n = copy(p, rb.chunk) - rb.chunk = rb.chunk[n:] - if len(rb.chunk) > 0 { - return - } - } - rb.readToReadCh <- struct{}{} - rb.chunk = <-rb.readCh - n = copy(p, rb.chunk) - rb.chunk = rb.chunk[n:] - return -} - type serviceUserdata struct { host [128]c.Char port [8]c.Char @@ -336,55 +304,16 @@ func onIdle(handle *libuv.Idle) { //fmt.Println("onIdle called") srv := (*Server)((*libuv.Handle)(unsafe.Pointer(handle)).GetData()) for conn := range srv.activeConnections { - //fmt.Println("onIdle conn called") if conn.executor != nil { task := conn.executor.Poll() - select { - case <-conn.requestBody.readToReadCh: - fmt.Println("readToReadCh signaled") - payload := (*taskData)(task.Userdata()) - if payload == nil { - fmt.Println("taskData is nil") - task.Free() - return - } - - taskID := payload.hyperTaskID - - fmt.Printf("taskID: %d\n", taskID) - - fmt.Println("taskGetBody get body task") - getBodyTask := payload.hyperBody.Data() - getBodyTask.SetUserdata(c.Pointer(payload), nil) - if getBodyTask != nil { - fmt.Println("taskGetBody push get body task") - r := payload.conn.executor.Push(getBodyTask) - fmt.Printf("taskGetBody push get body task: %d\n", r) - if r != hyper.OK { - fmt.Printf("failed to push get body task: %d\n", r) - getBodyTask.Free() - } - } - default: - fmt.Println("readToReadCh not signaled") - } for task != nil { - srv.handleTask(task) + srv.handleTask(conn, task) task = conn.executor.Poll() + srv.handleRead(conn, task) } } } - srv.channelMutex.RLock() - for _, ch := range srv.channels { - select { - case <-ch: - fmt.Printf("Received signal from channel\n") - default: - } - } - srv.channelMutex.RUnlock() - if srv.shuttingDown() { fmt.Println("Shutdown initiated, cleaning up...") handle.Stop() @@ -399,7 +328,7 @@ func serverCallback(userdata unsafe.Pointer, hyperReq *hyper.Request, channel *h return } - req, err := userData.conn.readRequest(userData.server, hyperReq) + req, err := userData.conn.readRequest(hyperReq) if err != nil { fmt.Printf("Error creating request: %v\n", err) return @@ -418,82 +347,124 @@ func serverCallback(userdata unsafe.Pointer, hyperReq *hyper.Request, channel *h // res.finalize() } -func (srv *Server) handleTask(task *hyper.Task) { - - taskType := task.Type() - fmt.Printf("taskType: %d\n", taskType) +func (srv *Server) handleRead(conn *conn, task *hyper.Task) { payload := (*taskData)(task.Userdata()) if payload == nil { - fmt.Println("taskData is nil") - task.Free() + fmt.Println("taskData is nil, no need to handle read") return } - - taskID := payload.hyperTaskID - - fmt.Printf("taskID: %d\n", taskID) - - if taskID == taskGetBody { - fmt.Println("taskGetBody called") - if taskType == hyper.TaskError { - fmt.Println("taskGetBody error") - err := (*hyper.Error)(task.Value()) - fmt.Printf("error code: %d\n", err.Code()) - - var errbuf [256]byte - errlen := err.Print(&errbuf[0], unsafe.Sizeof(errbuf)) - fmt.Printf("details: %s\n", errbuf[:errlen]) - err.Free() - task.Free() - } - - if taskType == hyper.TaskBuf { - fmt.Println("taskGetBody write buf") - buf := (*hyper.Buf)(task.Value()) - bytes := unsafe.Slice(buf.Bytes(), buf.Len()) - payload.conn.requestBody.readCh <- bytes - fmt.Println("taskGetBody wrote to bodyWriter") - buf.Free() - task.Free() - fmt.Println("taskGetBody free task") - - fmt.Println("taskGetBody get body task") - getBodyTask := payload.hyperBody.Data() - getBodyTask.SetUserdata(c.Pointer(payload), nil) - if getBodyTask != nil { - fmt.Println("taskGetBody push get body task") - r := payload.conn.executor.Push(getBodyTask) - fmt.Printf("taskGetBody push get body task: %d\n", r) - if r != hyper.OK { - fmt.Printf("failed to push get body task: %d\n", r) - getBodyTask.Free() - } + + select { + case <-conn.requestBody.readyCh: + fmt.Println("readyCh signaled") + + fmt.Println("taskGetBody get body task form readyCh") + getBodyTask := payload.hyperBody.Data() + getBodyTask.SetUserdata(c.Pointer(payload), nil) + if getBodyTask != nil { + fmt.Println("taskGetBody push get body task") + r := payload.conn.executor.Push(getBodyTask) + fmt.Printf("taskGetBody push get body task: %d\n", r) + if r != hyper.OK { + fmt.Printf("failed to push get body task: %d\n", r) + getBodyTask.Free() } } + default: + fmt.Println("readToReadCh not signaled") + } +} - if taskType == hyper.TaskEmpty { - fmt.Println("taskGetBody close bodyWriter") - payload.conn.bodyWriter.Close() - fmt.Println("taskGetBody free task") - task.Free() - } - } else if taskID == taskSetBody { - fmt.Println("taskSetBody called") - if taskType == hyper.TaskError { - fmt.Println("taskSetBody error") - err := (*hyper.Error)(task.Value()) - fmt.Printf("error code: %d\n", err.Code()) - - var errbuf [256]byte - errlen := err.Print(&errbuf[0], unsafe.Sizeof(errbuf)) - fmt.Printf("details: %s\n", errbuf[:errlen]) - err.Free() - task.Free() - } +func (srv *Server) handleTask(conn *conn, task *hyper.Task) { + taskType := task.Type() + //debug + switch taskType { + case hyper.TaskEmpty: + fmt.Println("Task type: Empty") + case hyper.TaskBuf: + fmt.Println("Task type: Buffer") + case hyper.TaskError: + fmt.Println("Task type: Error") + case hyper.TaskServerconn: + fmt.Println("Task type: Serverconn") + default: + fmt.Println("Unknown task type") + } - if taskType == hyper.TaskEmpty { - fmt.Println("taskSetBody free task") - task.Free() + payload := (*taskData)(task.Userdata()) + if payload != nil { + taskID := payload.hyperTaskID + + // select { + // case <-conn.requestBody.readyCh: + // fmt.Println("readyCh recieved") + + // fmt.Println("taskGetBody get body task form readyCh") + // getBodyTask := payload.hyperBody.Data() + // getBodyTask.SetUserdata(c.Pointer(payload), nil) + // if getBodyTask != nil { + // fmt.Println("taskGetBody push get body task") + // r := payload.conn.executor.Push(getBodyTask) + // fmt.Printf("taskGetBody push get body task: %d\n", r) + // if r != hyper.OK { + // fmt.Printf("failed to push get body task: %d\n", r) + // getBodyTask.Free() + // } + // } + // default: + // fmt.Println("readyCh not recieved") + // } + + if taskID == taskGetBody { + fmt.Println("taskGetBody called") + if taskType == hyper.TaskError { + fmt.Println("taskGetBody error") + err := (*hyper.Error)(task.Value()) + fmt.Printf("error code: %d\n", err.Code()) + + var errbuf [256]byte + errlen := err.Print(&errbuf[0], unsafe.Sizeof(errbuf)) + fmt.Printf("details: %s\n", errbuf[:errlen]) + err.Free() + task.Free() + } + + if taskType == hyper.TaskBuf { + fmt.Println("taskGetBody write buf") + buf := (*hyper.Buf)(task.Value()) + bytes := unsafe.Slice(buf.Bytes(), buf.Len()) + fmt.Printf("taskGetBody writing to bodyWriter: %s\n", string(bytes)) + buf.Free() + task.Free() + fmt.Println("taskGetBody free task") + payload.conn.requestBody.readCh <- bytes + fmt.Println("taskGetBody wrote to bodyWriter") + } + + if taskType == hyper.TaskEmpty { + fmt.Println("taskGetBody close requestBody") + payload.conn.requestBody.Close() + fmt.Println("taskGetBody free task") + task.Free() + } + } else if taskID == taskSetBody { + fmt.Println("taskSetBody called") + if taskType == hyper.TaskError { + fmt.Println("taskSetBody error") + err := (*hyper.Error)(task.Value()) + fmt.Printf("error code: %d\n", err.Code()) + + var errbuf [256]byte + errlen := err.Print(&errbuf[0], unsafe.Sizeof(errbuf)) + fmt.Printf("details: %s\n", errbuf[:errlen]) + err.Free() + task.Free() + } + + if taskType == hyper.TaskEmpty { + fmt.Println("taskSetBody free task") + task.Free() + } } } From 025037702b60cfee2b9b900a7f501ca8b8067200 Mon Sep 17 00:00:00 2001 From: hackerchai Date: Wed, 11 Sep 2024 19:42:54 +0800 Subject: [PATCH 38/55] refactor(x/net/http/demo): Implement requestBody logic and intergrate with main loop Signed-off-by: hackerchai --- x/net/http/_demo/http.go | 18 ++-- x/net/http/request.go | 39 ++------- x/net/http/request_body.go | 77 +++++++----------- x/net/http/response.go | 122 ++++++++++++++++----------- x/net/http/server.go | 163 +++++++++++++++++++++---------------- x/net/http/servermux.go | 4 +- 6 files changed, 214 insertions(+), 209 deletions(-) diff --git a/x/net/http/_demo/http.go b/x/net/http/_demo/http.go index 35bf001..9471272 100644 --- a/x/net/http/_demo/http.go +++ b/x/net/http/_demo/http.go @@ -8,15 +8,15 @@ import ( ) func echoHandler(w http.ResponseWriter, r *http.Request) { - fmt.Printf("echoHandler called\n") - fmt.Printf("> %s %s HTTP/%d.%d\n", r.Method, r.RequestURI, r.ProtoMajor, r.ProtoMinor) + fmt.Printf("[debug] echoHandler called\n") + fmt.Printf(">> %s %s HTTP/%d.%d\n", r.Method, r.RequestURI, r.ProtoMajor, r.ProtoMinor) for key, values := range r.Header { for _, value := range values { - fmt.Printf("> %s: %s\n", key, value) + fmt.Printf(">> %s: %s\n", key, value) } } - fmt.Printf("URL: %s\n", r.URL.String()) - fmt.Printf("RemoteAddr: %s\n", r.RemoteAddr) + fmt.Printf(">> URL: %s\n", r.URL.String()) + fmt.Printf(">> RemoteAddr: %s\n", r.RemoteAddr) // fmt.Println("ContentLength: %d", r.ContentLength) // fmt.Println("TransferEncoding: %s", r.TransferEncoding) @@ -25,7 +25,8 @@ func echoHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "Error reading request body", http.StatusInternalServerError) return } - + defer r.Body.Close() + // var body []byte // buffer := make([]byte, 1024) // for { @@ -40,9 +41,8 @@ func echoHandler(w http.ResponseWriter, r *http.Request) { // } // } - - r.Body.Close() - fmt.Println("body read") + fmt.Printf(">> Body: %s\n", string(body)) + fmt.Println("[debug] body read done") w.Header().Set("Content-Type", "text/plain") w.Write(body) } diff --git a/x/net/http/request.go b/x/net/http/request.go index 3179e55..5ae022f 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -36,7 +36,7 @@ type Request struct { } func (conn *conn) readRequest(hyperReq *hyper.Request) (*Request, error) { - println("readRequest called") + println("[debug] readRequest called") req := Request{ Header: make(Header), timeout: 0, @@ -135,11 +135,6 @@ func (conn *conn) readRequest(hyperReq *hyper.Request) (*Request, error) { body := hyperReq.Body() if body != nil { - fmt.Println("Body is not nil!!!!!!!") - requestBody := newRequestBody() - conn.requestBody = requestBody - req.Body = requestBody - //task := body.Foreach(getBodyChunk, c.Pointer(conn.bodyWriter), nil) task := body.Data() taskID := taskGetBody taskData := taskData{ @@ -149,6 +144,12 @@ func (conn *conn) readRequest(hyperReq *hyper.Request) (*Request, error) { hyperTaskID: taskID, } task.SetUserdata(c.Pointer(&taskData), nil) + requestBody := newRequestBody(conn.asyncHandle) + conn.requestBody = requestBody + req.Body = requestBody + + conn.asyncHandle.SetData(c.Pointer(&taskData)) + fmt.Println("[debug] async task set") if task != nil { r := conn.executor.Push(task) if r != hyper.OK { @@ -183,29 +184,3 @@ func addHeader(data unsafe.Pointer, name *byte, nameLen uintptr, value *byte, va } return hyper.IterContinue } - -// func getBodyChunk(userdata c.Pointer, chunk *hyper.Buf) c.Int { -// fmt.Printf("getBodyChunk called\n") -// writer := (*io.PipeWriter)(userdata) -// if writer == nil { -// fmt.Printf("writer is nil\n") -// return hyper.IterBreak -// } -// buf := chunk.Bytes() -// len := chunk.Len() -// bytes := unsafe.Slice(buf, len) -// //debug -// fmt.Printf("Writing %d bytes to response body\n", len) -// fmt.Printf("Body chunk: %s\n", string(bytes)) - -// go func() { -// count, err := writer.Write(bytes) -// fmt.Printf("Body chunk written: %d bytes\n", count) -// if err != nil { -// fmt.Println("Error writing to response body:", err) -// writer.Close() -// } -// }() - -// return hyper.IterContinue -// } diff --git a/x/net/http/request_body.go b/x/net/http/request_body.go index 4a85ccd..be19f56 100644 --- a/x/net/http/request_body.go +++ b/x/net/http/request_body.go @@ -5,6 +5,8 @@ import ( "fmt" "io" "sync" + + "github.com/goplus/llgo/c/libuv" ) type onceError struct { @@ -28,9 +30,9 @@ func (a *onceError) Load() error { } type requestBody struct { - chunk []byte - readCh chan []byte - readyCh chan struct{} + chunk []byte + readCh chan []byte + asyncHandle *libuv.Async once sync.Once done chan struct{} @@ -42,60 +44,42 @@ var ( ErrClosedRequestBody = errors.New("request body: read/write on closed body") ) -func newRequestBody() *requestBody { +func newRequestBody(asyncHandle *libuv.Async) *requestBody { return &requestBody{ - readCh: make(chan []byte, 1), - readyCh: make(chan struct{}, 1), - done: make(chan struct{}), + readCh: make(chan []byte, 1), + done: make(chan struct{}), + asyncHandle: asyncHandle, } } func (rb *requestBody) Read(p []byte) (n int, err error) { - fmt.Println("RequestBody Read called") - select { - case <-rb.done: - fmt.Println("Read done") - return 0, rb.rerr.Load() - default: - } - + fmt.Println("[debug] requestBody.Read called") + // If there are still unread chunks, read them first if len(rb.chunk) > 0 { - fmt.Println("Read remaining chunk") n = copy(p, rb.chunk) rb.chunk = rb.chunk[n:] - if len(rb.chunk) > 0 { - return - } + return n, nil } - fmt.Println("readyCh waiting") + // Attempt to read a new chunk from a channel select { - case rb.readyCh <- struct{}{}: - fmt.Println("readyCh signaled") - default: - fmt.Println("readyCh skipped (channel full)") - } - - select { - case rb.chunk = <-rb.readCh: - fmt.Printf("Read chunk received: %s\n", string(rb.chunk)) - case <-rb.done: - return 0, rb.rerr.Load() - default: - if len(rb.chunk) == 0 { - fmt.Println("Read ended") - return 0, io.EOF + case chunk, ok := <-rb.readCh: + if !ok { + // The channel has been closed, indicating that all data has been read + return 0, rb.readCloseError() } + n = copy(p, chunk) + if n < len(chunk) { + // If the capacity of p is insufficient to hold the whole chunk, save the rest of the chunk + rb.chunk = chunk[n:] + } + fmt.Println("[debug] requestBody.Read async send") + rb.asyncHandle.Send() + return n, nil + case <-rb.done: + // If the done channel is closed, the read needs to be terminated + return 0, rb.readCloseError() } - fmt.Printf("Read chunk received: %s\n", string(rb.chunk)) - if len(rb.chunk) == 0 { - fmt.Println("Read ended") - return 0, io.EOF - } - n = copy(p, rb.chunk) - rb.chunk = rb.chunk[n:] - fmt.Printf("Read chunk copied: %d bytes\n", n) - return } func (rb *requestBody) readCloseError() error { @@ -106,14 +90,15 @@ func (rb *requestBody) readCloseError() error { } func (rb *requestBody) closeRead(err error) error { - fmt.Println("closeRead called") + fmt.Println("[debug] RequestBody closeRead called") if err == nil { - err = ErrClosedRequestBody + err = io.EOF } rb.rerr.Store(err) rb.once.Do(func() { close(rb.done) }) + //close(rb.done) return nil } diff --git a/x/net/http/response.go b/x/net/http/response.go index 0f3cb49..c3ac428 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -5,18 +5,20 @@ import ( "unsafe" "github.com/goplus/llgo/c" + "github.com/goplus/llgo/c/libuv" "github.com/goplus/llgo/c/os" "github.com/goplus/llgo/rust/hyper" ) type response struct { - header Header - statusCode int - written bool - body []byte - channel *hyper.ResponseChannel - resp *hyper.Response - request *Request + header Header + statusCode int + written bool + body []byte + hyperChannel *hyper.ResponseChannel + hyperResp *hyper.Response + request *Request + asyncHandler *libuv.Async } type body struct { @@ -26,9 +28,9 @@ type body struct { } type taskData struct { - hyperBody *hyper.Body - body *body - conn *conn + hyperBody *hyper.Body + body *body + conn *conn hyperTaskID hyperTaskID } @@ -41,15 +43,18 @@ const ( var DefaultChunkSize uintptr = 8192 -func newResponse(request *Request, channel *hyper.ResponseChannel) *response { - fmt.Printf("newResponse called\n") - resp := response{ - header: make(Header), - channel: channel, - request: request, - resp: hyper.NewResponse(), +func newResponse(request *Request, hyperChannel *hyper.ResponseChannel) *response { + fmt.Printf("[debug] newResponse called\n") + + return &response{ + header: make(Header), + hyperChannel: hyperChannel, + //request: request, + statusCode: 200, + written: false, + body: nil, + hyperResp: hyper.NewResponse(), } - return &resp } func (r *response) Header() Header { @@ -60,21 +65,33 @@ func (r *response) Write(data []byte) (int, error) { if !r.written { r.WriteHeader(200) } + fmt.Printf("[debug] data: %s\n", string(data)) r.body = append(r.body, data...) + fmt.Printf("[debug] r.body: %s\n", string(r.body)) return len(data), nil } func (r *response) WriteHeader(statusCode int) { - fmt.Printf("WriteHeader called\n") + fmt.Println("[debug] WriteHeader called") if r.written { return } r.written = true r.statusCode = statusCode - r.resp.SetStatus(uint16(statusCode)) + r.hyperResp.SetStatus(uint16(statusCode)) - headers := r.resp.Headers() + fmt.Println("[debug] WriteHeaderStatusCode done") + + //debug + // fmt.Printf("[debug] < HTTP/1.1 %d\n", statusCode) + // for key, values := range r.header { + // for _, value := range values { + // fmt.Printf("< %s: %s\n", key, value) + // } + // } + + headers := r.hyperResp.Headers() for key, values := range r.header { valueLen := len(values) if valueLen > 1 { @@ -92,71 +109,76 @@ func (r *response) WriteHeader(statusCode int) { } } - //debug - fmt.Printf("< HTTP/1.1 %d\n", statusCode) - for key, values := range r.header { - for _, value := range values { - fmt.Printf("< %s: %s\n", key, value) - } - } + fmt.Println("[debug] WriteHeaderHeaders done") + + fmt.Println("[debug] WriteHeader done") } func (r *response) finalize() error { - fmt.Printf("finalize called\n") - err := r.request.Body.Close() - if err != nil { - return err - } - fmt.Printf("request body closed\n") + fmt.Printf("[debug] finalize called\n") + // err := r.request.Body.Close() + // if err != nil { + // return err + // } + // fmt.Printf("[debug] request body closed\n") if !r.written { r.WriteHeader(200) } + r.hyperResp = hyper.NewResponse() + + if r.hyperResp == nil { + return fmt.Errorf("failed to create response") + } + bodyData := body{ data: r.body, len: uintptr(len(r.body)), readLen: 0, } - fmt.Printf("bodyData constructed\n") + fmt.Println("[debug] bodyData constructed") body := hyper.NewBody() if body == nil { return fmt.Errorf("failed to create body") } taskData := &taskData{ - hyperBody: nil, - body: &bodyData, - conn: nil, + hyperBody: nil, + body: &bodyData, + conn: nil, hyperTaskID: taskSetBody, } body.SetDataFunc(setBodyDataFunc) body.SetUserdata(unsafe.Pointer(taskData), nil) - fmt.Printf("bodyData userdata set\n") + fmt.Println("[debug] bodyData userdata set") - fmt.Printf("bodyData set\n") + fmt.Println("[debug] bodyData set") - resBody := r.resp.SetBody(body) + resBody := r.hyperResp.SetBody(body) if resBody != hyper.OK { return fmt.Errorf("failed to set body") } - fmt.Printf("body set\n") + fmt.Println("[debug] body set") - r.channel.Send(r.resp) - fmt.Println("response sent") + r.hyperChannel.Send(r.hyperResp) + fmt.Println("[debug] response sent") return nil } func setBodyDataFunc(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { - fmt.Printf("setBodyDataFunc called\n") + fmt.Println("[debug] setBodyDataFunc called") taskData := (*taskData)(userdata) + if taskData == nil { + fmt.Println("[debug] taskData is nil") + return hyper.PollError + } body := taskData.body if body.len > 0 { //debug - fmt.Println("<") - fmt.Printf("%s", string(body.data)) - fmt.Println("") + fmt.Println("[debug]<") + fmt.Printf("[debug]%s\n", string(body.data)) if body.len > DefaultChunkSize { *chunk = hyper.CopyBuf(&body.data[body.readLen], DefaultChunkSize) @@ -167,13 +189,15 @@ func setBodyDataFunc(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) body.readLen += body.len body.len = 0 } + fmt.Println("[debug] setBodyDataFunc done") return hyper.PollReady } if body.len == 0 { *chunk = nil + fmt.Println("[debug] setBodyDataFunc done") return hyper.PollReady } - fmt.Printf("error setting body data: %s\n", c.GoString(c.Strerror(os.Errno))) + fmt.Printf("[debug] error setting body data: %s\n", c.GoString(c.Strerror(os.Errno))) return hyper.PollError } diff --git a/x/net/http/server.go b/x/net/http/server.go index 0455f7d..cc9601e 100644 --- a/x/net/http/server.go +++ b/x/net/http/server.go @@ -55,6 +55,7 @@ type conn struct { executor *hyper.Executor remoteAddr string requestBody *requestBody + asyncHandle *libuv.Async } type serviceUserdata struct { @@ -163,7 +164,7 @@ func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) { } func onNewConnection(serverStream *libuv.Stream, status c.Int) { - fmt.Println("onNewConnection called") + fmt.Println("[debug] onNewConnection called") if status < 0 { fmt.Printf("New connection error: %s\n", libuv.Strerror(libuv.Errno(status))) return @@ -181,11 +182,16 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { return } + fmt.Println("[debug] async handle creating") + + conn.asyncHandle = &libuv.Async{} + srv.uvLoop.Async(conn.asyncHandle, onAsync) + libuv.InitTcp(srv.uvLoop, &conn.stream) conn.stream.Data = unsafe.Pointer(conn) if serverStream.Accept((*libuv.Stream)(unsafe.Pointer(&conn.stream))) == 0 { - fmt.Println("Accepted new connection") + fmt.Println("[debug] Accepted new connection") r := libuv.PollInit(srv.uvLoop, &conn.pollHandle, libuv.OsFd(conn.stream.GetIoWatcherFd())) if r < 0 { fmt.Fprintf(os.Stderr, "uv_poll_init error: %s\n", libuv.Strerror(libuv.Errno(r))) @@ -234,9 +240,9 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { } conn.executor = executor - fmt.Println("Conn created") + fmt.Println("[debug] Conn created") srv.trackConn(conn, true) - fmt.Println("Conn tracked") + fmt.Println("[debug] Conn tracked") userdata.conn = conn @@ -275,12 +281,27 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { serverconn := hyper.ServeHttpXConnection(http1Opts, http2Opts, io, service) conn.executor.Push(serverconn) } else { - fmt.Println("Client not accepted") + fmt.Println("[debug] Client not accepted") (*libuv.Handle)(unsafe.Pointer(&conn.pollHandle)).Close(nil) (*libuv.Handle)(unsafe.Pointer(&conn.stream)).Close(nil) } } +func onAsync(asyncHandle *libuv.Async) { + fmt.Println("[debug] onAsync called") + taskData := (*taskData)(asyncHandle.GetData()) + dataTask := taskData.hyperBody.Data() + dataTask.SetUserdata(c.Pointer(taskData), nil) + if dataTask != nil { + r := taskData.conn.executor.Push(dataTask) + fmt.Printf("[debug] onAsync push data task: %d\n", r) + if r != hyper.OK { + fmt.Printf("failed to push data task: %d\n", r) + dataTask.Free() + } + } +} + // func onCheck(handle *libuv.Check) { // //fmt.Println("onCheck called") // srv := (*Server)((*libuv.Handle)(unsafe.Pointer(handle)).GetData()) @@ -307,9 +328,9 @@ func onIdle(handle *libuv.Idle) { if conn.executor != nil { task := conn.executor.Poll() for task != nil { - srv.handleTask(conn, task) + srv.handleTask(task) + //srv.handleRead(conn, task) task = conn.executor.Poll() - srv.handleRead(conn, task) } } } @@ -335,7 +356,7 @@ func serverCallback(userdata unsafe.Pointer, hyperReq *hyper.Request, channel *h } res := newResponse(req, channel) - fmt.Printf("Response created\n") + fmt.Println("[debug] Response created") go func() { userData.server.Handler.ServeHTTP(res, req) @@ -347,58 +368,58 @@ func serverCallback(userdata unsafe.Pointer, hyperReq *hyper.Request, channel *h // res.finalize() } -func (srv *Server) handleRead(conn *conn, task *hyper.Task) { - payload := (*taskData)(task.Userdata()) - if payload == nil { - fmt.Println("taskData is nil, no need to handle read") - return - } - - select { - case <-conn.requestBody.readyCh: - fmt.Println("readyCh signaled") - - fmt.Println("taskGetBody get body task form readyCh") - getBodyTask := payload.hyperBody.Data() - getBodyTask.SetUserdata(c.Pointer(payload), nil) - if getBodyTask != nil { - fmt.Println("taskGetBody push get body task") - r := payload.conn.executor.Push(getBodyTask) - fmt.Printf("taskGetBody push get body task: %d\n", r) - if r != hyper.OK { - fmt.Printf("failed to push get body task: %d\n", r) - getBodyTask.Free() - } - } - default: - fmt.Println("readToReadCh not signaled") - } -} +// func (srv *Server) handleRead(conn *conn, task *hyper.Task) { +// payload := (*taskData)(task.Userdata()) +// if payload == nil { +// fmt.Println("taskData is nil, no need to handle read") +// return +// } -func (srv *Server) handleTask(conn *conn, task *hyper.Task) { +// select { +// case <-conn.requestBody.readyCh: +// fmt.Println("readyCh signaled") + +// fmt.Println("taskGetBody get body task form readyCh") +// getBodyTask := payload.hyperBody.Data() +// getBodyTask.SetUserdata(c.Pointer(payload), nil) +// if getBodyTask != nil { +// fmt.Println("taskGetBody push get body task") +// r := payload.conn.executor.Push(getBodyTask) +// fmt.Printf("taskGetBody push get body task: %d\n", r) +// if r != hyper.OK { +// fmt.Printf("failed to push get body task: %d\n", r) +// getBodyTask.Free() +// } +// } +// default: +// fmt.Println("readToReadCh not signaled") +// } +// } + +func (srv *Server) handleTask(task *hyper.Task) { taskType := task.Type() //debug switch taskType { case hyper.TaskEmpty: - fmt.Println("Task type: Empty") + fmt.Println("[debug] Task type: Empty") case hyper.TaskBuf: - fmt.Println("Task type: Buffer") + fmt.Println("[debug] Task type: Buffer") case hyper.TaskError: - fmt.Println("Task type: Error") + fmt.Println("[debug] Task type: Error") case hyper.TaskServerconn: - fmt.Println("Task type: Serverconn") + fmt.Println("[debug] Task type: Serverconn") default: - fmt.Println("Unknown task type") + fmt.Println("[debug] Unknown task type") } payload := (*taskData)(task.Userdata()) if payload != nil { taskID := payload.hyperTaskID - + // select { // case <-conn.requestBody.readyCh: // fmt.Println("readyCh recieved") - + // fmt.Println("taskGetBody get body task form readyCh") // getBodyTask := payload.hyperBody.Data() // getBodyTask.SetUserdata(c.Pointer(payload), nil) @@ -414,67 +435,67 @@ func (srv *Server) handleTask(conn *conn, task *hyper.Task) { // default: // fmt.Println("readyCh not recieved") // } - + if taskID == taskGetBody { - fmt.Println("taskGetBody called") + fmt.Println("[debug] taskGetBody called") if taskType == hyper.TaskError { - fmt.Println("taskGetBody error") + fmt.Println("[debug] taskGetBody error") err := (*hyper.Error)(task.Value()) fmt.Printf("error code: %d\n", err.Code()) - + var errbuf [256]byte errlen := err.Print(&errbuf[0], unsafe.Sizeof(errbuf)) fmt.Printf("details: %s\n", errbuf[:errlen]) err.Free() task.Free() } - + if taskType == hyper.TaskBuf { - fmt.Println("taskGetBody write buf") + fmt.Println("[debug] taskGetBody write buf") buf := (*hyper.Buf)(task.Value()) bytes := unsafe.Slice(buf.Bytes(), buf.Len()) - fmt.Printf("taskGetBody writing to bodyWriter: %s\n", string(bytes)) + fmt.Printf("[debug] taskGetBody writing to bodyWriter: %s\n", string(bytes)) buf.Free() task.Free() - fmt.Println("taskGetBody free task") + fmt.Println("[debug] taskGetBody free task") payload.conn.requestBody.readCh <- bytes - fmt.Println("taskGetBody wrote to bodyWriter") + fmt.Println("[debug] taskGetBody wrote to bodyWriter") } - + if taskType == hyper.TaskEmpty { - fmt.Println("taskGetBody close requestBody") + fmt.Println("[debug] taskGetBody close requestBody") payload.conn.requestBody.Close() - fmt.Println("taskGetBody free task") + fmt.Println("[debug] taskGetBody free task") task.Free() } } else if taskID == taskSetBody { - fmt.Println("taskSetBody called") + fmt.Println("[debug] taskSetBody called") if taskType == hyper.TaskError { - fmt.Println("taskSetBody error") + fmt.Println("[debug] taskSetBody error") err := (*hyper.Error)(task.Value()) fmt.Printf("error code: %d\n", err.Code()) - + var errbuf [256]byte errlen := err.Print(&errbuf[0], unsafe.Sizeof(errbuf)) fmt.Printf("details: %s\n", errbuf[:errlen]) err.Free() task.Free() } - + if taskType == hyper.TaskEmpty { - fmt.Println("taskSetBody free task") + fmt.Println("[debug] taskSetBody free task") task.Free() } } } if taskType == hyper.TaskEmpty { - fmt.Println("taskEmpty called") + fmt.Println("[debug] taskEmpty called") task.Free() } if taskType == hyper.TaskServerconn { - fmt.Println("taskServerconn called") + fmt.Println("[debug] taskServerconn called") task.Free() } } @@ -526,11 +547,11 @@ func readCb(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uintp if conn.eventMask&c.Uint(libuv.READABLE) == 0 { conn.eventMask |= c.Uint(libuv.READABLE) - fmt.Printf("ReadCb Event mask: %d\n", conn.eventMask) + fmt.Printf("[debug] ReadCb Event mask: %d\n", conn.eventMask) if !updateConnRegistrations(conn, false) { return hyper.IoError } - fmt.Printf("ReadCb updateConnRegistrations\n") + fmt.Printf("[debug] ReadCb updateConnRegistrations\n") } conn.readWaker = ctx.Waker() @@ -555,7 +576,7 @@ func writeCb(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uint if conn.eventMask&c.Uint(libuv.WRITABLE) == 0 { conn.eventMask |= c.Uint(libuv.WRITABLE) - fmt.Printf("WriteCb Event mask: %d\n", conn.eventMask) + fmt.Printf("[debug] WriteCb Event mask: %d\n", conn.eventMask) if !updateConnRegistrations(conn, false) { return hyper.IoError } @@ -566,7 +587,7 @@ func writeCb(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uint } func onPoll(handle *libuv.Poll, status c.Int, events c.Int) { - fmt.Printf("onPoll called\n") + fmt.Printf("[debug] onPoll called\n") conn := (*conn)((*libuv.Handle)(unsafe.Pointer(handle)).GetData()) if status < 0 { @@ -586,14 +607,14 @@ func onPoll(handle *libuv.Poll, status c.Int, events c.Int) { } func updateConnRegistrations(conn *conn, create bool) bool { - fmt.Println("updateConnRegistrations called") + fmt.Println("[debug] updateConnRegistrations called") events := c.Int(0) if conn.eventMask == 0 { - fmt.Println("No events to poll, skipping poll start.") + fmt.Println("[debug] No events to poll, skipping poll start.") return true } - fmt.Printf("Event mask: %d\n", conn.eventMask) + fmt.Printf("[debug] Event mask: %d\n", conn.eventMask) if conn.eventMask&c.Uint(libuv.READABLE) != 0 { events |= c.Int(libuv.READABLE) } @@ -601,7 +622,7 @@ func updateConnRegistrations(conn *conn, create bool) bool { events |= c.Int(libuv.WRITABLE) } - fmt.Printf("Starting poll with events: %d\n", events) + fmt.Printf("[debug] Starting poll with events: %d\n", events) r := conn.pollHandle.Start(events, onPoll) if r < 0 { fmt.Fprintf(os.Stderr, "uv_poll_start error: %s\n", libuv.Strerror(libuv.Errno(r))) @@ -624,7 +645,7 @@ func createConnData() (*conn, error) { func freeConnData(userdata c.Pointer) { conn := (*conn)(userdata) if conn != nil && !conn.isClosing.Swap(true) { - fmt.Printf("Closing connection...\n") + fmt.Printf("[debug] Closing connection...\n") if conn.readWaker != nil { conn.readWaker.Free() conn.readWaker = nil diff --git a/x/net/http/servermux.go b/x/net/http/servermux.go index 6da8bce..a210a7b 100644 --- a/x/net/http/servermux.go +++ b/x/net/http/servermux.go @@ -19,9 +19,9 @@ type muxEntry struct { var DefaultServeMux = &ServeMux{m: make(map[string]muxEntry)} func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) { - fmt.Printf("ServeHTTP called\n") + fmt.Printf("[debug] ServeHTTP called\n") h, pattern := mux.Handler(r) - fmt.Printf("Handler found for pattern: %s\n", pattern) + fmt.Printf("[debug] Handler found for pattern: %s\n", pattern) h.ServeHTTP(w, r) } From e55b2616a086b5ef126178efab8a3f4d6884f4d4 Mon Sep 17 00:00:00 2001 From: spongehah <2635879218@qq.com> Date: Wed, 11 Sep 2024 13:22:10 +0800 Subject: [PATCH 39/55] WIP(x/net/http/client): Implement BodyChunk --- x/net/http/_demo/chunked/chunked.go | 29 + x/net/http/_demo/get/get.go | 6 +- x/net/http/_demo/headers/headers.go | 6 +- .../_demo/maxConnsPerHost/maxConnsPerHost.go | 6 +- x/net/http/_demo/post/post.go | 5 + x/net/http/_demo/redirect/redirect.go | 6 +- x/net/http/_demo/reuseConn/reuseConn.go | 12 +- x/net/http/_demo/server/chunkedServer.go | 42 + x/net/http/_demo/upload/upload.go | 8 +- x/net/http/bodyChunk.go | 104 ++ x/net/http/client.go | 3 +- x/net/http/header.go | 70 +- x/net/http/request.go | 11 +- x/net/http/response.go | 228 +++- x/net/http/transfer.go | 57 +- x/net/http/transport.go | 1084 +++++++---------- 16 files changed, 893 insertions(+), 784 deletions(-) create mode 100644 x/net/http/_demo/chunked/chunked.go create mode 100644 x/net/http/_demo/server/chunkedServer.go create mode 100644 x/net/http/bodyChunk.go diff --git a/x/net/http/_demo/chunked/chunked.go b/x/net/http/_demo/chunked/chunked.go new file mode 100644 index 0000000..7b33c0c --- /dev/null +++ b/x/net/http/_demo/chunked/chunked.go @@ -0,0 +1,29 @@ +package main + +import ( + "fmt" + "io" + + "github.com/goplus/llgoexamples/x/net/http" +) + +func main() { + resp, err := http.Get("http://localhost:8080/chunked") + if err != nil { + fmt.Println(err) + return + } + defer resp.Body.Close() + fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } + body, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(string(body)) +} diff --git a/x/net/http/_demo/get/get.go b/x/net/http/_demo/get/get.go index 6e91bd4..392cc72 100644 --- a/x/net/http/_demo/get/get.go +++ b/x/net/http/_demo/get/get.go @@ -15,7 +15,11 @@ func main() { } defer resp.Body.Close() fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) - resp.PrintHeaders() + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } body, err := io.ReadAll(resp.Body) if err != nil { fmt.Println(err) diff --git a/x/net/http/_demo/headers/headers.go b/x/net/http/_demo/headers/headers.go index 5538923..41cc15f 100644 --- a/x/net/http/_demo/headers/headers.go +++ b/x/net/http/_demo/headers/headers.go @@ -38,7 +38,11 @@ func main() { } defer resp.Body.Close() fmt.Println(resp.Status) - resp.PrintHeaders() + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } body, err := io.ReadAll(resp.Body) if err != nil { println(err.Error()) diff --git a/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go b/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go index 5662251..eff95fc 100644 --- a/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go +++ b/x/net/http/_demo/maxConnsPerHost/maxConnsPerHost.go @@ -22,7 +22,11 @@ func main() { defer resp.Body.Close() fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) fmt.Println(resp.Proto) - resp.PrintHeaders() + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } body, err := io.ReadAll(resp.Body) if err != nil { fmt.Println(err) diff --git a/x/net/http/_demo/post/post.go b/x/net/http/_demo/post/post.go index fd756b3..b700028 100644 --- a/x/net/http/_demo/post/post.go +++ b/x/net/http/_demo/post/post.go @@ -17,6 +17,11 @@ func main() { } defer resp.Body.Close() fmt.Println(resp.Status) + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } body, err := io.ReadAll(resp.Body) if err != nil { fmt.Println(err) diff --git a/x/net/http/_demo/redirect/redirect.go b/x/net/http/_demo/redirect/redirect.go index f189255..3d40f3b 100644 --- a/x/net/http/_demo/redirect/redirect.go +++ b/x/net/http/_demo/redirect/redirect.go @@ -16,7 +16,11 @@ func main() { defer resp.Body.Close() fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) fmt.Println(resp.Proto) - resp.PrintHeaders() + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } body, err := io.ReadAll(resp.Body) if err != nil { fmt.Println(err) diff --git a/x/net/http/_demo/reuseConn/reuseConn.go b/x/net/http/_demo/reuseConn/reuseConn.go index bb460ce..bccfe9d 100644 --- a/x/net/http/_demo/reuseConn/reuseConn.go +++ b/x/net/http/_demo/reuseConn/reuseConn.go @@ -15,7 +15,11 @@ func main() { return } fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) - resp.PrintHeaders() + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } body, err := io.ReadAll(resp.Body) if err != nil { fmt.Println(err) @@ -31,7 +35,11 @@ func main() { return } fmt.Println(resp.Status, "read bytes: ", resp.ContentLength) - resp.PrintHeaders() + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } body, err = io.ReadAll(resp.Body) if err != nil { fmt.Println(err) diff --git a/x/net/http/_demo/server/chunkedServer.go b/x/net/http/_demo/server/chunkedServer.go new file mode 100644 index 0000000..b79ad60 --- /dev/null +++ b/x/net/http/_demo/server/chunkedServer.go @@ -0,0 +1,42 @@ +package main + +import ( + "fmt" + "net/http" +) + +func chunkedHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Transfer-Encoding", "chunked") + w.Header().Set("Content-Type", "text/plain") + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) + return + } + + sentence := "This is a chunked encoded response. It will be sent in multiple parts. Note the delay between each section." + + words := []string{} + start := 0 + for i, r := range sentence { + if r == '。' || r == ',' || i == len(sentence)-1 { + words = append(words, sentence[start:i+1]) + start = i + 1 + } + } + + for _, word := range words { + fmt.Fprintf(w, "%s", word) + flusher.Flush() + } +} + +func main() { + http.HandleFunc("/chunked", chunkedHandler) + fmt.Println("Starting server on :8080") + err := http.ListenAndServe(":8080", nil) + if err != nil { + fmt.Printf("Error starting server: %s\n", err) + } +} \ No newline at end of file diff --git a/x/net/http/_demo/upload/upload.go b/x/net/http/_demo/upload/upload.go index fe7256b..b5baffa 100644 --- a/x/net/http/_demo/upload/upload.go +++ b/x/net/http/_demo/upload/upload.go @@ -11,7 +11,7 @@ import ( func main() { url := "http://httpbin.org/post" //url := "http://localhost:8080" - filePath := "/Users/spongehah/go/src/llgo/x/net/http/_demo/upload/example.txt" // Replace with your file path + filePath := "/Users/spongehah/Documents/code/GOPATH/src/llgo/x/net/http/_demo/upload/example.txt" // Replace with your file path //filePath := "/Users/spongehah/Downloads/xiaoshuo.txt" // Replace with your file path file, err := os.Open(filePath) @@ -36,7 +36,11 @@ func main() { } defer resp.Body.Close() fmt.Println("Status:", resp.Status) - resp.PrintHeaders() + for key, values := range resp.Header { + for _, value := range values { + fmt.Printf("%s: %s\n", key, value) + } + } respBody, err := io.ReadAll(resp.Body) if err != nil { fmt.Println(err) diff --git a/x/net/http/bodyChunk.go b/x/net/http/bodyChunk.go new file mode 100644 index 0000000..c1d1072 --- /dev/null +++ b/x/net/http/bodyChunk.go @@ -0,0 +1,104 @@ +package http + +import ( + "errors" + "io" + "sync" + + "github.com/goplus/llgo/c/libuv" +) + +type onceError struct { + sync.Mutex + err error +} + +func (a *onceError) Store(err error) { + a.Lock() + defer a.Unlock() + if a.err != nil { + return + } + a.err = err +} + +func (a *onceError) Load() error { + a.Lock() + defer a.Unlock() + return a.err +} + +func newBodyChunk(asyncHandle *libuv.Async) *bodyChunk { + return &bodyChunk{ + readCh: make(chan []byte, 1), + done: make(chan struct{}), + asyncHandle: asyncHandle, + } +} + +type bodyChunk struct { + chunk []byte + readCh chan []byte + asyncHandle *libuv.Async + + once sync.Once + done chan struct{} + + rerr onceError +} + +var ( + errClosedBodyChunk = errors.New("bodyChunk: read/write on closed body") +) + +func (bc *bodyChunk) Read(p []byte) (n int, err error) { + for n < len(p) { + if len(bc.chunk) == 0 { + select { + case chunk, ok := <-bc.readCh: + if !ok { + if n > 0 { + return n, nil + } + return 0, bc.readCloseError() + } + bc.chunk = chunk + bc.asyncHandle.Send() + case <-bc.done: + if n > 0 { + return n, nil + } + return 0, io.EOF + } + } + + copied := copy(p[n:], bc.chunk) + n += copied + bc.chunk = bc.chunk[copied:] + } + + return n, nil +} + +func (bc *bodyChunk) Close() error { + return bc.closeRead(nil) +} + +func (bc *bodyChunk) readCloseError() error { + if rerr := bc.rerr.Load(); rerr != nil { + return rerr + } + return errClosedBodyChunk +} + +func (bc *bodyChunk) closeRead(err error) error { + if err == nil { + err = io.EOF + } + bc.rerr.Store(err) + bc.once.Do(func() { + close(bc.done) + }) + //close(bc.done) + return nil +} diff --git a/x/net/http/client.go b/x/net/http/client.go index 002397a..7e26395 100644 --- a/x/net/http/client.go +++ b/x/net/http/client.go @@ -307,7 +307,7 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, d forkReq() } - // TODO(spongehah) timeout(send) + // TODO(spongehah) tmp timeout(send) //stopTimer, didTimeout := setRequestCancel(req, rt, deadline) req.timeoutch = make(chan struct{}, 1) req.deadline = deadline @@ -490,7 +490,6 @@ func knownRoundTripperImpl(rt RoundTripper, req *Request) bool { return knownRoundTripperImpl(altRT, req) } return true - // TODO(spongehah) http2 //case *http2Transport, http2noDialH2RoundTripper: // return true } diff --git a/x/net/http/header.go b/x/net/http/header.go index 0d1e2cc..7c95411 100644 --- a/x/net/http/header.go +++ b/x/net/http/header.go @@ -75,15 +75,6 @@ func (h Header) Del(key string) { textproto.MIMEHeader(h).Del(key) } -// CanonicalHeaderKey returns the canonical format of the -// header key s. The canonicalization converts the first -// letter and any letter following a hyphen to upper case; -// the rest are converted to lowercase. For example, the -// canonical key for "accept-encoding" is "Accept-Encoding". -// If s contains a space or invalid header field bytes, it is -// returned without modifications. -func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) } - // Clone returns a copy of h or nil if h is nil. func (h Header) Clone() Header { if h == nil { @@ -111,28 +102,6 @@ func (h Header) Clone() Header { return h2 } -var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ") - -type keyValues struct { - key string - values []string -} - -// A headerSorter implements sort.Interface by sorting a []keyValues -// by key. It's used as a pointer, so it can fit in a sort.Interface -// interface value without allocation. -type headerSorter struct { - kvs []keyValues -} - -func (s *headerSorter) Len() int { return len(s.kvs) } -func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] } -func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key } - -var headerSorterPool = sync.Pool{ - New: func() any { return new(headerSorter) }, -} - // sortedKeyValues returns h's keys sorted in the returned kvs // slice. The headerSorter used to sort is also returned, for possible // return to headerSorterCache. @@ -199,6 +168,37 @@ func (h Header) writeSubset(reqHeaders *hyper.Headers, exclude map[string]bool) return nil } +var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ") + +type keyValues struct { + key string + values []string +} + +// A headerSorter implements sort.Interface by sorting a []keyValues +// by key. It's used as a pointer, so it can fit in a sort.Interface +// interface value without allocation. +type headerSorter struct { + kvs []keyValues +} + +func (s *headerSorter) Len() int { return len(s.kvs) } +func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] } +func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key } + +var headerSorterPool = sync.Pool{ + New: func() any { return new(headerSorter) }, +} + +// CanonicalHeaderKey returns the canonical format of the +// header key s. The canonicalization converts the first +// letter and any letter following a hyphen to upper case; +// the rest are converted to lowercase. For example, the +// canonical key for "accept-encoding" is "Accept-Encoding". +// If s contains a space or invalid header field bytes, it is +// returned without modifications. +func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) } + // hasToken reports whether token appears with v, ASCII // case-insensitive, with space or comma boundaries. // token must be all lowercase. @@ -251,11 +251,3 @@ func appendToResponseHeader(userdata c.Pointer, name *uint8, nameLen uintptr, va resp.Header.Add(nameStr, valueStr) return hyper.IterContinue } - -func (resp *Response) PrintHeaders() { - for key, values := range resp.Header { - for _, value := range values { - fmt.Printf("%s: %s\n", key, value) - } - } -} diff --git a/x/net/http/request.go b/x/net/http/request.go index c5146ed..e9279fc 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -294,7 +294,7 @@ func (r *Request) write(client *hyper.ClientConn, taskData *taskData, exec *hype //} // Prepare the hyper.Request - hyperReq, err := r.newHyperRequest(taskData.pc.isProxy, taskData.req.extra) + hyperReq, err := r.newHyperRequest(taskData.pc.isProxy, taskData.req.extra, taskData.req) if err != nil { return err } @@ -308,7 +308,7 @@ func (r *Request) write(client *hyper.ClientConn, taskData *taskData, exec *hype return err } -func (r *Request) newHyperRequest(usingProxy bool, extraHeader Header) (*hyper.Request, error) { +func (r *Request) newHyperRequest(usingProxy bool, extraHeader Header, treq *transportRequest) (*hyper.Request, error) { // Find the target host. Prefer the Host: header, but if that // is not given, use the host from the request URL. // @@ -401,11 +401,6 @@ func (r *Request) newHyperRequest(usingProxy bool, extraHeader Header) (*hyper.R } // Process Body,ContentLength,Close,Trailer - //tw, err := newTransferWriter(r) - //if err != nil { - // return err - //} - //err = tw.writeHeader(w, trace) err = r.writeHeader(reqHeaders) if err != nil { return nil, err @@ -433,7 +428,7 @@ func (r *Request) newHyperRequest(usingProxy bool, extraHeader Header) (*hyper.R } // Write body and trailer - err = r.writeBody(hyperReq) + err = r.writeBody(hyperReq, treq) if err != nil { return nil, err } diff --git a/x/net/http/response.go b/x/net/http/response.go index 6ff5b3d..a3a96fc 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -1,10 +1,12 @@ package http import ( + "compress/gzip" + "errors" "fmt" "io" "strconv" - "unsafe" + "sync" "github.com/goplus/llgo/c" "github.com/goplus/llgoexamples/rust/hyper" @@ -32,7 +34,198 @@ func (r *Response) closeBody() { } } -func ReadResponse(r *io.PipeReader, req *Request, hyperResp *hyper.Response) (*Response, error) { +// bodyIsWritable reports whether the Body supports writing. The +// Transport returns Writable bodies for 101 Switching Protocols +// responses. +// The Transport uses this method to determine whether a persistent +// connection is done being managed from its perspective. Once we +// return a writable response body to a user, the net/http package is +// done managing that connection. +func (r *Response) bodyIsWritable() bool { + _, ok := r.Body.(io.Writer) + return ok +} + +// Cookies parses and returns the cookies set in the Set-Cookie headers. +func (r *Response) Cookies() []*Cookie { + return readSetCookies(r.Header) +} + +func (r *Response) checkRespBody(taskData *taskData) (needContinue bool) { + pc := taskData.pc + bodyWritable := r.bodyIsWritable() + hasBody := taskData.req.Method != "HEAD" && r.ContentLength != 0 + + if r.Close || taskData.req.Close || r.StatusCode <= 199 || bodyWritable { + // Don't do keep-alive on error if either party requested a close + // or we get an unexpected informational (1xx) response. + // StatusCode 100 is already handled above. + pc.alive = false + } + + if !hasBody || bodyWritable { + replaced := pc.t.replaceReqCanceler(taskData.req.cancelKey, nil) + + // Put the idle conn back into the pool before we send the response + // so if they process it quickly and make another request, they'll + // get this same conn. But we use the unbuffered channel 'rc' + // to guarantee that persistConn.roundTrip got out of its select + // potentially waiting for this persistConn to close. + pc.alive = pc.alive && + replaced && pc.tryPutIdleConn() + + if bodyWritable { + pc.closeErr = errCallerOwnsConn + } + + select { + case taskData.resc <- responseAndError{res: r}: + case <-taskData.callerGone: + readLoopDefer(pc, true) + return true + } + // Now that they've read from the unbuffered channel, they're safely + // out of the select that also waits on this goroutine to die, so + // we're allowed to exit now if needed (if alive is false) + readLoopDefer(pc, false) + return true + } + return false +} + +func (r *Response) wrapRespBody(taskData *taskData) { + body := &bodyEOFSignal{ + body: r.Body, + earlyCloseFn: func() error { + return nil + }, + fn: func(err error) error { + isEOF := err == io.EOF + if !isEOF { + if cerr := taskData.pc.canceled(); cerr != nil { + return cerr + } + } + return err + }, + } + r.Body = body + // TODO(spongehah) gzip(wrapRespBody) + //if taskData.addedGzip && EqualFold(r.Header.Get("Content-Encoding"), "gzip") { + // println("gzip reader") + // r.Body = &gzipReader{body: body} + // r.Header.Del("Content-Encoding") + // r.Header.Del("Content-Length") + // r.ContentLength = -1 + // r.Uncompressed = true + //} +} + +// bodyEOFSignal is used by the HTTP/1 transport when reading response +// bodies to make sure we see the end of a response body before +// proceeding and reading on the connection again. +// +// It wraps a ReadCloser but runs fn (if non-nil) at most +// once, right before its final (error-producing) Read or Close call +// returns. fn should return the new error to return from Read or Close. +// +// If earlyCloseFn is non-nil and Close is called before io.EOF is +// seen, earlyCloseFn is called instead of fn, and its return value is +// the return value from Close. +type bodyEOFSignal struct { + body io.ReadCloser + mu sync.Mutex // guards following 4 fields + closed bool // whether Close has been called + rerr error // sticky Read error + fn func(error) error // err will be nil on Read io.EOF + earlyCloseFn func() error // optional alt Close func used if io.EOF not seen +} + +var errReadOnClosedResBody = errors.New("http: read on closed response body") + +func (es *bodyEOFSignal) Read(p []byte) (n int, err error) { + es.mu.Lock() + closed, rerr := es.closed, es.rerr + es.mu.Unlock() + if closed { + return 0, errReadOnClosedResBody + } + if rerr != nil { + return 0, rerr + } + + n, err = es.body.Read(p) + if err != nil { + es.mu.Lock() + defer es.mu.Unlock() + if es.rerr == nil { + es.rerr = err + } + err = es.condfn(err) + } + return +} + +func (es *bodyEOFSignal) Close() error { + es.mu.Lock() + defer es.mu.Unlock() + if es.closed { + return nil + } + es.closed = true + if es.earlyCloseFn != nil && es.rerr != io.EOF { + return es.earlyCloseFn() + } + err := es.body.Close() + return es.condfn(err) +} + +// caller must hold es.mu. +func (es *bodyEOFSignal) condfn(err error) error { + if es.fn == nil { + return err + } + err = es.fn(err) + es.fn = nil + return err +} + +// gzipReader wraps a response body so it can lazily +// call gzip.NewReader on the first call to Read +type gzipReader struct { + _ incomparable + body *bodyEOFSignal // underlying HTTP/1 response body framing + zr *gzip.Reader // lazily-initialized gzip reader + zerr error // any error from gzip.NewReader; sticky +} + +func (gz *gzipReader) Read(p []byte) (n int, err error) { + if gz.zr == nil { + if gz.zerr == nil { + gz.zr, gz.zerr = gzip.NewReader(gz.body) + } + if gz.zerr != nil { + return 0, gz.zerr + } + } + + gz.body.mu.Lock() + if gz.body.closed { + err = errReadOnClosedResBody + } + gz.body.mu.Unlock() + + if err != nil { + return 0, err + } + return gz.zr.Read(p) +} + +func (gz *gzipReader) Close() error { + return gz.body.Close() +} + +func ReadResponse(r io.ReadCloser, req *Request, hyperResp *hyper.Response) (*Response, error) { resp := &Response{ Request: req, Header: make(Header), @@ -65,20 +258,6 @@ func readResponseLineAndHeader(resp *Response, hyperResp *hyper.Response) { headers.Foreach(appendToResponseHeader, c.Pointer(resp)) } -// appendToResponseBody BodyForeachCallback function: Process the response body -func appendToResponseBody(userdata c.Pointer, chunk *hyper.Buf) c.Int { - writer := (*io.PipeWriter)(userdata) - bufLen := chunk.Len() - bytes := unsafe.Slice(chunk.Bytes(), bufLen) - _, err := writer.Write(bytes) - if err != nil { - fmt.Println("Error writing to response body:", err) - writer.Close() - return hyper.IterBreak - } - return hyper.IterContinue -} - // RFC 7234, section 5.4: Should treat // // Pragma: no-cache @@ -94,26 +273,9 @@ func fixPragmaCacheControl(header Header) { } } -// Cookies parses and returns the cookies set in the Set-Cookie headers. -func (r *Response) Cookies() []*Cookie { - return readSetCookies(r.Header) -} - // isProtocolSwitchHeader reports whether the request or response header // is for a protocol switch. func isProtocolSwitchHeader(h Header) bool { return h.Get("Upgrade") != "" && HeaderValuesContainsToken(h["Connection"], "Upgrade") } - -// bodyIsWritable reports whether the Body supports writing. The -// Transport returns Writable bodies for 101 Switching Protocols -// responses. -// The Transport uses this method to determine whether a persistent -// connection is done being managed from its perspective. Once we -// return a writable response body to a user, the net/http package is -// done managing that connection. -func (r *Response) bodyIsWritable() bool { - _, ok := r.Body.(io.Writer) - return ok -} diff --git a/x/net/http/transfer.go b/x/net/http/transfer.go index 103200c..818fb3c 100644 --- a/x/net/http/transfer.go +++ b/x/net/http/transfer.go @@ -98,7 +98,7 @@ func (uste *unsupportedTEError) Error() string { } // msg is *Request or *Response. -func readTransfer(msg any, r *io.PipeReader) (err error) { +func readTransfer(msg any, r io.ReadCloser) (err error) { t := &transferReader{RequestMethod: "GET"} // Unify input @@ -173,19 +173,17 @@ func readTransfer(msg any, r *io.PipeReader) (err error) { if isResponse && noResponseBodyExpected(t.RequestMethod) || !bodyAllowedForStatus(t.StatusCode) { t.Body = NoBody } else { - // TODO(spongehah) ChunkReader(readTransfer) - //t.Body = &body{src: internal.NewChunkedReader(r), hdr: msg, r: r, closing: t.Close} - t.Body = &body{src: r, hdr: msg, r: r, closing: t.Close} + t.Body = &body{src: r, closer: r, hdr: msg, r: r, closing: t.Close} } case realLength == 0: t.Body = NoBody case realLength > 0: - t.Body = &body{src: io.LimitReader(r, realLength), closing: t.Close} + t.Body = &body{src: io.LimitReader(r, realLength), closer: r, closing: t.Close} default: // realLength < 0, i.e. "Content-Length" not mentioned in header if t.Close { // Close semantics (i.e. HTTP/1.0) - t.Body = &body{src: r, closing: t.Close} + t.Body = &body{src: r, closer: r, closing: t.Close} } else { // Persistent connection (i.e. HTTP/1.1) t.Body = NoBody @@ -349,9 +347,9 @@ func fixTrailer(header Header, chunked bool) (Header, error) { // Close ensures that the body has been fully read // and then reads the trailer if necessary. type body struct { - src io.Reader - hdr any // non-nil (Response or Request) value means read trailer - //r *bufio.Reader // underlying wire-format reader for the trailer + src io.Reader + closer io.Closer + hdr any // non-nil (Response or Request) value means read trailer r io.Reader // underlying wire-format reader for the trailer closing bool // is the connection to be closed after reading body? doEarlyClose bool // whether Close should stop early @@ -476,6 +474,15 @@ func (b *body) Close() error { _, err = io.Copy(io.Discard, bodyLocked{b}) } b.closed = true + + // Close bodyChunk + if b.closer != nil { + closeErr := b.closer.Close() + if err == nil { + err = closeErr + } + } + return err } @@ -654,26 +661,26 @@ func unwrapNopCloser(r io.Reader) (underlyingReader io.Reader, isNopCloser bool) // files (*os.File types) are properly optimized. // // This function is only intended for use in writeBody. -func (req *Request) unwrapBody() io.Reader { - if r, ok := unwrapNopCloser(req.Body); ok { +func (r *Request) unwrapBody() io.Reader { + if r, ok := unwrapNopCloser(r.Body); ok { return r } - if r, ok := req.Body.(*readTrackingBody); ok { + if r, ok := r.Body.(*readTrackingBody); ok { r.didRead = true return r.ReadCloser } - return req.Body + return r.Body } -func (r *Request) writeBody(hyperReq *hyper.Request) error { +func (r *Request) writeBody(hyperReq *hyper.Request, treq *transportRequest) error { if r.Body != nil { var body = r.unwrapBody() hyperReqBody := hyper.NewBody() buf := make([]byte, defaultChunkSize) reqData := &bodyReq{ - body: body, - buf: buf, - closeBody: r.closeBody, + body: body, + buf: buf, + treq: treq, } hyperReqBody.SetUserdata(c.Pointer(reqData)) hyperReqBody.SetDataFunc(setPostData) @@ -683,9 +690,9 @@ func (r *Request) writeBody(hyperReq *hyper.Request) error { } type bodyReq struct { - body io.Reader - buf []byte - closeBody func() error + body io.Reader + buf []byte + treq *transportRequest } func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { @@ -694,10 +701,11 @@ func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.In if err != nil { if err == io.EOF { *chunk = nil - req.closeBody() + req.treq.closeBody() return hyper.PollReady } fmt.Println("error reading request body: ", err) + req.treq.setError(requestBodyReadError{err}) return hyper.PollError } if n > 0 { @@ -706,10 +714,11 @@ func setPostData(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.In } if n == 0 { *chunk = nil - req.closeBody() + req.treq.closeBody() return hyper.PollReady } - req.closeBody() - fmt.Printf("error reading request body: %s\n", c.GoString(c.Strerror(os.Errno))) + req.treq.closeBody() + err = fmt.Errorf("error reading request body: %s\n", c.GoString(c.Strerror(os.Errno))) + req.treq.setError(requestBodyReadError{err}) return hyper.PollError } diff --git a/x/net/http/transport.go b/x/net/http/transport.go index 44d721d..8075133 100644 --- a/x/net/http/transport.go +++ b/x/net/http/transport.go @@ -1,7 +1,6 @@ package http import ( - "compress/gzip" "container/list" "context" "errors" @@ -37,7 +36,12 @@ var DefaultTransport RoundTripper = &Transport{ // DefaultMaxIdleConnsPerHost is the default value of Transport's // MaxIdleConnsPerHost. const DefaultMaxIdleConnsPerHost = 2 -const debugSwitch = true + +// Debug switch provided for developers +const ( + debugSwitch = true + debugReadWriteLoop = true +) type Transport struct { idleMu sync.Mutex @@ -46,53 +50,24 @@ type Transport struct { idleConnWait map[connectMethodKey]wantConnQueue // waiting getConns idleLRU connLRU - altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme reqMu sync.Mutex reqCanceler map[cancelKey]func(error) - Proxy func(*Request) (*url.URL, error) + + altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme connsPerHostMu sync.Mutex connsPerHost map[connectMethodKey]int connsPerHostWait map[connectMethodKey]wantConnQueue // waiting getConns - // DisableKeepAlives, if true, disables HTTP keep-alives and - // will only use the connection to the server for a single - // HTTP request. - // - // This is unrelated to the similarly named TCP keep-alives. - DisableKeepAlives bool - - // DisableCompression, if true, prevents the Transport from - // requesting compression with an "Accept-Encoding: gzip" - // request header when the Request contains no existing - // Accept-Encoding value. If the Transport requests gzip on - // its own and gets a gzipped response, it's transparently - // decoded in the Response.Body. However, if the user - // explicitly requested gzip it is not automatically - // uncompressed. - DisableCompression bool + Proxy func(*Request) (*url.URL, error) - // MaxIdleConns controls the maximum number of idle (keep-alive) - // connections across all hosts. Zero means no limit. - MaxIdleConns int + DisableKeepAlives bool + DisableCompression bool - // MaxIdleConnsPerHost, if non-zero, controls the maximum idle - // (keep-alive) connections to keep per-host. If zero, - // DefaultMaxIdleConnsPerHost is used. + MaxIdleConns int MaxIdleConnsPerHost int - - // MaxConnsPerHost optionally limits the total number of - // connections per host, including connections in the dialing, - // active, and idle states. On limit violation, dials will block. - // - // Zero means no limit. - MaxConnsPerHost int - - // IdleConnTimeout is the maximum amount of time an idle - // (keep-alive) connection will remain idle before closing - // itself. - // Zero means no limit. - IdleConnTimeout time.Duration + MaxConnsPerHost int + IdleConnTimeout time.Duration // libuv and hyper related loopInitOnce sync.Once @@ -516,14 +491,11 @@ func (t *Transport) alternateRoundTripper(req *Request) RoundTripper { // useRegisteredProtocol reports whether an alternate protocol (as registered // with Transport.RegisterProtocol) should be respected for this request. func (t *Transport) useRegisteredProtocol(req *Request) bool { - if req.URL.Scheme == "https" && req.requiresHTTP1() { - // If this request requires HTTP/1, don't use the - // "https" alternate protocol, which is used by the - // HTTP/2 code to take over requests if there's an - // existing cached HTTP/2 connection. - return false - } - return true + // If this request requires HTTP/1, don't use the + // "https" alternate protocol, which is used by the + // HTTP/2 code to take over requests if there's an + // existing cached HTTP/2 connection. + return !(req.URL.Scheme == "https" && req.requiresHTTP1()) } // CancelRequest cancels an in-flight request by closing its connection. @@ -573,6 +545,8 @@ func (t *Transport) closeLocked(err error) { } } +// ---------------------------------------------------------- + func getMilliseconds(deadline time.Time) uint64 { microseconds := deadline.Sub(time.Now()).Microseconds() milliseconds := microseconds / 1e3 @@ -582,15 +556,13 @@ func getMilliseconds(deadline time.Time) uint64 { return uint64(milliseconds) } -// ---------------------------------------------------------- - func (t *Transport) RoundTrip(req *Request) (*Response, error) { if debugSwitch { - println("RoundTrip start") - defer println("RoundTrip end") + println("############### RoundTrip start") + defer println("############### RoundTrip end") } t.loopInitOnce.Do(func() { - println("init loop") + println("############### init loop") t.loop = libuv.LoopNew() t.async = &libuv.Async{} t.exec = hyper.NewExecutor() @@ -620,7 +592,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { req.timer.Start(onTimeout, getMilliseconds(req.deadline), 0) if debugSwitch { - println("timer start") + println("############### timer start") } didTimeout = func() bool { return req.timer.GetDueIn() == 0 } stopTimer = func() { @@ -628,7 +600,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { req.timer.Stop() (*libuv.Handle)(c.Pointer(req.timer)).Close(nil) if debugSwitch { - println("timer close") + println("############### timer close") } } } else { @@ -654,8 +626,8 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { func (t *Transport) doRoundTrip(req *Request) (*Response, error) { if debugSwitch { - println("doRoundTrip start") - defer println("doRoundTrip end") + println("############### doRoundTrip start") + defer println("############### doRoundTrip end") } //t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) //ctx := req.Context() @@ -715,7 +687,6 @@ func (t *Transport) doRoundTrip(req *Request) (*Response, error) { } for { - // TODO(spongehah) timeout(t.doRoundTrip) //select { //case <-ctx.Done(): // req.closeBody() @@ -766,7 +737,6 @@ func (t *Transport) doRoundTrip(req *Request) (*Response, error) { } // Failed. Clean up and determine whether to retry. - // TODO(spongehah) ConnPool(t.doRoundTrip) if http2isNoCachedConnError(err) { if t.removeIdleConn(pconn) { t.decConnsPerHost(pconn.cacheKey) @@ -800,8 +770,8 @@ func (t *Transport) doRoundTrip(req *Request) (*Response, error) { func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persistConn, err error) { if debugSwitch { - println("getConn start") - defer println("getConn end") + println("############### getConn start") + defer println("############### getConn end") } req := treq.Request //trace := treq.trace @@ -824,13 +794,11 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi } }() - // TODO(spongehah) ConnPool(t.getConn) // Queue for idle connection. if delivered := t.queueForIdleConn(w); delivered { pc := w.pc // Trace only for HTTP/1. // HTTP/2 calls trace.GotConn itself. - // TODO(spongehah) trace(t.getConn) //if pc.alt == nil && trace != nil && trace.GotConn != nil { // trace.GotConn(pc.gotIdleConnTrace(pc.idleAt)) //} @@ -853,28 +821,28 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi // trace.GotConn(httptrace.GotConnInfo{Conn: w.pc.conn, Reused: w.pc.isReused()}) //} if w.err != nil { - // If the request has been canceled, that's probably - // what caused w.err; if so, prefer to return the - // cancellation error (see golang.org/issue/16049). - select { - // TODO(spongehah) timeout(t.getConn) - //case <-req.Cancel: - // return nil, errRequestCanceledConn - //case <-req.Context().Done(): - // return nil, req.Context().Err() - case <-req.timeoutch: - if debugSwitch { - println("getConn: timeoutch") - } - return nil, errors.New("timeout: req.Context().Err()") - case err := <-cancelc: - if err == errRequestCanceled { - err = errRequestCanceledConn - } - return nil, err - default: - // return below + return nil, w.err + } + // If the request has been canceled, that's probably + // what caused w.err; if so, prefer to return the + // cancellation error (see golang.org/issue/16049). + select { + //case <-req.Cancel: + // return nil, errRequestCanceledConn + //case <-req.Context().Done(): + // return nil, req.Context().Err() + case <-req.timeoutch: + if debugSwitch { + println("############### getConn: timeoutch") + } + return nil, errors.New("timeout: req.Context().Err()") + case err := <-cancelc: + if err == errRequestCanceled { + err = errRequestCanceledConn } + return nil, err + default: + // return below } return w.pc, w.err } @@ -883,8 +851,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi // Once w receives permission to dial, it will do so in a separate goroutine. func (t *Transport) queueForDial(w *wantConn) { if debugSwitch { - println("queueForDial start") - defer println("queueForDial end") + println("############### queueForDial start") + defer println("############### queueForDial end") } w.beforeDial() @@ -919,13 +887,12 @@ func (t *Transport) queueForDial(w *wantConn) { // If the dial is canceled or unsuccessful, dialConnFor decrements t.connCount[w.cm.key()]. func (t *Transport) dialConnFor(w *wantConn) { if debugSwitch { - println("dialConnFor start") - defer println("dialConnFor end") + println("############### dialConnFor start") + defer println("############### dialConnFor end") } defer w.afterDial() pc, err := t.dialConn(w.timeoutch, w.cm) - // TODO(spongehah) ConnPool(t.dialConnFor) delivered := w.tryDeliver(pc, err) // If the connection was successfully established but was not passed to w, // or is a shareable HTTP/2 connection @@ -994,8 +961,8 @@ func (t *Transport) decConnsPerHost(key connectMethodKey) { func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn *persistConn, err error) { if debugSwitch { - println("dialConn start") - defer println("dialConn end") + println("############### dialConn start") + defer println("############### dialConn end") } select { case <-timeoutch: @@ -1009,7 +976,9 @@ func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn * closech: make(chan struct{}, 1), writeLoopDone: make(chan struct{}, 1), alive: true, + chunkAsync: &libuv.Async{}, } + t.loop.Async(pconn.chunkAsync, readyToRead) //trace := httptrace.ContextClientTrace(ctx) //wrapErr := func(err error) error { @@ -1102,6 +1071,21 @@ func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn * // } //} + pconn.closeErr = errReadLoopExiting + pconn.tryPutIdleConn = func() bool { + if err := pconn.t.tryPutIdleConn(pconn); err != nil { + pconn.closeErr = err + //if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { + // trace.PutIdleConn(err) + //} + return false + } + //if trace != nil && trace.PutIdleConn != nil { + // trace.PutIdleConn(nil) + //} + return true + } + select { case <-timeoutch: err = errors.New("[t.dialConn] request timeout") @@ -1114,8 +1098,8 @@ func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn * func (t *Transport) dial(addr string) (*connData, error) { if debugSwitch { - println("dial start") - defer println("dial end") + println("############### dial start") + defer println("############### dial end") } host, port, err := net.SplitHostPort(addr) if err != nil { @@ -1150,12 +1134,11 @@ func (t *Transport) dial(addr string) (*connData, error) { func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { if debugSwitch { - println("roundTrip start") - defer println("roundTrip end") + println("############### roundTrip start") + defer println("############### roundTrip end") } testHookEnterRoundTrip() if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) { - // TODO(spongehah) ConnPool(pc.roundTrip) pc.t.putOrCloseIdleConn(pc) return nil, errRequestCanceled } @@ -1168,40 +1151,8 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err headerFn(req.extraHeaders()) } - // Ask for a compressed version if the caller didn't set their - // own value for Accept-Encoding. We only attempt to - // uncompress the gzip stream if we were the layer that - // requested it. - requestedGzip := false - // TODO(spongehah) gzip(pc.roundTrip) - //if !pc.t.DisableCompression && - // req.Header.Get("Accept-Encoding") == "" && - // req.Header.Get("Range") == "" && - // req.Method != "HEAD" { - // // Request gzip only, not deflate. Deflate is ambiguous and - // // not as universally supported anyway. - // // See: https://zlib.net/zlib_faq.html#faq39 - // // - // // Note that we don't request this for HEAD requests, - // // due to a bug in nginx: - // // https://trac.nginx.org/nginx/ticket/358 - // // https://golang.org/issue/5522 - // // - // // We don't request gzip if the request is for a range, since - // // auto-decoding a portion of a gzipped document will just fail - // // anyway. See https://golang.org/issue/8923 - // requestedGzip = true - // req.extraHeaders().Set("Accept-Encoding", "gzip") - //} - - // The 100-continue operation in Hyper is handled in the newHyperRequest function. - - // Keep-Alive - if pc.t.DisableKeepAlives && - !req.wantsClose() && - !isProtocolSwitchHeader(req.Header) { - req.extraHeaders().Set("Connection", "close") - } + // Set extra headers, such as Accept-Encoding, Connection(Keep-Alive). + requestedGzip := pc.setExtraHeaders(req) gone := make(chan struct{}, 1) defer close(gone) @@ -1229,9 +1180,8 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err } if pc.client == nil && !pc.isReused() { - println("first") // Hookup the IO - hyperIo := newIoWithConnReadWrite(pc.conn) + hyperIo := newHyperIo(pc.conn) // We need an executor generally to poll futures // Prepare client options opts := hyper.NewClientConnOptions() @@ -1243,7 +1193,6 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // Send the request to readWriteLoop(). pc.t.exec.Push(handshakeTask) } else { - println("second") taskData.taskId = read err = req.write(pc.client, taskData, pc.t.exec) if err != nil { @@ -1264,12 +1213,12 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err for { testHookWaitResLoop() if debugSwitch { - println("roundTrip for") + println("############### roundTrip for") } select { case err := <-writeErrCh: if debugSwitch { - println("roundTrip: writeErrch") + println("############### roundTrip: writeErrch") } if err != nil { pc.close(fmt.Errorf("write error: %w", err)) @@ -1278,17 +1227,9 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err } return nil, pc.mapRoundTripError(req, startBytesWritten, err) } - //if d := pc.t.ResponseHeaderTimeout; d > 0 { - // if debugRoundTrip { - // //req.logf("starting timer for %v", d) - // } - // timer := time.NewTimer(d) - // defer timer.Stop() // prevent leaks - // respHeaderTimer = timer.C - //} case <-pcClosed: if debugSwitch { - println("roundTrip: pcClosed") + println("############### roundTrip: pcClosed") } pcClosed = nil if canceled || pc.t.replaceReqCanceler(req.cancelKey, nil) { @@ -1297,7 +1238,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err //case <-respHeaderTimer: case re := <-resc: if debugSwitch { - println("roundTrip: resc") + println("############### roundTrip: resc") } if (re.res == nil) == (re.err == nil) { return nil, fmt.Errorf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil) @@ -1306,7 +1247,6 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err return nil, pc.mapRoundTripError(req, startBytesWritten, re.err) } return re.res, nil - // TODO(spongehah) timeout(pc.roundTrip) //case <-cancelChan: // canceled = pc.t.cancelRequest(req.cancelKey, errRequestCanceled) // cancelChan = nil @@ -1316,7 +1256,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // ctxDoneChan = nil case <-timeoutch: if debugSwitch { - println("roundTrip: timeoutch") + println("############### roundTrip: timeoutch") } canceled = pc.t.cancelRequest(req.cancelKey, errors.New("timeout: req.Context().Err()")) timeoutch = nil @@ -1330,361 +1270,232 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err func readWriteLoop(checker *libuv.Check) { t := (*Transport)((*libuv.Handle)(c.Pointer(checker)).GetData()) - // Read this once, before loop starts. (to avoid races in tests) - //testHookMu.Lock() - //testHookReadLoopBeforeNextRead := testHookReadLoopBeforeNextRead - //testHookMu.Unlock() - - const debugReadWriteLoop = true // Debug switch provided for developers - - // The polling state machine! - // Poll all ready tasks and act on them... - for { - task := t.exec.Poll() + // The polling state machine! Poll all ready tasks and act on them... + task := t.exec.Poll() + for task != nil { if debugSwitch { - println("polling") - } - if task == nil { - return - } - taskData := (*taskData)(task.Userdata()) - var taskId taskId - if taskData != nil { - taskId = taskData.taskId - } else { - taskId = notSet + println("############### polling") } + t.handleTask(task) + task = t.exec.Poll() + } +} + +func (t *Transport) handleTask(task *hyper.Task) { + taskData := (*taskData)(task.Userdata()) + if taskData == nil { + // A background task for hyper_client completed... + task.Free() + return + } + var err error + pc := taskData.pc + // If original taskId is set, we need to check it + err = checkTaskType(task, taskData) + if err != nil { + readLoopDefer(pc, true) + return + } + switch taskData.taskId { + case handshake: if debugReadWriteLoop { - println("taskId: ", taskId) + println("############### write") } - switch taskId { - case handshake: - if debugReadWriteLoop { - println("write") - } - - err := checkTaskType(task, handshake) - if err != nil { - taskData.writeErrCh <- err - task.Free() - continue - } - - pc := taskData.pc - select { - case <-pc.closech: - task.Free() - continue - default: - } - pc.client = (*hyper.ClientConn)(task.Value()) + // Check if the connection is closed + select { + case <-pc.closech: task.Free() + return + default: + } - // TODO(spongehah) Proxy(writeLoop) - taskData.taskId = read - err = taskData.req.Request.write(pc.client, taskData, t.exec) + pc.client = (*hyper.ClientConn)(task.Value()) + task.Free() - if err != nil { - //pc.writeErrCh <- err // to the body reader, which might recycle us - taskData.writeErrCh <- err // to the roundTrip function - pc.close(err) - continue - } - - if debugReadWriteLoop { - println("write end") - } - case read: - if debugReadWriteLoop { - println("read") - } - - pc := taskData.pc - - err := checkTaskType(task, read) - if bre, ok := err.(requestBodyReadError); ok { - err = bre.error - // Errors reading from the user's - // Request.Body are high priority. - // Set it here before sending on the - // channels below or calling - // pc.close() which tears down - // connections and causes other - // errors. - taskData.req.setError(err) - } - if err != nil { - //pc.writeErrCh <- err // to the body reader, which might recycle us - taskData.writeErrCh <- err // to the roundTrip function - pc.close(err) - continue - } + // TODO(spongehah) Proxy(writeLoop) + taskData.taskId = read + err = taskData.req.Request.write(pc.client, taskData, t.exec) - if pc.closeErr == nil { - pc.closeErr = errReadLoopExiting - } - // TODO(spongehah) ConnPool(readWriteLoop) - if pc.tryPutIdleConn == nil { - pc.tryPutIdleConn = func() bool { - if err := pc.t.tryPutIdleConn(pc); err != nil { - pc.closeErr = err - // TODO(spongehah) trace(readWriteLoop) - //if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { - // trace.PutIdleConn(err) - //} - return false - } - //if trace != nil && trace.PutIdleConn != nil { - // trace.PutIdleConn(nil) - //} - return true - } - } + if err != nil { + //pc.writeErrCh <- err // to the body reader, which might recycle us + taskData.writeErrCh <- err // to the roundTrip function + pc.close(err) + return + } - // Take the results - hyperResp := (*hyper.Response)(task.Value()) - task.Free() + if debugReadWriteLoop { + println("############### write end") + } + case read: + if debugReadWriteLoop { + println("############### read") + } - pc.mu.Lock() - if pc.numExpectedResponses == 0 { - pc.readLoopPeekFailLocked(hyperResp, err) - pc.mu.Unlock() + // Take the results + hyperResp := (*hyper.Response)(task.Value()) + task.Free() - // defer - readLoopDefer(pc, t) - continue - } + //pc.mu.Lock() + if pc.numExpectedResponses == 0 { + pc.readLoopPeekFailLocked(hyperResp, err) pc.mu.Unlock() + readLoopDefer(pc, true) + return + } + //pc.mu.Unlock() - //trace := httptrace.ContextClientTrace(rc.req.Context()) - - var resp *Response - var respBody *hyper.Body - if err == nil { - var pr *io.PipeReader - pr, taskData.bodyWriter = io.Pipe() - resp, err = ReadResponse(pr, taskData.req.Request, hyperResp) - respBody = hyperResp.Body() - } else { - err = transportReadFromServerError{err} - pc.closeErr = err - } + var resp *Response + if err == nil { + pc.chunkAsync.SetData(c.Pointer(taskData)) + bc := newBodyChunk(pc.chunkAsync) + pc.bodyChunk = bc + resp, err = ReadResponse(bc, taskData.req.Request, hyperResp) + taskData.hyperBody = hyperResp.Body() + } else { + err = transportReadFromServerError{err} + pc.closeErr = err + } - // No longer need the response - hyperResp.Free() + // No longer need the response + hyperResp.Free() - if err != nil { - select { - case taskData.resc <- responseAndError{err: err}: - case <-taskData.callerGone: - // defer - readLoopDefer(pc, t) - continue - } - // defer - readLoopDefer(pc, t) - continue + if err != nil { + select { + case taskData.resc <- responseAndError{err: err}: + case <-taskData.callerGone: + readLoopDefer(pc, true) + return } + readLoopDefer(pc, true) + return + } - pc.mu.Lock() - pc.numExpectedResponses-- - pc.mu.Unlock() - - bodyWritable := resp.bodyIsWritable() - hasBody := taskData.req.Method != "HEAD" && resp.ContentLength != 0 + dataTask := taskData.hyperBody.Data() + taskData.taskId = readBodyChunk + dataTask.SetUserdata(c.Pointer(taskData)) + t.exec.Push(dataTask) - if resp.Close || taskData.req.Close || resp.StatusCode <= 199 || bodyWritable { - // Don't do keep-alive on error if either party requested a close - // or we get an unexpected informational (1xx) response. - // StatusCode 100 is already handled above. - pc.alive = false - } + if !taskData.req.deadline.IsZero() { + (*timeoutData)((*libuv.Handle)(c.Pointer(taskData.req.timer)).GetData()).taskData = taskData + } - if !hasBody || bodyWritable { - replaced := pc.t.replaceReqCanceler(taskData.req.cancelKey, nil) - - // TODO(spongehah) ConnPool(readWriteLoop) - // Put the idle conn back into the pool before we send the response - // so if they process it quickly and make another request, they'll - // get this same conn. But we use the unbuffered channel 'rc' - // to guarantee that persistConn.roundTrip got out of its select - // potentially waiting for this persistConn to close. - pc.alive = pc.alive && - replaced && pc.tryPutIdleConn() - //pc.alive = pc.alive && - // !pc.sawEOF && - // pc.wroteRequest() && - // replaced && pc.tryPutIdleConn() - - if bodyWritable { - pc.closeErr = errCallerOwnsConn - } + //pc.mu.Lock() + pc.numExpectedResponses-- + //pc.mu.Unlock() - select { - case taskData.resc <- responseAndError{res: resp}: - case <-taskData.callerGone: - // defer - readLoopDefer(pc, t) - continue - } - // Now that they've read from the unbuffered channel, they're safely - // out of the select that also waits on this goroutine to die, so - // we're allowed to exit now if needed (if alive is false) - //testHookReadLoopBeforeNextRead() - if pc.alive == false { - // defer - readLoopDefer(pc, t) - } - continue - } + needContinue := resp.checkRespBody(taskData) + if needContinue { + return + } - body := &bodyEOFSignal{ - body: resp.Body, - earlyCloseFn: func() error { - taskData.bodyWriter.Close() - return nil - }, - fn: func(err error) error { - isEOF := err == io.EOF - if !isEOF { - if cerr := pc.canceled(); cerr != nil { - return cerr - } - } - return err - }, - } - resp.Body = body - - // TODO(spongehah) gzip(pc.readWriteLoop) - //if taskData.addedGzip && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") { - // println("gzip reader") - // resp.Body = &gzipReader{body: body} - // resp.Header.Del("Content-Encoding") - // resp.Header.Del("Content-Length") - // resp.ContentLength = -1 - // resp.Uncompressed = true - //} + resp.wrapRespBody(taskData) - bodyForeachTask := respBody.Foreach(appendToResponseBody, c.Pointer(taskData.bodyWriter)) - taskData.taskId = readDone - bodyForeachTask.SetUserdata(c.Pointer(taskData)) - t.exec.Push(bodyForeachTask) - if taskData.req.timer != nil { - (*timeoutData)((*libuv.Handle)(c.Pointer(taskData.req.timer)).GetData()).taskData = taskData - } + // FIXME: Waiting for the channel bug to be fixed + //select { + //case taskData.resc <- responseAndError{res: resp}: + //case <-taskData.callerGone: + // // defer + // readLoopDefer(pc, true) + // return + //} + select { + case <-taskData.callerGone: + readLoopDefer(pc, true) + return + default: + } + taskData.resc <- responseAndError{res: resp} - // TODO(spongehah) select blocking(readWriteLoop) - //select { - //case taskData.resc <- responseAndError{res: resp}: - //case <-taskData.callerGone: - // // defer - // readLoopDefer(pc, t) - // continue - //} - select { - case <-taskData.callerGone: - // defer - readLoopDefer(pc, t) - continue - default: - } - taskData.resc <- responseAndError{res: resp} + if debugReadWriteLoop { + println("############### read end") + } + case readBodyChunk: + if debugReadWriteLoop { + println("############### readBodyChunk") + } + taskType := task.Type() + if taskType == hyper.TaskBuf { + chunk := (*hyper.Buf)(task.Value()) + chunkLen := chunk.Len() + bytes := unsafe.Slice(chunk.Bytes(), chunkLen) + // Free chunk and task + chunk.Free() + task.Free() + // Write to the channel + pc.bodyChunk.readCh <- bytes if debugReadWriteLoop { - println("read end") - } - case readDone: - // A background task of reading the response body is completed - if debugReadWriteLoop { - println("readDone") - } - if taskData.bodyWriter != nil { - taskData.bodyWriter.Close() + println("############### readBodyChunk end [buf]") } - checkTaskType(task, readDone) - - bodyEOF := task.Type() == hyper.TaskEmpty - // free the task - task.Free() - - pc := taskData.pc - - replaced := t.replaceReqCanceler(taskData.req.cancelKey, nil) // before pc might return to idle pool - // TODO(spongehah) ConnPool(readWriteLoop) - pc.alive = pc.alive && - bodyEOF && - replaced && pc.tryPutIdleConn() - //pc.alive = pc.alive && - // bodyEOF && - // !pc.sawEOF && - // pc.wroteRequest() && - // replaced && tryPutIdleConn(trace) - - // TODO(spongehah) timeout(t.readWriteLoop) - //case <-rw.rc.req.Cancel: - // pc.alive = false - // pc.t.CancelRequest(rw.rc.req) - //case <-rw.rc.req.Context().Done(): - // pc.alive = false - // pc.t.cancelRequest(rw.rc.cancelKey, rw.rc.req.Context().Err()) - //case <-pc.closech: - // pc.alive = false - //} + return + } - //select { - //case <-taskData.req.timeoutch: - // continue - //case <-pc.closech: - // pc.alive = false - //default: - //} + // taskType == taskEmpty (check in checkTaskType) + task.Free() + taskData.hyperBody.Free() + taskData.hyperBody = nil + pc.bodyChunk.Close() + replaced := t.replaceReqCanceler(taskData.req.cancelKey, nil) // before pc might return to idle pool + pc.alive = pc.alive && + replaced && pc.tryPutIdleConn() - if pc.alive == false { - // defer - readLoopDefer(pc, t) - } + readLoopDefer(pc, false) - //testHookReadLoopBeforeNextRead() - if debugReadWriteLoop { - println("readDone end") - } - case notSet: - // A background task for hyper_client completed... - task.Free() + if debugReadWriteLoop { + println("############### readBodyChunk end [empty]") } } } -func readLoopDefer(pc *persistConn, t *Transport) { +func readyToRead(aysnc *libuv.Async) { + println("############### AsyncCb: readyToRead") + taskData := (*taskData)(aysnc.GetData()) + dataTask := taskData.hyperBody.Data() + dataTask.SetUserdata(c.Pointer(taskData)) + taskData.pc.t.exec.Push(dataTask) +} + +// readLoopDefer Replace the defer function of readLoop in stdlib +func readLoopDefer(pc *persistConn, force bool) { + if pc.alive == true && !force { + return + } pc.close(pc.closeErr) - // TODO(spongehah) ConnPool(readLoopDefer) - t.removeIdleConn(pc) + pc.t.removeIdleConn(pc) } // ---------------------------------------------------------- +type connData struct { + TcpHandle libuv.Tcp + ConnectReq libuv.Connect + ReadBuf libuv.Buf + ReadBufFilled uintptr + nwrite int64 // bytes written(Replaced from persistConn's nwrite) + ReadWaker *hyper.Waker + WriteWaker *hyper.Waker +} + type taskData struct { taskId taskId - bodyWriter *io.PipeWriter req *transportRequest pc *persistConn addedGzip bool writeErrCh chan error callerGone chan struct{} resc chan responseAndError + hyperBody *hyper.Body } -type connData struct { - TcpHandle libuv.Tcp - ConnectReq libuv.Connect - ReadBuf libuv.Buf - ReadBufFilled uintptr - nwrite int64 // bytes written(Replaced from persistConn's nwrite) - ReadWaker *hyper.Waker - WriteWaker *hyper.Waker -} +// taskId The unique identifier of the next task polled from the executor +type taskId c.Int + +const ( + handshake taskId = iota + 1 + read + readBodyChunk +) func (conn *connData) Close() error { if conn == nil { @@ -1709,8 +1520,8 @@ func (conn *connData) Close() error { // onConnect is the libuv callback for a successful connection func onConnect(req *libuv.Connect, status c.Int) { if debugSwitch { - println("connect start") - defer println("connect end") + println("############### connect start") + defer println("############### connect end") } conn := (*connData)((*libuv.Req)(c.Pointer(req)).GetData()) @@ -1723,8 +1534,6 @@ func onConnect(req *libuv.Connect, status c.Int) { // allocBuffer allocates a buffer for reading from a socket func allocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { - //conn := (*ConnData)(handle.Data) - //conn := (*struct{ data *ConnData })(c.Pointer(handle)).data conn := (*connData)(handle.GetData()) if conn.ReadBuf.Base == nil { conn.ReadBuf = libuv.InitBuf((*c.Char)(c.Malloc(suggestedSize)), c.Uint(suggestedSize)) @@ -1738,31 +1547,21 @@ func allocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { // onRead is the libuv callback for reading from a socket // This callback function is called when data is available to be read func onRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { - // Get the connection data associated with the stream conn := (*connData)((*libuv.Handle)(c.Pointer(stream)).GetData()) - - // If data was read (nread > 0) if nread > 0 { - // Update the amount of filled buffer conn.ReadBufFilled += uintptr(nread) } - // If there's a pending read waker if conn.ReadWaker != nil { // Wake up the pending read operation of Hyper conn.ReadWaker.Wake() - // Clear the waker reference conn.ReadWaker = nil } } // readCallBack read callback function for Hyper library func readCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { - // Get the user data (connection data) conn := (*connData)(userdata) - - // If there's data in the buffer if conn.ReadBufFilled > 0 { - // Calculate how much data to copy (minimum of filled amount and requested amount) var toCopy uintptr if bufLen < conn.ReadBufFilled { toCopy = bufLen @@ -1775,71 +1574,52 @@ func readCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uin c.Memmove(c.Pointer(conn.ReadBuf.Base), c.Pointer(uintptr(c.Pointer(conn.ReadBuf.Base))+toCopy), conn.ReadBufFilled-toCopy) // Update the amount of filled buffer conn.ReadBufFilled -= toCopy - // Return the number of bytes copied return toCopy } - // If no data in buffer, set up a waker to wait for more data - // Free the old waker if it exists if conn.ReadWaker != nil { conn.ReadWaker.Free() } - // Create a new waker conn.ReadWaker = ctx.Waker() - // Return HYPER_IO_PENDING to indicate operation is pending, waiting for more data return hyper.IoPending } // onWrite is the libuv callback for writing to a socket // Callback function called after a write operation completes func onWrite(req *libuv.Write, status c.Int) { - // Get the connection data associated with the write request conn := (*connData)((*libuv.Req)(c.Pointer(req)).GetData()) - - // If there's a pending write waker if conn.WriteWaker != nil { // Wake up the pending write operation conn.WriteWaker.Wake() - // Clear the waker reference conn.WriteWaker = nil } } // writeCallBack write callback function for Hyper library func writeCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { - // Get the user data (connection data) conn := (*connData)(userdata) - // Create a libuv buffer initBuf := libuv.InitBuf((*c.Char)(c.Pointer(buf)), c.Uint(bufLen)) req := &libuv.Write{} - // Associate the connection data with the write request (*libuv.Req)(c.Pointer(req)).SetData(c.Pointer(conn)) - // Perform the asynchronous write operation ret := req.Write((*libuv.Stream)(c.Pointer(&conn.TcpHandle)), &initBuf, 1, onWrite) - // If the write operation was successfully initiated if ret >= 0 { conn.nwrite += int64(bufLen) - // Return the number of bytes to be written return bufLen } - // If the write operation can't complete immediately, set up a waker to wait for completion if conn.WriteWaker != nil { - // Free the old waker if it exists conn.WriteWaker.Free() } - // Create a new waker conn.WriteWaker = ctx.Waker() - // Return HYPER_IO_PENDING to indicate operation is pending, waiting for write to complete return hyper.IoPending } // onTimeout is the libuv callback for a timeout func onTimeout(timer *libuv.Timer) { if debugSwitch { - println("onTimeout start") - defer println("onTimeout end") + println("############### onTimeout start") + defer println("############### onTimeout end") } data := (*timeoutData)((*libuv.Handle)(c.Pointer(timer)).GetData()) close(data.timeoutch) @@ -1850,13 +1630,12 @@ func onTimeout(timer *libuv.Timer) { pc := taskData.pc pc.alive = false pc.t.cancelRequest(taskData.req.cancelKey, errors.New("timeout: req.Context().Err()")) - // defer - readLoopDefer(pc, pc.t) + readLoopDefer(pc, true) } } -// newIoWithConnReadWrite creates a new IO with read and write callbacks -func newIoWithConnReadWrite(connData *connData) *hyper.Io { +// newHyperIo creates a new IO with read and write callbacks +func newHyperIo(connData *connData) *hyper.Io { hyperIo := hyper.NewIo() hyperIo.SetUserdata(c.Pointer(connData)) hyperIo.SetRead(readCallBack) @@ -1864,79 +1643,98 @@ func newIoWithConnReadWrite(connData *connData) *hyper.Io { return hyperIo } -// taskId The unique identifier of the next task polled from the executor -type taskId c.Int - -const ( - notSet taskId = iota - handshake - read - readDone -) - // checkTaskType checks the task type -func checkTaskType(task *hyper.Task, curTaskId taskId) error { - switch curTaskId { - case handshake: - if task.Type() == hyper.TaskError { - log.Printf("[readWriteLoop::handshake]handshake task error!\n") - return fail((*hyper.Error)(task.Value())) - } - if task.Type() != hyper.TaskClientConn { - return fmt.Errorf("[readWriteLoop::handshake]unexpected task type\n") - } - return nil - case read: - if task.Type() == hyper.TaskError { - log.Printf("[readWriteLoop::read]write task error!\n") - return fail((*hyper.Error)(task.Value())) - } - if task.Type() != hyper.TaskResponse { - c.Printf(c.Str("[readWriteLoop::read]unexpected task type\n")) - return errors.New("[readWriteLoop::read]unexpected task type\n") +func checkTaskType(task *hyper.Task, taskData *taskData) (err error) { + curTaskId := taskData.taskId + taskType := task.Type() + if taskType == hyper.TaskError { + err = fail((*hyper.Error)(task.Value()), curTaskId) + } + if err == nil { + switch curTaskId { + case handshake: + if taskType != hyper.TaskClientConn { + err = errors.New("Unexpected hyper task type: expected to be TaskClientConn, actual is " + strTaskType(taskType)) + } + case read: + if taskType != hyper.TaskResponse { + err = errors.New("Unexpected hyper task type: expected to be TaskResponse, actual is " + strTaskType(taskType)) + } + case readBodyChunk: + if taskType != hyper.TaskBuf && taskType != hyper.TaskEmpty { + err = errors.New("Unexpected hyper task type: expected to be TaskBuf / TaskEmpty, actual is " + strTaskType(taskType)) + } } - return nil - case readDone: - if task.Type() == hyper.TaskError { - log.Printf("[readWriteLoop::readDone]read response body error!\n") - return fail((*hyper.Error)(task.Value())) + } + if err != nil { + task.Free() + if curTaskId == handshake || curTaskId == read { + taskData.writeErrCh <- err + taskData.pc.close(err) } - return nil - case notSet: + taskData.pc.alive = false } - return errors.New("[readWriteLoop]unexpected task type\n") + return } // fail prints the error details and panics -func fail(err *hyper.Error) error { +func fail(err *hyper.Error, taskId taskId) error { if err != nil { - c.Printf(c.Str("[readWriteLoop]error code: %d\n"), err.Code()) // grab the error details var errBuf [256]c.Char errLen := err.Print((*uint8)(c.Pointer(&errBuf[:][0])), uintptr(len(errBuf))) - - c.Printf(c.Str("[readWriteLoop]details: %.*s\n"), c.Int(errLen), c.Pointer(&errBuf[:][0])) + errDetails := unsafe.SliceData(errBuf[:errLen]) + details := c.GoString(errDetails) // clean up the error err.Free() - return fmt.Errorf("[readWriteLoop]hyper request error, error code: %d\n", int(err.Code())) + return fmt.Errorf("hyper request error, taskId: %s, details: %s\n", strTaskId(taskId), details) } return nil } +func strTaskType(taskType hyper.TaskReturnType) string { + switch taskType { + case hyper.TaskClientConn: + return "TaskClientConn" + case hyper.TaskResponse: + return "TaskResponse" + case hyper.TaskBuf: + return "TaskBuf" + case hyper.TaskEmpty: + return "TaskEmpty" + case hyper.TaskError: + return "TaskError" + default: + return "Unknown" + } +} + +func strTaskId(taskId taskId) string { + switch taskId { + case handshake: + return "handshake" + case read: + return "read" + case readBodyChunk: + return "readBodyChunk" + default: + return "notSet" + } +} + // ---------------------------------------------------------- // error values for debugging and testing, not seen by users. var ( - errKeepAlivesDisabled = errors.New("http: putIdleConn: keep alives disabled") - errConnBroken = errors.New("http: putIdleConn: connection is in bad state") - errCloseIdle = errors.New("http: putIdleConn: CloseIdleConnections was called") - errTooManyIdle = errors.New("http: putIdleConn: too many idle connections") - errTooManyIdleHost = errors.New("http: putIdleConn: too many idle connections for host") - errCloseIdleConns = errors.New("http: CloseIdleConnections called") - errReadLoopExiting = errors.New("http: Transport.readWriteLoop.read exiting") - errReadWriteLoopExiting = errors.New("http: Transport.readWriteLoop exiting") - errIdleConnTimeout = errors.New("http: idle connection timeout") + errKeepAlivesDisabled = errors.New("http: putIdleConn: keep alives disabled") + errConnBroken = errors.New("http: putIdleConn: connection is in bad state") + errCloseIdle = errors.New("http: putIdleConn: CloseIdleConnections was called") + errTooManyIdle = errors.New("http: putIdleConn: too many idle connections") + errTooManyIdleHost = errors.New("http: putIdleConn: too many idle connections for host") + errCloseIdleConns = errors.New("http: CloseIdleConnections called") + errReadLoopExiting = errors.New("http: Transport.readWriteLoop.read exiting") + errIdleConnTimeout = errors.New("http: idle connection timeout") // errServerClosedIdle is not seen by users for idempotent requests, but may be // seen by a user if the server shuts down an idle connection and sends its FIN @@ -1971,14 +1769,6 @@ func (e *httpError) Error() string { return e.err } func (e *httpError) Timeout() bool { return e.timeout } func (e *httpError) Temporary() bool { return true } -// fakeLocker is a sync.Locker which does nothing. It's used to guard -// test-only fields when not under test, to avoid runtime atomic -// overhead. -type fakeLocker struct{} - -func (fakeLocker) Lock() {} -func (fakeLocker) Unlock() {} - // nothingWrittenError wraps a write errors which ended up writing zero bytes. type nothingWrittenError struct { error @@ -2014,9 +1804,6 @@ var ( testHookRoundTripRetried = nop testHookPrePendingDial = nop testHookPostPendingDial = nop - - testHookMu sync.Locker = fakeLocker{} // guards following - testHookReadLoopBeforeNextRead = nop ) var portMap = map[string]string{ @@ -2076,10 +1863,34 @@ type persistConn struct { mutateHeaderFunc func(Header) // other - alive bool // Replace the alive in readLoop - closeErr error // Replace the closeErr in readLoop - tryPutIdleConn func() bool // Replace the tryPutIdleConn in readLoop - client *hyper.ClientConn + alive bool // Replace the alive in readLoop + closeErr error // Replace the closeErr in readLoop + tryPutIdleConn func() bool // Replace the tryPutIdleConn in readLoop + client *hyper.ClientConn // http long connection client handle + bodyChunk *bodyChunk // Implement non-blocking consumption of each responseBody chunk + chunkAsync *libuv.Async // Notifying that the received chunk has been read +} + +// CloseIdleConnections closes any connections which were previously +// connected from previous requests but are now sitting idle in +// a "keep-alive" state. It does not interrupt any connections currently +// in use. +func (t *Transport) CloseIdleConnections() { + //t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) + t.idleMu.Lock() + m := t.idleConn + t.idleConn = nil + t.closeIdle = true // close newly idle connections + t.idleLRU = connLRU{} + t.idleMu.Unlock() + for _, conns := range m { + for _, pconn := range conns { + pconn.close(errCloseIdleConns) + } + } + //if t2 := t.h2transport; t2 != nil { + // t2.CloseIdleConnections() + //} } func (pc *persistConn) cancelRequest(err error) { @@ -2110,7 +1921,7 @@ func (pc *persistConn) markReused() { func (pc *persistConn) closeLocked(err error) { if debugSwitch { - println("pc closed") + println("############### pc closed") } if err == nil { panic("nil error") @@ -2128,6 +1939,7 @@ func (pc *persistConn) closeLocked(err error) { close(pc.closech) close(pc.writeLoopDone) pc.client.Free() + pc.chunkAsync.Close(nil) } } pc.mutateHeaderFunc = nil @@ -2256,13 +2068,11 @@ func (pc *persistConn) shouldRetryRequest(req *Request, err error) bool { // the 1st response byte from the server. return true } - if err == errServerClosedIdle { - // The server replied with io.EOF while we were trying to - // read the response. Probably an unfortunately keep-alive - // timeout, just as the client was writing a request. - return true - } - return false // conservatively + // The server replied with io.EOF while we were trying to + // read the response. Probably an unfortunately keep-alive + // timeout, just as the client was writing a request. + // conservatively return false. + return err == errServerClosedIdle } // closeConnIfStillIdle closes the connection if it's still sitting idle. @@ -2300,6 +2110,45 @@ func (pc *persistConn) readLoopPeekFailLocked(resp *hyper.Response, err error) { pc.closeLocked(fmt.Errorf("readLoopPeekFailLocked: %w", err)) } +// setExtraHeaders Set extra headers, such as Accept-Encoding, Connection(Keep-Alive). +func (pc *persistConn) setExtraHeaders(req *transportRequest) bool { + // Ask for a compressed version if the caller didn't set their + // own value for Accept-Encoding. We only attempt to + // uncompress the gzip stream if we were the layer that + // requested it. + requestedGzip := false + // TODO(spongehah) gzip(pc.roundTrip) + //if !pc.t.DisableCompression && + // req.Header.Get("Accept-Encoding") == "" && + // req.Header.Get("Range") == "" && + // req.Method != "HEAD" { + // // Request gzip only, not deflate. Deflate is ambiguous and + // // not as universally supported anyway. + // // See: https://zlib.net/zlib_faq.html#faq39 + // // + // // Note that we don't request this for HEAD requests, + // // due to a bug in nginx: + // // https://trac.nginx.org/nginx/ticket/358 + // // https://golang.org/issue/5522 + // // + // // We don't request gzip if the request is for a range, since + // // auto-decoding a portion of a gzipped document will just fail + // // anyway. See https://golang.org/issue/8923 + // requestedGzip = true + // req.extraHeaders().Set("Accept-Encoding", "gzip") + //} + + // The 100-continue operation in Hyper is handled in the newHyperRequest function. + + // Keep-Alive + if pc.t.DisableKeepAlives && + !req.wantsClose() && + !isProtocolSwitchHeader(req.Header) { + req.extraHeaders().Set("Connection", "close") + } + return requestedGzip +} + func is408Message(resp *hyper.Response) bool { httpVersion := int(resp.Version()) if httpVersion != 10 && httpVersion != 11 { @@ -2435,7 +2284,6 @@ func (w *wantConn) cancel(t *Transport, err error) { w.err = err w.mu.Unlock() - // TODO(spongehah) ConnPool(w.cancel) if pc != nil { t.putOrCloseIdleConn(pc) } @@ -2534,110 +2382,6 @@ func (q *wantConnQueue) cleanFront() (cleaned bool) { } } -// bodyEOFSignal is used by the HTTP/1 transport when reading response -// bodies to make sure we see the end of a response body before -// proceeding and reading on the connection again. -// -// It wraps a ReadCloser but runs fn (if non-nil) at most -// once, right before its final (error-producing) Read or Close call -// returns. fn should return the new error to return from Read or Close. -// -// If earlyCloseFn is non-nil and Close is called before io.EOF is -// seen, earlyCloseFn is called instead of fn, and its return value is -// the return value from Close. -type bodyEOFSignal struct { - body io.ReadCloser - mu sync.Mutex // guards following 4 fields - closed bool // whether Close has been called - rerr error // sticky Read error - fn func(error) error // err will be nil on Read io.EOF - earlyCloseFn func() error // optional alt Close func used if io.EOF not seen -} - -var errReadOnClosedResBody = errors.New("http: read on closed response body") - -func (es *bodyEOFSignal) Read(p []byte) (n int, err error) { - es.mu.Lock() - closed, rerr := es.closed, es.rerr - es.mu.Unlock() - if closed { - return 0, errReadOnClosedResBody - } - if rerr != nil { - return 0, rerr - } - - n, err = es.body.Read(p) - if err != nil { - es.mu.Lock() - defer es.mu.Unlock() - if es.rerr == nil { - es.rerr = err - } - err = es.condfn(err) - } - return -} - -func (es *bodyEOFSignal) Close() error { - es.mu.Lock() - defer es.mu.Unlock() - if es.closed { - return nil - } - es.closed = true - if es.earlyCloseFn != nil && es.rerr != io.EOF { - return es.earlyCloseFn() - } - err := es.body.Close() - return es.condfn(err) -} - -// caller must hold es.mu. -func (es *bodyEOFSignal) condfn(err error) error { - if es.fn == nil { - return err - } - err = es.fn(err) - es.fn = nil - return err -} - -// gzipReader wraps a response body so it can lazily -// call gzip.NewReader on the first call to Read -type gzipReader struct { - _ incomparable - body *bodyEOFSignal // underlying HTTP/1 response body framing - zr *gzip.Reader // lazily-initialized gzip reader - zerr error // any error from gzip.NewReader; sticky -} - -func (gz *gzipReader) Read(p []byte) (n int, err error) { - if gz.zr == nil { - if gz.zerr == nil { - gz.zr, gz.zerr = gzip.NewReader(gz.body) - } - if gz.zerr != nil { - return 0, gz.zerr - } - } - - gz.body.mu.Lock() - if gz.body.closed { - err = errReadOnClosedResBody - } - gz.body.mu.Unlock() - - if err != nil { - return 0, err - } - return gz.zr.Read(p) -} - -func (gz *gzipReader) Close() error { - return gz.body.Close() -} - type connLRU struct { ll *list.List // list.Element.Value type is of *persistConn m map[*persistConn]*list.Element From 67f1d214b0cf28677b6049f6ce09fe2d995609b4 Mon Sep 17 00:00:00 2001 From: hackerchai Date: Thu, 12 Sep 2024 17:43:39 +0800 Subject: [PATCH 40/55] refactor(x/net/http/demo): Neat http echo demo Signed-off-by: hackerchai --- x/net/http/_demo/http.go | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/x/net/http/_demo/http.go b/x/net/http/_demo/http.go index 9471272..725410b 100644 --- a/x/net/http/_demo/http.go +++ b/x/net/http/_demo/http.go @@ -17,8 +17,6 @@ func echoHandler(w http.ResponseWriter, r *http.Request) { } fmt.Printf(">> URL: %s\n", r.URL.String()) fmt.Printf(">> RemoteAddr: %s\n", r.RemoteAddr) - // fmt.Println("ContentLength: %d", r.ContentLength) - // fmt.Println("TransferEncoding: %s", r.TransferEncoding) body, err := io.ReadAll(r.Body) if err != nil { @@ -26,23 +24,8 @@ func echoHandler(w http.ResponseWriter, r *http.Request) { return } defer r.Body.Close() - - // var body []byte - // buffer := make([]byte, 1024) - // for { - // n, err := r.Body.Read(buffer) - // if err != nil && err != io.EOF { - // http.Error(w, "Error reading request body", http.StatusInternalServerError) - // return - // } - // body = append(body, buffer[:n]...) - // if err == io.EOF { - // break - // } - // } - fmt.Printf(">> Body: %s\n", string(body)) - fmt.Println("[debug] body read done") + w.Header().Set("Content-Type", "text/plain") w.Write(body) } From 0ce31d61aad3447f3981acb492c4c91c2f9b7791 Mon Sep 17 00:00:00 2001 From: hackerchai Date: Thu, 12 Sep 2024 17:44:50 +0800 Subject: [PATCH 41/55] refactor(x/net/http/demo): Re-implement requestBody Signed-off-by: hackerchai --- x/net/http/request_body.go | 31 ++++--------------------------- 1 file changed, 4 insertions(+), 27 deletions(-) diff --git a/x/net/http/request_body.go b/x/net/http/request_body.go index be19f56..7bb5dff 100644 --- a/x/net/http/request_body.go +++ b/x/net/http/request_body.go @@ -9,26 +9,6 @@ import ( "github.com/goplus/llgo/c/libuv" ) -type onceError struct { - sync.Mutex - err error -} - -func (a *onceError) Store(err error) { - a.Lock() - defer a.Unlock() - if a.err != nil { - return - } - a.err = err -} - -func (a *onceError) Load() error { - a.Lock() - defer a.Unlock() - return a.err -} - type requestBody struct { chunk []byte readCh chan []byte @@ -37,7 +17,7 @@ type requestBody struct { once sync.Once done chan struct{} - rerr onceError + rerr error } var ( @@ -83,7 +63,7 @@ func (rb *requestBody) Read(p []byte) (n int, err error) { } func (rb *requestBody) readCloseError() error { - if rerr := rb.rerr.Load(); rerr != nil { + if rerr := rb.rerr; rerr != nil { return rerr } return ErrClosedRequestBody @@ -94,11 +74,8 @@ func (rb *requestBody) closeRead(err error) error { if err == nil { err = io.EOF } - rb.rerr.Store(err) - rb.once.Do(func() { - close(rb.done) - }) - //close(rb.done) + rb.rerr = err + close(rb.done) return nil } From 3fae8cdbadac8eee2864a2360a8d59fe0a42a4fd Mon Sep 17 00:00:00 2001 From: hackerchai Date: Thu, 12 Sep 2024 17:46:01 +0800 Subject: [PATCH 42/55] fix(x/net/http): Fix nil pointer error & optimize naming Signed-off-by: hackerchai --- x/net/http/request.go | 10 +- x/net/http/response.go | 55 ++++----- x/net/http/server.go | 248 +++++++++++++++-------------------------- 3 files changed, 118 insertions(+), 195 deletions(-) diff --git a/x/net/http/request.go b/x/net/http/request.go index 5ae022f..2389794 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -136,12 +136,12 @@ func (conn *conn) readRequest(hyperReq *hyper.Request) (*Request, error) { body := hyperReq.Body() if body != nil { task := body.Data() - taskID := taskGetBody + taskFlag := getBodyTask taskData := taskData{ - hyperBody: body, - body: nil, - conn: conn, - hyperTaskID: taskID, + hyperBody: body, + responseBody: nil, + conn: conn, + taskFlag: taskFlag, } task.SetUserdata(c.Pointer(&taskData), nil) requestBody := newRequestBody(conn.asyncHandle) diff --git a/x/net/http/response.go b/x/net/http/response.go index c3ac428..9b166c9 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -5,7 +5,6 @@ import ( "unsafe" "github.com/goplus/llgo/c" - "github.com/goplus/llgo/c/libuv" "github.com/goplus/llgo/c/os" "github.com/goplus/llgo/rust/hyper" ) @@ -17,39 +16,36 @@ type response struct { body []byte hyperChannel *hyper.ResponseChannel hyperResp *hyper.Response - request *Request - asyncHandler *libuv.Async } -type body struct { +type responseBodyRaw struct { data []byte len uintptr readLen uintptr } type taskData struct { - hyperBody *hyper.Body - body *body - conn *conn - hyperTaskID hyperTaskID + hyperBody *hyper.Body + responseBody *responseBodyRaw + conn *conn + taskFlag taskFlag } -type hyperTaskID int +type taskFlag int const ( - taskSetBody hyperTaskID = iota - taskGetBody + setBodyTask taskFlag = iota + getBodyTask ) var DefaultChunkSize uintptr = 8192 -func newResponse(request *Request, hyperChannel *hyper.ResponseChannel) *response { +func newResponse(hyperChannel *hyper.ResponseChannel) *response { fmt.Printf("[debug] newResponse called\n") return &response{ header: make(Header), hyperChannel: hyperChannel, - //request: request, statusCode: 200, written: false, body: nil, @@ -65,9 +61,7 @@ func (r *response) Write(data []byte) (int, error) { if !r.written { r.WriteHeader(200) } - fmt.Printf("[debug] data: %s\n", string(data)) r.body = append(r.body, data...) - fmt.Printf("[debug] r.body: %s\n", string(r.body)) return len(data), nil } @@ -84,12 +78,12 @@ func (r *response) WriteHeader(statusCode int) { fmt.Println("[debug] WriteHeaderStatusCode done") //debug - // fmt.Printf("[debug] < HTTP/1.1 %d\n", statusCode) - // for key, values := range r.header { - // for _, value := range values { - // fmt.Printf("< %s: %s\n", key, value) - // } - // } + fmt.Printf("[debug] < HTTP/1.1 %d\n", statusCode) + for key, values := range r.header { + for _, value := range values { + fmt.Printf("< %s: %s\n", key, value) + } + } headers := r.hyperResp.Headers() for key, values := range r.header { @@ -109,18 +103,11 @@ func (r *response) WriteHeader(statusCode int) { } } - fmt.Println("[debug] WriteHeaderHeaders done") - fmt.Println("[debug] WriteHeader done") } func (r *response) finalize() error { fmt.Printf("[debug] finalize called\n") - // err := r.request.Body.Close() - // if err != nil { - // return err - // } - // fmt.Printf("[debug] request body closed\n") if !r.written { r.WriteHeader(200) @@ -132,7 +119,7 @@ func (r *response) finalize() error { return fmt.Errorf("failed to create response") } - bodyData := body{ + bodyData := responseBodyRaw{ data: r.body, len: uintptr(len(r.body)), readLen: 0, @@ -144,10 +131,10 @@ func (r *response) finalize() error { return fmt.Errorf("failed to create body") } taskData := &taskData{ - hyperBody: nil, - body: &bodyData, - conn: nil, - hyperTaskID: taskSetBody, + hyperBody: nil, + responseBody: &bodyData, + conn: nil, + taskFlag: setBodyTask, } body.SetDataFunc(setBodyDataFunc) body.SetUserdata(unsafe.Pointer(taskData), nil) @@ -173,7 +160,7 @@ func setBodyDataFunc(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) fmt.Println("[debug] taskData is nil") return hyper.PollError } - body := taskData.body + body := taskData.responseBody if body.len > 0 { //debug diff --git a/x/net/http/server.go b/x/net/http/server.go index cc9601e..0b26cf1 100644 --- a/x/net/http/server.go +++ b/x/net/http/server.go @@ -35,7 +35,6 @@ type Server struct { uvLoop *libuv.Loop uvServer libuv.Tcp inShutdown atomic.Bool - //checkHandle libuv.Check idleHandle libuv.Idle mu sync.Mutex @@ -124,18 +123,6 @@ func (srv *Server) ListenAndServe() error { return fmt.Errorf("failed to listen: %v", err) } - // if r := libuv.InitCheck(srv.uvLoop, &srv.checkHandle); r != 0 { - // fmt.Fprintf(os.Stderr, "Failed to initialize check handler: %d\n", r) - // os.Exit(1) - // } - - // (*libuv.Handle)(unsafe.Pointer(&srv.checkHandle)).SetData(unsafe.Pointer(srv)) - - // if r := srv.checkHandle.Start(onCheck); r != 0 { - // fmt.Fprintf(os.Stderr, "Failed to start check handler: %d\n", r) - // os.Exit(1) - // } - if r := libuv.InitIdle(srv.uvLoop, &srv.idleHandle); r != 0 { fmt.Fprintf(os.Stderr, "Failed to initialize idle handler: %d\n", r) os.Exit(1) @@ -302,34 +289,13 @@ func onAsync(asyncHandle *libuv.Async) { } } -// func onCheck(handle *libuv.Check) { -// //fmt.Println("onCheck called") -// srv := (*Server)((*libuv.Handle)(unsafe.Pointer(handle)).GetData()) -// for conn := range srv.activeConnections { -// if conn.executor != nil { -// task := conn.executor.Poll() -// for task != nil { -// srv.handleTask(task) -// task = conn.executor.Poll() -// } -// } -// } - -// if srv.shuttingDown() { -// fmt.Println("Shutdown initiated, cleaning up...") -// handle.Stop() -// } -// } - func onIdle(handle *libuv.Idle) { - //fmt.Println("onIdle called") srv := (*Server)((*libuv.Handle)(unsafe.Pointer(handle)).GetData()) for conn := range srv.activeConnections { if conn.executor != nil { task := conn.executor.Poll() for task != nil { srv.handleTask(task) - //srv.handleRead(conn, task) task = conn.executor.Poll() } } @@ -355,151 +321,121 @@ func serverCallback(userdata unsafe.Pointer, hyperReq *hyper.Request, channel *h return } - res := newResponse(req, channel) + res := newResponse(channel) fmt.Println("[debug] Response created") + //TODO(hackerchai): replace with no goroutine + // userData.server.Handler.ServeHTTP(res, req) + // res.finalize() go func() { userData.server.Handler.ServeHTTP(res, req) res.finalize() }() - - // userData.server.Handler.ServeHTTP(res, req) - - // res.finalize() } -// func (srv *Server) handleRead(conn *conn, task *hyper.Task) { -// payload := (*taskData)(task.Userdata()) -// if payload == nil { -// fmt.Println("taskData is nil, no need to handle read") -// return -// } - -// select { -// case <-conn.requestBody.readyCh: -// fmt.Println("readyCh signaled") - -// fmt.Println("taskGetBody get body task form readyCh") -// getBodyTask := payload.hyperBody.Data() -// getBodyTask.SetUserdata(c.Pointer(payload), nil) -// if getBodyTask != nil { -// fmt.Println("taskGetBody push get body task") -// r := payload.conn.executor.Push(getBodyTask) -// fmt.Printf("taskGetBody push get body task: %d\n", r) -// if r != hyper.OK { -// fmt.Printf("failed to push get body task: %d\n", r) -// getBodyTask.Free() -// } -// } -// default: -// fmt.Println("readToReadCh not signaled") -// } -// } - func (srv *Server) handleTask(task *hyper.Task) { - taskType := task.Type() - //debug - switch taskType { - case hyper.TaskEmpty: - fmt.Println("[debug] Task type: Empty") - case hyper.TaskBuf: - fmt.Println("[debug] Task type: Buffer") - case hyper.TaskError: - fmt.Println("[debug] Task type: Error") - case hyper.TaskServerconn: - fmt.Println("[debug] Task type: Serverconn") - default: - fmt.Println("[debug] Unknown task type") - } + hyperTaskType := task.Type() + // Debug + fmt.Printf("[debug] Task type: %s\n", getTaskTypeString(hyperTaskType)) payload := (*taskData)(task.Userdata()) - if payload != nil { - taskID := payload.hyperTaskID - - // select { - // case <-conn.requestBody.readyCh: - // fmt.Println("readyCh recieved") - - // fmt.Println("taskGetBody get body task form readyCh") - // getBodyTask := payload.hyperBody.Data() - // getBodyTask.SetUserdata(c.Pointer(payload), nil) - // if getBodyTask != nil { - // fmt.Println("taskGetBody push get body task") - // r := payload.conn.executor.Push(getBodyTask) - // fmt.Printf("taskGetBody push get body task: %d\n", r) - // if r != hyper.OK { - // fmt.Printf("failed to push get body task: %d\n", r) - // getBodyTask.Free() - // } - // } - // default: - // fmt.Println("readyCh not recieved") - // } - - if taskID == taskGetBody { - fmt.Println("[debug] taskGetBody called") - if taskType == hyper.TaskError { - fmt.Println("[debug] taskGetBody error") - err := (*hyper.Error)(task.Value()) - fmt.Printf("error code: %d\n", err.Code()) - - var errbuf [256]byte - errlen := err.Print(&errbuf[0], unsafe.Sizeof(errbuf)) - fmt.Printf("details: %s\n", errbuf[:errlen]) - err.Free() - task.Free() - } - - if taskType == hyper.TaskBuf { - fmt.Println("[debug] taskGetBody write buf") - buf := (*hyper.Buf)(task.Value()) - bytes := unsafe.Slice(buf.Bytes(), buf.Len()) - fmt.Printf("[debug] taskGetBody writing to bodyWriter: %s\n", string(bytes)) - buf.Free() - task.Free() - fmt.Println("[debug] taskGetBody free task") - payload.conn.requestBody.readCh <- bytes - fmt.Println("[debug] taskGetBody wrote to bodyWriter") - } - if taskType == hyper.TaskEmpty { - fmt.Println("[debug] taskGetBody close requestBody") - payload.conn.requestBody.Close() - fmt.Println("[debug] taskGetBody free task") - task.Free() - } - } else if taskID == taskSetBody { - fmt.Println("[debug] taskSetBody called") - if taskType == hyper.TaskError { - fmt.Println("[debug] taskSetBody error") - err := (*hyper.Error)(task.Value()) - fmt.Printf("error code: %d\n", err.Code()) - - var errbuf [256]byte - errlen := err.Print(&errbuf[0], unsafe.Sizeof(errbuf)) - fmt.Printf("details: %s\n", errbuf[:errlen]) - err.Free() - task.Free() - } + // Debug + if payload == nil { + fmt.Println("[debug] task data is nil") + } - if taskType == hyper.TaskEmpty { - fmt.Println("[debug] taskSetBody free task") - task.Free() - } + if payload != nil { + switch payload.taskFlag { + case getBodyTask: + handleGetBodyTask(hyperTaskType, task, payload) + return + case setBodyTask: + handleSetBodyTask(hyperTaskType, task) + return + default: + fmt.Println("[debug] Unknown response task type") + return } } - if taskType == hyper.TaskEmpty { - fmt.Println("[debug] taskEmpty called") + switch hyperTaskType { + case hyper.TaskError: + handleTaskError(task) + return + case hyper.TaskEmpty: + fmt.Println("[debug] Empty task handled") + task.Free() + return + case hyper.TaskServerconn: + fmt.Println("[debug] Server connection task handled") task.Free() + return } +} - if taskType == hyper.TaskServerconn { - fmt.Println("[debug] taskServerconn called") +func handleGetBodyTask(hyperTaskType hyper.TaskReturnType, task *hyper.Task, payload *taskData) { + switch hyperTaskType { + case hyper.TaskError: + handleTaskError(task) + case hyper.TaskBuf: + handleTaskBuffer(task, payload) + case hyper.TaskEmpty: + fmt.Println("[debug] Get body task closing request body") + payload.conn.requestBody.Close() task.Free() } } +func handleSetBodyTask(hyperTaskType hyper.TaskReturnType, task *hyper.Task) { + switch hyperTaskType { + case hyper.TaskError: + handleTaskError(task) + case hyper.TaskEmpty: + fmt.Println("[debug] Set body task freeing") + task.Free() + } +} + +func handleTaskError(task *hyper.Task) { + err := (*hyper.Error)(task.Value()) + fmt.Printf("Error code: %d\n", err.Code()) + + var errbuf [256]byte + errlen := err.Print(&errbuf[0], unsafe.Sizeof(errbuf)) + fmt.Printf("Details: %s\n", errbuf[:errlen]) + err.Free() + task.Free() +} + +func handleTaskBuffer(task *hyper.Task, payload *taskData) { + buf := (*hyper.Buf)(task.Value()) + bytes := unsafe.Slice(buf.Bytes(), buf.Len()) + payload.conn.requestBody.readCh <- bytes + fmt.Printf("[debug] Task get body writing to bodyWriter: %s\n", string(bytes)) + buf.Free() + task.Free() +} + +func getTaskTypeString(taskType hyper.TaskReturnType) string { + switch taskType { + case hyper.TaskEmpty: + return "Empty" + case hyper.TaskBuf: + return "Buffer" + case hyper.TaskError: + return "Error" + case hyper.TaskServerconn: + return "Server connection" + case hyper.TaskClientConn: + return "Client connection" + case hyper.TaskResponse: + return "Response" + default: + return "Unknown" + } +} + func (s *Server) trackConn(c *conn, add bool) { s.mu.Lock() defer s.mu.Unlock() From 22ab2d5cba6eb387f621849d87fde72ac2658ec5 Mon Sep 17 00:00:00 2001 From: hackerchai Date: Thu, 12 Sep 2024 17:50:34 +0800 Subject: [PATCH 43/55] fix(x/net/http): Remove updateConnRegistrations unsuse args Signed-off-by: hackerchai --- x/net/http/server.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/x/net/http/server.go b/x/net/http/server.go index 0b26cf1..af440d6 100644 --- a/x/net/http/server.go +++ b/x/net/http/server.go @@ -188,7 +188,7 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { (*libuv.Handle)(unsafe.Pointer(&conn.pollHandle)).Data = unsafe.Pointer(conn) - if !updateConnRegistrations(conn, true) { + if !updateConnRegistrations(conn) { (*libuv.Handle)(unsafe.Pointer(&conn.pollHandle)).Close(nil) (*libuv.Handle)(unsafe.Pointer(&conn.stream)).Close(nil) return @@ -484,7 +484,7 @@ func readCb(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uintp if conn.eventMask&c.Uint(libuv.READABLE) == 0 { conn.eventMask |= c.Uint(libuv.READABLE) fmt.Printf("[debug] ReadCb Event mask: %d\n", conn.eventMask) - if !updateConnRegistrations(conn, false) { + if !updateConnRegistrations(conn) { return hyper.IoError } fmt.Printf("[debug] ReadCb updateConnRegistrations\n") @@ -513,7 +513,7 @@ func writeCb(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uint if conn.eventMask&c.Uint(libuv.WRITABLE) == 0 { conn.eventMask |= c.Uint(libuv.WRITABLE) fmt.Printf("[debug] WriteCb Event mask: %d\n", conn.eventMask) - if !updateConnRegistrations(conn, false) { + if !updateConnRegistrations(conn) { return hyper.IoError } } @@ -542,7 +542,7 @@ func onPoll(handle *libuv.Poll, status c.Int, events c.Int) { } } -func updateConnRegistrations(conn *conn, create bool) bool { +func updateConnRegistrations(conn *conn) bool { fmt.Println("[debug] updateConnRegistrations called") events := c.Int(0) From e1c3717c943f76a66a761995a8a631830f236f65 Mon Sep 17 00:00:00 2001 From: hackerchai Date: Thu, 12 Sep 2024 18:11:59 +0800 Subject: [PATCH 44/55] refactor(x/net/http): Rewrite loop <-> hyper.executor logic Signed-off-by: hackerchai --- x/net/http/request.go | 14 ++++++----- x/net/http/response.go | 9 ++++--- x/net/http/server.go | 57 +++++++++++++++++++++--------------------- 3 files changed, 42 insertions(+), 38 deletions(-) diff --git a/x/net/http/request.go b/x/net/http/request.go index 2389794..e8ad072 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -35,7 +35,7 @@ type Request struct { timeout time.Duration } -func (conn *conn) readRequest(hyperReq *hyper.Request) (*Request, error) { +func (conn *conn) readRequest(srv *Server, hyperReq *hyper.Request) (*Request, error) { println("[debug] readRequest called") req := Request{ Header: make(Header), @@ -137,21 +137,23 @@ func (conn *conn) readRequest(hyperReq *hyper.Request) (*Request, error) { if body != nil { task := body.Data() taskFlag := getBodyTask + + requestBody := newRequestBody(conn.asyncHandle) + req.Body = requestBody + taskData := taskData{ hyperBody: body, responseBody: nil, - conn: conn, + requestBody: requestBody, taskFlag: taskFlag, + server: srv, } task.SetUserdata(c.Pointer(&taskData), nil) - requestBody := newRequestBody(conn.asyncHandle) - conn.requestBody = requestBody - req.Body = requestBody conn.asyncHandle.SetData(c.Pointer(&taskData)) fmt.Println("[debug] async task set") if task != nil { - r := conn.executor.Push(task) + r := srv.executor.Push(task) if r != hyper.OK { fmt.Printf("failed to push body foreach task: %d\n", r) task.Free() diff --git a/x/net/http/response.go b/x/net/http/response.go index 9b166c9..acd853f 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -14,6 +14,7 @@ type response struct { statusCode int written bool body []byte + server *Server hyperChannel *hyper.ResponseChannel hyperResp *hyper.Response } @@ -27,7 +28,8 @@ type responseBodyRaw struct { type taskData struct { hyperBody *hyper.Body responseBody *responseBodyRaw - conn *conn + requestBody *requestBody + server *Server taskFlag taskFlag } @@ -40,12 +42,13 @@ const ( var DefaultChunkSize uintptr = 8192 -func newResponse(hyperChannel *hyper.ResponseChannel) *response { +func newResponse(server *Server, hyperChannel *hyper.ResponseChannel) *response { fmt.Printf("[debug] newResponse called\n") return &response{ header: make(Header), hyperChannel: hyperChannel, + server: server, statusCode: 200, written: false, body: nil, @@ -133,7 +136,7 @@ func (r *response) finalize() error { taskData := &taskData{ hyperBody: nil, responseBody: &bodyData, - conn: nil, + server: r.server, taskFlag: setBodyTask, } body.SetDataFunc(setBodyDataFunc) diff --git a/x/net/http/server.go b/x/net/http/server.go index af440d6..aae76de 100644 --- a/x/net/http/server.go +++ b/x/net/http/server.go @@ -37,6 +37,8 @@ type Server struct { inShutdown atomic.Bool idleHandle libuv.Idle + executor *hyper.Executor + mu sync.Mutex activeConnections map[*conn]struct{} } @@ -51,9 +53,7 @@ type conn struct { http2Opts *hyper.Http2ServerconnOptions isClosing atomic.Bool closedHandles int32 - executor *hyper.Executor remoteAddr string - requestBody *requestBody asyncHandle *libuv.Async } @@ -225,7 +225,7 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { (*libuv.Handle)(unsafe.Pointer(&conn.stream)).Close(nil) return } - conn.executor = executor + srv.executor = executor fmt.Println("[debug] Conn created") srv.trackConn(conn, true) @@ -236,7 +236,7 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { io := createIo(conn) service := hyper.ServiceNew(serverCallback) service.SetUserdata(unsafe.Pointer(userdata), nil) - http1Opts := hyper.Http1ServerconnOptionsNew(conn.executor) + http1Opts := hyper.Http1ServerconnOptionsNew(srv.executor) if http1Opts == nil { fmt.Fprintf(os.Stderr, "Failed to create http1_opts\n") os.Exit(1) @@ -248,7 +248,7 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { } conn.http1Opts = http1Opts - http2Opts := hyper.Http2ServerconnOptionsNew(conn.executor) + http2Opts := hyper.Http2ServerconnOptionsNew(srv.executor) if http2Opts == nil { fmt.Fprintf(os.Stderr, "Failed to create http2_opts\n") os.Exit(1) @@ -266,7 +266,7 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { conn.http2Opts = http2Opts serverconn := hyper.ServeHttpXConnection(http1Opts, http2Opts, io, service) - conn.executor.Push(serverconn) + srv.executor.Push(serverconn) } else { fmt.Println("[debug] Client not accepted") (*libuv.Handle)(unsafe.Pointer(&conn.pollHandle)).Close(nil) @@ -280,7 +280,7 @@ func onAsync(asyncHandle *libuv.Async) { dataTask := taskData.hyperBody.Data() dataTask.SetUserdata(c.Pointer(taskData), nil) if dataTask != nil { - r := taskData.conn.executor.Push(dataTask) + r := taskData.server.executor.Push(dataTask) fmt.Printf("[debug] onAsync push data task: %d\n", r) if r != hyper.OK { fmt.Printf("failed to push data task: %d\n", r) @@ -291,13 +291,11 @@ func onAsync(asyncHandle *libuv.Async) { func onIdle(handle *libuv.Idle) { srv := (*Server)((*libuv.Handle)(unsafe.Pointer(handle)).GetData()) - for conn := range srv.activeConnections { - if conn.executor != nil { - task := conn.executor.Poll() - for task != nil { - srv.handleTask(task) - task = conn.executor.Poll() - } + if srv.executor != nil { + task := srv.executor.Poll() + for task != nil { + srv.handleTask(task) + task = srv.executor.Poll() } } @@ -309,19 +307,24 @@ func onIdle(handle *libuv.Idle) { func serverCallback(userdata unsafe.Pointer, hyperReq *hyper.Request, channel *hyper.ResponseChannel) { userData := (*serviceUserdata)(userdata) + srv := userData.server + if srv == nil { + fmt.Fprintf(os.Stderr, "Error: Received null server\n") + return + } if hyperReq == nil { fmt.Fprintf(os.Stderr, "Error: Received null request\n") return } - req, err := userData.conn.readRequest(hyperReq) + req, err := userData.conn.readRequest(srv, hyperReq) if err != nil { fmt.Printf("Error creating request: %v\n", err) return } - res := newResponse(channel) + res := newResponse(srv, channel) fmt.Println("[debug] Response created") //TODO(hackerchai): replace with no goroutine @@ -348,7 +351,7 @@ func (srv *Server) handleTask(task *hyper.Task) { if payload != nil { switch payload.taskFlag { case getBodyTask: - handleGetBodyTask(hyperTaskType, task, payload) + handleGetBodyTask(srv, hyperTaskType, task, payload) return case setBodyTask: handleSetBodyTask(hyperTaskType, task) @@ -374,7 +377,7 @@ func (srv *Server) handleTask(task *hyper.Task) { } } -func handleGetBodyTask(hyperTaskType hyper.TaskReturnType, task *hyper.Task, payload *taskData) { +func handleGetBodyTask(srv *Server, hyperTaskType hyper.TaskReturnType, task *hyper.Task, payload *taskData) { switch hyperTaskType { case hyper.TaskError: handleTaskError(task) @@ -382,7 +385,7 @@ func handleGetBodyTask(hyperTaskType hyper.TaskReturnType, task *hyper.Task, pay handleTaskBuffer(task, payload) case hyper.TaskEmpty: fmt.Println("[debug] Get body task closing request body") - payload.conn.requestBody.Close() + payload.requestBody.Close() task.Free() } } @@ -411,7 +414,7 @@ func handleTaskError(task *hyper.Task) { func handleTaskBuffer(task *hyper.Task, payload *taskData) { buf := (*hyper.Buf)(task.Value()) bytes := unsafe.Slice(buf.Bytes(), buf.Len()) - payload.conn.requestBody.readCh <- bytes + payload.requestBody.readCh <- bytes fmt.Printf("[debug] Task get body writing to bodyWriter: %s\n", string(bytes)) buf.Free() task.Free() @@ -591,11 +594,6 @@ func freeConnData(userdata c.Pointer) { conn.writeWaker = nil } - if conn.executor != nil { - conn.executor.Free() - conn.executor = nil - } - if conn.http1Opts != nil { conn.http1Opts.Free() conn.http1Opts = nil @@ -632,6 +630,11 @@ func (srv *Server) Close() error { delete(srv.activeConnections, c) } + if srv.executor != nil { + srv.executor.Free() + srv.executor = nil + } + srv.uvLoop.Walk(closeWalkCb, nil) srv.uvLoop.Run(libuv.RUN_ONCE) (*libuv.Handle)(unsafe.Pointer(&srv.uvServer)).Close(nil) @@ -663,10 +666,6 @@ func (c *conn) Close() { c.writeWaker = nil } - if c.executor != nil { - c.executor.Free() - c.executor = nil - } if c.http1Opts != nil { c.http1Opts.Free() c.http1Opts = nil From 84209571a296360ec25763e845948066f7193aec Mon Sep 17 00:00:00 2001 From: hackerchai Date: Thu, 12 Sep 2024 18:16:56 +0800 Subject: [PATCH 45/55] fix(x/net/http): Remove unsused sync.Once Signed-off-by: hackerchai --- x/net/http/request_body.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/x/net/http/request_body.go b/x/net/http/request_body.go index 7bb5dff..1ca7f6b 100644 --- a/x/net/http/request_body.go +++ b/x/net/http/request_body.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "io" - "sync" "github.com/goplus/llgo/c/libuv" ) @@ -14,7 +13,6 @@ type requestBody struct { readCh chan []byte asyncHandle *libuv.Async - once sync.Once done chan struct{} rerr error From 03b3d7bf95ebe6d879afef9b00add2b21d568351 Mon Sep 17 00:00:00 2001 From: hackerchai Date: Thu, 19 Sep 2024 11:55:29 +0800 Subject: [PATCH 46/55] fix(x/net/http): Implement multi-thread eventLoop Signed-off-by: hackerchai --- x/net/http/_demo/http.go | 19 +- x/net/http/request.go | 26 +- x/net/http/request_body.go | 53 +-- x/net/http/response.go | 21 +- x/net/http/server.go | 645 +++++++++++++++++++++++++------------ x/net/http/servermux.go | 7 + 6 files changed, 494 insertions(+), 277 deletions(-) diff --git a/x/net/http/_demo/http.go b/x/net/http/_demo/http.go index 725410b..19e62d5 100644 --- a/x/net/http/_demo/http.go +++ b/x/net/http/_demo/http.go @@ -2,7 +2,6 @@ package main import ( "fmt" - "io" "github.com/goplus/llgo/x/net/http" ) @@ -18,16 +17,18 @@ func echoHandler(w http.ResponseWriter, r *http.Request) { fmt.Printf(">> URL: %s\n", r.URL.String()) fmt.Printf(">> RemoteAddr: %s\n", r.RemoteAddr) - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, "Error reading request body", http.StatusInternalServerError) - return - } - defer r.Body.Close() - fmt.Printf(">> Body: %s\n", string(body)) + // body, err := io.ReadAll(r.Body) + // if err != nil { + // http.Error(w, "Error reading request body", http.StatusInternalServerError) + // return + // } + // defer r.Body.Close() + // fmt.Printf(">> Body: %s\n", string(body)) + // w.Header().Set("Content-Type", "text/plain") + // w.Write(body) w.Header().Set("Content-Type", "text/plain") - w.Write(body) + w.Write([]byte("Hello, World!")) } func main() { diff --git a/x/net/http/request.go b/x/net/http/request.go index e8ad072..0d4a600 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -12,6 +12,7 @@ import ( "github.com/goplus/llgo/c" "github.com/goplus/llgo/rust/hyper" + "github.com/goplus/llgo/c/libuv" ) type Request struct { @@ -35,14 +36,14 @@ type Request struct { timeout time.Duration } -func (conn *conn) readRequest(srv *Server, hyperReq *hyper.Request) (*Request, error) { +func readRequest(executor *hyper.Executor, hyperReq *hyper.Request, requestNotifyHandle *libuv.Async, remoteAddr string) (*Request, error) { println("[debug] readRequest called") req := Request{ Header: make(Header), timeout: 0, Body: nil, } - req.RemoteAddr = conn.remoteAddr + req.RemoteAddr = remoteAddr headers := hyperReq.Headers() if headers != nil { @@ -135,10 +136,10 @@ func (conn *conn) readRequest(srv *Server, hyperReq *hyper.Request) (*Request, e body := hyperReq.Body() if body != nil { - task := body.Data() + //task := body.Data() taskFlag := getBodyTask - requestBody := newRequestBody(conn.asyncHandle) + requestBody := newRequestBody(requestNotifyHandle) req.Body = requestBody taskData := taskData{ @@ -146,28 +147,17 @@ func (conn *conn) readRequest(srv *Server, hyperReq *hyper.Request) (*Request, e responseBody: nil, requestBody: requestBody, taskFlag: taskFlag, - server: srv, + executor: executor, } - task.SetUserdata(c.Pointer(&taskData), nil) - conn.asyncHandle.SetData(c.Pointer(&taskData)) + requestNotifyHandle.SetData(c.Pointer(&taskData)) fmt.Println("[debug] async task set") - if task != nil { - r := srv.executor.Push(task) - if r != hyper.OK { - fmt.Printf("failed to push body foreach task: %d\n", r) - task.Free() - return nil, fmt.Errorf("failed to push body foreach task: %v", r) - } - } else { - return nil, fmt.Errorf("failed to create body foreach task") - } } else { return nil, fmt.Errorf("failed to get request body") } - hyperReq.Free() + //hyperReq.Free() return &req, nil } diff --git a/x/net/http/request_body.go b/x/net/http/request_body.go index 1ca7f6b..e4424c5 100644 --- a/x/net/http/request_body.go +++ b/x/net/http/request_body.go @@ -3,7 +3,6 @@ package http import ( "errors" "fmt" - "io" "github.com/goplus/llgo/c/libuv" ) @@ -31,33 +30,34 @@ func newRequestBody(asyncHandle *libuv.Async) *requestBody { } func (rb *requestBody) Read(p []byte) (n int, err error) { - fmt.Println("[debug] requestBody.Read called") - // If there are still unread chunks, read them first - if len(rb.chunk) > 0 { - n = copy(p, rb.chunk) - rb.chunk = rb.chunk[n:] - return n, nil + fmt.Println("[debug] RequestBody Read called") + select { + case <-rb.done: + err = rb.readCloseError() + return + default: } - // Attempt to read a new chunk from a channel - select { - case chunk, ok := <-rb.readCh: - if !ok { - // The channel has been closed, indicating that all data has been read - return 0, rb.readCloseError() - } - n = copy(p, chunk) - if n < len(chunk) { - // If the capacity of p is insufficient to hold the whole chunk, save the rest of the chunk - rb.chunk = chunk[n:] + for n < len(p) { + if len(rb.chunk) == 0 { + rb.asyncHandle.Send() + fmt.Println("[debug] RequestBody Read asyncHandle.Send called") + select { + case chunk := <-rb.readCh: + rb.chunk = chunk + fmt.Println("[debug] RequestBody Read chunk received") + case <-rb.done: + err = rb.readCloseError() + return + } } - fmt.Println("[debug] requestBody.Read async send") - rb.asyncHandle.Send() - return n, nil - case <-rb.done: - // If the done channel is closed, the read needs to be terminated - return 0, rb.readCloseError() + + copied := copy(p[n:], rb.chunk) + n += copied + rb.chunk = rb.chunk[copied:] } + + return } func (rb *requestBody) readCloseError() error { @@ -69,8 +69,11 @@ func (rb *requestBody) readCloseError() error { func (rb *requestBody) closeRead(err error) error { fmt.Println("[debug] RequestBody closeRead called") + if rb.rerr != nil { + return nil + } if err == nil { - err = io.EOF + err = ErrClosedRequestBody } rb.rerr = err close(rb.done) diff --git a/x/net/http/response.go b/x/net/http/response.go index acd853f..7798925 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -14,7 +14,6 @@ type response struct { statusCode int written bool body []byte - server *Server hyperChannel *hyper.ResponseChannel hyperResp *hyper.Response } @@ -29,7 +28,7 @@ type taskData struct { hyperBody *hyper.Body responseBody *responseBodyRaw requestBody *requestBody - server *Server + executor *hyper.Executor taskFlag taskFlag } @@ -42,16 +41,14 @@ const ( var DefaultChunkSize uintptr = 8192 -func newResponse(server *Server, hyperChannel *hyper.ResponseChannel) *response { +func newResponse(hyperChannel *hyper.ResponseChannel) *response { fmt.Printf("[debug] newResponse called\n") return &response{ header: make(Header), - hyperChannel: hyperChannel, - server: server, - statusCode: 200, written: false, - body: nil, + statusCode: 200, + hyperChannel: hyperChannel, hyperResp: hyper.NewResponse(), } } @@ -84,7 +81,7 @@ func (r *response) WriteHeader(statusCode int) { fmt.Printf("[debug] < HTTP/1.1 %d\n", statusCode) for key, values := range r.header { for _, value := range values { - fmt.Printf("< %s: %s\n", key, value) + fmt.Printf("[debug] < %s: %s\n", key, value) } } @@ -116,8 +113,6 @@ func (r *response) finalize() error { r.WriteHeader(200) } - r.hyperResp = hyper.NewResponse() - if r.hyperResp == nil { return fmt.Errorf("failed to create response") } @@ -134,9 +129,10 @@ func (r *response) finalize() error { return fmt.Errorf("failed to create body") } taskData := &taskData{ - hyperBody: nil, + hyperBody: body, responseBody: &bodyData, - server: r.server, + requestBody: nil, + executor: nil, taskFlag: setBodyTask, } body.SetDataFunc(setBodyDataFunc) @@ -163,6 +159,7 @@ func setBodyDataFunc(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) fmt.Println("[debug] taskData is nil") return hyper.PollError } + fmt.Println("[debug] taskData is not nil") body := taskData.responseBody if body.len > 0 { diff --git a/x/net/http/server.go b/x/net/http/server.go index aae76de..e629cfe 100644 --- a/x/net/http/server.go +++ b/x/net/http/server.go @@ -18,6 +18,14 @@ import ( "github.com/goplus/llgo/x/net" ) +// var requestNotifyHandle *libuv.Async +const _SC_NPROCESSORS_ONLN c.Int = 58 + +var cpuCount int +var asyncHandleMapMu sync.Mutex +var asyncHandleMap = make(map[int]*libuv.Async) +var connID int32 + type Handler interface { ServeHTTP(ResponseWriter, *Request) } @@ -32,45 +40,182 @@ type Server struct { Addr string Handler Handler + // uvLoop *libuv.Loop + // uvServer libuv.Tcp + + isShutdown atomic.Bool + // idleHandle libuv.Idle + + // executor *hyper.Executor + // http1Opts *hyper.Http1ServerconnOptions + // http2Opts *hyper.Http2ServerconnOptions + + eventLoop []*eventLoop + + // mu sync.Mutex + // activeConnections map[*conn]struct{} +} + +type eventLoop struct { uvLoop *libuv.Loop uvServer libuv.Tcp - inShutdown atomic.Bool idleHandle libuv.Idle - executor *hyper.Executor + executor *hyper.Executor + http1Opts *hyper.Http1ServerconnOptions + http2Opts *hyper.Http2ServerconnOptions + isShutdown atomic.Bool mu sync.Mutex activeConnections map[*conn]struct{} } type conn struct { + asyncID int stream libuv.Tcp pollHandle libuv.Poll eventMask c.Uint readWaker *hyper.Waker writeWaker *hyper.Waker - http1Opts *hyper.Http1ServerconnOptions - http2Opts *hyper.Http2ServerconnOptions isClosing atomic.Bool closedHandles int32 remoteAddr string - asyncHandle *libuv.Async } type serviceUserdata struct { - host [128]c.Char - port [8]c.Char - conn *conn - server *Server + asyncHandleID int + host [128]c.Char + port [8]c.Char + executor *hyper.Executor +} + +type threadArg struct { + host string + port int + eventLoop *eventLoop } func NewServer(addr string) *Server { - activeClients := make(map[*conn]struct{}) return &Server{ - Addr: addr, - Handler: DefaultServeMux, + Addr: addr, + Handler: DefaultServeMux, + } +} + +func newEventLoop() (*eventLoop, error) { + activeClients := make(map[*conn]struct{}) + el := &eventLoop{ activeConnections: activeClients, } + + executor := hyper.NewExecutor() + if executor == nil { + return nil, fmt.Errorf("failed to create Executor") + } + el.executor = executor + + http1Opts := hyper.Http1ServerconnOptionsNew(el.executor) + if http1Opts == nil { + return nil, fmt.Errorf("failed to create http1_opts") + } + if hyperResult := http1Opts.HeaderReadTimeout(5 * 1000); hyperResult != hyper.OK { + return nil, fmt.Errorf("failed to set header read timeout for http1_opts") + } + el.http1Opts = http1Opts + + http2Opts := hyper.Http2ServerconnOptionsNew(el.executor) + if http2Opts == nil { + return nil, fmt.Errorf("failed to create http2_opts") + } + if hyperResult := http2Opts.KeepAliveInterval(5); hyperResult != hyper.OK { + return nil, fmt.Errorf("failed to set keep alive interval for http2_opts") + } + if hyperResult := http2Opts.KeepAliveTimeout(5); hyperResult != hyper.OK { + return nil, fmt.Errorf("failed to set keep alive timeout for http2_opts") + } + el.http2Opts = http2Opts + + el.uvLoop = libuv.LoopNew() + if el.uvLoop == nil { + return nil, fmt.Errorf("failed to get default loop") + } + el.uvLoop.SetData(unsafe.Pointer(el)) + + if r := libuv.InitTcpEx(el.uvLoop, &el.uvServer, cnet.AF_INET); r != 0 { + return nil, fmt.Errorf("failed to init TCP: %v", libuv.Strerror(libuv.Errno(r))) + } + + return el, nil +} + +func (el *eventLoop) run(host string, port int) error { + var sockaddr cnet.SockaddrIn + if r := libuv.Ip4Addr(c.AllocaCStr(host), c.Int(port), &sockaddr); r != 0 { + return fmt.Errorf("failed to create IP address: %v", libuv.Strerror(libuv.Errno(r))) + } + + // Set SO_REUSEADDR + // yes := c.Int(1) + // fmt.Println("[debug] el.uvServer.GetIoWatcherFd(): ", el.uvServer.GetIoWatcherFd()) + // result := cnet.SetSockOpt(el.uvServer.GetIoWatcherFd(), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, unsafe.Pointer(&yes), c.Uint(unsafe.Sizeof(yes))) + // if result != 0 { + // return fmt.Errorf("failed to set SO_REUSEADDR: %v", result) + // } + + // result = cnet.SetSockOpt(el.uvServer.GetIoWatcherFd(), syscall.SOL_SOCKET, syscall.SO_REUSEPORT, unsafe.Pointer(&yes), c.Uint(unsafe.Sizeof(yes))) + // if result != 0 { + // return fmt.Errorf("failed to set SO_REUSEADDR: %v", result) + // } + + if err := setReuseAddr(&el.uvServer); err != nil { + return fmt.Errorf("failed to set SO_REUSEADDR: %v", err) + } + + if r := el.uvServer.Bind((*cnet.SockAddr)(unsafe.Pointer(&sockaddr)), 0); r != 0 { + return fmt.Errorf("failed to bind: %v", libuv.Strerror(libuv.Errno(r))) + } + + //el.uvServer.Data = unsafe.Pointer(el) + if err := (*libuv.Stream)(&el.uvServer).Listen(128, onNewConnection); err != 0 { + return fmt.Errorf("failed to listen: %v", err) + } + + if r := libuv.InitIdle(el.uvLoop, &el.idleHandle); r != 0 { + return fmt.Errorf("failed to initialize idle handler: %d", r) + } + + (*libuv.Handle)(unsafe.Pointer(&el.idleHandle)).SetData(unsafe.Pointer(el)) + + if r := el.idleHandle.Start(onIdle); r != 0 { + return fmt.Errorf("failed to start idle handler: %d", r) + } + + //os.Setenv("UV_THREADPOOL_SIZE", "1") + + if r := el.uvLoop.Run(libuv.RUN_DEFAULT); r != 0 { + return fmt.Errorf("error in event loop: %d", r) + } + + return nil +} + +func setReuseAddr(handle *libuv.Tcp) error { + var fd libuv.OsFd + result := (*libuv.Handle)(unsafe.Pointer(handle)).Fileno(&fd) + if result != 0 { + return fmt.Errorf("Error getting file descriptor") + } + + yes := c.Int(1) + if err := cnet.SetSockOpt(c.Int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, unsafe.Pointer(&yes), c.Uint(unsafe.Sizeof(yes))); err != 0 { + return fmt.Errorf("Error setting SO_REUSEADDR") + } + + if err := cnet.SetSockOpt(c.Int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEPORT, unsafe.Pointer(&yes), c.Uint(unsafe.Sizeof(yes))); err != 0 { + return fmt.Errorf("Error setting SO_REUSEPORT") + } + + return nil } // ErrServerClosed is returned by the [Server.Serve], [ServeTLS], [ListenAndServe], @@ -83,15 +228,30 @@ func ListenAndServe(addr string, handler Handler) error { } func (srv *Server) ListenAndServe() error { - srv.uvLoop = libuv.DefaultLoop() - if srv.uvLoop == nil { - return fmt.Errorf("failed to get default loop") + cpuCount = int(c.Sysconf(_SC_NPROCESSORS_ONLN)) + if cpuCount <= 0 { + cpuCount = 4 } - if r := libuv.InitTcp(srv.uvLoop, &srv.uvServer); r != 0 { - return fmt.Errorf("failed to init TCP: %v", libuv.Strerror(libuv.Errno(r))) + fmt.Printf("[debug] cpuCount: %d\n", cpuCount) + + for i := 0; i < cpuCount; i++ { + el, err := newEventLoop() + if err != nil { + return fmt.Errorf("failed to create event loop: %v", err) + } + srv.eventLoop = append(srv.eventLoop, el) } + // el, err := newEventLoop() + // if err != nil { + // return fmt.Errorf("failed to create event loop: %v", err) + // } + // el2, err := newEventLoop() + // if err != nil { + // return fmt.Errorf("failed to create event loop: %v", err) + // } + host, port, err := net.SplitHostPort(srv.Addr) if err != nil { return fmt.Errorf("invalid address %q: %v", srv.Addr, err) @@ -102,50 +262,77 @@ func (srv *Server) ListenAndServe() error { return fmt.Errorf("invalid port number: %v", err) } - var sockaddr cnet.SockaddrIn - if r := libuv.Ip4Addr(c.AllocaCStr(host), c.Int(portNum), &sockaddr); r != 0 { - return fmt.Errorf("failed to create IP address: %v", libuv.Strerror(libuv.Errno(r))) - } - - if r := srv.uvServer.Bind((*cnet.SockAddr)(unsafe.Pointer(&sockaddr)), 0); r != 0 { - return fmt.Errorf("failed to bind: %v", libuv.Strerror(libuv.Errno(r))) - } - - // Set SO_REUSEADDR - yes := c.Int(1) - result := cnet.SetSockOpt(srv.uvServer.GetIoWatcherFd(), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, unsafe.Pointer(&yes), c.Uint(unsafe.Sizeof(yes))) - if result != 0 { - return fmt.Errorf("failed to set SO_REUSEADDR: %v", result) - } + // go func() { + // err = el2.run(host, portNum) + // if err != nil { + // println("[debug] failed to run event loop: %v", err) + // } + // }() + + //TODO(hackerchai): new logic for poll + // go func() { + // for { + // task := el.executor.Poll() + // if task != nil { + // handleTask(task) + // task = el.executor.Poll() + // } + // } + // }() + + // err = el.run(host, portNum) + // if err != nil { + // return fmt.Errorf("failed to run event loop: %v", err) + // } + + // Create a libuv thread pool with the same number of threads as event loops + threadPool := make([]*libuv.Thread, len(srv.eventLoop)) + for i := range threadPool { + threadPool[i] = &libuv.Thread{} + } + + // Start each event loop in its own thread + for i, el := range srv.eventLoop { + threadArg := &threadArg{ + host: host, + port: portNum, + eventLoop: el, + } - srv.uvServer.Data = unsafe.Pointer(srv) - if err := (*libuv.Stream)(&srv.uvServer).Listen(128, onNewConnection); err != 0 { - return fmt.Errorf("failed to listen: %v", err) - } + fmt.Printf("[debug] Creating thread %d\n", i) - if r := libuv.InitIdle(srv.uvLoop, &srv.idleHandle); r != 0 { - fmt.Fprintf(os.Stderr, "Failed to initialize idle handler: %d\n", r) - os.Exit(1) + if result := threadPool[i].Create(runEventLoopInThread, unsafe.Pointer(threadArg)); result != 0 { + return fmt.Errorf("failed to create thread: %v", err) + } } - (*libuv.Handle)(unsafe.Pointer(&srv.idleHandle)).SetData(unsafe.Pointer(srv)) - - if r := srv.idleHandle.Start(onIdle); r != 0 { - fmt.Fprintf(os.Stderr, "Failed to start idle handler: %d\n", r) - os.Exit(1) + // Wait for all threads to complete + for _, thread := range threadPool { + if result := thread.Join(); result != 0 { + fmt.Printf("[debug] Failed to join thread: %v\n", err) + } } fmt.Printf("Listening on %s\n", srv.Addr) - res := srv.uvLoop.Run(libuv.RUN_DEFAULT) - if res != 0 { - fmt.Fprintf(os.Stderr, "Error in event loop: %v\n", res) - os.Exit(1) - } + // if r := srv.uvServer.Bind((*cnet.SockAddr)(unsafe.Pointer(&sockaddr)), 0); r != 0 { + // return fmt.Errorf("failed to bind: %v", libuv.Strerror(libuv.Errno(r))) + // } return nil } +func runEventLoopInThread(arg c.Pointer) { + tArg := (*threadArg)(arg) + host := tArg.host + port := tArg.port + el := tArg.eventLoop + err := el.run(host, port) + if err != nil { + fmt.Printf("[debug] failed to run event loop: %v", err) + } +} + func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) { DefaultServeMux.HandleFunc(pattern, handler) } @@ -157,9 +344,19 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { return } - srv := (*Server)(serverStream.Data) - if srv == nil { - fmt.Fprintf(os.Stderr, "Server is nil\n") + // srv := (*Server)((*libuv.Handle)(unsafe.Pointer(serverStream)).GetData()) + // if srv == nil { + // fmt.Fprintf(os.Stderr, "Server is nil\n") + // return + // } + // el := (*eventLoop)((*libuv.Handle)(unsafe.Pointer(serverStream)).GetData()) + // if el == nil { + // fmt.Fprintf(os.Stderr, "Event loop is nil\n") + // return + // } + el := (*eventLoop)((*libuv.Handle)(unsafe.Pointer(serverStream)).GetLoop().GetData()) + if el == nil { + fmt.Fprintf(os.Stderr, "Event loop is nil\n") return } @@ -171,36 +368,43 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { fmt.Println("[debug] async handle creating") - conn.asyncHandle = &libuv.Async{} - srv.uvLoop.Async(conn.asyncHandle, onAsync) + requestNotifyHandle := &libuv.Async{} + el.uvLoop.Async(requestNotifyHandle, onAsync) + fmt.Println("[debug] async handle created") + asyncHandleMapMu.Lock() + asyncHandleMap[conn.asyncID] = requestNotifyHandle + asyncHandleMapMu.Unlock() + fmt.Println("[debug] async handle added to map") - libuv.InitTcp(srv.uvLoop, &conn.stream) + libuv.InitTcp(el.uvLoop, &conn.stream) conn.stream.Data = unsafe.Pointer(conn) if serverStream.Accept((*libuv.Stream)(unsafe.Pointer(&conn.stream))) == 0 { fmt.Println("[debug] Accepted new connection") - r := libuv.PollInit(srv.uvLoop, &conn.pollHandle, libuv.OsFd(conn.stream.GetIoWatcherFd())) - if r < 0 { - fmt.Fprintf(os.Stderr, "uv_poll_init error: %s\n", libuv.Strerror(libuv.Errno(r))) - (*libuv.Handle)(unsafe.Pointer(&conn.stream)).Close(nil) - return - } - (*libuv.Handle)(unsafe.Pointer(&conn.pollHandle)).Data = unsafe.Pointer(conn) - - if !updateConnRegistrations(conn) { - (*libuv.Handle)(unsafe.Pointer(&conn.pollHandle)).Close(nil) + userData := createServiceUserdata() + if userData == nil { + fmt.Fprintf(os.Stderr, "Failed to create service userdata\n") (*libuv.Handle)(unsafe.Pointer(&conn.stream)).Close(nil) return } - userdata := createServiceUserdata() - userdata.server = srv - if userdata == nil { - fmt.Fprintf(os.Stderr, "Failed to create service userdata\n") + if el.executor == nil { + fmt.Fprintf(os.Stderr, "Failed to get executor\n") (*libuv.Handle)(unsafe.Pointer(&conn.stream)).Close(nil) return } + //userData.setExecutor(srv.executor) + userData.executor = el.executor + userData.asyncHandleID = conn.asyncID + + // if srv.Handler == nil { + // fmt.Fprintf(os.Stderr, "Failed to get handler\n") + // (*libuv.Handle)(unsafe.Pointer(&conn.stream)).Close(nil) + // return + // } + //userData.handler = srv.Handler + //userData.requestNotifyHandle = requestNotifyHandle var addr cnet.SockaddrStorage addrlen := c.Int(unsafe.Sizeof(addr)) @@ -208,65 +412,42 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { if addr.Family == cnet.AF_INET { s := (*cnet.SockaddrIn)(unsafe.Pointer(&addr)) - libuv.Ip4Name(s, (*c.Char)(&userdata.host[0]), unsafe.Sizeof(userdata.host)) - c.Snprintf((*c.Char)(&userdata.port[0]), unsafe.Sizeof(userdata.port), c.Str("%d"), cnet.Ntohs(s.Port)) + libuv.Ip4Name(s, (*c.Char)(&userData.host[0]), unsafe.Sizeof(userData.host)) + c.Snprintf((*c.Char)(&userData.port[0]), unsafe.Sizeof(userData.port), c.Str("%d"), cnet.Ntohs(s.Port)) } else if addr.Family == cnet.AF_INET6 { s := (*cnet.SockaddrIn6)(unsafe.Pointer(&addr)) - libuv.Ip6Name(s, (*c.Char)(&userdata.host[0]), unsafe.Sizeof(userdata.host)) - c.Snprintf((*c.Char)(&userdata.port[0]), unsafe.Sizeof(userdata.port), c.Str("%d"), cnet.Ntohs(s.Port)) + libuv.Ip6Name(s, (*c.Char)(&userData.host[0]), unsafe.Sizeof(userData.host)) + c.Snprintf((*c.Char)(&userData.port[0]), unsafe.Sizeof(userData.port), c.Str("%d"), cnet.Ntohs(s.Port)) + } + + //TODO(hackerchai): use userData.host and userData.port + conn.remoteAddr = c.GoString((*c.Char)(&userData.host[0])) + ":" + c.GoString((*c.Char)(&userData.port[0])) + + r := libuv.PollInit(el.uvLoop, &conn.pollHandle, libuv.OsFd(conn.stream.GetIoWatcherFd())) + if r < 0 { + fmt.Fprintf(os.Stderr, "uv_poll_init error: %s\n", libuv.Strerror(libuv.Errno(r))) + (*libuv.Handle)(unsafe.Pointer(&conn.stream)).Close(nil) + return } - conn.remoteAddr = c.GoString((*c.Char)(&userdata.host[0])) + ":" + c.GoString((*c.Char)(&userdata.port[0])) + (*libuv.Handle)(unsafe.Pointer(&conn.pollHandle)).Data = unsafe.Pointer(conn) - executor := hyper.NewExecutor() - if executor == nil { - fmt.Fprintf(os.Stderr, "Failed to create Executor\n") + if !updateConnRegistrations(conn) { (*libuv.Handle)(unsafe.Pointer(&conn.pollHandle)).Close(nil) (*libuv.Handle)(unsafe.Pointer(&conn.stream)).Close(nil) return } - srv.executor = executor fmt.Println("[debug] Conn created") - srv.trackConn(conn, true) + el.trackConn(conn, true) fmt.Println("[debug] Conn tracked") - userdata.conn = conn - io := createIo(conn) service := hyper.ServiceNew(serverCallback) - service.SetUserdata(unsafe.Pointer(userdata), nil) - http1Opts := hyper.Http1ServerconnOptionsNew(srv.executor) - if http1Opts == nil { - fmt.Fprintf(os.Stderr, "Failed to create http1_opts\n") - os.Exit(1) - } - result := http1Opts.HeaderReadTimeout(5 * 1000) - if result != hyper.OK { - fmt.Fprintf(os.Stderr, "Failed to set header read timeout for http1_opts\n") - os.Exit(1) - } - conn.http1Opts = http1Opts + service.SetUserdata(unsafe.Pointer(userData), nil) - http2Opts := hyper.Http2ServerconnOptionsNew(srv.executor) - if http2Opts == nil { - fmt.Fprintf(os.Stderr, "Failed to create http2_opts\n") - os.Exit(1) - } - result = http2Opts.KeepAliveInterval(5) - if result != hyper.OK { - fmt.Fprintf(os.Stderr, "Failed to set keep alive interval for http2_opts\n") - os.Exit(1) - } - result = http2Opts.KeepAliveTimeout(5) - if result != hyper.OK { - fmt.Fprintf(os.Stderr, "Failed to set keep alive timeout for http2_opts\n") - os.Exit(1) - } - conn.http2Opts = http2Opts - - serverconn := hyper.ServeHttpXConnection(http1Opts, http2Opts, io, service) - srv.executor.Push(serverconn) + serverConn := hyper.ServeHttpXConnection(el.http1Opts, el.http2Opts, io, service) + el.executor.Push(serverConn) } else { fmt.Println("[debug] Client not accepted") (*libuv.Handle)(unsafe.Pointer(&conn.pollHandle)).Close(nil) @@ -277,10 +458,14 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { func onAsync(asyncHandle *libuv.Async) { fmt.Println("[debug] onAsync called") taskData := (*taskData)(asyncHandle.GetData()) + if taskData == nil { + fmt.Println("[debug] taskData is nil") + return + } dataTask := taskData.hyperBody.Data() dataTask.SetUserdata(c.Pointer(taskData), nil) if dataTask != nil { - r := taskData.server.executor.Push(dataTask) + r := taskData.executor.Push(dataTask) fmt.Printf("[debug] onAsync push data task: %d\n", r) if r != hyper.OK { fmt.Printf("failed to push data task: %d\n", r) @@ -290,26 +475,49 @@ func onAsync(asyncHandle *libuv.Async) { } func onIdle(handle *libuv.Idle) { - srv := (*Server)((*libuv.Handle)(unsafe.Pointer(handle)).GetData()) - if srv.executor != nil { - task := srv.executor.Poll() + // el := (*eventLoop)((*libuv.Handle)(unsafe.Pointer(handle)).GetData()) + el := (*eventLoop)((*libuv.Handle)(unsafe.Pointer(handle)).GetLoop().GetData()) + if el.executor != nil { + task := el.executor.Poll() for task != nil { - srv.handleTask(task) - task = srv.executor.Poll() + handleTask(task) + task = el.executor.Poll() } } - if srv.shuttingDown() { + if el.shuttingDown() { fmt.Println("Shutdown initiated, cleaning up...") handle.Stop() } } -func serverCallback(userdata unsafe.Pointer, hyperReq *hyper.Request, channel *hyper.ResponseChannel) { - userData := (*serviceUserdata)(userdata) - srv := userData.server - if srv == nil { - fmt.Fprintf(os.Stderr, "Error: Received null server\n") +func doNothing(handle *libuv.Idle) { + return +} + +// func (s *serviceUserdata) setExecutor(exec *hyper.Executor) { +// s.executor.Store(exec) +// } + +// func (s *serviceUserdata) getExecutor() *hyper.Executor { +// return s.executor.Load() +// } + +func serverCallback(userData unsafe.Pointer, hyperReq *hyper.Request, channel *hyper.ResponseChannel) { + payload := (*serviceUserdata)(userData) + // srv := userData.server + // if srv == nil { + // fmt.Fprintf(os.Stderr, "Error: Received null server\n") + // return + // } + if payload == nil { + fmt.Fprintf(os.Stderr, "Error: Received null userData\n") + return + } + + executor := payload.executor + if executor == nil { + fmt.Fprintf(os.Stderr, "Error: Received null executor\n") return } @@ -318,25 +526,48 @@ func serverCallback(userdata unsafe.Pointer, hyperReq *hyper.Request, channel *h return } - req, err := userData.conn.readRequest(srv, hyperReq) + connID := payload.asyncHandleID + asyncHandleMapMu.Lock() + requestNotifyHandle, ok := asyncHandleMap[connID] + asyncHandleMapMu.Unlock() + if !ok { + fmt.Println("[debug] requestNotifyHandle not found") + return + } + + host := payload.host + port := payload.port + remoteAddr := c.GoString(&host[0]) + ":" + c.GoString(&port[0]) + fmt.Printf("[debug] Remote address: %s\n", remoteAddr) + + req, err := readRequest(executor, hyperReq, requestNotifyHandle, remoteAddr) if err != nil { fmt.Printf("Error creating request: %v\n", err) return } - res := newResponse(srv, channel) + res := newResponse(channel) fmt.Println("[debug] Response created") //TODO(hackerchai): replace with no goroutine - // userData.server.Handler.ServeHTTP(res, req) - // res.finalize() - go func() { - userData.server.Handler.ServeHTTP(res, req) - res.finalize() - }() -} - -func (srv *Server) handleTask(task *hyper.Task) { + fmt.Println("[debug] Serving HTTP") + DefaultServeMux.ServeHTTP(res, req) + //srv.Handler.ServeHTTP(res, req) + fmt.Println("[debug] Response finalizing") + res.finalize() + fmt.Println("[debug] Response finalized") + + // go func() { + // fmt.Println("[debug] Serving HTTP") + // DefaultServeMux.ServeHTTP(res, req) + // //srv.Handler.ServeHTTP(res, req) + // fmt.Println("[debug] Response finalizing") + // res.finalize() + // fmt.Println("[debug] Response finalized") + // }() +} + +func handleTask(task *hyper.Task) { hyperTaskType := task.Type() // Debug fmt.Printf("[debug] Task type: %s\n", getTaskTypeString(hyperTaskType)) @@ -351,8 +582,7 @@ func (srv *Server) handleTask(task *hyper.Task) { if payload != nil { switch payload.taskFlag { case getBodyTask: - handleGetBodyTask(srv, hyperTaskType, task, payload) - return + handleGetBodyTask(hyperTaskType, task, payload) case setBodyTask: handleSetBodyTask(hyperTaskType, task) return @@ -374,10 +604,13 @@ func (srv *Server) handleTask(task *hyper.Task) { fmt.Println("[debug] Server connection task handled") task.Free() return + default: + fmt.Println("[debug] Unknown task type") + return } } -func handleGetBodyTask(srv *Server, hyperTaskType hyper.TaskReturnType, task *hyper.Task, payload *taskData) { +func handleGetBodyTask(hyperTaskType hyper.TaskReturnType, task *hyper.Task, payload *taskData) { switch hyperTaskType { case hyper.TaskError: handleTaskError(task) @@ -385,7 +618,9 @@ func handleGetBodyTask(srv *Server, hyperTaskType hyper.TaskReturnType, task *hy handleTaskBuffer(task, payload) case hyper.TaskEmpty: fmt.Println("[debug] Get body task closing request body") - payload.requestBody.Close() + if payload.requestBody != nil { + payload.requestBody.Close() + } task.Free() } } @@ -439,16 +674,16 @@ func getTaskTypeString(taskType hyper.TaskReturnType) string { } } -func (s *Server) trackConn(c *conn, add bool) { - s.mu.Lock() - defer s.mu.Unlock() - if s.activeConnections == nil { - s.activeConnections = make(map[*conn]struct{}) +func (el *eventLoop) trackConn(c *conn, add bool) { + el.mu.Lock() + defer el.mu.Unlock() + if el.activeConnections == nil { + el.activeConnections = make(map[*conn]struct{}) } if add { - s.activeConnections[c] = struct{}{} + el.activeConnections[c] = struct{}{} } else { - delete(s.activeConnections, c) + delete(el.activeConnections, c) } } @@ -526,7 +761,6 @@ func writeCb(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uint } func onPoll(handle *libuv.Poll, status c.Int, events c.Int) { - fmt.Printf("[debug] onPoll called\n") conn := (*conn)((*libuv.Handle)(unsafe.Pointer(handle)).GetData()) if status < 0 { @@ -577,40 +811,14 @@ func createConnData() (*conn, error) { } conn.isClosing.Store(false) conn.closedHandles = 0 + conn.asyncID = int(connID) + 1 return conn, nil } func freeConnData(userdata c.Pointer) { conn := (*conn)(userdata) - if conn != nil && !conn.isClosing.Swap(true) { - fmt.Printf("[debug] Closing connection...\n") - if conn.readWaker != nil { - conn.readWaker.Free() - conn.readWaker = nil - } - if conn.writeWaker != nil { - conn.writeWaker.Free() - conn.writeWaker = nil - } - - if conn.http1Opts != nil { - conn.http1Opts.Free() - conn.http1Opts = nil - } - if conn.http2Opts != nil { - conn.http2Opts.Free() - conn.http2Opts = nil - } - - if (*libuv.Handle)(unsafe.Pointer(&conn.pollHandle)).IsClosing() == 0 { - (*libuv.Handle)(unsafe.Pointer(&conn.pollHandle)).Close(nil) - } - - if (*libuv.Handle)(unsafe.Pointer(&conn.stream)).IsClosing() == 0 { - (*libuv.Handle)(unsafe.Pointer(&conn.stream)).Close(nil) - } - } + conn.Close() } func closeWalkCb(handle *libuv.Handle, arg c.Pointer) { @@ -620,31 +828,48 @@ func closeWalkCb(handle *libuv.Handle, arg c.Pointer) { } func (srv *Server) Close() error { - srv.inShutdown.Store(true) - srv.mu.Lock() - defer srv.mu.Unlock() + srv.isShutdown.Store(true) - for c := range srv.activeConnections { - c.Close() + // for c := range el.activeConnections { + // c.Close() - delete(srv.activeConnections, c) - } + // delete(srv.activeConnections, c) + // } - if srv.executor != nil { - srv.executor.Free() - srv.executor = nil - } + // if srv.executor != nil { + // srv.executor.Free() + // srv.executor = nil + // } + + // if exec := srv.executor; exec != nil { + // srv.executor = nil + // exec.Free() + // } - srv.uvLoop.Walk(closeWalkCb, nil) - srv.uvLoop.Run(libuv.RUN_ONCE) - (*libuv.Handle)(unsafe.Pointer(&srv.uvServer)).Close(nil) + // if srv.http1Opts != nil { + // srv.http1Opts.Free() + // srv.http1Opts = nil + // } - srv.uvLoop.Close() + // if srv.http2Opts != nil { + // srv.http2Opts.Free() + // srv.http2Opts = nil + // } + + // srv.uvLoop.Walk(closeWalkCb, nil) + // srv.uvLoop.Run(libuv.RUN_ONCE) + // (*libuv.Handle)(unsafe.Pointer(&srv.uvServer)).Close(nil) + + // srv.uvLoop.Close() return nil } func (s *Server) shuttingDown() bool { - return s.inShutdown.Load() + return s.isShutdown.Load() +} + +func (el *eventLoop) shuttingDown() bool { + return el.isShutdown.Load() } func (c *conn) shuttingDown() bool { @@ -652,31 +877,25 @@ func (c *conn) shuttingDown() bool { } func (c *conn) Close() { - c.isClosing.Store(true) - if c.shuttingDown() { - return - } + if c != nil && !c.isClosing.Swap(true) { + fmt.Printf("[debug] Closing connection...\n") + if c.readWaker != nil { + c.readWaker.Free() + c.readWaker = nil + } + if c.writeWaker != nil { + c.writeWaker.Free() + c.writeWaker = nil + } - if c.readWaker != nil { - c.readWaker.Free() - c.readWaker = nil - } - if c.writeWaker != nil { - c.writeWaker.Free() - c.writeWaker = nil - } + if (*libuv.Handle)(unsafe.Pointer(&c.pollHandle)).IsClosing() == 0 { + (*libuv.Handle)(unsafe.Pointer(&c.pollHandle)).Close(nil) + } - if c.http1Opts != nil { - c.http1Opts.Free() - c.http1Opts = nil - } - if c.http2Opts != nil { - c.http2Opts.Free() - c.http2Opts = nil + if (*libuv.Handle)(unsafe.Pointer(&c.stream)).IsClosing() == 0 { + (*libuv.Handle)(unsafe.Pointer(&c.stream)).Close(nil) + } } - - (*libuv.Handle)(unsafe.Pointer(&c.pollHandle)).Close(nil) - (*libuv.Handle)(unsafe.Pointer(&c.stream)).Close(nil) } type HandlerFunc func(ResponseWriter, *Request) diff --git a/x/net/http/servermux.go b/x/net/http/servermux.go index a210a7b..e111186 100644 --- a/x/net/http/servermux.go +++ b/x/net/http/servermux.go @@ -20,14 +20,21 @@ var DefaultServeMux = &ServeMux{m: make(map[string]muxEntry)} func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) { fmt.Printf("[debug] ServeHTTP called\n") + // NotFoundHandler().ServeHTTP(w, r) + // return h, pattern := mux.Handler(r) fmt.Printf("[debug] Handler found for pattern: %s\n", pattern) h.ServeHTTP(w, r) } func (mux *ServeMux) Handler(r *Request) (h Handler, pattern string) { + fmt.Printf("[debug] Mux Handler called\n") mux.mu.RLock() defer mux.mu.RUnlock() + if r.URL == nil { + fmt.Println("[debug] r.URL is nil") + } + fmt.Printf("[debug] Handler called: r.URL.Path = %s\n", r.URL.Path) h, pattern = mux.m[r.URL.Path].h, r.URL.Path if h == nil { From e8ea4124f17b0f171a00bec8547ee6cc53caa1e5 Mon Sep 17 00:00:00 2001 From: hackerchai Date: Thu, 19 Sep 2024 18:36:00 +0800 Subject: [PATCH 47/55] refactor(x/net/http): Implement close logic & other optimize Signed-off-by: hackerchai --- x/net/http/server.go | 249 ++++++++++++++----------------------------- 1 file changed, 78 insertions(+), 171 deletions(-) diff --git a/x/net/http/server.go b/x/net/http/server.go index e629cfe..ade6c5f 100644 --- a/x/net/http/server.go +++ b/x/net/http/server.go @@ -18,13 +18,14 @@ import ( "github.com/goplus/llgo/x/net" ) -// var requestNotifyHandle *libuv.Async const _SC_NPROCESSORS_ONLN c.Int = 58 - var cpuCount int + var asyncHandleMapMu sync.Mutex -var asyncHandleMap = make(map[int]*libuv.Async) -var connID int32 +var asyncHandleMap = make(map[int64]*libuv.Async) + +// connID is used to generate unique IDs for each connection +var connID int64 type Handler interface { ServeHTTP(ResponseWriter, *Request) @@ -39,21 +40,9 @@ type ResponseWriter interface { type Server struct { Addr string Handler Handler - - // uvLoop *libuv.Loop - // uvServer libuv.Tcp - isShutdown atomic.Bool - // idleHandle libuv.Idle - - // executor *hyper.Executor - // http1Opts *hyper.Http1ServerconnOptions - // http2Opts *hyper.Http2ServerconnOptions eventLoop []*eventLoop - - // mu sync.Mutex - // activeConnections map[*conn]struct{} } type eventLoop struct { @@ -71,7 +60,7 @@ type eventLoop struct { } type conn struct { - asyncID int + asyncID int64 stream libuv.Tcp pollHandle libuv.Poll eventMask c.Uint @@ -83,18 +72,12 @@ type conn struct { } type serviceUserdata struct { - asyncHandleID int + asyncHandleID c.Long host [128]c.Char port [8]c.Char executor *hyper.Executor } -type threadArg struct { - host string - port int - eventLoop *eventLoop -} - func NewServer(addr string) *Server { return &Server{ Addr: addr, @@ -154,19 +137,6 @@ func (el *eventLoop) run(host string, port int) error { return fmt.Errorf("failed to create IP address: %v", libuv.Strerror(libuv.Errno(r))) } - // Set SO_REUSEADDR - // yes := c.Int(1) - // fmt.Println("[debug] el.uvServer.GetIoWatcherFd(): ", el.uvServer.GetIoWatcherFd()) - // result := cnet.SetSockOpt(el.uvServer.GetIoWatcherFd(), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, unsafe.Pointer(&yes), c.Uint(unsafe.Sizeof(yes))) - // if result != 0 { - // return fmt.Errorf("failed to set SO_REUSEADDR: %v", result) - // } - - // result = cnet.SetSockOpt(el.uvServer.GetIoWatcherFd(), syscall.SOL_SOCKET, syscall.SO_REUSEPORT, unsafe.Pointer(&yes), c.Uint(unsafe.Sizeof(yes))) - // if result != 0 { - // return fmt.Errorf("failed to set SO_REUSEADDR: %v", result) - // } - if err := setReuseAddr(&el.uvServer); err != nil { return fmt.Errorf("failed to set SO_REUSEADDR: %v", err) } @@ -175,7 +145,6 @@ func (el *eventLoop) run(host string, port int) error { return fmt.Errorf("failed to bind: %v", libuv.Strerror(libuv.Errno(r))) } - //el.uvServer.Data = unsafe.Pointer(el) if err := (*libuv.Stream)(&el.uvServer).Listen(128, onNewConnection); err != 0 { return fmt.Errorf("failed to listen: %v", err) } @@ -190,8 +159,6 @@ func (el *eventLoop) run(host string, port int) error { return fmt.Errorf("failed to start idle handler: %d", r) } - //os.Setenv("UV_THREADPOOL_SIZE", "1") - if r := el.uvLoop.Run(libuv.RUN_DEFAULT); r != 0 { return fmt.Errorf("error in event loop: %d", r) } @@ -228,6 +195,7 @@ func ListenAndServe(addr string, handler Handler) error { } func (srv *Server) ListenAndServe() error { + connID = 0 cpuCount = int(c.Sysconf(_SC_NPROCESSORS_ONLN)) if cpuCount <= 0 { cpuCount = 4 @@ -243,15 +211,6 @@ func (srv *Server) ListenAndServe() error { srv.eventLoop = append(srv.eventLoop, el) } - // el, err := newEventLoop() - // if err != nil { - // return fmt.Errorf("failed to create event loop: %v", err) - // } - // el2, err := newEventLoop() - // if err != nil { - // return fmt.Errorf("failed to create event loop: %v", err) - // } - host, port, err := net.SplitHostPort(srv.Addr) if err != nil { return fmt.Errorf("invalid address %q: %v", srv.Addr, err) @@ -262,13 +221,6 @@ func (srv *Server) ListenAndServe() error { return fmt.Errorf("invalid port number: %v", err) } - // go func() { - // err = el2.run(host, portNum) - // if err != nil { - // println("[debug] failed to run event loop: %v", err) - // } - // }() - //TODO(hackerchai): new logic for poll // go func() { // for { @@ -280,59 +232,27 @@ func (srv *Server) ListenAndServe() error { // } // }() - // err = el.run(host, portNum) - // if err != nil { - // return fmt.Errorf("failed to run event loop: %v", err) - // } + errChan := make(chan error, len(srv.eventLoop)) + var wg sync.WaitGroup - // Create a libuv thread pool with the same number of threads as event loops - threadPool := make([]*libuv.Thread, len(srv.eventLoop)) - for i := range threadPool { - threadPool[i] = &libuv.Thread{} + for _, el := range srv.eventLoop { + wg.Add(1) + go func(el *eventLoop) { + err := el.run(host, portNum) + if err != nil { + errChan <- fmt.Errorf("failed to run event loop: %v", err) + } + wg.Done() + }(el) } - // Start each event loop in its own thread - for i, el := range srv.eventLoop { - threadArg := &threadArg{ - host: host, - port: portNum, - eventLoop: el, - } - - fmt.Printf("[debug] Creating thread %d\n", i) - - if result := threadPool[i].Create(runEventLoopInThread, unsafe.Pointer(threadArg)); result != 0 { - return fmt.Errorf("failed to create thread: %v", err) - } - } - - // Wait for all threads to complete - for _, thread := range threadPool { - if result := thread.Join(); result != 0 { - fmt.Printf("[debug] Failed to join thread: %v\n", err) - } - } + wg.Wait() fmt.Printf("Listening on %s\n", srv.Addr) - // if r := srv.uvServer.Bind((*cnet.SockAddr)(unsafe.Pointer(&sockaddr)), 0); r != 0 { - // return fmt.Errorf("failed to bind: %v", libuv.Strerror(libuv.Errno(r))) - // } - return nil } -func runEventLoopInThread(arg c.Pointer) { - tArg := (*threadArg)(arg) - host := tArg.host - port := tArg.port - el := tArg.eventLoop - err := el.run(host, port) - if err != nil { - fmt.Printf("[debug] failed to run event loop: %v", err) - } -} - func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) { DefaultServeMux.HandleFunc(pattern, handler) } @@ -344,16 +264,6 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { return } - // srv := (*Server)((*libuv.Handle)(unsafe.Pointer(serverStream)).GetData()) - // if srv == nil { - // fmt.Fprintf(os.Stderr, "Server is nil\n") - // return - // } - // el := (*eventLoop)((*libuv.Handle)(unsafe.Pointer(serverStream)).GetData()) - // if el == nil { - // fmt.Fprintf(os.Stderr, "Event loop is nil\n") - // return - // } el := (*eventLoop)((*libuv.Handle)(unsafe.Pointer(serverStream)).GetLoop().GetData()) if el == nil { fmt.Fprintf(os.Stderr, "Event loop is nil\n") @@ -394,17 +304,9 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { (*libuv.Handle)(unsafe.Pointer(&conn.stream)).Close(nil) return } - //userData.setExecutor(srv.executor) - userData.executor = el.executor - userData.asyncHandleID = conn.asyncID - // if srv.Handler == nil { - // fmt.Fprintf(os.Stderr, "Failed to get handler\n") - // (*libuv.Handle)(unsafe.Pointer(&conn.stream)).Close(nil) - // return - // } - //userData.handler = srv.Handler - //userData.requestNotifyHandle = requestNotifyHandle + userData.executor = el.executor + userData.asyncHandleID = c.Long(conn.asyncID) var addr cnet.SockaddrStorage addrlen := c.Int(unsafe.Sizeof(addr)) @@ -420,7 +322,6 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { c.Snprintf((*c.Char)(&userData.port[0]), unsafe.Sizeof(userData.port), c.Str("%d"), cnet.Ntohs(s.Port)) } - //TODO(hackerchai): use userData.host and userData.port conn.remoteAddr = c.GoString((*c.Char)(&userData.host[0])) + ":" + c.GoString((*c.Char)(&userData.port[0])) r := libuv.PollInit(el.uvLoop, &conn.pollHandle, libuv.OsFd(conn.stream.GetIoWatcherFd())) @@ -475,7 +376,6 @@ func onAsync(asyncHandle *libuv.Async) { } func onIdle(handle *libuv.Idle) { - // el := (*eventLoop)((*libuv.Handle)(unsafe.Pointer(handle)).GetData()) el := (*eventLoop)((*libuv.Handle)(unsafe.Pointer(handle)).GetLoop().GetData()) if el.executor != nil { task := el.executor.Poll() @@ -495,21 +395,9 @@ func doNothing(handle *libuv.Idle) { return } -// func (s *serviceUserdata) setExecutor(exec *hyper.Executor) { -// s.executor.Store(exec) -// } - -// func (s *serviceUserdata) getExecutor() *hyper.Executor { -// return s.executor.Load() -// } - func serverCallback(userData unsafe.Pointer, hyperReq *hyper.Request, channel *hyper.ResponseChannel) { + payload := (*serviceUserdata)(userData) - // srv := userData.server - // if srv == nil { - // fmt.Fprintf(os.Stderr, "Error: Received null server\n") - // return - // } if payload == nil { fmt.Fprintf(os.Stderr, "Error: Received null userData\n") return @@ -518,15 +406,27 @@ func serverCallback(userData unsafe.Pointer, hyperReq *hyper.Request, channel *h executor := payload.executor if executor == nil { fmt.Fprintf(os.Stderr, "Error: Received null executor\n") + fmt.Printf("[debug] host: %s\n", c.GoString(&payload.host[0])) + fmt.Printf("[debug] port: %s\n", c.GoString(&payload.port[0])) return } + host := payload.host + port := payload.port + + if payload.asyncHandleID == 0 { + fmt.Fprintf(os.Stderr, "Error: Received null asyncHandleID\n") + return + } + connID := int64(payload.asyncHandleID) + + + if hyperReq == nil { fmt.Fprintf(os.Stderr, "Error: Received null request\n") return } - connID := payload.asyncHandleID asyncHandleMapMu.Lock() requestNotifyHandle, ok := asyncHandleMap[connID] asyncHandleMapMu.Unlock() @@ -535,8 +435,6 @@ func serverCallback(userData unsafe.Pointer, hyperReq *hyper.Request, channel *h return } - host := payload.host - port := payload.port remoteAddr := c.GoString(&host[0]) + ":" + c.GoString(&port[0]) fmt.Printf("[debug] Remote address: %s\n", remoteAddr) @@ -811,7 +709,7 @@ func createConnData() (*conn, error) { } conn.isClosing.Store(false) conn.closedHandles = 0 - conn.asyncID = int(connID) + 1 + conn.asyncID = conn.asyncID + 1 return conn, nil } @@ -830,52 +728,51 @@ func closeWalkCb(handle *libuv.Handle, arg c.Pointer) { func (srv *Server) Close() error { srv.isShutdown.Store(true) - // for c := range el.activeConnections { - // c.Close() + for _, el := range srv.eventLoop { + el.Close() + } - // delete(srv.activeConnections, c) - // } + return nil +} - // if srv.executor != nil { - // srv.executor.Free() - // srv.executor = nil - // } +func (s *Server) shuttingDown() bool { + return s.isShutdown.Load() +} - // if exec := srv.executor; exec != nil { - // srv.executor = nil - // exec.Free() - // } +func (el *eventLoop) Close() error { + el.isShutdown.Store(true) - // if srv.http1Opts != nil { - // srv.http1Opts.Free() - // srv.http1Opts = nil - // } + for c := range el.activeConnections { + c.Close() + el.trackConn(c, false) + } - // if srv.http2Opts != nil { - // srv.http2Opts.Free() - // srv.http2Opts = nil - // } + if el.executor != nil { + el.executor.Free() + el.executor = nil + } + if el.http1Opts != nil { + el.http1Opts.Free() + el.http1Opts = nil + } + if el.http2Opts != nil { + el.http2Opts.Free() + el.http2Opts = nil + } - // srv.uvLoop.Walk(closeWalkCb, nil) - // srv.uvLoop.Run(libuv.RUN_ONCE) - // (*libuv.Handle)(unsafe.Pointer(&srv.uvServer)).Close(nil) + el.uvLoop.Walk(closeWalkCb, nil) + el.uvLoop.Run(libuv.RUN_ONCE) + (*libuv.Handle)(unsafe.Pointer(&el.uvServer)).Close(nil) + (*libuv.Handle)(unsafe.Pointer(&el.idleHandle)).Close(nil) + el.uvLoop.Close() - // srv.uvLoop.Close() return nil } -func (s *Server) shuttingDown() bool { - return s.isShutdown.Load() -} - func (el *eventLoop) shuttingDown() bool { return el.isShutdown.Load() } -func (c *conn) shuttingDown() bool { - return c.isClosing.Load() -} - func (c *conn) Close() { if c != nil && !c.isClosing.Swap(true) { fmt.Printf("[debug] Closing connection...\n") @@ -895,9 +792,19 @@ func (c *conn) Close() { if (*libuv.Handle)(unsafe.Pointer(&c.stream)).IsClosing() == 0 { (*libuv.Handle)(unsafe.Pointer(&c.stream)).Close(nil) } + + if asyncHandleMap[c.asyncID] != nil { + asyncHandleMapMu.Lock() + delete(asyncHandleMap, c.asyncID) + asyncHandleMapMu.Unlock() + } } } +func (c *conn) shuttingDown() bool { + return c.isClosing.Load() +} + type HandlerFunc func(ResponseWriter, *Request) func (f HandlerFunc) ServeHTTP(w ResponseWriter, r *Request) { From a4e86311ecd8999669de14624189c4b42de2c151 Mon Sep 17 00:00:00 2001 From: spongehah <2635879218@qq.com> Date: Fri, 20 Sep 2024 12:48:10 +0800 Subject: [PATCH 48/55] WIP(x/net/http/client): Mutiple eventLoop --- go.mod | 2 +- go.sum | 4 +- .../_demo/parallelRequest/parallelRequest.go | 43 ++ x/net/http/bodyChunk.go | 84 +-- x/net/http/client.go | 193 +++--- x/net/http/request.go | 73 +-- x/net/http/response.go | 23 +- x/net/http/server.go | 7 - x/net/http/transfer.go | 60 +- x/net/http/transport.go | 597 ++++++++++-------- 10 files changed, 525 insertions(+), 561 deletions(-) create mode 100644 x/net/http/_demo/parallelRequest/parallelRequest.go diff --git a/go.mod b/go.mod index e893515..95f17e6 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/goplus/llgoexamples go 1.20 require ( - github.com/goplus/llgo v0.9.7-0.20240830010153-2434fd778f0b + github.com/goplus/llgo v0.9.8-0.20240919105235-c6436ea6d196 golang.org/x/net v0.28.0 ) diff --git a/go.sum b/go.sum index 5d7faad..08150d6 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/goplus/llgo v0.9.7-0.20240830010153-2434fd778f0b h1:iC0vVA8F2DNJ9wVyHI9fP9U0nM+si3LSQJ1TtGftXyo= -github.com/goplus/llgo v0.9.7-0.20240830010153-2434fd778f0b/go.mod h1:5Fs+08NslqofJ7xtOiIXugkurYOoQvY02ZkFNWA1uEI= +github.com/goplus/llgo v0.9.8-0.20240919105235-c6436ea6d196 h1:LckJktvgChf3x0eex+GT//JkYRj1uiT4uMLzyrg3ChU= +github.com/goplus/llgo v0.9.8-0.20240919105235-c6436ea6d196/go.mod h1:5Fs+08NslqofJ7xtOiIXugkurYOoQvY02ZkFNWA1uEI= golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= diff --git a/x/net/http/_demo/parallelRequest/parallelRequest.go b/x/net/http/_demo/parallelRequest/parallelRequest.go new file mode 100644 index 0000000..0bcb336 --- /dev/null +++ b/x/net/http/_demo/parallelRequest/parallelRequest.go @@ -0,0 +1,43 @@ +package main + +import ( + "fmt" + "sync" + + "github.com/goplus/llgoexamples/x/net/http" +) + +func worker(id int, wg *sync.WaitGroup) { + defer wg.Done() + resp, err := http.Get("http://www.baidu.com") + if err != nil { + fmt.Println(err) + return + } + fmt.Println(id, ":", resp.Status) + //body, err := io.ReadAll(resp.Body) + //if err != nil { + // fmt.Println(err) + // return + //} + //fmt.Println(string(body)) + resp.Body.Close() +} + +func main() { + var wait sync.WaitGroup + for i := 0; i < 500; i++ { + wait.Add(1) + go worker(i, &wait) + } + wait.Wait() + fmt.Println("All done") + + resp, err := http.Get("http://www.baidu.com") + if err != nil { + fmt.Println(err) + return + } + fmt.Println(resp.Status) + resp.Body.Close() +} diff --git a/x/net/http/bodyChunk.go b/x/net/http/bodyChunk.go index c1d1072..01d9e74 100644 --- a/x/net/http/bodyChunk.go +++ b/x/net/http/bodyChunk.go @@ -2,73 +2,49 @@ package http import ( "errors" - "io" - "sync" "github.com/goplus/llgo/c/libuv" ) -type onceError struct { - sync.Mutex - err error -} - -func (a *onceError) Store(err error) { - a.Lock() - defer a.Unlock() - if a.err != nil { - return - } - a.err = err -} - -func (a *onceError) Load() error { - a.Lock() - defer a.Unlock() - return a.err -} - -func newBodyChunk(asyncHandle *libuv.Async) *bodyChunk { - return &bodyChunk{ - readCh: make(chan []byte, 1), - done: make(chan struct{}), - asyncHandle: asyncHandle, - } -} - type bodyChunk struct { chunk []byte readCh chan []byte asyncHandle *libuv.Async - once sync.Once done chan struct{} - rerr onceError + rerr error } var ( errClosedBodyChunk = errors.New("bodyChunk: read/write on closed body") ) +func newBodyChunk(asyncHandle *libuv.Async) *bodyChunk { + return &bodyChunk{ + readCh: make(chan []byte, 1), + done: make(chan struct{}), + asyncHandle: asyncHandle, + } +} + func (bc *bodyChunk) Read(p []byte) (n int, err error) { + select { + case <-bc.done: + err = bc.readCloseError() + return + default: + } + for n < len(p) { if len(bc.chunk) == 0 { + bc.asyncHandle.Send() select { - case chunk, ok := <-bc.readCh: - if !ok { - if n > 0 { - return n, nil - } - return 0, bc.readCloseError() - } + case chunk := <-bc.readCh: bc.chunk = chunk - bc.asyncHandle.Send() case <-bc.done: - if n > 0 { - return n, nil - } - return 0, io.EOF + err = bc.readCloseError() + return } } @@ -77,28 +53,28 @@ func (bc *bodyChunk) Read(p []byte) (n int, err error) { bc.chunk = bc.chunk[copied:] } - return n, nil + return } func (bc *bodyChunk) Close() error { - return bc.closeRead(nil) + return bc.closeWithError(nil) } func (bc *bodyChunk) readCloseError() error { - if rerr := bc.rerr.Load(); rerr != nil { + if rerr := bc.rerr; rerr != nil { return rerr } return errClosedBodyChunk } -func (bc *bodyChunk) closeRead(err error) error { +func (bc *bodyChunk) closeWithError(err error) error { + if bc.rerr != nil { + return nil + } if err == nil { - err = io.EOF + err = errClosedBodyChunk } - bc.rerr.Store(err) - bc.once.Do(func() { - close(bc.done) - }) - //close(bc.done) + bc.rerr = err + close(bc.done) return nil } diff --git a/x/net/http/client.go b/x/net/http/client.go index 7e26395..fa62732 100644 --- a/x/net/http/client.go +++ b/x/net/http/client.go @@ -11,8 +11,6 @@ import ( "reflect" "sort" "strings" - "sync" - "sync/atomic" "time" ) @@ -157,8 +155,7 @@ func (c *Client) do(req *Request) (retres *Response, reterr error) { URL: u, Header: make(Header), Host: host, - Cancel: ireq.Cancel, - ctx: ireq.ctx, + //Cancel: ireq.Cancel, timer: ireq.timer, timeoutch: ireq.timeoutch, @@ -307,16 +304,15 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, d forkReq() } - // TODO(spongehah) tmp timeout(send) + // TODO(hah) tmp timeout(send): LLGo has not yet implemented startTimer. //stopTimer, didTimeout := setRequestCancel(req, rt, deadline) req.timeoutch = make(chan struct{}, 1) req.deadline = deadline - req.ctx.Done() if deadline.IsZero() { didTimeout = alwaysFalse defer close(req.timeoutch) } else { - didTimeout = func() bool { return req.timer.GetDueIn() == 0 } + didTimeout = func() bool { return time.Now().After(deadline) } } resp, err = rt.RoundTrip(req) @@ -478,110 +474,83 @@ func (b *cancelTimerBody) Close() error { return err } -// knownRoundTripperImpl reports whether rt is a RoundTripper that's -// maintained by the Go team and known to implement the latest -// optional semantics (notably contexts). The Request is used -// to check whether this particular request is using an alternate protocol, -// in which case we need to check the RoundTripper for that protocol. -func knownRoundTripperImpl(rt RoundTripper, req *Request) bool { - switch t := rt.(type) { - case *Transport: - if altRT := t.alternateRoundTripper(req); altRT != nil { - return knownRoundTripperImpl(altRT, req) - } - return true - //case *http2Transport, http2noDialH2RoundTripper: - // return true - } - // There's a very minor chance of a false positive with this. - // Instead of detecting our golang.org/x/net/http2.Transport, - // it might detect a Transport type in a different http2 - // package. But I know of none, and the only problem would be - // some temporarily leaked goroutines if the transport didn't - // support contexts. So this is a good enough heuristic: - if reflect.TypeOf(rt).String() == "*http2.Transport" { - return true - } - return false -} - -// setRequestCancel sets req.Cancel and adds a deadline context to req -// if deadline is non-zero. The RoundTripper's type is used to -// determine whether the legacy CancelRequest behavior should be used. +//// setRequestCancel sets req.Cancel and adds a deadline context to req +//// if deadline is non-zero. The RoundTripper's type is used to +//// determine whether the legacy CancelRequest behavior should be used. +//// +//// As background, there are three ways to cancel a request: +//// First was Transport.CancelRequest. (deprecated) +//// Second was Request.Cancel. +//// Third was Request.Context. +//// This function populates the second and third, and uses the first if it really needs to. +//func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTimer func(), didTimeout func() bool) { +// if deadline.IsZero() { +// return nop, alwaysFalse +// } +// knownTransport := knownRoundTripperImpl(rt, req) +// oldCtx := req.Context() // -// As background, there are three ways to cancel a request: -// First was Transport.CancelRequest. (deprecated) -// Second was Request.Cancel. -// Third was Request.Context. -// This function populates the second and third, and uses the first if it really needs to. -func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTimer func(), didTimeout func() bool) { - if deadline.IsZero() { - return nop, alwaysFalse - } - knownTransport := knownRoundTripperImpl(rt, req) - oldCtx := req.Context() - - if req.Cancel == nil && knownTransport { - // If they already had a Request.Context that's - // expiring sooner, do nothing: - if !timeBeforeContextDeadline(deadline, oldCtx) { - return nop, alwaysFalse - } - - var cancelCtx func() - req.ctx, cancelCtx = context.WithDeadline(oldCtx, deadline) - return cancelCtx, func() bool { return time.Now().After(deadline) } - } - initialReqCancel := req.Cancel // the user's original Request.Cancel, if any - - var cancelCtx func() - if timeBeforeContextDeadline(deadline, oldCtx) { - req.ctx, cancelCtx = context.WithDeadline(oldCtx, deadline) - } - - cancel := make(chan struct{}) - req.Cancel = cancel - - doCancel := func() { - // The second way in the func comment above: - close(cancel) - // The first way, used only for RoundTripper - // implementations written before Go 1.5 or Go 1.6. - type canceler interface{ CancelRequest(*Request) } - if v, ok := rt.(canceler); ok { - v.CancelRequest(req) - } - } - - stopTimerCh := make(chan struct{}) - var once sync.Once - stopTimer = func() { - once.Do(func() { - close(stopTimerCh) - if cancelCtx != nil { - cancelCtx() - } - }) - } - - timer := time.NewTimer(time.Until(deadline)) - var timedOut atomic.Bool - - go func() { - select { - case <-initialReqCancel: - doCancel() - timer.Stop() - case <-timer.C: - timedOut.Store(true) - doCancel() - case <-stopTimerCh: - timer.Stop() - } - }() - - return stopTimer, timedOut.Load -} +// if req.Cancel == nil && knownTransport { +// // If they already had a Request.Context that's +// // expiring sooner, do nothing: +// if !timeBeforeContextDeadline(deadline, oldCtx) { +// return nop, alwaysFalse +// } +// +// var cancelCtx func() +// req.ctx, cancelCtx = context.WithDeadline(oldCtx, deadline) +// return cancelCtx, func() bool { return time.Now().After(deadline) } +// } +// initialReqCancel := req.Cancel // the user's original Request.Cancel, if any +// +// var cancelCtx func() +// if timeBeforeContextDeadline(deadline, oldCtx) { +// req.ctx, cancelCtx = context.WithDeadline(oldCtx, deadline) +// } +// +// cancel := make(chan struct{}) +// req.Cancel = cancel +// +// doCancel := func() { +// // The second way in the func comment above: +// close(cancel) +// // The first way, used only for RoundTripper +// // implementations written before Go 1.5 or Go 1.6. +// type canceler interface{ CancelRequest(*Request) } +// if v, ok := rt.(canceler); ok { +// v.CancelRequest(req) +// } +// } +// +// stopTimerCh := make(chan struct{}) +// var once sync.Once +// stopTimer = func() { +// once.Do(func() { +// close(stopTimerCh) +// if cancelCtx != nil { +// cancelCtx() +// } +// }) +// } +// +// timer := time.NewTimer(time.Until(deadline)) +// var timedOut atomic.Bool +// +// go func() { +// select { +// case <-initialReqCancel: +// doCancel() +// timer.Stop() +// case <-timer.C: +// timedOut.Store(true) +// doCancel() +// case <-stopTimerCh: +// timer.Stop() +// } +// }() +// +// return stopTimer, timedOut.Load +//} // timeBeforeContextDeadline reports whether the non-zero Time t is // before ctx's deadline, if any. If ctx does not have a deadline, it @@ -594,7 +563,7 @@ func timeBeforeContextDeadline(t time.Time, ctx context.Context) bool { return t.Before(d) } -/*// knownRoundTripperImpl reports whether rt is a RoundTripper that's +// knownRoundTripperImpl reports whether rt is a RoundTripper that's // maintained by the Go team and known to implement the latest // optional semantics (notably contexts). The Request is used // to check whether this particular request is using an alternate protocol, @@ -619,7 +588,7 @@ func knownRoundTripperImpl(rt RoundTripper, req *Request) bool { return true } return false -}*/ +} // makeHeadersCopier makes a function that copies headers from the // initial Request, ireq. For every redirect, this function must be called diff --git a/x/net/http/request.go b/x/net/http/request.go index e9279fc..37d6408 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -2,7 +2,6 @@ package http import ( "bytes" - "context" "errors" "fmt" "io" @@ -31,20 +30,16 @@ type Request struct { TransferEncoding []string Close bool Host string - //Form url.Values - //PostForm url.Values - //MultipartForm *multipart.Form - Trailer Header + // Form url.Values + // PostForm url.Values + // MultipartForm *multipart.Form RemoteAddr string RequestURI string - //TLS *tls.ConnectionState - Cancel <-chan struct{} Response *Response - ctx context.Context deadline time.Time - timeoutch chan struct{} //tmp timeout + timeoutch chan struct{} timer *libuv.Timer } @@ -75,34 +70,8 @@ var reqWriteExcludeHeader = map[string]bool{ type requestBodyReadError struct{ error } // NewRequest wraps NewRequestWithContext using context.Background. -func NewRequest(method, url string, body io.Reader) (*Request, error) { - return NewRequestWithContext(context.Background(), method, url, body) -} - -// NewRequestWithContext returns a new Request given a method, URL, and -// optional body. -// -// If the provided body is also an io.Closer, the returned -// Request.Body is set to body and will be closed by the Client -// methods Do, Post, and PostForm, and Transport.RoundTrip. -// -// NewRequestWithContext returns a Request suitable for use with -// Client.Do or Transport.RoundTrip. To create a request for use with -// testing a Server Handler, either use the NewRequest function in the -// net/http/httptest package, use ReadRequest, or manually update the -// Request fields. For an outgoing client request, the context -// controls the entire lifetime of a request and its response: -// obtaining a connection, sending the request, and reading the -// response headers and body. See the Request type's documentation for -// the difference between inbound and outbound request fields. -// -// If body is of type *bytes.Buffer, *bytes.Reader, or -// *strings.Reader, the returned request's ContentLength is set to its -// exact value (instead of -1), GetBody is populated (so 307 and 308 -// redirects can replay the body), and Body is set to NoBody if the -// ContentLength is 0. -func NewRequestWithContext(ctx context.Context, method, urlStr string, body io.Reader) (*Request, error) { - // TODO(spongehah) Hyper only supports http +func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { + // TODO(hah) Hyper only supports http isHttpPrefix := strings.HasPrefix(urlStr, "http://") isHttpsPrefix := strings.HasPrefix(urlStr, "https://") if !isHttpPrefix && !isHttpsPrefix { @@ -121,9 +90,6 @@ func NewRequestWithContext(ctx context.Context, method, urlStr string, body io.R if !validMethod(method) { return nil, fmt.Errorf("net/http: invalid method %q", method) } - if ctx == nil { - return nil, errors.New("net/http: nil Context") - } u, err := url.Parse(urlStr) if err != nil { return nil, err @@ -135,7 +101,6 @@ func NewRequestWithContext(ctx context.Context, method, urlStr string, body io.R // The host's colon:port should be normalized. See Issue 14836. u.Host = removeEmptyPort(u.Host) req := &Request{ - ctx: ctx, Method: method, URL: u, Proto: "HTTP/1.1", @@ -228,24 +193,6 @@ func (r *Request) isReplayable() bool { return false } -// Context returns the request's context. To change the context, use -// Clone or WithContext. -// -// The returned context is always non-nil; it defaults to the -// background context. -// -// For outgoing client requests, the context controls cancellation. -// -// For incoming server requests, the context is canceled when the -// client's connection closes, the request is canceled (with HTTP/2), -// or when the ServeHTTP method returns. -func (r *Request) Context() context.Context { - if r.ctx != nil { - return r.ctx - } - return context.Background() -} - // AddCookie adds a cookie to the request. Per RFC 6265 section 5.4, // AddCookie does not attach more than one Cookie header field. That // means all cookies, if any, are written into the same line, @@ -300,7 +247,11 @@ func (r *Request) write(client *hyper.ClientConn, taskData *taskData, exec *hype } // Send it! sendTask := client.Send(hyperReq) - sendTask.SetUserdata(c.Pointer(taskData)) + if sendTask == nil { + println("############### write: sendTask is nil") + return errors.New("failed to send the request") + } + sendTask.SetUserdata(c.Pointer(taskData), nil) sendRes := exec.Push(sendTask) if sendRes != hyper.OK { err = errors.New("failed to send the request") @@ -424,7 +375,7 @@ func (r *Request) newHyperRequest(usingProxy bool, extraHeader Header, treq *tra // Wait for 100-continue if expected. if r.ProtoAtLeast(1, 1) && r.Body != nil && r.expectsContinue() { - hyperReq.OnInformational(printInformational, nil) + hyperReq.OnInformational(printInformational, nil, nil) } // Write body and trailer diff --git a/x/net/http/response.go b/x/net/http/response.go index a3a96fc..da7c3e4 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -81,13 +81,19 @@ func (r *Response) checkRespBody(taskData *taskData) (needContinue bool) { select { case taskData.resc <- responseAndError{res: r}: case <-taskData.callerGone: - readLoopDefer(pc, true) + if debugSwitch { + println("############### checkRespBody callerGone") + } + closeAndRemoveIdleConn(pc, true) return true } // Now that they've read from the unbuffered channel, they're safely // out of the select that also waits on this goroutine to die, so // we're allowed to exit now if needed (if alive is false) - readLoopDefer(pc, false) + if debugSwitch { + println("############### checkRespBody return") + } + closeAndRemoveIdleConn(pc, false) return true } return false @@ -97,6 +103,17 @@ func (r *Response) wrapRespBody(taskData *taskData) { body := &bodyEOFSignal{ body: r.Body, earlyCloseFn: func() error { + // If the response body is closed prematurely, + // the hyperBody needs to be recycled and the persistConn needs to be handled. + taskData.closeHyperBody() + select { + case <-taskData.pc.closech: + taskData.pc.t.removeIdleConn(taskData.pc) + default: + } + replaced := taskData.pc.t.replaceReqCanceler(taskData.req.cancelKey, nil) // before pc might return to idle pool + taskData.pc.alive = taskData.pc.alive && + replaced && taskData.pc.tryPutIdleConn() return nil }, fn: func(err error) error { @@ -110,7 +127,7 @@ func (r *Response) wrapRespBody(taskData *taskData) { }, } r.Body = body - // TODO(spongehah) gzip(wrapRespBody) + // TODO(hah) gzip(wrapRespBody): The compress/gzip library still has a bug. An exception occurs when calling gzip.NewReader(). //if taskData.addedGzip && EqualFold(r.Header.Get("Content-Encoding"), "gzip") { // println("gzip reader") // r.Body = &gzipReader{body: body} diff --git a/x/net/http/server.go b/x/net/http/server.go index 5c4c58d..f38cbd0 100644 --- a/x/net/http/server.go +++ b/x/net/http/server.go @@ -10,10 +10,3 @@ package http // size is anyway. (if we have the bytes on the machine, we might as // well read them) const maxPostHandlerReadBytes = 256 << 10 - -type readResult struct { - _ incomparable - n int - err error - b byte // byte read, if n == 1 -} diff --git a/x/net/http/transfer.go b/x/net/http/transfer.go index 818fb3c..12f3d70 100644 --- a/x/net/http/transfer.go +++ b/x/net/http/transfer.go @@ -28,7 +28,6 @@ type transferReader struct { ContentLength int64 Chunked bool Close bool - Trailer Header } // parseTransferEncoding sets t.Chunked based on the Transfer-Encoding header. @@ -151,10 +150,6 @@ func readTransfer(msg any, r io.ReadCloser) (err error) { t.ContentLength = realLength } - // TODO(spongehah) Trailer(readTransfer) - // Trailer - //t.Trailer, err = fixTrailer(t.Header, t.Chunked) - // If there is no Content-Length or chunked Transfer-Encoding on a *Response // and the status is not 1xx, 204 or 304, then the body is unbounded. // See RFC 7230, section 3.3. @@ -301,48 +296,6 @@ func parseContentLength(cl string) (int64, error) { } -// Parse the trailer header. -func fixTrailer(header Header, chunked bool) (Header, error) { - vv, ok := header["Trailer"] - if !ok { - return nil, nil - } - if !chunked { - // Trailer and no chunking: - // this is an invalid use case for trailer header. - // Nevertheless, no error will be returned and we - // let users decide if this is a valid HTTP message. - // The Trailer header will be kept in Response.Header - // but not populate Response.Trailer. - // See issue #27197. - return nil, nil - } - header.Del("Trailer") - - trailer := make(Header) - var err error - for _, v := range vv { - foreachHeaderElement(v, func(key string) { - key = CanonicalHeaderKey(key) - switch key { - case "Transfer-Encoding", "Trailer", "Content-Length": - if err == nil { - err = badStringError("bad trailer key", key) - return - } - } - trailer[key] = nil - }) - } - if err != nil { - return nil, err - } - if len(trailer) == 0 { - return nil, nil - } - return trailer, nil -} - // body turns a Reader into a ReadCloser. // Close ensures that the body has been fully read // and then reads the trailer if necessary. @@ -387,16 +340,6 @@ func (b *body) readLocked(p []byte) (n int, err error) { b.sawEOF = true // Chunked case. Read the trailer. if b.hdr != nil { - // TODO(spongehah) Trailer(b.readLocked) - //if e := b.readTrailer(); e != nil { - // err = e - // // Something went wrong in the trailer, we must not allow any - // // further reads of any kind to succeed from body, nor any - // // subsequent requests on the server connection. See - // // golang.org/issue/12027 - // b.sawEOF = false - // b.closed = true - //} b.hdr = nil } else { // If the server declared the Content-Length, our body is a LimitedReader @@ -634,7 +577,6 @@ func (r *Request) writeHeader(reqHeaders *hyper.Headers) error { // 'Content-Length' and 'Transfer-Encoding:chunked' are already handled by hyper // Write Trailer header - // TODO(spongehah) Trailer(writeHeader) return nil } @@ -682,7 +624,7 @@ func (r *Request) writeBody(hyperReq *hyper.Request, treq *transportRequest) err buf: buf, treq: treq, } - hyperReqBody.SetUserdata(c.Pointer(reqData)) + hyperReqBody.SetUserdata(c.Pointer(reqData), nil) hyperReqBody.SetDataFunc(setPostData) hyperReq.SetBody(hyperReqBody) } diff --git a/x/net/http/transport.go b/x/net/http/transport.go index 8075133..e47bd2a 100644 --- a/x/net/http/transport.go +++ b/x/net/http/transport.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "hash/fnv" "io" "log" "net/url" @@ -27,7 +28,6 @@ import ( // as directed by the environment variables HTTP_PROXY, HTTPS_PROXY // and NO_PROXY (or the lowercase versions thereof). var DefaultTransport RoundTripper = &Transport{ - //Proxy: ProxyFromEnvironment, Proxy: nil, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, @@ -36,6 +36,7 @@ var DefaultTransport RoundTripper = &Transport{ // DefaultMaxIdleConnsPerHost is the default value of Transport's // MaxIdleConnsPerHost. const DefaultMaxIdleConnsPerHost = 2 +const _SC_NPROCESSORS_ONLN c.Int = 58 // Debug switch provided for developers const ( @@ -69,11 +70,10 @@ type Transport struct { MaxConnsPerHost int IdleConnTimeout time.Duration - // libuv and hyper related - loopInitOnce sync.Once - loop *libuv.Loop - async *libuv.Async - exec *hyper.Executor + loopsMu sync.Mutex + loops []*clientEventLoop + isClosing atomic.Bool + //curLoop atomic.Uint32 } // A cancelKey is the key of the reqCanceler map. @@ -183,6 +183,9 @@ func (tr *transportRequest) setError(err error) { func (t *Transport) putOrCloseIdleConn(pconn *persistConn) { if err := t.tryPutIdleConn(pconn); err != nil { + if debugSwitch { + println("############### putOrCloseIdleConn: close") + } pconn.close(err) } } @@ -274,6 +277,9 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error { t.idleLRU.add(pconn) if t.MaxIdleConns != 0 && t.idleLRU.len() > t.MaxIdleConns { oldest := t.idleLRU.removeOldest() + if debugSwitch { + println("############### tryPutIdleConn: removeOldest") + } oldest.close(errTooManyIdle) t.removeIdleConnLocked(oldest) } @@ -287,7 +293,7 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error { pconn.idleTimer.Start(onIdleConnTimeout, idleConnTimeout, 0) } else { pconn.idleTimer = &libuv.Timer{} - libuv.InitTimer(t.loop, pconn.idleTimer) + libuv.InitTimer(pconn.eventLoop.loop, pconn.idleTimer) (*libuv.Handle)(c.Pointer(pconn.idleTimer)).SetData(c.Pointer(pconn)) pconn.idleTimer.Start(onIdleConnTimeout, idleConnTimeout, 0) } @@ -343,7 +349,9 @@ func (t *Transport) queueForIdleConn(w *wantConn) (delivered bool) { // See whether this connection has been idle too long, considering // only the wall time (the Round(0)), in case this is a laptop or VM // coming out of suspend with previously cached idle connections. - tooOld := !oldTime.IsZero() && pconn.idleAt.Round(0).Before(oldTime) + // FIXME: Round() is not supported in llgo + //tooOld := !oldTime.IsZero() && pconn.idleAt.Round(0).Before(oldTime) + tooOld := !oldTime.IsZero() && pconn.idleAt.Before(oldTime) if tooOld { // Async cleanup. Launch in its own goroutine (as if a // time.AfterFunc called it); it acquires idleMu, which we're @@ -403,9 +411,10 @@ func (t *Transport) removeIdleConn(pconn *persistConn) bool { // t.idleMu must be held. func (t *Transport) removeIdleConnLocked(pconn *persistConn) bool { - if pconn.idleTimer != nil { + if pconn.idleTimer != nil && (*libuv.Handle)(c.Pointer(pconn.idleTimer)).IsClosing() == 0 { pconn.idleTimer.Stop() (*libuv.Handle)(c.Pointer(pconn.idleTimer)).Close(nil) + pconn.idleTimer = nil } t.idleLRU.remove(pconn) key := pconn.cacheKey @@ -467,13 +476,14 @@ func (t *Transport) replaceReqCanceler(key cancelKey, fn func(error)) bool { return true } -func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) { +func (t *Transport) connectMethodForRequest(treq *transportRequest, loop *clientEventLoop) (cm connectMethod, err error) { cm.targetScheme = treq.URL.Scheme cm.targetAddr = canonicalAddr(treq.URL) if t.Proxy != nil { cm.proxyURL, err = t.Proxy(treq.Request) } cm.onlyH1 = treq.requiresHTTP1() + cm.eventLoop = loop return cm, err } @@ -524,25 +534,56 @@ func (t *Transport) cancelRequest(key cancelKey, err error) bool { return cancel != nil } -func (t *Transport) close(err error) { - t.reqMu.Lock() - defer t.reqMu.Unlock() - t.closeLocked(err) +func (t *Transport) Close() { + if t != nil && !t.isClosing.Swap(true) { + t.CloseIdleConnections() + for _, el := range t.loops { + el.Close() + } + } } -func (t *Transport) closeLocked(err error) { - if err != nil { - fmt.Println(err) - } - if t.loop != nil { - t.loop.Close() - } - if t.async != nil { - t.async.Close(nil) +type clientEventLoop struct { + // libuv and hyper related + loop *libuv.Loop + async *libuv.Async + exec *hyper.Executor + isRunning atomic.Bool + isClosing atomic.Bool +} + +func (el *clientEventLoop) Close() { + if el != nil && !el.isClosing.Swap(true) { + if el.loop != nil && (*libuv.Handle)(c.Pointer(el.loop)).IsClosing() == 0 { + el.loop.Close() + el.loop = nil + } + if el.async != nil && (*libuv.Handle)(c.Pointer(el.async)).IsClosing() == 0 { + el.async.Close(nil) + el.async = nil + } + if el.exec != nil { + el.exec.Free() + el.exec = nil + } } - if t.exec != nil { - t.exec.Free() +} + +func (el *clientEventLoop) run() { + if el.isRunning.Load() { + return } + + el.loop.Async(el.async, nil) + + checker := &libuv.Idle{} + libuv.InitIdle(el.loop, checker) + (*libuv.Handle)(c.Pointer(checker)).SetData(c.Pointer(el)) + checker.Start(readWriteLoop) + + go el.loop.Run(libuv.RUN_DEFAULT) + + el.isRunning.Store(true) } // ---------------------------------------------------------- @@ -556,26 +597,65 @@ func getMilliseconds(deadline time.Time) uint64 { return uint64(milliseconds) } +var cpuCount int + +func init() { + cpuCount = int(c.Sysconf(_SC_NPROCESSORS_ONLN)) + if cpuCount <= 0 { + cpuCount = 4 + } +} + +func (t *Transport) getOrInitClientEventLoop(i uint32) *clientEventLoop { + if el := t.loops[i]; el != nil { + return el + } + + eventLoop := &clientEventLoop{ + loop: libuv.LoopNew(), + async: &libuv.Async{}, + exec: hyper.NewExecutor(), + } + + eventLoop.run() + + t.loops[i] = eventLoop + return eventLoop +} + +func (t *Transport) getClientEventLoop(req *Request) *clientEventLoop { + t.loopsMu.Lock() + defer t.loopsMu.Unlock() + if t.loops == nil { + t.loops = make([]*clientEventLoop, cpuCount) + } + + key := t.getLoopKey(req) + h := fnv.New32a() + h.Write([]byte(key)) + hashcode := h.Sum32() + + return t.getOrInitClientEventLoop(hashcode % uint32(cpuCount)) + //i := (t.curLoop.Add(1) - 1) % uint32(cpuCount) + //return t.getOrInitClientEventLoop(i) +} + +func (t *Transport) getLoopKey(req *Request) string { + proxyStr := "" + if t.Proxy != nil { + proxyURL, _ := t.Proxy(req) + proxyStr = proxyURL.String() + } + return req.URL.String() + proxyStr +} + func (t *Transport) RoundTrip(req *Request) (*Response, error) { if debugSwitch { println("############### RoundTrip start") defer println("############### RoundTrip end") } - t.loopInitOnce.Do(func() { - println("############### init loop") - t.loop = libuv.LoopNew() - t.async = &libuv.Async{} - t.exec = hyper.NewExecutor() - - t.loop.Async(t.async, nil) - checker := &libuv.Check{} - libuv.InitCheck(t.loop, checker) - (*libuv.Handle)(c.Pointer(checker)).SetData(c.Pointer(t)) - checker.Start(readWriteLoop) - - go t.loop.Run(libuv.RUN_DEFAULT) - }) + eventLoop := t.getClientEventLoop(req) // If timeout is set, start the timer var didTimeout func() bool @@ -583,7 +663,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { // Only the first request will initialize the timer if req.timer == nil && !req.deadline.IsZero() { req.timer = &libuv.Timer{} - libuv.InitTimer(t.loop, req.timer) + libuv.InitTimer(eventLoop.loop, req.timer) ch := &timeoutData{ timeoutch: req.timeoutch, taskData: nil, @@ -598,7 +678,9 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { stopTimer = func() { close(req.timeoutch) req.timer.Stop() - (*libuv.Handle)(c.Pointer(req.timer)).Close(nil) + if (*libuv.Handle)(c.Pointer(req.timer)).IsClosing() == 0 { + (*libuv.Handle)(c.Pointer(req.timer)).Close(nil) + } if debugSwitch { println("############### timer close") } @@ -608,7 +690,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { stopTimer = nop } - resp, err := t.doRoundTrip(req) + resp, err := t.doRoundTrip(req, eventLoop) if err != nil { stopTimer() return nil, err @@ -624,7 +706,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { return resp, nil } -func (t *Transport) doRoundTrip(req *Request) (*Response, error) { +func (t *Transport) doRoundTrip(req *Request, loop *clientEventLoop) (*Response, error) { if debugSwitch { println("############### doRoundTrip start") defer println("############### doRoundTrip end") @@ -687,12 +769,6 @@ func (t *Transport) doRoundTrip(req *Request) (*Response, error) { } for { - //select { - //case <-ctx.Done(): - // req.closeBody() - // return nil, ctx.Err() - //default: - //} select { case <-req.timeoutch: req.closeBody() @@ -703,7 +779,7 @@ func (t *Transport) doRoundTrip(req *Request) (*Response, error) { // treq gets modified by roundTrip, so we need to recreate for each retry. //treq := &transportRequest{Request: req, trace: trace, cancelKey: cancelKey} treq := &transportRequest{Request: req, cancelKey: cancelKey} - cm, err := t.connectMethodForRequest(treq) + cm, err := t.connectMethodForRequest(treq, loop) if err != nil { req.closeBody() return nil, err @@ -716,6 +792,7 @@ func (t *Transport) doRoundTrip(req *Request) (*Response, error) { pconn, err := t.getConn(treq, cm) if err != nil { + println("################# getConn err != nil") t.setReqCanceler(cancelKey, nil) req.closeBody() return nil, err @@ -827,10 +904,6 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi // what caused w.err; if so, prefer to return the // cancellation error (see golang.org/issue/16049). select { - //case <-req.Cancel: - // return nil, errRequestCanceledConn - //case <-req.Context().Done(): - // return nil, req.Context().Err() case <-req.timeoutch: if debugSwitch { println("############### getConn: timeoutch") @@ -977,68 +1050,23 @@ func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn * writeLoopDone: make(chan struct{}, 1), alive: true, chunkAsync: &libuv.Async{}, + eventLoop: cm.eventLoop, } - t.loop.Async(pconn.chunkAsync, readyToRead) + cm.eventLoop.loop.Async(pconn.chunkAsync, readyToRead) - //trace := httptrace.ContextClientTrace(ctx) - //wrapErr := func(err error) error { - // if cm.proxyURL != nil { - // // Return a typed error, per Issue 16997 - // return &net.OpError{Op: "proxyconnect", Net: "tcp", Err: err} - // } - // return err - //} - // - //if cm.scheme() == "https" && t.hasCustomTLSDialer() { - // var err error - // pconn.conn, err = t.customDialTLS(ctx, "tcp", cm.addr()) - // if err != nil { - // return nil, wrapErr(err) - // } - // if tc, ok := pconn.conn.(*tls.Conn); ok { - // // Handshake here, in case DialTLS didn't. TLSNextProto below - // // depends on it for knowing the connection state. - // if trace != nil && trace.TLSHandshakeStart != nil { - // trace.TLSHandshakeStart() - // } - // if err := tc.HandshakeContext(ctx); err != nil { - // go pconn.conn.Close() - // if trace != nil && trace.TLSHandshakeDone != nil { - // trace.TLSHandshakeDone(tls.ConnectionState{}, err) - // } - // return nil, err - // } - // cs := tc.ConnectionState() - // if trace != nil && trace.TLSHandshakeDone != nil { - // trace.TLSHandshakeDone(cs, nil) - // } - // pconn.tlsState = &cs - // } - //} else { - //conn, err := t.dial(ctx, "tcp", cm.addr()) - conn, err := t.dial(cm.addr()) + conn, err := t.dial(cm) if err != nil { return nil, err } pconn.conn = conn - //if cm.scheme() == "https" { - // var firstTLSHost string - // if firstTLSHost, _, err = net.SplitHostPort(cm.addr()); err != nil { - // return nil, wrapErr(err) - // } - // if err = pconn.addTLS(ctx, firstTLSHost, trace); err != nil { - // return nil, wrapErr(err) - // } - //} - //} select { case <-timeoutch: conn.Close() return default: } - // TODO(spongehah) Proxy(https/sock5)(t.dialConn) + // TODO(hah) Proxy(https/sock5)(t.dialConn) // Proxy setup. switch { case cm.proxyURL == nil: @@ -1054,41 +1082,14 @@ func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn * // case cm.targetScheme == "https": } - //if cm.proxyURL != nil && cm.targetScheme == "https" { - // if err := pconn.addTLS(ctx, cm.tlsHost(), trace); err != nil { - // return nil, err - // } - //} - // - //if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" { - // if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; ok { - // alt := next(cm.targetAddr, pconn.conn.(*tls.Conn)) - // if e, ok := alt.(erringRoundTripper); ok { - // // pconn.conn was closed by next (http2configureTransports.upgradeFn). - // return nil, e.RoundTripErr() - // } - // return &persistConn{t: t, cacheKey: pconn.cacheKey, alt: alt}, nil - // } - //} - pconn.closeErr = errReadLoopExiting - pconn.tryPutIdleConn = func() bool { - if err := pconn.t.tryPutIdleConn(pconn); err != nil { - pconn.closeErr = err - //if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { - // trace.PutIdleConn(err) - //} - return false - } - //if trace != nil && trace.PutIdleConn != nil { - // trace.PutIdleConn(nil) - //} - return true - } select { case <-timeoutch: err = errors.New("[t.dialConn] request timeout") + if debugSwitch { + println("############### dialConn: timeoutch") + } pconn.close(err) return nil, err default: @@ -1096,11 +1097,12 @@ func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn * return pconn, nil } -func (t *Transport) dial(addr string) (*connData, error) { +func (t *Transport) dial(cm connectMethod) (*connData, error) { if debugSwitch { println("############### dial start") defer println("############### dial end") } + addr := cm.addr() host, port, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -1108,8 +1110,8 @@ func (t *Transport) dial(addr string) (*connData, error) { conn := new(connData) - libuv.InitTcp(t.loop, &conn.TcpHandle) - (*libuv.Handle)(c.Pointer(&conn.TcpHandle)).SetData(c.Pointer(conn)) + libuv.InitTcp(cm.eventLoop.loop, &conn.tcpHandle) + (*libuv.Handle)(c.Pointer(&conn.tcpHandle)).SetData(c.Pointer(conn)) var hints cnet.AddrInfo c.Memset(c.Pointer(&hints), 0, unsafe.Sizeof(hints)) @@ -1122,8 +1124,8 @@ func (t *Transport) dial(addr string) (*connData, error) { return nil, fmt.Errorf("getaddrinfo error\n") } - (*libuv.Req)(c.Pointer(&conn.ConnectReq)).SetData(c.Pointer(conn)) - status = libuv.TcpConnect(&conn.ConnectReq, &conn.TcpHandle, res.Addr, onConnect) + (*libuv.Req)(c.Pointer(&conn.connectReq)).SetData(c.Pointer(conn)) + status = libuv.TcpConnect(&conn.connectReq, &conn.tcpHandle, res.Addr, onConnect) if status != 0 { return nil, fmt.Errorf("connect error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(status)))) } @@ -1179,33 +1181,32 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err resc: resc, } - if pc.client == nil && !pc.isReused() { - // Hookup the IO - hyperIo := newHyperIo(pc.conn) - // We need an executor generally to poll futures - // Prepare client options - opts := hyper.NewClientConnOptions() - opts.Exec(pc.t.exec) - // send the handshake - handshakeTask := hyper.Handshake(hyperIo, opts) - taskData.taskId = handshake - handshakeTask.SetUserdata(c.Pointer(taskData)) - // Send the request to readWriteLoop(). - pc.t.exec.Push(handshakeTask) - } else { - taskData.taskId = read - err = req.write(pc.client, taskData, pc.t.exec) - if err != nil { - writeErrCh <- err - } - } + //if pc.client == nil && !pc.isReused() { + // Hookup the IO + hyperIo := newHyperIo(pc.conn) + // We need an executor generally to poll futures + // Prepare client options + opts := hyper.NewClientConnOptions() + opts.Exec(pc.eventLoop.exec) + // send the handshake + handshakeTask := hyper.Handshake(hyperIo, opts) + taskData.taskId = handshake + handshakeTask.SetUserdata(c.Pointer(taskData), nil) + // Send the request to readWriteLoop(). + pc.eventLoop.exec.Push(handshakeTask) + //} else { + // println("############### roundTrip: pc.client != nil") + // taskData.taskId = read + // err = req.write(pc.client, taskData, pc.eventLoop.exec) + // if err != nil { + // writeErrCh <- err + // pc.close(err) + // } + //} // Wake up libuv. Loop - pc.t.async.Send() + pc.eventLoop.async.Send() - //var respHeaderTimer <-chan time.Time - //cancelChan := req.Request.Cancel - //ctxDoneChan := req.Context().Done() timeoutch := req.timeoutch pcClosed := pc.closech canceled := false @@ -1221,6 +1222,9 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err println("############### roundTrip: writeErrch") } if err != nil { + if debugSwitch { + println("############### roundTrip: writeErrch err != nil") + } pc.close(fmt.Errorf("write error: %w", err)) if pc.conn.nwrite == startBytesWritten { err = nothingWrittenError{err} @@ -1247,13 +1251,6 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err return nil, pc.mapRoundTripError(req, startBytesWritten, re.err) } return re.res, nil - //case <-cancelChan: - // canceled = pc.t.cancelRequest(req.cancelKey, errRequestCanceled) - // cancelChan = nil - //case <-ctxDoneChan: - // canceled = pc.t.cancelRequest(req.cancelKey, req.Context().Err()) - // cancelChan = nil - // ctxDoneChan = nil case <-timeoutch: if debugSwitch { println("############### roundTrip: timeoutch") @@ -1267,21 +1264,21 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // readWriteLoop handles the main I/O loop for a persistent connection. // It processes incoming requests, sends them to the server, and handles responses. -func readWriteLoop(checker *libuv.Check) { - t := (*Transport)((*libuv.Handle)(c.Pointer(checker)).GetData()) +func readWriteLoop(checker *libuv.Idle) { + eventLoop := (*clientEventLoop)((*libuv.Handle)(c.Pointer(checker)).GetData()) // The polling state machine! Poll all ready tasks and act on them... - task := t.exec.Poll() + task := eventLoop.exec.Poll() for task != nil { if debugSwitch { println("############### polling") } - t.handleTask(task) - task = t.exec.Poll() + eventLoop.handleTask(task) + task = eventLoop.exec.Poll() } } -func (t *Transport) handleTask(task *hyper.Task) { +func (eventLoop *clientEventLoop) handleTask(task *hyper.Task) { taskData := (*taskData)(task.Userdata()) if taskData == nil { // A background task for hyper_client completed... @@ -1293,7 +1290,10 @@ func (t *Transport) handleTask(task *hyper.Task) { // If original taskId is set, we need to check it err = checkTaskType(task, taskData) if err != nil { - readLoopDefer(pc, true) + if debugSwitch { + println("############### handleTask: checkTaskType err != nil") + } + closeAndRemoveIdleConn(pc, true) return } switch taskData.taskId { @@ -1313,13 +1313,16 @@ func (t *Transport) handleTask(task *hyper.Task) { pc.client = (*hyper.ClientConn)(task.Value()) task.Free() - // TODO(spongehah) Proxy(writeLoop) + // TODO(hah) Proxy(writeLoop) taskData.taskId = read - err = taskData.req.Request.write(pc.client, taskData, t.exec) + err = taskData.req.Request.write(pc.client, taskData, eventLoop.exec) if err != nil { //pc.writeErrCh <- err // to the body reader, which might recycle us taskData.writeErrCh <- err // to the roundTrip function + if debugSwitch { + println("############### handleTask: write err != nil") + } pc.close(err) return } @@ -1332,6 +1335,20 @@ func (t *Transport) handleTask(task *hyper.Task) { println("############### read") } + pc.tryPutIdleConn = func() bool { + if err := pc.t.tryPutIdleConn(pc); err != nil { + pc.closeErr = err + //if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { + // trace.PutIdleConn(err) + //} + return false + } + //if trace != nil && trace.PutIdleConn != nil { + // trace.PutIdleConn(nil) + //} + return true + } + // Take the results hyperResp := (*hyper.Response)(task.Value()) task.Free() @@ -1340,7 +1357,10 @@ func (t *Transport) handleTask(task *hyper.Task) { if pc.numExpectedResponses == 0 { pc.readLoopPeekFailLocked(hyperResp, err) pc.mu.Unlock() - readLoopDefer(pc, true) + if debugSwitch { + println("############### handleTask: numExpectedResponses == 0") + } + closeAndRemoveIdleConn(pc, true) return } //pc.mu.Unlock() @@ -1361,20 +1381,25 @@ func (t *Transport) handleTask(task *hyper.Task) { hyperResp.Free() if err != nil { + pc.bodyChunk.closeWithError(err) + taskData.closeHyperBody() select { case taskData.resc <- responseAndError{err: err}: case <-taskData.callerGone: - readLoopDefer(pc, true) + if debugSwitch { + println("############### handleTask read: callerGone") + } + closeAndRemoveIdleConn(pc, true) return } - readLoopDefer(pc, true) + if debugSwitch { + println("############### handleTask: read err != nil") + } + closeAndRemoveIdleConn(pc, true) return } - dataTask := taskData.hyperBody.Data() taskData.taskId = readBodyChunk - dataTask.SetUserdata(c.Pointer(taskData)) - t.exec.Push(dataTask) if !taskData.req.deadline.IsZero() { (*timeoutData)((*libuv.Handle)(c.Pointer(taskData.req.timer)).GetData()).taskData = taskData @@ -1391,21 +1416,18 @@ func (t *Transport) handleTask(task *hyper.Task) { resp.wrapRespBody(taskData) - // FIXME: Waiting for the channel bug to be fixed - //select { - //case taskData.resc <- responseAndError{res: resp}: - //case <-taskData.callerGone: - // // defer - // readLoopDefer(pc, true) - // return - //} select { + case taskData.resc <- responseAndError{res: resp}: case <-taskData.callerGone: - readLoopDefer(pc, true) + // defer + if debugSwitch { + println("############### handleTask read: callerGone 2") + } + pc.bodyChunk.Close() + taskData.closeHyperBody() + closeAndRemoveIdleConn(pc, true) return - default: } - taskData.resc <- responseAndError{res: resp} if debugReadWriteLoop { println("############### read end") @@ -1433,14 +1455,16 @@ func (t *Transport) handleTask(task *hyper.Task) { // taskType == taskEmpty (check in checkTaskType) task.Free() - taskData.hyperBody.Free() - taskData.hyperBody = nil - pc.bodyChunk.Close() - replaced := t.replaceReqCanceler(taskData.req.cancelKey, nil) // before pc might return to idle pool + pc.bodyChunk.closeWithError(io.EOF) + taskData.closeHyperBody() + replaced := pc.t.replaceReqCanceler(taskData.req.cancelKey, nil) // before pc might return to idle pool pc.alive = pc.alive && replaced && pc.tryPutIdleConn() - readLoopDefer(pc, false) + if debugSwitch { + println("############### handleTask readBodyChunk: alive: ", pc.alive) + } + closeAndRemoveIdleConn(pc, false) if debugReadWriteLoop { println("############### readBodyChunk end [empty]") @@ -1449,18 +1473,20 @@ func (t *Transport) handleTask(task *hyper.Task) { } func readyToRead(aysnc *libuv.Async) { - println("############### AsyncCb: readyToRead") taskData := (*taskData)(aysnc.GetData()) dataTask := taskData.hyperBody.Data() - dataTask.SetUserdata(c.Pointer(taskData)) - taskData.pc.t.exec.Push(dataTask) + dataTask.SetUserdata(c.Pointer(taskData), nil) + taskData.pc.eventLoop.exec.Push(dataTask) } -// readLoopDefer Replace the defer function of readLoop in stdlib -func readLoopDefer(pc *persistConn, force bool) { +// closeAndRemoveIdleConn Replace the defer function of readLoop in stdlib +func closeAndRemoveIdleConn(pc *persistConn, force bool) { if pc.alive == true && !force { return } + if debugSwitch { + println("############### closeAndRemoveIdleConn, force:", force) + } pc.close(pc.closeErr) pc.t.removeIdleConn(pc) } @@ -1468,13 +1494,14 @@ func readLoopDefer(pc *persistConn, force bool) { // ---------------------------------------------------------- type connData struct { - TcpHandle libuv.Tcp - ConnectReq libuv.Connect - ReadBuf libuv.Buf - ReadBufFilled uintptr + tcpHandle libuv.Tcp + connectReq libuv.Connect + readBuf libuv.Buf + readBufFilled uintptr nwrite int64 // bytes written(Replaced from persistConn's nwrite) - ReadWaker *hyper.Waker - WriteWaker *hyper.Waker + readWaker *hyper.Waker + writeWaker *hyper.Waker + isClosing atomic.Bool } type taskData struct { @@ -1497,24 +1524,32 @@ const ( readBodyChunk ) -func (conn *connData) Close() error { - if conn == nil { - return nil - } - if conn.ReadWaker != nil { - conn.ReadWaker.Free() - conn.ReadWaker = nil - } - if conn.WriteWaker != nil { - conn.WriteWaker.Free() - conn.WriteWaker = nil +func (conn *connData) Close() { + if conn != nil && !conn.isClosing.Swap(true) { + if conn.readWaker != nil { + conn.readWaker.Free() + conn.readWaker = nil + } + if conn.writeWaker != nil { + conn.writeWaker.Free() + conn.writeWaker = nil + } + //if conn.readBuf.Base != nil { + // c.Free(c.Pointer(conn.readBuf.Base)) + // conn.readBuf.Base = nil + //} + if (*libuv.Handle)(c.Pointer(&conn.tcpHandle)).IsClosing() == 0 { + (*libuv.Handle)(c.Pointer(&conn.tcpHandle)).Close(nil) + } + conn = nil } - if conn.ReadBuf.Base != nil { - c.Free(c.Pointer(conn.ReadBuf.Base)) - conn.ReadBuf.Base = nil +} + +func (d *taskData) closeHyperBody() { + if d.hyperBody != nil { + d.hyperBody.Free() + d.hyperBody = nil } - (*libuv.Handle)(c.Pointer(&conn.TcpHandle)).Close(nil) - return nil } // onConnect is the libuv callback for a successful connection @@ -1524,24 +1559,28 @@ func onConnect(req *libuv.Connect, status c.Int) { defer println("############### connect end") } conn := (*connData)((*libuv.Req)(c.Pointer(req)).GetData()) - if status < 0 { - c.Fprintf(c.Stderr, c.Str("connect error: %d\n"), libuv.Strerror(libuv.Errno(status))) + c.Fprintf(c.Stderr, c.Str("connect error: %s\n"), c.GoString(libuv.Strerror(libuv.Errno(status)))) + conn.Close() return } - (*libuv.Stream)(c.Pointer(&conn.TcpHandle)).StartRead(allocBuffer, onRead) + + // Keep-Alive + conn.tcpHandle.KeepAlive(1, 60) + + (*libuv.Stream)(c.Pointer(&conn.tcpHandle)).StartRead(allocBuffer, onRead) } // allocBuffer allocates a buffer for reading from a socket func allocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { conn := (*connData)(handle.GetData()) - if conn.ReadBuf.Base == nil { - conn.ReadBuf = libuv.InitBuf((*c.Char)(c.Malloc(suggestedSize)), c.Uint(suggestedSize)) - //base := make([]byte, suggestedSize) - //conn.ReadBuf = libuv.InitBuf((*c.Char)(c.Pointer(&base[0])), c.Uint(suggestedSize)) - conn.ReadBufFilled = 0 + if conn.readBuf.Base == nil { + //conn.readBuf = libuv.InitBuf((*c.Char)(c.Malloc(suggestedSize)), c.Uint(suggestedSize)) + base := make([]byte, suggestedSize) + conn.readBuf = libuv.InitBuf((*c.Char)(c.Pointer(&base[0])), c.Uint(suggestedSize)) + conn.readBufFilled = 0 } - *buf = libuv.InitBuf((*c.Char)(c.Pointer(uintptr(c.Pointer(conn.ReadBuf.Base))+conn.ReadBufFilled)), c.Uint(suggestedSize-conn.ReadBufFilled)) + *buf = libuv.InitBuf((*c.Char)(c.Pointer(uintptr(c.Pointer(conn.readBuf.Base))+conn.readBufFilled)), c.Uint(suggestedSize-conn.readBufFilled)) } // onRead is the libuv callback for reading from a socket @@ -1549,38 +1588,39 @@ func allocBuffer(handle *libuv.Handle, suggestedSize uintptr, buf *libuv.Buf) { func onRead(stream *libuv.Stream, nread c.Long, buf *libuv.Buf) { conn := (*connData)((*libuv.Handle)(c.Pointer(stream)).GetData()) if nread > 0 { - conn.ReadBufFilled += uintptr(nread) + conn.readBufFilled += uintptr(nread) } - if conn.ReadWaker != nil { + if conn.readWaker != nil { // Wake up the pending read operation of Hyper - conn.ReadWaker.Wake() - conn.ReadWaker = nil + conn.readWaker.Wake() + conn.readWaker = nil } } // readCallBack read callback function for Hyper library func readCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uintptr) uintptr { conn := (*connData)(userdata) - if conn.ReadBufFilled > 0 { + if conn.readBufFilled > 0 { var toCopy uintptr - if bufLen < conn.ReadBufFilled { + if bufLen < conn.readBufFilled { toCopy = bufLen } else { - toCopy = conn.ReadBufFilled + toCopy = conn.readBufFilled } // Copy data from read buffer to Hyper's buffer - c.Memcpy(c.Pointer(buf), c.Pointer(conn.ReadBuf.Base), toCopy) + c.Memcpy(c.Pointer(buf), c.Pointer(conn.readBuf.Base), toCopy) // Move remaining data to the beginning of the buffer - c.Memmove(c.Pointer(conn.ReadBuf.Base), c.Pointer(uintptr(c.Pointer(conn.ReadBuf.Base))+toCopy), conn.ReadBufFilled-toCopy) + c.Memmove(c.Pointer(conn.readBuf.Base), c.Pointer(uintptr(c.Pointer(conn.readBuf.Base))+toCopy), conn.readBufFilled-toCopy) // Update the amount of filled buffer - conn.ReadBufFilled -= toCopy + conn.readBufFilled -= toCopy return toCopy } - if conn.ReadWaker != nil { - conn.ReadWaker.Free() + if conn.readWaker != nil { + conn.readWaker.Free() } - conn.ReadWaker = ctx.Waker() + conn.readWaker = ctx.Waker() + println("############### readCallBack: IoPending") return hyper.IoPending } @@ -1588,10 +1628,10 @@ func readCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uin // Callback function called after a write operation completes func onWrite(req *libuv.Write, status c.Int) { conn := (*connData)((*libuv.Req)(c.Pointer(req)).GetData()) - if conn.WriteWaker != nil { + if conn.writeWaker != nil { // Wake up the pending write operation - conn.WriteWaker.Wake() - conn.WriteWaker = nil + conn.writeWaker.Wake() + conn.writeWaker = nil } } @@ -1602,16 +1642,17 @@ func writeCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen ui req := &libuv.Write{} (*libuv.Req)(c.Pointer(req)).SetData(c.Pointer(conn)) - ret := req.Write((*libuv.Stream)(c.Pointer(&conn.TcpHandle)), &initBuf, 1, onWrite) + ret := req.Write((*libuv.Stream)(c.Pointer(&conn.tcpHandle)), &initBuf, 1, onWrite) if ret >= 0 { conn.nwrite += int64(bufLen) return bufLen } - if conn.WriteWaker != nil { - conn.WriteWaker.Free() + if conn.writeWaker != nil { + conn.writeWaker.Free() } - conn.WriteWaker = ctx.Waker() + conn.writeWaker = ctx.Waker() + println("############### writeCallBack: IoPending") return hyper.IoPending } @@ -1630,14 +1671,14 @@ func onTimeout(timer *libuv.Timer) { pc := taskData.pc pc.alive = false pc.t.cancelRequest(taskData.req.cancelKey, errors.New("timeout: req.Context().Err()")) - readLoopDefer(pc, true) + closeAndRemoveIdleConn(pc, true) } } // newHyperIo creates a new IO with read and write callbacks func newHyperIo(connData *connData) *hyper.Io { hyperIo := hyper.NewIo() - hyperIo.SetUserdata(c.Pointer(connData)) + hyperIo.SetUserdata(c.Pointer(connData), nil) hyperIo.SetRead(readCallBack) hyperIo.SetWrite(writeCallBack) return hyperIo @@ -1670,8 +1711,16 @@ func checkTaskType(task *hyper.Task, taskData *taskData) (err error) { task.Free() if curTaskId == handshake || curTaskId == read { taskData.writeErrCh <- err + if debugSwitch { + println("############### checkTaskType: writeErrCh") + } taskData.pc.close(err) } + if taskData.pc.bodyChunk != nil { + taskData.pc.bodyChunk.Close() + taskData.pc.bodyChunk = nil + } + taskData.closeHyperBody() taskData.pc.alive = false } return @@ -1685,6 +1734,7 @@ func fail(err *hyper.Error, taskId taskId) error { errLen := err.Print((*uint8)(c.Pointer(&errBuf[:][0])), uintptr(len(errBuf))) errDetails := unsafe.SliceData(errBuf[:errLen]) details := c.GoString(errDetails) + fmt.Println(details) // clean up the error err.Free() @@ -1837,7 +1887,9 @@ type persistConn struct { // If it's non-nil, the rest of the fields are unused. alt RoundTripper - t *Transport + t *Transport + eventLoop *clientEventLoop + cacheKey connectMethodKey conn *connData //tlsState *tls.ConnectionState @@ -1876,6 +1928,9 @@ type persistConn struct { // a "keep-alive" state. It does not interrupt any connections currently // in use. func (t *Transport) CloseIdleConnections() { + if debugSwitch { + println("############### CloseIdleConnections") + } //t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) t.idleMu.Lock() m := t.idleConn @@ -1888,12 +1943,16 @@ func (t *Transport) CloseIdleConnections() { pconn.close(errCloseIdleConns) } } + //if t2 := t.h2transport; t2 != nil { // t2.CloseIdleConnections() //} } func (pc *persistConn) cancelRequest(err error) { + if debugSwitch { + println("############### cancelRequest") + } pc.mu.Lock() defer pc.mu.Unlock() pc.canceledErr = err @@ -1938,8 +1997,14 @@ func (pc *persistConn) closeLocked(err error) { } close(pc.closech) close(pc.writeLoopDone) - pc.client.Free() - pc.chunkAsync.Close(nil) + if pc.client != nil { + pc.client.Free() + pc.client = nil + } + if pc.chunkAsync != nil && pc.chunkAsync.IsClosing() == 0 { + pc.chunkAsync.Close(nil) + pc.chunkAsync = nil + } } } pc.mutateHeaderFunc = nil @@ -2096,10 +2161,16 @@ func (pc *persistConn) closeConnIfStillIdleLocked() { return } t.removeIdleConnLocked(pc) + if debugSwitch { + println("############### closeConnIfStillIdleLocked") + } pc.close(errIdleConnTimeout) } func (pc *persistConn) readLoopPeekFailLocked(resp *hyper.Response, err error) { + if debugSwitch { + println("############### readLoopPeekFailLocked") + } if pc.closed != nil { return } @@ -2117,7 +2188,7 @@ func (pc *persistConn) setExtraHeaders(req *transportRequest) bool { // uncompress the gzip stream if we were the layer that // requested it. requestedGzip := false - // TODO(spongehah) gzip(pc.roundTrip) + // TODO(hah) gzip(pc.roundTrip): The compress/gzip library still has a bug. An exception occurs when calling gzip.NewReader(). //if !pc.t.DisableCompression && // req.Header.Get("Accept-Encoding") == "" && // req.Header.Get("Range") == "" && @@ -2190,6 +2261,8 @@ type connectMethod struct { // be reused for different targetAddr values. targetAddr string onlyH1 bool // whether to disable HTTP/2 and force HTTP/1 + + eventLoop *clientEventLoop } // connectMethodKey is the map key version of connectMethod, with a From 49afcb557fd0a1821a0f5c036d3a7cc0f41135bd Mon Sep 17 00:00:00 2001 From: hackerchai Date: Fri, 20 Sep 2024 14:50:17 +0800 Subject: [PATCH 49/55] fix(x/net/http): Fix import name and adapt to future merge with client Signed-off-by: hackerchai --- x/net/http/_demo/http.go | 8 +- x/net/http/header.go | 237 ++++++++++++++++++++++++++++++++++++++- x/net/http/request.go | 12 +- x/net/http/response.go | 2 +- x/net/http/server.go | 102 ++++++----------- x/net/ipsock.go | 17 ++- 6 files changed, 297 insertions(+), 81 deletions(-) diff --git a/x/net/http/_demo/http.go b/x/net/http/_demo/http.go index 19e62d5..8386873 100644 --- a/x/net/http/_demo/http.go +++ b/x/net/http/_demo/http.go @@ -3,7 +3,7 @@ package main import ( "fmt" - "github.com/goplus/llgo/x/net/http" + "github.com/goplus/llgoexamples/x/net/http" ) func echoHandler(w http.ResponseWriter, r *http.Request) { @@ -17,6 +17,10 @@ func echoHandler(w http.ResponseWriter, r *http.Request) { fmt.Printf(">> URL: %s\n", r.URL.String()) fmt.Printf(">> RemoteAddr: %s\n", r.RemoteAddr) + w.Header().Set("Content-Type", "text/plain") + w.Write([]byte("Hello, World!")) + + //TODO(hackerchai): temporarily unable to do blocking operation in handler due to imperfections of goroutine // body, err := io.ReadAll(r.Body) // if err != nil { // http.Error(w, "Error reading request body", http.StatusInternalServerError) @@ -27,8 +31,6 @@ func echoHandler(w http.ResponseWriter, r *http.Request) { // w.Header().Set("Content-Type", "text/plain") // w.Write(body) - w.Header().Set("Content-Type", "text/plain") - w.Write([]byte("Hello, World!")) } func main() { diff --git a/x/net/http/header.go b/x/net/http/header.go index 5579972..ef45acf 100644 --- a/x/net/http/header.go +++ b/x/net/http/header.go @@ -1,22 +1,253 @@ package http +import ( + "fmt" + "net/textproto" + "sort" + "strings" + "sync" + + "github.com/goplus/llgo/c" + "github.com/goplus/llgoexamples/rust/hyper" +) + +// A Header represents the key-value pairs in an HTTP header. +// +// The keys should be in canonical form, as returned by +// CanonicalHeaderKey. type Header map[string][]string +// Add adds the key, value pair to the header. +// It appends to any existing values associated with key. +// The key is case insensitive; it is canonicalized by +// CanonicalHeaderKey. func (h Header) Add(key, value string) { - h[key] = append(h[key], value) + textproto.MIMEHeader(h).Add(key, value) } +// Set sets the header entries associated with key to the +// single element value. It replaces any existing values +// associated with key. The key is case insensitive; it is +// canonicalized by textproto.CanonicalMIMEHeaderKey. +// To use non-canonical keys, assign to the map directly. func (h Header) Set(key, value string) { - h[key] = []string{value} + textproto.MIMEHeader(h).Set(key, value) } +// Get gets the first value associated with the given key. If +// there are no values associated with the key, Get returns "". +// It is case insensitive; textproto.CanonicalMIMEHeaderKey is +// used to canonicalize the provided key. Get assumes that all +// keys are stored in canonical form. To use non-canonical keys, +// access the map directly. func (h Header) Get(key string) string { + return textproto.MIMEHeader(h).Get(key) +} + +// Values returns all values associated with the given key. +// It is case insensitive; textproto.CanonicalMIMEHeaderKey is +// used to canonicalize the provided key. To use non-canonical +// keys, access the map directly. +// The returned slice is not a copy. +func (h Header) Values(key string) []string { + return textproto.MIMEHeader(h).Values(key) +} + +// get is like Get, but key must already be in CanonicalHeaderKey form. +func (h Header) get(key string) string { if v := h[key]; len(v) > 0 { return v[0] } return "" } +// has reports whether h has the provided key defined, even if it's +// set to 0-length slice. +func (h Header) has(key string) bool { + _, ok := h[key] + return ok +} + +// Del deletes the values associated with key. +// The key is case insensitive; it is canonicalized by +// CanonicalHeaderKey. func (h Header) Del(key string) { - delete(h, key) + textproto.MIMEHeader(h).Del(key) +} + +// Clone returns a copy of h or nil if h is nil. +func (h Header) Clone() Header { + if h == nil { + return nil + } + + // Find total number of values. + nv := 0 + for _, vv := range h { + nv += len(vv) + } + sv := make([]string, nv) // shared backing array for headers' values + h2 := make(Header, len(h)) + for k, vv := range h { + if vv == nil { + // Preserve nil values. ReverseProxy distinguishes + // between nil and zero-length header values. + h2[k] = nil + continue + } + n := copy(sv, vv) + h2[k] = sv[:n:n] + sv = sv[n:] + } + return h2 +} + +// sortedKeyValues returns h's keys sorted in the returned kvs +// slice. The headerSorter used to sort is also returned, for possible +// return to headerSorterCache. +func (h Header) sortedKeyValues(exclude map[string]bool) (kvs []keyValues, hs *headerSorter) { + hs = headerSorterPool.Get().(*headerSorter) + if cap(hs.kvs) < len(h) { + hs.kvs = make([]keyValues, 0, len(h)) + } + kvs = hs.kvs[:0] + for k, vv := range h { + if !exclude[k] { + kvs = append(kvs, keyValues{k, vv}) + } + } + hs.kvs = kvs + sort.Sort(hs) + return kvs, hs +} + +// Write writes a header in wire format. +func (h Header) Write(reqHeaders *hyper.Headers) error { + return h.write(reqHeaders) +} + +func (h Header) write(reqHeaders *hyper.Headers) error { + return h.writeSubset(reqHeaders, nil) +} + +// WriteSubset writes a header in wire format. +// If exclude is not nil, keys where exclude[key] == true are not written. +// Keys are not canonicalized before checking the exclude map. +func (h Header) WriteSubset(reqHeaders *hyper.Headers, exclude map[string]bool) error { + return h.writeSubset(reqHeaders, exclude) +} + +func (h Header) writeSubset(reqHeaders *hyper.Headers, exclude map[string]bool) error { + kvs, sorter := h.sortedKeyValues(exclude) + for _, kv := range kvs { + if !ValidHeaderFieldName(kv.key) { + // This could be an error. In the common case of + // writing response headers, however, we have no good + // way to provide the error back to the server + // handler, so just drop invalid headers instead. + continue + } + for _, v := range kv.values { + v = headerNewlineToSpace.Replace(v) + v = textproto.TrimString(v) + if reqHeaders.Add(&[]byte(kv.key)[0], c.Strlen(c.AllocaCStr(kv.key)), &[]byte(v)[0], c.Strlen(c.AllocaCStr(v))) != hyper.OK { + headerSorterPool.Put(sorter) + return fmt.Errorf("error adding header %s: %s\n", kv.key, v) + } + //if trace != nil && trace.WroteHeaderField != nil { + // formattedVals = append(formattedVals, v) + //} + } + //if trace != nil && trace.WroteHeaderField != nil { + // trace.WroteHeaderField(kv.key, formattedVals) + // formattedVals = nil + //} + } + + headerSorterPool.Put(sorter) + return nil +} + +var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ") + +type keyValues struct { + key string + values []string +} + +// A headerSorter implements sort.Interface by sorting a []keyValues +// by key. It's used as a pointer, so it can fit in a sort.Interface +// interface value without allocation. +type headerSorter struct { + kvs []keyValues +} + +func (s *headerSorter) Len() int { return len(s.kvs) } +func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] } +func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key } + +var headerSorterPool = sync.Pool{ + New: func() any { return new(headerSorter) }, +} + +// CanonicalHeaderKey returns the canonical format of the +// header key s. The canonicalization converts the first +// letter and any letter following a hyphen to upper case; +// the rest are converted to lowercase. For example, the +// canonical key for "accept-encoding" is "Accept-Encoding". +// If s contains a space or invalid header field bytes, it is +// returned without modifications. +func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) } + +// hasToken reports whether token appears with v, ASCII +// case-insensitive, with space or comma boundaries. +// token must be all lowercase. +// v may contain mixed cased. +func hasToken(v, token string) bool { + if len(token) > len(v) || token == "" { + return false + } + if v == token { + return true + } + for sp := 0; sp <= len(v)-len(token); sp++ { + // Check that first character is good. + // The token is ASCII, so checking only a single byte + // is sufficient. We skip this potential starting + // position if both the first byte and its potential + // ASCII uppercase equivalent (b|0x20) don't match. + // False positives ('^' => '~') are caught by EqualFold. + if b := v[sp]; b != token[0] && b|0x20 != token[0] { + continue + } + // Check that start pos is on a valid token boundary. + if sp > 0 && !isTokenBoundary(v[sp-1]) { + continue + } + // Check that end pos is on a valid token boundary. + if endPos := sp + len(token); endPos != len(v) && !isTokenBoundary(v[endPos]) { + continue + } + if EqualFold(v[sp:sp+len(token)], token) { + return true + } + } + return false +} + +func isTokenBoundary(b byte) bool { + return b == ' ' || b == ',' || b == '\t' +} + +// appendToResponseHeader (HeadersForEachCallback) prints each header to the console +func appendToResponseHeader(userdata c.Pointer, name *uint8, nameLen uintptr, value *uint8, valueLen uintptr) c.Int { + resp := (*Response)(userdata) + nameStr := c.GoString((*int8)(c.Pointer(name)), nameLen) + valueStr := c.GoString((*int8)(c.Pointer(value)), valueLen) + + if resp.Header == nil { + resp.Header = make(Header) + } + resp.Header.Add(nameStr, valueStr) + return hyper.IterContinue } \ No newline at end of file diff --git a/x/net/http/request.go b/x/net/http/request.go index 0d4a600..8142ac2 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -4,15 +4,14 @@ import ( "fmt" "io" - //"mime/multipart" "net/url" "strings" "time" "unsafe" "github.com/goplus/llgo/c" - "github.com/goplus/llgo/rust/hyper" "github.com/goplus/llgo/c/libuv" + "github.com/goplus/llgoexamples/rust/hyper" ) type Request struct { @@ -33,14 +32,18 @@ type Request struct { // MultipartForm *multipart.Form RemoteAddr string RequestURI string - timeout time.Duration + + Response *Response + + deadline time.Time + timeoutch chan struct{} + timer *libuv.Timer } func readRequest(executor *hyper.Executor, hyperReq *hyper.Request, requestNotifyHandle *libuv.Async, remoteAddr string) (*Request, error) { println("[debug] readRequest called") req := Request{ Header: make(Header), - timeout: 0, Body: nil, } req.RemoteAddr = remoteAddr @@ -136,7 +139,6 @@ func readRequest(executor *hyper.Executor, hyperReq *hyper.Request, requestNotif body := hyperReq.Body() if body != nil { - //task := body.Data() taskFlag := getBodyTask requestBody := newRequestBody(requestNotifyHandle) diff --git a/x/net/http/response.go b/x/net/http/response.go index 7798925..b7794e3 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -6,7 +6,7 @@ import ( "github.com/goplus/llgo/c" "github.com/goplus/llgo/c/os" - "github.com/goplus/llgo/rust/hyper" + "github.com/goplus/llgoexamples/rust/hyper" ) type response struct { diff --git a/x/net/http/server.go b/x/net/http/server.go index ade6c5f..d530645 100644 --- a/x/net/http/server.go +++ b/x/net/http/server.go @@ -14,18 +14,24 @@ import ( cnet "github.com/goplus/llgo/c/net" cos "github.com/goplus/llgo/c/os" "github.com/goplus/llgo/c/syscall" - "github.com/goplus/llgo/rust/hyper" - "github.com/goplus/llgo/x/net" + "github.com/goplus/llgoexamples/rust/hyper" + "github.com/goplus/llgoexamples/x/net" ) +// maxPostHandlerReadBytes is the max number of Request.Body bytes not +// consumed by a handler that the server will read from the client +// in order to keep a connection alive. If there are more bytes than +// this then the server to be paranoid instead sends a "Connection: +// close" response. +// +// This number is approximately what a typical machine's TCP buffer +// size is anyway. (if we have the bytes on the machine, we might as +// well read them) +const maxPostHandlerReadBytes = 256 << 10 + const _SC_NPROCESSORS_ONLN c.Int = 58 var cpuCount int -var asyncHandleMapMu sync.Mutex -var asyncHandleMap = make(map[int64]*libuv.Async) - -// connID is used to generate unique IDs for each connection -var connID int64 type Handler interface { ServeHTTP(ResponseWriter, *Request) @@ -60,7 +66,6 @@ type eventLoop struct { } type conn struct { - asyncID int64 stream libuv.Tcp pollHandle libuv.Poll eventMask c.Uint @@ -72,9 +77,9 @@ type conn struct { } type serviceUserdata struct { - asyncHandleID c.Long - host [128]c.Char - port [8]c.Char + asyncHandle *libuv.Async + host [128]c.Char + port [8]c.Char executor *hyper.Executor } @@ -195,7 +200,6 @@ func ListenAndServe(addr string, handler Handler) error { } func (srv *Server) ListenAndServe() error { - connID = 0 cpuCount = int(c.Sysconf(_SC_NPROCESSORS_ONLN)) if cpuCount <= 0 { cpuCount = 4 @@ -221,17 +225,6 @@ func (srv *Server) ListenAndServe() error { return fmt.Errorf("invalid port number: %v", err) } - //TODO(hackerchai): new logic for poll - // go func() { - // for { - // task := el.executor.Poll() - // if task != nil { - // handleTask(task) - // task = el.executor.Poll() - // } - // } - // }() - errChan := make(chan error, len(srv.eventLoop)) var wg sync.WaitGroup @@ -260,7 +253,7 @@ func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) { func onNewConnection(serverStream *libuv.Stream, status c.Int) { fmt.Println("[debug] onNewConnection called") if status < 0 { - fmt.Printf("New connection error: %s\n", libuv.Strerror(libuv.Errno(status))) + fmt.Printf("New connection error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(status)))) return } @@ -280,11 +273,6 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { requestNotifyHandle := &libuv.Async{} el.uvLoop.Async(requestNotifyHandle, onAsync) - fmt.Println("[debug] async handle created") - asyncHandleMapMu.Lock() - asyncHandleMap[conn.asyncID] = requestNotifyHandle - asyncHandleMapMu.Unlock() - fmt.Println("[debug] async handle added to map") libuv.InitTcp(el.uvLoop, &conn.stream) conn.stream.Data = unsafe.Pointer(conn) @@ -306,7 +294,7 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { } userData.executor = el.executor - userData.asyncHandleID = c.Long(conn.asyncID) + userData.asyncHandle = requestNotifyHandle var addr cnet.SockaddrStorage addrlen := c.Int(unsafe.Sizeof(addr)) @@ -345,7 +333,7 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { io := createIo(conn) service := hyper.ServiceNew(serverCallback) - service.SetUserdata(unsafe.Pointer(userData), nil) + service.SetUserdata(unsafe.Pointer(userData), freeServiceUserdata) serverConn := hyper.ServeHttpXConnection(el.http1Opts, el.http2Opts, io, service) el.executor.Push(serverConn) @@ -391,12 +379,7 @@ func onIdle(handle *libuv.Idle) { } } -func doNothing(handle *libuv.Idle) { - return -} - func serverCallback(userData unsafe.Pointer, hyperReq *hyper.Request, channel *hyper.ResponseChannel) { - payload := (*serviceUserdata)(userData) if payload == nil { fmt.Fprintf(os.Stderr, "Error: Received null userData\n") @@ -406,35 +389,23 @@ func serverCallback(userData unsafe.Pointer, hyperReq *hyper.Request, channel *h executor := payload.executor if executor == nil { fmt.Fprintf(os.Stderr, "Error: Received null executor\n") - fmt.Printf("[debug] host: %s\n", c.GoString(&payload.host[0])) - fmt.Printf("[debug] port: %s\n", c.GoString(&payload.port[0])) return } - host := payload.host - port := payload.port - - if payload.asyncHandleID == 0 { - fmt.Fprintf(os.Stderr, "Error: Received null asyncHandleID\n") + requestNotifyHandle := payload.asyncHandle + if requestNotifyHandle == nil { + fmt.Fprintf(os.Stderr, "Error: Received null asyncHandle\n") return } - connID := int64(payload.asyncHandleID) - + host := payload.host + port := payload.port if hyperReq == nil { fmt.Fprintf(os.Stderr, "Error: Received null request\n") return } - asyncHandleMapMu.Lock() - requestNotifyHandle, ok := asyncHandleMap[connID] - asyncHandleMapMu.Unlock() - if !ok { - fmt.Println("[debug] requestNotifyHandle not found") - return - } - remoteAddr := c.GoString(&host[0]) + ":" + c.GoString(&port[0]) fmt.Printf("[debug] Remote address: %s\n", remoteAddr) @@ -447,21 +418,16 @@ func serverCallback(userData unsafe.Pointer, hyperReq *hyper.Request, channel *h res := newResponse(channel) fmt.Println("[debug] Response created") - //TODO(hackerchai): replace with no goroutine + //TODO(hackerchai): replace with goroutine to enable blocking operation in handler fmt.Println("[debug] Serving HTTP") DefaultServeMux.ServeHTTP(res, req) - //srv.Handler.ServeHTTP(res, req) fmt.Println("[debug] Response finalizing") res.finalize() fmt.Println("[debug] Response finalized") // go func() { - // fmt.Println("[debug] Serving HTTP") // DefaultServeMux.ServeHTTP(res, req) - // //srv.Handler.ServeHTTP(res, req) - // fmt.Println("[debug] Response finalizing") // res.finalize() - // fmt.Println("[debug] Response finalized") // }() } @@ -594,13 +560,20 @@ func createIo(conn *conn) *hyper.Io { } func createServiceUserdata() *serviceUserdata { - userdata := &serviceUserdata{} + userdata := (*serviceUserdata)(c.Calloc(1, unsafe.Sizeof(serviceUserdata{}))) if userdata == nil { fmt.Fprintf(os.Stderr, "Failed to allocate service_userdata\n") } return userdata } +func freeServiceUserdata(userdata c.Pointer) { + castUserdata := (*serviceUserdata)(userdata) + if castUserdata != nil { + c.Free(c.Pointer(castUserdata)) + } +} + func readCb(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uintptr) uintptr { conn := (*conn)(userdata) ret := cnet.Recv(conn.stream.GetIoWatcherFd(), unsafe.Pointer(buf), bufLen, 0) @@ -662,7 +635,7 @@ func onPoll(handle *libuv.Poll, status c.Int, events c.Int) { conn := (*conn)((*libuv.Handle)(unsafe.Pointer(handle)).GetData()) if status < 0 { - fmt.Fprintf(os.Stderr, "Poll error: %s\n", libuv.Strerror(libuv.Errno(status))) + fmt.Fprintf(os.Stderr, "Poll error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(status)))) return } @@ -709,7 +682,6 @@ func createConnData() (*conn, error) { } conn.isClosing.Store(false) conn.closedHandles = 0 - conn.asyncID = conn.asyncID + 1 return conn, nil } @@ -792,12 +764,6 @@ func (c *conn) Close() { if (*libuv.Handle)(unsafe.Pointer(&c.stream)).IsClosing() == 0 { (*libuv.Handle)(unsafe.Pointer(&c.stream)).Close(nil) } - - if asyncHandleMap[c.asyncID] != nil { - asyncHandleMapMu.Lock() - delete(asyncHandleMap, c.asyncID) - asyncHandleMapMu.Unlock() - } } } diff --git a/x/net/ipsock.go b/x/net/ipsock.go index 269d2b5..087eee6 100644 --- a/x/net/ipsock.go +++ b/x/net/ipsock.go @@ -1,5 +1,20 @@ package net +// JoinHostPort combines host and port into a network address of the +// form "host:port". If host contains a colon, as found in literal +// IPv6 addresses, then JoinHostPort returns "[host]:port". +// +// See func Dial for a description of the host and port parameters. +func JoinHostPort(host, port string) string { + // We assume that host is a literal IPv6 address if host has + // colons. + + if IndexByteString(host, ':') >= 0 { + return "[" + host + "]:" + port + } + return host + ":" + port +} + // SplitHostPort splits a network address of the form "host:port", // "host%zone:port", "[host]:port" or "[host%zone]:port" into host or // host%zone and port. @@ -20,7 +35,7 @@ func SplitHostPort(hostport string) (host, port string, err error) { j, k := 0, 0 // The port starts after the last colon. - i := LastIndexByteString(hostport, ':') + i := last(hostport, ':') if i < 0 { return addrErr(hostport, missingPort) } From 0abdeb040735f940945ec44f80ac5440b4cfb8b8 Mon Sep 17 00:00:00 2001 From: hackerchai Date: Fri, 20 Sep 2024 15:59:03 +0800 Subject: [PATCH 50/55] refactor(x/net/http/demo): Refactor file structure & rename Signed-off-by: hackerchai --- x/net/http/_demo/{http.go => server/server.go} | 2 +- x/net/http/_demo/{server => transfer}/chunkedServer.go | 0 x/net/http/_demo/{server => transfer}/redirectServer.go | 0 3 files changed, 1 insertion(+), 1 deletion(-) rename x/net/http/_demo/{http.go => server/server.go} (96%) rename x/net/http/_demo/{server => transfer}/chunkedServer.go (100%) rename x/net/http/_demo/{server => transfer}/redirectServer.go (100%) diff --git a/x/net/http/_demo/http.go b/x/net/http/_demo/server/server.go similarity index 96% rename from x/net/http/_demo/http.go rename to x/net/http/_demo/server/server.go index 8386873..cac52ea 100644 --- a/x/net/http/_demo/http.go +++ b/x/net/http/_demo/server/server.go @@ -3,7 +3,7 @@ package main import ( "fmt" - "github.com/goplus/llgoexamples/x/net/http" + "github.com/goplus/llgo/x/net/http" ) func echoHandler(w http.ResponseWriter, r *http.Request) { diff --git a/x/net/http/_demo/server/chunkedServer.go b/x/net/http/_demo/transfer/chunkedServer.go similarity index 100% rename from x/net/http/_demo/server/chunkedServer.go rename to x/net/http/_demo/transfer/chunkedServer.go diff --git a/x/net/http/_demo/server/redirectServer.go b/x/net/http/_demo/transfer/redirectServer.go similarity index 100% rename from x/net/http/_demo/server/redirectServer.go rename to x/net/http/_demo/transfer/redirectServer.go From 6fdc04026d41891a6239ee804c2ce8a7217551ba Mon Sep 17 00:00:00 2001 From: hackerchai Date: Fri, 20 Sep 2024 15:59:35 +0800 Subject: [PATCH 51/55] refactor(x/net/http): Rename into bodyStream Signed-off-by: hackerchai --- x/net/http/bodyChunk.go | 80 ------------------- .../http/{request_body.go => body_stream.go} | 16 ++-- 2 files changed, 8 insertions(+), 88 deletions(-) delete mode 100644 x/net/http/bodyChunk.go rename x/net/http/{request_body.go => body_stream.go} (78%) diff --git a/x/net/http/bodyChunk.go b/x/net/http/bodyChunk.go deleted file mode 100644 index 01d9e74..0000000 --- a/x/net/http/bodyChunk.go +++ /dev/null @@ -1,80 +0,0 @@ -package http - -import ( - "errors" - - "github.com/goplus/llgo/c/libuv" -) - -type bodyChunk struct { - chunk []byte - readCh chan []byte - asyncHandle *libuv.Async - - done chan struct{} - - rerr error -} - -var ( - errClosedBodyChunk = errors.New("bodyChunk: read/write on closed body") -) - -func newBodyChunk(asyncHandle *libuv.Async) *bodyChunk { - return &bodyChunk{ - readCh: make(chan []byte, 1), - done: make(chan struct{}), - asyncHandle: asyncHandle, - } -} - -func (bc *bodyChunk) Read(p []byte) (n int, err error) { - select { - case <-bc.done: - err = bc.readCloseError() - return - default: - } - - for n < len(p) { - if len(bc.chunk) == 0 { - bc.asyncHandle.Send() - select { - case chunk := <-bc.readCh: - bc.chunk = chunk - case <-bc.done: - err = bc.readCloseError() - return - } - } - - copied := copy(p[n:], bc.chunk) - n += copied - bc.chunk = bc.chunk[copied:] - } - - return -} - -func (bc *bodyChunk) Close() error { - return bc.closeWithError(nil) -} - -func (bc *bodyChunk) readCloseError() error { - if rerr := bc.rerr; rerr != nil { - return rerr - } - return errClosedBodyChunk -} - -func (bc *bodyChunk) closeWithError(err error) error { - if bc.rerr != nil { - return nil - } - if err == nil { - err = errClosedBodyChunk - } - bc.rerr = err - close(bc.done) - return nil -} diff --git a/x/net/http/request_body.go b/x/net/http/body_stream.go similarity index 78% rename from x/net/http/request_body.go rename to x/net/http/body_stream.go index e4424c5..64ebfa4 100644 --- a/x/net/http/request_body.go +++ b/x/net/http/body_stream.go @@ -7,7 +7,7 @@ import ( "github.com/goplus/llgo/c/libuv" ) -type requestBody struct { +type bodyStream struct { chunk []byte readCh chan []byte asyncHandle *libuv.Async @@ -21,15 +21,15 @@ var ( ErrClosedRequestBody = errors.New("request body: read/write on closed body") ) -func newRequestBody(asyncHandle *libuv.Async) *requestBody { - return &requestBody{ +func newBodyStream(asyncHandle *libuv.Async) *bodyStream { + return &bodyStream{ readCh: make(chan []byte, 1), done: make(chan struct{}), asyncHandle: asyncHandle, } } -func (rb *requestBody) Read(p []byte) (n int, err error) { +func (rb *bodyStream) Read(p []byte) (n int, err error) { fmt.Println("[debug] RequestBody Read called") select { case <-rb.done: @@ -60,14 +60,14 @@ func (rb *requestBody) Read(p []byte) (n int, err error) { return } -func (rb *requestBody) readCloseError() error { +func (rb *bodyStream) readCloseError() error { if rerr := rb.rerr; rerr != nil { return rerr } return ErrClosedRequestBody } -func (rb *requestBody) closeRead(err error) error { +func (rb *bodyStream) closeWithError(err error) error { fmt.Println("[debug] RequestBody closeRead called") if rb.rerr != nil { return nil @@ -80,6 +80,6 @@ func (rb *requestBody) closeRead(err error) error { return nil } -func (rb *requestBody) Close() error { - return rb.closeRead(nil) +func (rb *bodyStream) Close() error { + return rb.closeWithError(nil) } From 4bf50fee39f584241079d8ebbb712742932b49a4 Mon Sep 17 00:00:00 2001 From: hackerchai Date: Fri, 20 Sep 2024 16:00:31 +0800 Subject: [PATCH 52/55] refactor(x/net/http): Multiple fixes & adapt bodyStream struct Signed-off-by: hackerchai --- x/net/http/request.go | 147 ++++++++++++++++++++++++++++++++++++++-- x/net/http/response.go | 11 ++- x/net/http/server.go | 35 ++++++---- x/net/http/transport.go | 117 ++++++++++++++++---------------- x/net/http/util.go | 3 +- 5 files changed, 225 insertions(+), 88 deletions(-) diff --git a/x/net/http/request.go b/x/net/http/request.go index 37d6408..b049578 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -9,12 +9,12 @@ import ( "net/url" "strings" "time" - - "github.com/goplus/llgo/c/libuv" - "golang.org/x/net/idna" + "unsafe" "github.com/goplus/llgo/c" + "github.com/goplus/llgo/c/libuv" "github.com/goplus/llgoexamples/rust/hyper" + "golang.org/x/net/idna" ) type Request struct { @@ -230,7 +230,7 @@ func (r *Request) ProtoAtLeast(major, minor int) bool { // extraHeaders may be nil // waitForContinue may be nil // always closes body -func (r *Request) write(client *hyper.ClientConn, taskData *taskData, exec *hyper.Executor) (err error) { +func (r *Request) write(client *hyper.ClientConn, taskData *clientTaskData, exec *hyper.Executor) (err error) { //trace := httptrace.ContextClientTrace(r.Context()) //if trace != nil && trace.WroteRequest != nil { // defer func() { @@ -488,3 +488,142 @@ func valueOrDefault(value, def string) string { } return def } + +func readRequest(executor *hyper.Executor, hyperReq *hyper.Request, requestNotifyHandle *libuv.Async, remoteAddr string) (*Request, error) { + println("[debug] readRequest called") + req := Request{ + Header: make(Header), + Body: nil, + } + req.RemoteAddr = remoteAddr + + headers := hyperReq.Headers() + if headers != nil { + headers.Foreach(addHeader, unsafe.Pointer(&req)) + } else { + return nil, fmt.Errorf("failed to get request headers") + } + + var host string + for key, values := range req.Header { + if strings.EqualFold(key, "Host") { + if len(values) > 0 { + host = values[0] + break + } + } + + } + + method := make([]byte, 32) + methodLen := unsafe.Sizeof(method) + if err := hyperReq.Method(&method[0], &methodLen); err != hyper.OK { + return nil, fmt.Errorf("failed to get method: %v", err) + } + + methodStr := string(method[:methodLen]) + + var scheme, authority, pathAndQuery [1024]byte + schemeLen, authorityLen, pathAndQueryLen := unsafe.Sizeof(scheme), unsafe.Sizeof(authority), unsafe.Sizeof(pathAndQuery) + uriResult := hyperReq.URIParts(&scheme[0], &schemeLen, &authority[0], &authorityLen, &pathAndQuery[0], &pathAndQueryLen) + if uriResult != hyper.OK { + return nil, fmt.Errorf("failed to get URI parts: %v", uriResult) + } + + var schemeStr, authorityStr, pathAndQueryStr string + if schemeLen == 0 { + schemeStr = "http" + } else { + schemeStr = string(scheme[:schemeLen]) + } + + if authorityLen == 0 { + authorityStr = host + } else { + authorityStr = string(authority[:authorityLen]) + } + + if pathAndQueryLen == 0 { + return nil, fmt.Errorf("failed to get URI path and query: %v", uriResult) + } else { + pathAndQueryStr = string(pathAndQuery[:pathAndQueryLen]) + } + req.Host = authorityStr + req.Method = methodStr + req.RequestURI = pathAndQueryStr + + var proto string + var protoMajor, protoMinor int + version := hyperReq.Version() + switch version { + case hyper.HTTPVersion10: + proto = "HTTP/1.0" + protoMajor = 1 + protoMinor = 0 + case hyper.HTTPVersion11: + proto = "HTTP/1.1" + protoMajor = 1 + protoMinor = 1 + case hyper.HTTPVersion2: + proto = "HTTP/2.0" + protoMajor = 2 + protoMinor = 0 + case hyper.HTTPVersionNone: + proto = "HTTP/0.0" + protoMajor = 0 + protoMinor = 0 + default: + return nil, fmt.Errorf("unknown HTTP version: %d", version) + } + req.Proto = proto + req.ProtoMajor = protoMajor + req.ProtoMinor = protoMinor + + urlStr := fmt.Sprintf("%s://%s%s", schemeStr, authorityStr, pathAndQueryStr) + url, err := url.Parse(urlStr) + if err != nil { + return nil, err + } + req.URL = url + + body := hyperReq.Body() + if body != nil { + taskFlag := getBodyTask + + bodyStream := newBodyStream(requestNotifyHandle) + req.Body = bodyStream + + taskData := taskData{ + hyperBody: body, + responseBody: nil, + bodyStream: bodyStream, + taskFlag: taskFlag, + executor: executor, + } + + requestNotifyHandle.SetData(c.Pointer(&taskData)) + fmt.Println("[debug] async task set") + + } else { + return nil, fmt.Errorf("failed to get request body") + } + + //hyperReq.Free() + + return &req, nil +} + +func addHeader(data unsafe.Pointer, name *byte, nameLen uintptr, value *byte, valueLen uintptr) c.Int { + req := (*Request)(data) + key := string(unsafe.Slice(name, nameLen)) + val := string(unsafe.Slice(value, valueLen)) + values := strings.Split(val, ",") + if len(values) > 1 { + for _, v := range values { + req.Header.Add(key, strings.TrimSpace(v)) + } + } else { + req.Header.Add(key, val) + } + return hyper.IterContinue +} diff --git a/x/net/http/response.go b/x/net/http/response.go index eceff38..a23de2e 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -32,7 +32,7 @@ type responseBodyRaw struct { type taskData struct { hyperBody *hyper.Body responseBody *responseBodyRaw - requestBody *requestBody + bodyStream *bodyStream executor *hyper.Executor taskFlag taskFlag } @@ -44,8 +44,6 @@ const ( getBodyTask ) -var DefaultChunkSize uintptr = 8192 - func newResponse(hyperChannel *hyper.ResponseChannel) *response { fmt.Printf("[debug] newResponse called\n") @@ -136,7 +134,7 @@ func (r *response) finalize() error { taskData := &taskData{ hyperBody: body, responseBody: &bodyData, - requestBody: nil, + bodyStream: nil, executor: nil, taskFlag: setBodyTask, } @@ -194,7 +192,6 @@ func setBodyDataFunc(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) return hyper.PollError } - type Response struct { Status string // e.g. "200 OK" StatusCode int // e.g. 200 @@ -234,7 +231,7 @@ func (r *Response) Cookies() []*Cookie { return readSetCookies(r.Header) } -func (r *Response) checkRespBody(taskData *taskData) (needContinue bool) { +func (r *Response) checkRespBody(taskData *clientTaskData) (needContinue bool) { pc := taskData.pc bodyWritable := r.bodyIsWritable() hasBody := taskData.req.Method != "HEAD" && r.ContentLength != 0 @@ -282,7 +279,7 @@ func (r *Response) checkRespBody(taskData *taskData) (needContinue bool) { return false } -func (r *Response) wrapRespBody(taskData *taskData) { +func (r *Response) wrapRespBody(taskData *clientTaskData) { body := &bodyEOFSignal{ body: r.Body, earlyCloseFn: func() error { diff --git a/x/net/http/server.go b/x/net/http/server.go index d530645..07a2ae7 100644 --- a/x/net/http/server.go +++ b/x/net/http/server.go @@ -29,9 +29,14 @@ import ( // well read them) const maxPostHandlerReadBytes = 256 << 10 +// _SC_NPROCESSORS_ONLN is the number of processors on the system const _SC_NPROCESSORS_ONLN c.Int = 58 -var cpuCount int +// DefaultChunkSize is the default chunk size for reading and writing data +var DefaultChunkSize uintptr = 8192 + +// cpuCount is the number of processors on the system +var cpuCount int type Handler interface { ServeHTTP(ResponseWriter, *Request) @@ -44,8 +49,8 @@ type ResponseWriter interface { } type Server struct { - Addr string - Handler Handler + Addr string + Handler Handler isShutdown atomic.Bool eventLoop []*eventLoop @@ -80,7 +85,7 @@ type serviceUserdata struct { asyncHandle *libuv.Async host [128]c.Char port [8]c.Char - executor *hyper.Executor + executor *hyper.Executor } func NewServer(addr string) *Server { @@ -130,7 +135,7 @@ func newEventLoop() (*eventLoop, error) { el.uvLoop.SetData(unsafe.Pointer(el)) if r := libuv.InitTcpEx(el.uvLoop, &el.uvServer, cnet.AF_INET); r != 0 { - return nil, fmt.Errorf("failed to init TCP: %v", libuv.Strerror(libuv.Errno(r))) + return nil, fmt.Errorf("failed to init TCP: %s", c.GoString(libuv.Strerror(libuv.Errno(r)))) } return el, nil @@ -139,19 +144,19 @@ func newEventLoop() (*eventLoop, error) { func (el *eventLoop) run(host string, port int) error { var sockaddr cnet.SockaddrIn if r := libuv.Ip4Addr(c.AllocaCStr(host), c.Int(port), &sockaddr); r != 0 { - return fmt.Errorf("failed to create IP address: %v", libuv.Strerror(libuv.Errno(r))) + return fmt.Errorf("failed to create IP address: %s", c.GoString(libuv.Strerror(libuv.Errno(r)))) } if err := setReuseAddr(&el.uvServer); err != nil { - return fmt.Errorf("failed to set SO_REUSEADDR: %v", err) + return fmt.Errorf("failed to set SO_REUSEADDR: %s", err) } if r := el.uvServer.Bind((*cnet.SockAddr)(unsafe.Pointer(&sockaddr)), 0); r != 0 { - return fmt.Errorf("failed to bind: %v", libuv.Strerror(libuv.Errno(r))) + return fmt.Errorf("failed to bind: %s", c.GoString(libuv.Strerror(libuv.Errno(r)))) } - if err := (*libuv.Stream)(&el.uvServer).Listen(128, onNewConnection); err != 0 { - return fmt.Errorf("failed to listen: %v", err) + if r := (*libuv.Stream)(&el.uvServer).Listen(128, onNewConnection); r != 0 { + return fmt.Errorf("failed to listen: %s", c.GoString(libuv.Strerror(libuv.Errno(r)))) } if r := libuv.InitIdle(el.uvLoop, &el.idleHandle); r != 0 { @@ -314,7 +319,7 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { r := libuv.PollInit(el.uvLoop, &conn.pollHandle, libuv.OsFd(conn.stream.GetIoWatcherFd())) if r < 0 { - fmt.Fprintf(os.Stderr, "uv_poll_init error: %s\n", libuv.Strerror(libuv.Errno(r))) + fmt.Fprintf(os.Stderr, "uv_poll_init error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(r)))) (*libuv.Handle)(unsafe.Pointer(&conn.stream)).Close(nil) return } @@ -482,8 +487,8 @@ func handleGetBodyTask(hyperTaskType hyper.TaskReturnType, task *hyper.Task, pay handleTaskBuffer(task, payload) case hyper.TaskEmpty: fmt.Println("[debug] Get body task closing request body") - if payload.requestBody != nil { - payload.requestBody.Close() + if payload.bodyStream != nil { + payload.bodyStream.Close() } task.Free() } @@ -513,7 +518,7 @@ func handleTaskError(task *hyper.Task) { func handleTaskBuffer(task *hyper.Task, payload *taskData) { buf := (*hyper.Buf)(task.Value()) bytes := unsafe.Slice(buf.Bytes(), buf.Len()) - payload.requestBody.readCh <- bytes + payload.bodyStream.readCh <- bytes fmt.Printf("[debug] Task get body writing to bodyWriter: %s\n", string(bytes)) buf.Free() task.Free() @@ -669,7 +674,7 @@ func updateConnRegistrations(conn *conn) bool { fmt.Printf("[debug] Starting poll with events: %d\n", events) r := conn.pollHandle.Start(events, onPoll) if r < 0 { - fmt.Fprintf(os.Stderr, "uv_poll_start error: %s\n", libuv.Strerror(libuv.Errno(r))) + fmt.Fprintf(os.Stderr, "uv_poll_start error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(r)))) return false } return true diff --git a/x/net/http/transport.go b/x/net/http/transport.go index e47bd2a..1fc027f 100644 --- a/x/net/http/transport.go +++ b/x/net/http/transport.go @@ -36,7 +36,6 @@ var DefaultTransport RoundTripper = &Transport{ // DefaultMaxIdleConnsPerHost is the default value of Transport's // MaxIdleConnsPerHost. const DefaultMaxIdleConnsPerHost = 2 -const _SC_NPROCESSORS_ONLN c.Int = 58 // Debug switch provided for developers const ( @@ -98,7 +97,7 @@ type responseAndError struct { type timeoutData struct { timeoutch chan struct{} - taskData *taskData + clientTaskData *clientTaskData } type readTrackingBody struct { @@ -597,8 +596,6 @@ func getMilliseconds(deadline time.Time) uint64 { return uint64(milliseconds) } -var cpuCount int - func init() { cpuCount = int(c.Sysconf(_SC_NPROCESSORS_ONLN)) if cpuCount <= 0 { @@ -666,7 +663,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { libuv.InitTimer(eventLoop.loop, req.timer) ch := &timeoutData{ timeoutch: req.timeoutch, - taskData: nil, + clientTaskData: nil, } (*libuv.Handle)(c.Pointer(req.timer)).SetData(c.Pointer(ch)) @@ -1172,7 +1169,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err writeErrCh := make(chan error, 1) resc := make(chan responseAndError, 1) - taskData := &taskData{ + clientTaskData := &clientTaskData{ req: req, pc: pc, addedGzip: requestedGzip, @@ -1190,14 +1187,14 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err opts.Exec(pc.eventLoop.exec) // send the handshake handshakeTask := hyper.Handshake(hyperIo, opts) - taskData.taskId = handshake - handshakeTask.SetUserdata(c.Pointer(taskData), nil) + clientTaskData.taskId = handshake + handshakeTask.SetUserdata(c.Pointer(clientTaskData), nil) // Send the request to readWriteLoop(). pc.eventLoop.exec.Push(handshakeTask) //} else { // println("############### roundTrip: pc.client != nil") - // taskData.taskId = read - // err = req.write(pc.client, taskData, pc.eventLoop.exec) + // clientTaskData.taskId = read + // err = req.write(pc.client, clientTaskData, pc.eventLoop.exec) // if err != nil { // writeErrCh <- err // pc.close(err) @@ -1279,16 +1276,16 @@ func readWriteLoop(checker *libuv.Idle) { } func (eventLoop *clientEventLoop) handleTask(task *hyper.Task) { - taskData := (*taskData)(task.Userdata()) - if taskData == nil { + clientTaskData := (*clientTaskData)(task.Userdata()) + if clientTaskData == nil { // A background task for hyper_client completed... task.Free() return } var err error - pc := taskData.pc + pc := clientTaskData.pc // If original taskId is set, we need to check it - err = checkTaskType(task, taskData) + err = checkTaskType(task, clientTaskData) if err != nil { if debugSwitch { println("############### handleTask: checkTaskType err != nil") @@ -1296,7 +1293,7 @@ func (eventLoop *clientEventLoop) handleTask(task *hyper.Task) { closeAndRemoveIdleConn(pc, true) return } - switch taskData.taskId { + switch clientTaskData.taskId { case handshake: if debugReadWriteLoop { println("############### write") @@ -1314,12 +1311,12 @@ func (eventLoop *clientEventLoop) handleTask(task *hyper.Task) { task.Free() // TODO(hah) Proxy(writeLoop) - taskData.taskId = read - err = taskData.req.Request.write(pc.client, taskData, eventLoop.exec) + clientTaskData.taskId = read + err = clientTaskData.req.Request.write(pc.client, clientTaskData, eventLoop.exec) if err != nil { //pc.writeErrCh <- err // to the body reader, which might recycle us - taskData.writeErrCh <- err // to the roundTrip function + clientTaskData.writeErrCh <- err // to the roundTrip function if debugSwitch { println("############### handleTask: write err != nil") } @@ -1367,11 +1364,11 @@ func (eventLoop *clientEventLoop) handleTask(task *hyper.Task) { var resp *Response if err == nil { - pc.chunkAsync.SetData(c.Pointer(taskData)) - bc := newBodyChunk(pc.chunkAsync) - pc.bodyChunk = bc - resp, err = ReadResponse(bc, taskData.req.Request, hyperResp) - taskData.hyperBody = hyperResp.Body() + pc.chunkAsync.SetData(c.Pointer(clientTaskData)) + bc := newBodyStream(pc.chunkAsync) + pc.bodyStream = bc + resp, err = ReadResponse(bc, clientTaskData.req.Request, hyperResp) + clientTaskData.hyperBody = hyperResp.Body() } else { err = transportReadFromServerError{err} pc.closeErr = err @@ -1381,11 +1378,11 @@ func (eventLoop *clientEventLoop) handleTask(task *hyper.Task) { hyperResp.Free() if err != nil { - pc.bodyChunk.closeWithError(err) - taskData.closeHyperBody() + pc.bodyStream.closeWithError(err) + clientTaskData.closeHyperBody() select { - case taskData.resc <- responseAndError{err: err}: - case <-taskData.callerGone: + case clientTaskData.resc <- responseAndError{err: err}: + case <-clientTaskData.callerGone: if debugSwitch { println("############### handleTask read: callerGone") } @@ -1399,32 +1396,32 @@ func (eventLoop *clientEventLoop) handleTask(task *hyper.Task) { return } - taskData.taskId = readBodyChunk + clientTaskData.taskId = readBodyChunk - if !taskData.req.deadline.IsZero() { - (*timeoutData)((*libuv.Handle)(c.Pointer(taskData.req.timer)).GetData()).taskData = taskData + if !clientTaskData.req.deadline.IsZero() { + (*timeoutData)((*libuv.Handle)(c.Pointer(clientTaskData.req.timer)).GetData()).clientTaskData = clientTaskData } //pc.mu.Lock() pc.numExpectedResponses-- //pc.mu.Unlock() - needContinue := resp.checkRespBody(taskData) + needContinue := resp.checkRespBody(clientTaskData) if needContinue { return } - resp.wrapRespBody(taskData) + resp.wrapRespBody(clientTaskData) select { - case taskData.resc <- responseAndError{res: resp}: - case <-taskData.callerGone: + case clientTaskData.resc <- responseAndError{res: resp}: + case <-clientTaskData.callerGone: // defer if debugSwitch { println("############### handleTask read: callerGone 2") } - pc.bodyChunk.Close() - taskData.closeHyperBody() + pc.bodyStream.Close() + clientTaskData.closeHyperBody() closeAndRemoveIdleConn(pc, true) return } @@ -1446,7 +1443,7 @@ func (eventLoop *clientEventLoop) handleTask(task *hyper.Task) { chunk.Free() task.Free() // Write to the channel - pc.bodyChunk.readCh <- bytes + pc.bodyStream.readCh <- bytes if debugReadWriteLoop { println("############### readBodyChunk end [buf]") } @@ -1455,9 +1452,9 @@ func (eventLoop *clientEventLoop) handleTask(task *hyper.Task) { // taskType == taskEmpty (check in checkTaskType) task.Free() - pc.bodyChunk.closeWithError(io.EOF) - taskData.closeHyperBody() - replaced := pc.t.replaceReqCanceler(taskData.req.cancelKey, nil) // before pc might return to idle pool + pc.bodyStream.closeWithError(io.EOF) + clientTaskData.closeHyperBody() + replaced := pc.t.replaceReqCanceler(clientTaskData.req.cancelKey, nil) // before pc might return to idle pool pc.alive = pc.alive && replaced && pc.tryPutIdleConn() @@ -1473,10 +1470,10 @@ func (eventLoop *clientEventLoop) handleTask(task *hyper.Task) { } func readyToRead(aysnc *libuv.Async) { - taskData := (*taskData)(aysnc.GetData()) - dataTask := taskData.hyperBody.Data() - dataTask.SetUserdata(c.Pointer(taskData), nil) - taskData.pc.eventLoop.exec.Push(dataTask) + clientTaskData := (*clientTaskData)(aysnc.GetData()) + dataTask := clientTaskData.hyperBody.Data() + dataTask.SetUserdata(c.Pointer(clientTaskData), nil) + clientTaskData.pc.eventLoop.exec.Push(dataTask) } // closeAndRemoveIdleConn Replace the defer function of readLoop in stdlib @@ -1504,7 +1501,7 @@ type connData struct { isClosing atomic.Bool } -type taskData struct { +type clientTaskData struct { taskId taskId req *transportRequest pc *persistConn @@ -1545,7 +1542,7 @@ func (conn *connData) Close() { } } -func (d *taskData) closeHyperBody() { +func (d *clientTaskData) closeHyperBody() { if d.hyperBody != nil { d.hyperBody.Free() d.hyperBody = nil @@ -1666,11 +1663,11 @@ func onTimeout(timer *libuv.Timer) { close(data.timeoutch) timer.Stop() - taskData := data.taskData - if taskData != nil { - pc := taskData.pc + clientTaskData := data.clientTaskData + if clientTaskData != nil { + pc := clientTaskData.pc pc.alive = false - pc.t.cancelRequest(taskData.req.cancelKey, errors.New("timeout: req.Context().Err()")) + pc.t.cancelRequest(clientTaskData.req.cancelKey, errors.New("timeout: req.Context().Err()")) closeAndRemoveIdleConn(pc, true) } } @@ -1685,8 +1682,8 @@ func newHyperIo(connData *connData) *hyper.Io { } // checkTaskType checks the task type -func checkTaskType(task *hyper.Task, taskData *taskData) (err error) { - curTaskId := taskData.taskId +func checkTaskType(task *hyper.Task, clientTaskData *clientTaskData) (err error) { + curTaskId := clientTaskData.taskId taskType := task.Type() if taskType == hyper.TaskError { err = fail((*hyper.Error)(task.Value()), curTaskId) @@ -1710,18 +1707,18 @@ func checkTaskType(task *hyper.Task, taskData *taskData) (err error) { if err != nil { task.Free() if curTaskId == handshake || curTaskId == read { - taskData.writeErrCh <- err + clientTaskData.writeErrCh <- err if debugSwitch { println("############### checkTaskType: writeErrCh") } - taskData.pc.close(err) + clientTaskData.pc.close(err) } - if taskData.pc.bodyChunk != nil { - taskData.pc.bodyChunk.Close() - taskData.pc.bodyChunk = nil + if clientTaskData.pc.bodyStream != nil { + clientTaskData.pc.bodyStream.Close() + clientTaskData.pc.bodyStream = nil } - taskData.closeHyperBody() - taskData.pc.alive = false + clientTaskData.closeHyperBody() + clientTaskData.pc.alive = false } return } @@ -1919,7 +1916,7 @@ type persistConn struct { closeErr error // Replace the closeErr in readLoop tryPutIdleConn func() bool // Replace the tryPutIdleConn in readLoop client *hyper.ClientConn // http long connection client handle - bodyChunk *bodyChunk // Implement non-blocking consumption of each responseBody chunk + bodyStream *bodyStream // Implement non-blocking consumption of each responseBody chunk chunkAsync *libuv.Async // Notifying that the received chunk has been read } diff --git a/x/net/http/util.go b/x/net/http/util.go index bfd9fc3..c7fc4dc 100644 --- a/x/net/http/util.go +++ b/x/net/http/util.go @@ -5,9 +5,8 @@ import ( "unicode" "unicode/utf8" - "golang.org/x/net/idna" - "github.com/goplus/llgoexamples/x/net" + "golang.org/x/net/idna" ) /** From 6d8087a6892927c9abe9422922e7dfc836a58123 Mon Sep 17 00:00:00 2001 From: hackerchai Date: Fri, 20 Sep 2024 18:38:21 +0800 Subject: [PATCH 53/55] refactor(x/net/http): Remove debug message and add comments Signed-off-by: hackerchai --- x/net/http/_demo/server/server.go | 3 +- x/net/http/request.go | 12 +- x/net/http/response.go | 52 +++------ x/net/http/server.go | 175 ++++++++++++++++++++++-------- x/net/http/servermux.go | 18 ++- 5 files changed, 160 insertions(+), 100 deletions(-) diff --git a/x/net/http/_demo/server/server.go b/x/net/http/_demo/server/server.go index cac52ea..f57f517 100644 --- a/x/net/http/_demo/server/server.go +++ b/x/net/http/_demo/server/server.go @@ -3,11 +3,10 @@ package main import ( "fmt" - "github.com/goplus/llgo/x/net/http" + "github.com/goplus/llgoexamples/x/net/http" ) func echoHandler(w http.ResponseWriter, r *http.Request) { - fmt.Printf("[debug] echoHandler called\n") fmt.Printf(">> %s %s HTTP/%d.%d\n", r.Method, r.RequestURI, r.ProtoMajor, r.ProtoMinor) for key, values := range r.Header { for _, value := range values { diff --git a/x/net/http/request.go b/x/net/http/request.go index b049578..8a13372 100644 --- a/x/net/http/request.go +++ b/x/net/http/request.go @@ -489,14 +489,15 @@ func valueOrDefault(value, def string) string { return def } +// readRequest reads the request from the hyper request func readRequest(executor *hyper.Executor, hyperReq *hyper.Request, requestNotifyHandle *libuv.Async, remoteAddr string) (*Request, error) { - println("[debug] readRequest called") req := Request{ Header: make(Header), Body: nil, } req.RemoteAddr = remoteAddr + //get the request headers headers := hyperReq.Headers() if headers != nil { headers.Foreach(addHeader, unsafe.Pointer(&req)) @@ -504,6 +505,7 @@ func readRequest(executor *hyper.Executor, hyperReq *hyper.Request, requestNotif return nil, fmt.Errorf("failed to get request headers") } + //get the host from the request header var host string for key, values := range req.Header { if strings.EqualFold(key, "Host") { @@ -537,12 +539,14 @@ func readRequest(executor *hyper.Executor, hyperReq *hyper.Request, requestNotif schemeStr = string(scheme[:schemeLen]) } + //if authority is empty, use the host from the request header if authorityLen == 0 { authorityStr = host } else { authorityStr = string(authority[:authorityLen]) } + //if path and query is empty, use the path and query from the request header if pathAndQueryLen == 0 { return nil, fmt.Errorf("failed to get URI path and query: %v", uriResult) } else { @@ -593,7 +597,8 @@ func readRequest(executor *hyper.Executor, hyperReq *hyper.Request, requestNotif bodyStream := newBodyStream(requestNotifyHandle) req.Body = bodyStream - taskData := taskData{ + //prepare task data for hyper executor + taskData := serverTaskData{ hyperBody: body, responseBody: nil, bodyStream: bodyStream, @@ -601,8 +606,8 @@ func readRequest(executor *hyper.Executor, hyperReq *hyper.Request, requestNotif executor: executor, } + //set task data to the request notify async handle requestNotifyHandle.SetData(c.Pointer(&taskData)) - fmt.Println("[debug] async task set") } else { return nil, fmt.Errorf("failed to get request body") @@ -613,6 +618,7 @@ func readRequest(executor *hyper.Executor, hyperReq *hyper.Request, requestNotif return &req, nil } +// addHeader callback function to add the header to the request func addHeader(data unsafe.Pointer, name *byte, nameLen uintptr, value *byte, valueLen uintptr) c.Int { req := (*Request)(data) key := string(unsafe.Slice(name, nameLen)) diff --git a/x/net/http/response.go b/x/net/http/response.go index a23de2e..d8eca16 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -10,10 +10,10 @@ import ( "unsafe" "github.com/goplus/llgo/c" - "github.com/goplus/llgo/c/os" "github.com/goplus/llgoexamples/rust/hyper" ) +// response is the response of the server type response struct { header Header statusCode int @@ -23,13 +23,15 @@ type response struct { hyperResp *hyper.Response } +// responseBodyRaw is the body of the response type responseBodyRaw struct { data []byte len uintptr readLen uintptr } -type taskData struct { +// serverTaskData is the data of the task +type serverTaskData struct { hyperBody *hyper.Body responseBody *responseBodyRaw bodyStream *bodyStream @@ -37,6 +39,7 @@ type taskData struct { taskFlag taskFlag } +// taskFlag is the sign of the task type taskFlag int const ( @@ -44,9 +47,8 @@ const ( getBodyTask ) +// newResponse creates a new response func newResponse(hyperChannel *hyper.ResponseChannel) *response { - fmt.Printf("[debug] newResponse called\n") - return &response{ header: make(Header), written: false, @@ -56,10 +58,12 @@ func newResponse(hyperChannel *hyper.ResponseChannel) *response { } } +// Header returns the header of the response func (r *response) Header() Header { return r.header } +// Write writes the data to the response func (r *response) Write(data []byte) (int, error) { if !r.written { r.WriteHeader(200) @@ -68,8 +72,8 @@ func (r *response) Write(data []byte) (int, error) { return len(data), nil } +// WriteHeader writes the status code to the response func (r *response) WriteHeader(statusCode int) { - fmt.Println("[debug] WriteHeader called") if r.written { return } @@ -78,16 +82,7 @@ func (r *response) WriteHeader(statusCode int) { r.hyperResp.SetStatus(uint16(statusCode)) - fmt.Println("[debug] WriteHeaderStatusCode done") - - //debug - fmt.Printf("[debug] < HTTP/1.1 %d\n", statusCode) - for key, values := range r.header { - for _, value := range values { - fmt.Printf("[debug] < %s: %s\n", key, value) - } - } - + // set the header to the hyper response headers := r.hyperResp.Headers() for key, values := range r.header { valueLen := len(values) @@ -105,13 +100,10 @@ func (r *response) WriteHeader(statusCode int) { return } } - - fmt.Println("[debug] WriteHeader done") } +// finalize finalizes the response (body & header), it will be called when the response is ready to be sent func (r *response) finalize() error { - fmt.Printf("[debug] finalize called\n") - if !r.written { r.WriteHeader(200) } @@ -125,13 +117,12 @@ func (r *response) finalize() error { len: uintptr(len(r.body)), readLen: 0, } - fmt.Println("[debug] bodyData constructed") body := hyper.NewBody() if body == nil { return fmt.Errorf("failed to create body") } - taskData := &taskData{ + taskData := &serverTaskData{ hyperBody: body, responseBody: &bodyData, bodyStream: nil, @@ -140,36 +131,25 @@ func (r *response) finalize() error { } body.SetDataFunc(setBodyDataFunc) body.SetUserdata(unsafe.Pointer(taskData), nil) - fmt.Println("[debug] bodyData userdata set") - - fmt.Println("[debug] bodyData set") resBody := r.hyperResp.SetBody(body) if resBody != hyper.OK { return fmt.Errorf("failed to set body") } - fmt.Println("[debug] body set") r.hyperChannel.Send(r.hyperResp) - fmt.Println("[debug] response sent") return nil } +// setBodyDataFunc is the callback function to set the body data func setBodyDataFunc(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) c.Int { - fmt.Println("[debug] setBodyDataFunc called") - taskData := (*taskData)(userdata) + taskData := (*serverTaskData)(userdata) if taskData == nil { - fmt.Println("[debug] taskData is nil") return hyper.PollError } - fmt.Println("[debug] taskData is not nil") body := taskData.responseBody if body.len > 0 { - //debug - fmt.Println("[debug]<") - fmt.Printf("[debug]%s\n", string(body.data)) - if body.len > DefaultChunkSize { *chunk = hyper.CopyBuf(&body.data[body.readLen], DefaultChunkSize) body.readLen += DefaultChunkSize @@ -179,16 +159,14 @@ func setBodyDataFunc(userdata c.Pointer, ctx *hyper.Context, chunk **hyper.Buf) body.readLen += body.len body.len = 0 } - fmt.Println("[debug] setBodyDataFunc done") return hyper.PollReady } + // if the body is empty, return PollReady directly if body.len == 0 { *chunk = nil - fmt.Println("[debug] setBodyDataFunc done") return hyper.PollReady } - fmt.Printf("[debug] error setting body data: %s\n", c.GoString(c.Strerror(os.Errno))) return hyper.PollError } diff --git a/x/net/http/server.go b/x/net/http/server.go index 07a2ae7..a13bbaa 100644 --- a/x/net/http/server.go +++ b/x/net/http/server.go @@ -3,6 +3,7 @@ package http import ( "errors" "fmt" + "io" "os" "strconv" "sync" @@ -38,16 +39,19 @@ var DefaultChunkSize uintptr = 8192 // cpuCount is the number of processors on the system var cpuCount int +// Handler is the interface implemented by objects that can serve HTTP requests. type Handler interface { ServeHTTP(ResponseWriter, *Request) } +// ResponseWriter is the interface implemented by objects that can write HTTP responses. type ResponseWriter interface { Header() Header Write([]byte) (int, error) WriteHeader(statusCode int) } +// Server is a HTTP server. type Server struct { Addr string Handler Handler @@ -56,6 +60,7 @@ type Server struct { eventLoop []*eventLoop } +// eventLoop is a event loop for the server. type eventLoop struct { uvLoop *libuv.Loop uvServer libuv.Tcp @@ -70,6 +75,7 @@ type eventLoop struct { activeConnections map[*conn]struct{} } +// conn is an abstraction of a connection. type conn struct { stream libuv.Tcp pollHandle libuv.Poll @@ -81,6 +87,7 @@ type conn struct { remoteAddr string } +// serviceUserdata is the user data for the service. type serviceUserdata struct { asyncHandle *libuv.Async host [128]c.Char @@ -88,6 +95,7 @@ type serviceUserdata struct { executor *hyper.Executor } +// NewServer creates a new Server. func NewServer(addr string) *Server { return &Server{ Addr: addr, @@ -95,18 +103,21 @@ func NewServer(addr string) *Server { } } +// newEventLoop creates a new event loop. func newEventLoop() (*eventLoop, error) { activeClients := make(map[*conn]struct{}) el := &eventLoop{ activeConnections: activeClients, } + // create executor executor := hyper.NewExecutor() if executor == nil { return nil, fmt.Errorf("failed to create Executor") } el.executor = executor + // set http options http1Opts := hyper.Http1ServerconnOptionsNew(el.executor) if http1Opts == nil { return nil, fmt.Errorf("failed to create http1_opts") @@ -128,12 +139,16 @@ func newEventLoop() (*eventLoop, error) { } el.http2Opts = http2Opts + // create libuv event loop el.uvLoop = libuv.LoopNew() if el.uvLoop == nil { return nil, fmt.Errorf("failed to get default loop") } + + // set event loop data el.uvLoop.SetData(unsafe.Pointer(el)) + // create libuv TCP server if r := libuv.InitTcpEx(el.uvLoop, &el.uvServer, cnet.AF_INET); r != 0 { return nil, fmt.Errorf("failed to init TCP: %s", c.GoString(libuv.Strerror(libuv.Errno(r)))) } @@ -141,34 +156,42 @@ func newEventLoop() (*eventLoop, error) { return el, nil } +// run runs the event loop. func (el *eventLoop) run(host string, port int) error { var sockaddr cnet.SockaddrIn if r := libuv.Ip4Addr(c.AllocaCStr(host), c.Int(port), &sockaddr); r != 0 { return fmt.Errorf("failed to create IP address: %s", c.GoString(libuv.Strerror(libuv.Errno(r)))) } + // set SO_REUSEADDR and SO_REUSEPORT for the server if err := setReuseAddr(&el.uvServer); err != nil { return fmt.Errorf("failed to set SO_REUSEADDR: %s", err) } + // bind the server to the address if r := el.uvServer.Bind((*cnet.SockAddr)(unsafe.Pointer(&sockaddr)), 0); r != 0 { return fmt.Errorf("failed to bind: %s", c.GoString(libuv.Strerror(libuv.Errno(r)))) } + // listen for new connections if r := (*libuv.Stream)(&el.uvServer).Listen(128, onNewConnection); r != 0 { return fmt.Errorf("failed to listen: %s", c.GoString(libuv.Strerror(libuv.Errno(r)))) } + // create idle handler if r := libuv.InitIdle(el.uvLoop, &el.idleHandle); r != 0 { return fmt.Errorf("failed to initialize idle handler: %d", r) } + // set idle handler data (*libuv.Handle)(unsafe.Pointer(&el.idleHandle)).SetData(unsafe.Pointer(el)) + // start the idle handler if r := el.idleHandle.Start(onIdle); r != 0 { return fmt.Errorf("failed to start idle handler: %d", r) } + // run the libuv event loop if r := el.uvLoop.Run(libuv.RUN_DEFAULT); r != 0 { return fmt.Errorf("error in event loop: %d", r) } @@ -176,6 +199,7 @@ func (el *eventLoop) run(host string, port int) error { return nil } +// setReuseAddr sets the SO_REUSEADDR and SO_REUSEPORT options for the given TCP handle. func setReuseAddr(handle *libuv.Tcp) error { var fd libuv.OsFd result := (*libuv.Handle)(unsafe.Pointer(handle)).Fileno(&fd) @@ -184,10 +208,12 @@ func setReuseAddr(handle *libuv.Tcp) error { } yes := c.Int(1) + // set SO_REUSEADDR if err := cnet.SetSockOpt(c.Int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, unsafe.Pointer(&yes), c.Uint(unsafe.Sizeof(yes))); err != 0 { return fmt.Errorf("Error setting SO_REUSEADDR") } + // set SO_REUSEPORT if err := cnet.SetSockOpt(c.Int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEPORT, unsafe.Pointer(&yes), c.Uint(unsafe.Sizeof(yes))); err != 0 { return fmt.Errorf("Error setting SO_REUSEPORT") } @@ -199,19 +225,23 @@ func setReuseAddr(handle *libuv.Tcp) error { // and [ListenAndServeTLS] methods after a call to [Server.Shutdown] or [Server.Close]. var ErrServerClosed = errors.New("http: Server closed") +// ListenAndServe listens on the TCP network address addr and then calls Serve +// to handle requests on incoming connections. func ListenAndServe(addr string, handler Handler) error { server := &Server{Addr: addr, Handler: handler} return server.ListenAndServe() } +// ListenAndServe listens on the TCP network address addr and then calls Serve +// to handle requests on incoming connections. func (srv *Server) ListenAndServe() error { + // get the number of processors on the system cpuCount = int(c.Sysconf(_SC_NPROCESSORS_ONLN)) if cpuCount <= 0 { cpuCount = 4 } - fmt.Printf("[debug] cpuCount: %d\n", cpuCount) - + // create event loops for i := 0; i < cpuCount; i++ { el, err := newEventLoop() if err != nil { @@ -220,6 +250,7 @@ func (srv *Server) ListenAndServe() error { srv.eventLoop = append(srv.eventLoop, el) } + // parse the address host, port, err := net.SplitHostPort(srv.Addr) if err != nil { return fmt.Errorf("invalid address %q: %v", srv.Addr, err) @@ -230,6 +261,7 @@ func (srv *Server) ListenAndServe() error { return fmt.Errorf("invalid port number: %v", err) } + // create error channel and wait group errChan := make(chan error, len(srv.eventLoop)) var wg sync.WaitGroup @@ -246,22 +278,32 @@ func (srv *Server) ListenAndServe() error { wg.Wait() + // wait for all event loops to finish + close(errChan) + + // collect errors from all event loops + for err := range errChan { + return fmt.Errorf("error in event loop: %v", err) + } + fmt.Printf("Listening on %s\n", srv.Addr) return nil } +// HandleFunc is a convenience function to register a handler for a pattern. func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) { DefaultServeMux.HandleFunc(pattern, handler) } +// onNewConnection is the libuv callback function for new connections. func onNewConnection(serverStream *libuv.Stream, status c.Int) { - fmt.Println("[debug] onNewConnection called") if status < 0 { fmt.Printf("New connection error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(status)))) return } + // get the event loop el := (*eventLoop)((*libuv.Handle)(unsafe.Pointer(serverStream)).GetLoop().GetData()) if el == nil { fmt.Fprintf(os.Stderr, "Event loop is nil\n") @@ -274,17 +316,16 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { return } - fmt.Println("[debug] async handle creating") - + // create async handle for request to notify hyper executor reading next request chunk requestNotifyHandle := &libuv.Async{} el.uvLoop.Async(requestNotifyHandle, onAsync) + // initialize the TCP connection libuv.InitTcp(el.uvLoop, &conn.stream) conn.stream.Data = unsafe.Pointer(conn) + // accept the connection if serverStream.Accept((*libuv.Stream)(unsafe.Pointer(&conn.stream))) == 0 { - fmt.Println("[debug] Accepted new connection") - userData := createServiceUserdata() if userData == nil { fmt.Fprintf(os.Stderr, "Failed to create service userdata\n") @@ -301,6 +342,7 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { userData.executor = el.executor userData.asyncHandle = requestNotifyHandle + // get the remote address var addr cnet.SockaddrStorage addrlen := c.Int(unsafe.Sizeof(addr)) conn.stream.Getpeername((*cnet.SockAddr)(c.Pointer(&addr)), &addrlen) @@ -326,41 +368,45 @@ func onNewConnection(serverStream *libuv.Stream, status c.Int) { (*libuv.Handle)(unsafe.Pointer(&conn.pollHandle)).Data = unsafe.Pointer(conn) + // update the connection registrations if !updateConnRegistrations(conn) { (*libuv.Handle)(unsafe.Pointer(&conn.pollHandle)).Close(nil) (*libuv.Handle)(unsafe.Pointer(&conn.stream)).Close(nil) return } - fmt.Println("[debug] Conn created") + // track the connection el.trackConn(conn, true) - fmt.Println("[debug] Conn tracked") + // create hyper io io := createIo(conn) + + // create hyper service service := hyper.ServiceNew(serverCallback) service.SetUserdata(unsafe.Pointer(userData), freeServiceUserdata) serverConn := hyper.ServeHttpXConnection(el.http1Opts, el.http2Opts, io, service) el.executor.Push(serverConn) } else { - fmt.Println("[debug] Client not accepted") (*libuv.Handle)(unsafe.Pointer(&conn.pollHandle)).Close(nil) (*libuv.Handle)(unsafe.Pointer(&conn.stream)).Close(nil) } } +// onAsync is the libuv callback function for async events. func onAsync(asyncHandle *libuv.Async) { - fmt.Println("[debug] onAsync called") - taskData := (*taskData)(asyncHandle.GetData()) + taskData := (*serverTaskData)(asyncHandle.GetData()) if taskData == nil { - fmt.Println("[debug] taskData is nil") return } + + // set the task data to the hyper body dataTask := taskData.hyperBody.Data() dataTask.SetUserdata(c.Pointer(taskData), nil) + + // push the task to the hyper executor if dataTask != nil { r := taskData.executor.Push(dataTask) - fmt.Printf("[debug] onAsync push data task: %d\n", r) if r != hyper.OK { fmt.Printf("failed to push data task: %d\n", r) dataTask.Free() @@ -368,9 +414,11 @@ func onAsync(asyncHandle *libuv.Async) { } } +// onIdle is the libuv callback function for running libuv event loop. func onIdle(handle *libuv.Idle) { el := (*eventLoop)((*libuv.Handle)(unsafe.Pointer(handle)).GetLoop().GetData()) if el.executor != nil { + // poll the hyper executor for tasks task := el.executor.Poll() for task != nil { handleTask(task) @@ -384,7 +432,9 @@ func onIdle(handle *libuv.Idle) { } } +// serverCallback is the callback function for hyper server connections. func serverCallback(userData unsafe.Pointer, hyperReq *hyper.Request, channel *hyper.ResponseChannel) { + // get the service userdata payload := (*serviceUserdata)(userData) if payload == nil { fmt.Fprintf(os.Stderr, "Error: Received null userData\n") @@ -412,42 +462,38 @@ func serverCallback(userData unsafe.Pointer, hyperReq *hyper.Request, channel *h } remoteAddr := c.GoString(&host[0]) + ":" + c.GoString(&port[0]) - fmt.Printf("[debug] Remote address: %s\n", remoteAddr) + // read the request req, err := readRequest(executor, hyperReq, requestNotifyHandle, remoteAddr) if err != nil { fmt.Printf("Error creating request: %v\n", err) return } + // create a new response res := newResponse(channel) - fmt.Println("[debug] Response created") - //TODO(hackerchai): replace with goroutine to enable blocking operation in handler - fmt.Println("[debug] Serving HTTP") + // handle the request DefaultServeMux.ServeHTTP(res, req) - fmt.Println("[debug] Response finalizing") + + // finalize the response res.finalize() - fmt.Println("[debug] Response finalized") + //TODO(hackerchai): replace with goroutine to enable blocking operation in handler // go func() { // DefaultServeMux.ServeHTTP(res, req) // res.finalize() // }() } +// handleTask is the callback function for hyper tasks. func handleTask(task *hyper.Task) { hyperTaskType := task.Type() - // Debug - fmt.Printf("[debug] Task type: %s\n", getTaskTypeString(hyperTaskType)) - - payload := (*taskData)(task.Userdata()) - // Debug - if payload == nil { - fmt.Println("[debug] task data is nil") - } + // get the server task data + payload := (*serverTaskData)(task.Userdata()) + // handle the task based on the task flag if payload != nil { switch payload.taskFlag { case getBodyTask: @@ -456,54 +502,52 @@ func handleTask(task *hyper.Task) { handleSetBodyTask(hyperTaskType, task) return default: - fmt.Println("[debug] Unknown response task type") return } } + // if the payload is nil, handle the task based on the task type switch hyperTaskType { case hyper.TaskError: handleTaskError(task) return case hyper.TaskEmpty: - fmt.Println("[debug] Empty task handled") task.Free() return case hyper.TaskServerconn: - fmt.Println("[debug] Server connection task handled") task.Free() return default: - fmt.Println("[debug] Unknown task type") return } } -func handleGetBodyTask(hyperTaskType hyper.TaskReturnType, task *hyper.Task, payload *taskData) { +// handleGetBodyTask is the callback function for hyper tasks with get body task type. +func handleGetBodyTask(hyperTaskType hyper.TaskReturnType, task *hyper.Task, payload *serverTaskData) { switch hyperTaskType { case hyper.TaskError: handleTaskError(task) case hyper.TaskBuf: handleTaskBuffer(task, payload) case hyper.TaskEmpty: - fmt.Println("[debug] Get body task closing request body") if payload.bodyStream != nil { - payload.bodyStream.Close() + payload.bodyStream.closeWithError(io.EOF) } task.Free() } } +// handleSetBodyTask is the callback function for hyper tasks with set body task type. func handleSetBodyTask(hyperTaskType hyper.TaskReturnType, task *hyper.Task) { switch hyperTaskType { case hyper.TaskError: handleTaskError(task) case hyper.TaskEmpty: - fmt.Println("[debug] Set body task freeing") task.Free() } } +// handleTaskError is the callback function for hyper tasks with error task type. func handleTaskError(task *hyper.Task) { err := (*hyper.Error)(task.Value()) fmt.Printf("Error code: %d\n", err.Code()) @@ -515,15 +559,20 @@ func handleTaskError(task *hyper.Task) { task.Free() } -func handleTaskBuffer(task *hyper.Task, payload *taskData) { +// handleTaskBuffer is the callback function for hyper tasks with buffer task type. +func handleTaskBuffer(task *hyper.Task, payload *serverTaskData) { buf := (*hyper.Buf)(task.Value()) bytes := unsafe.Slice(buf.Bytes(), buf.Len()) + + // push the bytes to the body stream payload.bodyStream.readCh <- bytes - fmt.Printf("[debug] Task get body writing to bodyWriter: %s\n", string(bytes)) + + // free the buffer and the task buf.Free() task.Free() } +// getTaskTypeString is the helper function for getting task type string. func getTaskTypeString(taskType hyper.TaskReturnType) string { switch taskType { case hyper.TaskEmpty: @@ -543,6 +592,7 @@ func getTaskTypeString(taskType hyper.TaskReturnType) string { } } +// trackConn is the helper function for tracking connections into event loop. func (el *eventLoop) trackConn(c *conn, add bool) { el.mu.Lock() defer el.mu.Unlock() @@ -556,6 +606,7 @@ func (el *eventLoop) trackConn(c *conn, add bool) { } } +// createIo is the helper function for creating hyper io. func createIo(conn *conn) *hyper.Io { io := hyper.NewIo() io.SetUserdata(unsafe.Pointer(conn), freeConnData) @@ -564,6 +615,7 @@ func createIo(conn *conn) *hyper.Io { return io } +// createServiceUserdata is the helper function for creating service userdata. func createServiceUserdata() *serviceUserdata { userdata := (*serviceUserdata)(c.Calloc(1, unsafe.Sizeof(serviceUserdata{}))) if userdata == nil { @@ -572,6 +624,7 @@ func createServiceUserdata() *serviceUserdata { return userdata } +// freeServiceUserdata is the helper function for freeing service userdata. func freeServiceUserdata(userdata c.Pointer) { castUserdata := (*serviceUserdata)(userdata) if castUserdata != nil { @@ -579,54 +632,61 @@ func freeServiceUserdata(userdata c.Pointer) { } } +// readCb is the callback function for hyper io read. func readCb(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uintptr) uintptr { conn := (*conn)(userdata) ret := cnet.Recv(conn.stream.GetIoWatcherFd(), unsafe.Pointer(buf), bufLen, 0) + // if the ret is greater than 0, return the ret if ret >= 0 { return uintptr(ret) } + // if the ret is less than 0, return the IoError if uintptr(cos.Errno) != syscall.EAGAIN && uintptr(cos.Errno) != syscall.EWOULDBLOCK { return hyper.IoError } + // if the readWaker is not nil, free the readWaker if conn.readWaker != nil { conn.readWaker.Free() } + // if the eventMask is not readable, set the eventMask to readable if conn.eventMask&c.Uint(libuv.READABLE) == 0 { conn.eventMask |= c.Uint(libuv.READABLE) - fmt.Printf("[debug] ReadCb Event mask: %d\n", conn.eventMask) if !updateConnRegistrations(conn) { return hyper.IoError } - fmt.Printf("[debug] ReadCb updateConnRegistrations\n") } conn.readWaker = ctx.Waker() return hyper.IoPending } +// writeCb is the callback function for hyper io write. func writeCb(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uintptr) uintptr { conn := (*conn)(userdata) ret := cnet.Send(conn.stream.GetIoWatcherFd(), unsafe.Pointer(buf), bufLen, 0) + // if the ret is greater than 0, return the ret if ret >= 0 { return uintptr(ret) } + // if the ret is less than 0, return the IoError if uintptr(cos.Errno) != syscall.EAGAIN && uintptr(cos.Errno) != syscall.EWOULDBLOCK { return hyper.IoError } + // if the writeWaker is not nil, free the writeWaker if conn.writeWaker != nil { conn.writeWaker.Free() } + // if the eventMask is not writable, set the eventMask to writable if conn.eventMask&c.Uint(libuv.WRITABLE) == 0 { conn.eventMask |= c.Uint(libuv.WRITABLE) - fmt.Printf("[debug] WriteCb Event mask: %d\n", conn.eventMask) if !updateConnRegistrations(conn) { return hyper.IoError } @@ -636,42 +696,51 @@ func writeCb(userdata unsafe.Pointer, ctx *hyper.Context, buf *byte, bufLen uint return hyper.IoPending } +// onPoll is the callback function for libuv poll. func onPoll(handle *libuv.Poll, status c.Int, events c.Int) { + // get the conn conn := (*conn)((*libuv.Handle)(unsafe.Pointer(handle)).GetData()) + // if the status is less than 0, return the PollError if status < 0 { fmt.Fprintf(os.Stderr, "Poll error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(status)))) return } + // if the events is readable and the readWaker is not nil, wake the readWaker if events&c.Int(libuv.READABLE) != 0 && conn.readWaker != nil { conn.readWaker.Wake() conn.readWaker = nil } + // if the events is writable and the writeWaker is not nil, wake the writeWaker if events&c.Int(libuv.WRITABLE) != 0 && conn.writeWaker != nil { conn.writeWaker.Wake() conn.writeWaker = nil } } +// updateConnRegistrations is the helper function for updating connection registrations. func updateConnRegistrations(conn *conn) bool { - fmt.Println("[debug] updateConnRegistrations called") - + // initialize the events events := c.Int(0) + + // if the eventMask is 0, return true if conn.eventMask == 0 { - fmt.Println("[debug] No events to poll, skipping poll start.") return true } - fmt.Printf("[debug] Event mask: %d\n", conn.eventMask) + + // if the eventMask is readable, set the events to readable if conn.eventMask&c.Uint(libuv.READABLE) != 0 { events |= c.Int(libuv.READABLE) } + + // if the eventMask is writable, set the events to writable if conn.eventMask&c.Uint(libuv.WRITABLE) != 0 { events |= c.Int(libuv.WRITABLE) } - fmt.Printf("[debug] Starting poll with events: %d\n", events) + // start the poll r := conn.pollHandle.Start(events, onPoll) if r < 0 { fmt.Fprintf(os.Stderr, "uv_poll_start error: %s\n", c.GoString(libuv.Strerror(libuv.Errno(r)))) @@ -680,6 +749,7 @@ func updateConnRegistrations(conn *conn) bool { return true } +// createConnData is the helper function for creating connection data. func createConnData() (*conn, error) { conn := &conn{} if conn == nil { @@ -691,17 +761,20 @@ func createConnData() (*conn, error) { return conn, nil } +// freeConnData is the helper function for freeing connection data. func freeConnData(userdata c.Pointer) { conn := (*conn)(userdata) conn.Close() } +// closeWalkCb is the callback function for libuv handle close. func closeWalkCb(handle *libuv.Handle, arg c.Pointer) { if handle.IsClosing() == 0 { handle.Close(nil) } } +// Close is the method for closing the server. func (srv *Server) Close() error { srv.isShutdown.Store(true) @@ -712,10 +785,12 @@ func (srv *Server) Close() error { return nil } +// shuttingDown is the method for checking if the server is shutting down. func (s *Server) shuttingDown() bool { return s.isShutdown.Load() } +// Close is the method for closing the event loop. func (el *eventLoop) Close() error { el.isShutdown.Store(true) @@ -746,13 +821,14 @@ func (el *eventLoop) Close() error { return nil } +// shuttingDown is the method for checking if the event loop is shutting down. func (el *eventLoop) shuttingDown() bool { return el.isShutdown.Load() } +// Close is the method for closing the connection. func (c *conn) Close() { if c != nil && !c.isClosing.Swap(true) { - fmt.Printf("[debug] Closing connection...\n") if c.readWaker != nil { c.readWaker.Free() c.readWaker = nil @@ -772,18 +848,23 @@ func (c *conn) Close() { } } +// shuttingDown is the method for checking if the connection is shutting down. func (c *conn) shuttingDown() bool { return c.isClosing.Load() } +// HandlerFunc is the type for handler function. type HandlerFunc func(ResponseWriter, *Request) +// ServeHTTP is the method for serving HTTP. func (f HandlerFunc) ServeHTTP(w ResponseWriter, r *Request) { f(w, r) } +// NotFoundHandler is the method for not found handler. func NotFoundHandler() Handler { return HandlerFunc(NotFound) } +// NotFound is the method for not found. func NotFound(w ResponseWriter, r *Request) { w.WriteHeader(404) w.Write([]byte("404 page not found")) diff --git a/x/net/http/servermux.go b/x/net/http/servermux.go index e111186..96b0819 100644 --- a/x/net/http/servermux.go +++ b/x/net/http/servermux.go @@ -1,15 +1,16 @@ package http import ( - "fmt" "sync" ) +// ServeMux is a HTTP request multiplexer type ServeMux struct { mu sync.RWMutex m map[string]muxEntry } +// muxEntry is a HTTP request multiplexer entry type muxEntry struct { h Handler pattern string @@ -19,22 +20,15 @@ type muxEntry struct { var DefaultServeMux = &ServeMux{m: make(map[string]muxEntry)} func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) { - fmt.Printf("[debug] ServeHTTP called\n") - // NotFoundHandler().ServeHTTP(w, r) - // return - h, pattern := mux.Handler(r) - fmt.Printf("[debug] Handler found for pattern: %s\n", pattern) + h, _ := mux.Handler(r) h.ServeHTTP(w, r) } +// Handler returns the handler to use for the given request, consulting r.Method, r.Host, and r.URL.Path. +// It always returns a non-nil handler. func (mux *ServeMux) Handler(r *Request) (h Handler, pattern string) { - fmt.Printf("[debug] Mux Handler called\n") mux.mu.RLock() defer mux.mu.RUnlock() - if r.URL == nil { - fmt.Println("[debug] r.URL is nil") - } - fmt.Printf("[debug] Handler called: r.URL.Path = %s\n", r.URL.Path) h, pattern = mux.m[r.URL.Path].h, r.URL.Path if h == nil { @@ -43,10 +37,12 @@ func (mux *ServeMux) Handler(r *Request) (h Handler, pattern string) { return } +// HandleFunc registers the handler for the given pattern. func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Request)) { mux.Handle(pattern, HandlerFunc(handler)) } +// Handle registers the handler for the given pattern using the default ServeMux. func (mux *ServeMux) Handle(pattern string, handler Handler) { mux.mu.Lock() defer mux.mu.Unlock() From 0382274c95a4e1c7a56e3c389e77edb6b598e03c Mon Sep 17 00:00:00 2001 From: hackerchai Date: Fri, 20 Sep 2024 18:43:33 +0800 Subject: [PATCH 54/55] neat(x/net/http): Go fmt style Signed-off-by: hackerchai --- x/net/bytealg.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/x/net/bytealg.go b/x/net/bytealg.go index 9c9ac68..4c6f008 100644 --- a/x/net/bytealg.go +++ b/x/net/bytealg.go @@ -10,10 +10,10 @@ func LastIndexByteString(s string, c byte) int { } func IndexByteString(s string, c byte) int { - for i := 0; i < len(s); i++ { - if s[i] == c { - return i - } - } - return -1 -} \ No newline at end of file + for i := 0; i < len(s); i++ { + if s[i] == c { + return i + } + } + return -1 +} From 1fb70adffaa7f2cd92eec55f9d004f28215583d7 Mon Sep 17 00:00:00 2001 From: spongehah <2635879218@qq.com> Date: Fri, 20 Sep 2024 19:28:41 +0800 Subject: [PATCH 55/55] neat(x/net/http/client): Go fmt style --- .../_demo/parallelRequest/parallelRequest.go | 2 +- x/net/http/body_stream.go | 11 +- x/net/http/response.go | 6 - x/net/http/transport.go | 163 +----------------- 4 files changed, 9 insertions(+), 173 deletions(-) diff --git a/x/net/http/_demo/parallelRequest/parallelRequest.go b/x/net/http/_demo/parallelRequest/parallelRequest.go index 0bcb336..ae1083d 100644 --- a/x/net/http/_demo/parallelRequest/parallelRequest.go +++ b/x/net/http/_demo/parallelRequest/parallelRequest.go @@ -26,7 +26,7 @@ func worker(id int, wg *sync.WaitGroup) { func main() { var wait sync.WaitGroup - for i := 0; i < 500; i++ { + for i := 0; i < 100; i++ { wait.Add(1) go worker(i, &wait) } diff --git a/x/net/http/body_stream.go b/x/net/http/body_stream.go index 64ebfa4..451ef7f 100644 --- a/x/net/http/body_stream.go +++ b/x/net/http/body_stream.go @@ -2,7 +2,6 @@ package http import ( "errors" - "fmt" "github.com/goplus/llgo/c/libuv" ) @@ -18,7 +17,7 @@ type bodyStream struct { } var ( - ErrClosedRequestBody = errors.New("request body: read/write on closed body") + ErrClosedBodyStream = errors.New("body stream: read/write on closed body") ) func newBodyStream(asyncHandle *libuv.Async) *bodyStream { @@ -30,7 +29,6 @@ func newBodyStream(asyncHandle *libuv.Async) *bodyStream { } func (rb *bodyStream) Read(p []byte) (n int, err error) { - fmt.Println("[debug] RequestBody Read called") select { case <-rb.done: err = rb.readCloseError() @@ -41,11 +39,9 @@ func (rb *bodyStream) Read(p []byte) (n int, err error) { for n < len(p) { if len(rb.chunk) == 0 { rb.asyncHandle.Send() - fmt.Println("[debug] RequestBody Read asyncHandle.Send called") select { case chunk := <-rb.readCh: rb.chunk = chunk - fmt.Println("[debug] RequestBody Read chunk received") case <-rb.done: err = rb.readCloseError() return @@ -64,16 +60,15 @@ func (rb *bodyStream) readCloseError() error { if rerr := rb.rerr; rerr != nil { return rerr } - return ErrClosedRequestBody + return ErrClosedBodyStream } func (rb *bodyStream) closeWithError(err error) error { - fmt.Println("[debug] RequestBody closeRead called") if rb.rerr != nil { return nil } if err == nil { - err = ErrClosedRequestBody + err = ErrClosedBodyStream } rb.rerr = err close(rb.done) diff --git a/x/net/http/response.go b/x/net/http/response.go index d8eca16..a64c4b0 100644 --- a/x/net/http/response.go +++ b/x/net/http/response.go @@ -239,18 +239,12 @@ func (r *Response) checkRespBody(taskData *clientTaskData) (needContinue bool) { select { case taskData.resc <- responseAndError{res: r}: case <-taskData.callerGone: - if debugSwitch { - println("############### checkRespBody callerGone") - } closeAndRemoveIdleConn(pc, true) return true } // Now that they've read from the unbuffered channel, they're safely // out of the select that also waits on this goroutine to die, so // we're allowed to exit now if needed (if alive is false) - if debugSwitch { - println("############### checkRespBody return") - } closeAndRemoveIdleConn(pc, false) return true } diff --git a/x/net/http/transport.go b/x/net/http/transport.go index 1fc027f..b7596c1 100644 --- a/x/net/http/transport.go +++ b/x/net/http/transport.go @@ -37,12 +37,6 @@ var DefaultTransport RoundTripper = &Transport{ // MaxIdleConnsPerHost. const DefaultMaxIdleConnsPerHost = 2 -// Debug switch provided for developers -const ( - debugSwitch = true - debugReadWriteLoop = true -) - type Transport struct { idleMu sync.Mutex closeIdle bool // user has requested to close all idle conns @@ -96,8 +90,8 @@ type responseAndError struct { } type timeoutData struct { - timeoutch chan struct{} - clientTaskData *clientTaskData + timeoutch chan struct{} + clientTaskData *clientTaskData } type readTrackingBody struct { @@ -182,9 +176,6 @@ func (tr *transportRequest) setError(err error) { func (t *Transport) putOrCloseIdleConn(pconn *persistConn) { if err := t.tryPutIdleConn(pconn); err != nil { - if debugSwitch { - println("############### putOrCloseIdleConn: close") - } pconn.close(err) } } @@ -276,9 +267,6 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error { t.idleLRU.add(pconn) if t.MaxIdleConns != 0 && t.idleLRU.len() > t.MaxIdleConns { oldest := t.idleLRU.removeOldest() - if debugSwitch { - println("############### tryPutIdleConn: removeOldest") - } oldest.close(errTooManyIdle) t.removeIdleConnLocked(oldest) } @@ -647,10 +635,6 @@ func (t *Transport) getLoopKey(req *Request) string { } func (t *Transport) RoundTrip(req *Request) (*Response, error) { - if debugSwitch { - println("############### RoundTrip start") - defer println("############### RoundTrip end") - } eventLoop := t.getClientEventLoop(req) @@ -662,15 +646,12 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { req.timer = &libuv.Timer{} libuv.InitTimer(eventLoop.loop, req.timer) ch := &timeoutData{ - timeoutch: req.timeoutch, - clientTaskData: nil, + timeoutch: req.timeoutch, + clientTaskData: nil, } (*libuv.Handle)(c.Pointer(req.timer)).SetData(c.Pointer(ch)) req.timer.Start(onTimeout, getMilliseconds(req.deadline), 0) - if debugSwitch { - println("############### timer start") - } didTimeout = func() bool { return req.timer.GetDueIn() == 0 } stopTimer = func() { close(req.timeoutch) @@ -678,9 +659,6 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { if (*libuv.Handle)(c.Pointer(req.timer)).IsClosing() == 0 { (*libuv.Handle)(c.Pointer(req.timer)).Close(nil) } - if debugSwitch { - println("############### timer close") - } } } else { didTimeout = alwaysFalse @@ -704,10 +682,6 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { } func (t *Transport) doRoundTrip(req *Request, loop *clientEventLoop) (*Response, error) { - if debugSwitch { - println("############### doRoundTrip start") - defer println("############### doRoundTrip end") - } //t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) //ctx := req.Context() //trace := httptrace.ContextClientTrace(ctx) @@ -843,10 +817,6 @@ func (t *Transport) doRoundTrip(req *Request, loop *clientEventLoop) (*Response, } func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persistConn, err error) { - if debugSwitch { - println("############### getConn start") - defer println("############### getConn end") - } req := treq.Request //trace := treq.trace //ctx := req.Context() @@ -902,9 +872,6 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi // cancellation error (see golang.org/issue/16049). select { case <-req.timeoutch: - if debugSwitch { - println("############### getConn: timeoutch") - } return nil, errors.New("timeout: req.Context().Err()") case err := <-cancelc: if err == errRequestCanceled { @@ -920,10 +887,6 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi // queueForDial queues w to wait for permission to begin dialing. // Once w receives permission to dial, it will do so in a separate goroutine. func (t *Transport) queueForDial(w *wantConn) { - if debugSwitch { - println("############### queueForDial start") - defer println("############### queueForDial end") - } w.beforeDial() if t.MaxConnsPerHost <= 0 { @@ -956,10 +919,6 @@ func (t *Transport) queueForDial(w *wantConn) { // dialConnFor has received permission to dial w.cm and is counted in t.connCount[w.cm.key()]. // If the dial is canceled or unsuccessful, dialConnFor decrements t.connCount[w.cm.key()]. func (t *Transport) dialConnFor(w *wantConn) { - if debugSwitch { - println("############### dialConnFor start") - defer println("############### dialConnFor end") - } defer w.afterDial() pc, err := t.dialConn(w.timeoutch, w.cm) @@ -1030,10 +989,6 @@ func (t *Transport) decConnsPerHost(key connectMethodKey) { } func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn *persistConn, err error) { - if debugSwitch { - println("############### dialConn start") - defer println("############### dialConn end") - } select { case <-timeoutch: err = errors.New("[t.dialConn] request timeout") @@ -1084,9 +1039,6 @@ func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn * select { case <-timeoutch: err = errors.New("[t.dialConn] request timeout") - if debugSwitch { - println("############### dialConn: timeoutch") - } pconn.close(err) return nil, err default: @@ -1095,10 +1047,6 @@ func (t *Transport) dialConn(timeoutch chan struct{}, cm connectMethod) (pconn * } func (t *Transport) dial(cm connectMethod) (*connData, error) { - if debugSwitch { - println("############### dial start") - defer println("############### dial end") - } addr := cm.addr() host, port, err := net.SplitHostPort(addr) if err != nil { @@ -1132,10 +1080,6 @@ func (t *Transport) dial(cm connectMethod) (*connData, error) { } func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { - if debugSwitch { - println("############### roundTrip start") - defer println("############### roundTrip end") - } testHookEnterRoundTrip() if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) { pc.t.putOrCloseIdleConn(pc) @@ -1210,18 +1154,9 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err for { testHookWaitResLoop() - if debugSwitch { - println("############### roundTrip for") - } select { case err := <-writeErrCh: - if debugSwitch { - println("############### roundTrip: writeErrch") - } if err != nil { - if debugSwitch { - println("############### roundTrip: writeErrch err != nil") - } pc.close(fmt.Errorf("write error: %w", err)) if pc.conn.nwrite == startBytesWritten { err = nothingWrittenError{err} @@ -1229,18 +1164,12 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err return nil, pc.mapRoundTripError(req, startBytesWritten, err) } case <-pcClosed: - if debugSwitch { - println("############### roundTrip: pcClosed") - } pcClosed = nil if canceled || pc.t.replaceReqCanceler(req.cancelKey, nil) { return nil, pc.mapRoundTripError(req, startBytesWritten, pc.closed) } //case <-respHeaderTimer: case re := <-resc: - if debugSwitch { - println("############### roundTrip: resc") - } if (re.res == nil) == (re.err == nil) { return nil, fmt.Errorf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil) } @@ -1249,9 +1178,6 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err } return re.res, nil case <-timeoutch: - if debugSwitch { - println("############### roundTrip: timeoutch") - } canceled = pc.t.cancelRequest(req.cancelKey, errors.New("timeout: req.Context().Err()")) timeoutch = nil return nil, errors.New("request timeout") @@ -1267,9 +1193,6 @@ func readWriteLoop(checker *libuv.Idle) { // The polling state machine! Poll all ready tasks and act on them... task := eventLoop.exec.Poll() for task != nil { - if debugSwitch { - println("############### polling") - } eventLoop.handleTask(task) task = eventLoop.exec.Poll() } @@ -1287,17 +1210,11 @@ func (eventLoop *clientEventLoop) handleTask(task *hyper.Task) { // If original taskId is set, we need to check it err = checkTaskType(task, clientTaskData) if err != nil { - if debugSwitch { - println("############### handleTask: checkTaskType err != nil") - } closeAndRemoveIdleConn(pc, true) return } switch clientTaskData.taskId { case handshake: - if debugReadWriteLoop { - println("############### write") - } // Check if the connection is closed select { @@ -1317,21 +1234,11 @@ func (eventLoop *clientEventLoop) handleTask(task *hyper.Task) { if err != nil { //pc.writeErrCh <- err // to the body reader, which might recycle us clientTaskData.writeErrCh <- err // to the roundTrip function - if debugSwitch { - println("############### handleTask: write err != nil") - } pc.close(err) return } - if debugReadWriteLoop { - println("############### write end") - } case read: - if debugReadWriteLoop { - println("############### read") - } - pc.tryPutIdleConn = func() bool { if err := pc.t.tryPutIdleConn(pc); err != nil { pc.closeErr = err @@ -1354,9 +1261,6 @@ func (eventLoop *clientEventLoop) handleTask(task *hyper.Task) { if pc.numExpectedResponses == 0 { pc.readLoopPeekFailLocked(hyperResp, err) pc.mu.Unlock() - if debugSwitch { - println("############### handleTask: numExpectedResponses == 0") - } closeAndRemoveIdleConn(pc, true) return } @@ -1383,15 +1287,9 @@ func (eventLoop *clientEventLoop) handleTask(task *hyper.Task) { select { case clientTaskData.resc <- responseAndError{err: err}: case <-clientTaskData.callerGone: - if debugSwitch { - println("############### handleTask read: callerGone") - } closeAndRemoveIdleConn(pc, true) return } - if debugSwitch { - println("############### handleTask: read err != nil") - } closeAndRemoveIdleConn(pc, true) return } @@ -1417,23 +1315,13 @@ func (eventLoop *clientEventLoop) handleTask(task *hyper.Task) { case clientTaskData.resc <- responseAndError{res: resp}: case <-clientTaskData.callerGone: // defer - if debugSwitch { - println("############### handleTask read: callerGone 2") - } pc.bodyStream.Close() clientTaskData.closeHyperBody() closeAndRemoveIdleConn(pc, true) return } - if debugReadWriteLoop { - println("############### read end") - } case readBodyChunk: - if debugReadWriteLoop { - println("############### readBodyChunk") - } - taskType := task.Type() if taskType == hyper.TaskBuf { chunk := (*hyper.Buf)(task.Value()) @@ -1444,9 +1332,6 @@ func (eventLoop *clientEventLoop) handleTask(task *hyper.Task) { task.Free() // Write to the channel pc.bodyStream.readCh <- bytes - if debugReadWriteLoop { - println("############### readBodyChunk end [buf]") - } return } @@ -1458,14 +1343,7 @@ func (eventLoop *clientEventLoop) handleTask(task *hyper.Task) { pc.alive = pc.alive && replaced && pc.tryPutIdleConn() - if debugSwitch { - println("############### handleTask readBodyChunk: alive: ", pc.alive) - } closeAndRemoveIdleConn(pc, false) - - if debugReadWriteLoop { - println("############### readBodyChunk end [empty]") - } } } @@ -1481,9 +1359,6 @@ func closeAndRemoveIdleConn(pc *persistConn, force bool) { if pc.alive == true && !force { return } - if debugSwitch { - println("############### closeAndRemoveIdleConn, force:", force) - } pc.close(pc.closeErr) pc.t.removeIdleConn(pc) } @@ -1551,10 +1426,6 @@ func (d *clientTaskData) closeHyperBody() { // onConnect is the libuv callback for a successful connection func onConnect(req *libuv.Connect, status c.Int) { - if debugSwitch { - println("############### connect start") - defer println("############### connect end") - } conn := (*connData)((*libuv.Req)(c.Pointer(req)).GetData()) if status < 0 { c.Fprintf(c.Stderr, c.Str("connect error: %s\n"), c.GoString(libuv.Strerror(libuv.Errno(status)))) @@ -1617,7 +1488,6 @@ func readCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen uin conn.readWaker.Free() } conn.readWaker = ctx.Waker() - println("############### readCallBack: IoPending") return hyper.IoPending } @@ -1649,16 +1519,11 @@ func writeCallBack(userdata c.Pointer, ctx *hyper.Context, buf *uint8, bufLen ui conn.writeWaker.Free() } conn.writeWaker = ctx.Waker() - println("############### writeCallBack: IoPending") return hyper.IoPending } // onTimeout is the libuv callback for a timeout func onTimeout(timer *libuv.Timer) { - if debugSwitch { - println("############### onTimeout start") - defer println("############### onTimeout end") - } data := (*timeoutData)((*libuv.Handle)(c.Pointer(timer)).GetData()) close(data.timeoutch) timer.Stop() @@ -1708,9 +1573,6 @@ func checkTaskType(task *hyper.Task, clientTaskData *clientTaskData) (err error) task.Free() if curTaskId == handshake || curTaskId == read { clientTaskData.writeErrCh <- err - if debugSwitch { - println("############### checkTaskType: writeErrCh") - } clientTaskData.pc.close(err) } if clientTaskData.pc.bodyStream != nil { @@ -1916,7 +1778,7 @@ type persistConn struct { closeErr error // Replace the closeErr in readLoop tryPutIdleConn func() bool // Replace the tryPutIdleConn in readLoop client *hyper.ClientConn // http long connection client handle - bodyStream *bodyStream // Implement non-blocking consumption of each responseBody chunk + bodyStream *bodyStream // Implement non-blocking consumption of each responseBody chunk chunkAsync *libuv.Async // Notifying that the received chunk has been read } @@ -1925,9 +1787,6 @@ type persistConn struct { // a "keep-alive" state. It does not interrupt any connections currently // in use. func (t *Transport) CloseIdleConnections() { - if debugSwitch { - println("############### CloseIdleConnections") - } //t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) t.idleMu.Lock() m := t.idleConn @@ -1947,9 +1806,6 @@ func (t *Transport) CloseIdleConnections() { } func (pc *persistConn) cancelRequest(err error) { - if debugSwitch { - println("############### cancelRequest") - } pc.mu.Lock() defer pc.mu.Unlock() pc.canceledErr = err @@ -1976,9 +1832,6 @@ func (pc *persistConn) markReused() { } func (pc *persistConn) closeLocked(err error) { - if debugSwitch { - println("############### pc closed") - } if err == nil { panic("nil error") } @@ -2158,16 +2011,10 @@ func (pc *persistConn) closeConnIfStillIdleLocked() { return } t.removeIdleConnLocked(pc) - if debugSwitch { - println("############### closeConnIfStillIdleLocked") - } pc.close(errIdleConnTimeout) } func (pc *persistConn) readLoopPeekFailLocked(resp *hyper.Response, err error) { - if debugSwitch { - println("############### readLoopPeekFailLocked") - } if pc.closed != nil { return }