diff --git a/conn.go b/conn.go index fdfb935..eb7e754 100644 --- a/conn.go +++ b/conn.go @@ -25,13 +25,15 @@ var ( // ClientConn http2 client connection type ClientConn struct { - session *C.nghttp2_session - conn net.Conn - streams map[int]*ClientStream - lock *sync.Mutex - errch chan struct{} - exitch chan struct{} - err error + session *C.nghttp2_session + conn net.Conn + streams map[int]*ClientStream + lock *sync.Mutex + errch chan struct{} + exitch chan struct{} + err error + closed bool + streamCount int } // Client create http2 client @@ -57,8 +59,22 @@ func Client(c net.Conn) (*ClientConn, error) { return conn, nil } +// Error return current error on connection +func (c *ClientConn) Error() error { + c.lock.Lock() + defer c.lock.Unlock() + return c.err +} + // Close close the http2 connection func (c *ClientConn) Close() error { + c.lock.Lock() + defer c.lock.Unlock() + if c.closed { + return nil + } + //log.Println("close client connection") + c.closed = true for _, s := range c.streams { s.Close() } @@ -70,17 +86,17 @@ func (c *ClientConn) Close() error { } func (c *ClientConn) run() { - var wantRead int var wantWrite int var delay = 50 * time.Millisecond var keepalive = 5 * time.Second var ret C.int var lastDataRecv time.Time + //defer c.Close() + defer close(c.errch) - datach := make(chan []byte) - errch := make(chan error) + errch := make(chan struct{}, 5) // data read loop go func() { @@ -90,16 +106,35 @@ func (c *ClientConn) run() { select { case <-c.exitch: break readloop + case <-errch: + break readloop default: } n, err := c.conn.Read(buf) if err != nil { - errch <- err + c.lock.Lock() + c.err = err + c.lock.Unlock() + errch <- struct{}{} break } - datach <- buf[:n] + //log.Printf("read %d bytes from network", n) lastDataRecv = time.Now() + d1 := C.CBytes(buf[:n]) + + c.lock.Lock() + ret1 := C.nghttp2_session_mem_recv(c.session, + (*C.uchar)(d1), C.size_t(n)) + c.lock.Unlock() + + C.free(d1) + if int(ret1) < 0 { + c.err = fmt.Errorf("sesion recv error: %s", + C.GoString(C.nghttp2_strerror(ret))) + //log.Println(c.err) + break + } } }() @@ -109,13 +144,18 @@ func (c *ClientConn) run() { select { case <-c.exitch: return + case <-errch: + return case <-time.After(keepalive): } now := time.Now() last := lastDataRecv d := now.Sub(last) if d > keepalive { + c.lock.Lock() C.nghttp2_submit_ping(c.session, 0, nil) + c.lock.Unlock() + //log.Println("submit ping") } } }() @@ -125,47 +165,99 @@ loop: select { case <-c.errch: break loop - case err := <-errch: - c.err = err + case <-errch: break loop case <-c.exitch: break loop default: } - wantWrite = int(C.nghttp2_session_want_write(c.session)) - if wantWrite != 0 { - ret = C.nghttp2_session_send(c.session) - if int(ret) < 0 { - c.err = fmt.Errorf("sesion send error: %s", - C.GoString(C.nghttp2_strerror(ret))) - //log.Println(c.err) - break - } + c.lock.Lock() + ret = C.nghttp2_session_send(c.session) + c.lock.Unlock() + + if int(ret) < 0 { + c.lock.Lock() + c.err = fmt.Errorf("sesion send error: %s", + C.GoString(C.nghttp2_strerror(ret))) + c.lock.Unlock() + //log.Println(c.err) + errch <- struct{}{} + break } - wantRead = int(C.nghttp2_session_want_read(c.session)) - select { - case d := <-datach: - d1 := C.CBytes(d) - ret1 := C.nghttp2_session_mem_recv(c.session, - (*C.uchar)(d1), C.size_t(int(len(d)))) - C.free(d1) - if int(ret1) < 0 { - c.err = fmt.Errorf("sesion recv error: %s", - C.GoString(C.nghttp2_strerror(ret))) - //log.Println(c.err) - break loop - } - default: - } + c.lock.Lock() + wantWrite = int(C.nghttp2_session_want_write(c.session)) + c.lock.Unlock() // make delay when no data read/write - if wantRead == 0 && wantWrite == 0 { - select { - case <-time.After(delay): - } + if wantWrite == 0 { + time.Sleep(delay) + } + } +} + +// Connect submit a CONNECT request +func (c *ClientConn) Connect(req *http.Request) (*ClientStream, error) { + if c.err != nil { + return nil, c.err + } + nvIndex := 0 + nvMax := 5 + nva := C.new_nv_array(C.size_t(nvMax)) + //log.Printf("%s %s", req.Method, req.RequestURI) + setNvArray(nva, nvIndex, ":method", req.Method, 0) + nvIndex++ + //setNvArray(nva, nvIndex, ":scheme", "https", 0) + //nvIndex++ + //log.Printf("header authority: %s", req.RequestURI) + setNvArray(nva, nvIndex, ":authority", req.RequestURI, 0) + nvIndex++ + var dp *dataProvider + var cdp *C.nghttp2_data_provider + dp, cdp = newDataProvider(c.lock) + + c.lock.Lock() + streamID := C.submit_request(c.session, nva.nv, C.size_t(nvIndex), cdp) + c.lock.Unlock() + + C.delete_nv_array(nva) + if int(streamID) < 0 { + return nil, fmt.Errorf("submit request error: %s", + C.GoString(C.nghttp2_strerror(streamID))) + } + if dp != nil { + dp.streamID = int(streamID) + dp.session = c.session + } + //log.Println("stream id ", int(streamID)) + s := &ClientStream{ + streamID: int(streamID), + conn: c, + dp: dp, + cdp: cdp, + resch: make(chan *http.Response), + errch: make(chan error), + lock: new(sync.Mutex), + } + c.lock.Lock() + c.streams[int(streamID)] = s + c.streamCount++ + c.lock.Unlock() + //log.Printf("new stream id %d", int(streamID)) + select { + case err := <-s.errch: + //log.Println("wait response, got ", err) + return nil, err + case res := <-s.resch: + if res != nil { + res.Request = req + return s, nil } + //log.Println("wait response, empty response") + return nil, io.EOF + case <-c.errch: + return nil, fmt.Errorf("connection error") } } @@ -180,8 +272,10 @@ func (c *ClientConn) CreateRequest(req *http.Request) (*http.Response, error) { nva := C.new_nv_array(C.size_t(nvMax)) setNvArray(nva, nvIndex, ":method", req.Method, 0) nvIndex++ - setNvArray(nva, nvIndex, ":scheme", "https", 0) - nvIndex++ + if req.Method != "CONNECT" { + setNvArray(nva, nvIndex, ":scheme", "https", 0) + nvIndex++ + } setNvArray(nva, nvIndex, ":authority", req.Host, 0) nvIndex++ @@ -190,10 +284,15 @@ func (c *ClientConn) CreateRequest(req *http.Request) (*http.Response, error) { if q != "" { p = p + "?" + q } - setNvArray(nva, nvIndex, ":path", p, 0) - nvIndex++ + if req.Method != "CONNECT" { + setNvArray(nva, nvIndex, ":path", p, 0) + nvIndex++ + } + //log.Printf("%s http://%s%s", req.Method, req.Host, p) for k, v := range req.Header { - if strings.ToLower(k) == "host" { + //log.Printf("header %s: %s\n", k, v[0]) + _k := strings.ToLower(k) + if _k == "host" || _k == "connection" || _k == "proxy-connection" { continue } //log.Printf("header %s: %s", k, v) @@ -202,41 +301,53 @@ func (c *ClientConn) CreateRequest(req *http.Request) (*http.Response, error) { } var dp *dataProvider var cdp *C.nghttp2_data_provider - if req.Body != nil { - dp, cdp = newDataProvider() + if req.Method == "PUT" || req.Method == "POST" || req.Method == "CONNECT" { + dp, cdp = newDataProvider(c.lock) go func() { io.Copy(dp, req.Body) dp.Close() }() } + + c.lock.Lock() streamID := C.submit_request(c.session, nva.nv, C.size_t(nvIndex), cdp) - if dp != nil { - dp.streamID = int(streamID) - dp.session = c.session - } + c.lock.Unlock() + C.delete_nv_array(nva) if int(streamID) < 0 { return nil, fmt.Errorf("submit request error: %s", C.GoString(C.nghttp2_strerror(streamID))) } - //log.Println("stream id ", int(streamID)) + + //log.Printf("new stream, id %d", int(streamID)) + + if dp != nil { + dp.streamID = int(streamID) + dp.session = c.session + } s := &ClientStream{ streamID: int(streamID), + conn: c, dp: dp, cdp: cdp, resch: make(chan *http.Response), errch: make(chan error), + lock: new(sync.Mutex), } c.lock.Lock() c.streams[int(streamID)] = s + c.streamCount++ c.lock.Unlock() select { case err := <-s.errch: return nil, err case res := <-s.resch: - res.Request = req - return res, nil + if res != nil { + res.Request = req + return res, nil + } + return nil, io.EOF case <-c.errch: return nil, fmt.Errorf("connection error") } @@ -336,17 +447,14 @@ func (c *ServerConn) Close() error { // Run run the server loop func (c *ServerConn) Run() { - var wantRead int var wantWrite int var delay = 100 * time.Millisecond var ret C.int - var shouldDelay bool defer c.Close() defer close(c.errch) - datach := make(chan []byte) - errch := make(chan error) + errch := make(chan struct{}, 5) go func() { buf := make([]byte, 16*1024) @@ -355,15 +463,37 @@ func (c *ServerConn) Run() { select { case <-c.exitch: break readloop + case <-errch: + break readloop default: } n, err := c.conn.Read(buf) if err != nil { - errch <- err + c.lock.Lock() + c.err = err + c.lock.Unlock() + errch <- struct{}{} + break + } + + d1 := C.CBytes(buf[:n]) + + c.lock.Lock() + ret1 := C.nghttp2_session_mem_recv(c.session, + (*C.uchar)(d1), C.size_t(n)) + c.lock.Unlock() + + C.free(d1) + if int(ret1) < 0 { + c.lock.Lock() + c.err = fmt.Errorf("sesion recv error: %s", + C.GoString(C.nghttp2_strerror(ret))) + c.lock.Unlock() + //log.Println(c.err) + errch <- struct{}{} break } - datach <- buf[:n] } }() @@ -372,50 +502,33 @@ loop: select { case <-c.errch: break loop - case err := <-errch: - c.err = err + case <-errch: break loop case <-c.exitch: break loop default: } - wantWrite = int(C.nghttp2_session_want_write(c.session)) - if wantWrite != 0 { - ret = C.nghttp2_session_send(c.session) - if int(ret) < 0 { - c.err = fmt.Errorf("sesion send error: %s", - C.GoString(C.nghttp2_strerror(ret))) - //log.Println(c.err) - break - } - } - - wantRead = int(C.nghttp2_session_want_read(c.session)) - select { - case d := <-datach: - d1 := C.CBytes(d) - ret1 := C.nghttp2_session_mem_recv(c.session, - (*C.uchar)(d1), C.size_t(int(len(d)))) - C.free(d1) - if int(ret1) < 0 { - c.err = fmt.Errorf("sesion recv error: %s", - C.GoString(C.nghttp2_strerror(ret))) - //log.Println(c.err) - break loop - } - shouldDelay = false - default: - // want read but data not avaliable - if wantRead != 0 { - shouldDelay = true - } + c.lock.Lock() + ret = C.nghttp2_session_send(c.session) + c.lock.Unlock() + + if int(ret) < 0 { + c.lock.Lock() + c.err = fmt.Errorf("sesion send error: %s", + C.GoString(C.nghttp2_strerror(ret))) + c.lock.Unlock() + //log.Println(c.err) + errch <- struct{}{} + break } + c.lock.Lock() wantWrite = int(C.nghttp2_session_want_write(c.session)) + c.lock.Unlock() // make delay when no data read/write - if (shouldDelay || wantRead == 0) && wantWrite == 0 { + if wantWrite == 0 { time.Sleep(delay) } } diff --git a/data_provider.go b/data_provider.go index 6ded9b7..3e053c7 100644 --- a/data_provider.go +++ b/data_provider.go @@ -6,6 +6,7 @@ package nghttp2 import "C" import ( "bytes" + "io" "sync" "time" "unsafe" @@ -19,8 +20,10 @@ type dataProvider struct { buf *bytes.Buffer closed bool lock *sync.Mutex + sessLock *sync.Mutex session *C.nghttp2_session streamID int + deferred bool } // Read read from data provider @@ -39,7 +42,15 @@ func (dp *dataProvider) Read(buf []byte) (n int, err error) { func (dp *dataProvider) Write(buf []byte) (n int, err error) { dp.lock.Lock() defer dp.lock.Unlock() - C.nghttp2_session_resume_data(dp.session, C.int(dp.streamID)) + if dp.closed { + return 0, io.EOF + } + if dp.deferred { + dp.sessLock.Lock() + C.nghttp2_session_resume_data(dp.session, C.int(dp.streamID)) + dp.sessLock.Unlock() + dp.deferred = false + } return dp.buf.Write(buf) } @@ -47,16 +58,26 @@ func (dp *dataProvider) Write(buf []byte) (n int, err error) { func (dp *dataProvider) Close() error { dp.lock.Lock() defer dp.lock.Unlock() + if dp.closed { + return nil + } dp.closed = true - C.nghttp2_session_resume_data(dp.session, C.int(dp.streamID)) + //log.Printf("dp close stream %d", dp.streamID) + if dp.deferred { + dp.sessLock.Lock() + C.nghttp2_session_resume_data(dp.session, C.int(dp.streamID)) + dp.sessLock.Unlock() + dp.deferred = false + } return nil } -func newDataProvider() ( +func newDataProvider(sessionLock *sync.Mutex) ( *dataProvider, *C.nghttp2_data_provider) { dp := &dataProvider{ - buf: new(bytes.Buffer), - lock: new(sync.Mutex), + buf: new(bytes.Buffer), + lock: new(sync.Mutex), + sessLock: sessionLock, } cdp := C.new_data_provider(C.size_t(uintptr(unsafe.Pointer(dp)))) return dp, cdp diff --git a/stream.go b/stream.go index f28fa18..235316a 100644 --- a/stream.go +++ b/stream.go @@ -9,11 +9,14 @@ import ( "io" "net/http" "strings" + "sync" + "unsafe" ) // ClientStream http2 client stream type ClientStream struct { streamID int + conn *ClientConn cdp *C.nghttp2_data_provider dp *dataProvider // application read data from stream @@ -24,15 +27,27 @@ type ClientStream struct { resch chan *http.Response errch chan error closed bool + lock *sync.Mutex +} + +// Response return response of the current stream +func (s *ClientStream) Response() *http.Response { + return s.res } // Read read stream data func (s *ClientStream) Read(buf []byte) (n int, err error) { + if s.closed || s.res == nil || s.res.Body == nil { + return 0, io.EOF + } return s.res.Body.Read(buf) } // Write write data to stream func (s *ClientStream) Write(buf []byte) (n int, err error) { + if s.closed { + return 0, io.EOF + } if s.dp != nil { return s.dp.Write(buf) } @@ -41,26 +56,44 @@ func (s *ClientStream) Write(buf []byte) (n int, err error) { // Close close the stream func (s *ClientStream) Close() error { + //s.lock.Lock() + //defer s.lock.Unlock() if s.closed { return nil } + s.closed = true err := io.EOF - //log.Println("close stream") + //log.Printf("close stream %d", int(s.streamID)) select { case s.errch <- err: default: } //log.Println("close stream resch") - close(s.resch) + //close(s.resch) //log.Println("close stream errch") - close(s.errch) + //close(s.errch) //log.Println("close pipe w") - s.res.Body.Close() + if s.res != nil && s.res.Body != nil { + s.res.Body.Close() + } //log.Println("close stream done") if s.dp != nil { s.dp.Close() + //s.conn.lock.Lock() + //C.nghttp2_submit_rst_stream(s.conn.session, 0, C.int(s.streamID), 0) + //s.conn.lock.Unlock() + s.dp = nil } - s.closed = true + if s.cdp != nil { + C.free(unsafe.Pointer(s.cdp)) + s.cdp = nil + } + + //s.conn.lock.Lock() + //defer s.conn.lock.Unlock() + //if _, ok := s.conn.streams[s.streamID]; ok { + //delete(s.conn.streams, s.streamID) + //} return nil } @@ -94,27 +127,26 @@ type ServerStream struct { // Write write data to stream, // implements http.ResponseWriter func (s *ServerStream) Write(buf []byte) (int, error) { + if s.closed { + return 0, io.EOF + } + if !s.responseSend { s.WriteHeader(http.StatusOK) } - /* - //log.Printf("stream %d, send %d bytes", s.streamID, len(buf)) - if s.buf.Len() > 2048 { - s.dp.Write(s.buf.Bytes()) - s.buf.Reset() - } - - if len(buf) < 2048 { - s.buf.Write(buf) - return len(buf), nil - } - */ return s.dp.Write(buf) } // WriteHeader set response code and send reponse, // implements http.ResponseWriter func (s *ServerStream) WriteHeader(code int) { + if s.closed { + return + } + if s.responseSend { + return + } + s.responseSend = true s.statusCode = code nvIndex := 0 nvMax := 25 @@ -132,18 +164,21 @@ func (s *ServerStream) WriteHeader(code int) { } var dp *dataProvider var cdp *C.nghttp2_data_provider - dp, cdp = newDataProvider() + dp, cdp = newDataProvider(s.conn.lock) dp.streamID = s.streamID dp.session = s.conn.session s.dp = dp s.cdp = cdp + + s.conn.lock.Lock() ret := C.nghttp2_submit_response( s.conn.session, C.int(s.streamID), nva.nv, C.size_t(nvIndex), cdp) + s.conn.lock.Unlock() + C.delete_nv_array(nva) if int(ret) < 0 { panic(fmt.Sprintf("sumit response error %s", C.GoString(C.nghttp2_strerror(ret)))) } - s.responseSend = true //log.Printf("stream %d send response", s.streamID) } @@ -161,14 +196,19 @@ func (s *ServerStream) Close() error { if s.closed { return nil } + s.closed = true //C.nghttp2_submit_rst_stream(s.conn.session, 0, C.int(s.streamID), 0) if s.req.Body != nil { s.req.Body.Close() } if s.dp != nil { s.dp.Close() + s.dp = nil + } + if s.cdp != nil { + C.free(unsafe.Pointer(s.cdp)) + s.cdp = nil } - s.closed = true //log.Printf("stream %d closed", s.streamID) return nil }