From c9eb6957be34350e196e75e4ac627d01d15f3d41 Mon Sep 17 00:00:00 2001 From: dingjun Date: Wed, 30 Jan 2019 17:58:40 +0800 Subject: [PATCH] use context --- http_transport.go | 2 +- jsonrpc.go | 3 ++- websocket_transport.go | 53 ++++++++++++++++++++++++++++++++++-------- 3 files changed, 46 insertions(+), 12 deletions(-) diff --git a/http_transport.go b/http_transport.go index 1706320..028c1bb 100644 --- a/http_transport.go +++ b/http_transport.go @@ -67,7 +67,7 @@ func (h *HTTPTransport) Call(method string, args interface{}, reply interface{}) log.Debugf("send %s", data) - body := bytes.NewBuffer(data) + body := bytes.NewReader(data) req, err := http.NewRequest("POST", h.URL, body) if err != nil { diff --git a/jsonrpc.go b/jsonrpc.go index 6a25e7e..f7d851d 100644 --- a/jsonrpc.go +++ b/jsonrpc.go @@ -1,6 +1,7 @@ package jsonrpc import ( + "context" "encoding/json" "errors" "fmt" @@ -66,7 +67,7 @@ func NewClient(uri string) (*Client, error) { return &Client{Transport: tr, URL: uri}, nil } if t == "ws" { - tr, _ := NewWebsocketTransport(uri) + tr, _ := NewWebsocketTransport(context.Background(), uri) return &Client{Transport: tr, URL: uri}, nil } return nil, errors.New("not supported") diff --git a/websocket_transport.go b/websocket_transport.go index 27a749f..c75d75a 100644 --- a/websocket_transport.go +++ b/websocket_transport.go @@ -1,7 +1,9 @@ package jsonrpc import ( + "context" "encoding/json" + "errors" "fmt" "net/http" "sync" @@ -12,12 +14,14 @@ import ( // WebsocketTransport json rpc over websocket type WebsocketTransport struct { - Conn *websocket.Conn - URL string - Mu *sync.Mutex - inflight map[string]*inflightReq - nextid uint64 - err error + Conn *websocket.Conn + URL string + Mu *sync.Mutex + inflight map[string]*inflightReq + nextid uint64 + err error + ctx context.Context + cancelFunc context.CancelFunc } type inflightReq struct { @@ -26,16 +30,19 @@ type inflightReq struct { errch chan *Error } +// ErrConnClosed error for connection closed +var ErrConnClosed = errors.New("connection closed") + var _ Transport = &WebsocketTransport{} // NewWebsocketTransport create a new websocket transport -func NewWebsocketTransport(uri string) (Transport, error) { - dialer := &websocket.Dialer{} - - conn, res, err := dialer.Dial(uri, nil) +func NewWebsocketTransport(ctx context.Context, uri string) (Transport, error) { + var dialer = &websocket.Dialer{} + conn, res, err := dialer.DialContext(ctx, uri, nil) if err != nil { return nil, err } + defer res.Body.Close() if res.StatusCode != http.StatusSwitchingProtocols { return nil, fmt.Errorf("http error %d", res.StatusCode) @@ -47,16 +54,35 @@ func NewWebsocketTransport(uri string) (Transport, error) { inflight: make(map[string]*inflightReq), Mu: new(sync.Mutex), } + w.ctx, w.cancelFunc = context.WithCancel(ctx) go w.readloop() return w, nil } +// Context return the context transport used +func (h *WebsocketTransport) Context() context.Context { + return h.ctx +} + func (h *WebsocketTransport) readloop() { + defer func() { + //log.Debugf("close websocket connection") + h.Conn.Close() + //log.Debugf("cancel context") + h.cancelFunc() + }() + for { + select { + case <-h.ctx.Done(): + return + default: + } var res response _, data, err := h.Conn.ReadMessage() if err != nil { + log.Errorln(err) h.err = err return } @@ -122,6 +148,13 @@ func (h *WebsocketTransport) nextID() uint64 { // Call call a remote method func (h *WebsocketTransport) Call(method string, args interface{}, reply interface{}) error { + + select { + case <-h.ctx.Done(): + return ErrConnClosed + default: + } + if h.err != nil { return h.err }