diff --git a/pty_unix.go b/pty_unix.go new file mode 100644 index 0000000..400e726 --- /dev/null +++ b/pty_unix.go @@ -0,0 +1,14 @@ +// +build linux + +package obfssh + +import ( + "io" + "os/exec" + + "github.com/kr/pty" +) + +func startPty(cmd *exec.Cmd) (io.ReadWriteCloser, error) { + return pty.Start(cmd) +} diff --git a/pty_windows.go b/pty_windows.go new file mode 100644 index 0000000..de08339 --- /dev/null +++ b/pty_windows.go @@ -0,0 +1,11 @@ +package obfssh + +import ( + "errors" + "io" + "os/exec" +) + +func startPty(cmd *exec.Cmd) (io.ReadWriteCloser, error) { + return nil, errors.New("not implement") +} diff --git a/server.go b/server.go index fd57d46..178c041 100644 --- a/server.go +++ b/server.go @@ -2,7 +2,11 @@ package obfssh import ( "fmt" + "io" "net" + "os/exec" + "runtime" + "syscall" "time" "github.com/fangdingjun/go-log" @@ -130,6 +134,15 @@ type args struct { Arg string } +type envArgs struct { + Name string + Value string +} + +type exitStatus struct { + Status uint32 +} + func (sc *Server) handleSession(newch ssh.NewChannel) { ch, req, err := newch.Accept() if err != nil { @@ -137,27 +150,69 @@ func (sc *Server) handleSession(newch ssh.NewChannel) { return } - var cmd args + var _cmd args ret := false + var cmd *exec.Cmd + var env []string for r := range req { switch r.Type { case "subsystem": - if err := ssh.Unmarshal(r.Payload, &cmd); err != nil { - ret = false - } else { - if cmd.Arg != "sftp" { // only support sftp - ret = false - } else { - + if err := ssh.Unmarshal(r.Payload, &_cmd); err == nil { + if _cmd.Arg == "sftp" { // only support sftp ret = true - log.Debugf("handle sftp request") - go serveSFTP(ch) + } else { + log.Debugln("subsystem", _cmd.Arg, "not support") } + } else { + ret = false + log.Debugln("get subsystem arg error", err) } + case "shell": + ret = true + if runtime.GOOS == "windows" { + cmd = exec.Command("powershell") + } else { + cmd = exec.Command("bash", "-l") + } + cmd.Env = env + go handleShell(cmd, ch) + case "signal": + log.Debugln("got signal") + ret = true + case "exec": + ret = true + if err = ssh.Unmarshal(r.Payload, &_cmd); err == nil { + log.Debugln("execute command", _cmd.Arg) + if runtime.GOOS == "windows" { + cmd = exec.Command("powershell", "-Command", _cmd.Arg) + } else { + cmd = exec.Command("bash", "-c", _cmd.Arg) + } + cmd.Env = env + //cmd.Stdin = ch + go handleCommand(cmd, ch) + } else { + log.Debugln(err) + ret = false + } + case "pty-req": + ret = true + case "env": + var arg envArgs + ret = true + if err = ssh.Unmarshal(r.Payload, &arg); err == nil { + log.Debugf("got env %s=%s", arg.Name, arg.Value) + env = append(env, fmt.Sprintf("%s=%s", arg.Name, arg.Value)) + } else { + log.Debugln("parse env failed", err) + ret = false + } + case "window-change": + ret = true default: ret = false } @@ -170,6 +225,92 @@ func (sc *Server) handleSession(newch ssh.NewChannel) { } } +func handleShell(cmd *exec.Cmd, ch ssh.Channel) { + defer ch.Close() + + var _pty io.ReadWriteCloser + var err error + + log.Debugln("start shell") + + //_pty, err = pty.Start(cmd) + if runtime.GOOS == "unix" || runtime.GOOS == "linux" { + _pty, err = startPty(cmd) + if err != nil { + log.Debugln("start pty", err) + ch.SendRequest("exit-status", false, + ssh.Marshal(exitStatus{Status: 127})) + return + } + } + + if runtime.GOOS == "unix" || runtime.GOOS == "linux" { + defer _pty.Close() + go io.Copy(ch, _pty) + go io.Copy(_pty, ch) + } else { // windows + cmd.Stderr = ch + cmd.Stdout = ch + in, err := cmd.StdinPipe() + if err != nil { + ch.SendRequest("exit-status", false, + ssh.Marshal(exitStatus{Status: 127})) + return + } + go func() { + defer in.Close() + io.Copy(in, ch) + }() + if err := cmd.Start(); err != nil { + log.Debugln("start command ", err) + ch.SendRequest("exit-status", false, + ssh.Marshal(exitStatus{Status: 126})) + return + } + } + code := 0 + if err = cmd.Wait(); err != nil { + log.Debugln(err) + if exiterr, ok := err.(*exec.ExitError); ok { + if s, ok := exiterr.Sys().(syscall.WaitStatus); ok { + code = s.ExitStatus() + } + } + } + ch.SendRequest("exit-status", false, + ssh.Marshal(exitStatus{Status: uint32(code)})) +} + +func handleCommand(cmd *exec.Cmd, ch ssh.Channel) { + defer ch.Close() + + cmd.Stdout = ch + cmd.Stderr = ch + //log.Debugln("execute command", cmd) + in, err := cmd.StdinPipe() + if err != nil { + log.Debugln(err) + ch.SendRequest("exit-status", false, + ssh.Marshal(exitStatus{Status: 127})) + return + } + go func() { + defer in.Close() + io.Copy(in, ch) + }() + code := 0 + if err := cmd.Run(); err != nil { + log.Debugln(err) + if exiterr, ok := err.(*exec.ExitError); ok { + if s, ok := exiterr.Sys().(syscall.WaitStatus); ok { + code = s.ExitStatus() + } + } + } + ch.SendRequest("exit-status", false, + ssh.Marshal(exitStatus{Status: uint32(code)})) +} + func handleDirectTcpip(newch ssh.NewChannel) { var r directTcpipMsg