use context.Context to signal exit

tls
fangdingjun 6 years ago
parent 311c076e55
commit 8c32fb3bb1

@ -2,6 +2,8 @@ package obfssh
import ( import (
"bufio" "bufio"
"context"
"errors"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -24,8 +26,9 @@ type Client struct {
sshConn ssh.Conn sshConn ssh.Conn
client *ssh.Client client *ssh.Client
listeners []net.Listener listeners []net.Listener
ch chan struct{}
err error err error
ctx context.Context
cancel context.CancelFunc
} }
// NewClient create a new ssh Client // NewClient create a new ssh Client
@ -45,8 +48,8 @@ func NewClient(c net.Conn, config *ssh.ClientConfig, addr string, conf *Conf) (*
sshClient := ssh.NewClient(sshConn, newch, reqs) sshClient := ssh.NewClient(sshConn, newch, reqs)
client := &Client{ client := &Client{
conn: c, sshConn: sshConn, client: sshClient, conn: c, sshConn: sshConn, client: sshClient,
ch: make(chan struct{}),
} }
client.ctx, client.cancel = context.WithCancel(context.Background())
go client.keepAlive(conf.KeepAliveInterval, conf.KeepAliveMax) go client.keepAlive(conf.KeepAliveInterval, conf.KeepAliveMax)
return client, nil return client, nil
} }
@ -58,6 +61,9 @@ func (cc *Client) Client() *ssh.Client {
// Run wait ssh connection to finish // Run wait ssh connection to finish
func (cc *Client) Run() error { func (cc *Client) Run() error {
defer cc.Close()
defer cc.cancel()
select { select {
case <-time.After(1 * time.Second): case <-time.After(1 * time.Second):
} }
@ -68,21 +74,11 @@ func (cc *Client) Run() error {
go func() { go func() {
cc.err = cc.sshConn.Wait() cc.err = cc.sshConn.Wait()
log.Debugf("connection hang up") log.Debugf("connection hang up")
cc.cancel()
//close(cc.ch) //close(cc.ch)
select {
case cc.ch <- struct{}{}:
default:
}
}() }()
<-cc.ctx.Done()
// wait exit signal
select {
case <-cc.ch:
log.Debugf("got signal, exit")
}
} }
cc.Close()
log.Debugf("Done")
return cc.err return cc.err
} }
@ -93,6 +89,8 @@ func (cc *Client) closeListener() {
// close remote listener may block, because of connection issue // close remote listener may block, because of connection issue
// so only 1 second to wait // so only 1 second to wait
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
for _, l := range cc.listeners { for _, l := range cc.listeners {
@ -105,16 +103,12 @@ func (cc *Client) closeListener() {
wg.Add(1) wg.Add(1)
} }
ch := make(chan struct{})
go func() { go func() {
wg.Wait() wg.Wait()
ch <- struct{}{} cancel()
}() }()
select { <-ctx.Done()
case <-ch:
case <-time.After(1 * time.Second):
}
} }
// Close close the ssh connection // Close close the ssh connection
@ -315,41 +309,47 @@ func (cc *Client) handleTransparentProxy(c net.Conn, addr net.Addr) {
PipeAndClose(c2, c) PipeAndClose(c2, c)
} }
func doKeepAlive(conn ssh.Conn, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
ch := make(chan error, 1)
go func() {
_, _, err := conn.SendRequest("keepalive@openssh.org", true, nil)
ch <- err
}()
select {
case <-ctx.Done():
return errors.New("keepalive timeout")
case err := <-ch:
if err != nil {
return err
}
return nil
}
}
func (cc *Client) keepAlive(interval time.Duration, maxCount int) { func (cc *Client) keepAlive(interval time.Duration, maxCount int) {
count := 0 count := 0
c := time.NewTicker(interval) c := time.NewTicker(interval)
defer c.Stop()
for { for {
select { select {
case <-cc.ctx.Done():
return
case <-c.C: case <-c.C:
resCh := make(chan error, 1) if err := doKeepAlive(cc.sshConn, 3*time.Second); err != nil {
go func(resCh chan error) {
_, _, err := cc.sshConn.SendRequest("keepalive@openssh.org", true, nil)
resCh <- err
}(resCh)
select {
case err := <-resCh:
if err != nil {
log.Debugf("keep alive error: %s", err.Error())
count++
} else {
count = 0
}
case <-time.After(3 * time.Second):
log.Debugf("keep alive timed out")
count++ count++
} else {
count = 0
} }
if count >= maxCount { if count >= maxCount {
cc.err = fmt.Errorf("keep alive detects connection hang up") cc.err = fmt.Errorf("keep alive detects connection hang up")
log.Errorf("keep alive hit max count, exit") log.Errorf("keep alive hit max count, exit")
//cc.sshConn.Close() cc.cancel()
//cc.conn.Close()
// send exit signal
// close(cc.ch)
select {
case cc.ch <- struct{}{}:
default:
}
return return
} }
} }
@ -364,10 +364,7 @@ func (cc *Client) registerSignal() {
cc.err = fmt.Errorf("signal %v", s1) cc.err = fmt.Errorf("signal %v", s1)
log.Errorf("signal %d received, exit", s1) log.Errorf("signal %d received, exit", s1)
//close(cc.ch) //close(cc.ch)
select { cc.cancel()
case cc.ch <- struct{}{}:
default:
}
} }
} }

Loading…
Cancel
Save