fix hang issue on close remote listener

tls
fangdingjun 7 years ago
parent 397f60da61
commit 2b973b326d

@ -3,12 +3,12 @@ package obfssh
import ( import (
"bufio" "bufio"
"fmt" "fmt"
"io"
"net" "net"
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
"strings" "strings"
"sync"
"syscall" "syscall"
"time" "time"
@ -80,21 +80,51 @@ func (cc *Client) Run() error {
Log(INFO, "got signal, exit") Log(INFO, "got signal, exit")
} }
} }
Log(DEBUG, "Done")
cc.Close() cc.Close()
Log(DEBUG, "Done")
return cc.err return cc.err
} }
// Close close the ssh connection func (cc *Client) closeListener() {
// and free all the port forward resources if len(cc.listeners) == 0 {
func (cc *Client) Close() { return
}
// close remote listener may block, because of connection issue
// so only 1 second to wait
wg := &sync.WaitGroup{}
for _, l := range cc.listeners { for _, l := range cc.listeners {
Log(INFO, "close the listener %s", l.Addr()) go func(l net.Listener) {
Log(DEBUG, "begin to close listener %s", l.Addr().String())
l.Close() l.Close()
Log(DEBUG, "close listener %s done", l.Addr().String())
wg.Done()
}(l)
wg.Add(1)
}
ch := make(chan struct{})
go func() {
wg.Wait()
ch <- struct{}{}
}()
select {
case <-ch:
case <-time.After(1 * time.Second):
}
} }
//Log(DEBUG, "close ssh connection")
// Close close the ssh connection
// and free all the port forward resources
func (cc *Client) Close() {
cc.closeListener()
Log(DEBUG, "close ssh connection")
cc.sshConn.Close() cc.sshConn.Close()
cc.conn.Close() cc.conn.Close()
Log(DEBUG, "close ssh connection done")
} }
// RunCmd run a single command on server // RunCmd run a single command on server
@ -304,6 +334,7 @@ func (cc *Client) keepAlive(interval time.Duration, maxCount int) {
count = 0 count = 0
} }
case <-time.After(3 * time.Second): case <-time.After(3 * time.Second):
Log(DEBUG, "keep alive timed out")
count++ count++
} }
@ -366,13 +397,14 @@ func (cc *Client) AddDynamicHTTPForward(addr string) error {
} }
func (cc *Client) handleHTTPIncoming(c net.Conn) { func (cc *Client) handleHTTPIncoming(c net.Conn) {
defer c.Close() //defer c.Close()
r := bufio.NewReader(c) r := bufio.NewReader(c)
req, err := http.ReadRequest(r) req, err := http.ReadRequest(r)
if err != nil { if err != nil {
Log(ERROR, "read http request error %s", err) Log(ERROR, "read http request error %s", err)
c.Close()
return return
} }
@ -388,27 +420,16 @@ func (cc *Client) handleConnect(req *http.Request, c net.Conn) {
c1, err := cc.client.Dial("tcp", req.RequestURI) c1, err := cc.client.Dial("tcp", req.RequestURI)
if err != nil { if err != nil {
c.Close()
fmt.Fprintf(c, "HTTP/1.0 503 connection failed\r\n\r\n") fmt.Fprintf(c, "HTTP/1.0 503 connection failed\r\n\r\n")
Log(ERROR, "dial error %s", err) Log(ERROR, "dial error %s", err)
return return
} }
defer c1.Close() //defer c1.Close()
fmt.Fprintf(c, "HTTP/1.0 200 connection established\r\n\r\n") fmt.Fprintf(c, "HTTP/1.0 200 connection established\r\n\r\n")
PipeAndClose(c, c1)
ch := make(chan struct{}, 2)
go func() {
io.Copy(c1, c)
ch <- struct{}{}
}()
go func() {
io.Copy(c, c1)
ch <- struct{}{}
}()
<-ch
} }
func (cc *Client) handleHTTPReq(req *http.Request, c net.Conn) { func (cc *Client) handleHTTPReq(req *http.Request, c net.Conn) {
@ -420,29 +441,19 @@ func (cc *Client) handleHTTPReq(req *http.Request, c net.Conn) {
Log(DEBUG, "request to %s", host) Log(DEBUG, "request to %s", host)
c1, err := cc.client.Dial("tcp", host) c1, err := cc.client.Dial("tcp", host)
if err != nil { if err != nil {
c.Close()
fmt.Fprintf(c, "HTTP/1.1 503 connection failed\r\nConnection: close\r\n\r\n") fmt.Fprintf(c, "HTTP/1.1 503 connection failed\r\nConnection: close\r\n\r\n")
Log(ERROR, "connection failed %s", err) Log(ERROR, "connection failed %s", err)
return return
} }
defer c1.Close() //defer c1.Close()
if err = req.Write(c1); err != nil { if err = req.Write(c1); err != nil {
fmt.Fprintf(c, "HTTP/1.1 503 write to server error\r\nConnection: close\r\n\r\n") fmt.Fprintf(c, "HTTP/1.1 503 write to server error\r\nConnection: close\r\n\r\n")
Log(ERROR, "write request to server error %s", err) Log(ERROR, "write request to server error %s", err)
c.Close()
c1.Close()
return return
} }
PipeAndClose(c, c1)
ch := make(chan struct{}, 2)
go func() {
io.Copy(c1, c)
ch <- struct{}{}
}()
go func() {
io.Copy(c, c1)
ch <- struct{}{}
}()
<-ch
} }

Loading…
Cancel
Save