diff --git a/callbacks.go b/callbacks.go index 4544730..e2402f7 100644 --- a/callbacks.go +++ b/callbacks.go @@ -11,6 +11,7 @@ import ( "io" "net/http" "net/url" + "runtime" "strconv" "strings" "sync" @@ -160,9 +161,11 @@ func onBeginHeaderCallback(ptr unsafe.Pointer, streamID C.int) C.int { }, } s.request.Body = s.bp - + //log.Printf("new stream %d", int(streamID)) conn.streams[int(streamID)] = s + runtime.SetFinalizer(s, (*stream).free) + //log.Println("onBeginHeaderCallback end") return NGHTTP2_NO_ERROR } @@ -274,6 +277,7 @@ func onStreamClose(ptr unsafe.Pointer, streamID C.int) C.int { stream, ok := conn.streams[int(streamID)] if ok { go stream.Close() + //log.Printf("remove stream %d", int(streamID)) //conn.lock.Lock() delete(conn.streams, int(streamID)) //go stream.Close() @@ -291,7 +295,7 @@ func onConnectionCloseCallback(ptr unsafe.Pointer) { conn.err = io.EOF // signal all goroutings exit - for i := 0; i < 4; i++ { + for i := 0; i < 6; i++ { select { case conn.exitch <- struct{}{}: default: diff --git a/conn.go b/conn.go index 81385ad..5fc1622 100644 --- a/conn.go +++ b/conn.go @@ -14,6 +14,7 @@ import ( "net" "net/http" "net/url" + "runtime" "strings" "sync" "time" @@ -79,6 +80,8 @@ func Server(c net.Conn, handler http.Handler) (*Conn, error) { isServer: true, streams: make(map[int]*stream), } + //log.Printf("new conn %x", uintptr(unsafe.Pointer(conn))) + runtime.SetFinalizer(conn, (*Conn).free) conn.session = C.init_nghttp2_server_session(C.size_t(uintptr(unsafe.Pointer(conn)))) if conn.session == nil { return nil, errors.New("init server session failed") @@ -104,6 +107,8 @@ func Client(c net.Conn) (*Conn, error) { lock: new(sync.Mutex), streams: make(map[int]*stream), } + //log.Printf("new conn %x", uintptr(unsafe.Pointer(conn))) + runtime.SetFinalizer(conn, (*Conn).free) conn.session = C.init_nghttp2_client_session(C.size_t(uintptr(unsafe.Pointer(conn)))) if conn.session == nil { return nil, errors.New("init server session failed") @@ -138,6 +143,17 @@ func HTTP2Handler(srv *http.Server, conn *tls.Conn, handler http.Handler) { h2conn.Run() } +func (c *Conn) free() { + //log.Printf("free conn %x", uintptr(unsafe.Pointer(c))) + if !c.closed { + c.Close() + } + c.conn = nil + c.session = nil + c.streams = nil + c.lock = nil +} + // Error return conn error func (c *Conn) Error() error { c.lock.Lock() @@ -239,6 +255,7 @@ func (c *Conn) submitRequest(nv []C.nghttp2_nv, cdp *C.nghttp2_data_provider) (* if cdp != nil { s.cdp = *cdp } + runtime.SetFinalizer(s, (*stream).free) return s, nil } @@ -335,6 +352,10 @@ func (c *Conn) Close() error { s.Close() } + for n := range c.streams { + delete(c.streams, n) + } + c.lock.Lock() C.nghttp2_session_terminate_session(c.session, 0) C.nghttp2_session_del(c.session) @@ -353,27 +374,27 @@ func (c *Conn) errorNotify(err error) { } func (c *Conn) readloop() { - type data struct { - buf []byte - err error - } var ret C.ssize_t var err error - var d data - - datach := make(chan data) + datach := make(chan []byte) go func() { - d1 := data{} var n int var err1 error for { + select { + case <-c.exitch: + return + default: + } buf := make([]byte, 16*1024) n, err1 = c.conn.Read(buf) - d1.buf = buf[:n] - d1.err = err1 - datach <- d1 + if err1 != nil { + c.errorNotify(err1) + break + } + datach <- buf[:n] } }() @@ -381,14 +402,14 @@ func (c *Conn) readloop() { select { case <-c.exitch: return - case d = <-datach: - if d.err != nil { - c.errorNotify(d.err) + case d := <-datach: + c.lock.Lock() + if c.closed { + c.lock.Unlock() return } - c.lock.Lock() ret = C.nghttp2_session_mem_recv(c.session, - (*C.uchar)(unsafe.Pointer(&d.buf[0])), C.size_t(len(d.buf))) + (*C.uchar)(unsafe.Pointer(&d[0])), C.size_t(len(d))) c.lock.Unlock() if int(ret) < 0 { err = fmt.Errorf("http2 recv error: %s", C.GoString(C.nghttp2_strerror(C.int(ret)))) @@ -403,6 +424,7 @@ func (c *Conn) writeloop() { var ret C.int var err error var delay = 50 * time.Millisecond + for { select { case <-c.exitch: diff --git a/data_provider.go b/data_provider.go index 42b8c93..ca7a986 100644 --- a/data_provider.go +++ b/data_provider.go @@ -7,6 +7,7 @@ import "C" import ( "bytes" "errors" + "io" "log" "sync" "time" @@ -51,6 +52,24 @@ func (dp *dataProvider) Write(buf []byte) (n int, err error) { log.Println("dp write invalid state") return 0, errors.New("invalid state") } + + // make sure the buffer not too large + delay := 10 * time.Millisecond + maxBufSize := 1 * 1024 * 1024 + for { + dp.lock.Lock() + _len := dp.buf.Len() + closed := dp.closed + dp.lock.Unlock() + if closed { + return 0, io.EOF + } + if _len < maxBufSize { + break + } + time.Sleep(delay) + } + dp.lock.Lock() defer dp.lock.Unlock() diff --git a/stream.go b/stream.go index 3d5cd70..0ccfd96 100644 --- a/stream.go +++ b/stream.go @@ -31,6 +31,29 @@ type stream struct { var _ net.Conn = &stream{} +func (s *stream) free() { + //log.Printf("stream free %d", s.streamID) + if !s.closed { + s.Close() + } + s.conn = nil + if s.dp != nil { + s.dp.buf = nil + s.dp.lock = nil + s.dp.sessLock = nil + s.dp.session = nil + s.dp = nil + } + if s.bp != nil { + s.bp.buf = nil + s.bp.lock = nil + s.bp = nil + } + s.request = nil + s.response = nil + s.resch = nil +} + func (s *stream) Read(buf []byte) (int, error) { if s.closed { return 0, io.EOF