diff --git a/callbacks.go b/callbacks.go index 8823636..51a49d3 100644 --- a/callbacks.go +++ b/callbacks.go @@ -5,12 +5,15 @@ package nghttp2 */ import "C" import ( + "bytes" "crypto/tls" "errors" "io" "net/http" + "net/url" "strconv" "strings" + "sync" "unsafe" ) @@ -49,7 +52,7 @@ func onDataSourceReadCallback(ptr unsafe.Pointer, streamID C.int, } if err == errAgain { //log.Println("onDataSourceReadCallback end") - s.dp.deferred = true + //s.dp.deferred = true return NGHTTP2_ERR_DEFERRED } //log.Println("onDataSourceReadCallback end") @@ -82,7 +85,7 @@ func onDataChunkRecv(ptr unsafe.Pointer, streamID C.int, //log.Println("onDataChunkRecv end") return C.int(length) } - + //log.Println("bp write") n, err := s.bp.Write(gobuf) if err != nil { return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE @@ -118,34 +121,47 @@ func onBeginHeaderCallback(ptr unsafe.Pointer, streamID C.int) C.int { //log.Printf("stream %d begin headers", int(streamID)) conn := (*Conn)(unsafe.Pointer(uintptr(ptr))) - s, ok := conn.streams[int(streamID)] - if !ok { - //log.Println("onBeginHeaderCallback end") - return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE - } - var TLS tls.ConnectionState - if tlsconn, ok := conn.conn.(*tls.Conn); ok { - TLS = tlsconn.ConnectionState() - } - if conn.isServer { - s.request = &http.Request{ - Header: make(http.Header), + // client + if !conn.isServer { + s, ok := conn.streams[int(streamID)] + if !ok { + //log.Println("onBeginHeaderCallback end") + return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE + } + var TLS tls.ConnectionState + if tlsconn, ok := conn.conn.(*tls.Conn); ok { + TLS = tlsconn.ConnectionState() + } + s.response = &http.Response{ Proto: "HTTP/2", ProtoMajor: 2, ProtoMinor: 0, - TLS: &TLS, + Header: make(http.Header), Body: s.bp, + TLS: &TLS, } return NGHTTP2_NO_ERROR } - s.response = &http.Response{ - Proto: "HTTP/2", - ProtoMajor: 2, - ProtoMinor: 0, - Header: make(http.Header), - Body: s.bp, - TLS: &TLS, + + // server + s := &stream{ + streamID: int(streamID), + conn: conn, + bp: &bodyProvider{ + buf: new(bytes.Buffer), + lock: new(sync.Mutex), + }, + request: &http.Request{ + Header: make(http.Header), + Proto: "HTTP/2", + ProtoMajor: 2, + ProtoMinor: 0, + }, } + s.request.Body = s.bp + + conn.streams[int(streamID)] = s + //log.Println("onBeginHeaderCallback end") return NGHTTP2_NO_ERROR } @@ -157,7 +173,7 @@ func onHeaderCallback(ptr unsafe.Pointer, streamID C.int, name unsafe.Pointer, namelen C.int, value unsafe.Pointer, valuelen C.int) C.int { //log.Println("onHeaderCallback begin") - //log.Println("header") + //log.Printf("header %d", int(streamID)) conn := (*Conn)(unsafe.Pointer(uintptr(ptr))) goname := string(C.GoBytes(name, namelen)) govalue := string(C.GoBytes(value, valuelen)) @@ -176,7 +192,16 @@ func onHeaderCallback(ptr unsafe.Pointer, streamID C.int, s.request.Host = govalue case ":path": s.request.RequestURI = govalue + u, err := url.Parse(govalue) + if err != nil { + return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE + } + s.request.URL = u case ":status": + if s.response == nil { + //log.Println("empty response") + return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE + } statusCode, _ := strconv.Atoi(govalue) s.response.StatusCode = statusCode s.response.Status = http.StatusText(statusCode) @@ -208,7 +233,11 @@ func onHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int { //log.Println("onHeadersDoneCallback end") return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE } + s.headersEnd = true if conn.isServer { + if s.request.Method == "CONNECT" { + go conn.serve(s) + } return NGHTTP2_NO_ERROR } select { @@ -257,5 +286,19 @@ func onConnectionCloseCallback(ptr unsafe.Pointer) { //export onStreamEndCallback func onStreamEndCallback(ptr unsafe.Pointer, streamID C.int) { + conn := (*Conn)(unsafe.Pointer(uintptr(ptr))) + stream, ok := conn.streams[int(streamID)] + if !ok { + return + } + stream.streamEnd = true + stream.bp.Close() + + if stream.conn.isServer { + if stream.request.Method != "CONNECT" { + go conn.serve(stream) + } + return + } } diff --git a/conn.go b/conn.go index 2a2f4b3..9af5a7a 100644 --- a/conn.go +++ b/conn.go @@ -6,10 +6,13 @@ package nghttp2 */ import "C" import ( + "bytes" "errors" "fmt" + "io" "net" "net/http" + "net/url" "sync" "time" "unsafe" @@ -30,14 +33,189 @@ type Conn struct { exitch chan struct{} } +// Server create server side http2 connection +func Server(c net.Conn, handler http.Handler) (*Conn, error) { + conn := &Conn{ + conn: c, + handler: handler, + errch: make(chan error), + exitch: make(chan struct{}), + lock: new(sync.Mutex), + isServer: true, + streams: make(map[int]*stream), + } + conn.session = C.init_nghttp2_server_session(C.size_t(uintptr(unsafe.Pointer(conn)))) + if conn.session == nil { + return nil, errors.New("init server session failed") + } + ret := C.send_connection_header(conn.session) + if int(ret) < 0 { + conn.Close() + return nil, fmt.Errorf("send settings error: %s", C.GoString(C.nghttp2_strerror(ret))) + } + return conn, nil +} + +// Client create client side http2 connection +func Client(c net.Conn) (*Conn, error) { + conn := &Conn{ + conn: c, + errch: make(chan error), + exitch: make(chan struct{}), + lock: new(sync.Mutex), + streams: make(map[int]*stream), + } + conn.session = C.init_nghttp2_client_session(C.size_t(uintptr(unsafe.Pointer(conn)))) + if conn.session == nil { + return nil, errors.New("init server session failed") + } + ret := C.send_connection_header(conn.session) + if int(ret) < 0 { + conn.Close() + return nil, fmt.Errorf("send settings error: %s", C.GoString(C.nghttp2_strerror(ret))) + } + go conn.Run() + return conn, nil +} + +// Error return conn error +func (c *Conn) Error() error { + c.lock.Lock() + defer c.lock.Unlock() + return c.err +} + +// CanTakeNewRequest check if conn can create new request +func (c *Conn) CanTakeNewRequest() bool { + if c.streamCount > ((2<<31 - 1) / 2) { + return false + } + if c.err != nil { + return false + } + return true +} + // RoundTrip submit http request and return the response func (c *Conn) RoundTrip(req *http.Request) (*http.Response, error) { - return nil, errors.New("not implement") + nv := []C.nghttp2_nv{} + + nv = append(nv, newNV(":method", req.Method)) + nv = append(nv, newNV(":authority", req.Host)) + nv = append(nv, newNV(":scheme", "https")) + + p := req.URL.Path + q := req.URL.Query().Encode() + if q != "" { + p = fmt.Sprintf("%s?%s", p, q) + } + nv = append(nv, newNV(":path", p)) + + cdp := C.nghttp2_data_provider{} + dp := newDataProvider(unsafe.Pointer(&cdp), c.lock, 1) + dp.session = c.session + var _cdp *C.nghttp2_data_provider + if req.Method == "POST" || req.Method == "PUT" { + _cdp = &cdp + } + s, err := c.submitRequest(nv, _cdp) + if err != nil { + return nil, err + } + s.dp = dp + s.dp.streamID = s.streamID + + c.lock.Lock() + c.streams[s.streamID] = s + c.streamCount++ + c.lock.Unlock() + if req.Method == "POST" || req.Method == "PUT" { + go func() { + io.Copy(dp, req.Body) + dp.Close() + }() + } + select { + case res := <-s.resch: + /* + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("http error code %d", res.StatusCode) + } + */ + s.request = req + res.Request = s.request + return res, nil + case <-c.exitch: + return nil, errors.New("connection closed") + } +} + +func (c *Conn) submitRequest(nv []C.nghttp2_nv, cdp *C.nghttp2_data_provider) (*stream, error) { + + c.lock.Lock() + ret := C._nghttp2_submit_request(c.session, nil, + C.size_t(uintptr(unsafe.Pointer(&nv[0]))), C.size_t(len(nv)), cdp, nil) + c.lock.Unlock() + + if int(ret) < 0 { + return nil, fmt.Errorf("submit request error: %s", C.GoString(C.nghttp2_strerror(ret))) + } + streamID := int(ret) + s := &stream{ + streamID: streamID, + conn: c, + bp: &bodyProvider{ + buf: new(bytes.Buffer), + lock: new(sync.Mutex), + }, + resch: make(chan *http.Response), + } + if cdp != nil { + s.cdp = *cdp + } + return s, nil } // Connect submit connect request -func (c *Conn) Connect(addr string) (net.Conn, error) { - return nil, errors.New("not implement") +func (c *Conn) Connect(addr string) (net.Conn, int, error) { + nv := []C.nghttp2_nv{} + + nv = append(nv, newNV(":method", "CONNECT")) + nv = append(nv, newNV(":authority", addr)) + + cdp := C.nghttp2_data_provider{} + dp := newDataProvider(unsafe.Pointer(&cdp), c.lock, 1) + dp.session = c.session + + s, err := c.submitRequest(nv, &cdp) + if err != nil { + return nil, http.StatusBadGateway, err + } + s.dp = dp + c.lock.Lock() + c.streams[s.streamID] = s + c.streamCount++ + c.lock.Unlock() + + s.dp.streamID = s.streamID + + select { + case res := <-s.resch: + if res.StatusCode != http.StatusOK { + return nil, res.StatusCode, fmt.Errorf("http error code %d", res.StatusCode) + } + s.request = &http.Request{ + Method: "CONNECT", + Host: addr, + RequestURI: addr, + URL: &url.URL{}, + } + res.Request = s.request + return s, res.StatusCode, nil + case <-c.exitch: + return nil, http.StatusServiceUnavailable, errors.New("connection closed") + } + } // Run run the event loop @@ -58,6 +236,15 @@ func (c *Conn) Run() { } } +func (c *Conn) serve(s *stream) { + var handler = c.handler + if handler == nil { + handler = http.DefaultServeMux + } + handler.ServeHTTP(s, s.request) + s.Close() +} + // Close close the connection func (c *Conn) Close() error { if c.closed { @@ -67,6 +254,12 @@ func (c *Conn) Close() error { for _, s := range c.streams { s.Close() } + + c.lock.Lock() + C.nghttp2_session_terminate_session(c.session, 0) + C.nghttp2_session_del(c.session) + c.lock.Unlock() + close(c.exitch) c.conn.Close() return nil @@ -148,6 +341,7 @@ func (c *Conn) writeloop() { wantWrite := C.nghttp2_session_want_write(c.session) c.lock.Unlock() if int(wantWrite) == 0 { + //log.Println("write loop, sleep") time.Sleep(delay) } } diff --git a/data_provider.go b/data_provider.go index 9d1efb2..42b8c93 100644 --- a/data_provider.go +++ b/data_provider.go @@ -7,7 +7,6 @@ import "C" import ( "bytes" "errors" - "io" "log" "sync" "time" @@ -31,7 +30,7 @@ type dataProvider struct { // Read read from data provider func (dp *dataProvider) Read(buf []byte) (n int, err error) { if dp.buf == nil || dp.lock == nil || dp.sessLock == nil || dp.session == nil { - log.Println("db read invalid state") + log.Println("dp read invalid state") return 0, errors.New("invalid state") } dp.lock.Lock() @@ -39,6 +38,8 @@ func (dp *dataProvider) Read(buf []byte) (n int, err error) { n, err = dp.buf.Read(buf) if err != nil && !dp.closed { + //log.Println("deferred") + dp.deferred = true return 0, errAgain } return @@ -53,18 +54,20 @@ func (dp *dataProvider) Write(buf []byte) (n int, err error) { dp.lock.Lock() defer dp.lock.Unlock() - if dp.closed { - return 0, io.EOF - } + //if dp.closed { + // return 0, io.EOF + //} + n, err = dp.buf.Write(buf) if dp.deferred { dp.sessLock.Lock() C.nghttp2_session_resume_data(dp.session, C.int(dp.streamID)) dp.sessLock.Unlock() + //log.Println("resume") dp.deferred = false } - return dp.buf.Write(buf) + return } // Close end to provide data diff --git a/nghttp2.c b/nghttp2.c index 3b8f301..7af9e7f 100644 --- a/nghttp2.c +++ b/nghttp2.c @@ -23,8 +23,8 @@ int on_invalid_frame_recv_callback(nghttp2_session *session, return 0; } static ssize_t on_data_source_read_callback(nghttp2_session *session, int32_t stream_id, - uint8_t *buf, size_t length, uint32_t *data_flags, - nghttp2_data_source *source, void *user_data) + uint8_t *buf, size_t length, uint32_t *data_flags, + nghttp2_data_source *source, void *user_data) { int ret = onDataSourceReadCallback(user_data, stream_id, buf, length); if (ret == 0) @@ -48,7 +48,8 @@ static int on_frame_recv_callback(nghttp2_session *session, switch (frame->hd.type) { case NGHTTP2_HEADERS: - if (frame->headers.cat == NGHTTP2_HCAT_REQUEST) + if (frame->headers.cat == NGHTTP2_HCAT_REQUEST || + frame->headers.cat == NGHTTP2_HCAT_RESPONSE) { onHeadersDoneCallback(user_data, frame->hd.stream_id); } @@ -82,7 +83,8 @@ static int on_header_callback(nghttp2_session *session, switch (frame->hd.type) { case NGHTTP2_HEADERS: - if (frame->headers.cat == NGHTTP2_HCAT_REQUEST) + if (frame->headers.cat == NGHTTP2_HCAT_REQUEST || + frame->headers.cat == NGHTTP2_HCAT_RESPONSE) { onHeaderCallback(user_data, frame->hd.stream_id, (void *)name, namelen, (void *)value, valuelen); @@ -109,7 +111,8 @@ static int on_begin_headers_callback(nghttp2_session *session, switch (frame->hd.type) { case NGHTTP2_HEADERS: - if (frame->headers.cat == NGHTTP2_HCAT_REQUEST) + if (frame->headers.cat == NGHTTP2_HCAT_REQUEST || + frame->headers.cat == NGHTTP2_HCAT_RESPONSE) { onBeginHeaderCallback(user_data, frame->hd.stream_id); } @@ -123,6 +126,7 @@ nghttp2_session *init_nghttp2_server_session(size_t data) nghttp2_session_callbacks *callbacks; nghttp2_session *session; + nghttp2_session_callbacks_new(&callbacks); init_nghttp2_callbacks(callbacks); nghttp2_session_server_new(&session, callbacks, (void *)((int *)(data))); @@ -177,8 +181,26 @@ int send_connection_header(nghttp2_session *session) return rv; } -int data_provider_set_callback(size_t dp, size_t data, int t){ - nghttp2_data_provider *cdp = (nghttp2_data_provider*)dp; +int data_provider_set_callback(size_t dp, size_t data, int t) +{ + nghttp2_data_provider *cdp = (nghttp2_data_provider *)dp; cdp->source.ptr = (void *)data; - cdp->read_callback=on_data_source_read_callback; + cdp->read_callback = on_data_source_read_callback; +} + +int _nghttp2_submit_response(nghttp2_session *sess, int streamid, + size_t nv, size_t nvlen, nghttp2_data_provider *dp) +{ + + return nghttp2_submit_response(sess, streamid, (nghttp2_nv *)nv, nvlen, dp); +} + +int _nghttp2_submit_request(nghttp2_session *session, const nghttp2_priority_spec *pri_spec, + size_t nva, size_t nvlen, + const nghttp2_data_provider *data_prd, void *stream_user_data) +{ + + return nghttp2_submit_request(session, pri_spec, + (nghttp2_nv *)nva, nvlen, + data_prd, stream_user_data); } \ No newline at end of file diff --git a/stream.go b/stream.go index e43b234..17d6513 100644 --- a/stream.go +++ b/stream.go @@ -1,48 +1,147 @@ package nghttp2 +/* +#include "_nghttp2.h" +*/ +import "C" import ( "errors" + "fmt" "net" "net/http" + "strings" "time" + "unsafe" ) type stream struct { - streamID int - conn *Conn - dp *dataProvider - bp *bodyProvider - request *http.Request - response *http.Response - resch chan *http.Response + streamID int + conn *Conn + dp *dataProvider + bp *bodyProvider + request *http.Request + response *http.Response + resch chan *http.Response + headersEnd bool + streamEnd bool + closed bool + cdp C.nghttp2_data_provider } var _ net.Conn = &stream{} func (s *stream) Read(buf []byte) (int, error) { - return 0, errors.New("not implement") + if s.bp != nil { + return s.bp.Read(buf) + } + return 0, errors.New("empty body") +} + +func (s *stream) WriteHeader(code int) { + if s.response == nil { + s.response = &http.Response{ + Proto: "http/2", + ProtoMajor: 2, + ProtoMinor: 0, + Header: make(http.Header), + } + } + if s.response.StatusCode != 0 { + return + } + + s.response.StatusCode = code + s.response.Status = http.StatusText(code) + + nv := []C.nghttp2_nv{} + nv = append(nv, newNV(":status", fmt.Sprintf("%d", code))) + for k, v := range s.response.Header { + _k := strings.ToLower(k) + if _k == "host" || _k == "connection" || _k == "transfer-encoding" { + continue + } + nv = append(nv, newNV(k, v[0])) + } + + s.cdp = C.nghttp2_data_provider{} + s.dp = newDataProvider(unsafe.Pointer(&s.cdp), s.conn.lock, 0) + s.dp.session = s.conn.session + s.dp.streamID = s.streamID + + s.conn.lock.Lock() + ret := C._nghttp2_submit_response(s.conn.session, C.int(s.streamID), + C.size_t(uintptr(unsafe.Pointer(&nv[0]))), C.size_t(len(nv)), &s.cdp) + s.conn.lock.Unlock() + + if int(ret) < 0 { + panic(fmt.Sprintf("submit response error: %s", C.GoString(C.nghttp2_strerror(ret)))) + } +} + +func (s *stream) Header() http.Header { + if s.response == nil { + s.response = &http.Response{ + Proto: "http/2", + ProtoMajor: 2, + ProtoMinor: 0, + Header: make(http.Header), + } + } + return s.response.Header } + func (s *stream) Write(buf []byte) (int, error) { - if s.conn.isServer { - return 0, errors.New("not implement") + if s.conn.isServer && s.response == nil { + s.WriteHeader(http.StatusOK) + } + + if s.dp != nil { + return s.dp.Write(buf) } - return 0, errors.New("not implement") + return 0, errors.New("empty dp") } + func (s *stream) Close() error { + if s.closed { + return nil + } + s.closed = true + if s.dp != nil { + s.dp.Close() + } + if s.bp != nil { + s.bp.Close() + } + //s.conn.lock.Lock() + //if _, ok := s.conn.streams[s.streamID]; ok { + // delete(s.conn.streams, s.streamID) + ///} + //s.conn.lock.Unlock() + if s.request != nil && s.request.Method == "CONNECT" { + //log.Println("rst stream") + s.conn.lock.Lock() + C.nghttp2_submit_rst_stream(s.conn.session, 0, C.int(s.streamID), 8) + s.conn.lock.Unlock() + } return nil } + func (s *stream) LocalAddr() net.Addr { return nil } + func (s *stream) RemoteAddr() net.Addr { return nil } + func (s *stream) SetDeadline(t time.Time) error { return errors.New("not implement") } + func (s *stream) SetReadDeadline(t time.Time) error { return errors.New("not implement") } + func (s *stream) SetWriteDeadline(t time.Time) error { return errors.New("not implement") }