refine handle session

ws
dingjun 5 years ago
parent e7c09d348e
commit 690c894f36

@ -4,6 +4,8 @@ package obfssh
import (
"errors"
"os/exec"
"os/user"
"syscall"
"github.com/containerd/console"
@ -24,3 +26,10 @@ func setProcAttr(attr *syscall.SysProcAttr) {
func setTermios(fd int, args ssh.TerminalModes) error {
return errors.New("not supported")
}
func setUserEnv(_cmd *exec.Cmd, u *user.User, attr *syscall.SysProcAttr) {
if u == nil {
return
}
_cmd.Dir = u.HomeDir
}

@ -3,8 +3,12 @@
package obfssh
import (
"fmt"
"os"
"os/exec"
"os/signal"
"os/user"
"strconv"
"syscall"
"github.com/containerd/console"
@ -98,3 +102,20 @@ func setTermios(fd int, args ssh.TerminalModes) error {
return nil
}
func setUserEnv(_cmd *exec.Cmd, u *user.User, attr *syscall.SysProcAttr) {
if u == nil {
return
}
_uid, _ := strconv.ParseUint(u.Uid, 10, 32)
_gid, _ := strconv.ParseUint(u.Gid, 10, 32)
if attr.Credential == nil {
attr.Credential = &syscall.Credential{}
}
attr.Credential.Uid = uint32(_uid)
attr.Credential.Gid = uint32(_gid)
_cmd.Env = append(_cmd.Env, fmt.Sprintf("HOME=%s", u.HomeDir))
_cmd.Env = append(_cmd.Env, fmt.Sprintf("LOGNAME=%s", u.Name))
_cmd.Dir = u.HomeDir
}

@ -7,6 +7,7 @@ import (
"net"
"os"
"os/exec"
"os/user"
"runtime"
"syscall"
"time"
@ -181,141 +182,174 @@ func parseTerminalModes(s string) ssh.TerminalModes {
return t
}
func (sc *Server) handleSession(newch ssh.NewChannel) {
ch, req, err := newch.Accept()
if err != nil {
log.Errorf("%s", err.Error())
return
type session struct {
ch ssh.Channel
env []string
_console console.Console
ptsname string
cmd *exec.Cmd
user string
}
func (s *session) handleSubsystem(payload []byte) bool {
var _cmd args
if err := ssh.Unmarshal(payload, &_cmd); err != nil {
log.Errorln(err)
return false
}
if _cmd.Arg != "sftp" { // only support sftp
log.Debugln("subsystem", _cmd.Arg, "not support")
return false
}
log.Debugf("handle sftp request")
go serveSFTP(s.ch)
return true
}
ret := false
func (s *session) handleShell() bool {
var cmd *exec.Cmd
if runtime.GOOS == "windows" {
s.env = append(s.env, fmt.Sprintf("SHELL=powershell"))
cmd = exec.Command("powershell")
} else {
s.env = append(s.env, fmt.Sprintf("SHELL=/bin/bash"))
cmd = exec.Command("/bin/bash", "-l")
}
s.cmd = cmd
cmd.Env = s.env
go handleShell(cmd, s.ch, s._console, s.ptsname, s.user)
return true
}
func (s *session) handleExec(payload []byte) bool {
var _cmd args
var cmd *exec.Cmd
var env []string
if err := ssh.Unmarshal(payload, &_cmd); err != nil {
log.Errorln(err)
return false
}
log.Infoln("execute command", _cmd.Arg)
if runtime.GOOS == "windows" {
s.env = append(s.env, fmt.Sprintf("SHELL=powershell"))
cmd = exec.Command("powershell", "-Command", _cmd.Arg)
} else {
s.env = append(s.env, fmt.Sprintf("SHELL=/bin/bash"))
cmd = exec.Command("/bin/bash", "-c", _cmd.Arg)
}
s.cmd = cmd
cmd.Env = s.env
go handleShell(cmd, s.ch, s._console, s.ptsname, s.user)
return true
}
func (s *session) handlePtyReq(payload []byte) bool {
var _ptyReq ptyReq
var err error
if err = ssh.Unmarshal(payload, &_ptyReq); err != nil {
log.Errorln(err)
return false
}
log.Debugf("pty req Rows: %d, Columns: %d, Mode: %x", _ptyReq.Rows, _ptyReq.Columns, _ptyReq.Mode)
termios := parseTerminalModes(_ptyReq.Mode)
log.Debugf("parsed terminal mode %+v", termios)
s._console, s.ptsname, err = newPty()
if err != nil {
log.Errorln(err)
return false
}
log.Debugf("allocate pty %s", s.ptsname)
log.Debugf("set termios")
if err1 := setTermios(int(s._console.Fd()), termios); err1 != nil {
log.Errorln(err)
return false
}
s.env = append(s.env, fmt.Sprintf("SSH_TTY=%s", s.ptsname))
s.env = append(s.env, fmt.Sprintf("TERM=%s", _ptyReq.Term))
ws, err := s._console.Size()
log.Debugf("current console %+v", ws)
ws.Height = uint16(_ptyReq.Rows)
ws.Width = uint16(_ptyReq.Columns)
if err = s._console.Resize(ws); err != nil {
log.Errorln(err)
return false
}
return true
}
func (s *session) handleEnv(payload []byte) bool {
var arg envArgs
if err := ssh.Unmarshal(payload, &arg); err != nil {
log.Errorln(err)
return false
}
log.Debugf("got env %s=%s", arg.Name, arg.Value)
s.env = append(s.env, fmt.Sprintf("%s=%s", arg.Name, arg.Value))
return true
}
func (s *session) handleWindowChange(payload []byte) bool {
var _windowChange windowChange
var _console console.Console
var ptsname string
if err := ssh.Unmarshal(payload, &_windowChange); err != nil {
log.Errorln(err)
return false
}
log.Debugf("window change %+v", _windowChange)
if s._console == nil {
// ignore
return true
}
ws, err := s._console.Size()
if err != nil {
log.Errorln(err)
return false
}
log.Debugf("current console %+v", ws)
ws.Height = uint16(_windowChange.Rows)
ws.Width = uint16(_windowChange.Columns)
if err := s._console.Resize(ws); err != nil {
log.Errorln(err)
return false
}
return true
}
func (sc *Server) handleSession(newch ssh.NewChannel) {
ch, req, err := newch.Accept()
if err != nil {
log.Errorf("%s", err.Error())
return
}
sess := &session{ch: ch, user: sc.sshConn.User()}
for r := range req {
ret := false
switch r.Type {
case "subsystem":
if err := ssh.Unmarshal(r.Payload, &_cmd); err == nil {
if _cmd.Arg == "sftp" { // only support sftp
ret = true
log.Debugf("handle sftp request")
go serveSFTP(ch)
} else {
ret = false
log.Debugln("subsystem", _cmd.Arg, "not support")
}
} else {
ret = false
log.Debugln("get subsystem arg error", err)
}
ret = sess.handleSubsystem(r.Payload)
case "shell":
ret = true
if runtime.GOOS == "windows" {
env = append(env, fmt.Sprintf("SHELL=powershell"))
cmd = exec.Command("powershell")
} else {
env = append(env, fmt.Sprintf("SHELL=/bin/bash"))
cmd = exec.Command("/bin/bash", "-l")
}
env = append(env, fmt.Sprintf("HOME=/home/guest"))
cmd.Env = env
go handleShell(cmd, ch, _console, ptsname)
case "signal":
log.Debugln("got signal")
ret = true
ret = sess.handleShell()
case "exec":
ret = true
if err = ssh.Unmarshal(r.Payload, &_cmd); err == nil {
log.Infoln("execute command", _cmd.Arg)
if runtime.GOOS == "windows" {
env = append(env, fmt.Sprintf("SHELL=powershell"))
cmd = exec.Command("powershell", "-Command", _cmd.Arg)
} else {
env = append(env, fmt.Sprintf("SHELL=/bin/bash"))
cmd = exec.Command("/bin/bash", "-c", _cmd.Arg)
}
env = append(env, fmt.Sprintf("HOME=/home/guest"))
cmd.Env = env
//cmd.Stdin = ch
go handleCommand(cmd, ch)
} else {
log.Debugln(err)
ret = false
}
ret = sess.handleExec(r.Payload)
case "pty-req":
ret = true
err = ssh.Unmarshal(r.Payload, &_ptyReq)
if err != nil {
log.Errorln(err)
ret = false
}
log.Debugf("pty req Rows: %d, Columns: %d, Mode: %x", _ptyReq.Rows, _ptyReq.Columns, _ptyReq.Mode)
if err == nil && (runtime.GOOS == "unix" || runtime.GOOS == "linux") {
termios := parseTerminalModes(_ptyReq.Mode)
log.Debugf("parsed terminal mode %+v", termios)
_console, ptsname, err = newPty()
if err == nil {
log.Debugf("allocate pty %s", ptsname)
log.Debugf("set termios")
if err1 := setTermios(int(_console.Fd()), termios); err1 != nil {
log.Errorln(err)
}
env = append(env, fmt.Sprintf("SSH_TTY=%s", ptsname))
env = append(env, fmt.Sprintf("TERM=%s", _ptyReq.Term))
ws, err := _console.Size()
if err != nil {
log.Errorln(err)
} else {
log.Debugf("current console %+v", ws)
ws.Height = uint16(_ptyReq.Rows)
ws.Width = uint16(_ptyReq.Columns)
if err = _console.Resize(ws); err != nil {
log.Errorln(err)
}
}
} else {
log.Errorln(err)
}
}
ret = sess.handlePtyReq(r.Payload)
case "env":
var arg envArgs
ret = true
if err = ssh.Unmarshal(r.Payload, &arg); err == nil {
log.Debugf("got env %s=%s", arg.Name, arg.Value)
env = append(env, fmt.Sprintf("%s=%s", arg.Name, arg.Value))
} else {
log.Debugln("parse env failed", err)
ret = false
}
ret = sess.handleEnv(r.Payload)
case "window-change":
ret = sess.handleWindowChange(r.Payload)
case "signal":
log.Debugln("got signal")
ret = true
err = ssh.Unmarshal(r.Payload, &_windowChange)
if err != nil {
ret = false
log.Errorln(err)
}
if err == nil && _console != nil {
ws, err := _console.Size()
if err != nil {
log.Errorln(err)
} else {
log.Debugf("current console %+v", ws)
ws.Height = uint16(_windowChange.Rows)
ws.Width = uint16(_windowChange.Columns)
if err = _console.Resize(ws); err != nil {
log.Errorln(err)
}
}
}
log.Debugf("window change %+v", _windowChange)
default:
ret = false
}
log.Debugf("session request %s, reply %v", r.Type, ret)
@ -324,9 +358,23 @@ func (sc *Server) handleSession(newch ssh.NewChannel) {
r.Reply(ret, nil)
}
}
if sess.cmd != nil && sess.cmd.Process != nil {
log.Debugf("kill the running process %s", sess.cmd.Args)
p := sess.cmd.Process
if err := p.Kill(); err != nil {
log.Debugln(err)
}
time.Sleep(100 * time.Millisecond)
if err := p.Signal(os.Kill); err != nil {
log.Debugln(err)
}
}
log.Debugln("session ended.")
}
func handleShell(cmd *exec.Cmd, ch ssh.Channel, _console console.Console, ptsname string) {
func handleShell(cmd *exec.Cmd, ch ssh.Channel, _console console.Console, ptsname string, _user string) {
defer func() {
ch.Close()
if _console != nil {
@ -336,7 +384,16 @@ func handleShell(cmd *exec.Cmd, ch ssh.Channel, _console console.Console, ptsnam
var err error
log.Infoln("start shell")
if cmd.SysProcAttr == nil {
cmd.SysProcAttr = &syscall.SysProcAttr{}
}
_u, err := user.Lookup(_user)
if err != nil {
log.Errorln(err)
}
setUserEnv(cmd, _u, cmd.SysProcAttr)
if _console != nil {
_tty, err := os.OpenFile(ptsname, syscall.O_RDWR|syscall.O_NOCTTY, 0600)
@ -348,9 +405,6 @@ func handleShell(cmd *exec.Cmd, ch ssh.Channel, _console console.Console, ptsnam
cmd.Stderr = _tty
cmd.Stdout = _tty
cmd.Stdin = _tty
if cmd.SysProcAttr == nil {
cmd.SysProcAttr = &syscall.SysProcAttr{}
}
setProcAttr(cmd.SysProcAttr)
@ -359,6 +413,7 @@ func handleShell(cmd *exec.Cmd, ch ssh.Channel, _console console.Console, ptsnam
} else {
cmd.Stderr = ch
cmd.Stdout = ch
// cmd.Stdin = ch
in, err := cmd.StdinPipe()
if err != nil {
@ -388,36 +443,7 @@ func handleShell(cmd *exec.Cmd, ch ssh.Channel, _console console.Console, ptsnam
}
}
}
ch.SendRequest("exit-status", false,
ssh.Marshal(exitStatus{Status: uint32(code)}))
}
func handleCommand(cmd *exec.Cmd, ch ssh.Channel) {
defer ch.Close()
cmd.Stdout = ch
cmd.Stderr = ch
//log.Debugln("execute command", cmd)
in, err := cmd.StdinPipe()
if err != nil {
log.Debugln(err)
ch.SendRequest("exit-status", false,
ssh.Marshal(exitStatus{Status: 127}))
return
}
go func() {
defer in.Close()
io.Copy(in, ch)
}()
code := 0
if err := cmd.Run(); err != nil {
log.Debugln(err)
if exiterr, ok := err.(*exec.ExitError); ok {
if s, ok := exiterr.Sys().(syscall.WaitStatus); ok {
code = s.ExitStatus()
}
}
}
cmd.Process = nil
ch.SendRequest("exit-status", false,
ssh.Marshal(exitStatus{Status: uint32(code)}))
}

Loading…
Cancel
Save