add http2 handler can cowork with net/http

merge_conn
fangdingjun 7 years ago
parent 8d4069b1c2
commit 790bc1b7c8

@ -7,7 +7,6 @@ import "C"
import (
"bytes"
"io"
"log"
"net/http"
"net/url"
"strings"
@ -16,7 +15,8 @@ import (
)
// OnServerDataRecvCallback callback function for libnghttp2 library
// want receive data from network,
// want receive data from network.
//
//export OnServerDataRecvCallback
func OnServerDataRecvCallback(ptr unsafe.Pointer, data unsafe.Pointer,
length C.size_t) C.ssize_t {
@ -33,7 +33,8 @@ func OnServerDataRecvCallback(ptr unsafe.Pointer, data unsafe.Pointer,
}
// OnServerDataSendCallback callback function for libnghttp2 library
// want send data to network
// want send data to network.
//
//export OnServerDataSendCallback
func OnServerDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer,
length C.size_t) C.ssize_t {
@ -48,7 +49,8 @@ func OnServerDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer,
return C.ssize_t(n)
}
// OnServerDataChunkRecv callback function for libnghttp2 library's data chunk recv
// OnServerDataChunkRecv callback function for libnghttp2 library's data chunk recv.
//
//export OnServerDataChunkRecv
func OnServerDataChunkRecv(ptr unsafe.Pointer, streamID C.int,
data unsafe.Pointer, length C.size_t) C.int {
@ -60,7 +62,8 @@ func OnServerDataChunkRecv(ptr unsafe.Pointer, streamID C.int,
return C.int(length)
}
// OnServerBeginHeaderCallback callback function for begin begin header recv
// OnServerBeginHeaderCallback callback function for begin begin header recv.
//
//export OnServerBeginHeaderCallback
func OnServerBeginHeaderCallback(ptr unsafe.Pointer, streamID C.int) C.int {
conn := (*ServerConn)(ptr)
@ -80,7 +83,8 @@ func OnServerBeginHeaderCallback(ptr unsafe.Pointer, streamID C.int) C.int {
return 0
}
// OnServerHeaderCallback callback function for each header recv
// OnServerHeaderCallback callback function for each header recv.
//
//export OnServerHeaderCallback
func OnServerHeaderCallback(ptr unsafe.Pointer, streamID C.int,
name unsafe.Pointer, namelen C.int,
@ -113,6 +117,7 @@ func OnServerHeaderCallback(ptr unsafe.Pointer, streamID C.int,
}
// OnServerStreamEndCallback callback function for the stream when END_STREAM flag set
//
//export OnServerStreamEndCallback
func OnServerStreamEndCallback(ptr unsafe.Pointer, streamID C.int) C.int {
@ -122,13 +127,14 @@ func OnServerStreamEndCallback(ptr unsafe.Pointer, streamID C.int) C.int {
bp := s.req.Body.(*bodyProvider)
if s.req.Method != "CONNECT" {
bp.closed = true
log.Println("stream end flag set, begin to serve")
//log.Println("stream end flag set, begin to serve")
go conn.serve(s)
}
return 0
}
// OnServerHeadersDoneCallback callback function for the stream when all headers received
// OnServerHeadersDoneCallback callback function for the stream when all headers received.
//
//export OnServerHeadersDoneCallback
func OnServerHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int {
conn := (*ServerConn)(ptr)
@ -145,7 +151,8 @@ func OnServerHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int {
return 0
}
// OnServerStreamClose callback function for the stream when closed
// OnServerStreamClose callback function for the stream when closed.
//
//export OnServerStreamClose
func OnServerStreamClose(ptr unsafe.Pointer, streamID C.int) C.int {
conn := (*ServerConn)(ptr)
@ -158,9 +165,9 @@ func OnServerStreamClose(ptr unsafe.Pointer, streamID C.int) C.int {
}
// OnDataSourceReadCallback callback function for libnghttp2 library
// want read data from data provider source,
// return NGHTTP2_ERR_DEFERED will cause data frame defered,
// application later call nghttp2_session_resume_data will re-quene the data frame
// want read data from data provider source,
// return NGHTTP2_ERR_DEFERED will cause data frame defered,
// application later call nghttp2_session_resume_data will re-quene the data frame
//
//export OnDataSourceReadCallback
func OnDataSourceReadCallback(ptr unsafe.Pointer,
@ -185,7 +192,8 @@ func OnDataSourceReadCallback(ptr unsafe.Pointer,
return C.ssize_t(n)
}
// OnClientDataChunkRecv callback function for libnghttp2 library data chunk received,
// OnClientDataChunkRecv callback function for libnghttp2 library data chunk received.
//
//export OnClientDataChunkRecv
func OnClientDataChunkRecv(ptr unsafe.Pointer, streamID C.int,
buf unsafe.Pointer, length C.size_t) C.int {
@ -196,7 +204,8 @@ func OnClientDataChunkRecv(ptr unsafe.Pointer, streamID C.int,
return 0
}
// OnClientDataRecvCallback callback function for libnghttp2 library want read data from network,
// OnClientDataRecvCallback callback function for libnghttp2 library want read data from network.
//
//export OnClientDataRecvCallback
func OnClientDataRecvCallback(ptr unsafe.Pointer, data unsafe.Pointer, size C.size_t) C.ssize_t {
//log.Println("data read req", int(size))
@ -214,7 +223,8 @@ func OnClientDataRecvCallback(ptr unsafe.Pointer, data unsafe.Pointer, size C.si
return C.ssize_t(n)
}
// OnClientDataSendCallback callback function for libnghttp2 library want send data to network,
// OnClientDataSendCallback callback function for libnghttp2 library want send data to network.
//
//export OnClientDataSendCallback
func OnClientDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer, size C.size_t) C.ssize_t {
//log.Println("data write req ", int(size))
@ -230,7 +240,8 @@ func OnClientDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer, size C.si
return C.ssize_t(n)
}
// OnClientBeginHeaderCallback callback function for begin header receive,
// OnClientBeginHeaderCallback callback function for begin header receive.
//
//export OnClientBeginHeaderCallback
func OnClientBeginHeaderCallback(ptr unsafe.Pointer, streamID C.int) C.int {
//log.Println("begin header")
@ -239,7 +250,8 @@ func OnClientBeginHeaderCallback(ptr unsafe.Pointer, streamID C.int) C.int {
return 0
}
// OnClientHeaderCallback callback function for each header received,
// OnClientHeaderCallback callback function for each header received.
//
//export OnClientHeaderCallback
func OnClientHeaderCallback(ptr unsafe.Pointer, streamID C.int,
name unsafe.Pointer, namelen C.int,
@ -252,7 +264,8 @@ func OnClientHeaderCallback(ptr unsafe.Pointer, streamID C.int,
return 0
}
// OnClientHeadersDoneCallback callback function for the stream when all headers received,
// OnClientHeadersDoneCallback callback function for the stream when all headers received.
//
//export OnClientHeadersDoneCallback
func OnClientHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int {
//log.Println("frame recv")
@ -261,7 +274,8 @@ func OnClientHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int {
return 0
}
// OnClientStreamClose callback function for the stream when closed,
// OnClientStreamClose callback function for the stream when closed.
//
//export OnClientStreamClose
func OnClientStreamClose(ptr unsafe.Pointer, streamID C.int) C.int {
//log.Println("stream close")

@ -7,6 +7,7 @@ package nghttp2
import "C"
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
@ -290,6 +291,27 @@ type ServerConn struct {
err error
}
// HTTP2Handler is the http2 server handler that can co-work with standard net/http.
//
// usage example:
// l, err := net.Listen("tcp", ":1222")
// srv := &http.Server{
// TLSConfig: &tls.Config{
// NextProtos:[]string{"h2", "http/1.1"},
// }
// TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){
// "h2": nghttp2.Http2Handler
// }
// }
// srv.ServeTLS(l, "server.crt", "server.key")
func HTTP2Handler(srv *http.Server, conn *tls.Conn, handler http.Handler) {
h2conn, err := NewServerConn(conn, handler)
if err != nil {
panic(err.Error())
}
h2conn.Run()
}
// NewServerConn create new server connection
func NewServerConn(c net.Conn, handler http.Handler) (*ServerConn, error) {
conn := &ServerConn{

@ -6,10 +6,13 @@ import (
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"net/url"
"os"
"testing"
"golang.org/x/net/http2"
)
func TestHttp2Client(t *testing.T) {
@ -80,7 +83,7 @@ func TestHttp2Server(t *testing.T) {
defer l.Close()
addr := l.Addr().String()
go func() {
http.HandleFunc("/get", func(w http.ResponseWriter, r *http.Request) {
http.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
log.Printf("%+v", r)
hdr := w.Header()
hdr.Set("content-type", "text/plain")
@ -119,16 +122,26 @@ func TestHttp2Server(t *testing.T) {
if cstate.NegotiatedProtocol != "h2" {
t.Fatal("no http2 on server")
}
h2conn, err := NewClientConn(conn)
if err != nil {
t.Fatal(err)
client := &http.Client{
Transport: &http2.Transport{
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
conn, err := tls.Dial(network, addr, &tls.Config{
NextProtos: []string{"h2", "http/1.1"},
InsecureSkipVerify: true,
})
if err := conn.Handshake(); err != nil {
return nil, err
}
return conn, err
},
},
}
d := bytes.NewBuffer([]byte("hello"))
req, _ := http.NewRequest("POST",
fmt.Sprintf("https://%s/get?a=b&c=d", addr), d)
fmt.Sprintf("https://%s/test?a=b&c=d", addr), d)
req.Header.Add("User-Agent", "nghttp2/1.32")
req.Header.Add("Content-Type", "text/palin")
res, err := h2conn.CreateRequest(req)
res, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
@ -146,3 +159,72 @@ func TestHttp2Server(t *testing.T) {
t.Errorf("expect %s, got %s", "hello", string(data))
}
}
func TestHttp2Handler(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
srv := &http.Server{
TLSConfig: &tls.Config{
NextProtos: []string{"h2", "http/1.1"},
},
TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){
"h2": HTTP2Handler,
},
}
defer srv.Close()
testdata := "asc fasdf32ddfasfff\r\nassdf312313"
addr := l.Addr().String()
go func() {
http.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
hdr := w.Header()
hdr.Set("content-type", "text/plain")
hdr.Set("aa", "bb")
fmt.Fprintf(w, testdata)
})
http.Handle("/", http.FileServer(http.Dir("/")))
srv.ServeTLS(l, "testdata/server.crt", "testdata/server.key")
}()
client := &http.Client{
Transport: &http2.Transport{
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
conn, err := tls.Dial(network, addr, &tls.Config{
NextProtos: []string{"h2", "http/1.1"},
InsecureSkipVerify: true,
})
if err := conn.Handshake(); err != nil {
return nil, err
}
return conn, err
},
},
}
u := fmt.Sprintf("https://%s/test", addr)
resp, err := client.Get(u)
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != http.StatusOK {
t.Errorf("http error %d", resp.StatusCode)
}
if resp.TLS == nil {
t.Errorf("not tls")
}
if resp.TLS.NegotiatedProtocol != "h2" {
t.Errorf("http2 is not enabled")
}
d, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Error(err)
}
if string(d) != testdata {
t.Errorf("expect %s, got %s", testdata, string(d))
}
if resp.Header.Get("aa") != "bb" {
t.Errorf("expect header not found")
}
//io.Copy(os.Stdout, resp.Body)
resp.Write(os.Stdout)
}

Loading…
Cancel
Save