diff --git a/console_unix.go b/console_unix.go index 3c5e6a0..d0beed9 100644 --- a/console_unix.go +++ b/console_unix.go @@ -9,6 +9,7 @@ import ( "os/signal" "os/user" "strconv" + "strings" "syscall" "github.com/containerd/console" @@ -109,13 +110,29 @@ func setUserEnv(_cmd *exec.Cmd, u *user.User, attr *syscall.SysProcAttr) { } _uid, _ := strconv.ParseUint(u.Uid, 10, 32) _gid, _ := strconv.ParseUint(u.Gid, 10, 32) + + _cmd.Env = append(_cmd.Env, fmt.Sprintf("HOME=%s", u.HomeDir)) + _cmd.Env = append(_cmd.Env, fmt.Sprintf("LOGNAME=%s", u.Name)) + _cmd.Dir = u.HomeDir + + if os.Getuid() != 0 { + return + } + if attr.Credential == nil { attr.Credential = &syscall.Credential{} } + attr.Credential.Uid = uint32(_uid) attr.Credential.Gid = uint32(_gid) + for _, _env := range _cmd.Env { + ss := strings.Split(_env, "=") + if ss[0] == "SSH_AUTH_SOCK" { + os.Chown(ss[1], int(_uid), int(_gid)) + } + if ss[0] == "SSH_TTY" { + os.Chown(ss[1], int(_uid), 0) + } + } - _cmd.Env = append(_cmd.Env, fmt.Sprintf("HOME=%s", u.HomeDir)) - _cmd.Env = append(_cmd.Env, fmt.Sprintf("LOGNAME=%s", u.Name)) - _cmd.Dir = u.HomeDir } diff --git a/obfscp/scp.go b/obfscp/scp.go index 5707a68..1929b97 100644 --- a/obfscp/scp.go +++ b/obfscp/scp.go @@ -102,13 +102,7 @@ func createSFTPConn(host, user string, cfg *options) (*sftp.Client, error) { if cfg.Passwd == "" && cfg.PrivateKey == "" { var pkeys []ssh.Signer if aconn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { - //auths = append(auths, ssh.PublicKeysCallback(agent.NewClient(aconn).Signers)) - if signers, err := agent.NewClient(aconn).Signers(); err == nil { - log.Debugf("add private key from agent") - pkeys = append(pkeys, signers...) - } else { - log.Debugf("get key from agent failed: %s", err) - } + auths = append(auths, ssh.PublicKeysCallback(agent.NewClient(aconn).Signers)) } else { log.Debugf("dial to agent failed: %s", err) } diff --git a/obfssh/ssh.go b/obfssh/ssh.go index 261ac08..2c94dca 100644 --- a/obfssh/ssh.go +++ b/obfssh/ssh.go @@ -117,13 +117,7 @@ func main() { defer agentConn.Close() log.Debugf("add auth method with agent %s", os.Getenv("SSH_AUTH_SOCK")) agentClient = agent.NewClient(agentConn) - //auth = append(auth, ssh.PublicKeysCallback(agentClient.Signers)) - signers, err := agentClient.Signers() - if err == nil { - pkeys = append(pkeys, signers...) - } else { - log.Debugf("get key from agent failed: %s", err) - } + auth = append(auth, ssh.PublicKeysCallback(agentClient.Signers)) } else { log.Debugf("connect to agent failed") } diff --git a/server.go b/server.go index fc143ee..d8936ff 100644 --- a/server.go +++ b/server.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "fmt" "io" + "io/ioutil" "net" "os" "os/exec" @@ -322,6 +323,50 @@ func (s *session) handleWindowChange(payload []byte) bool { return true } +func (sc *Server) handleAuthAgentForward(sess *session) bool { + f, err := ioutil.TempFile("", "agent-") + if err != nil { + log.Errorln(err) + return false + } + + p := f.Name() + f.Close() + os.Remove(p) + + l, err := net.Listen("unix", p) + if err != nil { + log.Errorln(err) + return false + } + + sess.env = append(sess.env, fmt.Sprintf("SSH_AUTH_SOCK=%s", p)) + + sc.forwardedPorts[p] = l + + go func() { + defer os.Remove(p) + for { + c, err := l.Accept() + if err != nil { + log.Errorln(err) + break + } + go func(c net.Conn) { + ch, req, err := sc.sshConn.OpenChannel("auth-agent@openssh.com", nil) + if err != nil { + c.Close() + log.Errorln(err) + return + } + go ssh.DiscardRequests(req) + PipeAndClose(ch, c) + }(c) + } + }() + return true +} + func (sc *Server) handleSession(newch ssh.NewChannel) { ch, req, err := newch.Accept() if err != nil { @@ -346,6 +391,8 @@ func (sc *Server) handleSession(newch ssh.NewChannel) { ret = sess.handleEnv(r.Payload) case "window-change": ret = sess.handleWindowChange(r.Payload) + case "auth-agent-req@openssh.com": + ret = sc.handleAuthAgentForward(sess) case "signal": log.Debugln("got signal") ret = true