Compare commits

..

No commits in common. 'master' and 'merge_conn' have entirely different histories.

@ -10,7 +10,7 @@ server usage example:
cert, err := tls.LoadX509KeyPair("testdata/server.crt", "testdata/server.key") cert, err := tls.LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
if err != nil { if err != nil {
log.Fatal(err) t.Fatal(err)
} }
l, err := tls.Listen("tcp", "127.0.0.1:1100", &tls.Config{ l, err := tls.Listen("tcp", "127.0.0.1:1100", &tls.Config{
@ -18,7 +18,7 @@ server usage example:
NextProtos: []string{"h2"}, NextProtos: []string{"h2"},
}) })
if err != nil { if err != nil {
log.Fatal(err) t.Fatal(err)
} }
defer l.Close() defer l.Close()
addr := l.Addr().String() addr := l.Addr().String()
@ -30,7 +30,7 @@ server usage example:
hdr.Set("aa", "bb") hdr.Set("aa", "bb")
d, err := ioutil.ReadAll(r.Body) d, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
log.Println(err) ln(err)
return return
} }
w.Write(d) w.Write(d)
@ -43,7 +43,7 @@ server usage example:
} }
h2conn, err := NewServerConn(c, nil) h2conn, err := NewServerConn(c, nil)
if err != nil { if err != nil {
log.Fatal(err) t.Fatal(err)
} }
f("%+v", h2conn) f("%+v", h2conn)
go h2conn.Run() go h2conn.Run()
@ -57,18 +57,18 @@ client usage example:
ServerName: "nghttp2.org", ServerName: "nghttp2.org",
}) })
if err != nil { if err != nil {
log.Fatal(err) t.Fatal(err)
} }
defer conn.Close() defer conn.Close()
cstate := conn.ConnectionState() cstate := conn.ConnectionState()
if cstate.NegotiatedProtocol != "h2" { if cstate.NegotiatedProtocol != "h2" {
log.Fatal("no http2 on server") t.Fatal("no http2 on server")
} }
h2conn, err := NewClientConn(conn) h2conn, err := NewClientConn(conn)
if err != nil { if err != nil {
log.Fatal(err) t.Fatal(err)
} }
param := url.Values{} param := url.Values{}
@ -79,18 +79,18 @@ client usage example:
"https://nghttp2.org/httpbin/post?a=b&c=d", "https://nghttp2.org/httpbin/post?a=b&c=d",
data) data)
log.Printf("%+v", req) f("%+v", req)
req.Header.Set("user-agent", "go-nghttp2/1.0") req.Header.Set("user-agent", "go-nghttp2/1.0")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
res, err := h2conn.CreateRequest(req) res, err := h2conn.CreateRequest(req)
if err != nil { if err != nil {
log.Fatal(err) t.Fatal(err)
} }
if res.StatusCode != http.StatusOK { if res.StatusCode != http.StatusOK {
log.Printf("expect %d, got %d", http.StatusOK, res.StatusCode) t.Errorf("expect %d, got %d", http.StatusOK, res.StatusCode)
} }
res.Write(os.Stderr) res.Write(os.Stderr)
@ -99,7 +99,7 @@ co-work with net/http server example:
l, err := net.Listen("tcp", "127.0.0.1:1222") l, err := net.Listen("tcp", "127.0.0.1:1222")
if err != nil { if err != nil {
log.Fatal(err) t.Fatal(err)
} }
srv := &http.Server{ srv := &http.Server{
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{

@ -1,338 +1,319 @@
package nghttp2 package nghttp2
/* /*
#include "_nghttp2.h" #include "_nghttp2.h"
*/ */
import "C" import "C"
import ( import (
"bytes" "bytes"
"context" "crypto/tls"
"crypto/tls" "errors"
"errors" "io"
"io" "net/http"
"net/http" "net/url"
"net/url" "strconv"
"runtime" "strings"
"strconv" "sync"
"strings" "unsafe"
"sync" )
"unsafe"
) var (
errAgain = errors.New("again")
var ( )
errAgain = errors.New("again")
) const (
NGHTTP2_NO_ERROR = 0
const ( NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE = -521
NGHTTP2_NO_ERROR = 0 NGHTTP2_ERR_CALLBACK_FAILURE = -902
NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE = -521 NGHTTP2_ERR_DEFERRED = -508
NGHTTP2_ERR_CALLBACK_FAILURE = -902 )
NGHTTP2_ERR_DEFERRED = -508
) // onDataSourceReadCallback callback function for libnghttp2 library
// want read data from data provider source,
/* // return NGHTTP2_ERR_DEFERRED will cause data frame defered,
var bufPool = &sync.Pool{ // application later call nghttp2_session_resume_data will re-quene the data frame
New: func() interface{} { //
return make([]byte, 16*1024) //export onDataSourceReadCallback
}, func onDataSourceReadCallback(ptr unsafe.Pointer, streamID C.int,
} buf unsafe.Pointer, length C.size_t) C.ssize_t {
*/ //log.Println("onDataSourceReadCallback begin")
conn := (*Conn)(unsafe.Pointer(uintptr(ptr)))
// onDataSourceReadCallback callback function for libnghttp2 library s, ok := conn.streams[int(streamID)]
// want read data from data provider source, if !ok {
// return NGHTTP2_ERR_DEFERRED will cause data frame defered, //log.Println("client dp callback, stream not exists")
// application later call nghttp2_session_resume_data will re-quene the data frame return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE
// }
//export onDataSourceReadCallback gobuf := make([]byte, int(length))
func onDataSourceReadCallback(ptr unsafe.Pointer, streamID C.int, n, err := s.dp.Read(gobuf)
buf unsafe.Pointer, length C.size_t) C.ssize_t { if err != nil {
//log.Println("onDataSourceReadCallback begin") if err == io.EOF {
conn := (*Conn)(unsafe.Pointer(uintptr(ptr))) //log.Println("onDataSourceReadCallback end")
s, ok := conn.streams[int(streamID)] return 0
if !ok { }
//log.Println("client dp callback, stream not exists") if err == errAgain {
return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE //log.Println("onDataSourceReadCallback end")
} //s.dp.deferred = true
gobuf := make([]byte, int(length)) return NGHTTP2_ERR_DEFERRED
/* }
_length := int(length) //log.Println("onDataSourceReadCallback end")
gobuf := bufPool.Get().([]byte) return NGHTTP2_ERR_CALLBACK_FAILURE
if len(gobuf) < _length { }
gobuf = make([]byte, _length) //cbuf := C.CBytes(gobuf)
} //defer C.free(cbuf)
defer bufPool.Put(gobuf) //C.memcpy(buf, cbuf, C.size_t(n))
*/ C.memcpy(buf, unsafe.Pointer(&gobuf[0]), C.size_t(n))
//log.Println("onDataSourceReadCallback end")
n, err := s.dp.Read(gobuf[0:]) return C.ssize_t(n)
if err != nil { }
if err == io.EOF {
//log.Println("onDataSourceReadCallback end") // onDataChunkRecv callback function for libnghttp2 library data chunk received.
return 0 //
} //export onDataChunkRecv
if err == errAgain { func onDataChunkRecv(ptr unsafe.Pointer, streamID C.int,
//log.Println("onDataSourceReadCallback end") buf unsafe.Pointer, length C.size_t) C.int {
//s.dp.deferred = true //log.Println("onDataChunkRecv begin")
return NGHTTP2_ERR_DEFERRED conn := (*Conn)(unsafe.Pointer(uintptr(ptr)))
} gobuf := C.GoBytes(buf, C.int(length))
//log.Println("onDataSourceReadCallback end")
return NGHTTP2_ERR_CALLBACK_FAILURE s, ok := conn.streams[int(streamID)]
} if !ok {
//cbuf := C.CBytes(gobuf) //log.Println("onDataChunkRecv end")
//defer C.free(cbuf) return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE
//C.memcpy(buf, cbuf, C.size_t(n)) }
C.memcpy(buf, unsafe.Pointer(&gobuf[0]), C.size_t(n)) if s.bp == nil {
//log.Println("onDataSourceReadCallback end") //log.Println("empty body")
return C.ssize_t(n) //log.Println("onDataChunkRecv end")
} return C.int(length)
}
// onDataChunkRecv callback function for libnghttp2 library data chunk received. //log.Println("bp write")
// n, err := s.bp.Write(gobuf)
//export onDataChunkRecv if err != nil {
func onDataChunkRecv(ptr unsafe.Pointer, streamID C.int, return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE
buf unsafe.Pointer, length C.size_t) C.int { }
//log.Println("onDataChunkRecv begin") //log.Println("onDataChunkRecv end")
conn := (*Conn)(unsafe.Pointer(uintptr(ptr))) return C.int(n)
gobuf := C.GoBytes(buf, C.int(length)) }
s, ok := conn.streams[int(streamID)] // onDataSendCallback callback function for libnghttp2 library want send data to network.
if !ok { //
//log.Println("onDataChunkRecv end") //export onDataSendCallback
return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE func onDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer, size C.size_t) C.ssize_t {
} //log.Println("onDataSendCallback begin")
if s.bp == nil { //log.Println("data write req ", int(size))
//log.Println("empty body") conn := (*Conn)(unsafe.Pointer(uintptr(ptr)))
//log.Println("onDataChunkRecv end") buf := C.GoBytes(data, C.int(size))
return C.int(length) //log.Println(conn.conn.RemoteAddr())
} n, err := conn.conn.Write(buf)
//log.Println("bp write") if err != nil {
n, err := s.bp.Write(gobuf) //log.Println("onDataSendCallback end")
if err != nil { return NGHTTP2_ERR_CALLBACK_FAILURE
return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE }
} //log.Printf("write %d bytes to network ", n)
//log.Println("onDataChunkRecv end") //log.Println("onDataSendCallback end")
return C.int(n) return C.ssize_t(n)
} }
// onDataSendCallback callback function for libnghttp2 library want send data to network. // onBeginHeaderCallback callback function for begin header receive.
// //
//export onDataSendCallback //export onBeginHeaderCallback
func onDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer, func onBeginHeaderCallback(ptr unsafe.Pointer, streamID C.int) C.int {
size C.size_t) C.ssize_t { //log.Println("onBeginHeaderCallback begin")
//log.Println("onDataSendCallback begin") //log.Printf("stream %d begin headers", int(streamID))
//log.Println("data write req ", int(size)) conn := (*Conn)(unsafe.Pointer(uintptr(ptr)))
conn := (*Conn)(unsafe.Pointer(uintptr(ptr)))
buf := C.GoBytes(data, C.int(size)) var TLS tls.ConnectionState
//log.Println(conn.conn.RemoteAddr()) if tlsconn, ok := conn.conn.(*tls.Conn); ok {
n, err := conn.conn.Write(buf) TLS = tlsconn.ConnectionState()
if err != nil { }
//log.Println("onDataSendCallback end") // client
return NGHTTP2_ERR_CALLBACK_FAILURE if !conn.isServer {
} s, ok := conn.streams[int(streamID)]
//log.Printf("write %d bytes to network ", n) if !ok {
//log.Println("onDataSendCallback end") //log.Println("onBeginHeaderCallback end")
return C.ssize_t(n) return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE
} }
s.response = &http.Response{
// onBeginHeaderCallback callback function for begin header receive. Proto: "HTTP/2",
// ProtoMajor: 2,
//export onBeginHeaderCallback ProtoMinor: 0,
func onBeginHeaderCallback(ptr unsafe.Pointer, streamID C.int) C.int { Header: make(http.Header),
//log.Println("onBeginHeaderCallback begin") Body: s.bp,
//log.Printf("stream %d begin headers", int(streamID)) TLS: &TLS,
conn := (*Conn)(unsafe.Pointer(uintptr(ptr))) }
return NGHTTP2_NO_ERROR
var TLS tls.ConnectionState }
if tlsconn, ok := conn.conn.(*tls.Conn); ok {
TLS = tlsconn.ConnectionState() // server
} s := &stream{
// client streamID: int(streamID),
if !conn.isServer { conn: conn,
s, ok := conn.streams[int(streamID)] bp: &bodyProvider{
if !ok { buf: new(bytes.Buffer),
//log.Println("onBeginHeaderCallback end") lock: new(sync.Mutex),
return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE },
} request: &http.Request{
s.response = &http.Response{ Header: make(http.Header),
Proto: "HTTP/2", Proto: "HTTP/2",
ProtoMajor: 2, ProtoMajor: 2,
ProtoMinor: 0, ProtoMinor: 0,
Header: make(http.Header), TLS: &TLS,
Body: s.bp, },
TLS: &TLS, }
} s.request.Body = s.bp
return NGHTTP2_NO_ERROR
} conn.streams[int(streamID)] = s
// server //log.Println("onBeginHeaderCallback end")
s := &stream{ return NGHTTP2_NO_ERROR
streamID: int(streamID), }
conn: conn,
bp: &bodyProvider{ // onHeaderCallback callback function for each header received.
buf: new(bytes.Buffer), //
lock: new(sync.Mutex), //export onHeaderCallback
}, func onHeaderCallback(ptr unsafe.Pointer, streamID C.int,
request: &http.Request{ name unsafe.Pointer, namelen C.int,
Header: make(http.Header), value unsafe.Pointer, valuelen C.int) C.int {
Proto: "HTTP/2", //log.Println("onHeaderCallback begin")
ProtoMajor: 2, //log.Printf("header %d", int(streamID))
ProtoMinor: 0, conn := (*Conn)(unsafe.Pointer(uintptr(ptr)))
TLS: &TLS, goname := string(C.GoBytes(name, namelen))
}, govalue := string(C.GoBytes(value, valuelen))
}
s.ctx, s.cancel = context.WithCancel(context.Background()) s, ok := conn.streams[int(streamID)]
s.request.Body = s.bp if !ok {
//log.Printf("new stream %d", int(streamID)) //log.Println("onHeaderCallback end")
conn.streams[int(streamID)] = s return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE
}
runtime.SetFinalizer(s, (*stream).free) var header http.Header
if conn.isServer {
//log.Println("onBeginHeaderCallback end") header = s.request.Header
return NGHTTP2_NO_ERROR } else {
} header = s.response.Header
}
// onHeaderCallback callback function for each header received. goname = strings.ToLower(goname)
// switch goname {
//export onHeaderCallback case ":method":
func onHeaderCallback(ptr unsafe.Pointer, streamID C.int, s.request.Method = govalue
name unsafe.Pointer, namelen C.int, case ":scheme":
value unsafe.Pointer, valuelen C.int) C.int { case ":authority":
//log.Println("onHeaderCallback begin") s.request.Host = govalue
//log.Printf("header %d", int(streamID)) case ":path":
conn := (*Conn)(unsafe.Pointer(uintptr(ptr))) s.request.RequestURI = govalue
goname := string(C.GoBytes(name, namelen)) u, err := url.Parse(govalue)
govalue := string(C.GoBytes(value, valuelen)) if err != nil {
return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE
s, ok := conn.streams[int(streamID)] }
if !ok { s.request.URL = u
//log.Println("onHeaderCallback end") case ":status":
return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE if s.response == nil {
} //log.Println("empty response")
var header http.Header return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE
if conn.isServer { }
header = s.request.Header statusCode, _ := strconv.Atoi(govalue)
} else { s.response.StatusCode = statusCode
header = s.response.Header s.response.Status = http.StatusText(statusCode)
} case "content-length":
goname = strings.ToLower(goname) header.Add(goname, govalue)
switch goname { n, err := strconv.ParseInt(govalue, 10, 64)
case ":method": if err == nil {
s.request.Method = govalue if conn.isServer {
case ":scheme": s.request.ContentLength = n
case ":authority": } else {
s.request.Host = govalue s.response.ContentLength = n
case ":path": }
s.request.RequestURI = govalue }
u, err := url.Parse(govalue) case "transfer-encoding":
if err != nil { header.Add(goname, govalue)
return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE if conn.isServer {
} s.request.TransferEncoding = append(s.response.TransferEncoding, govalue)
s.request.URL = u } else {
case ":status": s.response.TransferEncoding = append(s.response.TransferEncoding, govalue)
if s.response == nil { }
//log.Println("empty response") default:
return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE header.Add(goname, govalue)
} }
statusCode, _ := strconv.Atoi(govalue) //log.Println("onHeaderCallback end")
s.response.StatusCode = statusCode return NGHTTP2_NO_ERROR
s.response.Status = http.StatusText(statusCode) }
case "content-length":
header.Add(goname, govalue) // onHeadersDoneCallback callback function for the stream when all headers received.
n, err := strconv.ParseInt(govalue, 10, 64) //
if err == nil { //export onHeadersDoneCallback
if conn.isServer { func onHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int {
s.request.ContentLength = n //log.Println("onHeadersDoneCallback begin")
} else { //log.Printf("stream %d headers done", int(streamID))
s.response.ContentLength = n conn := (*Conn)(unsafe.Pointer(uintptr(ptr)))
} s, ok := conn.streams[int(streamID)]
} if !ok {
case "transfer-encoding": //log.Println("onHeadersDoneCallback end")
header.Add(goname, govalue) return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE
if conn.isServer { }
s.request.TransferEncoding = append( s.headersEnd = true
s.response.TransferEncoding, govalue) if conn.isServer {
} else { if s.request.Method == "CONNECT" {
s.response.TransferEncoding = append( go conn.serve(s)
s.response.TransferEncoding, govalue) }
} return NGHTTP2_NO_ERROR
default: }
header.Add(goname, govalue) select {
} case s.resch <- s.response:
//log.Println("onHeaderCallback end") default:
return NGHTTP2_NO_ERROR }
} //log.Println("onHeadersDoneCallback end")
return NGHTTP2_NO_ERROR
// onHeadersDoneCallback callback function for the stream when all headers received. }
//
//export onHeadersDoneCallback // onStreamClose callback function for the stream when closed.
func onHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int { //
//log.Println("onHeadersDoneCallback begin") //export onStreamClose
//log.Printf("stream %d headers done", int(streamID)) func onStreamClose(ptr unsafe.Pointer, streamID C.int) C.int {
conn := (*Conn)(unsafe.Pointer(uintptr(ptr))) //log.Println("onStreamClose begin")
s, ok := conn.streams[int(streamID)] //log.Printf("stream %d closed", int(streamID))
if !ok { conn := (*Conn)(unsafe.Pointer(uintptr(ptr)))
//log.Println("onHeadersDoneCallback end")
return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE stream, ok := conn.streams[int(streamID)]
} if ok {
s.headersEnd = true go stream.Close()
if conn.isServer { //conn.lock.Lock()
if s.request.Method == "CONNECT" { delete(conn.streams, int(streamID))
go conn.serve(s) //go stream.Close()
} //conn.lock.Unlock()
return NGHTTP2_NO_ERROR //log.Println("onStreamClose end")
} return NGHTTP2_NO_ERROR
select { }
case s.resch <- s.response: //log.Println("onStreamClose end")
default: return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE
} }
//log.Println("onHeadersDoneCallback end")
return NGHTTP2_NO_ERROR //export onConnectionCloseCallback
} func onConnectionCloseCallback(ptr unsafe.Pointer) {
conn := (*Conn)(unsafe.Pointer(uintptr(ptr)))
// onStreamClose callback function for the stream when closed. conn.err = io.EOF
//
//export onStreamClose // signal all goroutings exit
func onStreamClose(ptr unsafe.Pointer, streamID C.int) C.int { for i := 0; i < 4; i++ {
//log.Println("onStreamClose begin") select {
//log.Printf("stream %d closed", int(streamID)) case conn.exitch <- struct{}{}:
conn := (*Conn)(unsafe.Pointer(uintptr(ptr))) default:
}
stream, ok := conn.streams[int(streamID)] }
if ok { }
go stream.Close()
//log.Printf("remove stream %d", int(streamID)) //export onStreamEndCallback
//conn.lock.Lock() func onStreamEndCallback(ptr unsafe.Pointer, streamID C.int) {
delete(conn.streams, int(streamID)) conn := (*Conn)(unsafe.Pointer(uintptr(ptr)))
//go stream.Close() stream, ok := conn.streams[int(streamID)]
//conn.lock.Unlock() if !ok {
//log.Println("onStreamClose end") return
return NGHTTP2_NO_ERROR }
} stream.streamEnd = true
//log.Println("onStreamClose end")
return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE stream.bp.Close()
}
if stream.conn.isServer {
//export onConnectionCloseCallback if stream.request.Method != "CONNECT" {
func onConnectionCloseCallback(ptr unsafe.Pointer) { go conn.serve(stream)
conn := (*Conn)(unsafe.Pointer(uintptr(ptr))) }
conn.err = io.EOF return
conn.Close() }
} }
//export onStreamEndCallback
func onStreamEndCallback(ptr unsafe.Pointer, streamID C.int) {
conn := (*Conn)(unsafe.Pointer(uintptr(ptr)))
stream, ok := conn.streams[int(streamID)]
if !ok {
return
}
stream.streamEnd = true
stream.bp.Close()
if stream.conn.isServer {
if stream.request.Method != "CONNECT" {
go conn.serve(stream)
}
return
}
}

