From bcf8b9f48156546d7e2f15f93c6317e77882f575 Mon Sep 17 00:00:00 2001 From: dingjun Date: Thu, 9 Jul 2020 19:52:42 +0800 Subject: [PATCH] fix exec and shell pty issue --- client.go | 17 +++---- conn_test.go | 13 ++++- go.mod | 2 +- go.sum | 2 - obfssh/ssh.go | 5 +- obfsshd/server.go | 1 - pty_unix.go | 14 ----- pty_windows.go | 11 ---- server.go | 127 +++++++++++++++++++++++++++++++++++++--------- 9 files changed, 124 insertions(+), 68 deletions(-) delete mode 100644 pty_unix.go delete mode 100644 pty_windows.go diff --git a/client.go b/client.go index 0b43322..f8a90f5 100644 --- a/client.go +++ b/client.go @@ -123,21 +123,20 @@ func (cc *Client) Close() { } // RunCmd run a single command on server -func (cc *Client) RunCmd(cmd string) ([]byte, error) { +func (cc *Client) RunCmd(cmd string) error { log.Debugf("run command %s", cmd) session, err := cc.client.NewSession() if err != nil { - log.Debugf("command exited with error: %s", err.Error()) - } else { - log.Debugf("command exited with no error") + log.Debugf("new session error: %s", err.Error()) + return err } - if err != nil { - return nil, err - } - d, err := session.CombinedOutput(cmd) + session.Stdin = os.Stdin + session.Stderr = os.Stderr + session.Stdout = os.Stdout + err = session.Run(cmd) session.Close() - return d, err + return err } // Shell start a login shell on server diff --git a/conn_test.go b/conn_test.go index 03f7beb..f794432 100644 --- a/conn_test.go +++ b/conn_test.go @@ -5,8 +5,15 @@ import ( "net" "testing" "time" + + "github.com/fangdingjun/go-log" ) +func TestTimedOutConn(t *testing.T) { + testTimedOutConn(t, true) + testTimedOutConn(t, false) +} + func testTimedOutConn(t *testing.T, _timeout bool) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -23,7 +30,8 @@ func testTimedOutConn(t *testing.T, _timeout bool) { go func() { s, err := l.Accept() if err != nil { - t.Fatalf("accept failed: %s", err) + log.Errorf("accept failed: %s", err) + return } defer s.Close() @@ -34,7 +42,8 @@ func testTimedOutConn(t *testing.T, _timeout bool) { n, err := sConn.Read(buf) if err != nil { - t.Fatalf("server read failed: %s", err) + log.Errorf("server read failed: %s", err) + return } if _timeout { diff --git a/go.mod b/go.mod index abb3a42..9a9ec81 100644 --- a/go.mod +++ b/go.mod @@ -4,13 +4,13 @@ go 1.13 require ( github.com/bgentry/speakeasy v0.1.0 + github.com/creack/pty v1.1.7 github.com/fangdingjun/go-log v0.0.0-20190821073628-ae332053d6dc github.com/fangdingjun/protolistener v0.0.0-20190821093313-6d5d2138f296 github.com/fangdingjun/socks-go v0.0.0-20180926100003-fc6f0a9ee1f4 github.com/go-yaml/yaml v2.1.0+incompatible github.com/kr/fs v0.1.0 github.com/kr/pretty v0.1.0 // indirect - github.com/kr/pty v1.1.8 github.com/pkg/errors v0.8.1 // indirect github.com/pkg/sftp v1.10.0 github.com/stretchr/testify v1.4.0 // indirect diff --git a/go.sum b/go.sum index 8889a63..d31a6ac 100644 --- a/go.sum +++ b/go.sum @@ -17,8 +17,6 @@ github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/pty v1.1.8 h1:AkaSdXYQOWeaO3neb8EM634ahkXXe3jYbVh/F9lq+GI= -github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/pires/go-proxyproto v0.0.0-20190111085350-4d51b51e3bfc h1:lNOt1SMsgHXTdpuGw+RpnJtzUcCb/oRKZP65pBy9pr8= diff --git a/obfssh/ssh.go b/obfssh/ssh.go index a7173c3..e49eb72 100644 --- a/obfssh/ssh.go +++ b/obfssh/ssh.go @@ -331,12 +331,9 @@ func main() { if !cfg.NotRunCmd { if cmd != "" { - if d, err := client.RunCmd(cmd); err != nil { + if err := client.RunCmd(cmd); err != nil { log.Errorln(err) hasErr = true - } else { - //log.Printf("%s", string(d)) - fmt.Printf("%s", string(d)) } } else { if err := client.Shell(); err != nil { diff --git a/obfsshd/server.go b/obfsshd/server.go index d2e431a..d552fe7 100644 --- a/obfsshd/server.go +++ b/obfsshd/server.go @@ -152,7 +152,6 @@ func main() { defer c.Close() sc, err := obfssh.NewServer(c, config, sconf) if err != nil { - c.Close() log.Errorf("%s", err.Error()) return } diff --git a/pty_unix.go b/pty_unix.go deleted file mode 100644 index 5b837c4..0000000 --- a/pty_unix.go +++ /dev/null @@ -1,14 +0,0 @@ -// +build linux darwin - -package obfssh - -import ( - "io" - "os/exec" - - "github.com/kr/pty" -) - -func startPty(cmd *exec.Cmd) (io.ReadWriteCloser, error) { - return pty.Start(cmd) -} diff --git a/pty_windows.go b/pty_windows.go deleted file mode 100644 index de08339..0000000 --- a/pty_windows.go +++ /dev/null @@ -1,11 +0,0 @@ -package obfssh - -import ( - "errors" - "io" - "os/exec" -) - -func startPty(cmd *exec.Cmd) (io.ReadWriteCloser, error) { - return nil, errors.New("not implement") -} diff --git a/server.go b/server.go index a278d59..4b4d262 100644 --- a/server.go +++ b/server.go @@ -4,11 +4,13 @@ import ( "fmt" "io" "net" + "os" "os/exec" "runtime" "syscall" "time" + "github.com/creack/pty" "github.com/fangdingjun/go-log" "github.com/pkg/sftp" "golang.org/x/crypto/ssh" @@ -146,6 +148,22 @@ type exitStatus struct { Status uint32 } +type ptyReq struct { + Term string + Columns uint32 + Rows uint32 + Width uint32 + Height uint32 + Mode string +} + +type windowChange struct { + Columns uint32 + Rows uint32 + Width uint32 + Height uint32 +} + func (sc *Server) handleSession(newch ssh.NewChannel) { ch, req, err := newch.Accept() if err != nil { @@ -153,11 +171,14 @@ func (sc *Server) handleSession(newch ssh.NewChannel) { return } - var _cmd args - ret := false + + var _cmd args var cmd *exec.Cmd var env []string + var _ptyReq ptyReq + var _windowChange windowChange + var _pty, _tty *os.File for r := range req { switch r.Type { @@ -168,6 +189,7 @@ func (sc *Server) handleSession(newch ssh.NewChannel) { log.Debugf("handle sftp request") go serveSFTP(ch) } else { + ret = false log.Debugln("subsystem", _cmd.Arg, "not support") } } else { @@ -182,7 +204,7 @@ func (sc *Server) handleSession(newch ssh.NewChannel) { cmd = exec.Command("bash", "-l") } cmd.Env = env - go handleShell(cmd, ch) + go handleShell(cmd, ch, _pty, _tty) case "signal": log.Debugln("got signal") ret = true @@ -204,6 +226,35 @@ func (sc *Server) handleSession(newch ssh.NewChannel) { } case "pty-req": ret = true + err = ssh.Unmarshal(r.Payload, &_ptyReq) + if err != nil { + log.Errorln(err) + ret = false + } + log.Debugf("pty req %+v", _ptyReq) + if err == nil && (runtime.GOOS == "unix" || runtime.GOOS == "linux") { + _pty, _tty, err = pty.Open() + if err != nil { + log.Errorln(err) + ret = false + } else { + env = append(env, fmt.Sprintf("TERM=%s", _ptyReq.Term)) + size, err := pty.GetsizeFull(_pty) + if err == nil { + log.Debugf("term size %+v", size) + size.Rows = uint16(_ptyReq.Rows) + size.Cols = uint16(_ptyReq.Columns) + if err = pty.Setsize(_pty, size); err != nil { + log.Errorln(err) + } + if err = pty.Setsize(_tty, size); err != nil { + log.Errorln(err) + } + } else { + log.Errorln(err) + } + } + } case "env": var arg envArgs ret = true @@ -216,6 +267,28 @@ func (sc *Server) handleSession(newch ssh.NewChannel) { } case "window-change": ret = true + err = ssh.Unmarshal(r.Payload, &_windowChange) + if err != nil { + ret = false + log.Errorln(err) + } + if err == nil && _pty != nil { + size, err := pty.GetsizeFull(_pty) + if err == nil { + log.Debugf("term size %+v", size) + size.Rows = uint16(_ptyReq.Rows) + size.Cols = uint16(_ptyReq.Columns) + if err = pty.Setsize(_pty, size); err != nil { + log.Errorln(err) + } + if err = pty.Setsize(_tty, size); err != nil { + log.Errorln(err) + } + } else { + log.Errorln(err) + } + } + log.Debugf("window change %+v", _windowChange) default: ret = false } @@ -228,32 +301,36 @@ func (sc *Server) handleSession(newch ssh.NewChannel) { } } -func handleShell(cmd *exec.Cmd, ch ssh.Channel) { - defer ch.Close() +func handleShell(cmd *exec.Cmd, ch ssh.Channel, _pty, _tty *os.File) { + defer func() { + ch.Close() + if _pty != nil { + _pty.Close() + _tty.Close() + } + }() - var _pty io.ReadWriteCloser var err error log.Infoln("start shell") - //_pty, err = pty.Start(cmd) - if runtime.GOOS == "unix" || runtime.GOOS == "linux" { - _pty, err = startPty(cmd) - if err != nil { - log.Debugln("start pty", err) - ch.SendRequest("exit-status", false, - ssh.Marshal(exitStatus{Status: 127})) - return + if _tty != nil { + cmd.Stderr = _tty + cmd.Stdout = _tty + cmd.Stdin = _tty + if cmd.SysProcAttr == nil { + cmd.SysProcAttr = &syscall.SysProcAttr{} } - } + cmd.SysProcAttr.Setsid = true + cmd.SysProcAttr.Setctty = true + cmd.SysProcAttr.Ctty = int(_tty.Fd()) - if runtime.GOOS == "unix" || runtime.GOOS == "linux" { - defer _pty.Close() go io.Copy(ch, _pty) go io.Copy(_pty, ch) - } else { // windows + } else { cmd.Stderr = ch cmd.Stdout = ch + in, err := cmd.StdinPipe() if err != nil { ch.SendRequest("exit-status", false, @@ -263,14 +340,16 @@ func handleShell(cmd *exec.Cmd, ch ssh.Channel) { go func() { defer in.Close() io.Copy(in, ch) + }() - if err := cmd.Start(); err != nil { - log.Debugln("start command ", err) - ch.SendRequest("exit-status", false, - ssh.Marshal(exitStatus{Status: 126})) - return - } } + if err := cmd.Start(); err != nil { + log.Debugln("start command ", err) + ch.SendRequest("exit-status", false, + ssh.Marshal(exitStatus{Status: 126})) + return + } + code := 0 if err = cmd.Wait(); err != nil { log.Debugln(err)