finish server/client logical

merge_conn
fangdingjun 6 years ago
parent c87800ed83
commit bf4658d8f4

@ -5,12 +5,15 @@ package nghttp2
*/
import "C"
import (
"bytes"
"crypto/tls"
"errors"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"unsafe"
)
@ -49,7 +52,7 @@ func onDataSourceReadCallback(ptr unsafe.Pointer, streamID C.int,
}
if err == errAgain {
//log.Println("onDataSourceReadCallback end")
s.dp.deferred = true
//s.dp.deferred = true
return NGHTTP2_ERR_DEFERRED
}
//log.Println("onDataSourceReadCallback end")
@ -82,7 +85,7 @@ func onDataChunkRecv(ptr unsafe.Pointer, streamID C.int,
//log.Println("onDataChunkRecv end")
return C.int(length)
}
//log.Println("bp write")
n, err := s.bp.Write(gobuf)
if err != nil {
return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE
@ -118,34 +121,47 @@ func onBeginHeaderCallback(ptr unsafe.Pointer, streamID C.int) C.int {
//log.Printf("stream %d begin headers", int(streamID))
conn := (*Conn)(unsafe.Pointer(uintptr(ptr)))
s, ok := conn.streams[int(streamID)]
if !ok {
//log.Println("onBeginHeaderCallback end")
return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE
}
var TLS tls.ConnectionState
if tlsconn, ok := conn.conn.(*tls.Conn); ok {
TLS = tlsconn.ConnectionState()
}
if conn.isServer {
s.request = &http.Request{
Header: make(http.Header),
// client
if !conn.isServer {
s, ok := conn.streams[int(streamID)]
if !ok {
//log.Println("onBeginHeaderCallback end")
return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE
}
var TLS tls.ConnectionState
if tlsconn, ok := conn.conn.(*tls.Conn); ok {
TLS = tlsconn.ConnectionState()
}
s.response = &http.Response{
Proto: "HTTP/2",
ProtoMajor: 2,
ProtoMinor: 0,
TLS: &TLS,
Header: make(http.Header),
Body: s.bp,
TLS: &TLS,
}
return NGHTTP2_NO_ERROR
}
s.response = &http.Response{
Proto: "HTTP/2",
ProtoMajor: 2,
ProtoMinor: 0,
Header: make(http.Header),
Body: s.bp,
TLS: &TLS,
// server
s := &stream{
streamID: int(streamID),
conn: conn,
bp: &bodyProvider{
buf: new(bytes.Buffer),
lock: new(sync.Mutex),
},
request: &http.Request{
Header: make(http.Header),
Proto: "HTTP/2",
ProtoMajor: 2,
ProtoMinor: 0,
},
}
s.request.Body = s.bp
conn.streams[int(streamID)] = s
//log.Println("onBeginHeaderCallback end")
return NGHTTP2_NO_ERROR
}
@ -157,7 +173,7 @@ func onHeaderCallback(ptr unsafe.Pointer, streamID C.int,
name unsafe.Pointer, namelen C.int,
value unsafe.Pointer, valuelen C.int) C.int {
//log.Println("onHeaderCallback begin")
//log.Println("header")
//log.Printf("header %d", int(streamID))
conn := (*Conn)(unsafe.Pointer(uintptr(ptr)))
goname := string(C.GoBytes(name, namelen))
govalue := string(C.GoBytes(value, valuelen))
@ -176,7 +192,16 @@ func onHeaderCallback(ptr unsafe.Pointer, streamID C.int,
s.request.Host = govalue
case ":path":
s.request.RequestURI = govalue
u, err := url.Parse(govalue)
if err != nil {
return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE
}
s.request.URL = u
case ":status":
if s.response == nil {
//log.Println("empty response")
return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE
}
statusCode, _ := strconv.Atoi(govalue)
s.response.StatusCode = statusCode
s.response.Status = http.StatusText(statusCode)
@ -208,7 +233,11 @@ func onHeadersDoneCallback(ptr unsafe.Pointer, streamID C.int) C.int {
//log.Println("onHeadersDoneCallback end")
return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE
}
s.headersEnd = true
if conn.isServer {
if s.request.Method == "CONNECT" {
go conn.serve(s)
}
return NGHTTP2_NO_ERROR
}
select {
@ -257,5 +286,19 @@ func onConnectionCloseCallback(ptr unsafe.Pointer) {
//export onStreamEndCallback
func onStreamEndCallback(ptr unsafe.Pointer, streamID C.int) {
conn := (*Conn)(unsafe.Pointer(uintptr(ptr)))
stream, ok := conn.streams[int(streamID)]
if !ok {
return
}
stream.streamEnd = true
stream.bp.Close()
if stream.conn.isServer {
if stream.request.Method != "CONNECT" {
go conn.serve(stream)
}
return
}
}

@ -6,10 +6,13 @@ package nghttp2
*/
import "C"
import (
"bytes"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"sync"
"time"
"unsafe"
@ -30,14 +33,189 @@ type Conn struct {
exitch chan struct{}
}
// Server create server side http2 connection
func Server(c net.Conn, handler http.Handler) (*Conn, error) {
conn := &Conn{
conn: c,
handler: handler,
errch: make(chan error),
exitch: make(chan struct{}),
lock: new(sync.Mutex),
isServer: true,
streams: make(map[int]*stream),
}
conn.session = C.init_nghttp2_server_session(C.size_t(uintptr(unsafe.Pointer(conn))))
if conn.session == nil {
return nil, errors.New("init server session failed")
}
ret := C.send_connection_header(conn.session)
if int(ret) < 0 {
conn.Close()
return nil, fmt.Errorf("send settings error: %s", C.GoString(C.nghttp2_strerror(ret)))
}
return conn, nil
}
// Client create client side http2 connection
func Client(c net.Conn) (*Conn, error) {
conn := &Conn{
conn: c,
errch: make(chan error),
exitch: make(chan struct{}),
lock: new(sync.Mutex),
streams: make(map[int]*stream),
}
conn.session = C.init_nghttp2_client_session(C.size_t(uintptr(unsafe.Pointer(conn))))
if conn.session == nil {
return nil, errors.New("init server session failed")
}
ret := C.send_connection_header(conn.session)
if int(ret) < 0 {
conn.Close()
return nil, fmt.Errorf("send settings error: %s", C.GoString(C.nghttp2_strerror(ret)))
}
go conn.Run()
return conn, nil
}
// Error return conn error
func (c *Conn) Error() error {
c.lock.Lock()
defer c.lock.Unlock()
return c.err
}
// CanTakeNewRequest check if conn can create new request
func (c *Conn) CanTakeNewRequest() bool {
if c.streamCount > ((2<<31 - 1) / 2) {
return false
}
if c.err != nil {
return false
}
return true
}
// RoundTrip submit http request and return the response
func (c *Conn) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, errors.New("not implement")
nv := []C.nghttp2_nv{}
nv = append(nv, newNV(":method", req.Method))
nv = append(nv, newNV(":authority", req.Host))
nv = append(nv, newNV(":scheme", "https"))
p := req.URL.Path
q := req.URL.Query().Encode()
if q != "" {
p = fmt.Sprintf("%s?%s", p, q)
}
nv = append(nv, newNV(":path", p))
cdp := C.nghttp2_data_provider{}
dp := newDataProvider(unsafe.Pointer(&cdp), c.lock, 1)
dp.session = c.session
var _cdp *C.nghttp2_data_provider
if req.Method == "POST" || req.Method == "PUT" {
_cdp = &cdp
}
s, err := c.submitRequest(nv, _cdp)
if err != nil {
return nil, err
}
s.dp = dp
s.dp.streamID = s.streamID
c.lock.Lock()
c.streams[s.streamID] = s
c.streamCount++
c.lock.Unlock()
if req.Method == "POST" || req.Method == "PUT" {
go func() {
io.Copy(dp, req.Body)
dp.Close()
}()
}
select {
case res := <-s.resch:
/*
if res.StatusCode != http.StatusOK {
return nil, fmt.Errorf("http error code %d", res.StatusCode)
}
*/
s.request = req
res.Request = s.request
return res, nil
case <-c.exitch:
return nil, errors.New("connection closed")
}
}
func (c *Conn) submitRequest(nv []C.nghttp2_nv, cdp *C.nghttp2_data_provider) (*stream, error) {
c.lock.Lock()
ret := C._nghttp2_submit_request(c.session, nil,
C.size_t(uintptr(unsafe.Pointer(&nv[0]))), C.size_t(len(nv)), cdp, nil)
c.lock.Unlock()
if int(ret) < 0 {
return nil, fmt.Errorf("submit request error: %s", C.GoString(C.nghttp2_strerror(ret)))
}
streamID := int(ret)
s := &stream{
streamID: streamID,
conn: c,
bp: &bodyProvider{
buf: new(bytes.Buffer),
lock: new(sync.Mutex),
},
resch: make(chan *http.Response),
}
if cdp != nil {
s.cdp = *cdp
}
return s, nil
}
// Connect submit connect request
func (c *Conn) Connect(addr string) (net.Conn, error) {
return nil, errors.New("not implement")
func (c *Conn) Connect(addr string) (net.Conn, int, error) {
nv := []C.nghttp2_nv{}
nv = append(nv, newNV(":method", "CONNECT"))
nv = append(nv, newNV(":authority", addr))
cdp := C.nghttp2_data_provider{}
dp := newDataProvider(unsafe.Pointer(&cdp), c.lock, 1)
dp.session = c.session
s, err := c.submitRequest(nv, &cdp)
if err != nil {
return nil, http.StatusBadGateway, err
}
s.dp = dp
c.lock.Lock()
c.streams[s.streamID] = s
c.streamCount++
c.lock.Unlock()
s.dp.streamID = s.streamID
select {
case res := <-s.resch:
if res.StatusCode != http.StatusOK {
return nil, res.StatusCode, fmt.Errorf("http error code %d", res.StatusCode)
}
s.request = &http.Request{
Method: "CONNECT",
Host: addr,
RequestURI: addr,
URL: &url.URL{},
}
res.Request = s.request
return s, res.StatusCode, nil
case <-c.exitch:
return nil, http.StatusServiceUnavailable, errors.New("connection closed")
}
}
// Run run the event loop
@ -58,6 +236,15 @@ func (c *Conn) Run() {
}
}
func (c *Conn) serve(s *stream) {
var handler = c.handler
if handler == nil {
handler = http.DefaultServeMux
}
handler.ServeHTTP(s, s.request)
s.Close()
}
// Close close the connection
func (c *Conn) Close() error {
if c.closed {
@ -67,6 +254,12 @@ func (c *Conn) Close() error {
for _, s := range c.streams {
s.Close()
}
c.lock.Lock()
C.nghttp2_session_terminate_session(c.session, 0)
C.nghttp2_session_del(c.session)
c.lock.Unlock()
close(c.exitch)
c.conn.Close()
return nil
@ -148,6 +341,7 @@ func (c *Conn) writeloop() {
wantWrite := C.nghttp2_session_want_write(c.session)
c.lock.Unlock()
if int(wantWrite) == 0 {
//log.Println("write loop, sleep")
time.Sleep(delay)
}
}

@ -7,7 +7,6 @@ import "C"
import (
"bytes"
"errors"
"io"
"log"
"sync"
"time"
@ -31,7 +30,7 @@ type dataProvider struct {
// Read read from data provider
func (dp *dataProvider) Read(buf []byte) (n int, err error) {
if dp.buf == nil || dp.lock == nil || dp.sessLock == nil || dp.session == nil {
log.Println("db read invalid state")
log.Println("dp read invalid state")
return 0, errors.New("invalid state")
}
dp.lock.Lock()
@ -39,6 +38,8 @@ func (dp *dataProvider) Read(buf []byte) (n int, err error) {
n, err = dp.buf.Read(buf)
if err != nil && !dp.closed {
//log.Println("deferred")
dp.deferred = true
return 0, errAgain
}
return
@ -53,18 +54,20 @@ func (dp *dataProvider) Write(buf []byte) (n int, err error) {
dp.lock.Lock()
defer dp.lock.Unlock()
if dp.closed {
return 0, io.EOF
}
//if dp.closed {
// return 0, io.EOF
//}
n, err = dp.buf.Write(buf)
if dp.deferred {
dp.sessLock.Lock()
C.nghttp2_session_resume_data(dp.session, C.int(dp.streamID))
dp.sessLock.Unlock()
//log.Println("resume")
dp.deferred = false
}
return dp.buf.Write(buf)
return
}
// Close end to provide data

@ -23,8 +23,8 @@ int on_invalid_frame_recv_callback(nghttp2_session *session,
return 0;
}
static ssize_t on_data_source_read_callback(nghttp2_session *session, int32_t stream_id,
uint8_t *buf, size_t length, uint32_t *data_flags,
nghttp2_data_source *source, void *user_data)
uint8_t *buf, size_t length, uint32_t *data_flags,
nghttp2_data_source *source, void *user_data)
{
int ret = onDataSourceReadCallback(user_data, stream_id, buf, length);
if (ret == 0)
@ -48,7 +48,8 @@ static int on_frame_recv_callback(nghttp2_session *session,
switch (frame->hd.type)
{
case NGHTTP2_HEADERS:
if (frame->headers.cat == NGHTTP2_HCAT_REQUEST)
if (frame->headers.cat == NGHTTP2_HCAT_REQUEST ||
frame->headers.cat == NGHTTP2_HCAT_RESPONSE)
{
onHeadersDoneCallback(user_data, frame->hd.stream_id);
}
@ -82,7 +83,8 @@ static int on_header_callback(nghttp2_session *session,
switch (frame->hd.type)
{
case NGHTTP2_HEADERS:
if (frame->headers.cat == NGHTTP2_HCAT_REQUEST)
if (frame->headers.cat == NGHTTP2_HCAT_REQUEST ||
frame->headers.cat == NGHTTP2_HCAT_RESPONSE)
{
onHeaderCallback(user_data, frame->hd.stream_id,
(void *)name, namelen, (void *)value, valuelen);
@ -109,7 +111,8 @@ static int on_begin_headers_callback(nghttp2_session *session,
switch (frame->hd.type)
{
case NGHTTP2_HEADERS:
if (frame->headers.cat == NGHTTP2_HCAT_REQUEST)
if (frame->headers.cat == NGHTTP2_HCAT_REQUEST ||
frame->headers.cat == NGHTTP2_HCAT_RESPONSE)
{
onBeginHeaderCallback(user_data, frame->hd.stream_id);
}
@ -123,6 +126,7 @@ nghttp2_session *init_nghttp2_server_session(size_t data)
nghttp2_session_callbacks *callbacks;
nghttp2_session *session;
nghttp2_session_callbacks_new(&callbacks);
init_nghttp2_callbacks(callbacks);
nghttp2_session_server_new(&session, callbacks, (void *)((int *)(data)));
@ -177,8 +181,26 @@ int send_connection_header(nghttp2_session *session)
return rv;
}
int data_provider_set_callback(size_t dp, size_t data, int t){
nghttp2_data_provider *cdp = (nghttp2_data_provider*)dp;
int data_provider_set_callback(size_t dp, size_t data, int t)
{
nghttp2_data_provider *cdp = (nghttp2_data_provider *)dp;
cdp->source.ptr = (void *)data;
cdp->read_callback=on_data_source_read_callback;
cdp->read_callback = on_data_source_read_callback;
}
int _nghttp2_submit_response(nghttp2_session *sess, int streamid,
size_t nv, size_t nvlen, nghttp2_data_provider *dp)
{
return nghttp2_submit_response(sess, streamid, (nghttp2_nv *)nv, nvlen, dp);
}
int _nghttp2_submit_request(nghttp2_session *session, const nghttp2_priority_spec *pri_spec,
size_t nva, size_t nvlen,
const nghttp2_data_provider *data_prd, void *stream_user_data)
{
return nghttp2_submit_request(session, pri_spec,
(nghttp2_nv *)nva, nvlen,
data_prd, stream_user_data);
}

@ -1,48 +1,147 @@
package nghttp2
/*
#include "_nghttp2.h"
*/
import "C"
import (
"errors"
"fmt"
"net"
"net/http"
"strings"
"time"
"unsafe"
)
type stream struct {
streamID int
conn *Conn
dp *dataProvider
bp *bodyProvider
request *http.Request
response *http.Response
resch chan *http.Response
streamID int
conn *Conn
dp *dataProvider
bp *bodyProvider
request *http.Request
response *http.Response
resch chan *http.Response
headersEnd bool
streamEnd bool
closed bool
cdp C.nghttp2_data_provider
}
var _ net.Conn = &stream{}
func (s *stream) Read(buf []byte) (int, error) {
return 0, errors.New("not implement")
if s.bp != nil {
return s.bp.Read(buf)
}
return 0, errors.New("empty body")
}
func (s *stream) WriteHeader(code int) {
if s.response == nil {
s.response = &http.Response{
Proto: "http/2",
ProtoMajor: 2,
ProtoMinor: 0,
Header: make(http.Header),
}
}
if s.response.StatusCode != 0 {
return
}
s.response.StatusCode = code
s.response.Status = http.StatusText(code)
nv := []C.nghttp2_nv{}
nv = append(nv, newNV(":status", fmt.Sprintf("%d", code)))
for k, v := range s.response.Header {
_k := strings.ToLower(k)
if _k == "host" || _k == "connection" || _k == "transfer-encoding" {
continue
}
nv = append(nv, newNV(k, v[0]))
}
s.cdp = C.nghttp2_data_provider{}
s.dp = newDataProvider(unsafe.Pointer(&s.cdp), s.conn.lock, 0)
s.dp.session = s.conn.session
s.dp.streamID = s.streamID
s.conn.lock.Lock()
ret := C._nghttp2_submit_response(s.conn.session, C.int(s.streamID),
C.size_t(uintptr(unsafe.Pointer(&nv[0]))), C.size_t(len(nv)), &s.cdp)
s.conn.lock.Unlock()
if int(ret) < 0 {
panic(fmt.Sprintf("submit response error: %s", C.GoString(C.nghttp2_strerror(ret))))
}
}
func (s *stream) Header() http.Header {
if s.response == nil {
s.response = &http.Response{
Proto: "http/2",
ProtoMajor: 2,
ProtoMinor: 0,
Header: make(http.Header),
}
}
return s.response.Header
}
func (s *stream) Write(buf []byte) (int, error) {
if s.conn.isServer {
return 0, errors.New("not implement")
if s.conn.isServer && s.response == nil {
s.WriteHeader(http.StatusOK)
}
if s.dp != nil {
return s.dp.Write(buf)
}
return 0, errors.New("not implement")
return 0, errors.New("empty dp")
}
func (s *stream) Close() error {
if s.closed {
return nil
}
s.closed = true
if s.dp != nil {
s.dp.Close()
}
if s.bp != nil {
s.bp.Close()
}
//s.conn.lock.Lock()
//if _, ok := s.conn.streams[s.streamID]; ok {
// delete(s.conn.streams, s.streamID)
///}
//s.conn.lock.Unlock()
if s.request != nil && s.request.Method == "CONNECT" {
//log.Println("rst stream")
s.conn.lock.Lock()
C.nghttp2_submit_rst_stream(s.conn.session, 0, C.int(s.streamID), 8)
s.conn.lock.Unlock()
}
return nil
}
func (s *stream) LocalAddr() net.Addr {
return nil
}
func (s *stream) RemoteAddr() net.Addr {
return nil
}
func (s *stream) SetDeadline(t time.Time) error {
return errors.New("not implement")
}
func (s *stream) SetReadDeadline(t time.Time) error {
return errors.New("not implement")
}
func (s *stream) SetWriteDeadline(t time.Time) error {
return errors.New("not implement")
}

Loading…
Cancel
Save