diff --git a/callbacks.go b/callbacks.go index b9e1454..b9a0943 100644 --- a/callbacks.go +++ b/callbacks.go @@ -7,7 +7,6 @@ import "C" import ( "bytes" "io" - "log" "net/http" "net/url" "strings" @@ -16,7 +15,8 @@ import ( ) // OnServerDataRecvCallback callback function for libnghttp2 library -// want receive data from network, +// want receive data from network. +// //export OnServerDataRecvCallback func OnServerDataRecvCallback(ptr unsafe.Pointer, data unsafe.Pointer, length C.size_t) C.ssize_t { @@ -33,7 +33,8 @@ func OnServerDataRecvCallback(ptr unsafe.Pointer, data unsafe.Pointer, } // OnServerDataSendCallback callback function for libnghttp2 library -// want send data to network +// want send data to network. +// //export OnServerDataSendCallback func OnServerDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer, length C.size_t) C.ssize_t { @@ -48,7 +49,8 @@ func OnServerDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer, return C.ssize_t(n) } -// OnServerDataChunkRecv callback function for libnghttp2 library's data chunk recv +// OnServerDataChunkRecv callback function for libnghttp2 library's data chunk recv. +// //export OnServerDataChunkRecv func OnServerDataChunkRecv(ptr unsafe.Pointer, streamID C.int, data unsafe.Pointer, length C.size_t) C.int { @@ -60,7 +62,8 @@ func OnServerDataChunkRecv(ptr unsafe.Pointer, streamID C.int, return C.int(length) } -// OnServerBeginHeaderCallback callback function for begin begin header recv +// OnServerBeginHeaderCallback callback function for begin begin header recv. +// //export OnServerBeginHeaderCallback func OnServerBeginHeaderCallback(ptr unsafe.Pointer, streamID C.int) C.int { conn := (*ServerConn)(ptr) @@ -80,7 +83,8 @@ func OnServerBeginHeaderCallback(ptr unsafe.Pointer, streamID C.int) C.int { return 0 } -// OnServerHeaderCallback callback function for each header recv +// OnServerHeaderCallback callback function for each header recv. +// //export OnServerHeaderCallback func OnServerHeaderCallback(ptr unsafe.Pointer, streamID C.int, name unsafe.Pointer, namelen C.int, @@ -113,6 +117,7 @@ func OnServerHeaderCallback(ptr unsafe.Pointer, streamID C.int, } // OnServerStreamEndCallback callback function for the stream when END_STREAM flag set +// //export OnServerStreamEndCallback func OnServerStreamEndCallback(ptr unsafe.Pointer, streamID C.int) C.int { @@ -122,13 +127,14 @@ func OnServerStreamEndCallback(ptr unsafe.Pointer, streamID C.int) C.int { bp := s.req.Body.(*bodyProvider) if s.req.Method != "CONNECT" { bp.closed = true - log.Println("stream end flag set, begin to serve") + //log.Println("stream end flag set, begin to serve") go conn.serve(s) } return 0 } -// OnServerHeadersDoneCallback callback function for the stream when all headers received +// OnServerHeadersDoneCallback callback function for the stream when all headers received. +// //export OnServerHeadersDoneCallback func OnServerHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int { conn := (*ServerConn)(ptr) @@ -145,7 +151,8 @@ func OnServerHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int { return 0 } -// OnServerStreamClose callback function for the stream when closed +// OnServerStreamClose callback function for the stream when closed. +// //export OnServerStreamClose func OnServerStreamClose(ptr unsafe.Pointer, streamID C.int) C.int { conn := (*ServerConn)(ptr) @@ -158,9 +165,9 @@ func OnServerStreamClose(ptr unsafe.Pointer, streamID C.int) C.int { } // OnDataSourceReadCallback callback function for libnghttp2 library -// want read data from data provider source, -// return NGHTTP2_ERR_DEFERED will cause data frame defered, -// application later call nghttp2_session_resume_data will re-quene the data frame +// want read data from data provider source, +// return NGHTTP2_ERR_DEFERED will cause data frame defered, +// application later call nghttp2_session_resume_data will re-quene the data frame // //export OnDataSourceReadCallback func OnDataSourceReadCallback(ptr unsafe.Pointer, @@ -185,7 +192,8 @@ func OnDataSourceReadCallback(ptr unsafe.Pointer, return C.ssize_t(n) } -// OnClientDataChunkRecv callback function for libnghttp2 library data chunk received, +// OnClientDataChunkRecv callback function for libnghttp2 library data chunk received. +// //export OnClientDataChunkRecv func OnClientDataChunkRecv(ptr unsafe.Pointer, streamID C.int, buf unsafe.Pointer, length C.size_t) C.int { @@ -196,7 +204,8 @@ func OnClientDataChunkRecv(ptr unsafe.Pointer, streamID C.int, return 0 } -// OnClientDataRecvCallback callback function for libnghttp2 library want read data from network, +// OnClientDataRecvCallback callback function for libnghttp2 library want read data from network. +// //export OnClientDataRecvCallback func OnClientDataRecvCallback(ptr unsafe.Pointer, data unsafe.Pointer, size C.size_t) C.ssize_t { //log.Println("data read req", int(size)) @@ -214,7 +223,8 @@ func OnClientDataRecvCallback(ptr unsafe.Pointer, data unsafe.Pointer, size C.si return C.ssize_t(n) } -// OnClientDataSendCallback callback function for libnghttp2 library want send data to network, +// OnClientDataSendCallback callback function for libnghttp2 library want send data to network. +// //export OnClientDataSendCallback func OnClientDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer, size C.size_t) C.ssize_t { //log.Println("data write req ", int(size)) @@ -230,7 +240,8 @@ func OnClientDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer, size C.si return C.ssize_t(n) } -// OnClientBeginHeaderCallback callback function for begin header receive, +// OnClientBeginHeaderCallback callback function for begin header receive. +// //export OnClientBeginHeaderCallback func OnClientBeginHeaderCallback(ptr unsafe.Pointer, streamID C.int) C.int { //log.Println("begin header") @@ -239,7 +250,8 @@ func OnClientBeginHeaderCallback(ptr unsafe.Pointer, streamID C.int) C.int { return 0 } -// OnClientHeaderCallback callback function for each header received, +// OnClientHeaderCallback callback function for each header received. +// //export OnClientHeaderCallback func OnClientHeaderCallback(ptr unsafe.Pointer, streamID C.int, name unsafe.Pointer, namelen C.int, @@ -252,7 +264,8 @@ func OnClientHeaderCallback(ptr unsafe.Pointer, streamID C.int, return 0 } -// OnClientHeadersDoneCallback callback function for the stream when all headers received, +// OnClientHeadersDoneCallback callback function for the stream when all headers received. +// //export OnClientHeadersDoneCallback func OnClientHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int { //log.Println("frame recv") @@ -261,7 +274,8 @@ func OnClientHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int { return 0 } -// OnClientStreamClose callback function for the stream when closed, +// OnClientStreamClose callback function for the stream when closed. +// //export OnClientStreamClose func OnClientStreamClose(ptr unsafe.Pointer, streamID C.int) C.int { //log.Println("stream close") diff --git a/conn.go b/conn.go index a253dbe..ea43fae 100644 --- a/conn.go +++ b/conn.go @@ -7,6 +7,7 @@ package nghttp2 import "C" import ( "bytes" + "crypto/tls" "errors" "fmt" "io" @@ -290,6 +291,27 @@ type ServerConn struct { err error } +// HTTP2Handler is the http2 server handler that can co-work with standard net/http. +// +// usage example: +// l, err := net.Listen("tcp", ":1222") +// srv := &http.Server{ +// TLSConfig: &tls.Config{ +// NextProtos:[]string{"h2", "http/1.1"}, +// } +// TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){ +// "h2": nghttp2.Http2Handler +// } +// } +// srv.ServeTLS(l, "server.crt", "server.key") +func HTTP2Handler(srv *http.Server, conn *tls.Conn, handler http.Handler) { + h2conn, err := NewServerConn(conn, handler) + if err != nil { + panic(err.Error()) + } + h2conn.Run() +} + // NewServerConn create new server connection func NewServerConn(c net.Conn, handler http.Handler) (*ServerConn, error) { conn := &ServerConn{ diff --git a/http2_test.go b/http2_test.go index 40265f0..53a8b7f 100644 --- a/http2_test.go +++ b/http2_test.go @@ -6,10 +6,13 @@ import ( "fmt" "io/ioutil" "log" + "net" "net/http" "net/url" "os" "testing" + + "golang.org/x/net/http2" ) func TestHttp2Client(t *testing.T) { @@ -80,7 +83,7 @@ func TestHttp2Server(t *testing.T) { defer l.Close() addr := l.Addr().String() go func() { - http.HandleFunc("/get", func(w http.ResponseWriter, r *http.Request) { + http.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) { log.Printf("%+v", r) hdr := w.Header() hdr.Set("content-type", "text/plain") @@ -119,16 +122,26 @@ func TestHttp2Server(t *testing.T) { if cstate.NegotiatedProtocol != "h2" { t.Fatal("no http2 on server") } - h2conn, err := NewClientConn(conn) - if err != nil { - t.Fatal(err) + client := &http.Client{ + Transport: &http2.Transport{ + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + conn, err := tls.Dial(network, addr, &tls.Config{ + NextProtos: []string{"h2", "http/1.1"}, + InsecureSkipVerify: true, + }) + if err := conn.Handshake(); err != nil { + return nil, err + } + return conn, err + }, + }, } d := bytes.NewBuffer([]byte("hello")) req, _ := http.NewRequest("POST", - fmt.Sprintf("https://%s/get?a=b&c=d", addr), d) + fmt.Sprintf("https://%s/test?a=b&c=d", addr), d) req.Header.Add("User-Agent", "nghttp2/1.32") req.Header.Add("Content-Type", "text/palin") - res, err := h2conn.CreateRequest(req) + res, err := client.Do(req) if err != nil { t.Fatal(err) } @@ -146,3 +159,72 @@ func TestHttp2Server(t *testing.T) { t.Errorf("expect %s, got %s", "hello", string(data)) } } + +func TestHttp2Handler(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + srv := &http.Server{ + TLSConfig: &tls.Config{ + NextProtos: []string{"h2", "http/1.1"}, + }, + TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){ + "h2": HTTP2Handler, + }, + } + defer srv.Close() + + testdata := "asc fasdf32ddfasfff\r\nassdf312313" + addr := l.Addr().String() + go func() { + http.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) { + hdr := w.Header() + hdr.Set("content-type", "text/plain") + hdr.Set("aa", "bb") + fmt.Fprintf(w, testdata) + }) + http.Handle("/", http.FileServer(http.Dir("/"))) + srv.ServeTLS(l, "testdata/server.crt", "testdata/server.key") + }() + client := &http.Client{ + Transport: &http2.Transport{ + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + conn, err := tls.Dial(network, addr, &tls.Config{ + NextProtos: []string{"h2", "http/1.1"}, + InsecureSkipVerify: true, + }) + if err := conn.Handshake(); err != nil { + return nil, err + } + return conn, err + }, + }, + } + u := fmt.Sprintf("https://%s/test", addr) + resp, err := client.Get(u) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("http error %d", resp.StatusCode) + } + if resp.TLS == nil { + t.Errorf("not tls") + } + if resp.TLS.NegotiatedProtocol != "h2" { + t.Errorf("http2 is not enabled") + } + d, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Error(err) + } + if string(d) != testdata { + t.Errorf("expect %s, got %s", testdata, string(d)) + } + if resp.Header.Get("aa") != "bb" { + t.Errorf("expect header not found") + } + //io.Copy(os.Stdout, resp.Body) + resp.Write(os.Stdout) +}