use context to manage conn/stream close

master
fangdingjun 6 years ago
parent 2696323ff3
commit f23c78ede5

@ -6,6 +6,7 @@ package nghttp2
import "C" import "C"
import ( import (
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"io" "io"
@ -115,7 +116,8 @@ func onDataChunkRecv(ptr unsafe.Pointer, streamID C.int,
// onDataSendCallback callback function for libnghttp2 library want send data to network. // onDataSendCallback callback function for libnghttp2 library want send data to network.
// //
//export onDataSendCallback //export onDataSendCallback
func onDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer, size C.size_t) C.ssize_t { func onDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer,
size C.size_t) C.ssize_t {
//log.Println("onDataSendCallback begin") //log.Println("onDataSendCallback begin")
//log.Println("data write req ", int(size)) //log.Println("data write req ", int(size))
conn := (*Conn)(unsafe.Pointer(uintptr(ptr))) conn := (*Conn)(unsafe.Pointer(uintptr(ptr)))
@ -177,6 +179,7 @@ func onBeginHeaderCallback(ptr unsafe.Pointer, streamID C.int) C.int {
TLS: &TLS, TLS: &TLS,
}, },
} }
s.ctx, s.cancel = context.WithCancel(context.Background())
s.request.Body = s.bp s.request.Body = s.bp
//log.Printf("new stream %d", int(streamID)) //log.Printf("new stream %d", int(streamID))
conn.streams[int(streamID)] = s conn.streams[int(streamID)] = s
@ -245,9 +248,11 @@ func onHeaderCallback(ptr unsafe.Pointer, streamID C.int,
case "transfer-encoding": case "transfer-encoding":
header.Add(goname, govalue) header.Add(goname, govalue)
if conn.isServer { if conn.isServer {
s.request.TransferEncoding = append(s.response.TransferEncoding, govalue) s.request.TransferEncoding = append(
s.response.TransferEncoding, govalue)
} else { } else {
s.response.TransferEncoding = append(s.response.TransferEncoding, govalue) s.response.TransferEncoding = append(
s.response.TransferEncoding, govalue)
} }
default: default:
header.Add(goname, govalue) header.Add(goname, govalue)

@ -7,6 +7,7 @@ package nghttp2
import "C" import "C"
import ( import (
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
@ -27,14 +28,14 @@ type Conn struct {
session *C.nghttp2_session session *C.nghttp2_session
streams map[int]*stream streams map[int]*stream
streamCount int streamCount int
closed bool
isServer bool isServer bool
running bool running bool
handler http.Handler handler http.Handler
lock *sync.Mutex lock *sync.Mutex
err error err error
errch chan error errch chan error
exitch chan struct{} ctx context.Context
cancel context.CancelFunc
} }
// Dial connect to addr and create a http2 client Conn // Dial connect to addr and create a http2 client Conn
@ -75,21 +76,25 @@ func Server(c net.Conn, handler http.Handler) (*Conn, error) {
conn: c, conn: c,
handler: handler, handler: handler,
errch: make(chan error), errch: make(chan error),
exitch: make(chan struct{}),
lock: new(sync.Mutex), lock: new(sync.Mutex),
isServer: true, isServer: true,
streams: make(map[int]*stream), streams: make(map[int]*stream),
} }
conn.ctx, conn.cancel = context.WithCancel(context.Background())
//log.Printf("new conn %x", uintptr(unsafe.Pointer(conn))) //log.Printf("new conn %x", uintptr(unsafe.Pointer(conn)))
runtime.SetFinalizer(conn, (*Conn).free) runtime.SetFinalizer(conn, (*Conn).free)
conn.session = C.init_nghttp2_server_session(C.size_t(uintptr(unsafe.Pointer(conn)))) conn.session = C.init_nghttp2_server_session(
C.size_t(uintptr(unsafe.Pointer(conn))))
if conn.session == nil { if conn.session == nil {
return nil, errors.New("init server session failed") return nil, errors.New("init server session failed")
} }
ret := C.send_connection_header(conn.session) ret := C.send_connection_header(conn.session)
if int(ret) < 0 { if int(ret) < 0 {
conn.Close() conn.Close()
return nil, fmt.Errorf("send settings error: %s", C.GoString(C.nghttp2_strerror(ret))) return nil, fmt.Errorf("send settings error: %s",
C.GoString(C.nghttp2_strerror(ret)))
} }
return conn, nil return conn, nil
} }
@ -103,20 +108,24 @@ func Client(c net.Conn) (*Conn, error) {
conn := &Conn{ conn := &Conn{
conn: c, conn: c,
errch: make(chan error), errch: make(chan error),
exitch: make(chan struct{}),
lock: new(sync.Mutex), lock: new(sync.Mutex),
streams: make(map[int]*stream), streams: make(map[int]*stream),
} }
conn.ctx, conn.cancel = context.WithCancel(context.Background())
//log.Printf("new conn %x", uintptr(unsafe.Pointer(conn))) //log.Printf("new conn %x", uintptr(unsafe.Pointer(conn)))
runtime.SetFinalizer(conn, (*Conn).free) runtime.SetFinalizer(conn, (*Conn).free)
conn.session = C.init_nghttp2_client_session(C.size_t(uintptr(unsafe.Pointer(conn)))) conn.session = C.init_nghttp2_client_session(
C.size_t(uintptr(unsafe.Pointer(conn))))
if conn.session == nil { if conn.session == nil {
return nil, errors.New("init server session failed") return nil, errors.New("init server session failed")
} }
ret := C.send_connection_header(conn.session) ret := C.send_connection_header(conn.session)
if int(ret) < 0 { if int(ret) < 0 {
conn.Close() conn.Close()
return nil, fmt.Errorf("send settings error: %s", C.GoString(C.nghttp2_strerror(ret))) return nil, fmt.Errorf("send settings error: %s",
C.GoString(C.nghttp2_strerror(ret)))
} }
go conn.Run() go conn.Run()
return conn, nil return conn, nil
@ -145,13 +154,9 @@ func HTTP2Handler(srv *http.Server, conn *tls.Conn, handler http.Handler) {
func (c *Conn) free() { func (c *Conn) free() {
//log.Printf("free conn %x", uintptr(unsafe.Pointer(c))) //log.Printf("free conn %x", uintptr(unsafe.Pointer(c)))
if !c.closed { if !c.isClosed() {
c.Close() c.Close()
} }
c.conn = nil
c.session = nil
c.streams = nil
c.lock = nil
} }
// Error return conn error // Error return conn error
@ -188,7 +193,8 @@ func (c *Conn) RoundTrip(req *http.Request) (*http.Response, error) {
nv = append(nv, newNV(":path", p)) nv = append(nv, newNV(":path", p))
for k, v := range req.Header { for k, v := range req.Header {
_k := strings.ToLower(k) _k := strings.ToLower(k)
if _k == "connection" || _k == "proxy-connection" || _k == "transfer-encoding" { if _k == "connection" || _k == "proxy-connection" ||
_k == "transfer-encoding" {
continue continue
} }
nv = append(nv, newNV(k, v[0])) nv = append(nv, newNV(k, v[0]))
@ -227,20 +233,23 @@ func (c *Conn) RoundTrip(req *http.Request) (*http.Response, error) {
s.request = req s.request = req
res.Request = s.request res.Request = s.request
return res, nil return res, nil
case <-c.exitch: case <-c.ctx.Done():
return nil, errors.New("connection closed") return nil, errors.New("connection closed")
} }
} }
func (c *Conn) submitRequest(nv []C.nghttp2_nv, cdp *C.nghttp2_data_provider) (*stream, error) { func (c *Conn) submitRequest(nv []C.nghttp2_nv,
cdp *C.nghttp2_data_provider) (*stream, error) {
c.lock.Lock() c.lock.Lock()
ret := C._nghttp2_submit_request(c.session, nil, ret := C._nghttp2_submit_request(c.session, nil,
C.size_t(uintptr(unsafe.Pointer(&nv[0]))), C.size_t(len(nv)), cdp, nil) C.size_t(uintptr(unsafe.Pointer(&nv[0]))),
C.size_t(len(nv)), cdp, nil)
c.lock.Unlock() c.lock.Unlock()
if int(ret) < 0 { if int(ret) < 0 {
return nil, fmt.Errorf("submit request error: %s", C.GoString(C.nghttp2_strerror(ret))) return nil, fmt.Errorf("submit request error: %s",
C.GoString(C.nghttp2_strerror(ret)))
} }
streamID := int(ret) streamID := int(ret)
s := &stream{ s := &stream{
@ -252,6 +261,7 @@ func (c *Conn) submitRequest(nv []C.nghttp2_nv, cdp *C.nghttp2_data_provider) (*
}, },
resch: make(chan *http.Response), resch: make(chan *http.Response),
} }
s.ctx, s.cancel = context.WithCancel(context.Background())
if cdp != nil { if cdp != nil {
s.cdp = *cdp s.cdp = *cdp
} }
@ -291,7 +301,8 @@ func (c *Conn) Connect(addr string) (conn net.Conn, statusCode int, err error) {
select { select {
case res := <-s.resch: case res := <-s.resch:
if res.StatusCode != http.StatusOK { if res.StatusCode != http.StatusOK {
return nil, res.StatusCode, fmt.Errorf("http error code %d", res.StatusCode) return nil, res.StatusCode, fmt.Errorf(
"http error code %d", res.StatusCode)
} }
s.request = &http.Request{ s.request = &http.Request{
Method: "CONNECT", Method: "CONNECT",
@ -301,8 +312,9 @@ func (c *Conn) Connect(addr string) (conn net.Conn, statusCode int, err error) {
} }
res.Request = s.request res.Request = s.request
return s, res.StatusCode, nil return s, res.StatusCode, nil
case <-c.exitch: case <-c.ctx.Done():
return nil, http.StatusServiceUnavailable, errors.New("connection closed") return nil, http.StatusServiceUnavailable,
errors.New("connection closed")
} }
} }
@ -323,7 +335,7 @@ func (c *Conn) Run() {
case err := <-c.errch: case err := <-c.errch:
c.err = err c.err = err
return return
case <-c.exitch: case <-c.ctx.Done():
return return
} }
} }
@ -345,11 +357,11 @@ func (c *Conn) serve(s *stream) {
// Close close the connection // Close close the connection
func (c *Conn) Close() error { func (c *Conn) Close() error {
c.lock.Lock() c.lock.Lock()
if c.closed { if c.isClosed() {
c.lock.Unlock() c.lock.Unlock()
return nil return nil
} }
c.closed = true c.cancel()
c.lock.Unlock() c.lock.Unlock()
// stream.Close may require the conn.Lock // stream.Close may require the conn.Lock
@ -368,7 +380,6 @@ func (c *Conn) Close() error {
C.nghttp2_session_del(c.session) C.nghttp2_session_del(c.session)
c.lock.Unlock() c.lock.Unlock()
close(c.exitch)
c.conn.Close() c.conn.Close()
return nil return nil
} }
@ -383,12 +394,9 @@ func (c *Conn) errorNotify(err error) {
func (c *Conn) readloop() { func (c *Conn) readloop() {
buf := make([]byte, 16*1024) buf := make([]byte, 16*1024)
for { for {
select { if c.isClosed() {
case <-c.exitch:
return return
default:
} }
n, err := c.conn.Read(buf) n, err := c.conn.Read(buf)
if err != nil { if err != nil {
c.errorNotify(err) c.errorNotify(err)
@ -396,16 +404,17 @@ func (c *Conn) readloop() {
} }
c.lock.Lock() c.lock.Lock()
if c.closed { // check again
if c.isClosed() {
c.lock.Unlock() c.lock.Unlock()
return return
} }
ret := C.nghttp2_session_mem_recv(c.session, ret := C.nghttp2_session_mem_recv(c.session,
(*C.uchar)(unsafe.Pointer(&buf[0])), C.size_t(n)) (*C.uchar)(unsafe.Pointer(&buf[0])), C.size_t(n))
c.lock.Unlock() c.lock.Unlock()
if int(ret) < 0 { if int(ret) < 0 {
err = fmt.Errorf("http2 recv error: %s", C.GoString(C.nghttp2_strerror(C.int(ret)))) err = fmt.Errorf("http2 recv error: %s",
C.GoString(C.nghttp2_strerror(C.int(ret))))
c.errorNotify(err) c.errorNotify(err)
return return
} }
@ -418,16 +427,16 @@ func (c *Conn) writeloop() {
var delay = 50 * time.Millisecond var delay = 50 * time.Millisecond
for { for {
select { c.lock.Lock()
case <-c.exitch: if c.isClosed() {
c.lock.Unlock()
return return
default:
} }
c.lock.Lock()
ret = C.nghttp2_session_send(c.session) ret = C.nghttp2_session_send(c.session)
c.lock.Unlock() c.lock.Unlock()
if int(ret) < 0 { if int(ret) < 0 {
err = fmt.Errorf("http2 send error: %s", C.GoString(C.nghttp2_strerror(C.int(ret)))) err = fmt.Errorf("http2 send error: %s",
C.GoString(C.nghttp2_strerror(C.int(ret))))
c.errorNotify(err) c.errorNotify(err)
return return
} }
@ -440,3 +449,12 @@ func (c *Conn) writeloop() {
} }
} }
} }
func (c *Conn) isClosed() bool {
select {
case <-c.ctx.Done():
return true
default:
}
return false
}

@ -30,7 +30,8 @@ type dataProvider struct {
// Read read from data provider // Read read from data provider
func (dp *dataProvider) Read(buf []byte) (n int, err error) { func (dp *dataProvider) Read(buf []byte) (n int, err error) {
if dp.buf == nil || dp.lock == nil || dp.sessLock == nil || dp.session == nil { if dp.buf == nil || dp.lock == nil ||
dp.sessLock == nil || dp.session == nil {
log.Println("dp read invalid state") log.Println("dp read invalid state")
return 0, errors.New("invalid state") return 0, errors.New("invalid state")
} }
@ -48,14 +49,15 @@ func (dp *dataProvider) Read(buf []byte) (n int, err error) {
// Write provider data for data provider // Write provider data for data provider
func (dp *dataProvider) Write(buf []byte) (n int, err error) { func (dp *dataProvider) Write(buf []byte) (n int, err error) {
if dp.buf == nil || dp.lock == nil || dp.sessLock == nil || dp.session == nil { if dp.buf == nil || dp.lock == nil ||
dp.sessLock == nil || dp.session == nil {
log.Println("dp write invalid state") log.Println("dp write invalid state")
return 0, errors.New("invalid state") return 0, errors.New("invalid state")
} }
// make sure the buffer not too large // make sure the buffer not too large
delay := 10 * time.Millisecond delay := 10 * time.Millisecond
maxBufSize := 1 * 1024 * 1024 maxBufSize := 4 * 1024
for { for {
dp.lock.Lock() dp.lock.Lock()
_len := dp.buf.Len() _len := dp.buf.Len()
@ -80,7 +82,8 @@ func (dp *dataProvider) Write(buf []byte) (n int, err error) {
n, err = dp.buf.Write(buf) n, err = dp.buf.Write(buf)
if dp.deferred { if dp.deferred {
dp.sessLock.Lock() dp.sessLock.Lock()
C.nghttp2_session_resume_data(dp.session, C.int(dp.streamID)) C.nghttp2_session_resume_data(
dp.session, C.int(dp.streamID))
dp.sessLock.Unlock() dp.sessLock.Unlock()
//log.Println("resume") //log.Println("resume")
@ -91,7 +94,8 @@ func (dp *dataProvider) Write(buf []byte) (n int, err error) {
// Close end to provide data // Close end to provide data
func (dp *dataProvider) Close() error { func (dp *dataProvider) Close() error {
if dp.buf == nil || dp.lock == nil || dp.sessLock == nil || dp.session == nil { if dp.buf == nil || dp.lock == nil ||
dp.sessLock == nil || dp.session == nil {
log.Println("dp close, invalid state") log.Println("dp close, invalid state")
return errors.New("invalid state") return errors.New("invalid state")
} }
@ -105,7 +109,8 @@ func (dp *dataProvider) Close() error {
//log.Printf("dp close stream %d", dp.streamID) //log.Printf("dp close stream %d", dp.streamID)
if dp.deferred { if dp.deferred {
dp.sessLock.Lock() dp.sessLock.Lock()
C.nghttp2_session_resume_data(dp.session, C.int(dp.streamID)) C.nghttp2_session_resume_data(
dp.session, C.int(dp.streamID))
dp.sessLock.Unlock() dp.sessLock.Unlock()
dp.deferred = false dp.deferred = false
@ -113,7 +118,8 @@ func (dp *dataProvider) Close() error {
return nil return nil
} }
func newDataProvider(cdp unsafe.Pointer, sessionLock *sync.Mutex, t int) *dataProvider { func newDataProvider(cdp unsafe.Pointer,
sessionLock *sync.Mutex, t int) *dataProvider {
dp := &dataProvider{ dp := &dataProvider{
buf: new(bytes.Buffer), buf: new(bytes.Buffer),
lock: new(sync.Mutex), lock: new(sync.Mutex),

@ -5,6 +5,7 @@ package nghttp2
*/ */
import "C" import "C"
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -25,37 +26,31 @@ type stream struct {
resch chan *http.Response resch chan *http.Response
headersEnd bool headersEnd bool
streamEnd bool streamEnd bool
closed bool
cdp C.nghttp2_data_provider cdp C.nghttp2_data_provider
ctx context.Context
cancel context.CancelFunc
} }
var _ net.Conn = &stream{} var _ net.Conn = &stream{}
func (s *stream) isClosed() bool {
select {
case <-s.ctx.Done():
return true
default:
}
return false
}
func (s *stream) free() { func (s *stream) free() {
//log.Printf("stream free %d", s.streamID) //log.Printf("stream free %d", s.streamID)
if !s.closed { if !s.isClosed() {
s.Close() s.Close()
} }
s.conn = nil
if s.dp != nil {
s.dp.buf = nil
s.dp.lock = nil
s.dp.sessLock = nil
s.dp.session = nil
s.dp = nil
}
if s.bp != nil {
s.bp.buf = nil
s.bp.lock = nil
s.bp = nil
}
s.request = nil
s.response = nil
s.resch = nil
} }
func (s *stream) Read(buf []byte) (int, error) { func (s *stream) Read(buf []byte) (int, error) {
if s.closed { if s.isClosed() {
return 0, io.EOF return 0, io.EOF
} }
if s.bp != nil { if s.bp != nil {
@ -65,7 +60,7 @@ func (s *stream) Read(buf []byte) (int, error) {
} }
func (s *stream) WriteHeader(code int) { func (s *stream) WriteHeader(code int) {
if s.closed { if s.isClosed() {
return return
} }
if s.response == nil { if s.response == nil {
@ -87,28 +82,33 @@ func (s *stream) WriteHeader(code int) {
nv = append(nv, newNV(":status", fmt.Sprintf("%d", code))) nv = append(nv, newNV(":status", fmt.Sprintf("%d", code)))
for k, v := range s.response.Header { for k, v := range s.response.Header {
_k := strings.ToLower(k) _k := strings.ToLower(k)
if _k == "host" || _k == "connection" || _k == "transfer-encoding" { if _k == "host" || _k == "connection" ||
_k == "transfer-encoding" {
continue continue
} }
nv = append(nv, newNV(k, v[0])) nv = append(nv, newNV(k, v[0]))
} }
s.cdp = C.nghttp2_data_provider{} s.cdp = C.nghttp2_data_provider{}
s.dp = newDataProvider(unsafe.Pointer(&s.cdp), s.conn.lock, 0) s.dp = newDataProvider(unsafe.Pointer(&s.cdp),
s.conn.lock, 0)
s.dp.session = s.conn.session s.dp.session = s.conn.session
s.dp.streamID = s.streamID s.dp.streamID = s.streamID
s.conn.lock.Lock() s.conn.lock.Lock()
if s.conn.closed { if s.conn.isClosed() {
s.conn.lock.Unlock() s.conn.lock.Unlock()
return return
} }
ret := C._nghttp2_submit_response(s.conn.session, C.int(s.streamID), ret := C._nghttp2_submit_response(s.conn.session,
C.size_t(uintptr(unsafe.Pointer(&nv[0]))), C.size_t(len(nv)), &s.cdp) C.int(s.streamID),
C.size_t(uintptr(unsafe.Pointer(&nv[0]))),
C.size_t(len(nv)), &s.cdp)
s.conn.lock.Unlock() s.conn.lock.Unlock()
if int(ret) < 0 { if int(ret) < 0 {
panic(fmt.Sprintf("submit response error: %s", C.GoString(C.nghttp2_strerror(ret)))) panic(fmt.Sprintf("submit response error: %s",
C.GoString(C.nghttp2_strerror(ret))))
} }
} }
@ -125,10 +125,11 @@ func (s *stream) Header() http.Header {
} }
func (s *stream) Write(buf []byte) (int, error) { func (s *stream) Write(buf []byte) (int, error) {
if s.closed { if s.isClosed() {
return 0, io.EOF return 0, io.EOF
} }
if s.conn.isServer && (s.response == nil || s.response.StatusCode == 0) { if s.conn.isServer && (s.response == nil ||
s.response.StatusCode == 0) {
s.WriteHeader(http.StatusOK) s.WriteHeader(http.StatusOK)
} }
@ -139,10 +140,12 @@ func (s *stream) Write(buf []byte) (int, error) {
} }
func (s *stream) Close() error { func (s *stream) Close() error {
if s.closed { if s.isClosed() {
return nil return nil
} }
s.closed = true
s.cancel()
if s.dp != nil { if s.dp != nil {
s.dp.Close() s.dp.Close()
} }
@ -157,7 +160,8 @@ func (s *stream) Close() error {
if s.request != nil && s.request.Method == "CONNECT" { if s.request != nil && s.request.Method == "CONNECT" {
//log.Println("rst stream") //log.Println("rst stream")
s.conn.lock.Lock() s.conn.lock.Lock()
C.nghttp2_submit_rst_stream(s.conn.session, 0, C.int(s.streamID), 8) C.nghttp2_submit_rst_stream(s.conn.session, 0,
C.int(s.streamID), 8)
s.conn.lock.Unlock() s.conn.lock.Unlock()
} }
return nil return nil

Loading…
Cancel
Save