use context

master
dingjun 6 years ago
parent e9fc2a99ad
commit c9eb6957be

@ -67,7 +67,7 @@ func (h *HTTPTransport) Call(method string, args interface{}, reply interface{})
log.Debugf("send %s", data) log.Debugf("send %s", data)
body := bytes.NewBuffer(data) body := bytes.NewReader(data)
req, err := http.NewRequest("POST", h.URL, body) req, err := http.NewRequest("POST", h.URL, body)
if err != nil { if err != nil {

@ -1,6 +1,7 @@
package jsonrpc package jsonrpc
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -66,7 +67,7 @@ func NewClient(uri string) (*Client, error) {
return &Client{Transport: tr, URL: uri}, nil return &Client{Transport: tr, URL: uri}, nil
} }
if t == "ws" { if t == "ws" {
tr, _ := NewWebsocketTransport(uri) tr, _ := NewWebsocketTransport(context.Background(), uri)
return &Client{Transport: tr, URL: uri}, nil return &Client{Transport: tr, URL: uri}, nil
} }
return nil, errors.New("not supported") return nil, errors.New("not supported")

@ -1,7 +1,9 @@
package jsonrpc package jsonrpc
import ( import (
"context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"sync" "sync"
@ -18,6 +20,8 @@ type WebsocketTransport struct {
inflight map[string]*inflightReq inflight map[string]*inflightReq
nextid uint64 nextid uint64
err error err error
ctx context.Context
cancelFunc context.CancelFunc
} }
type inflightReq struct { type inflightReq struct {
@ -26,16 +30,19 @@ type inflightReq struct {
errch chan *Error errch chan *Error
} }
// ErrConnClosed error for connection closed
var ErrConnClosed = errors.New("connection closed")
var _ Transport = &WebsocketTransport{} var _ Transport = &WebsocketTransport{}
// NewWebsocketTransport create a new websocket transport // NewWebsocketTransport create a new websocket transport
func NewWebsocketTransport(uri string) (Transport, error) { func NewWebsocketTransport(ctx context.Context, uri string) (Transport, error) {
dialer := &websocket.Dialer{} var dialer = &websocket.Dialer{}
conn, res, err := dialer.DialContext(ctx, uri, nil)
conn, res, err := dialer.Dial(uri, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer res.Body.Close()
if res.StatusCode != http.StatusSwitchingProtocols { if res.StatusCode != http.StatusSwitchingProtocols {
return nil, fmt.Errorf("http error %d", res.StatusCode) return nil, fmt.Errorf("http error %d", res.StatusCode)
@ -47,16 +54,35 @@ func NewWebsocketTransport(uri string) (Transport, error) {
inflight: make(map[string]*inflightReq), inflight: make(map[string]*inflightReq),
Mu: new(sync.Mutex), Mu: new(sync.Mutex),
} }
w.ctx, w.cancelFunc = context.WithCancel(ctx)
go w.readloop() go w.readloop()
return w, nil return w, nil
} }
// Context return the context transport used
func (h *WebsocketTransport) Context() context.Context {
return h.ctx
}
func (h *WebsocketTransport) readloop() { func (h *WebsocketTransport) readloop() {
defer func() {
//log.Debugf("close websocket connection")
h.Conn.Close()
//log.Debugf("cancel context")
h.cancelFunc()
}()
for { for {
select {
case <-h.ctx.Done():
return
default:
}
var res response var res response
_, data, err := h.Conn.ReadMessage() _, data, err := h.Conn.ReadMessage()
if err != nil { if err != nil {
log.Errorln(err)
h.err = err h.err = err
return return
} }
@ -122,6 +148,13 @@ func (h *WebsocketTransport) nextID() uint64 {
// Call call a remote method // Call call a remote method
func (h *WebsocketTransport) Call(method string, args interface{}, reply interface{}) error { func (h *WebsocketTransport) Call(method string, args interface{}, reply interface{}) error {
select {
case <-h.ctx.Done():
return ErrConnClosed
default:
}
if h.err != nil { if h.err != nil {
return h.err return h.err
} }

Loading…
Cancel
Save