package nghttp2 /* #cgo pkg-config: libnghttp2 #include "_nghttp2.h" */ import "C" import ( "bytes" "errors" "fmt" "io" "net" "net/http" "net/url" "strings" "sync" "time" "unsafe" ) // Conn http2 connection type Conn struct { conn net.Conn session *C.nghttp2_session streams map[int]*stream streamCount int closed bool isServer bool running bool handler http.Handler lock *sync.Mutex err error errch chan error exitch chan struct{} } // Server create server side http2 connection func Server(c net.Conn, handler http.Handler) (*Conn, error) { conn := &Conn{ conn: c, handler: handler, errch: make(chan error), exitch: make(chan struct{}), lock: new(sync.Mutex), isServer: true, streams: make(map[int]*stream), } conn.session = C.init_nghttp2_server_session(C.size_t(uintptr(unsafe.Pointer(conn)))) if conn.session == nil { return nil, errors.New("init server session failed") } ret := C.send_connection_header(conn.session) if int(ret) < 0 { conn.Close() return nil, fmt.Errorf("send settings error: %s", C.GoString(C.nghttp2_strerror(ret))) } return conn, nil } // Client create client side http2 connection func Client(c net.Conn) (*Conn, error) { conn := &Conn{ conn: c, errch: make(chan error), exitch: make(chan struct{}), lock: new(sync.Mutex), streams: make(map[int]*stream), } conn.session = C.init_nghttp2_client_session(C.size_t(uintptr(unsafe.Pointer(conn)))) if conn.session == nil { return nil, errors.New("init server session failed") } ret := C.send_connection_header(conn.session) if int(ret) < 0 { conn.Close() return nil, fmt.Errorf("send settings error: %s", C.GoString(C.nghttp2_strerror(ret))) } go conn.Run() return conn, nil } // Error return conn error func (c *Conn) Error() error { c.lock.Lock() defer c.lock.Unlock() return c.err } // CanTakeNewRequest check if conn can create new request func (c *Conn) CanTakeNewRequest() bool { if c.streamCount > ((2<<31 - 1) / 2) { return false } if c.err != nil { return false } return true } // RoundTrip submit http request and return the response func (c *Conn) RoundTrip(req *http.Request) (*http.Response, error) { nv := []C.nghttp2_nv{} nv = append(nv, newNV(":method", req.Method)) nv = append(nv, newNV(":authority", req.Host)) nv = append(nv, newNV(":scheme", "https")) p := req.URL.Path q := req.URL.Query().Encode() if q != "" { p = fmt.Sprintf("%s?%s", p, q) } nv = append(nv, newNV(":path", p)) for k, v := range req.Header { _k := strings.ToLower(k) if _k == "connection" || _k == "proxy-connection" || _k == "transfer-encoding" { continue } nv = append(nv, newNV(k, v[0])) } cdp := C.nghttp2_data_provider{} dp := newDataProvider(unsafe.Pointer(&cdp), c.lock, 1) dp.session = c.session var _cdp *C.nghttp2_data_provider if req.Method == "POST" || req.Method == "PUT" { _cdp = &cdp } s, err := c.submitRequest(nv, _cdp) if err != nil { return nil, err } s.dp = dp s.dp.streamID = s.streamID c.lock.Lock() c.streams[s.streamID] = s c.streamCount++ c.lock.Unlock() if req.Method == "POST" || req.Method == "PUT" { go func() { io.Copy(dp, req.Body) dp.Close() }() } select { case res := <-s.resch: /* if res.StatusCode != http.StatusOK { return nil, fmt.Errorf("http error code %d", res.StatusCode) } */ s.request = req res.Request = s.request return res, nil case <-c.exitch: return nil, errors.New("connection closed") } } func (c *Conn) submitRequest(nv []C.nghttp2_nv, cdp *C.nghttp2_data_provider) (*stream, error) { c.lock.Lock() ret := C._nghttp2_submit_request(c.session, nil, C.size_t(uintptr(unsafe.Pointer(&nv[0]))), C.size_t(len(nv)), cdp, nil) c.lock.Unlock() if int(ret) < 0 { return nil, fmt.Errorf("submit request error: %s", C.GoString(C.nghttp2_strerror(ret))) } streamID := int(ret) s := &stream{ streamID: streamID, conn: c, bp: &bodyProvider{ buf: new(bytes.Buffer), lock: new(sync.Mutex), }, resch: make(chan *http.Response), } if cdp != nil { s.cdp = *cdp } return s, nil } // Connect submit connect request func (c *Conn) Connect(addr string) (net.Conn, int, error) { nv := []C.nghttp2_nv{} nv = append(nv, newNV(":method", "CONNECT")) nv = append(nv, newNV(":authority", addr)) cdp := C.nghttp2_data_provider{} dp := newDataProvider(unsafe.Pointer(&cdp), c.lock, 1) dp.session = c.session s, err := c.submitRequest(nv, &cdp) if err != nil { return nil, http.StatusBadGateway, err } s.dp = dp c.lock.Lock() c.streams[s.streamID] = s c.streamCount++ c.lock.Unlock() s.dp.streamID = s.streamID select { case res := <-s.resch: if res.StatusCode != http.StatusOK { return nil, res.StatusCode, fmt.Errorf("http error code %d", res.StatusCode) } s.request = &http.Request{ Method: "CONNECT", Host: addr, RequestURI: addr, URL: &url.URL{}, } res.Request = s.request return s, res.StatusCode, nil case <-c.exitch: return nil, http.StatusServiceUnavailable, errors.New("connection closed") } } // Run run the event loop func (c *Conn) Run() { if c.running { return } c.running = true defer c.Close() go c.readloop() go c.writeloop() for { select { case err := <-c.errch: c.err = err return case <-c.exitch: return } } } func (c *Conn) serve(s *stream) { var handler = c.handler if handler == nil { handler = http.DefaultServeMux } handler.ServeHTTP(s, s.request) s.Close() } // Close close the connection func (c *Conn) Close() error { if c.closed { return nil } c.closed = true for _, s := range c.streams { s.Close() } c.lock.Lock() C.nghttp2_session_terminate_session(c.session, 0) C.nghttp2_session_del(c.session) c.lock.Unlock() close(c.exitch) c.conn.Close() return nil } func (c *Conn) errorNotify(err error) { select { case c.errch <- err: default: } } func (c *Conn) readloop() { type data struct { buf []byte err error } var ret C.ssize_t var err error var d data datach := make(chan data) go func() { d1 := data{} var n int var err1 error for { buf := make([]byte, 16*1024) n, err1 = c.conn.Read(buf) d1.buf = buf[:n] d1.err = err1 datach <- d1 } }() for { select { case <-c.exitch: return case d = <-datach: if d.err != nil { c.errorNotify(d.err) return } c.lock.Lock() ret = C.nghttp2_session_mem_recv(c.session, (*C.uchar)(unsafe.Pointer(&d.buf[0])), C.size_t(len(d.buf))) c.lock.Unlock() if int(ret) < 0 { err = fmt.Errorf("http2 recv error: %s", C.GoString(C.nghttp2_strerror(C.int(ret)))) c.errorNotify(err) return } } } } func (c *Conn) writeloop() { var ret C.int var err error var delay = 50 * time.Millisecond for { select { case <-c.exitch: return default: } c.lock.Lock() ret = C.nghttp2_session_send(c.session) c.lock.Unlock() if int(ret) < 0 { err = fmt.Errorf("http2 send error: %s", C.GoString(C.nghttp2_strerror(C.int(ret)))) c.errorNotify(err) return } c.lock.Lock() wantWrite := C.nghttp2_session_want_write(c.session) c.lock.Unlock() if int(wantWrite) == 0 { //log.Println("write loop, sleep") time.Sleep(delay) } } }