@ -1,85 +0,0 @@
package nghttp2
import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"strings"
"sync"
)
// Transport the nghttp2 RoundTripper implement
type Transport struct {
TLSConfig *tls.Config
DialTLS func(network, addr string, cfg *tls.Config) (*tls.Conn, error)
cacheConn map[string]*Conn
mu sync.Mutex
}
// RoundTrip send req and get res
func (tr *Transport) RoundTrip(req *http.Request) (res *http.Response, err error) {
h2conn, err := tr.getConn(req)
if err != nil {
return nil, err
}
return h2conn.RoundTrip(req)
}
func (tr *Transport) getConn(req *http.Request) (*Conn, error) {
tr.mu.Lock()
defer tr.mu.Unlock()
if tr.cacheConn == nil {
tr.cacheConn = map[string]*Conn{}
}
k := req.URL.Host
if c, ok := tr.cacheConn[k]; ok {
if c.CanTakeNewRequest() {
return c, nil
}
delete(tr.cacheConn, k)
c.Close()
}
c, err := tr.createConn(k)
if err == nil {
tr.cacheConn[k] = c
}
return c, err
}
func (tr *Transport) createConn(host string) (*Conn, error) {
dial := tls.Dial
if tr.DialTLS != nil {
dial = tr.DialTLS
}
cfg := tr.TLSConfig
if cfg == nil {
h, _, err := net.SplitHostPort(host)
if err != nil {
h = host
}
cfg = &tls.Config{
ServerName: h,
NextProtos: []string{"h2"},
}
}
if !strings.Contains(host, ":") {
host = fmt.Sprintf("%s:443", host)
}
conn, err := dial("tcp", host, cfg)
if err != nil {
return nil, err
}
if err = conn.Handshake(); err != nil {
return nil, err
}
state := conn.ConnectionState()
if state.NegotiatedProtocol != "h2" {
conn.Close()
return nil, errors.New("http2 is not supported")
}
return Client(conn)
}

