|
|
@ -2,6 +2,8 @@ package obfssh
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
import (
|
|
|
|
"bufio"
|
|
|
|
"bufio"
|
|
|
|
|
|
|
|
"context"
|
|
|
|
|
|
|
|
"errors"
|
|
|
|
"fmt"
|
|
|
|
"fmt"
|
|
|
|
"net"
|
|
|
|
"net"
|
|
|
|
"net/http"
|
|
|
|
"net/http"
|
|
|
@ -24,8 +26,9 @@ type Client struct {
|
|
|
|
sshConn ssh.Conn
|
|
|
|
sshConn ssh.Conn
|
|
|
|
client *ssh.Client
|
|
|
|
client *ssh.Client
|
|
|
|
listeners []net.Listener
|
|
|
|
listeners []net.Listener
|
|
|
|
ch chan struct{}
|
|
|
|
|
|
|
|
err error
|
|
|
|
err error
|
|
|
|
|
|
|
|
ctx context.Context
|
|
|
|
|
|
|
|
cancel context.CancelFunc
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// NewClient create a new ssh Client
|
|
|
|
// NewClient create a new ssh Client
|
|
|
@ -45,8 +48,8 @@ func NewClient(c net.Conn, config *ssh.ClientConfig, addr string, conf *Conf) (*
|
|
|
|
sshClient := ssh.NewClient(sshConn, newch, reqs)
|
|
|
|
sshClient := ssh.NewClient(sshConn, newch, reqs)
|
|
|
|
client := &Client{
|
|
|
|
client := &Client{
|
|
|
|
conn: c, sshConn: sshConn, client: sshClient,
|
|
|
|
conn: c, sshConn: sshConn, client: sshClient,
|
|
|
|
ch: make(chan struct{}),
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
client.ctx, client.cancel = context.WithCancel(context.Background())
|
|
|
|
go client.keepAlive(conf.KeepAliveInterval, conf.KeepAliveMax)
|
|
|
|
go client.keepAlive(conf.KeepAliveInterval, conf.KeepAliveMax)
|
|
|
|
return client, nil
|
|
|
|
return client, nil
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -58,6 +61,9 @@ func (cc *Client) Client() *ssh.Client {
|
|
|
|
|
|
|
|
|
|
|
|
// Run wait ssh connection to finish
|
|
|
|
// Run wait ssh connection to finish
|
|
|
|
func (cc *Client) Run() error {
|
|
|
|
func (cc *Client) Run() error {
|
|
|
|
|
|
|
|
defer cc.Close()
|
|
|
|
|
|
|
|
defer cc.cancel()
|
|
|
|
|
|
|
|
|
|
|
|
select {
|
|
|
|
select {
|
|
|
|
case <-time.After(1 * time.Second):
|
|
|
|
case <-time.After(1 * time.Second):
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -68,21 +74,11 @@ func (cc *Client) Run() error {
|
|
|
|
go func() {
|
|
|
|
go func() {
|
|
|
|
cc.err = cc.sshConn.Wait()
|
|
|
|
cc.err = cc.sshConn.Wait()
|
|
|
|
log.Debugf("connection hang up")
|
|
|
|
log.Debugf("connection hang up")
|
|
|
|
|
|
|
|
cc.cancel()
|
|
|
|
//close(cc.ch)
|
|
|
|
//close(cc.ch)
|
|
|
|
select {
|
|
|
|
|
|
|
|
case cc.ch <- struct{}{}:
|
|
|
|
|
|
|
|
default:
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}()
|
|
|
|
}()
|
|
|
|
|
|
|
|
<-cc.ctx.Done()
|
|
|
|
// wait exit signal
|
|
|
|
|
|
|
|
select {
|
|
|
|
|
|
|
|
case <-cc.ch:
|
|
|
|
|
|
|
|
log.Debugf("got signal, exit")
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
cc.Close()
|
|
|
|
|
|
|
|
log.Debugf("Done")
|
|
|
|
|
|
|
|
return cc.err
|
|
|
|
return cc.err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -93,6 +89,8 @@ func (cc *Client) closeListener() {
|
|
|
|
|
|
|
|
|
|
|
|
// close remote listener may block, because of connection issue
|
|
|
|
// close remote listener may block, because of connection issue
|
|
|
|
// so only 1 second to wait
|
|
|
|
// so only 1 second to wait
|
|
|
|
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
|
|
|
|
|
|
|
defer cancel()
|
|
|
|
|
|
|
|
|
|
|
|
wg := &sync.WaitGroup{}
|
|
|
|
wg := &sync.WaitGroup{}
|
|
|
|
for _, l := range cc.listeners {
|
|
|
|
for _, l := range cc.listeners {
|
|
|
@ -105,16 +103,12 @@ func (cc *Client) closeListener() {
|
|
|
|
wg.Add(1)
|
|
|
|
wg.Add(1)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ch := make(chan struct{})
|
|
|
|
|
|
|
|
go func() {
|
|
|
|
go func() {
|
|
|
|
wg.Wait()
|
|
|
|
wg.Wait()
|
|
|
|
ch <- struct{}{}
|
|
|
|
cancel()
|
|
|
|
}()
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
|
|
select {
|
|
|
|
<-ctx.Done()
|
|
|
|
case <-ch:
|
|
|
|
|
|
|
|
case <-time.After(1 * time.Second):
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Close close the ssh connection
|
|
|
|
// Close close the ssh connection
|
|
|
@ -315,41 +309,47 @@ func (cc *Client) handleTransparentProxy(c net.Conn, addr net.Addr) {
|
|
|
|
PipeAndClose(c2, c)
|
|
|
|
PipeAndClose(c2, c)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func doKeepAlive(conn ssh.Conn, timeout time.Duration) error {
|
|
|
|
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
|
|
|
|
|
|
defer cancel()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ch := make(chan error, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
go func() {
|
|
|
|
|
|
|
|
_, _, err := conn.SendRequest("keepalive@openssh.org", true, nil)
|
|
|
|
|
|
|
|
ch <- err
|
|
|
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
select {
|
|
|
|
|
|
|
|
case <-ctx.Done():
|
|
|
|
|
|
|
|
return errors.New("keepalive timeout")
|
|
|
|
|
|
|
|
case err := <-ch:
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
|
|
return err
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (cc *Client) keepAlive(interval time.Duration, maxCount int) {
|
|
|
|
func (cc *Client) keepAlive(interval time.Duration, maxCount int) {
|
|
|
|
count := 0
|
|
|
|
count := 0
|
|
|
|
c := time.NewTicker(interval)
|
|
|
|
c := time.NewTicker(interval)
|
|
|
|
|
|
|
|
defer c.Stop()
|
|
|
|
|
|
|
|
|
|
|
|
for {
|
|
|
|
for {
|
|
|
|
select {
|
|
|
|
select {
|
|
|
|
|
|
|
|
case <-cc.ctx.Done():
|
|
|
|
|
|
|
|
return
|
|
|
|
case <-c.C:
|
|
|
|
case <-c.C:
|
|
|
|
resCh := make(chan error, 1)
|
|
|
|
if err := doKeepAlive(cc.sshConn, 3*time.Second); err != nil {
|
|
|
|
go func(resCh chan error) {
|
|
|
|
|
|
|
|
_, _, err := cc.sshConn.SendRequest("keepalive@openssh.org", true, nil)
|
|
|
|
|
|
|
|
resCh <- err
|
|
|
|
|
|
|
|
}(resCh)
|
|
|
|
|
|
|
|
select {
|
|
|
|
|
|
|
|
case err := <-resCh:
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
|
|
log.Debugf("keep alive error: %s", err.Error())
|
|
|
|
|
|
|
|
count++
|
|
|
|
count++
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
count = 0
|
|
|
|
count = 0
|
|
|
|
}
|
|
|
|
}
|
|
|
|
case <-time.After(3 * time.Second):
|
|
|
|
|
|
|
|
log.Debugf("keep alive timed out")
|
|
|
|
|
|
|
|
count++
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if count >= maxCount {
|
|
|
|
if count >= maxCount {
|
|
|
|
cc.err = fmt.Errorf("keep alive detects connection hang up")
|
|
|
|
cc.err = fmt.Errorf("keep alive detects connection hang up")
|
|
|
|
log.Errorf("keep alive hit max count, exit")
|
|
|
|
log.Errorf("keep alive hit max count, exit")
|
|
|
|
//cc.sshConn.Close()
|
|
|
|
cc.cancel()
|
|
|
|
//cc.conn.Close()
|
|
|
|
|
|
|
|
// send exit signal
|
|
|
|
|
|
|
|
// close(cc.ch)
|
|
|
|
|
|
|
|
select {
|
|
|
|
|
|
|
|
case cc.ch <- struct{}{}:
|
|
|
|
|
|
|
|
default:
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -364,10 +364,7 @@ func (cc *Client) registerSignal() {
|
|
|
|
cc.err = fmt.Errorf("signal %v", s1)
|
|
|
|
cc.err = fmt.Errorf("signal %v", s1)
|
|
|
|
log.Errorf("signal %d received, exit", s1)
|
|
|
|
log.Errorf("signal %d received, exit", s1)
|
|
|
|
//close(cc.ch)
|
|
|
|
//close(cc.ch)
|
|
|
|
select {
|
|
|
|
cc.cancel()
|
|
|
|
case cc.ch <- struct{}{}:
|
|
|
|
|
|
|
|
default:
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|