diff --git a/obfssh/ssh.go b/obfssh/ssh.go index 7e36f9c..a3002f3 100644 --- a/obfssh/ssh.go +++ b/obfssh/ssh.go @@ -6,6 +6,7 @@ import ( "fmt" "io/ioutil" "net" + "net/url" "os" "path/filepath" "strings" @@ -197,29 +198,36 @@ func main() { // parse environment proxy updateProxyFromEnv(&cfg) - rhost := net.JoinHostPort(host, fmt.Sprintf("%d", cfg.Port)) - var c net.Conn - if cfg.Proxy.Scheme != "" && cfg.Proxy.Host != "" && cfg.Proxy.Port != 0 { - switch cfg.Proxy.Scheme { - case "http": - log.Debugf("use http proxy %s:%d to connect to server", - cfg.Proxy.Host, cfg.Proxy.Port) - c, err = dialHTTPProxy(host, cfg.Port, cfg.Proxy) - case "https": - log.Debugf("use https proxy %s:%d to connect to server", - cfg.Proxy.Host, cfg.Proxy.Port) - c, err = dialHTTPSProxy(host, cfg.Port, cfg.Proxy) - case "socks5": - log.Debugf("use socks proxy %s:%d to connect to server", - cfg.Proxy.Host, cfg.Proxy.Port) - c, err = dialSocks5Proxy(host, cfg.Port, cfg.Proxy) - default: - err = fmt.Errorf("unsupported scheme: %s", cfg.Proxy.Scheme) - } + var rhost string + + if strings.HasPrefix(host, "ws://") || strings.HasPrefix(host, "wss://") { + c, err = obfssh.NewWSConn(host) + u, _ := url.Parse(host) + rhost = u.Host } else { - log.Debugf("dail to %s", rhost) - c, err = dialer.Dial("tcp", rhost) + rhost = net.JoinHostPort(host, fmt.Sprintf("%d", cfg.Port)) + if cfg.Proxy.Scheme != "" && cfg.Proxy.Host != "" && cfg.Proxy.Port != 0 { + switch cfg.Proxy.Scheme { + case "http": + log.Debugf("use http proxy %s:%d to connect to server", + cfg.Proxy.Host, cfg.Proxy.Port) + c, err = dialHTTPProxy(host, cfg.Port, cfg.Proxy) + case "https": + log.Debugf("use https proxy %s:%d to connect to server", + cfg.Proxy.Host, cfg.Proxy.Port) + c, err = dialHTTPSProxy(host, cfg.Port, cfg.Proxy) + case "socks5": + log.Debugf("use socks proxy %s:%d to connect to server", + cfg.Proxy.Host, cfg.Proxy.Port) + c, err = dialSocks5Proxy(host, cfg.Port, cfg.Proxy) + default: + err = fmt.Errorf("unsupported scheme: %s", cfg.Proxy.Scheme) + } + } else { + log.Debugf("dail to %s", rhost) + c, err = dialer.Dial("tcp", rhost) + } } if err != nil { diff --git a/ws.go b/ws.go new file mode 100644 index 0000000..07912fb --- /dev/null +++ b/ws.go @@ -0,0 +1,110 @@ +package obfssh + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" + "net/http" + "sync" + "time" + + "github.com/fangdingjun/go-log/v5" + "github.com/gorilla/websocket" +) + +type wsConn struct { + *websocket.Conn + buf *bytes.Buffer + mu *sync.Mutex + ch chan struct{} +} + +var _ net.Conn = &wsConn{} + +// NewWSConn dial to websocket server and return net.Conn +func NewWSConn(p string) (net.Conn, error) { + conn, resp, err := websocket.DefaultDialer.Dial(p, nil) + if err != nil { + return nil, err + } + resp.Body.Close() + + if resp.StatusCode != http.StatusSwitchingProtocols { + return nil, fmt.Errorf("http status %d", resp.StatusCode) + } + + c := &wsConn{Conn: conn, + buf: bytes.NewBuffer(nil), + mu: new(sync.Mutex), + ch: make(chan struct{}), + } + + go c.readLoop() + + return c, nil +} + +func (wc *wsConn) readLoop() { + for { + _, data, err := wc.ReadMessage() + if err != nil { + log.Debugln(err) + close(wc.ch) + break + } + + wc.mu.Lock() + wc.buf.Write(data) + wc.mu.Unlock() + + select { + case wc.ch <- struct{}{}: + default: + } + } +} + +func (wc *wsConn) Read(buf []byte) (int, error) { + wc.mu.Lock() + + n, err := wc.buf.Read(buf) + if err == nil { + wc.mu.Unlock() + return n, err + } + + wc.mu.Unlock() + + if err != io.EOF { + return 0, err + } + + // EOF, no data avaliable, read again + select { + case _, ok := <-wc.ch: + if !ok { + return 0, errors.New("connection closed") + } + } + + wc.mu.Lock() + defer wc.mu.Unlock() + return wc.buf.Read(buf) +} + +func (wc *wsConn) Write(buf []byte) (int, error) { + err := wc.WriteMessage(websocket.BinaryMessage, buf) + return len(buf), err +} + +func (wc *wsConn) SetDeadline(t time.Time) error { + if err := wc.SetReadDeadline(t); err != nil { + return err + } + if err := wc.SetWriteDeadline(t); err != nil { + return err + } + return nil +}