@ -7,7 +7,6 @@ package nghttp2
import "C" import "C"
import ( import (
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
@ -15,7 +14,6 @@ import (
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"runtime"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -28,104 +26,56 @@ type Conn struct {
session *C.nghttp2_session session *C.nghttp2_session
streams map[int]*stream streams map[int]*stream
streamCount int streamCount int
closed bool
isServer bool isServer bool
running bool running bool
handler http.Handler handler http.Handler
lock *sync.Mutex lock *sync.Mutex
err error err error
errch chan error errch chan error
ctx context.Context exitch chan struct{}
cancel context.CancelFunc
} }
// Dial connect to addr and create a http2 client Conn // Server create server side http2 connection
//
// the Conn.Run have already called, should not call it again
func Dial(network, addr string, cfg *tls.Config) (*Conn, error) {
nextProto := []string{"h2"}
if cfg == nil {
_addr := addr
h, _, err := net.SplitHostPort(addr)
if err == nil {
_addr = h
}
cfg = &tls.Config{ServerName: _addr}
}
cfg.NextProtos = nextProto
conn, err := tls.Dial(network, addr, cfg)
if err != nil {
return nil, err
}
if err := conn.Handshake(); err != nil {
return nil, err
}
state := conn.ConnectionState()
if state.NegotiatedProtocol != "h2" {
return nil, errors.New("server not support http2")
}
return Client(conn)
}
// Server create server side http2 connection on c
//
// c must be TLS connection and negotiated for h2
//
// the Conn.Run not called, you must run it
func Server(c net.Conn, handler http.Handler) (*Conn, error) { func Server(c net.Conn, handler http.Handler) (*Conn, error) {
conn := &Conn{ conn := &Conn{
conn: c, conn: c,
handler: handler, handler: handler,
errch: make(chan error), errch: make(chan error),
exitch: make(chan struct{}),
lock: new(sync.Mutex), lock: new(sync.Mutex),
isServer: true, isServer: true,
streams: make(map[int]*stream), streams: make(map[int]*stream),
} }
conn.session = C.init_nghttp2_server_session(C.size_t(uintptr(unsafe.Pointer(conn))))
conn.ctx, conn.cancel = context.WithCancel(context.Background())
//log.Printf("new conn %x", uintptr(unsafe.Pointer(conn)))
runtime.SetFinalizer(conn, (*Conn).free)
conn.session = C.init_nghttp2_server_session(
C.size_t(uintptr(unsafe.Pointer(conn))))
if conn.session == nil { if conn.session == nil {
return nil, errors.New("init server session failed") return nil, errors.New("init server session failed")
} }
ret := C.send_connection_header(conn.session) ret := C.send_connection_header(conn.session)
if int(ret) < 0 { if int(ret) < 0 {
conn.Close() conn.Close()
return nil, fmt.Errorf("send settings error: %s", return nil, fmt.Errorf("send settings error: %s", C.GoString(C.nghttp2_strerror(ret)))
C.GoString(C.nghttp2_strerror(ret)))
} }
return conn, nil return conn, nil
} }
// Client create client side http2 connection on c // Client create client side http2 connection
//
// c must be TLS connection and negotiated for h2
//
// the Conn.Run have alread called, you should not call it again
func Client(c net.Conn) (*Conn, error) { func Client(c net.Conn) (*Conn, error) {
conn := &Conn{ conn := &Conn{
conn: c, conn: c,
errch: make(chan error), errch: make(chan error),
exitch: make(chan struct{}),
lock: new(sync.Mutex), lock: new(sync.Mutex),
streams: make(map[int]*stream), streams: make(map[int]*stream),
} }
conn.session = C.init_nghttp2_client_session(C.size_t(uintptr(unsafe.Pointer(conn))))
conn.ctx, conn.cancel = context.WithCancel(context.Background())
//log.Printf("new conn %x", uintptr(unsafe.Pointer(conn)))
runtime.SetFinalizer(conn, (*Conn).free)
conn.session = C.init_nghttp2_client_session(
C.size_t(uintptr(unsafe.Pointer(conn))))
if conn.session == nil { if conn.session == nil {
return nil, errors.New("init server session failed") return nil, errors.New("init server session failed")
} }
ret := C.send_connection_header(conn.session) ret := C.send_connection_header(conn.session)
if int(ret) < 0 { if int(ret) < 0 {
conn.Close() conn.Close()
return nil, fmt.Errorf("send settings error: %s", return nil, fmt.Errorf("send settings error: %s", C.GoString(C.nghttp2_strerror(ret)))
C.GoString(C.nghttp2_strerror(ret)))
} }
go conn.Run() go conn.Run()
return conn, nil return conn, nil
@ -152,13 +102,6 @@ func HTTP2Handler(srv *http.Server, conn *tls.Conn, handler http.Handler) {
h2conn.Run() h2conn.Run()
} }
func (c *Conn) free() {
//log.Printf("free conn %x", uintptr(unsafe.Pointer(c)))
if !c.isClosed() {
c.Close()
}
}
// Error return conn error // Error return conn error
func (c *Conn) Error() error { func (c *Conn) Error() error {
c.lock.Lock() c.lock.Lock()
@ -193,8 +136,7 @@ func (c *Conn) RoundTrip(req *http.Request) (*http.Response, error) {
nv = append(nv, newNV(":path", p)) nv = append(nv, newNV(":path", p))
for k, v := range req.Header { for k, v := range req.Header {
_k := strings.ToLower(k) _k := strings.ToLower(k)
if _k == "connection" || _k == "proxy-connection" || if _k == "connection" || _k == "proxy-connection" || _k == "transfer-encoding" {
_k == "transfer-encoding" {
continue continue
} }
nv = append(nv, newNV(k, v[0])) nv = append(nv, newNV(k, v[0]))
@ -223,9 +165,6 @@ func (c *Conn) RoundTrip(req *http.Request) (*http.Response, error) {
dp.Close() dp.Close()
}() }()
} }
s.request = req
select { select {
case res := <-s.resch: case res := <-s.resch:
/* /*
@ -233,25 +172,23 @@ func (c *Conn) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, fmt.Errorf("http error code %d", res.StatusCode) return nil, fmt.Errorf("http error code %d", res.StatusCode)
} }
*/ */
s.request = req
res.Request = s.request res.Request = s.request
return res, nil return res, nil
case <-c.ctx.Done(): case <-c.exitch:
return nil, errors.New("connection closed") return nil, errors.New("connection closed")
} }
} }
func (c *Conn) submitRequest(nv []C.nghttp2_nv, func (c *Conn) submitRequest(nv []C.nghttp2_nv, cdp *C.nghttp2_data_provider) (*stream, error) {
cdp *C.nghttp2_data_provider) (*stream, error) {
c.lock.Lock() c.lock.Lock()
ret := C._nghttp2_submit_request(c.session, nil, ret := C._nghttp2_submit_request(c.session, nil,
C.size_t(uintptr(unsafe.Pointer(&nv[0]))), C.size_t(uintptr(unsafe.Pointer(&nv[0]))), C.size_t(len(nv)), cdp, nil)
C.size_t(len(nv)), cdp, nil)
c.lock.Unlock() c.lock.Unlock()
if int(ret) < 0 { if int(ret) < 0 {
return nil, fmt.Errorf("submit request error: %s", return nil, fmt.Errorf("submit request error: %s", C.GoString(C.nghttp2_strerror(ret)))
C.GoString(C.nghttp2_strerror(ret)))
} }
streamID := int(ret) streamID := int(ret)
s := &stream{ s := &stream{
@ -263,22 +200,14 @@ func (c *Conn) submitRequest(nv []C.nghttp2_nv,
}, },
resch: make(chan *http.Response), resch: make(chan *http.Response),
} }
s.ctx, s.cancel = context.WithCancel(context.Background())
if cdp != nil { if cdp != nil {
s.cdp = *cdp s.cdp = *cdp
} }
runtime.SetFinalizer(s, (*stream).free)
return s, nil return s, nil
} }
// Connect submit connect request // Connect submit connect request
// func (c *Conn) Connect(addr string) (net.Conn, int, error) {
// like "CONNECT host:port" on http/1.1
//
// statusCode is the http status code the server returned
//
// c bounds to the remote host of addr
func (c *Conn) Connect(addr string) (conn net.Conn, statusCode int, err error) {
nv := []C.nghttp2_nv{} nv := []C.nghttp2_nv{}
nv = append(nv, newNV(":method", "CONNECT")) nv = append(nv, newNV(":method", "CONNECT"))
@ -303,8 +232,7 @@ func (c *Conn) Connect(addr string) (conn net.Conn, statusCode int, err error) {
select { select {
case res := <-s.resch: case res := <-s.resch:
if res.StatusCode != http.StatusOK { if res.StatusCode != http.StatusOK {
return nil, res.StatusCode, fmt.Errorf( return nil, res.StatusCode, fmt.Errorf("http error code %d", res.StatusCode)
"http error code %d", res.StatusCode)
} }
s.request = &http.Request{ s.request = &http.Request{
Method: "CONNECT", Method: "CONNECT",
@ -314,9 +242,8 @@ func (c *Conn) Connect(addr string) (conn net.Conn, statusCode int, err error) {
} }
res.Request = s.request res.Request = s.request
return s, res.StatusCode, nil return s, res.StatusCode, nil
case <-c.ctx.Done(): case <-c.exitch:
return nil, http.StatusServiceUnavailable, return nil, http.StatusServiceUnavailable, errors.New("connection closed")
errors.New("connection closed")
} }
} }
@ -337,7 +264,7 @@ func (c *Conn) Run() {
case err := <-c.errch: case err := <-c.errch:
c.err = err c.err = err
return return
case <-c.ctx.Done(): case <-c.exitch:
return return
} }
} }
@ -348,40 +275,26 @@ func (c *Conn) serve(s *stream) {
if handler == nil { if handler == nil {
handler = http.DefaultServeMux handler = http.DefaultServeMux
} }
s.request.RemoteAddr = c.conn.RemoteAddr().String()
if s.request.URL == nil {
s.request.URL = &url.URL{}
}
handler.ServeHTTP(s, s.request) handler.ServeHTTP(s, s.request)
s.Close() s.Close()
} }
// Close close the connection // Close close the connection
func (c *Conn) Close() error { func (c *Conn) Close() error {
c.lock.Lock() if c.closed {
if c.isClosed() {
c.lock.Unlock()
return nil return nil
} }
c.cancel() c.closed = true
c.lock.Unlock()
// stream.Close may require the conn.Lock
// so must not hold the lock here
for _, s := range c.streams { for _, s := range c.streams {
s.Close() s.Close()
} }
c.lock.Lock() c.lock.Lock()
for n := range c.streams {
delete(c.streams, n)
}
C.nghttp2_session_terminate_session(c.session, 0) C.nghttp2_session_terminate_session(c.session, 0)
C.nghttp2_session_del(c.session) C.nghttp2_session_del(c.session)
c.lock.Unlock() c.lock.Unlock()
close(c.exitch)
c.conn.Close() c.conn.Close()
return nil return nil
} }
@ -394,31 +307,48 @@ func (c *Conn) errorNotify(err error) {
} }
func (c *Conn) readloop() { func (c *Conn) readloop() {
buf := make([]byte, 16*1024) type data struct {
for { buf []byte
if c.isClosed() { err error
return }
}
n, err := c.conn.Read(buf)
if err != nil {
c.errorNotify(err)
return
}
c.lock.Lock() var ret C.ssize_t
// check again var err error
if c.isClosed() { var d data
c.lock.Unlock()
return datach := make(chan data)
go func() {
d1 := data{}
var n int
var err1 error
for {
buf := make([]byte, 16*1024)
n, err1 = c.conn.Read(buf)
d1.buf = buf[:n]
d1.err = err1
datach <- d1
} }
ret := C.nghttp2_session_mem_recv(c.session, }()
(*C.uchar)(unsafe.Pointer(&buf[0])), C.size_t(n))
c.lock.Unlock() for {
if int(ret) < 0 { select {
err = fmt.Errorf("http2 recv error: %s", case <-c.exitch:
C.GoString(C.nghttp2_strerror(C.int(ret))))
c.errorNotify(err)
return return
case d = <-datach:
if d.err != nil {
c.errorNotify(d.err)
return
}
c.lock.Lock()
ret = C.nghttp2_session_mem_recv(c.session,
(*C.uchar)(unsafe.Pointer(&d.buf[0])), C.size_t(len(d.buf)))
c.lock.Unlock()
if int(ret) < 0 {
err = fmt.Errorf("http2 recv error: %s", C.GoString(C.nghttp2_strerror(C.int(ret))))
c.errorNotify(err)
return
}
} }
} }
} }
@ -427,18 +357,17 @@ func (c *Conn) writeloop() {
var ret C.int var ret C.int
var err error var err error
var delay = 50 * time.Millisecond var delay = 50 * time.Millisecond
for { for {
c.lock.Lock() select {
if c.isClosed() { case <-c.exitch:
c.lock.Unlock()
return return
default:
} }
c.lock.Lock()
ret = C.nghttp2_session_send(c.session) ret = C.nghttp2_session_send(c.session)
c.lock.Unlock() c.lock.Unlock()
if int(ret) < 0 { if int(ret) < 0 {
err = fmt.Errorf("http2 send error: %s", err = fmt.Errorf("http2 send error: %s", C.GoString(C.nghttp2_strerror(C.int(ret))))
C.GoString(C.nghttp2_strerror(C.int(ret))))
c.errorNotify(err) c.errorNotify(err)
return return
} }
@ -451,12 +380,3 @@ func (c *Conn) writeloop() {
} }
} }
} }
func (c *Conn) isClosed() bool {
select {
case <-c.ctx.Done():
return true
default:
}
return false
}

