add http2 handler can cowork with net/http

merge_conn
fangdingjun 6 years ago
parent 8d4069b1c2
commit 790bc1b7c8

@ -7,7 +7,6 @@ import "C"
import ( import (
"bytes" "bytes"
"io" "io"
"log"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -16,7 +15,8 @@ import (
) )
// OnServerDataRecvCallback callback function for libnghttp2 library // OnServerDataRecvCallback callback function for libnghttp2 library
// want receive data from network, // want receive data from network.
//
//export OnServerDataRecvCallback //export OnServerDataRecvCallback
func OnServerDataRecvCallback(ptr unsafe.Pointer, data unsafe.Pointer, func OnServerDataRecvCallback(ptr unsafe.Pointer, data unsafe.Pointer,
length C.size_t) C.ssize_t { length C.size_t) C.ssize_t {
@ -33,7 +33,8 @@ func OnServerDataRecvCallback(ptr unsafe.Pointer, data unsafe.Pointer,
} }
// OnServerDataSendCallback callback function for libnghttp2 library // OnServerDataSendCallback callback function for libnghttp2 library
// want send data to network // want send data to network.
//
//export OnServerDataSendCallback //export OnServerDataSendCallback
func OnServerDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer, func OnServerDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer,
length C.size_t) C.ssize_t { length C.size_t) C.ssize_t {
@ -48,7 +49,8 @@ func OnServerDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer,
return C.ssize_t(n) return C.ssize_t(n)
} }
// OnServerDataChunkRecv callback function for libnghttp2 library's data chunk recv // OnServerDataChunkRecv callback function for libnghttp2 library's data chunk recv.
//
//export OnServerDataChunkRecv //export OnServerDataChunkRecv
func OnServerDataChunkRecv(ptr unsafe.Pointer, streamID C.int, func OnServerDataChunkRecv(ptr unsafe.Pointer, streamID C.int,
data unsafe.Pointer, length C.size_t) C.int { data unsafe.Pointer, length C.size_t) C.int {
@ -60,7 +62,8 @@ func OnServerDataChunkRecv(ptr unsafe.Pointer, streamID C.int,
return C.int(length) return C.int(length)
} }
// OnServerBeginHeaderCallback callback function for begin begin header recv // OnServerBeginHeaderCallback callback function for begin begin header recv.
//
//export OnServerBeginHeaderCallback //export OnServerBeginHeaderCallback
func OnServerBeginHeaderCallback(ptr unsafe.Pointer, streamID C.int) C.int { func OnServerBeginHeaderCallback(ptr unsafe.Pointer, streamID C.int) C.int {
conn := (*ServerConn)(ptr) conn := (*ServerConn)(ptr)
@ -80,7 +83,8 @@ func OnServerBeginHeaderCallback(ptr unsafe.Pointer, streamID C.int) C.int {
return 0 return 0
} }
// OnServerHeaderCallback callback function for each header recv // OnServerHeaderCallback callback function for each header recv.
//
//export OnServerHeaderCallback //export OnServerHeaderCallback
func OnServerHeaderCallback(ptr unsafe.Pointer, streamID C.int, func OnServerHeaderCallback(ptr unsafe.Pointer, streamID C.int,
name unsafe.Pointer, namelen C.int, name unsafe.Pointer, namelen C.int,
@ -113,6 +117,7 @@ func OnServerHeaderCallback(ptr unsafe.Pointer, streamID C.int,
} }
// OnServerStreamEndCallback callback function for the stream when END_STREAM flag set // OnServerStreamEndCallback callback function for the stream when END_STREAM flag set
//
//export OnServerStreamEndCallback //export OnServerStreamEndCallback
func OnServerStreamEndCallback(ptr unsafe.Pointer, streamID C.int) C.int { func OnServerStreamEndCallback(ptr unsafe.Pointer, streamID C.int) C.int {
@ -122,13 +127,14 @@ func OnServerStreamEndCallback(ptr unsafe.Pointer, streamID C.int) C.int {
bp := s.req.Body.(*bodyProvider) bp := s.req.Body.(*bodyProvider)
if s.req.Method != "CONNECT" { if s.req.Method != "CONNECT" {
bp.closed = true bp.closed = true
log.Println("stream end flag set, begin to serve") //log.Println("stream end flag set, begin to serve")
go conn.serve(s) go conn.serve(s)
} }
return 0 return 0
} }
// OnServerHeadersDoneCallback callback function for the stream when all headers received // OnServerHeadersDoneCallback callback function for the stream when all headers received.
//
//export OnServerHeadersDoneCallback //export OnServerHeadersDoneCallback
func OnServerHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int { func OnServerHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int {
conn := (*ServerConn)(ptr) conn := (*ServerConn)(ptr)
@ -145,7 +151,8 @@ func OnServerHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int {
return 0 return 0
} }
// OnServerStreamClose callback function for the stream when closed // OnServerStreamClose callback function for the stream when closed.
//
//export OnServerStreamClose //export OnServerStreamClose
func OnServerStreamClose(ptr unsafe.Pointer, streamID C.int) C.int { func OnServerStreamClose(ptr unsafe.Pointer, streamID C.int) C.int {
conn := (*ServerConn)(ptr) conn := (*ServerConn)(ptr)
@ -185,7 +192,8 @@ func OnDataSourceReadCallback(ptr unsafe.Pointer,
return C.ssize_t(n) return C.ssize_t(n)
} }
// OnClientDataChunkRecv callback function for libnghttp2 library data chunk received, // OnClientDataChunkRecv callback function for libnghttp2 library data chunk received.
//
//export OnClientDataChunkRecv //export OnClientDataChunkRecv
func OnClientDataChunkRecv(ptr unsafe.Pointer, streamID C.int, func OnClientDataChunkRecv(ptr unsafe.Pointer, streamID C.int,
buf unsafe.Pointer, length C.size_t) C.int { buf unsafe.Pointer, length C.size_t) C.int {
@ -196,7 +204,8 @@ func OnClientDataChunkRecv(ptr unsafe.Pointer, streamID C.int,
return 0 return 0
} }
// OnClientDataRecvCallback callback function for libnghttp2 library want read data from network, // OnClientDataRecvCallback callback function for libnghttp2 library want read data from network.
//
//export OnClientDataRecvCallback //export OnClientDataRecvCallback
func OnClientDataRecvCallback(ptr unsafe.Pointer, data unsafe.Pointer, size C.size_t) C.ssize_t { func OnClientDataRecvCallback(ptr unsafe.Pointer, data unsafe.Pointer, size C.size_t) C.ssize_t {
//log.Println("data read req", int(size)) //log.Println("data read req", int(size))
@ -214,7 +223,8 @@ func OnClientDataRecvCallback(ptr unsafe.Pointer, data unsafe.Pointer, size C.si
return C.ssize_t(n) return C.ssize_t(n)
} }
// OnClientDataSendCallback callback function for libnghttp2 library want send data to network, // OnClientDataSendCallback callback function for libnghttp2 library want send data to network.
//
//export OnClientDataSendCallback //export OnClientDataSendCallback
func OnClientDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer, size C.size_t) C.ssize_t { func OnClientDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer, size C.size_t) C.ssize_t {
//log.Println("data write req ", int(size)) //log.Println("data write req ", int(size))
@ -230,7 +240,8 @@ func OnClientDataSendCallback(ptr unsafe.Pointer, data unsafe.Pointer, size C.si
return C.ssize_t(n) return C.ssize_t(n)
} }
// OnClientBeginHeaderCallback callback function for begin header receive, // OnClientBeginHeaderCallback callback function for begin header receive.
//
//export OnClientBeginHeaderCallback //export OnClientBeginHeaderCallback
func OnClientBeginHeaderCallback(ptr unsafe.Pointer, streamID C.int) C.int { func OnClientBeginHeaderCallback(ptr unsafe.Pointer, streamID C.int) C.int {
//log.Println("begin header") //log.Println("begin header")
@ -239,7 +250,8 @@ func OnClientBeginHeaderCallback(ptr unsafe.Pointer, streamID C.int) C.int {
return 0 return 0
} }
// OnClientHeaderCallback callback function for each header received, // OnClientHeaderCallback callback function for each header received.
//
//export OnClientHeaderCallback //export OnClientHeaderCallback
func OnClientHeaderCallback(ptr unsafe.Pointer, streamID C.int, func OnClientHeaderCallback(ptr unsafe.Pointer, streamID C.int,
name unsafe.Pointer, namelen C.int, name unsafe.Pointer, namelen C.int,
@ -252,7 +264,8 @@ func OnClientHeaderCallback(ptr unsafe.Pointer, streamID C.int,
return 0 return 0
} }
// OnClientHeadersDoneCallback callback function for the stream when all headers received, // OnClientHeadersDoneCallback callback function for the stream when all headers received.
//
//export OnClientHeadersDoneCallback //export OnClientHeadersDoneCallback
func OnClientHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int { func OnClientHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int {
//log.Println("frame recv") //log.Println("frame recv")
@ -261,7 +274,8 @@ func OnClientHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int {
return 0 return 0
} }
// OnClientStreamClose callback function for the stream when closed, // OnClientStreamClose callback function for the stream when closed.
//
//export OnClientStreamClose //export OnClientStreamClose
func OnClientStreamClose(ptr unsafe.Pointer, streamID C.int) C.int { func OnClientStreamClose(ptr unsafe.Pointer, streamID C.int) C.int {
//log.Println("stream close") //log.Println("stream close")

@ -7,6 +7,7 @@ package nghttp2
import "C" import "C"
import ( import (
"bytes" "bytes"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -290,6 +291,27 @@ type ServerConn struct {
err error err error
} }
// HTTP2Handler is the http2 server handler that can co-work with standard net/http.
//
// usage example:
// l, err := net.Listen("tcp", ":1222")
// srv := &http.Server{
// TLSConfig: &tls.Config{
// NextProtos:[]string{"h2", "http/1.1"},
// }
// TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){
// "h2": nghttp2.Http2Handler
// }
// }
// srv.ServeTLS(l, "server.crt", "server.key")
func HTTP2Handler(srv *http.Server, conn *tls.Conn, handler http.Handler) {
h2conn, err := NewServerConn(conn, handler)
if err != nil {
panic(err.Error())
}
h2conn.Run()
}
// NewServerConn create new server connection // NewServerConn create new server connection
func NewServerConn(c net.Conn, handler http.Handler) (*ServerConn, error) { func NewServerConn(c net.Conn, handler http.Handler) (*ServerConn, error) {
conn := &ServerConn{ conn := &ServerConn{

@ -6,10 +6,13 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log" "log"
"net"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"testing" "testing"
"golang.org/x/net/http2"
) )
func TestHttp2Client(t *testing.T) { func TestHttp2Client(t *testing.T) {
@ -80,7 +83,7 @@ func TestHttp2Server(t *testing.T) {
defer l.Close() defer l.Close()
addr := l.Addr().String() addr := l.Addr().String()
go func() { go func() {
http.HandleFunc("/get", func(w http.ResponseWriter, r *http.Request) { http.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
log.Printf("%+v", r) log.Printf("%+v", r)
hdr := w.Header() hdr := w.Header()
hdr.Set("content-type", "text/plain") hdr.Set("content-type", "text/plain")
@ -119,16 +122,26 @@ func TestHttp2Server(t *testing.T) {
if cstate.NegotiatedProtocol != "h2" { if cstate.NegotiatedProtocol != "h2" {
t.Fatal("no http2 on server") t.Fatal("no http2 on server")
} }
h2conn, err := NewClientConn(conn) client := &http.Client{
if err != nil { Transport: &http2.Transport{
t.Fatal(err) DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
conn, err := tls.Dial(network, addr, &tls.Config{
NextProtos: []string{"h2", "http/1.1"},
InsecureSkipVerify: true,
})
if err := conn.Handshake(); err != nil {
return nil, err
}
return conn, err
},
},
} }
d := bytes.NewBuffer([]byte("hello")) d := bytes.NewBuffer([]byte("hello"))
req, _ := http.NewRequest("POST", req, _ := http.NewRequest("POST",
fmt.Sprintf("https://%s/get?a=b&c=d", addr), d) fmt.Sprintf("https://%s/test?a=b&c=d", addr), d)
req.Header.Add("User-Agent", "nghttp2/1.32") req.Header.Add("User-Agent", "nghttp2/1.32")
req.Header.Add("Content-Type", "text/palin") req.Header.Add("Content-Type", "text/palin")
res, err := h2conn.CreateRequest(req) res, err := client.Do(req)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -146,3 +159,72 @@ func TestHttp2Server(t *testing.T) {
t.Errorf("expect %s, got %s", "hello", string(data)) t.Errorf("expect %s, got %s", "hello", string(data))
} }
} }
func TestHttp2Handler(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
srv := &http.Server{
TLSConfig: &tls.Config{
NextProtos: []string{"h2", "http/1.1"},
},
TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){
"h2": HTTP2Handler,
},
}
defer srv.Close()
testdata := "asc fasdf32ddfasfff\r\nassdf312313"
addr := l.Addr().String()
go func() {
http.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
hdr := w.Header()
hdr.Set("content-type", "text/plain")
hdr.Set("aa", "bb")
fmt.Fprintf(w, testdata)
})
http.Handle("/", http.FileServer(http.Dir("/")))
srv.ServeTLS(l, "testdata/server.crt", "testdata/server.key")
}()
client := &http.Client{
Transport: &http2.Transport{
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
conn, err := tls.Dial(network, addr, &tls.Config{
NextProtos: []string{"h2", "http/1.1"},
InsecureSkipVerify: true,
})
if err := conn.Handshake(); err != nil {
return nil, err
}
return conn, err
},
},
}
u := fmt.Sprintf("https://%s/test", addr)
resp, err := client.Get(u)
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != http.StatusOK {
t.Errorf("http error %d", resp.StatusCode)
}
if resp.TLS == nil {
t.Errorf("not tls")
}
if resp.TLS.NegotiatedProtocol != "h2" {
t.Errorf("http2 is not enabled")
}
d, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Error(err)
}
if string(d) != testdata {
t.Errorf("expect %s, got %s", testdata, string(d))
}
if resp.Header.Get("aa") != "bb" {
t.Errorf("expect header not found")
}
//io.Copy(os.Stdout, resp.Body)
resp.Write(os.Stdout)
}

Loading…
Cancel
Save