From 28a6eb07467f0f577a50b300dd88a09d8c60bd7a Mon Sep 17 00:00:00 2001 From: fangdingjun Date: Fri, 13 Jul 2018 15:44:22 +0800 Subject: [PATCH] code clean up remove unused code use go slice to manage nghttp2_nv array send RST_STREAM when connection closed on CONNECT request --- _nghttp2.h | 12 +++- callbacks.go | 132 ++++++++++++++++++++++---------------------- conn.go | 139 ++++++++++++++++++++++++----------------------- data_provider.go | 37 ++++++------- nghttp2.c | 95 +++++++++++--------------------- stream.go | 91 ++++++++++++------------------- 6 files changed, 235 insertions(+), 271 deletions(-) diff --git a/_nghttp2.h b/_nghttp2.h index 436fcb4..d585e11 100644 --- a/_nghttp2.h +++ b/_nghttp2.h @@ -9,7 +9,8 @@ extern ssize_t onClientDataRecvCallback(void *, void *data, size_t); extern ssize_t onClientDataSendCallback(void *, void *data, size_t); -extern ssize_t onDataSourceReadCallback(void *, void *, size_t); +extern ssize_t onServerDataSourceReadCallback(void *, int, void *, size_t); +extern ssize_t onClientDataSourceReadCallback(void *, int, void *, size_t); extern int onClientDataChunkRecv(void *, int, void *, size_t); extern int onClientBeginHeaderCallback(void *, int); extern int onClientHeaderCallback(void *, int, void *, int, void *, int); @@ -17,6 +18,13 @@ extern int onClientHeadersDoneCallback(void *, int); extern int onClientStreamClose(void *, int); extern void onClientConnectionCloseCallback(void *user_data); +int _nghttp2_submit_response(nghttp2_session *sess, int streamid, + size_t nv, size_t nvlen, nghttp2_data_provider *dp); + +int _nghttp2_submit_request(nghttp2_session *session, const nghttp2_priority_spec *pri_spec, + size_t nva, size_t nvlen, + const nghttp2_data_provider *data_prd, void *stream_user_data); + extern ssize_t onServerDataRecvCallback(void *, void *data, size_t); extern ssize_t onServerDataSendCallback(void *, void *data, size_t); extern int onServerDataChunkRecv(void *, int, void *, size_t); @@ -34,7 +42,7 @@ struct nv_array }; void delete_nv_array(struct nv_array *a); -nghttp2_data_provider *new_data_provider(size_t data); +int data_provider_set_callback(size_t dp, size_t data, int t); int nv_array_set(struct nv_array *a, int index, char *name, char *value, diff --git a/callbacks.go b/callbacks.go index 04d657f..a1d7f6a 100644 --- a/callbacks.go +++ b/callbacks.go @@ -23,26 +23,6 @@ const ( NGHTTP2_ERR_DEFERRED = -508 ) -/* -// onServerDataRecvCallback callback function for libnghttp2 library -// want receive data from network. -// -//export onServerDataRecvCallback -func onServerDataRecvCallback(ptr unsafe.Pointer, data unsafe.Pointer, - length C.size_t) C.ssize_t { - conn := (*ServerConn)(ptr) - buf := make([]byte, int(length)) - n, err := conn.conn.Read(buf) - if err != nil { - return -1 - } - cbuf := C.CBytes(buf[:n]) - defer C.free(cbuf) - C.memcpy(data, cbuf, C.size_t(n)) - return C.ssize_t(n) -} -*/ - // onServerDataSendCallback callback function for libnghttp2 library // want send data to network. // @@ -50,7 +30,7 @@ func onServerDataRecvCallback(ptr unsafe.Pointer, data unsafe.Pointer, func onServerDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer, length C.size_t) C.ssize_t { //log.Println("server data send") - conn := (*ServerConn)(ptr) + conn := (*ServerConn)(unsafe.Pointer(uintptr(ptr))) buf := C.GoBytes(data, C.int(length)) n, err := conn.conn.Write(buf) if err != nil { @@ -65,7 +45,7 @@ func onServerDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer, //export onServerDataChunkRecv func onServerDataChunkRecv(ptr unsafe.Pointer, streamID C.int, data unsafe.Pointer, length C.size_t) C.int { - conn := (*ServerConn)(ptr) + conn := (*ServerConn)(unsafe.Pointer(uintptr(ptr))) s, ok := conn.streams[int(streamID)] if !ok { return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE @@ -80,7 +60,7 @@ func onServerDataChunkRecv(ptr unsafe.Pointer, streamID C.int, // //export onServerBeginHeaderCallback func onServerBeginHeaderCallback(ptr unsafe.Pointer, streamID C.int) C.int { - conn := (*ServerConn)(ptr) + conn := (*ServerConn)(unsafe.Pointer(uintptr(ptr))) var TLS tls.ConnectionState if tlsconn, ok := conn.conn.(*tls.Conn); ok { TLS = tlsconn.ConnectionState() @@ -113,7 +93,7 @@ func onServerBeginHeaderCallback(ptr unsafe.Pointer, streamID C.int) C.int { func onServerHeaderCallback(ptr unsafe.Pointer, streamID C.int, name unsafe.Pointer, namelen C.int, value unsafe.Pointer, valuelen C.int) C.int { - conn := (*ServerConn)(ptr) + conn := (*ServerConn)(unsafe.Pointer(uintptr(ptr))) s, ok := conn.streams[int(streamID)] if !ok { return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE @@ -149,7 +129,7 @@ func onServerHeaderCallback(ptr unsafe.Pointer, streamID C.int, //export onServerStreamEndCallback func onServerStreamEndCallback(ptr unsafe.Pointer, streamID C.int) C.int { - conn := (*ServerConn)(ptr) + conn := (*ServerConn)(unsafe.Pointer(uintptr(ptr))) s, ok := conn.streams[int(streamID)] if !ok { return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE @@ -169,7 +149,7 @@ func onServerStreamEndCallback(ptr unsafe.Pointer, streamID C.int) C.int { // //export onServerHeadersDoneCallback func onServerHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int { - conn := (*ServerConn)(ptr) + conn := (*ServerConn)(unsafe.Pointer(uintptr(ptr))) s, ok := conn.streams[int(streamID)] if !ok { return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE @@ -190,8 +170,9 @@ func onServerHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int { // //export onServerStreamClose func onServerStreamClose(ptr unsafe.Pointer, streamID C.int) C.int { - conn := (*ServerConn)(ptr) + conn := (*ServerConn)(unsafe.Pointer(uintptr(ptr))) s, ok := conn.streams[int(streamID)] + //log.Printf("stream %d closed", int(streamID)) if !ok { return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE } @@ -202,18 +183,61 @@ func onServerStreamClose(ptr unsafe.Pointer, streamID C.int) C.int { return NGHTTP2_NO_ERROR } -// onDataSourceReadCallback callback function for libnghttp2 library +// onClientDataSourceReadCallback callback function for libnghttp2 library +// want read data from data provider source, +// return NGHTTP2_ERR_DEFERRED will cause data frame defered, +// application later call nghttp2_session_resume_data will re-quene the data frame +// +//export onClientDataSourceReadCallback +func onClientDataSourceReadCallback(ptr unsafe.Pointer, streamID C.int, + buf unsafe.Pointer, length C.size_t) C.ssize_t { + //log.Println("onDataSourceReadCallback begin") + conn := (*ClientConn)(unsafe.Pointer(uintptr(ptr))) + s, ok := conn.streams[int(streamID)] + if !ok { + //log.Println("client dp callback, stream not exists") + return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE + } + gobuf := make([]byte, int(length)) + n, err := s.dp.Read(gobuf) + if err != nil { + if err == io.EOF { + //log.Println("onDataSourceReadCallback end") + return 0 + } + if err == errAgain { + //log.Println("onDataSourceReadCallback end") + s.dp.deferred = true + return NGHTTP2_ERR_DEFERRED + } + //log.Println("onDataSourceReadCallback end") + return NGHTTP2_ERR_CALLBACK_FAILURE + } + //cbuf := C.CBytes(gobuf) + //defer C.free(cbuf) + //C.memcpy(buf, cbuf, C.size_t(n)) + C.memcpy(buf, unsafe.Pointer(&gobuf[0]), C.size_t(n)) + //log.Println("onDataSourceReadCallback end") + return C.ssize_t(n) +} + +// onServerDataSourceReadCallback callback function for libnghttp2 library // want read data from data provider source, // return NGHTTP2_ERR_DEFERRED will cause data frame defered, // application later call nghttp2_session_resume_data will re-quene the data frame // -//export onDataSourceReadCallback -func onDataSourceReadCallback(ptr unsafe.Pointer, +//export onServerDataSourceReadCallback +func onServerDataSourceReadCallback(ptr unsafe.Pointer, streamID C.int, buf unsafe.Pointer, length C.size_t) C.ssize_t { //log.Println("onDataSourceReadCallback begin") - dp := (*dataProvider)(ptr) + conn := (*ServerConn)(unsafe.Pointer(uintptr(ptr))) + s, ok := conn.streams[int(streamID)] + if !ok { + //log.Println("server dp callback, stream not exists") + return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE + } gobuf := make([]byte, int(length)) - n, err := dp.Read(gobuf) + n, err := s.dp.Read(gobuf) if err != nil { if err == io.EOF { //log.Println("onDataSourceReadCallback end") @@ -221,15 +245,16 @@ func onDataSourceReadCallback(ptr unsafe.Pointer, } if err == errAgain { //log.Println("onDataSourceReadCallback end") - dp.deferred = true + s.dp.deferred = true return NGHTTP2_ERR_DEFERRED } //log.Println("onDataSourceReadCallback end") return NGHTTP2_ERR_CALLBACK_FAILURE } - cbuf := C.CBytes(gobuf) - defer C.free(cbuf) - C.memcpy(buf, cbuf, C.size_t(n)) + //cbuf := C.CBytes(gobuf) + //defer C.free(cbuf) + //C.memcpy(buf, cbuf, C.size_t(n)) + C.memcpy(buf, unsafe.Pointer(&gobuf[0]), C.size_t(n)) //log.Println("onDataSourceReadCallback end") return C.ssize_t(n) } @@ -240,7 +265,7 @@ func onDataSourceReadCallback(ptr unsafe.Pointer, func onClientDataChunkRecv(ptr unsafe.Pointer, streamID C.int, buf unsafe.Pointer, length C.size_t) C.int { //log.Println("onClientDataChunkRecv begin") - conn := (*ClientConn)(ptr) + conn := (*ClientConn)(unsafe.Pointer(uintptr(ptr))) gobuf := C.GoBytes(buf, C.int(length)) s, ok := conn.streams[int(streamID)] @@ -266,34 +291,13 @@ func onClientDataChunkRecv(ptr unsafe.Pointer, streamID C.int, return C.int(length) } -/* -// onClientDataRecvCallback callback function for libnghttp2 library want read data from network. -// -//export onClientDataRecvCallback -func onClientDataRecvCallback(ptr unsafe.Pointer, data unsafe.Pointer, size C.size_t) C.ssize_t { - //log.Println("data read req", int(size)) - conn := (*ClientConn)(ptr) - buf := make([]byte, int(size)) - //log.Println(conn.conn.RemoteAddr()) - n, err := conn.conn.Read(buf) - if err != nil { - //log.Println(err) - return -1 - } - cbuf := C.CBytes(buf) - //log.Println("read from network ", n, buf[:n]) - C.memcpy(data, cbuf, C.size_t(n)) - return C.ssize_t(n) -} -*/ - // onClientDataSendCallback callback function for libnghttp2 library want send data to network. // //export onClientDataSendCallback func onClientDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer, size C.size_t) C.ssize_t { //log.Println("onClientDataSendCallback begin") //log.Println("data write req ", int(size)) - conn := (*ClientConn)(ptr) + conn := (*ClientConn)(unsafe.Pointer(uintptr(ptr))) buf := C.GoBytes(data, C.int(size)) //log.Println(conn.conn.RemoteAddr()) n, err := conn.conn.Write(buf) @@ -312,7 +316,7 @@ func onClientDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer, size C.si func onClientBeginHeaderCallback(ptr unsafe.Pointer, streamID C.int) C.int { //log.Println("onClientBeginHeaderCallback begin") //log.Printf("stream %d begin headers", int(streamID)) - conn := (*ClientConn)(ptr) + conn := (*ClientConn)(unsafe.Pointer(uintptr(ptr))) s, ok := conn.streams[int(streamID)] if !ok { @@ -343,7 +347,7 @@ func onClientHeaderCallback(ptr unsafe.Pointer, streamID C.int, value unsafe.Pointer, valuelen C.int) C.int { //log.Println("onClientHeaderCallback begin") //log.Println("header") - conn := (*ClientConn)(ptr) + conn := (*ClientConn)(unsafe.Pointer(uintptr(ptr))) goname := string(C.GoBytes(name, namelen)) govalue := string(C.GoBytes(value, valuelen)) @@ -383,7 +387,7 @@ func onClientHeaderCallback(ptr unsafe.Pointer, streamID C.int, func onClientHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int { //log.Println("onClientHeadersDoneCallback begin") //log.Printf("stream %d headers done", int(streamID)) - conn := (*ClientConn)(ptr) + conn := (*ClientConn)(unsafe.Pointer(uintptr(ptr))) s, ok := conn.streams[int(streamID)] if !ok { //log.Println("onClientHeadersDoneCallback end") @@ -403,7 +407,7 @@ func onClientHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int { func onClientStreamClose(ptr unsafe.Pointer, streamID C.int) C.int { //log.Println("onClientStreamClose begin") //log.Printf("stream %d closed", int(streamID)) - conn := (*ClientConn)(ptr) + conn := (*ClientConn)(unsafe.Pointer(uintptr(ptr))) stream, ok := conn.streams[int(streamID)] if ok { @@ -421,7 +425,7 @@ func onClientStreamClose(ptr unsafe.Pointer, streamID C.int) C.int { //export onClientConnectionCloseCallback func onClientConnectionCloseCallback(ptr unsafe.Pointer) { - conn := (*ClientConn)(ptr) + conn := (*ClientConn)(unsafe.Pointer(uintptr(ptr))) conn.err = io.EOF // signal all goroutings exit diff --git a/conn.go b/conn.go index ca615ad..a1b031b 100644 --- a/conn.go +++ b/conn.go @@ -123,14 +123,16 @@ func (c *ClientConn) run() { } //log.Printf("read %d bytes from network", n) lastDataRecv = time.Now() - d1 := C.CBytes(buf[:n]) + //d1 := C.CBytes(buf[:n]) c.lock.Lock() + //ret1 := C.nghttp2_session_mem_recv(c.session, + // (*C.uchar)(d1), C.size_t(n)) ret1 := C.nghttp2_session_mem_recv(c.session, - (*C.uchar)(d1), C.size_t(n)) + (*C.uchar)(unsafe.Pointer(&buf[0])), C.size_t(n)) c.lock.Unlock() - C.free(d1) + //C.free(d1) if int(ret1) < 0 { c.lock.Lock() c.err = fmt.Errorf("sesion recv error: %s", @@ -209,46 +211,36 @@ func (c *ClientConn) Connect(addr string) (cs *ClientStream, statusCode int, err if c.err != nil { return nil, http.StatusServiceUnavailable, c.err } - nvIndex := 0 - nvMax := 5 - nva := C.new_nv_array(C.size_t(nvMax)) - //log.Printf("%s %s", req.Method, req.RequestURI) - setNvArray(nva, nvIndex, ":method", "CONNECT", 0) - nvIndex++ - //setNvArray(nva, nvIndex, ":scheme", "https", 0) - //nvIndex++ - //log.Printf("header authority: %s", req.RequestURI) - setNvArray(nva, nvIndex, ":authority", addr, 0) - nvIndex++ + var nv = []C.nghttp2_nv{} + nv = append(nv, newNV(":method", "CONNECT")) + nv = append(nv, newNV(":authority", addr)) var dp *dataProvider - var cdp *C.nghttp2_data_provider - dp, cdp = newDataProvider(c.lock) + + s := &ClientStream{ + conn: c, + cdp: C.nghttp2_data_provider{}, + resch: make(chan *http.Response), + errch: make(chan error), + lock: new(sync.Mutex), + } + + dp = newDataProvider(unsafe.Pointer(&s.cdp), c.lock, 1) + s.dp = dp c.lock.Lock() - streamID := C.nghttp2_submit_request(c.session, nil, - nva.nv, C.size_t(nvIndex), cdp, nil) + streamID := C._nghttp2_submit_request(c.session, nil, + C.size_t(uintptr(unsafe.Pointer(&nv[0]))), C.size_t(len(nv)), &s.cdp, nil) c.lock.Unlock() - C.delete_nv_array(nva) if int(streamID) < 0 { return nil, http.StatusServiceUnavailable, fmt.Errorf( "submit request error: %s", C.GoString(C.nghttp2_strerror(streamID))) } - if dp != nil { - dp.streamID = int(streamID) - dp.session = c.session - } + dp.streamID = int(streamID) + dp.session = c.session + s.streamID = int(streamID) //log.Println("stream id ", int(streamID)) - s := &ClientStream{ - streamID: int(streamID), - conn: c, - dp: dp, - cdp: cdp, - resch: make(chan *http.Response), - errch: make(chan error), - lock: new(sync.Mutex), - } c.lock.Lock() c.streams[int(streamID)] = s @@ -262,6 +254,12 @@ func (c *ClientConn) Connect(addr string) (cs *ClientStream, statusCode int, err return nil, http.StatusServiceUnavailable, err case res := <-s.resch: if res != nil { + res.Request = &http.Request{ + Method: "CONNECT", + RequestURI: addr, + URL: &url.URL{}, + Host: addr, + } return s, res.StatusCode, nil } //log.Println("wait response, empty response") @@ -271,23 +269,31 @@ func (c *ClientConn) Connect(addr string) (cs *ClientStream, statusCode int, err } } +func newNV(name, value string) C.nghttp2_nv { + nv := C.nghttp2_nv{} + nameArr := make([]byte, len(name)+1) + valueArr := make([]byte, len(value)+1) + copy(nameArr, []byte(name)) + copy(valueArr, []byte(value)) + + nv.name = (*C.uchar)(unsafe.Pointer(&nameArr[0])) + nv.value = (*C.uchar)(unsafe.Pointer(&valueArr[0])) + nv.namelen = C.size_t(len(name)) + nv.valuelen = C.size_t(len(value)) + nv.flags = 0 + return nv +} + // CreateRequest submit a request and return a http.Response, func (c *ClientConn) CreateRequest(req *http.Request) (*http.Response, error) { if c.err != nil { return nil, c.err } - nvIndex := 0 - nvMax := 25 - nva := C.new_nv_array(C.size_t(nvMax)) - setNvArray(nva, nvIndex, ":method", req.Method, 0) - nvIndex++ - if req.Method != "CONNECT" { - setNvArray(nva, nvIndex, ":scheme", "https", 0) - nvIndex++ - } - setNvArray(nva, nvIndex, ":authority", req.Host, 0) - nvIndex++ + nv := []C.nghttp2_nv{} + nv = append(nv, newNV(":method", req.Method)) + nv = append(nv, newNV(":scheme", "https")) + nv = append(nv, newNV(":authority", req.Host)) /* :path must starts with "/" @@ -299,10 +305,7 @@ func (c *ClientConn) CreateRequest(req *http.Request) (*http.Response, error) { p = p + "?" + q } - if req.Method != "CONNECT" { - setNvArray(nva, nvIndex, ":path", p, 0) - nvIndex++ - } + nv = append(nv, newNV(":path", p)) //log.Printf("%s http://%s%s", req.Method, req.Host, p) for k, v := range req.Header { @@ -312,15 +315,23 @@ func (c *ClientConn) CreateRequest(req *http.Request) (*http.Response, error) { continue } //log.Printf("header %s: %s", k, v) - setNvArray(nva, nvIndex, strings.Title(k), v[0], 0) - nvIndex++ + nv = append(nv, newNV(k, v[0])) } var dp *dataProvider - var cdp *C.nghttp2_data_provider + + s := &ClientStream{ + //streamID: int(streamID), + conn: c, + resch: make(chan *http.Response), + errch: make(chan error), + lock: new(sync.Mutex), + } if req.Method == "PUT" || req.Method == "POST" || req.Method == "CONNECT" { - dp, cdp = newDataProvider(c.lock) + s.cdp = C.nghttp2_data_provider{} + dp = newDataProvider(unsafe.Pointer(&s.cdp), c.lock, 1) + s.dp = dp go func() { io.Copy(dp, req.Body) dp.Close() @@ -328,17 +339,17 @@ func (c *ClientConn) CreateRequest(req *http.Request) (*http.Response, error) { } c.lock.Lock() - streamID := C.nghttp2_submit_request(c.session, nil, - nva.nv, C.size_t(nvIndex), cdp, nil) + streamID := C._nghttp2_submit_request(c.session, nil, + C.size_t(uintptr(unsafe.Pointer(&nv[0]))), C.size_t(len(nv)), &s.cdp, nil) c.lock.Unlock() - C.delete_nv_array(nva) + //C.delete_nv_array(nva) if int(streamID) < 0 { return nil, fmt.Errorf("submit request error: %s", C.GoString(C.nghttp2_strerror(streamID))) } - + s.streamID = int(streamID) //log.Printf("new stream, id %d", int(streamID)) if dp != nil { @@ -346,16 +357,6 @@ func (c *ClientConn) CreateRequest(req *http.Request) (*http.Response, error) { dp.session = c.session } - s := &ClientStream{ - streamID: int(streamID), - conn: c, - dp: dp, - cdp: cdp, - resch: make(chan *http.Response), - errch: make(chan error), - lock: new(sync.Mutex), - } - c.lock.Lock() c.streams[int(streamID)] = s c.streamCount++ @@ -528,14 +529,16 @@ func (c *ServerConn) Run() { break } - d1 := C.CBytes(buf[:n]) + //d1 := C.CBytes(buf[:n]) c.lock.Lock() + //ret1 := C.nghttp2_session_mem_recv(c.session, + // (*C.uchar)(d1), C.size_t(n)) ret1 := C.nghttp2_session_mem_recv(c.session, - (*C.uchar)(d1), C.size_t(n)) + (*C.uchar)(unsafe.Pointer(&buf[0])), C.size_t(n)) c.lock.Unlock() - C.free(d1) + //C.free(d1) if int(ret1) < 0 { c.lock.Lock() c.err = fmt.Errorf("sesion recv error: %s", diff --git a/data_provider.go b/data_provider.go index 190a630..30b748c 100644 --- a/data_provider.go +++ b/data_provider.go @@ -6,7 +6,9 @@ package nghttp2 import "C" import ( "bytes" + "errors" "io" + "log" "sync" "time" "unsafe" @@ -28,6 +30,10 @@ type dataProvider struct { // Read read from data provider func (dp *dataProvider) Read(buf []byte) (n int, err error) { + if dp.buf == nil || dp.lock == nil || dp.sessLock == nil || dp.session == nil { + log.Println("db read invalid state") + return 0, errors.New("invalid state") + } dp.lock.Lock() defer dp.lock.Unlock() @@ -40,6 +46,10 @@ func (dp *dataProvider) Read(buf []byte) (n int, err error) { // Write provider data for data provider func (dp *dataProvider) Write(buf []byte) (n int, err error) { + if dp.buf == nil || dp.lock == nil || dp.sessLock == nil || dp.session == nil { + log.Println("dp write invalid state") + return 0, errors.New("invalid state") + } dp.lock.Lock() defer dp.lock.Unlock() @@ -59,6 +69,10 @@ func (dp *dataProvider) Write(buf []byte) (n int, err error) { // Close end to provide data func (dp *dataProvider) Close() error { + if dp.buf == nil || dp.lock == nil || dp.sessLock == nil || dp.session == nil { + log.Println("dp close, invalid state") + return errors.New("invalid state") + } dp.lock.Lock() defer dp.lock.Unlock() @@ -77,15 +91,15 @@ func (dp *dataProvider) Close() error { return nil } -func newDataProvider(sessionLock *sync.Mutex) ( - *dataProvider, *C.nghttp2_data_provider) { +func newDataProvider(cdp unsafe.Pointer, sessionLock *sync.Mutex, t int) *dataProvider { dp := &dataProvider{ buf: new(bytes.Buffer), lock: new(sync.Mutex), sessLock: sessionLock, } - cdp := C.new_data_provider(C.size_t(uintptr(unsafe.Pointer(dp)))) - return dp, cdp + C.data_provider_set_callback(C.size_t(uintptr(cdp)), + C.size_t(uintptr(unsafe.Pointer(dp))), C.int(t)) + return dp } // bodyProvider provide data for http body @@ -130,18 +144,3 @@ func (bp *bodyProvider) Close() error { bp.closed = true return nil } - -// setNvArray set the array for nghttp2_nv array -func setNvArray(a *C.struct_nv_array, index int, - name, value string, flags int) { - cname := C.CString(name) - cvalue := C.CString(value) - cnamelen := C.size_t(len(name)) - cvaluelen := C.size_t(len(value)) - cflags := C.int(flags) - - // note: cname and cvalue will freed in C.delete_nv_array - - C.nv_array_set(a, C.int(index), cname, - cvalue, cnamelen, cvaluelen, cflags) -} diff --git a/nghttp2.c b/nghttp2.c index ee8540b..dd0d91f 100644 --- a/nghttp2.c +++ b/nghttp2.c @@ -133,13 +133,6 @@ int send_server_connection_header(nghttp2_session *session) rv = nghttp2_submit_settings(session, NGHTTP2_FLAG_NONE, iv, ARRLEN(iv)); return rv; - /* - if (rv != 0) { - // warnx("Fatal error: %s", nghttp2_strerror(rv)); - return rv; - } - return 0; - */ } // send_callback send data to network @@ -148,14 +141,6 @@ static ssize_t client_send_callback(nghttp2_session *session, const uint8_t *dat { return onClientDataSendCallback(user_data, (void *)data, length); } -/* -// recv_callback read data from network -static ssize_t client_recv_callback(nghttp2_session *session, uint8_t *buf, - size_t length, int flags, void *user_data) -{ - return onClientDataRecvCallback(user_data, (void *)buf, length); -} -*/ static int on_client_header_callback(nghttp2_session *session, const nghttp2_frame *frame, const uint8_t *name, @@ -286,11 +271,23 @@ static int on_client_stream_close_callback(nghttp2_session *session, int32_t str return 0; } -static ssize_t data_source_read_callback(nghttp2_session *session, int32_t stream_id, - uint8_t *buf, size_t length, uint32_t *data_flags, - nghttp2_data_source *source, void *user_data) +static ssize_t on_client_data_source_read_callback(nghttp2_session *session, int32_t stream_id, + uint8_t *buf, size_t length, uint32_t *data_flags, + nghttp2_data_source *source, void *user_data) +{ + int ret = onClientDataSourceReadCallback(user_data, stream_id, buf, length); + if (ret == 0) + { + *data_flags = NGHTTP2_DATA_FLAG_EOF; + } + return ret; +} + +static ssize_t on_server_data_source_read_callback(nghttp2_session *session, int32_t stream_id, + uint8_t *buf, size_t length, uint32_t *data_flags, + nghttp2_data_source *source, void *user_data) { - int ret = onDataSourceReadCallback(source->ptr, buf, length); + int ret = onServerDataSourceReadCallback(user_data, stream_id, buf, length); if (ret == 0) { *data_flags = NGHTTP2_DATA_FLAG_EOF; @@ -403,57 +400,31 @@ int32_t submit_request(nghttp2_session *session, nghttp2_nv *hdrs, size_t hdrlen } #endif -struct nv_array *new_nv_array(size_t n) +int data_provider_set_callback(size_t cdp, size_t data, int type) { - struct nv_array *a = malloc(sizeof(struct nv_array)); - nghttp2_nv *nv = (nghttp2_nv *)malloc(n * sizeof(nghttp2_nv)); - memset(nv, 0, n * sizeof(nghttp2_nv)); - a->nv = nv; - a->len = n; - return a; -} - -int nv_array_set(struct nv_array *a, int index, - char *name, char *value, - size_t namelen, size_t valuelen, int flag) -{ - if (index > (a->len - 1)) + //nghttp2_data_provider *dp = malloc(sizeof(nghttp2_data_provider)); + nghttp2_data_provider *dp = (nghttp2_data_provider *)cdp; + dp->source.ptr = (void *)((int *)data); + if (type == 0) { - return -1; + dp->read_callback = on_server_data_source_read_callback; + } + else + { + dp->read_callback = on_client_data_source_read_callback; } - nghttp2_nv *nv = &((a->nv)[index]); - nv->name = name; - nv->value = value; - nv->namelen = namelen; - nv->valuelen = valuelen; - nv->flags = flag; return 0; } -void delete_nv_array(struct nv_array *a) +int _nghttp2_submit_response(nghttp2_session *sess, int streamid, + size_t nv, size_t nvlen, nghttp2_data_provider *dp) { - int i; - nghttp2_nv *nv; - for (i = 0; i < a->len; i++) - { - nv = &((a->nv)[i]); - if (nv->name != NULL) - { - free(nv->name); - } - if (nv->value != NULL) - { - free(nv->value); - } - } - free(a->nv); - free(a); + return nghttp2_submit_response(sess, streamid, (nghttp2_nv *)nv, nvlen, dp); } -nghttp2_data_provider *new_data_provider(size_t data) +int _nghttp2_submit_request(nghttp2_session *session, const nghttp2_priority_spec *pri_spec, + size_t nva, size_t nvlen, + const nghttp2_data_provider *data_prd, void *stream_user_data) { - nghttp2_data_provider *dp = malloc(sizeof(nghttp2_data_provider)); - dp->source.ptr = (void *)((int *)data); - dp->read_callback = data_source_read_callback; - return dp; + return nghttp2_submit_request(session, pri_spec, (nghttp2_nv *)nva, nvlen, data_prd, stream_user_data); } \ No newline at end of file diff --git a/stream.go b/stream.go index e6e22ae..bbfba34 100644 --- a/stream.go +++ b/stream.go @@ -7,6 +7,7 @@ import "C" import ( "fmt" "io" + "log" "net/http" "strings" "sync" @@ -17,17 +18,13 @@ import ( type ClientStream struct { streamID int conn *ClientConn - cdp *C.nghttp2_data_provider + cdp C.nghttp2_data_provider dp *dataProvider - // application read data from stream - //r *io.PipeReader - // recv stream data from session - //w *io.PipeWriter - res *http.Response - resch chan *http.Response - errch chan error - closed bool - lock *sync.Mutex + res *http.Response + resch chan *http.Response + errch chan error + closed bool + lock *sync.Mutex } // Read read stream data @@ -63,33 +60,21 @@ func (s *ClientStream) Close() error { case s.errch <- err: default: } - //log.Println("close stream resch") - //close(s.resch) - //log.Println("close stream errch") - //close(s.errch) - //log.Println("close pipe w") if s.res != nil && s.res.Body != nil { s.res.Body.Close() } //log.Println("close stream done") if s.dp != nil { s.dp.Close() - //s.conn.lock.Lock() - //C.nghttp2_submit_rst_stream(s.conn.session, 0, C.int(s.streamID), 0) - //s.conn.lock.Unlock() - s.dp = nil - } - if s.cdp != nil { - C.free(unsafe.Pointer(s.cdp)) - s.cdp = nil } - s.conn.lock.Lock() - if _, ok := s.conn.streams[s.streamID]; ok { - delete(s.conn.streams, s.streamID) + if s.res != nil && s.res.Request != nil && + s.res.Request.Method == "CONNECT" { + //log.Printf("send rst stream for %d", s.streamID) + s.conn.lock.Lock() + C.nghttp2_submit_rst_stream(s.conn.session, 0, C.int(s.streamID), 0) + s.conn.lock.Unlock() } - s.conn.lock.Unlock() - return nil } @@ -114,7 +99,7 @@ type ServerStream struct { // data provider dp *dataProvider - cdp *C.nghttp2_data_provider + cdp C.nghttp2_data_provider closed bool //buf *bytes.Buffer @@ -147,39 +132,35 @@ func (s *ServerStream) WriteHeader(code int) { s.responseSend = true s.statusCode = code - nvIndex := 0 - nvMax := 25 - - nva := C.new_nv_array(C.size_t(nvMax)) + var nv = []C.nghttp2_nv{} - setNvArray(nva, nvIndex, ":status", fmt.Sprintf("%d", code), 0) - nvIndex++ + nv = append(nv, newNV(":status", fmt.Sprintf("%d", code))) for k, v := range s.header { - if strings.ToLower(k) == "host" { + //log.Println(k, v[0]) + _k := strings.ToLower(k) + if _k == "host" || _k == "connection" || _k == "proxy-connection" || + _k == "transfer-encoding" { continue } - //log.Printf("header %s: %s", k, v) - setNvArray(nva, nvIndex, strings.Title(k), v[0], 0) - nvIndex++ + nv = append(nv, newNV(k, v[0])) } var dp *dataProvider - var cdp *C.nghttp2_data_provider - dp, cdp = newDataProvider(s.conn.lock) + dp = newDataProvider(unsafe.Pointer(&s.cdp), s.conn.lock, 0) dp.streamID = s.streamID dp.session = s.conn.session s.dp = dp - s.cdp = cdp s.conn.lock.Lock() - ret := C.nghttp2_submit_response( - s.conn.session, C.int(s.streamID), nva.nv, C.size_t(nvIndex), cdp) + ret := C._nghttp2_submit_response( + s.conn.session, C.int(s.streamID), + C.size_t(uintptr(unsafe.Pointer(&nv[0]))), + C.size_t(len(nv)), &s.cdp) s.conn.lock.Unlock() - C.delete_nv_array(nva) if int(ret) < 0 { panic(fmt.Sprintf("sumit response error %s", C.GoString(C.nghttp2_strerror(ret)))) @@ -203,26 +184,24 @@ func (s *ServerStream) Close() error { } s.closed = true - //C.nghttp2_submit_rst_stream(s.conn.session, 0, C.int(s.streamID), 0) if s.req.Body != nil { s.req.Body.Close() } if s.dp != nil { s.dp.Close() - s.dp = nil + //s.dp = nil } - if s.cdp != nil { - C.free(unsafe.Pointer(s.cdp)) - s.cdp = nil - } + if s.req.Method == "CONNECT" { + s.conn.lock.Lock() + s.conn.lock.Unlock() - s.conn.lock.Lock() - s.conn.lock.Unlock() - - if _, ok := s.conn.streams[s.streamID]; ok { - delete(s.conn.streams, s.streamID) + if _, ok := s.conn.streams[s.streamID]; ok { + log.Printf("send rst stream %d", s.streamID) + C.nghttp2_submit_rst_stream(s.conn.session, 0, C.int(s.streamID), 0) + delete(s.conn.streams, s.streamID) + } } //log.Printf("stream %d closed", s.streamID) return nil