diff --git a/client.go b/client.go index b490fac..5859e9b 100644 --- a/client.go +++ b/client.go @@ -2,6 +2,8 @@ package obfssh import ( "bufio" + "context" + "errors" "fmt" "net" "net/http" @@ -24,8 +26,9 @@ type Client struct { sshConn ssh.Conn client *ssh.Client listeners []net.Listener - ch chan struct{} err error + ctx context.Context + cancel context.CancelFunc } // NewClient create a new ssh Client @@ -45,8 +48,8 @@ func NewClient(c net.Conn, config *ssh.ClientConfig, addr string, conf *Conf) (* sshClient := ssh.NewClient(sshConn, newch, reqs) client := &Client{ conn: c, sshConn: sshConn, client: sshClient, - ch: make(chan struct{}), } + client.ctx, client.cancel = context.WithCancel(context.Background()) go client.keepAlive(conf.KeepAliveInterval, conf.KeepAliveMax) return client, nil } @@ -58,6 +61,9 @@ func (cc *Client) Client() *ssh.Client { // Run wait ssh connection to finish func (cc *Client) Run() error { + defer cc.Close() + defer cc.cancel() + select { case <-time.After(1 * time.Second): } @@ -68,21 +74,11 @@ func (cc *Client) Run() error { go func() { cc.err = cc.sshConn.Wait() log.Debugf("connection hang up") + cc.cancel() //close(cc.ch) - select { - case cc.ch <- struct{}{}: - default: - } }() - - // wait exit signal - select { - case <-cc.ch: - log.Debugf("got signal, exit") - } + <-cc.ctx.Done() } - cc.Close() - log.Debugf("Done") return cc.err } @@ -93,6 +89,8 @@ func (cc *Client) closeListener() { // close remote listener may block, because of connection issue // so only 1 second to wait + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() wg := &sync.WaitGroup{} for _, l := range cc.listeners { @@ -105,16 +103,12 @@ func (cc *Client) closeListener() { wg.Add(1) } - ch := make(chan struct{}) go func() { wg.Wait() - ch <- struct{}{} + cancel() }() - select { - case <-ch: - case <-time.After(1 * time.Second): - } + <-ctx.Done() } // Close close the ssh connection @@ -315,41 +309,47 @@ func (cc *Client) handleTransparentProxy(c net.Conn, addr net.Addr) { PipeAndClose(c2, c) } +func doKeepAlive(conn ssh.Conn, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + ch := make(chan error, 1) + + go func() { + _, _, err := conn.SendRequest("keepalive@openssh.org", true, nil) + ch <- err + }() + + select { + case <-ctx.Done(): + return errors.New("keepalive timeout") + case err := <-ch: + if err != nil { + return err + } + return nil + } +} + func (cc *Client) keepAlive(interval time.Duration, maxCount int) { count := 0 c := time.NewTicker(interval) + defer c.Stop() + for { select { + case <-cc.ctx.Done(): + return case <-c.C: - resCh := make(chan error, 1) - go func(resCh chan error) { - _, _, err := cc.sshConn.SendRequest("keepalive@openssh.org", true, nil) - resCh <- err - }(resCh) - select { - case err := <-resCh: - if err != nil { - log.Debugf("keep alive error: %s", err.Error()) - count++ - } else { - count = 0 - } - case <-time.After(3 * time.Second): - log.Debugf("keep alive timed out") + if err := doKeepAlive(cc.sshConn, 3*time.Second); err != nil { count++ + } else { + count = 0 } - if count >= maxCount { cc.err = fmt.Errorf("keep alive detects connection hang up") log.Errorf("keep alive hit max count, exit") - //cc.sshConn.Close() - //cc.conn.Close() - // send exit signal - // close(cc.ch) - select { - case cc.ch <- struct{}{}: - default: - } + cc.cancel() return } } @@ -364,10 +364,7 @@ func (cc *Client) registerSignal() { cc.err = fmt.Errorf("signal %v", s1) log.Errorf("signal %d received, exit", s1) //close(cc.ch) - select { - case cc.ch <- struct{}{}: - default: - } + cc.cancel() } }