From 2b973b326d5ee2f7a3ff425ac26e5ae04eea1814 Mon Sep 17 00:00:00 2001 From: fangdingjun Date: Thu, 17 May 2018 17:57:43 +0800 Subject: [PATCH] fix hang issue on close remote listener --- client.go | 85 +++++++++++++++++++++++++++++++------------------------ 1 file changed, 48 insertions(+), 37 deletions(-) diff --git a/client.go b/client.go index c914a6e..b71c885 100644 --- a/client.go +++ b/client.go @@ -3,12 +3,12 @@ package obfssh import ( "bufio" "fmt" - "io" "net" "net/http" "os" "os/signal" "strings" + "sync" "syscall" "time" @@ -80,21 +80,51 @@ func (cc *Client) Run() error { Log(INFO, "got signal, exit") } } - Log(DEBUG, "Done") cc.Close() + Log(DEBUG, "Done") return cc.err } +func (cc *Client) closeListener() { + if len(cc.listeners) == 0 { + return + } + + // close remote listener may block, because of connection issue + // so only 1 second to wait + + wg := &sync.WaitGroup{} + for _, l := range cc.listeners { + go func(l net.Listener) { + Log(DEBUG, "begin to close listener %s", l.Addr().String()) + l.Close() + Log(DEBUG, "close listener %s done", l.Addr().String()) + wg.Done() + }(l) + wg.Add(1) + } + + ch := make(chan struct{}) + go func() { + wg.Wait() + ch <- struct{}{} + }() + + select { + case <-ch: + case <-time.After(1 * time.Second): + } +} + // Close close the ssh connection // and free all the port forward resources func (cc *Client) Close() { - for _, l := range cc.listeners { - Log(INFO, "close the listener %s", l.Addr()) - l.Close() - } - //Log(DEBUG, "close ssh connection") + cc.closeListener() + + Log(DEBUG, "close ssh connection") cc.sshConn.Close() cc.conn.Close() + Log(DEBUG, "close ssh connection done") } // RunCmd run a single command on server @@ -304,6 +334,7 @@ func (cc *Client) keepAlive(interval time.Duration, maxCount int) { count = 0 } case <-time.After(3 * time.Second): + Log(DEBUG, "keep alive timed out") count++ } @@ -366,13 +397,14 @@ func (cc *Client) AddDynamicHTTPForward(addr string) error { } func (cc *Client) handleHTTPIncoming(c net.Conn) { - defer c.Close() + //defer c.Close() r := bufio.NewReader(c) req, err := http.ReadRequest(r) if err != nil { Log(ERROR, "read http request error %s", err) + c.Close() return } @@ -388,27 +420,16 @@ func (cc *Client) handleConnect(req *http.Request, c net.Conn) { c1, err := cc.client.Dial("tcp", req.RequestURI) if err != nil { + c.Close() fmt.Fprintf(c, "HTTP/1.0 503 connection failed\r\n\r\n") Log(ERROR, "dial error %s", err) return } - defer c1.Close() + //defer c1.Close() fmt.Fprintf(c, "HTTP/1.0 200 connection established\r\n\r\n") - - ch := make(chan struct{}, 2) - go func() { - io.Copy(c1, c) - ch <- struct{}{} - }() - - go func() { - io.Copy(c, c1) - ch <- struct{}{} - }() - - <-ch + PipeAndClose(c, c1) } func (cc *Client) handleHTTPReq(req *http.Request, c net.Conn) { @@ -420,29 +441,19 @@ func (cc *Client) handleHTTPReq(req *http.Request, c net.Conn) { Log(DEBUG, "request to %s", host) c1, err := cc.client.Dial("tcp", host) if err != nil { + c.Close() fmt.Fprintf(c, "HTTP/1.1 503 connection failed\r\nConnection: close\r\n\r\n") Log(ERROR, "connection failed %s", err) return } - defer c1.Close() + //defer c1.Close() if err = req.Write(c1); err != nil { fmt.Fprintf(c, "HTTP/1.1 503 write to server error\r\nConnection: close\r\n\r\n") Log(ERROR, "write request to server error %s", err) + c.Close() + c1.Close() return } - - ch := make(chan struct{}, 2) - - go func() { - io.Copy(c1, c) - ch <- struct{}{} - }() - - go func() { - io.Copy(c, c1) - ch <- struct{}{} - }() - - <-ch + PipeAndClose(c, c1) }