make sure nghttp_session in one thread at same time

merge_conn
fangdingjun 6 years ago
parent 4125b26c35
commit 563e006303

@ -32,6 +32,8 @@ type ClientConn struct {
errch chan struct{} errch chan struct{}
exitch chan struct{} exitch chan struct{}
err error err error
closed bool
streamCount int
} }
// Client create http2 client // Client create http2 client
@ -57,8 +59,22 @@ func Client(c net.Conn) (*ClientConn, error) {
return conn, nil return conn, nil
} }
// Error return current error on connection
func (c *ClientConn) Error() error {
c.lock.Lock()
defer c.lock.Unlock()
return c.err
}
// Close close the http2 connection // Close close the http2 connection
func (c *ClientConn) Close() error { func (c *ClientConn) Close() error {
c.lock.Lock()
defer c.lock.Unlock()
if c.closed {
return nil
}
//log.Println("close client connection")
c.closed = true
for _, s := range c.streams { for _, s := range c.streams {
s.Close() s.Close()
} }
@ -70,17 +86,17 @@ func (c *ClientConn) Close() error {
} }
func (c *ClientConn) run() { func (c *ClientConn) run() {
var wantRead int
var wantWrite int var wantWrite int
var delay = 50 * time.Millisecond var delay = 50 * time.Millisecond
var keepalive = 5 * time.Second var keepalive = 5 * time.Second
var ret C.int var ret C.int
var lastDataRecv time.Time var lastDataRecv time.Time
//defer c.Close()
defer close(c.errch) defer close(c.errch)
datach := make(chan []byte) errch := make(chan struct{}, 5)
errch := make(chan error)
// data read loop // data read loop
go func() { go func() {
@ -90,16 +106,35 @@ func (c *ClientConn) run() {
select { select {
case <-c.exitch: case <-c.exitch:
break readloop break readloop
case <-errch:
break readloop
default: default:
} }
n, err := c.conn.Read(buf) n, err := c.conn.Read(buf)
if err != nil { if err != nil {
errch <- err c.lock.Lock()
c.err = err
c.lock.Unlock()
errch <- struct{}{}
break break
} }
datach <- buf[:n] //log.Printf("read %d bytes from network", n)
lastDataRecv = time.Now() lastDataRecv = time.Now()
d1 := C.CBytes(buf[:n])
c.lock.Lock()
ret1 := C.nghttp2_session_mem_recv(c.session,
(*C.uchar)(d1), C.size_t(n))
c.lock.Unlock()
C.free(d1)
if int(ret1) < 0 {
c.err = fmt.Errorf("sesion recv error: %s",
C.GoString(C.nghttp2_strerror(ret)))
//log.Println(c.err)
break
}
} }
}() }()
@ -109,13 +144,18 @@ func (c *ClientConn) run() {
select { select {
case <-c.exitch: case <-c.exitch:
return return
case <-errch:
return
case <-time.After(keepalive): case <-time.After(keepalive):
} }
now := time.Now() now := time.Now()
last := lastDataRecv last := lastDataRecv
d := now.Sub(last) d := now.Sub(last)
if d > keepalive { if d > keepalive {
c.lock.Lock()
C.nghttp2_submit_ping(c.session, 0, nil) C.nghttp2_submit_ping(c.session, 0, nil)
c.lock.Unlock()
//log.Println("submit ping")
} }
} }
}() }()
@ -125,47 +165,99 @@ loop:
select { select {
case <-c.errch: case <-c.errch:
break loop break loop
case err := <-errch: case <-errch:
c.err = err
break loop break loop
case <-c.exitch: case <-c.exitch:
break loop break loop
default: default:
} }
wantWrite = int(C.nghttp2_session_want_write(c.session)) c.lock.Lock()
if wantWrite != 0 {
ret = C.nghttp2_session_send(c.session) ret = C.nghttp2_session_send(c.session)
c.lock.Unlock()
if int(ret) < 0 { if int(ret) < 0 {
c.lock.Lock()
c.err = fmt.Errorf("sesion send error: %s", c.err = fmt.Errorf("sesion send error: %s",
C.GoString(C.nghttp2_strerror(ret))) C.GoString(C.nghttp2_strerror(ret)))
c.lock.Unlock()
//log.Println(c.err) //log.Println(c.err)
errch <- struct{}{}
break break
} }
}
wantRead = int(C.nghttp2_session_want_read(c.session)) c.lock.Lock()
select { wantWrite = int(C.nghttp2_session_want_write(c.session))
case d := <-datach: c.lock.Unlock()
d1 := C.CBytes(d)
ret1 := C.nghttp2_session_mem_recv(c.session, // make delay when no data read/write
(*C.uchar)(d1), C.size_t(int(len(d)))) if wantWrite == 0 {
C.free(d1) time.Sleep(delay)
if int(ret1) < 0 { }
c.err = fmt.Errorf("sesion recv error: %s",
C.GoString(C.nghttp2_strerror(ret)))
//log.Println(c.err)
break loop
} }
default:
} }
// make delay when no data read/write // Connect submit a CONNECT request
if wantRead == 0 && wantWrite == 0 { func (c *ClientConn) Connect(req *http.Request) (*ClientStream, error) {
select { if c.err != nil {
case <-time.After(delay): return nil, c.err
}
nvIndex := 0
nvMax := 5
nva := C.new_nv_array(C.size_t(nvMax))
//log.Printf("%s %s", req.Method, req.RequestURI)
setNvArray(nva, nvIndex, ":method", req.Method, 0)
nvIndex++
//setNvArray(nva, nvIndex, ":scheme", "https", 0)
//nvIndex++
//log.Printf("header authority: %s", req.RequestURI)
setNvArray(nva, nvIndex, ":authority", req.RequestURI, 0)
nvIndex++
var dp *dataProvider
var cdp *C.nghttp2_data_provider
dp, cdp = newDataProvider(c.lock)
c.lock.Lock()
streamID := C.submit_request(c.session, nva.nv, C.size_t(nvIndex), cdp)
c.lock.Unlock()
C.delete_nv_array(nva)
if int(streamID) < 0 {
return nil, fmt.Errorf("submit request error: %s",
C.GoString(C.nghttp2_strerror(streamID)))
}
if dp != nil {
dp.streamID = int(streamID)
dp.session = c.session
}
//log.Println("stream id ", int(streamID))
s := &ClientStream{
streamID: int(streamID),
conn: c,
dp: dp,
cdp: cdp,
resch: make(chan *http.Response),
errch: make(chan error),
lock: new(sync.Mutex),
} }
c.lock.Lock()
c.streams[int(streamID)] = s
c.streamCount++
c.lock.Unlock()
//log.Printf("new stream id %d", int(streamID))
select {
case err := <-s.errch:
//log.Println("wait response, got ", err)
return nil, err
case res := <-s.resch:
if res != nil {
res.Request = req
return s, nil
} }
//log.Println("wait response, empty response")
return nil, io.EOF
case <-c.errch:
return nil, fmt.Errorf("connection error")
} }
} }
@ -180,8 +272,10 @@ func (c *ClientConn) CreateRequest(req *http.Request) (*http.Response, error) {
nva := C.new_nv_array(C.size_t(nvMax)) nva := C.new_nv_array(C.size_t(nvMax))
setNvArray(nva, nvIndex, ":method", req.Method, 0) setNvArray(nva, nvIndex, ":method", req.Method, 0)
nvIndex++ nvIndex++
if req.Method != "CONNECT" {
setNvArray(nva, nvIndex, ":scheme", "https", 0) setNvArray(nva, nvIndex, ":scheme", "https", 0)
nvIndex++ nvIndex++
}
setNvArray(nva, nvIndex, ":authority", req.Host, 0) setNvArray(nva, nvIndex, ":authority", req.Host, 0)
nvIndex++ nvIndex++
@ -190,10 +284,15 @@ func (c *ClientConn) CreateRequest(req *http.Request) (*http.Response, error) {
if q != "" { if q != "" {
p = p + "?" + q p = p + "?" + q
} }
if req.Method != "CONNECT" {
setNvArray(nva, nvIndex, ":path", p, 0) setNvArray(nva, nvIndex, ":path", p, 0)
nvIndex++ nvIndex++
}
//log.Printf("%s http://%s%s", req.Method, req.Host, p)
for k, v := range req.Header { for k, v := range req.Header {
if strings.ToLower(k) == "host" { //log.Printf("header %s: %s\n", k, v[0])
_k := strings.ToLower(k)
if _k == "host" || _k == "connection" || _k == "proxy-connection" {
continue continue
} }
//log.Printf("header %s: %s", k, v) //log.Printf("header %s: %s", k, v)
@ -202,41 +301,53 @@ func (c *ClientConn) CreateRequest(req *http.Request) (*http.Response, error) {
} }
var dp *dataProvider var dp *dataProvider
var cdp *C.nghttp2_data_provider var cdp *C.nghttp2_data_provider
if req.Body != nil { if req.Method == "PUT" || req.Method == "POST" || req.Method == "CONNECT" {
dp, cdp = newDataProvider() dp, cdp = newDataProvider(c.lock)
go func() { go func() {
io.Copy(dp, req.Body) io.Copy(dp, req.Body)
dp.Close() dp.Close()
}() }()
} }
c.lock.Lock()
streamID := C.submit_request(c.session, nva.nv, C.size_t(nvIndex), cdp) streamID := C.submit_request(c.session, nva.nv, C.size_t(nvIndex), cdp)
if dp != nil { c.lock.Unlock()
dp.streamID = int(streamID)
dp.session = c.session
}
C.delete_nv_array(nva) C.delete_nv_array(nva)
if int(streamID) < 0 { if int(streamID) < 0 {
return nil, fmt.Errorf("submit request error: %s", return nil, fmt.Errorf("submit request error: %s",
C.GoString(C.nghttp2_strerror(streamID))) C.GoString(C.nghttp2_strerror(streamID)))
} }
//log.Println("stream id ", int(streamID))
//log.Printf("new stream, id %d", int(streamID))
if dp != nil {
dp.streamID = int(streamID)
dp.session = c.session
}
s := &ClientStream{ s := &ClientStream{
streamID: int(streamID), streamID: int(streamID),
conn: c,
dp: dp, dp: dp,
cdp: cdp, cdp: cdp,
resch: make(chan *http.Response), resch: make(chan *http.Response),
errch: make(chan error), errch: make(chan error),
lock: new(sync.Mutex),
} }
c.lock.Lock() c.lock.Lock()
c.streams[int(streamID)] = s c.streams[int(streamID)] = s
c.streamCount++
c.lock.Unlock() c.lock.Unlock()
select { select {
case err := <-s.errch: case err := <-s.errch:
return nil, err return nil, err
case res := <-s.resch: case res := <-s.resch:
if res != nil {
res.Request = req res.Request = req
return res, nil return res, nil
}
return nil, io.EOF
case <-c.errch: case <-c.errch:
return nil, fmt.Errorf("connection error") return nil, fmt.Errorf("connection error")
} }
@ -336,17 +447,14 @@ func (c *ServerConn) Close() error {
// Run run the server loop // Run run the server loop
func (c *ServerConn) Run() { func (c *ServerConn) Run() {
var wantRead int
var wantWrite int var wantWrite int
var delay = 100 * time.Millisecond var delay = 100 * time.Millisecond
var ret C.int var ret C.int
var shouldDelay bool
defer c.Close() defer c.Close()
defer close(c.errch) defer close(c.errch)
datach := make(chan []byte) errch := make(chan struct{}, 5)
errch := make(chan error)
go func() { go func() {
buf := make([]byte, 16*1024) buf := make([]byte, 16*1024)
@ -355,15 +463,37 @@ func (c *ServerConn) Run() {
select { select {
case <-c.exitch: case <-c.exitch:
break readloop break readloop
case <-errch:
break readloop
default: default:
} }
n, err := c.conn.Read(buf) n, err := c.conn.Read(buf)
if err != nil { if err != nil {
errch <- err c.lock.Lock()
c.err = err
c.lock.Unlock()
errch <- struct{}{}
break
}
d1 := C.CBytes(buf[:n])
c.lock.Lock()
ret1 := C.nghttp2_session_mem_recv(c.session,
(*C.uchar)(d1), C.size_t(n))
c.lock.Unlock()
C.free(d1)
if int(ret1) < 0 {
c.lock.Lock()
c.err = fmt.Errorf("sesion recv error: %s",
C.GoString(C.nghttp2_strerror(ret)))
c.lock.Unlock()
//log.Println(c.err)
errch <- struct{}{}
break break
} }
datach <- buf[:n]
} }
}() }()
@ -372,50 +502,33 @@ loop:
select { select {
case <-c.errch: case <-c.errch:
break loop break loop
case err := <-errch: case <-errch:
c.err = err
break loop break loop
case <-c.exitch: case <-c.exitch:
break loop break loop
default: default:
} }
wantWrite = int(C.nghttp2_session_want_write(c.session)) c.lock.Lock()
if wantWrite != 0 {
ret = C.nghttp2_session_send(c.session) ret = C.nghttp2_session_send(c.session)
c.lock.Unlock()
if int(ret) < 0 { if int(ret) < 0 {
c.lock.Lock()
c.err = fmt.Errorf("sesion send error: %s", c.err = fmt.Errorf("sesion send error: %s",
C.GoString(C.nghttp2_strerror(ret))) C.GoString(C.nghttp2_strerror(ret)))
c.lock.Unlock()
//log.Println(c.err) //log.Println(c.err)
errch <- struct{}{}
break break
} }
}
wantRead = int(C.nghttp2_session_want_read(c.session))
select {
case d := <-datach:
d1 := C.CBytes(d)
ret1 := C.nghttp2_session_mem_recv(c.session,
(*C.uchar)(d1), C.size_t(int(len(d))))
C.free(d1)
if int(ret1) < 0 {
c.err = fmt.Errorf("sesion recv error: %s",
C.GoString(C.nghttp2_strerror(ret)))
//log.Println(c.err)
break loop
}
shouldDelay = false
default:
// want read but data not avaliable
if wantRead != 0 {
shouldDelay = true
}
}
c.lock.Lock()
wantWrite = int(C.nghttp2_session_want_write(c.session)) wantWrite = int(C.nghttp2_session_want_write(c.session))
c.lock.Unlock()
// make delay when no data read/write // make delay when no data read/write
if (shouldDelay || wantRead == 0) && wantWrite == 0 { if wantWrite == 0 {
time.Sleep(delay) time.Sleep(delay)
} }
} }

@ -6,6 +6,7 @@ package nghttp2
import "C" import "C"
import ( import (
"bytes" "bytes"
"io"
"sync" "sync"
"time" "time"
"unsafe" "unsafe"
@ -19,8 +20,10 @@ type dataProvider struct {
buf *bytes.Buffer buf *bytes.Buffer
closed bool closed bool
lock *sync.Mutex lock *sync.Mutex
sessLock *sync.Mutex
session *C.nghttp2_session session *C.nghttp2_session
streamID int streamID int
deferred bool
} }
// Read read from data provider // Read read from data provider
@ -39,7 +42,15 @@ func (dp *dataProvider) Read(buf []byte) (n int, err error) {
func (dp *dataProvider) Write(buf []byte) (n int, err error) { func (dp *dataProvider) Write(buf []byte) (n int, err error) {
dp.lock.Lock() dp.lock.Lock()
defer dp.lock.Unlock() defer dp.lock.Unlock()
if dp.closed {
return 0, io.EOF
}
if dp.deferred {
dp.sessLock.Lock()
C.nghttp2_session_resume_data(dp.session, C.int(dp.streamID)) C.nghttp2_session_resume_data(dp.session, C.int(dp.streamID))
dp.sessLock.Unlock()
dp.deferred = false
}
return dp.buf.Write(buf) return dp.buf.Write(buf)
} }
@ -47,16 +58,26 @@ func (dp *dataProvider) Write(buf []byte) (n int, err error) {
func (dp *dataProvider) Close() error { func (dp *dataProvider) Close() error {
dp.lock.Lock() dp.lock.Lock()
defer dp.lock.Unlock() defer dp.lock.Unlock()
if dp.closed {
return nil
}
dp.closed = true dp.closed = true
//log.Printf("dp close stream %d", dp.streamID)
if dp.deferred {
dp.sessLock.Lock()
C.nghttp2_session_resume_data(dp.session, C.int(dp.streamID)) C.nghttp2_session_resume_data(dp.session, C.int(dp.streamID))
dp.sessLock.Unlock()
dp.deferred = false
}
return nil return nil
} }
func newDataProvider() ( func newDataProvider(sessionLock *sync.Mutex) (
*dataProvider, *C.nghttp2_data_provider) { *dataProvider, *C.nghttp2_data_provider) {
dp := &dataProvider{ dp := &dataProvider{
buf: new(bytes.Buffer), buf: new(bytes.Buffer),
lock: new(sync.Mutex), lock: new(sync.Mutex),
sessLock: sessionLock,
} }
cdp := C.new_data_provider(C.size_t(uintptr(unsafe.Pointer(dp)))) cdp := C.new_data_provider(C.size_t(uintptr(unsafe.Pointer(dp))))
return dp, cdp return dp, cdp

@ -9,11 +9,14 @@ import (
"io" "io"
"net/http" "net/http"
"strings" "strings"
"sync"
"unsafe"
) )
// ClientStream http2 client stream // ClientStream http2 client stream
type ClientStream struct { type ClientStream struct {
streamID int streamID int
conn *ClientConn
cdp *C.nghttp2_data_provider cdp *C.nghttp2_data_provider
dp *dataProvider dp *dataProvider
// application read data from stream // application read data from stream
@ -24,15 +27,27 @@ type ClientStream struct {
resch chan *http.Response resch chan *http.Response
errch chan error errch chan error
closed bool closed bool
lock *sync.Mutex
}
// Response return response of the current stream
func (s *ClientStream) Response() *http.Response {
return s.res
} }
// Read read stream data // Read read stream data
func (s *ClientStream) Read(buf []byte) (n int, err error) { func (s *ClientStream) Read(buf []byte) (n int, err error) {
if s.closed || s.res == nil || s.res.Body == nil {
return 0, io.EOF
}
return s.res.Body.Read(buf) return s.res.Body.Read(buf)
} }
// Write write data to stream // Write write data to stream
func (s *ClientStream) Write(buf []byte) (n int, err error) { func (s *ClientStream) Write(buf []byte) (n int, err error) {
if s.closed {
return 0, io.EOF
}
if s.dp != nil { if s.dp != nil {
return s.dp.Write(buf) return s.dp.Write(buf)
} }
@ -41,26 +56,44 @@ func (s *ClientStream) Write(buf []byte) (n int, err error) {
// Close close the stream // Close close the stream
func (s *ClientStream) Close() error { func (s *ClientStream) Close() error {
//s.lock.Lock()
//defer s.lock.Unlock()
if s.closed { if s.closed {
return nil return nil
} }
s.closed = true
err := io.EOF err := io.EOF
//log.Println("close stream") //log.Printf("close stream %d", int(s.streamID))
select { select {
case s.errch <- err: case s.errch <- err:
default: default:
} }
//log.Println("close stream resch") //log.Println("close stream resch")
close(s.resch) //close(s.resch)
//log.Println("close stream errch") //log.Println("close stream errch")
close(s.errch) //close(s.errch)
//log.Println("close pipe w") //log.Println("close pipe w")
if s.res != nil && s.res.Body != nil {
s.res.Body.Close() s.res.Body.Close()
}
//log.Println("close stream done") //log.Println("close stream done")
if s.dp != nil { if s.dp != nil {
s.dp.Close() s.dp.Close()
//s.conn.lock.Lock()
//C.nghttp2_submit_rst_stream(s.conn.session, 0, C.int(s.streamID), 0)
//s.conn.lock.Unlock()
s.dp = nil
} }
s.closed = true if s.cdp != nil {
C.free(unsafe.Pointer(s.cdp))
s.cdp = nil
}
//s.conn.lock.Lock()
//defer s.conn.lock.Unlock()
//if _, ok := s.conn.streams[s.streamID]; ok {
//delete(s.conn.streams, s.streamID)
//}
return nil return nil
} }
@ -94,27 +127,26 @@ type ServerStream struct {
// Write write data to stream, // Write write data to stream,
// implements http.ResponseWriter // implements http.ResponseWriter
func (s *ServerStream) Write(buf []byte) (int, error) { func (s *ServerStream) Write(buf []byte) (int, error) {
if !s.responseSend { if s.closed {
s.WriteHeader(http.StatusOK) return 0, io.EOF
}
/*
//log.Printf("stream %d, send %d bytes", s.streamID, len(buf))
if s.buf.Len() > 2048 {
s.dp.Write(s.buf.Bytes())
s.buf.Reset()
} }
if len(buf) < 2048 { if !s.responseSend {
s.buf.Write(buf) s.WriteHeader(http.StatusOK)
return len(buf), nil
} }
*/
return s.dp.Write(buf) return s.dp.Write(buf)
} }
// WriteHeader set response code and send reponse, // WriteHeader set response code and send reponse,
// implements http.ResponseWriter // implements http.ResponseWriter
func (s *ServerStream) WriteHeader(code int) { func (s *ServerStream) WriteHeader(code int) {
if s.closed {
return
}
if s.responseSend {
return
}
s.responseSend = true
s.statusCode = code s.statusCode = code
nvIndex := 0 nvIndex := 0
nvMax := 25 nvMax := 25
@ -132,18 +164,21 @@ func (s *ServerStream) WriteHeader(code int) {
} }
var dp *dataProvider var dp *dataProvider
var cdp *C.nghttp2_data_provider var cdp *C.nghttp2_data_provider
dp, cdp = newDataProvider() dp, cdp = newDataProvider(s.conn.lock)
dp.streamID = s.streamID dp.streamID = s.streamID
dp.session = s.conn.session dp.session = s.conn.session
s.dp = dp s.dp = dp
s.cdp = cdp s.cdp = cdp
s.conn.lock.Lock()
ret := C.nghttp2_submit_response( ret := C.nghttp2_submit_response(
s.conn.session, C.int(s.streamID), nva.nv, C.size_t(nvIndex), cdp) s.conn.session, C.int(s.streamID), nva.nv, C.size_t(nvIndex), cdp)
s.conn.lock.Unlock()
C.delete_nv_array(nva) C.delete_nv_array(nva)
if int(ret) < 0 { if int(ret) < 0 {
panic(fmt.Sprintf("sumit response error %s", C.GoString(C.nghttp2_strerror(ret)))) panic(fmt.Sprintf("sumit response error %s", C.GoString(C.nghttp2_strerror(ret))))
} }
s.responseSend = true
//log.Printf("stream %d send response", s.streamID) //log.Printf("stream %d send response", s.streamID)
} }
@ -161,14 +196,19 @@ func (s *ServerStream) Close() error {
if s.closed { if s.closed {
return nil return nil
} }
s.closed = true
//C.nghttp2_submit_rst_stream(s.conn.session, 0, C.int(s.streamID), 0) //C.nghttp2_submit_rst_stream(s.conn.session, 0, C.int(s.streamID), 0)
if s.req.Body != nil { if s.req.Body != nil {
s.req.Body.Close() s.req.Body.Close()
} }
if s.dp != nil { if s.dp != nil {
s.dp.Close() s.dp.Close()
s.dp = nil
}
if s.cdp != nil {
C.free(unsafe.Pointer(s.cdp))
s.cdp = nil
} }
s.closed = true
//log.Printf("stream %d closed", s.streamID) //log.Printf("stream %d closed", s.streamID)
return nil return nil
} }

Loading…
Cancel
Save