diff --git a/callbacks.go b/callbacks.go index b9a0943..167e703 100644 --- a/callbacks.go +++ b/callbacks.go @@ -6,14 +6,23 @@ package nghttp2 import "C" import ( "bytes" + "crypto/tls" "io" "net/http" "net/url" + "strconv" "strings" "sync" "unsafe" ) +const ( + NGHTTP2_NO_ERROR = 0 + NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE = -521 + NGHTTP2_ERR_CALLBACK_FAILURE = -902 + NGHTTP2_ERR_DEFERRED = -508 +) + // OnServerDataRecvCallback callback function for libnghttp2 library // want receive data from network. // @@ -43,7 +52,7 @@ func OnServerDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer, buf := C.GoBytes(data, C.int(length)) n, err := conn.conn.Write(buf) if err != nil { - return -1 + return NGHTTP2_ERR_CALLBACK_FAILURE } //log.Println("send ", n, " bytes to network ", buf) return C.ssize_t(n) @@ -55,7 +64,10 @@ func OnServerDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer, func OnServerDataChunkRecv(ptr unsafe.Pointer, streamID C.int, data unsafe.Pointer, length C.size_t) C.int { conn := (*ServerConn)(ptr) - s := conn.streams[int(streamID)] + s, ok := conn.streams[int(streamID)] + if !ok { + return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE + } bp := s.req.Body.(*bodyProvider) buf := C.GoBytes(data, C.int(length)) bp.Write(buf) @@ -67,20 +79,27 @@ func OnServerDataChunkRecv(ptr unsafe.Pointer, streamID C.int, //export OnServerBeginHeaderCallback func OnServerBeginHeaderCallback(ptr unsafe.Pointer, streamID C.int) C.int { conn := (*ServerConn)(ptr) + var TLS tls.ConnectionState + if tlsconn, ok := conn.conn.(*tls.Conn); ok { + TLS = tlsconn.ConnectionState() + } + s := &ServerStream{ streamID: int(streamID), conn: conn, req: &http.Request{ - URL: &url.URL{}, + //URL: &url.URL{}, Header: http.Header{}, Proto: "HTTP/2.0", ProtoMajor: 2, ProtoMinor: 0, + RemoteAddr: conn.conn.RemoteAddr().String(), + TLS: &TLS, }, //buf: new(bytes.Buffer), } conn.streams[int(streamID)] = s - return 0 + return NGHTTP2_NO_ERROR } // OnServerHeaderCallback callback function for each header recv. @@ -90,7 +109,10 @@ func OnServerHeaderCallback(ptr unsafe.Pointer, streamID C.int, name unsafe.Pointer, namelen C.int, value unsafe.Pointer, valuelen C.int) C.int { conn := (*ServerConn)(ptr) - s := conn.streams[int(streamID)] + s, ok := conn.streams[int(streamID)] + if !ok { + return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE + } hdrname := C.GoStringN((*C.char)(name), namelen) hdrvalue := C.GoStringN((*C.char)(value), valuelen) hdrname = strings.ToLower(hdrname) @@ -98,22 +120,23 @@ func OnServerHeaderCallback(ptr unsafe.Pointer, streamID C.int, case ":method": s.req.Method = hdrvalue case ":scheme": - s.req.URL.Scheme = hdrvalue + // s.req.URL.Scheme = hdrvalue case ":path": s.req.RequestURI = hdrvalue u, _ := url.ParseRequestURI(s.req.RequestURI) - scheme := s.req.URL.Scheme - *(s.req.URL) = *u - if scheme != "" { - s.req.URL.Scheme = scheme - } + s.req.URL = u case ":authority": s.req.Host = hdrvalue + case "content-length": + s.req.Header.Add(hdrname, hdrvalue) + n, err := strconv.ParseInt(hdrvalue, 10, 64) + if err == nil { + s.req.ContentLength = n + } default: s.req.Header.Add(hdrname, hdrvalue) - } - return 0 + return NGHTTP2_NO_ERROR } // OnServerStreamEndCallback callback function for the stream when END_STREAM flag set @@ -122,7 +145,11 @@ func OnServerHeaderCallback(ptr unsafe.Pointer, streamID C.int, func OnServerStreamEndCallback(ptr unsafe.Pointer, streamID C.int) C.int { conn := (*ServerConn)(ptr) - s := conn.streams[int(streamID)] + s, ok := conn.streams[int(streamID)] + if !ok { + return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE + } + s.streamEnd = true bp := s.req.Body.(*bodyProvider) if s.req.Method != "CONNECT" { @@ -130,7 +157,7 @@ func OnServerStreamEndCallback(ptr unsafe.Pointer, streamID C.int) C.int { //log.Println("stream end flag set, begin to serve") go conn.serve(s) } - return 0 + return NGHTTP2_NO_ERROR } // OnServerHeadersDoneCallback callback function for the stream when all headers received. @@ -138,7 +165,10 @@ func OnServerStreamEndCallback(ptr unsafe.Pointer, streamID C.int) C.int { //export OnServerHeadersDoneCallback func OnServerHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int { conn := (*ServerConn)(ptr) - s := conn.streams[int(streamID)] + s, ok := conn.streams[int(streamID)] + if !ok { + return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE + } s.headersDone = true bp := &bodyProvider{ buf: new(bytes.Buffer), @@ -148,7 +178,7 @@ func OnServerHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int { if s.req.Method == "CONNECT" { go conn.serve(s) } - return 0 + return NGHTTP2_NO_ERROR } // OnServerStreamClose callback function for the stream when closed. @@ -156,17 +186,20 @@ func OnServerHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int { //export OnServerStreamClose func OnServerStreamClose(ptr unsafe.Pointer, streamID C.int) C.int { conn := (*ServerConn)(ptr) - s := conn.streams[int(streamID)] + s, ok := conn.streams[int(streamID)] + if !ok { + return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE + } conn.lock.Lock() delete(conn.streams, int(streamID)) conn.lock.Unlock() s.Close() - return 0 + return NGHTTP2_NO_ERROR } // OnDataSourceReadCallback callback function for libnghttp2 library // want read data from data provider source, -// return NGHTTP2_ERR_DEFERED will cause data frame defered, +// return NGHTTP2_ERR_DEFERRED will cause data frame defered, // application later call nghttp2_session_resume_data will re-quene the data frame // //export OnDataSourceReadCallback @@ -181,10 +214,9 @@ func OnDataSourceReadCallback(ptr unsafe.Pointer, return 0 } if err == errAgain { - // NGHTTP2_ERR_DEFERED - return -508 + return NGHTTP2_ERR_DEFERRED } - return -1 + return NGHTTP2_ERR_CALLBACK_FAILURE } cbuf := C.CBytes(gobuf) defer C.free(cbuf) @@ -200,8 +232,24 @@ func OnClientDataChunkRecv(ptr unsafe.Pointer, streamID C.int, //log.Println("on data recv") conn := (*ClientConn)(ptr) gobuf := C.GoBytes(buf, C.int(length)) - conn.onDataRecv(gobuf, int(streamID)) - return 0 + + s, ok := conn.streams[int(streamID)] + if !ok { + return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE + } + if s.res.Body == nil { + //log.Println("empty body") + return C.int(length) + } + + if bp, ok := s.res.Body.(*bodyProvider); ok { + n, err := bp.Write(gobuf) + if err != nil { + return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE + } + return C.int(n) + } + return C.int(length) } // OnClientDataRecvCallback callback function for libnghttp2 library want read data from network. @@ -233,8 +281,7 @@ func OnClientDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer, size C.si //log.Println(conn.conn.RemoteAddr()) n, err := conn.conn.Write(buf) if err != nil { - //log.Println(err) - return -1 + return NGHTTP2_ERR_CALLBACK_FAILURE } //log.Println("write data to network ", n) return C.ssize_t(n) @@ -246,8 +293,24 @@ func OnClientDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer, size C.si func OnClientBeginHeaderCallback(ptr unsafe.Pointer, streamID C.int) C.int { //log.Println("begin header") conn := (*ClientConn)(ptr) - conn.onBeginHeader(int(streamID)) - return 0 + + s, ok := conn.streams[int(streamID)] + if !ok { + return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE + } + var TLS tls.ConnectionState + if tlsconn, ok := conn.conn.(*tls.Conn); ok { + TLS = tlsconn.ConnectionState() + } + s.res = &http.Response{ + Header: make(http.Header), + Body: &bodyProvider{ + buf: new(bytes.Buffer), + lock: new(sync.Mutex), + }, + TLS: &TLS, + } + return NGHTTP2_NO_ERROR } // OnClientHeaderCallback callback function for each header received. @@ -258,10 +321,35 @@ func OnClientHeaderCallback(ptr unsafe.Pointer, streamID C.int, value unsafe.Pointer, valuelen C.int) C.int { //log.Println("header") conn := (*ClientConn)(ptr) - goname := C.GoBytes(name, namelen) - govalue := C.GoBytes(value, valuelen) - conn.onHeader(int(streamID), string(goname), string(govalue)) - return 0 + goname := string(C.GoBytes(name, namelen)) + govalue := string(C.GoBytes(value, valuelen)) + + s, ok := conn.streams[int(streamID)] + if !ok { + return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE + } + goname = strings.ToLower(goname) + switch goname { + case ":status": + statusCode, _ := strconv.Atoi(govalue) + s.res.StatusCode = statusCode + s.res.Status = http.StatusText(statusCode) + s.res.Proto = "HTTP/2.0" + s.res.ProtoMajor = 2 + s.res.ProtoMinor = 0 + case "content-length": + s.res.Header.Add(goname, govalue) + n, err := strconv.ParseInt(govalue, 10, 64) + if err == nil { + s.res.ContentLength = n + } + case "transfer-encoding": + s.res.Header.Add(goname, govalue) + s.res.TransferEncoding = append(s.res.TransferEncoding, govalue) + default: + s.res.Header.Add(goname, govalue) + } + return NGHTTP2_NO_ERROR } // OnClientHeadersDoneCallback callback function for the stream when all headers received. @@ -270,8 +358,12 @@ func OnClientHeaderCallback(ptr unsafe.Pointer, streamID C.int, func OnClientHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int { //log.Println("frame recv") conn := (*ClientConn)(ptr) - conn.onHeadersDone(int(streamID)) - return 0 + s, ok := conn.streams[int(streamID)] + if !ok { + return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE + } + s.resch <- s.res + return NGHTTP2_NO_ERROR } // OnClientStreamClose callback function for the stream when closed. @@ -280,6 +372,14 @@ func OnClientHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int { func OnClientStreamClose(ptr unsafe.Pointer, streamID C.int) C.int { //log.Println("stream close") conn := (*ClientConn)(ptr) - conn.onStreamClose(int(streamID)) - return 0 + + stream, ok := conn.streams[int(streamID)] + if ok { + stream.Close() + conn.lock.Lock() + delete(conn.streams, int(streamID)) + conn.lock.Unlock() + return NGHTTP2_NO_ERROR + } + return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE } diff --git a/conn.go b/conn.go index 609437a..f6f386b 100644 --- a/conn.go +++ b/conn.go @@ -6,14 +6,12 @@ package nghttp2 */ import "C" import ( - "bytes" "crypto/tls" "errors" "fmt" "io" "net" "net/http" - "strconv" "strings" "sync" "time" @@ -58,61 +56,6 @@ func NewClientConn(c net.Conn) (*ClientConn, error) { return conn, nil } -func (c *ClientConn) onDataRecv(buf []byte, streamID int) { - s := c.streams[streamID] - if s.res.Body == nil { - //log.Println("empty body") - return - } - - if bp, ok := s.res.Body.(*bodyProvider); ok { - bp.Write(buf) - } -} - -func (c *ClientConn) onBeginHeader(streamID int) { - s := c.streams[streamID] - - s.res = &http.Response{ - Header: make(http.Header), - Body: &bodyProvider{ - buf: new(bytes.Buffer), - lock: new(sync.Mutex), - }, - } -} - -func (c *ClientConn) onHeader(streamID int, name, value string) { - s := c.streams[streamID] - if name == ":status" { - statusCode, _ := strconv.Atoi(value) - s.res.StatusCode = statusCode - s.res.Status = http.StatusText(statusCode) - s.res.Proto = "HTTP/2.0" - s.res.ProtoMajor = 2 - s.res.ProtoMinor = 0 - return - } - s.res.Header.Add(name, value) - -} - -func (c *ClientConn) onHeadersDone(streamID int) { - s := c.streams[streamID] - s.resch <- s.res -} - -func (c *ClientConn) onStreamClose(streamID int) { - stream, ok := c.streams[streamID] - if ok { - stream.Close() - c.lock.Lock() - delete(c.streams, streamID) - c.lock.Unlock() - } - -} - // Close close the http2 connection func (c *ClientConn) Close() error { for _, s := range c.streams { @@ -291,6 +234,7 @@ func (c *ClientConn) CreateRequest(req *http.Request) (*http.Response, error) { case err := <-s.errch: return nil, err case res := <-s.resch: + res.Request = req return res, nil case <-c.errch: return nil, fmt.Errorf("connection error") diff --git a/http2_test.go b/http2_test.go index 96f84c6..efa3f42 100644 --- a/http2_test.go +++ b/http2_test.go @@ -51,7 +51,8 @@ func TestHttp2Client(t *testing.T) { if res.StatusCode != http.StatusOK { t.Errorf("expect %d, got %d", http.StatusOK, res.StatusCode) } - res.Write(os.Stderr) + log.Printf("%+v", res) + //res.Write(os.Stderr) req, _ = http.NewRequest("GET", "https://www.simicloud.com/media/httpbin/get?a=b&c=d", nil) @@ -62,7 +63,8 @@ func TestHttp2Client(t *testing.T) { if res.StatusCode != http.StatusOK { t.Errorf("expect %d, got %d", http.StatusOK, res.StatusCode) } - res.Write(os.Stderr) + log.Printf("%+v", res) + //res.Write(os.Stderr) log.Println("end") }