diff --git a/client.c b/client.c index 38bd807..7296c05 100644 --- a/client.c +++ b/client.c @@ -1,6 +1,5 @@ #include "_nghttp2.h" -#define ARRLEN(x) (sizeof(x) / sizeof(x[0])) // send_callback send data to network static ssize_t client_send_callback(nghttp2_session *session, const uint8_t *data, @@ -113,7 +112,7 @@ static int on_client_frame_recv_callback(nghttp2_session *session, if (frame->headers.cat == NGHTTP2_HCAT_RESPONSE) { //fprintf(stderr, "All headers received\n"); - OnClientFrameRecvCallback(user_data, frame->hd.stream_id); + OnClientHeadersDoneCallback(user_data, frame->hd.stream_id); } break; case NGHTTP2_RST_STREAM: @@ -140,7 +139,7 @@ static int on_client_stream_close_callback(nghttp2_session *session, int32_t str return 0; } -ssize_t data_source_read_callback(nghttp2_session *session, int32_t stream_id, +static ssize_t data_source_read_callback(nghttp2_session *session, int32_t stream_id, uint8_t *buf, size_t length, uint32_t *data_flags, nghttp2_data_source *source, void *user_data) { diff --git a/client.go b/client.go index 06def36..b51aa13 100644 --- a/client.go +++ b/client.go @@ -6,6 +6,8 @@ package nghttp2 */ import "C" import ( + "bytes" + "errors" "fmt" "io" "log" @@ -18,6 +20,10 @@ import ( "unsafe" ) +var ( + errAgain = errors.New("again") +) + // ClientConn http2 connection type ClientConn struct { session *C.nghttp2_session @@ -49,7 +55,13 @@ type dataProvider struct { // drain the data r io.Reader // provider the data - w io.Writer + w io.Writer + datach chan []byte + errch chan error + buf *bytes.Buffer + run bool + streamID int + session *C.nghttp2_session } // NewClientConn create http2 client @@ -238,6 +250,10 @@ func (c *ClientConn) CreateRequest(req *http.Request) (*http.Response, error) { dp, cdp = newDataProvider(req.Body, nil) } streamID := C.submit_request(c.session, nva.nv, C.size_t(nvIndex), cdp) + if dp != nil { + dp.streamID = int(streamID) + dp.session = c.session + } C.delete_nv_array(nva) if int(streamID) < 0 { return nil, fmt.Errorf("submit request error: %s", @@ -282,10 +298,46 @@ func setNvArray(a *C.struct_nv_array, index int, cvalue, cnamelen, cvaluelen, cflags) } +func (dp *dataProvider) start() { + buf := make([]byte, 4096) + for { + n, err := dp.r.Read(buf) + if err != nil { + dp.errch <- err + break + } + dp.datach <- buf[:n] + C.nghttp2_session_resume_data(dp.session, C.int(dp.streamID)) + } +} + +// Read read from data provider +// this emulate a unblocking reading +// if data is not avaliable return errAgain func (dp *dataProvider) Read(buf []byte) (n int, err error) { - return dp.r.Read(buf) + if !dp.run { + go dp.start() + dp.run = true + time.Sleep(100 * time.Millisecond) + } + + select { + case err := <-dp.errch: + //log.Println("d err ", err) + return 0, err + case b := <-dp.datach: + dp.buf.Write(b) + default: + } + n, err = dp.buf.Read(buf) + if err != nil { + //log.Println(err) + return 0, errAgain + } + return } +// Write provider data for data provider func (dp *dataProvider) Write(buf []byte) (n int, err error) { if dp.w == nil { return 0, fmt.Errorf("write not supported") @@ -295,7 +347,12 @@ func (dp *dataProvider) Write(buf []byte) (n int, err error) { func newDataProvider(r io.Reader, w io.Writer) ( *dataProvider, *C.nghttp2_data_provider) { - dp := &dataProvider{r, w} + dp := &dataProvider{ + r: r, w: w, + errch: make(chan error), + datach: make(chan []byte), + buf: new(bytes.Buffer), + } cdp := C.new_data_provider(C.size_t(uintptr(unsafe.Pointer(dp)))) return dp, cdp } @@ -371,6 +428,10 @@ func DataSourceRead(ptr unsafe.Pointer, if err == io.EOF { return 0 } + if err == errAgain { + // NGHTTP2_ERR_DEFERED + return -508 + } return -1 } cbuf := C.CBytes(gobuf) @@ -446,9 +507,9 @@ func OnClientHeaderCallback(ptr unsafe.Pointer, streamID C.int, return 0 } -// OnClientFrameRecvCallback callback function for begion to recv data -//export OnClientFrameRecvCallback -func OnClientFrameRecvCallback(ptr unsafe.Pointer, streamID C.int) C.int { +// OnClientHeadersDoneCallback callback function for begion to recv data +//export OnClientHeadersDoneCallback +func OnClientHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int { //log.Println("frame recv") conn := (*ClientConn)(ptr) conn.onFrameRecv(int(streamID))