make sure nghttp_session in one thread at same time

merge_conn
fangdingjun 6 years ago
parent 4125b26c35
commit 563e006303

@ -32,6 +32,8 @@ type ClientConn struct {
errch chan struct{}
exitch chan struct{}
err error
closed bool
streamCount int
}
// Client create http2 client
@ -57,8 +59,22 @@ func Client(c net.Conn) (*ClientConn, error) {
return conn, nil
}
// Error return current error on connection
func (c *ClientConn) Error() error {
c.lock.Lock()
defer c.lock.Unlock()
return c.err
}
// Close close the http2 connection
func (c *ClientConn) Close() error {
c.lock.Lock()
defer c.lock.Unlock()
if c.closed {
return nil
}
//log.Println("close client connection")
c.closed = true
for _, s := range c.streams {
s.Close()
}
@ -70,17 +86,17 @@ func (c *ClientConn) Close() error {
}
func (c *ClientConn) run() {
var wantRead int
var wantWrite int
var delay = 50 * time.Millisecond
var keepalive = 5 * time.Second
var ret C.int
var lastDataRecv time.Time
//defer c.Close()
defer close(c.errch)
datach := make(chan []byte)
errch := make(chan error)
errch := make(chan struct{}, 5)
// data read loop
go func() {
@ -90,16 +106,35 @@ func (c *ClientConn) run() {
select {
case <-c.exitch:
break readloop
case <-errch:
break readloop
default:
}
n, err := c.conn.Read(buf)
if err != nil {
errch <- err
c.lock.Lock()
c.err = err
c.lock.Unlock()
errch <- struct{}{}
break
}
datach <- buf[:n]
//log.Printf("read %d bytes from network", n)
lastDataRecv = time.Now()
d1 := C.CBytes(buf[:n])
c.lock.Lock()
ret1 := C.nghttp2_session_mem_recv(c.session,
(*C.uchar)(d1), C.size_t(n))
c.lock.Unlock()
C.free(d1)
if int(ret1) < 0 {
c.err = fmt.Errorf("sesion recv error: %s",
C.GoString(C.nghttp2_strerror(ret)))
//log.Println(c.err)
break
}
}
}()
@ -109,13 +144,18 @@ func (c *ClientConn) run() {
select {
case <-c.exitch:
return
case <-errch:
return
case <-time.After(keepalive):
}
now := time.Now()
last := lastDataRecv
d := now.Sub(last)
if d > keepalive {
c.lock.Lock()
C.nghttp2_submit_ping(c.session, 0, nil)
c.lock.Unlock()
//log.Println("submit ping")
}
}
}()
@ -125,47 +165,99 @@ loop:
select {
case <-c.errch:
break loop
case err := <-errch:
c.err = err
case <-errch:
break loop
case <-c.exitch:
break loop
default:
}
wantWrite = int(C.nghttp2_session_want_write(c.session))
if wantWrite != 0 {
c.lock.Lock()
ret = C.nghttp2_session_send(c.session)
c.lock.Unlock()
if int(ret) < 0 {
c.lock.Lock()
c.err = fmt.Errorf("sesion send error: %s",
C.GoString(C.nghttp2_strerror(ret)))
c.lock.Unlock()
//log.Println(c.err)
errch <- struct{}{}
break
}
}
wantRead = int(C.nghttp2_session_want_read(c.session))
select {
case d := <-datach:
d1 := C.CBytes(d)
ret1 := C.nghttp2_session_mem_recv(c.session,
(*C.uchar)(d1), C.size_t(int(len(d))))
C.free(d1)
if int(ret1) < 0 {
c.err = fmt.Errorf("sesion recv error: %s",
C.GoString(C.nghttp2_strerror(ret)))
//log.Println(c.err)
break loop
c.lock.Lock()
wantWrite = int(C.nghttp2_session_want_write(c.session))
c.lock.Unlock()
// make delay when no data read/write
if wantWrite == 0 {
time.Sleep(delay)
}
}
default:
}
// make delay when no data read/write
if wantRead == 0 && wantWrite == 0 {
select {
case <-time.After(delay):
// Connect submit a CONNECT request
func (c *ClientConn) Connect(req *http.Request) (*ClientStream, error) {
if c.err != nil {
return nil, c.err
}
nvIndex := 0
nvMax := 5
nva := C.new_nv_array(C.size_t(nvMax))
//log.Printf("%s %s", req.Method, req.RequestURI)
setNvArray(nva, nvIndex, ":method", req.Method, 0)
nvIndex++
//setNvArray(nva, nvIndex, ":scheme", "https", 0)
//nvIndex++
//log.Printf("header authority: %s", req.RequestURI)
setNvArray(nva, nvIndex, ":authority", req.RequestURI, 0)
nvIndex++
var dp *dataProvider
var cdp *C.nghttp2_data_provider
dp, cdp = newDataProvider(c.lock)
c.lock.Lock()
streamID := C.submit_request(c.session, nva.nv, C.size_t(nvIndex), cdp)
c.lock.Unlock()
C.delete_nv_array(nva)
if int(streamID) < 0 {
return nil, fmt.Errorf("submit request error: %s",
C.GoString(C.nghttp2_strerror(streamID)))
}
if dp != nil {
dp.streamID = int(streamID)
dp.session = c.session
}
//log.Println("stream id ", int(streamID))
s := &ClientStream{
streamID: int(streamID),
conn: c,
dp: dp,
cdp: cdp,
resch: make(chan *http.Response),
errch: make(chan error),
lock: new(sync.Mutex),
}
c.lock.Lock()
c.streams[int(streamID)] = s
c.streamCount++
c.lock.Unlock()
//log.Printf("new stream id %d", int(streamID))
select {
case err := <-s.errch:
//log.Println("wait response, got ", err)
return nil, err
case res := <-s.resch:
if res != nil {
res.Request = req
return s, nil
}
//log.Println("wait response, empty response")
return nil, io.EOF
case <-c.errch:
return nil, fmt.Errorf("connection error")
}
}
@ -180,8 +272,10 @@ func (c *ClientConn) CreateRequest(req *http.Request) (*http.Response, error) {
nva := C.new_nv_array(C.size_t(nvMax))
setNvArray(nva, nvIndex, ":method", req.Method, 0)
nvIndex++
if req.Method != "CONNECT" {
setNvArray(nva, nvIndex, ":scheme", "https", 0)
nvIndex++
}
setNvArray(nva, nvIndex, ":authority", req.Host, 0)
nvIndex++
@ -190,10 +284,15 @@ func (c *ClientConn) CreateRequest(req *http.Request) (*http.Response, error) {
if q != "" {
p = p + "?" + q
}
if req.Method != "CONNECT" {
setNvArray(nva, nvIndex, ":path", p, 0)
nvIndex++
}
//log.Printf("%s http://%s%s", req.Method, req.Host, p)
for k, v := range req.Header {
if strings.ToLower(k) == "host" {
//log.Printf("header %s: %s\n", k, v[0])
_k := strings.ToLower(k)
if _k == "host" || _k == "connection" || _k == "proxy-connection" {
continue
}
//log.Printf("header %s: %s", k, v)
@ -202,41 +301,53 @@ func (c *ClientConn) CreateRequest(req *http.Request) (*http.Response, error) {
}
var dp *dataProvider
var cdp *C.nghttp2_data_provider
if req.Body != nil {
dp, cdp = newDataProvider()
if req.Method == "PUT" || req.Method == "POST" || req.Method == "CONNECT" {
dp, cdp = newDataProvider(c.lock)
go func() {
io.Copy(dp, req.Body)
dp.Close()
}()
}
c.lock.Lock()
streamID := C.submit_request(c.session, nva.nv, C.size_t(nvIndex), cdp)
if dp != nil {
dp.streamID = int(streamID)
dp.session = c.session
}
c.lock.Unlock()
C.delete_nv_array(nva)
if int(streamID) < 0 {
return nil, fmt.Errorf("submit request error: %s",
C.GoString(C.nghttp2_strerror(streamID)))
}
//log.Println("stream id ", int(streamID))
//log.Printf("new stream, id %d", int(streamID))
if dp != nil {
dp.streamID = int(streamID)
dp.session = c.session
}
s := &ClientStream{
streamID: int(streamID),
conn: c,
dp: dp,
cdp: cdp,
resch: make(chan *http.Response),
errch: make(chan error),
lock: new(sync.Mutex),
}
c.lock.Lock()
c.streams[int(streamID)] = s
c.streamCount++
c.lock.Unlock()
select {
case err := <-s.errch:
return nil, err
case res := <-s.resch:
if res != nil {
res.Request = req
return res, nil
}
return nil, io.EOF
case <-c.errch:
return nil, fmt.Errorf("connection error")
}
@ -336,17 +447,14 @@ func (c *ServerConn) Close() error {
// Run run the server loop
func (c *ServerConn) Run() {
var wantRead int
var wantWrite int
var delay = 100 * time.Millisecond
var ret C.int
var shouldDelay bool
defer c.Close()
defer close(c.errch)
datach := make(chan []byte)
errch := make(chan error)
errch := make(chan struct{}, 5)
go func() {
buf := make([]byte, 16*1024)
@ -355,15 +463,37 @@ func (c *ServerConn) Run() {
select {
case <-c.exitch:
break readloop
case <-errch:
break readloop
default:
}
n, err := c.conn.Read(buf)
if err != nil {
errch <- err
c.lock.Lock()
c.err = err
c.lock.Unlock()
errch <- struct{}{}
break
}
d1 := C.CBytes(buf[:n])
c.lock.Lock()
ret1 := C.nghttp2_session_mem_recv(c.session,
(*C.uchar)(d1), C.size_t(n))
c.lock.Unlock()
C.free(d1)
if int(ret1) < 0 {
c.lock.Lock()
c.err = fmt.Errorf("sesion recv error: %s",
C.GoString(C.nghttp2_strerror(ret)))
c.lock.Unlock()
//log.Println(c.err)
errch <- struct{}{}
break
}
datach <- buf[:n]
}
}()
@ -372,50 +502,33 @@ loop:
select {
case <-c.errch:
break loop
case err := <-errch:
c.err = err
case <-errch:
break loop
case <-c.exitch:
break loop
default:
}
wantWrite = int(C.nghttp2_session_want_write(c.session))
if wantWrite != 0 {
c.lock.Lock()
ret = C.nghttp2_session_send(c.session)
c.lock.Unlock()
if int(ret) < 0 {
c.lock.Lock()
c.err = fmt.Errorf("sesion send error: %s",
C.GoString(C.nghttp2_strerror(ret)))
c.lock.Unlock()
//log.Println(c.err)
errch <- struct{}{}
break
}
}
wantRead = int(C.nghttp2_session_want_read(c.session))
select {
case d := <-datach:
d1 := C.CBytes(d)
ret1 := C.nghttp2_session_mem_recv(c.session,
(*C.uchar)(d1), C.size_t(int(len(d))))
C.free(d1)
if int(ret1) < 0 {
c.err = fmt.Errorf("sesion recv error: %s",
C.GoString(C.nghttp2_strerror(ret)))
//log.Println(c.err)
break loop
}
shouldDelay = false
default:
// want read but data not avaliable
if wantRead != 0 {
shouldDelay = true
}
}
c.lock.Lock()
wantWrite = int(C.nghttp2_session_want_write(c.session))
c.lock.Unlock()
// make delay when no data read/write
if (shouldDelay || wantRead == 0) && wantWrite == 0 {
if wantWrite == 0 {
time.Sleep(delay)
}
}

@ -6,6 +6,7 @@ package nghttp2
import "C"
import (
"bytes"
"io"
"sync"
"time"
"unsafe"
@ -19,8 +20,10 @@ type dataProvider struct {
buf *bytes.Buffer
closed bool
lock *sync.Mutex
sessLock *sync.Mutex
session *C.nghttp2_session
streamID int
deferred bool
}
// Read read from data provider
@ -39,7 +42,15 @@ func (dp *dataProvider) Read(buf []byte) (n int, err error) {
func (dp *dataProvider) Write(buf []byte) (n int, err error) {
dp.lock.Lock()
defer dp.lock.Unlock()
if dp.closed {
return 0, io.EOF
}
if dp.deferred {
dp.sessLock.Lock()
C.nghttp2_session_resume_data(dp.session, C.int(dp.streamID))
dp.sessLock.Unlock()
dp.deferred = false
}
return dp.buf.Write(buf)
}
@ -47,16 +58,26 @@ func (dp *dataProvider) Write(buf []byte) (n int, err error) {
func (dp *dataProvider) Close() error {
dp.lock.Lock()
defer dp.lock.Unlock()
if dp.closed {
return nil
}
dp.closed = true
//log.Printf("dp close stream %d", dp.streamID)
if dp.deferred {
dp.sessLock.Lock()
C.nghttp2_session_resume_data(dp.session, C.int(dp.streamID))
dp.sessLock.Unlock()
dp.deferred = false
}
return nil
}
func newDataProvider() (
func newDataProvider(sessionLock *sync.Mutex) (
*dataProvider, *C.nghttp2_data_provider) {
dp := &dataProvider{
buf: new(bytes.Buffer),
lock: new(sync.Mutex),
sessLock: sessionLock,
}
cdp := C.new_data_provider(C.size_t(uintptr(unsafe.Pointer(dp))))
return dp, cdp

@ -9,11 +9,14 @@ import (
"io"
"net/http"
"strings"
"sync"
"unsafe"
)
// ClientStream http2 client stream
type ClientStream struct {
streamID int
conn *ClientConn
cdp *C.nghttp2_data_provider
dp *dataProvider
// application read data from stream
@ -24,15 +27,27 @@ type ClientStream struct {
resch chan *http.Response
errch chan error
closed bool
lock *sync.Mutex
}
// Response return response of the current stream
func (s *ClientStream) Response() *http.Response {
return s.res
}
// Read read stream data
func (s *ClientStream) Read(buf []byte) (n int, err error) {
if s.closed || s.res == nil || s.res.Body == nil {
return 0, io.EOF
}
return s.res.Body.Read(buf)
}
// Write write data to stream
func (s *ClientStream) Write(buf []byte) (n int, err error) {
if s.closed {
return 0, io.EOF
}
if s.dp != nil {
return s.dp.Write(buf)
}
@ -41,26 +56,44 @@ func (s *ClientStream) Write(buf []byte) (n int, err error) {
// Close close the stream
func (s *ClientStream) Close() error {
//s.lock.Lock()
//defer s.lock.Unlock()
if s.closed {
return nil
}
s.closed = true
err := io.EOF
//log.Println("close stream")
//log.Printf("close stream %d", int(s.streamID))
select {
case s.errch <- err:
default:
}
//log.Println("close stream resch")
close(s.resch)
//close(s.resch)
//log.Println("close stream errch")
close(s.errch)
//close(s.errch)
//log.Println("close pipe w")
if s.res != nil && s.res.Body != nil {
s.res.Body.Close()
}
//log.Println("close stream done")
if s.dp != nil {
s.dp.Close()
//s.conn.lock.Lock()
//C.nghttp2_submit_rst_stream(s.conn.session, 0, C.int(s.streamID), 0)
//s.conn.lock.Unlock()
s.dp = nil
}
s.closed = true
if s.cdp != nil {
C.free(unsafe.Pointer(s.cdp))
s.cdp = nil
}
//s.conn.lock.Lock()
//defer s.conn.lock.Unlock()
//if _, ok := s.conn.streams[s.streamID]; ok {
//delete(s.conn.streams, s.streamID)
//}
return nil
}
@ -94,27 +127,26 @@ type ServerStream struct {
// Write write data to stream,
// implements http.ResponseWriter
func (s *ServerStream) Write(buf []byte) (int, error) {
if !s.responseSend {
s.WriteHeader(http.StatusOK)
}
/*
//log.Printf("stream %d, send %d bytes", s.streamID, len(buf))
if s.buf.Len() > 2048 {
s.dp.Write(s.buf.Bytes())
s.buf.Reset()
if s.closed {
return 0, io.EOF
}
if len(buf) < 2048 {
s.buf.Write(buf)
return len(buf), nil
if !s.responseSend {
s.WriteHeader(http.StatusOK)
}
*/
return s.dp.Write(buf)
}
// WriteHeader set response code and send reponse,
// implements http.ResponseWriter
func (s *ServerStream) WriteHeader(code int) {
if s.closed {
return
}
if s.responseSend {
return
}
s.responseSend = true
s.statusCode = code
nvIndex := 0
nvMax := 25
@ -132,18 +164,21 @@ func (s *ServerStream) WriteHeader(code int) {
}
var dp *dataProvider
var cdp *C.nghttp2_data_provider
dp, cdp = newDataProvider()
dp, cdp = newDataProvider(s.conn.lock)
dp.streamID = s.streamID
dp.session = s.conn.session
s.dp = dp
s.cdp = cdp
s.conn.lock.Lock()
ret := C.nghttp2_submit_response(
s.conn.session, C.int(s.streamID), nva.nv, C.size_t(nvIndex), cdp)
s.conn.lock.Unlock()
C.delete_nv_array(nva)
if int(ret) < 0 {
panic(fmt.Sprintf("sumit response error %s", C.GoString(C.nghttp2_strerror(ret))))
}
s.responseSend = true
//log.Printf("stream %d send response", s.streamID)
}
@ -161,14 +196,19 @@ func (s *ServerStream) Close() error {
if s.closed {
return nil
}
s.closed = true
//C.nghttp2_submit_rst_stream(s.conn.session, 0, C.int(s.streamID), 0)
if s.req.Body != nil {
s.req.Body.Close()
}
if s.dp != nil {
s.dp.Close()
s.dp = nil
}
if s.cdp != nil {
C.free(unsafe.Pointer(s.cdp))
s.cdp = nil
}
s.closed = true
//log.Printf("stream %d closed", s.streamID)
return nil
}

Loading…
Cancel
Save