server: add sftp support

master
Dingjun 8 years ago
parent a3c1a3ee75
commit 0ae9fe428e

1
.gitignore vendored

@ -1 +1,2 @@
*~ *~
*.swp

@ -2,9 +2,8 @@ package obfssh
import ( import (
"fmt" "fmt"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal"
//"log"
"net" "net"
) )
@ -73,22 +72,30 @@ func (sc *Server) close() {
func (sc *Server) handleNewChannelRequest(ch <-chan ssh.NewChannel) { func (sc *Server) handleNewChannelRequest(ch <-chan ssh.NewChannel) {
for newch := range ch { for newch := range ch {
Log(DEBUG, "request channel %s", newch.ChannelType())
switch newch.ChannelType() { switch newch.ChannelType() {
case "session": case "session":
//go sc.handleSession(newch) go sc.handleSession(newch)
//continue continue
case "direct-tcpip": case "direct-tcpip":
go handleDirectTcpip(newch) go handleDirectTcpip(newch)
continue continue
} }
Log(DEBUG, "reject channel request %s", newch.ChannelType()) Log(DEBUG, "reject channel request %s", newch.ChannelType())
newch.Reject(ssh.UnknownChannelType, "unknown channel type") newch.Reject(ssh.UnknownChannelType, "unknown channel type")
//channel, request, err := newch.Accept()
} }
} }
func (sc *Server) handleGlobalRequest(req <-chan *ssh.Request) { func (sc *Server) handleGlobalRequest(req <-chan *ssh.Request) {
for r := range req { for r := range req {
Log(DEBUG, "global request %s", r.Type)
switch r.Type { switch r.Type {
case "tcpip-forward": case "tcpip-forward":
Log(DEBUG, "request port forward") Log(DEBUG, "request port forward")
@ -99,32 +106,26 @@ func (sc *Server) handleGlobalRequest(req <-chan *ssh.Request) {
go sc.handleCancelTcpipForward(r) go sc.handleCancelTcpipForward(r)
continue continue
} }
Log(DEBUG, "global request %s", r.Type)
if r.WantReply { if r.WantReply {
r.Reply(false, nil) r.Reply(false, nil)
} }
} }
} }
func (sc *Server) handleChannelRequest(req <-chan *ssh.Request) { func serveSFTP(ch ssh.Channel) {
ret := false defer ch.Close()
for r := range req {
switch r.Type { server, err := sftp.NewServer(ch)
case "shell":
ret = true if err != nil {
case "pty-req": Log(DEBUG, "start sftp server failed: %s", err)
ret = true return
case "env":
ret = true
case "exec":
ret = false
case "subsystem":
default:
ret = false
}
if r.WantReply {
r.Reply(ret, nil)
} }
if err := server.Serve(); err != nil {
Log(DEBUG, "sftp server finished with error: %s", err)
return
} }
} }
@ -135,49 +136,81 @@ type directTcpipMsg struct {
Lport uint32 Lport uint32
} }
type args struct {
Arg string
}
func (sc *Server) handleSession(newch ssh.NewChannel) { func (sc *Server) handleSession(newch ssh.NewChannel) {
ch, req, err := newch.Accept() ch, req, err := newch.Accept()
if err != nil { if err != nil {
Log(ERROR, "%s", err.Error()) Log(ERROR, "%s", err.Error())
return return
} }
go sc.handleChannelRequest(req)
term := terminal.NewTerminal(ch, "shell>") var cmd args
defer ch.Close()
for { ret := false
line, err := term.ReadLine()
if err != nil { for r := range req {
break switch r.Type {
case "subsystem":
if err := ssh.Unmarshal(r.Payload, &cmd); err != nil {
ret = false
} else {
if cmd.Arg != "sftp" { // only support sftp
ret = false
} else {
ret = true
Log(DEBUG, "handle sftp request")
go serveSFTP(ch)
}
}
default:
ret = false
}
Log(DEBUG, "session request %s, reply %q", r.Type, ret)
if r.WantReply {
r.Reply(ret, nil)
} }
term.Write([]byte(line))
term.Write([]byte("\n"))
} }
} }
func handleDirectTcpip(newch ssh.NewChannel) { func handleDirectTcpip(newch ssh.NewChannel) {
data := newch.ExtraData()
var r directTcpipMsg var r directTcpipMsg
data := newch.ExtraData()
err := ssh.Unmarshal(data, &r) err := ssh.Unmarshal(data, &r)
if err != nil { if err != nil {
Log(DEBUG, "invalid ssh parameter") Log(DEBUG, "invalid ssh parameter")
newch.Reject(ssh.ConnectionFailed, "invalid argument") newch.Reject(ssh.ConnectionFailed, "invalid argument")
return return
} }
Log(DEBUG, "create connection to %s:%d", r.Raddr, r.Rport) Log(DEBUG, "create connection to %s:%d", r.Raddr, r.Rport)
rconn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", r.Raddr, r.Rport)) rconn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", r.Raddr, r.Rport))
if err != nil { if err != nil {
Log(ERROR, "%s", err.Error()) Log(ERROR, "%s", err.Error())
newch.Reject(ssh.ConnectionFailed, "invalid argument") newch.Reject(ssh.ConnectionFailed, "invalid argument")
return return
} }
channel, requests, err := newch.Accept() channel, requests, err := newch.Accept()
if err != nil { if err != nil {
rconn.Close() rconn.Close()
Log(ERROR, "%s", err.Error()) Log(ERROR, "%s", err.Error())
return return
} }
//log.Println("forward") //log.Println("forward")
go ssh.DiscardRequests(requests) go ssh.DiscardRequests(requests)
PipeAndClose(channel, rconn) PipeAndClose(channel, rconn)
} }

Loading…
Cancel
Save