From 690c894f36a9860f3286485242fe854eff49676b Mon Sep 17 00:00:00 2001 From: dingjun Date: Fri, 17 Jul 2020 17:03:15 +0800 Subject: [PATCH] refine handle session --- console.go | 9 ++ console_unix.go | 21 +++ server.go | 330 ++++++++++++++++++++++++++---------------------- 3 files changed, 208 insertions(+), 152 deletions(-) diff --git a/console.go b/console.go index 8ec4eab..45b7186 100644 --- a/console.go +++ b/console.go @@ -4,6 +4,8 @@ package obfssh import ( "errors" + "os/exec" + "os/user" "syscall" "github.com/containerd/console" @@ -24,3 +26,10 @@ func setProcAttr(attr *syscall.SysProcAttr) { func setTermios(fd int, args ssh.TerminalModes) error { return errors.New("not supported") } + +func setUserEnv(_cmd *exec.Cmd, u *user.User, attr *syscall.SysProcAttr) { + if u == nil { + return + } + _cmd.Dir = u.HomeDir +} diff --git a/console_unix.go b/console_unix.go index 7bdf8e9..3c5e6a0 100644 --- a/console_unix.go +++ b/console_unix.go @@ -3,8 +3,12 @@ package obfssh import ( + "fmt" "os" + "os/exec" "os/signal" + "os/user" + "strconv" "syscall" "github.com/containerd/console" @@ -98,3 +102,20 @@ func setTermios(fd int, args ssh.TerminalModes) error { return nil } + +func setUserEnv(_cmd *exec.Cmd, u *user.User, attr *syscall.SysProcAttr) { + if u == nil { + return + } + _uid, _ := strconv.ParseUint(u.Uid, 10, 32) + _gid, _ := strconv.ParseUint(u.Gid, 10, 32) + if attr.Credential == nil { + attr.Credential = &syscall.Credential{} + } + attr.Credential.Uid = uint32(_uid) + attr.Credential.Gid = uint32(_gid) + + _cmd.Env = append(_cmd.Env, fmt.Sprintf("HOME=%s", u.HomeDir)) + _cmd.Env = append(_cmd.Env, fmt.Sprintf("LOGNAME=%s", u.Name)) + _cmd.Dir = u.HomeDir +} diff --git a/server.go b/server.go index 896247c..fc143ee 100644 --- a/server.go +++ b/server.go @@ -7,6 +7,7 @@ import ( "net" "os" "os/exec" + "os/user" "runtime" "syscall" "time" @@ -181,141 +182,174 @@ func parseTerminalModes(s string) ssh.TerminalModes { return t } -func (sc *Server) handleSession(newch ssh.NewChannel) { - ch, req, err := newch.Accept() - if err != nil { - log.Errorf("%s", err.Error()) - return +type session struct { + ch ssh.Channel + env []string + _console console.Console + ptsname string + cmd *exec.Cmd + user string +} + +func (s *session) handleSubsystem(payload []byte) bool { + var _cmd args + if err := ssh.Unmarshal(payload, &_cmd); err != nil { + log.Errorln(err) + return false + } + + if _cmd.Arg != "sftp" { // only support sftp + log.Debugln("subsystem", _cmd.Arg, "not support") + return false } + log.Debugf("handle sftp request") + go serveSFTP(s.ch) + return true +} - ret := false +func (s *session) handleShell() bool { + var cmd *exec.Cmd + if runtime.GOOS == "windows" { + s.env = append(s.env, fmt.Sprintf("SHELL=powershell")) + cmd = exec.Command("powershell") + } else { + s.env = append(s.env, fmt.Sprintf("SHELL=/bin/bash")) + cmd = exec.Command("/bin/bash", "-l") + } + s.cmd = cmd + cmd.Env = s.env + go handleShell(cmd, s.ch, s._console, s.ptsname, s.user) + return true +} +func (s *session) handleExec(payload []byte) bool { var _cmd args var cmd *exec.Cmd - var env []string + if err := ssh.Unmarshal(payload, &_cmd); err != nil { + log.Errorln(err) + return false + } + log.Infoln("execute command", _cmd.Arg) + if runtime.GOOS == "windows" { + s.env = append(s.env, fmt.Sprintf("SHELL=powershell")) + cmd = exec.Command("powershell", "-Command", _cmd.Arg) + } else { + s.env = append(s.env, fmt.Sprintf("SHELL=/bin/bash")) + cmd = exec.Command("/bin/bash", "-c", _cmd.Arg) + } + s.cmd = cmd + cmd.Env = s.env + go handleShell(cmd, s.ch, s._console, s.ptsname, s.user) + return true +} + +func (s *session) handlePtyReq(payload []byte) bool { var _ptyReq ptyReq + var err error + if err = ssh.Unmarshal(payload, &_ptyReq); err != nil { + log.Errorln(err) + return false + } + + log.Debugf("pty req Rows: %d, Columns: %d, Mode: %x", _ptyReq.Rows, _ptyReq.Columns, _ptyReq.Mode) + + termios := parseTerminalModes(_ptyReq.Mode) + log.Debugf("parsed terminal mode %+v", termios) + + s._console, s.ptsname, err = newPty() + if err != nil { + log.Errorln(err) + return false + } + + log.Debugf("allocate pty %s", s.ptsname) + log.Debugf("set termios") + if err1 := setTermios(int(s._console.Fd()), termios); err1 != nil { + log.Errorln(err) + return false + } + + s.env = append(s.env, fmt.Sprintf("SSH_TTY=%s", s.ptsname)) + s.env = append(s.env, fmt.Sprintf("TERM=%s", _ptyReq.Term)) + + ws, err := s._console.Size() + log.Debugf("current console %+v", ws) + ws.Height = uint16(_ptyReq.Rows) + ws.Width = uint16(_ptyReq.Columns) + if err = s._console.Resize(ws); err != nil { + log.Errorln(err) + return false + } + return true +} + +func (s *session) handleEnv(payload []byte) bool { + var arg envArgs + if err := ssh.Unmarshal(payload, &arg); err != nil { + log.Errorln(err) + return false + } + log.Debugf("got env %s=%s", arg.Name, arg.Value) + s.env = append(s.env, fmt.Sprintf("%s=%s", arg.Name, arg.Value)) + return true +} + +func (s *session) handleWindowChange(payload []byte) bool { var _windowChange windowChange - var _console console.Console - var ptsname string + if err := ssh.Unmarshal(payload, &_windowChange); err != nil { + log.Errorln(err) + return false + } + log.Debugf("window change %+v", _windowChange) + if s._console == nil { + // ignore + return true + } + + ws, err := s._console.Size() + if err != nil { + log.Errorln(err) + return false + } + + log.Debugf("current console %+v", ws) + ws.Height = uint16(_windowChange.Rows) + ws.Width = uint16(_windowChange.Columns) + if err := s._console.Resize(ws); err != nil { + log.Errorln(err) + return false + } + return true +} + +func (sc *Server) handleSession(newch ssh.NewChannel) { + ch, req, err := newch.Accept() + if err != nil { + log.Errorf("%s", err.Error()) + return + } + + sess := &session{ch: ch, user: sc.sshConn.User()} for r := range req { + ret := false switch r.Type { case "subsystem": - if err := ssh.Unmarshal(r.Payload, &_cmd); err == nil { - if _cmd.Arg == "sftp" { // only support sftp - ret = true - log.Debugf("handle sftp request") - go serveSFTP(ch) - } else { - ret = false - log.Debugln("subsystem", _cmd.Arg, "not support") - } - } else { - ret = false - log.Debugln("get subsystem arg error", err) - } + ret = sess.handleSubsystem(r.Payload) case "shell": - ret = true - if runtime.GOOS == "windows" { - env = append(env, fmt.Sprintf("SHELL=powershell")) - cmd = exec.Command("powershell") - } else { - env = append(env, fmt.Sprintf("SHELL=/bin/bash")) - cmd = exec.Command("/bin/bash", "-l") - } - env = append(env, fmt.Sprintf("HOME=/home/guest")) - cmd.Env = env - go handleShell(cmd, ch, _console, ptsname) - case "signal": - log.Debugln("got signal") - ret = true + ret = sess.handleShell() case "exec": - ret = true - if err = ssh.Unmarshal(r.Payload, &_cmd); err == nil { - log.Infoln("execute command", _cmd.Arg) - if runtime.GOOS == "windows" { - env = append(env, fmt.Sprintf("SHELL=powershell")) - cmd = exec.Command("powershell", "-Command", _cmd.Arg) - } else { - env = append(env, fmt.Sprintf("SHELL=/bin/bash")) - cmd = exec.Command("/bin/bash", "-c", _cmd.Arg) - } - env = append(env, fmt.Sprintf("HOME=/home/guest")) - cmd.Env = env - //cmd.Stdin = ch - go handleCommand(cmd, ch) - } else { - log.Debugln(err) - ret = false - } + ret = sess.handleExec(r.Payload) case "pty-req": - ret = true - err = ssh.Unmarshal(r.Payload, &_ptyReq) - if err != nil { - log.Errorln(err) - ret = false - } - log.Debugf("pty req Rows: %d, Columns: %d, Mode: %x", _ptyReq.Rows, _ptyReq.Columns, _ptyReq.Mode) - if err == nil && (runtime.GOOS == "unix" || runtime.GOOS == "linux") { - termios := parseTerminalModes(_ptyReq.Mode) - log.Debugf("parsed terminal mode %+v", termios) - _console, ptsname, err = newPty() - if err == nil { - log.Debugf("allocate pty %s", ptsname) - log.Debugf("set termios") - if err1 := setTermios(int(_console.Fd()), termios); err1 != nil { - log.Errorln(err) - } - env = append(env, fmt.Sprintf("SSH_TTY=%s", ptsname)) - env = append(env, fmt.Sprintf("TERM=%s", _ptyReq.Term)) - ws, err := _console.Size() - if err != nil { - log.Errorln(err) - } else { - log.Debugf("current console %+v", ws) - ws.Height = uint16(_ptyReq.Rows) - ws.Width = uint16(_ptyReq.Columns) - if err = _console.Resize(ws); err != nil { - log.Errorln(err) - } - } - } else { - log.Errorln(err) - } - } + ret = sess.handlePtyReq(r.Payload) case "env": - var arg envArgs - ret = true - if err = ssh.Unmarshal(r.Payload, &arg); err == nil { - log.Debugf("got env %s=%s", arg.Name, arg.Value) - env = append(env, fmt.Sprintf("%s=%s", arg.Name, arg.Value)) - } else { - log.Debugln("parse env failed", err) - ret = false - } + ret = sess.handleEnv(r.Payload) case "window-change": + ret = sess.handleWindowChange(r.Payload) + case "signal": + log.Debugln("got signal") ret = true - err = ssh.Unmarshal(r.Payload, &_windowChange) - if err != nil { - ret = false - log.Errorln(err) - } - if err == nil && _console != nil { - ws, err := _console.Size() - if err != nil { - log.Errorln(err) - } else { - log.Debugf("current console %+v", ws) - ws.Height = uint16(_windowChange.Rows) - ws.Width = uint16(_windowChange.Columns) - if err = _console.Resize(ws); err != nil { - log.Errorln(err) - } - } - } - log.Debugf("window change %+v", _windowChange) default: - ret = false } log.Debugf("session request %s, reply %v", r.Type, ret) @@ -324,9 +358,23 @@ func (sc *Server) handleSession(newch ssh.NewChannel) { r.Reply(ret, nil) } } + + if sess.cmd != nil && sess.cmd.Process != nil { + log.Debugf("kill the running process %s", sess.cmd.Args) + p := sess.cmd.Process + if err := p.Kill(); err != nil { + log.Debugln(err) + } + time.Sleep(100 * time.Millisecond) + if err := p.Signal(os.Kill); err != nil { + log.Debugln(err) + } + } + + log.Debugln("session ended.") } -func handleShell(cmd *exec.Cmd, ch ssh.Channel, _console console.Console, ptsname string) { +func handleShell(cmd *exec.Cmd, ch ssh.Channel, _console console.Console, ptsname string, _user string) { defer func() { ch.Close() if _console != nil { @@ -336,7 +384,16 @@ func handleShell(cmd *exec.Cmd, ch ssh.Channel, _console console.Console, ptsnam var err error - log.Infoln("start shell") + if cmd.SysProcAttr == nil { + cmd.SysProcAttr = &syscall.SysProcAttr{} + } + + _u, err := user.Lookup(_user) + if err != nil { + log.Errorln(err) + } + + setUserEnv(cmd, _u, cmd.SysProcAttr) if _console != nil { _tty, err := os.OpenFile(ptsname, syscall.O_RDWR|syscall.O_NOCTTY, 0600) @@ -348,9 +405,6 @@ func handleShell(cmd *exec.Cmd, ch ssh.Channel, _console console.Console, ptsnam cmd.Stderr = _tty cmd.Stdout = _tty cmd.Stdin = _tty - if cmd.SysProcAttr == nil { - cmd.SysProcAttr = &syscall.SysProcAttr{} - } setProcAttr(cmd.SysProcAttr) @@ -359,6 +413,7 @@ func handleShell(cmd *exec.Cmd, ch ssh.Channel, _console console.Console, ptsnam } else { cmd.Stderr = ch cmd.Stdout = ch + // cmd.Stdin = ch in, err := cmd.StdinPipe() if err != nil { @@ -388,36 +443,7 @@ func handleShell(cmd *exec.Cmd, ch ssh.Channel, _console console.Console, ptsnam } } } - ch.SendRequest("exit-status", false, - ssh.Marshal(exitStatus{Status: uint32(code)})) -} - -func handleCommand(cmd *exec.Cmd, ch ssh.Channel) { - defer ch.Close() - - cmd.Stdout = ch - cmd.Stderr = ch - //log.Debugln("execute command", cmd) - in, err := cmd.StdinPipe() - if err != nil { - log.Debugln(err) - ch.SendRequest("exit-status", false, - ssh.Marshal(exitStatus{Status: 127})) - return - } - go func() { - defer in.Close() - io.Copy(in, ch) - }() - code := 0 - if err := cmd.Run(); err != nil { - log.Debugln(err) - if exiterr, ok := err.(*exec.ExitError); ok { - if s, ok := exiterr.Sys().(syscall.WaitStatus); ok { - code = s.ExitStatus() - } - } - } + cmd.Process = nil ch.SendRequest("exit-status", false, ssh.Marshal(exitStatus{Status: uint32(code)})) }