@ -1,189 +1,164 @@
package nghttp2 package nghttp2
/* /*
#include "_nghttp2.h" #include "_nghttp2.h"
*/ */
import "C" import "C"
import ( import (
"bytes" "bytes"
"errors" "errors"
"io" "log"
"log" "sync"
"sync" "time"
"time" "unsafe"
"unsafe" )
)
// dataProvider provider data for libnghttp2 library
// dataProvider provider data for libnghttp2 library // libnghttp2 callback will Read to read the data,
// libnghttp2 callback will Read to read the data, // application call Write to provider data,
// application call Write to provider data, // application call Close will cause Read return io.EOF
// application call Close will cause Read return io.EOF type dataProvider struct {
type dataProvider struct { buf *bytes.Buffer
buf *bytes.Buffer closed bool
closed bool lock *sync.Mutex
lock *sync.Mutex sessLock *sync.Mutex
sessLock *sync.Mutex session *C.nghttp2_session
session *C.nghttp2_session streamID int
streamID int deferred bool
deferred bool }
}
// Read read from data provider
// Read read from data provider func (dp *dataProvider) Read(buf []byte) (n int, err error) {
func (dp *dataProvider) Read(buf []byte) (n int, err error) { if dp.buf == nil || dp.lock == nil || dp.sessLock == nil || dp.session == nil {
if dp.buf == nil || dp.lock == nil || log.Println("dp read invalid state")
dp.sessLock == nil || dp.session == nil { return 0, errors.New("invalid state")
log.Println("dp read invalid state") }
return 0, errors.New("invalid state") dp.lock.Lock()
} defer dp.lock.Unlock()
dp.lock.Lock()
defer dp.lock.Unlock() n, err = dp.buf.Read(buf)
if err != nil && !dp.closed {
n, err = dp.buf.Read(buf) //log.Println("deferred")
if err != nil && !dp.closed { dp.deferred = true
//log.Println("deferred") return 0, errAgain
dp.deferred = true }
return 0, errAgain return
} }
return
} // Write provider data for data provider
func (dp *dataProvider) Write(buf []byte) (n int, err error) {
// Write provider data for data provider if dp.buf == nil || dp.lock == nil || dp.sessLock == nil || dp.session == nil {
func (dp *dataProvider) Write(buf []byte) (n int, err error) { log.Println("dp write invalid state")
if dp.buf == nil || dp.lock == nil || return 0, errors.New("invalid state")
dp.sessLock == nil || dp.session == nil { }
log.Println("dp write invalid state") dp.lock.Lock()
return 0, errors.New("invalid state") defer dp.lock.Unlock()
}
//if dp.closed {
// make sure the buffer not too large // return 0, io.EOF
delay := 10 * time.Millisecond //}
maxBufSize := 4 * 1024
for { n, err = dp.buf.Write(buf)
dp.lock.Lock() if dp.deferred {
_len := dp.buf.Len() dp.sessLock.Lock()
closed := dp.closed C.nghttp2_session_resume_data(dp.session, C.int(dp.streamID))
dp.lock.Unlock() dp.sessLock.Unlock()
if closed {
return 0, io.EOF //log.Println("resume")
} dp.deferred = false
if _len < maxBufSize { }
break return
} }
time.Sleep(delay)
} // Close end to provide data
func (dp *dataProvider) Close() error {
dp.lock.Lock() if dp.buf == nil || dp.lock == nil || dp.sessLock == nil || dp.session == nil {
defer dp.lock.Unlock() log.Println("dp close, invalid state")
return errors.New("invalid state")
//if dp.closed { }
// return 0, io.EOF dp.lock.Lock()
//} defer dp.lock.Unlock()
n, err = dp.buf.Write(buf) if dp.closed {
if dp.deferred { return nil
dp.sessLock.Lock() }
C.nghttp2_session_resume_data( dp.closed = true
dp.session, C.int(dp.streamID)) //log.Printf("dp close stream %d", dp.streamID)
dp.sessLock.Unlock() if dp.deferred {
dp.sessLock.Lock()
//log.Println("resume") C.nghttp2_session_resume_data(dp.session, C.int(dp.streamID))
dp.deferred = false dp.sessLock.Unlock()
}
return dp.deferred = false
} }
return nil
// Close end to provide data }
func (dp *dataProvider) Close() error {
if dp.buf == nil || dp.lock == nil || func newDataProvider(cdp unsafe.Pointer, sessionLock *sync.Mutex, t int) *dataProvider {
dp.sessLock == nil || dp.session == nil { dp := &dataProvider{
log.Println("dp close, invalid state") buf: new(bytes.Buffer),
return errors.New("invalid state") lock: new(sync.Mutex),
} sessLock: sessionLock,
dp.lock.Lock() }
defer dp.lock.Unlock() C.data_provider_set_callback(C.size_t(uintptr(cdp)),
C.size_t(uintptr(unsafe.Pointer(dp))), C.int(t))
if dp.closed { return dp
return nil }
}
dp.closed = true // bodyProvider provide data for http body
//log.Printf("dp close stream %d", dp.streamID) // Read will block when data not yet avaliable
if dp.deferred { type bodyProvider struct {
dp.sessLock.Lock() buf *bytes.Buffer
C.nghttp2_session_resume_data( closed bool
dp.session, C.int(dp.streamID)) lock *sync.Mutex
dp.sessLock.Unlock() }
dp.deferred = false // Read read data from provider
} // will block when data not yet avaliable
return nil func (bp *bodyProvider) Read(buf []byte) (int, error) {
} var delay = 100 * time.Millisecond
func newDataProvider(cdp unsafe.Pointer, for {
sessionLock *sync.Mutex, t int) *dataProvider { bp.lock.Lock()
dp := &dataProvider{ n, err := bp.buf.Read(buf)
buf: new(bytes.Buffer), bp.lock.Unlock()
lock: new(sync.Mutex), if err != nil && !bp.closed {
sessLock: sessionLock, time.Sleep(delay)
} continue
C.data_provider_set_callback(C.size_t(uintptr(cdp)), }
C.size_t(uintptr(unsafe.Pointer(dp))), C.int(t)) return n, err
return dp }
} }
// bodyProvider provide data for http body // Write provide data for dataProvider
// Read will block when data not yet avaliable // libnghttp2 data chunk recv callback will call this
type bodyProvider struct { func (bp *bodyProvider) Write(buf []byte) (int, error) {
buf *bytes.Buffer bp.lock.Lock()
closed bool defer bp.lock.Unlock()
lock *sync.Mutex
} return bp.buf.Write(buf)
}
// Read read data from provider
// will block when data not yet avaliable // Close end to provide data
func (bp *bodyProvider) Read(buf []byte) (int, error) { func (bp *bodyProvider) Close() error {
var delay = 100 * time.Millisecond bp.lock.Lock()
defer bp.lock.Unlock()
for {
bp.lock.Lock() bp.closed = true
n, err := bp.buf.Read(buf) return nil
bp.lock.Unlock() }
if err != nil && !bp.closed {
time.Sleep(delay) func newNV(name, value string) C.nghttp2_nv {
continue nv := C.nghttp2_nv{}
} nameArr := make([]byte, len(name)+1)
return n, err valueArr := make([]byte, len(value)+1)
} copy(nameArr, []byte(name))
} copy(valueArr, []byte(value))
// Write provide data for dataProvider nv.name = (*C.uchar)(unsafe.Pointer(&nameArr[0]))
// libnghttp2 data chunk recv callback will call this nv.value = (*C.uchar)(unsafe.Pointer(&valueArr[0]))
func (bp *bodyProvider) Write(buf []byte) (int, error) { nv.namelen = C.size_t(len(name))
bp.lock.Lock() nv.valuelen = C.size_t(len(value))
defer bp.lock.Unlock() nv.flags = 0
return nv
return bp.buf.Write(buf) }
}
// Close end to provide data
func (bp *bodyProvider) Close() error {
bp.lock.Lock()
defer bp.lock.Unlock()
bp.closed = true
return nil
}
func newNV(name, value string) C.nghttp2_nv {
nv := C.nghttp2_nv{}
nameArr := make([]byte, len(name)+1)
valueArr := make([]byte, len(value)+1)
copy(nameArr, []byte(name))
copy(valueArr, []byte(value))
nv.name = (*C.uchar)(unsafe.Pointer(&nameArr[0]))
nv.value = (*C.uchar)(unsafe.Pointer(&valueArr[0]))
nv.namelen = C.size_t(len(name))
nv.valuelen = C.size_t(len(value))
nv.flags = 0
return nv
}

226
doc.go

@ -1,113 +1,113 @@
/*Package nghttp2 is libnghttp2 binding for golang. /*Package nghttp2 is libnghttp2 binding for golang.
server example server example
cert, err := tls.LoadX509KeyPair("testdata/server.crt", "testdata/server.key") cert, err := tls.LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
if err != nil { if err != nil {
log.Fatal(err) t.Fatal(err)
} }
l, err := tls.Listen("tcp", "127.0.0.1:1100", &tls.Config{ l, err := tls.Listen("tcp", "127.0.0.1:1100", &tls.Config{
Certificates: []tls.Certificate{cert}, Certificates: []tls.Certificate{cert},
NextProtos: []string{"h2"}, NextProtos: []string{"h2"},
}) })
if err != nil { if err != nil {
log.Fatal(err) t.Fatal(err)
} }
defer l.Close() defer l.Close()
addr := l.Addr().String() addr := l.Addr().String()
http.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) { http.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
log.Printf("%+v", r) log.Printf("%+v", r)
hdr := w.Header() hdr := w.Header()
hdr.Set("content-type", "text/plain") hdr.Set("content-type", "text/plain")
hdr.Set("aa", "bb") hdr.Set("aa", "bb")
d, err := ioutil.ReadAll(r.Body) d, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return return
} }
w.Write(d) w.Write(d)
}) })
for { for {
c, err := l.Accept() c, err := l.Accept()
if err != nil { if err != nil {
break break
} }
h2conn, err := Server(c, nil) h2conn, err := Server(c, nil)
if err != nil { if err != nil {
log.Fatal(err) t.Fatal(err)
} }
log.Printf("%+v", h2conn) log.Printf("%+v", h2conn)
go h2conn.Run() go h2conn.Run()
} }
client example client example
conn, err := tls.Dial("tcp", "nghttp2.org:443", &tls.Config{ conn, err := tls.Dial("tcp", "nghttp2.org:443", &tls.Config{
NextProtos: []string{"h2"}, NextProtos: []string{"h2"},
ServerName: "nghttp2.org", ServerName: "nghttp2.org",
}) })
if err != nil { if err != nil {
log.Fatal(err) t.Fatal(err)
} }
defer conn.Close() defer conn.Close()
if err := conn.Handshake(); err != nil{ if err := conn.Handshake(); err != nil{
log.Fatal(err) t.Fatal(err)
} }
cstate := conn.ConnectionState() cstate := conn.ConnectionState()
if cstate.NegotiatedProtocol != "h2" { if cstate.NegotiatedProtocol != "h2" {
log.Fatal("no http2 on server") t.Fatal("no http2 on server")
} }
h2conn, err := Client(conn) h2conn, err := Client(conn)
if err != nil { if err != nil {
log.Fatal(err) t.Fatal(err)
} }
param := url.Values{} param := url.Values{}
param.Add("e", "b") param.Add("e", "b")
param.Add("f", "d") param.Add("f", "d")
data := bytes.NewReader([]byte(param.Encode())) data := bytes.NewReader([]byte(param.Encode()))
req, _ := http.NewRequest("POST", req, _ := http.NewRequest("POST",
"https://nghttp2.org/httpbin/post?a=b&c=d", "https://nghttp2.org/httpbin/post?a=b&c=d",
data) data)
log.Printf("%+v", req) log.Printf("%+v", req)
req.Header.Set("user-agent", "go-nghttp2/1.0") req.Header.Set("user-agent", "go-nghttp2/1.0")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
res, err := h2conn.RoundTrip(req) res, err := h2conn.RoundTrip(req)
if err != nil { if err != nil {
log.Fatal(err) t.Fatal(err)
} }
if res.StatusCode != http.StatusOK { if res.StatusCode != http.StatusOK {
log.Printf("expect %d, got %d", http.StatusOK, res.StatusCode) t.Errorf("expect %d, got %d", http.StatusOK, res.StatusCode)
} }
res.Write(os.Stderr) res.Write(os.Stderr)
co-work with net/http example co-work with net/http example
l, err := net.Listen("tcp", "127.0.0.1:1222") l, err := net.Listen("tcp", "127.0.0.1:1222")
if err != nil { if err != nil {
log.Fatal(err) t.Fatal(err)
} }
srv := &http.Server{ srv := &http.Server{
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
NextProtos: []string{"h2", "http/1.1"}, NextProtos: []string{"h2", "http/1.1"},
}, },
TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){ TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){
"h2": nghttp2.HTTP2Handler, "h2": nghttp2.HTTP2Handler,
}, },
} }
defer srv.Close() defer srv.Close()
srv.ServeTLS(l, "testdata/server.crt", "testdata/server.key") srv.ServeTLS(l, "testdata/server.crt", "testdata/server.key")
see http2_test.go for more details see http2_test.go for more details
*/ */
package nghttp2 package nghttp2

@ -5,7 +5,6 @@ package nghttp2
*/ */
import "C" import "C"
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -26,31 +25,14 @@ type stream struct {
resch chan *http.Response resch chan *http.Response
headersEnd bool headersEnd bool
streamEnd bool streamEnd bool
closed bool
cdp C.nghttp2_data_provider cdp C.nghttp2_data_provider
ctx context.Context
cancel context.CancelFunc
} }
var _ net.Conn = &stream{} var _ net.Conn = &stream{}
func (s *stream) isClosed() bool {
select {
case <-s.ctx.Done():
return true
default:
}
return false
}
func (s *stream) free() {
//log.Printf("stream free %d", s.streamID)
if !s.isClosed() {
s.Close()
}
}
func (s *stream) Read(buf []byte) (int, error) { func (s *stream) Read(buf []byte) (int, error) {
if s.isClosed() { if s.closed {
return 0, io.EOF return 0, io.EOF
} }
if s.bp != nil { if s.bp != nil {
@ -60,7 +42,7 @@ func (s *stream) Read(buf []byte) (int, error) {
} }
func (s *stream) WriteHeader(code int) { func (s *stream) WriteHeader(code int) {
if s.isClosed() { if s.closed {
return return
} }
if s.response == nil { if s.response == nil {
@ -82,33 +64,24 @@ func (s *stream) WriteHeader(code int) {
nv = append(nv, newNV(":status", fmt.Sprintf("%d", code))) nv = append(nv, newNV(":status", fmt.Sprintf("%d", code)))
for k, v := range s.response.Header { for k, v := range s.response.Header {
_k := strings.ToLower(k) _k := strings.ToLower(k)
if _k == "host" || _k == "connection" || if _k == "host" || _k == "connection" || _k == "transfer-encoding" {
_k == "transfer-encoding" {
continue continue
} }
nv = append(nv, newNV(k, v[0])) nv = append(nv, newNV(k, v[0]))
} }
s.cdp = C.nghttp2_data_provider{} s.cdp = C.nghttp2_data_provider{}
s.dp = newDataProvider(unsafe.Pointer(&s.cdp), s.dp = newDataProvider(unsafe.Pointer(&s.cdp), s.conn.lock, 0)
s.conn.lock, 0)
s.dp.session = s.conn.session s.dp.session = s.conn.session
s.dp.streamID = s.streamID s.dp.streamID = s.streamID
s.conn.lock.Lock() s.conn.lock.Lock()
if s.conn.isClosed() { ret := C._nghttp2_submit_response(s.conn.session, C.int(s.streamID),
s.conn.lock.Unlock() C.size_t(uintptr(unsafe.Pointer(&nv[0]))), C.size_t(len(nv)), &s.cdp)
return
}
ret := C._nghttp2_submit_response(s.conn.session,
C.int(s.streamID),
C.size_t(uintptr(unsafe.Pointer(&nv[0]))),
C.size_t(len(nv)), &s.cdp)
s.conn.lock.Unlock() s.conn.lock.Unlock()
if int(ret) < 0 { if int(ret) < 0 {
panic(fmt.Sprintf("submit response error: %s", panic(fmt.Sprintf("submit response error: %s", C.GoString(C.nghttp2_strerror(ret))))
C.GoString(C.nghttp2_strerror(ret))))
} }
} }
@ -125,11 +98,10 @@ func (s *stream) Header() http.Header {
} }
func (s *stream) Write(buf []byte) (int, error) { func (s *stream) Write(buf []byte) (int, error) {
if s.isClosed() { if s.closed {
return 0, io.EOF return 0, io.EOF
} }
if s.conn.isServer && (s.response == nil || if s.conn.isServer && (s.response == nil || s.response.StatusCode == 0) {
s.response.StatusCode == 0) {
s.WriteHeader(http.StatusOK) s.WriteHeader(http.StatusOK)
} }
@ -140,12 +112,10 @@ func (s *stream) Write(buf []byte) (int, error) {
} }
func (s *stream) Close() error { func (s *stream) Close() error {
if s.isClosed() { if s.closed {
return nil return nil
} }
s.closed = true
s.cancel()
if s.dp != nil { if s.dp != nil {
s.dp.Close() s.dp.Close()
} }
@ -160,8 +130,7 @@ func (s *stream) Close() error {
if s.request != nil && s.request.Method == "CONNECT" { if s.request != nil && s.request.Method == "CONNECT" {
//log.Println("rst stream") //log.Println("rst stream")
s.conn.lock.Lock() s.conn.lock.Lock()
C.nghttp2_submit_rst_stream(s.conn.session, 0, C.nghttp2_submit_rst_stream(s.conn.session, 0, C.int(s.streamID), 8)
C.int(s.streamID), 8)
s.conn.lock.Unlock() s.conn.lock.Unlock()
} }
return nil return nil

Loading…
Cancel
Save