diff --git a/client.go b/client.go index f8a90f5..defe15a 100644 --- a/client.go +++ b/client.go @@ -14,10 +14,10 @@ import ( "syscall" "time" + "github.com/containerd/console" "github.com/fangdingjun/go-log" socks "github.com/fangdingjun/socks-go" "golang.org/x/crypto/ssh" - "golang.org/x/crypto/ssh/terminal" ) // Client is ssh client connection @@ -147,30 +147,46 @@ func (cc *Client) Shell() error { return err } - session.Stdin = os.Stdin - session.Stdout = os.Stdout - session.Stderr = os.Stderr modes := ssh.TerminalModes{ ssh.ECHO: 1, ssh.TTY_OP_ISPEED: 14400, ssh.TTY_OP_OSPEED: 14400, } + _console := console.Current() + defer _console.Reset() + // this make CTRL+C works log.Debugf("turn terminal mode to raw") + _console.SetRaw() - oldState, _ := terminal.MakeRaw(0) - defer func() { - log.Debugf("restore terminal mode") - terminal.Restore(0, oldState) - }() + ws, _ := _console.Size() - w, h, _ := terminal.GetSize(0) log.Debugf("request pty") - if err := session.RequestPty("xterm", h, w, modes); err != nil { + if err := session.RequestPty("xterm", int(ws.Height), int(ws.Width), modes); err != nil { log.Errorf("request pty error: %s", err.Error()) return err } + + ch := make(chan os.Signal, 2) + signal.Notify(ch, syscall.SIGWINCH) + go func() { + for { + select { + case <-ch: + ws, _ := _console.Size() + _winCh := windowChange{Rows: uint32(ws.Height), Columns: uint32(ws.Width)} + d := ssh.Marshal(_winCh) + ok, err := session.SendRequest("window-change", true, d) + log.Debugf("send window change request %+v %+v", ok, err) + } + } + }() + + session.Stdin = _console + session.Stdout = os.Stdout + session.Stderr = os.Stderr + log.Debugf("request shell") if err := session.Shell(); err != nil { log.Errorf("start shell error: %s", err.Error()) diff --git a/server.go b/server.go index 4b4d262..20b0d87 100644 --- a/server.go +++ b/server.go @@ -10,7 +10,7 @@ import ( "syscall" "time" - "github.com/creack/pty" + "github.com/containerd/console" "github.com/fangdingjun/go-log" "github.com/pkg/sftp" "golang.org/x/crypto/ssh" @@ -178,7 +178,8 @@ func (sc *Server) handleSession(newch ssh.NewChannel) { var env []string var _ptyReq ptyReq var _windowChange windowChange - var _pty, _tty *os.File + var _console console.Console + var ptsname string for r := range req { switch r.Type { @@ -204,7 +205,7 @@ func (sc *Server) handleSession(newch ssh.NewChannel) { cmd = exec.Command("bash", "-l") } cmd.Env = env - go handleShell(cmd, ch, _pty, _tty) + go handleShell(cmd, ch, _console, ptsname) case "signal": log.Debugln("got signal") ret = true @@ -233,26 +234,23 @@ func (sc *Server) handleSession(newch ssh.NewChannel) { } log.Debugf("pty req %+v", _ptyReq) if err == nil && (runtime.GOOS == "unix" || runtime.GOOS == "linux") { - _pty, _tty, err = pty.Open() - if err != nil { - log.Errorln(err) - ret = false - } else { - env = append(env, fmt.Sprintf("TERM=%s", _ptyReq.Term)) - size, err := pty.GetsizeFull(_pty) - if err == nil { - log.Debugf("term size %+v", size) - size.Rows = uint16(_ptyReq.Rows) - size.Cols = uint16(_ptyReq.Columns) - if err = pty.Setsize(_pty, size); err != nil { - log.Errorln(err) - } - if err = pty.Setsize(_tty, size); err != nil { + _console, ptsname, err = console.NewPty() + if err == nil { + log.Debugf("allocate pty %s", ptsname) + env = append(env, fmt.Sprintf("SSH_TTY=%s", ptsname)) + ws, err := _console.Size() + if err != nil { + log.Errorln(err) + } else { + log.Debugf("current console %+v", ws) + ws.Height = uint16(_ptyReq.Rows) + ws.Width = uint16(_ptyReq.Columns) + if err = _console.Resize(ws); err != nil { log.Errorln(err) } - } else { - log.Errorln(err) } + } else { + log.Errorln(err) } } case "env": @@ -272,20 +270,17 @@ func (sc *Server) handleSession(newch ssh.NewChannel) { ret = false log.Errorln(err) } - if err == nil && _pty != nil { - size, err := pty.GetsizeFull(_pty) - if err == nil { - log.Debugf("term size %+v", size) - size.Rows = uint16(_ptyReq.Rows) - size.Cols = uint16(_ptyReq.Columns) - if err = pty.Setsize(_pty, size); err != nil { - log.Errorln(err) - } - if err = pty.Setsize(_tty, size); err != nil { + if err == nil && _console != nil { + ws, err := _console.Size() + if err != nil { + log.Errorln(err) + } else { + log.Debugf("current console %+v", ws) + ws.Height = uint16(_windowChange.Rows) + ws.Width = uint16(_windowChange.Columns) + if err = _console.Resize(ws); err != nil { log.Errorln(err) } - } else { - log.Errorln(err) } } log.Debugf("window change %+v", _windowChange) @@ -301,12 +296,11 @@ func (sc *Server) handleSession(newch ssh.NewChannel) { } } -func handleShell(cmd *exec.Cmd, ch ssh.Channel, _pty, _tty *os.File) { +func handleShell(cmd *exec.Cmd, ch ssh.Channel, _console console.Console, ptsname string) { defer func() { ch.Close() - if _pty != nil { - _pty.Close() - _tty.Close() + if _console != nil { + _console.Close() } }() @@ -314,7 +308,13 @@ func handleShell(cmd *exec.Cmd, ch ssh.Channel, _pty, _tty *os.File) { log.Infoln("start shell") - if _tty != nil { + if _console != nil { + _tty, err := os.OpenFile(ptsname, syscall.O_RDWR|syscall.O_NOCTTY, 0600) + if err != nil { + log.Errorln(err) + return + } + defer _tty.Close() cmd.Stderr = _tty cmd.Stdout = _tty cmd.Stdin = _tty @@ -323,10 +323,10 @@ func handleShell(cmd *exec.Cmd, ch ssh.Channel, _pty, _tty *os.File) { } cmd.SysProcAttr.Setsid = true cmd.SysProcAttr.Setctty = true - cmd.SysProcAttr.Ctty = int(_tty.Fd()) + // cmd.SysProcAttr.Ctty = int(_tty.Fd()) - go io.Copy(ch, _pty) - go io.Copy(_pty, ch) + go io.Copy(ch, _console) + go io.Copy(_console, ch) } else { cmd.Stderr = ch cmd.Stdout = ch