diff --git a/.gitignore b/.gitignore index b25c15b..b72f9be 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ *~ +*.swp diff --git a/server.go b/server.go index 3badc3f..f19fc27 100644 --- a/server.go +++ b/server.go @@ -2,9 +2,8 @@ package obfssh import ( "fmt" + "github.com/pkg/sftp" "golang.org/x/crypto/ssh" - "golang.org/x/crypto/ssh/terminal" - //"log" "net" ) @@ -73,22 +72,30 @@ func (sc *Server) close() { func (sc *Server) handleNewChannelRequest(ch <-chan ssh.NewChannel) { for newch := range ch { + + Log(DEBUG, "request channel %s", newch.ChannelType()) + switch newch.ChannelType() { case "session": - //go sc.handleSession(newch) - //continue + go sc.handleSession(newch) + continue + case "direct-tcpip": go handleDirectTcpip(newch) continue } + Log(DEBUG, "reject channel request %s", newch.ChannelType()) + newch.Reject(ssh.UnknownChannelType, "unknown channel type") - //channel, request, err := newch.Accept() } } func (sc *Server) handleGlobalRequest(req <-chan *ssh.Request) { for r := range req { + + Log(DEBUG, "global request %s", r.Type) + switch r.Type { case "tcpip-forward": Log(DEBUG, "request port forward") @@ -99,32 +106,26 @@ func (sc *Server) handleGlobalRequest(req <-chan *ssh.Request) { go sc.handleCancelTcpipForward(r) continue } - Log(DEBUG, "global request %s", r.Type) + if r.WantReply { r.Reply(false, nil) } } } -func (sc *Server) handleChannelRequest(req <-chan *ssh.Request) { - ret := false - for r := range req { - switch r.Type { - case "shell": - ret = true - case "pty-req": - ret = true - case "env": - ret = true - case "exec": - ret = false - case "subsystem": - default: - ret = false - } - if r.WantReply { - r.Reply(ret, nil) - } +func serveSFTP(ch ssh.Channel) { + defer ch.Close() + + server, err := sftp.NewServer(ch) + + if err != nil { + Log(DEBUG, "start sftp server failed: %s", err) + return + } + + if err := server.Serve(); err != nil { + Log(DEBUG, "sftp server finished with error: %s", err) + return } } @@ -135,49 +136,81 @@ type directTcpipMsg struct { Lport uint32 } +type args struct { + Arg string +} + func (sc *Server) handleSession(newch ssh.NewChannel) { ch, req, err := newch.Accept() if err != nil { Log(ERROR, "%s", err.Error()) return } - go sc.handleChannelRequest(req) - term := terminal.NewTerminal(ch, "shell>") - defer ch.Close() - for { - line, err := term.ReadLine() - if err != nil { - break + + var cmd args + + ret := false + + for r := range req { + switch r.Type { + case "subsystem": + if err := ssh.Unmarshal(r.Payload, &cmd); err != nil { + ret = false + } else { + if cmd.Arg != "sftp" { // only support sftp + ret = false + } else { + + ret = true + + Log(DEBUG, "handle sftp request") + + go serveSFTP(ch) + } + } + default: + ret = false + } + + Log(DEBUG, "session request %s, reply %q", r.Type, ret) + + if r.WantReply { + r.Reply(ret, nil) } - term.Write([]byte(line)) - term.Write([]byte("\n")) } } func handleDirectTcpip(newch ssh.NewChannel) { - data := newch.ExtraData() var r directTcpipMsg + + data := newch.ExtraData() + err := ssh.Unmarshal(data, &r) if err != nil { Log(DEBUG, "invalid ssh parameter") newch.Reject(ssh.ConnectionFailed, "invalid argument") return } + Log(DEBUG, "create connection to %s:%d", r.Raddr, r.Rport) + rconn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", r.Raddr, r.Rport)) if err != nil { Log(ERROR, "%s", err.Error()) newch.Reject(ssh.ConnectionFailed, "invalid argument") return } + channel, requests, err := newch.Accept() if err != nil { rconn.Close() Log(ERROR, "%s", err.Error()) return } + //log.Println("forward") go ssh.DiscardRequests(requests) + PipeAndClose(channel, rconn) }