From f23c78ede5e400a06d31f13399da9e02c9e941d1 Mon Sep 17 00:00:00 2001 From: fangdingjun Date: Thu, 9 Aug 2018 14:48:35 +0800 Subject: [PATCH] use context to manage conn/stream close --- callbacks.go | 11 ++++-- conn.go | 92 +++++++++++++++++++++++++++++------------------- data_provider.go | 20 +++++++---- stream.go | 66 ++++++++++++++++++---------------- 4 files changed, 111 insertions(+), 78 deletions(-) diff --git a/callbacks.go b/callbacks.go index 16f4807..7755036 100644 --- a/callbacks.go +++ b/callbacks.go @@ -6,6 +6,7 @@ package nghttp2 import "C" import ( "bytes" + "context" "crypto/tls" "errors" "io" @@ -115,7 +116,8 @@ func onDataChunkRecv(ptr unsafe.Pointer, streamID C.int, // onDataSendCallback callback function for libnghttp2 library want send data to network. // //export onDataSendCallback -func onDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer, size C.size_t) C.ssize_t { +func onDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer, + size C.size_t) C.ssize_t { //log.Println("onDataSendCallback begin") //log.Println("data write req ", int(size)) conn := (*Conn)(unsafe.Pointer(uintptr(ptr))) @@ -177,6 +179,7 @@ func onBeginHeaderCallback(ptr unsafe.Pointer, streamID C.int) C.int { TLS: &TLS, }, } + s.ctx, s.cancel = context.WithCancel(context.Background()) s.request.Body = s.bp //log.Printf("new stream %d", int(streamID)) conn.streams[int(streamID)] = s @@ -245,9 +248,11 @@ func onHeaderCallback(ptr unsafe.Pointer, streamID C.int, case "transfer-encoding": header.Add(goname, govalue) if conn.isServer { - s.request.TransferEncoding = append(s.response.TransferEncoding, govalue) + s.request.TransferEncoding = append( + s.response.TransferEncoding, govalue) } else { - s.response.TransferEncoding = append(s.response.TransferEncoding, govalue) + s.response.TransferEncoding = append( + s.response.TransferEncoding, govalue) } default: header.Add(goname, govalue) diff --git a/conn.go b/conn.go index 1c2f012..1540f8a 100644 --- a/conn.go +++ b/conn.go @@ -7,6 +7,7 @@ package nghttp2 import "C" import ( "bytes" + "context" "crypto/tls" "errors" "fmt" @@ -27,14 +28,14 @@ type Conn struct { session *C.nghttp2_session streams map[int]*stream streamCount int - closed bool isServer bool running bool handler http.Handler lock *sync.Mutex err error errch chan error - exitch chan struct{} + ctx context.Context + cancel context.CancelFunc } // Dial connect to addr and create a http2 client Conn @@ -75,21 +76,25 @@ func Server(c net.Conn, handler http.Handler) (*Conn, error) { conn: c, handler: handler, errch: make(chan error), - exitch: make(chan struct{}), lock: new(sync.Mutex), isServer: true, streams: make(map[int]*stream), } + + conn.ctx, conn.cancel = context.WithCancel(context.Background()) + //log.Printf("new conn %x", uintptr(unsafe.Pointer(conn))) runtime.SetFinalizer(conn, (*Conn).free) - conn.session = C.init_nghttp2_server_session(C.size_t(uintptr(unsafe.Pointer(conn)))) + conn.session = C.init_nghttp2_server_session( + C.size_t(uintptr(unsafe.Pointer(conn)))) if conn.session == nil { return nil, errors.New("init server session failed") } ret := C.send_connection_header(conn.session) if int(ret) < 0 { conn.Close() - return nil, fmt.Errorf("send settings error: %s", C.GoString(C.nghttp2_strerror(ret))) + return nil, fmt.Errorf("send settings error: %s", + C.GoString(C.nghttp2_strerror(ret))) } return conn, nil } @@ -103,20 +108,24 @@ func Client(c net.Conn) (*Conn, error) { conn := &Conn{ conn: c, errch: make(chan error), - exitch: make(chan struct{}), lock: new(sync.Mutex), streams: make(map[int]*stream), } + + conn.ctx, conn.cancel = context.WithCancel(context.Background()) + //log.Printf("new conn %x", uintptr(unsafe.Pointer(conn))) runtime.SetFinalizer(conn, (*Conn).free) - conn.session = C.init_nghttp2_client_session(C.size_t(uintptr(unsafe.Pointer(conn)))) + conn.session = C.init_nghttp2_client_session( + C.size_t(uintptr(unsafe.Pointer(conn)))) if conn.session == nil { return nil, errors.New("init server session failed") } ret := C.send_connection_header(conn.session) if int(ret) < 0 { conn.Close() - return nil, fmt.Errorf("send settings error: %s", C.GoString(C.nghttp2_strerror(ret))) + return nil, fmt.Errorf("send settings error: %s", + C.GoString(C.nghttp2_strerror(ret))) } go conn.Run() return conn, nil @@ -145,13 +154,9 @@ func HTTP2Handler(srv *http.Server, conn *tls.Conn, handler http.Handler) { func (c *Conn) free() { //log.Printf("free conn %x", uintptr(unsafe.Pointer(c))) - if !c.closed { + if !c.isClosed() { c.Close() } - c.conn = nil - c.session = nil - c.streams = nil - c.lock = nil } // Error return conn error @@ -188,7 +193,8 @@ func (c *Conn) RoundTrip(req *http.Request) (*http.Response, error) { nv = append(nv, newNV(":path", p)) for k, v := range req.Header { _k := strings.ToLower(k) - if _k == "connection" || _k == "proxy-connection" || _k == "transfer-encoding" { + if _k == "connection" || _k == "proxy-connection" || + _k == "transfer-encoding" { continue } nv = append(nv, newNV(k, v[0])) @@ -227,20 +233,23 @@ func (c *Conn) RoundTrip(req *http.Request) (*http.Response, error) { s.request = req res.Request = s.request return res, nil - case <-c.exitch: + case <-c.ctx.Done(): return nil, errors.New("connection closed") } } -func (c *Conn) submitRequest(nv []C.nghttp2_nv, cdp *C.nghttp2_data_provider) (*stream, error) { +func (c *Conn) submitRequest(nv []C.nghttp2_nv, + cdp *C.nghttp2_data_provider) (*stream, error) { c.lock.Lock() ret := C._nghttp2_submit_request(c.session, nil, - C.size_t(uintptr(unsafe.Pointer(&nv[0]))), C.size_t(len(nv)), cdp, nil) + C.size_t(uintptr(unsafe.Pointer(&nv[0]))), + C.size_t(len(nv)), cdp, nil) c.lock.Unlock() if int(ret) < 0 { - return nil, fmt.Errorf("submit request error: %s", C.GoString(C.nghttp2_strerror(ret))) + return nil, fmt.Errorf("submit request error: %s", + C.GoString(C.nghttp2_strerror(ret))) } streamID := int(ret) s := &stream{ @@ -252,6 +261,7 @@ func (c *Conn) submitRequest(nv []C.nghttp2_nv, cdp *C.nghttp2_data_provider) (* }, resch: make(chan *http.Response), } + s.ctx, s.cancel = context.WithCancel(context.Background()) if cdp != nil { s.cdp = *cdp } @@ -291,7 +301,8 @@ func (c *Conn) Connect(addr string) (conn net.Conn, statusCode int, err error) { select { case res := <-s.resch: if res.StatusCode != http.StatusOK { - return nil, res.StatusCode, fmt.Errorf("http error code %d", res.StatusCode) + return nil, res.StatusCode, fmt.Errorf( + "http error code %d", res.StatusCode) } s.request = &http.Request{ Method: "CONNECT", @@ -301,8 +312,9 @@ func (c *Conn) Connect(addr string) (conn net.Conn, statusCode int, err error) { } res.Request = s.request return s, res.StatusCode, nil - case <-c.exitch: - return nil, http.StatusServiceUnavailable, errors.New("connection closed") + case <-c.ctx.Done(): + return nil, http.StatusServiceUnavailable, + errors.New("connection closed") } } @@ -323,7 +335,7 @@ func (c *Conn) Run() { case err := <-c.errch: c.err = err return - case <-c.exitch: + case <-c.ctx.Done(): return } } @@ -345,11 +357,11 @@ func (c *Conn) serve(s *stream) { // Close close the connection func (c *Conn) Close() error { c.lock.Lock() - if c.closed { + if c.isClosed() { c.lock.Unlock() return nil } - c.closed = true + c.cancel() c.lock.Unlock() // stream.Close may require the conn.Lock @@ -368,7 +380,6 @@ func (c *Conn) Close() error { C.nghttp2_session_del(c.session) c.lock.Unlock() - close(c.exitch) c.conn.Close() return nil } @@ -383,12 +394,9 @@ func (c *Conn) errorNotify(err error) { func (c *Conn) readloop() { buf := make([]byte, 16*1024) for { - select { - case <-c.exitch: + if c.isClosed() { return - default: } - n, err := c.conn.Read(buf) if err != nil { c.errorNotify(err) @@ -396,16 +404,17 @@ func (c *Conn) readloop() { } c.lock.Lock() - if c.closed { + // check again + if c.isClosed() { c.lock.Unlock() return } - ret := C.nghttp2_session_mem_recv(c.session, (*C.uchar)(unsafe.Pointer(&buf[0])), C.size_t(n)) c.lock.Unlock() if int(ret) < 0 { - err = fmt.Errorf("http2 recv error: %s", C.GoString(C.nghttp2_strerror(C.int(ret)))) + err = fmt.Errorf("http2 recv error: %s", + C.GoString(C.nghttp2_strerror(C.int(ret)))) c.errorNotify(err) return } @@ -418,16 +427,16 @@ func (c *Conn) writeloop() { var delay = 50 * time.Millisecond for { - select { - case <-c.exitch: + c.lock.Lock() + if c.isClosed() { + c.lock.Unlock() return - default: } - c.lock.Lock() ret = C.nghttp2_session_send(c.session) c.lock.Unlock() if int(ret) < 0 { - err = fmt.Errorf("http2 send error: %s", C.GoString(C.nghttp2_strerror(C.int(ret)))) + err = fmt.Errorf("http2 send error: %s", + C.GoString(C.nghttp2_strerror(C.int(ret)))) c.errorNotify(err) return } @@ -440,3 +449,12 @@ func (c *Conn) writeloop() { } } } + +func (c *Conn) isClosed() bool { + select { + case <-c.ctx.Done(): + return true + default: + } + return false +} diff --git a/data_provider.go b/data_provider.go index b139e99..3c755cd 100644 --- a/data_provider.go +++ b/data_provider.go @@ -30,7 +30,8 @@ type dataProvider struct { // Read read from data provider func (dp *dataProvider) Read(buf []byte) (n int, err error) { - if dp.buf == nil || dp.lock == nil || dp.sessLock == nil || dp.session == nil { + if dp.buf == nil || dp.lock == nil || + dp.sessLock == nil || dp.session == nil { log.Println("dp read invalid state") return 0, errors.New("invalid state") } @@ -48,14 +49,15 @@ func (dp *dataProvider) Read(buf []byte) (n int, err error) { // Write provider data for data provider func (dp *dataProvider) Write(buf []byte) (n int, err error) { - if dp.buf == nil || dp.lock == nil || dp.sessLock == nil || dp.session == nil { + if dp.buf == nil || dp.lock == nil || + dp.sessLock == nil || dp.session == nil { log.Println("dp write invalid state") return 0, errors.New("invalid state") } // make sure the buffer not too large delay := 10 * time.Millisecond - maxBufSize := 1 * 1024 * 1024 + maxBufSize := 4 * 1024 for { dp.lock.Lock() _len := dp.buf.Len() @@ -80,7 +82,8 @@ func (dp *dataProvider) Write(buf []byte) (n int, err error) { n, err = dp.buf.Write(buf) if dp.deferred { dp.sessLock.Lock() - C.nghttp2_session_resume_data(dp.session, C.int(dp.streamID)) + C.nghttp2_session_resume_data( + dp.session, C.int(dp.streamID)) dp.sessLock.Unlock() //log.Println("resume") @@ -91,7 +94,8 @@ func (dp *dataProvider) Write(buf []byte) (n int, err error) { // Close end to provide data func (dp *dataProvider) Close() error { - if dp.buf == nil || dp.lock == nil || dp.sessLock == nil || dp.session == nil { + if dp.buf == nil || dp.lock == nil || + dp.sessLock == nil || dp.session == nil { log.Println("dp close, invalid state") return errors.New("invalid state") } @@ -105,7 +109,8 @@ func (dp *dataProvider) Close() error { //log.Printf("dp close stream %d", dp.streamID) if dp.deferred { dp.sessLock.Lock() - C.nghttp2_session_resume_data(dp.session, C.int(dp.streamID)) + C.nghttp2_session_resume_data( + dp.session, C.int(dp.streamID)) dp.sessLock.Unlock() dp.deferred = false @@ -113,7 +118,8 @@ func (dp *dataProvider) Close() error { return nil } -func newDataProvider(cdp unsafe.Pointer, sessionLock *sync.Mutex, t int) *dataProvider { +func newDataProvider(cdp unsafe.Pointer, + sessionLock *sync.Mutex, t int) *dataProvider { dp := &dataProvider{ buf: new(bytes.Buffer), lock: new(sync.Mutex), diff --git a/stream.go b/stream.go index f044685..2745157 100644 --- a/stream.go +++ b/stream.go @@ -5,6 +5,7 @@ package nghttp2 */ import "C" import ( + "context" "errors" "fmt" "io" @@ -25,37 +26,31 @@ type stream struct { resch chan *http.Response headersEnd bool streamEnd bool - closed bool cdp C.nghttp2_data_provider + ctx context.Context + cancel context.CancelFunc } var _ net.Conn = &stream{} +func (s *stream) isClosed() bool { + select { + case <-s.ctx.Done(): + return true + default: + } + return false +} + func (s *stream) free() { //log.Printf("stream free %d", s.streamID) - if !s.closed { + if !s.isClosed() { s.Close() } - s.conn = nil - if s.dp != nil { - s.dp.buf = nil - s.dp.lock = nil - s.dp.sessLock = nil - s.dp.session = nil - s.dp = nil - } - if s.bp != nil { - s.bp.buf = nil - s.bp.lock = nil - s.bp = nil - } - s.request = nil - s.response = nil - s.resch = nil } func (s *stream) Read(buf []byte) (int, error) { - if s.closed { + if s.isClosed() { return 0, io.EOF } if s.bp != nil { @@ -65,7 +60,7 @@ func (s *stream) Read(buf []byte) (int, error) { } func (s *stream) WriteHeader(code int) { - if s.closed { + if s.isClosed() { return } if s.response == nil { @@ -87,28 +82,33 @@ func (s *stream) WriteHeader(code int) { nv = append(nv, newNV(":status", fmt.Sprintf("%d", code))) for k, v := range s.response.Header { _k := strings.ToLower(k) - if _k == "host" || _k == "connection" || _k == "transfer-encoding" { + if _k == "host" || _k == "connection" || + _k == "transfer-encoding" { continue } nv = append(nv, newNV(k, v[0])) } s.cdp = C.nghttp2_data_provider{} - s.dp = newDataProvider(unsafe.Pointer(&s.cdp), s.conn.lock, 0) + s.dp = newDataProvider(unsafe.Pointer(&s.cdp), + s.conn.lock, 0) s.dp.session = s.conn.session s.dp.streamID = s.streamID s.conn.lock.Lock() - if s.conn.closed { + if s.conn.isClosed() { s.conn.lock.Unlock() return } - ret := C._nghttp2_submit_response(s.conn.session, C.int(s.streamID), - C.size_t(uintptr(unsafe.Pointer(&nv[0]))), C.size_t(len(nv)), &s.cdp) + ret := C._nghttp2_submit_response(s.conn.session, + C.int(s.streamID), + C.size_t(uintptr(unsafe.Pointer(&nv[0]))), + C.size_t(len(nv)), &s.cdp) s.conn.lock.Unlock() if int(ret) < 0 { - panic(fmt.Sprintf("submit response error: %s", C.GoString(C.nghttp2_strerror(ret)))) + panic(fmt.Sprintf("submit response error: %s", + C.GoString(C.nghttp2_strerror(ret)))) } } @@ -125,10 +125,11 @@ func (s *stream) Header() http.Header { } func (s *stream) Write(buf []byte) (int, error) { - if s.closed { + if s.isClosed() { return 0, io.EOF } - if s.conn.isServer && (s.response == nil || s.response.StatusCode == 0) { + if s.conn.isServer && (s.response == nil || + s.response.StatusCode == 0) { s.WriteHeader(http.StatusOK) } @@ -139,10 +140,12 @@ func (s *stream) Write(buf []byte) (int, error) { } func (s *stream) Close() error { - if s.closed { + if s.isClosed() { return nil } - s.closed = true + + s.cancel() + if s.dp != nil { s.dp.Close() } @@ -157,7 +160,8 @@ func (s *stream) Close() error { if s.request != nil && s.request.Method == "CONNECT" { //log.Println("rst stream") s.conn.lock.Lock() - C.nghttp2_submit_rst_stream(s.conn.session, 0, C.int(s.streamID), 8) + C.nghttp2_submit_rst_stream(s.conn.session, 0, + C.int(s.streamID), 8) s.conn.lock.Unlock() } return nil