scp: support get multiple remote file to local

master
fangdingjun 8 years ago
parent 6bb7601ef4
commit fb4067916e

@ -1,6 +1,7 @@
package main package main
import ( import (
"errors"
"flag" "flag"
"fmt" "fmt"
"github.com/bgentry/speakeasy" "github.com/bgentry/speakeasy"
@ -20,28 +21,34 @@ import (
"time" "time"
) )
func main() { type options struct {
var user, port, pass, key string Debug bool
var recursive bool Port int
var obfsMethod, obfsKey string User string
var disableObfsAfterHandshake bool Passwd string
var debug bool Recursive bool
var hasError bool ObfsMethod string
ObfsKey string
DisableObfsAfterHandshake bool
PrivateKey string
}
func main() {
var cfg options
flag.Usage = usage flag.Usage = usage
flag.BoolVar(&debug, "d", false, "verbose mode") flag.BoolVar(&cfg.Debug, "d", false, "verbose mode")
flag.StringVar(&port, "p", "22", "port") flag.IntVar(&cfg.Port, "p", 22, "port")
flag.StringVar(&user, "l", os.Getenv("USER"), "user") flag.StringVar(&cfg.User, "l", os.Getenv("USER"), "user")
flag.StringVar(&pass, "pw", "", "password") flag.StringVar(&cfg.Passwd, "pw", "", "password")
flag.StringVar(&key, "i", "", "private key") flag.StringVar(&cfg.PrivateKey, "i", "", "private key")
flag.BoolVar(&recursive, "r", false, "recursively copy entries") flag.BoolVar(&cfg.Recursive, "r", false, "recursively copy entries")
flag.StringVar(&obfsMethod, "obfs_method", "none", "obfs encrypt method, rc4, aes or none") flag.StringVar(&cfg.ObfsMethod, "obfs_method", "none", "obfs encrypt method, rc4, aes or none")
flag.StringVar(&obfsKey, "obfs_key", "", "obfs encrypt key") flag.StringVar(&cfg.ObfsKey, "obfs_key", "", "obfs encrypt key")
flag.BoolVar(&disableObfsAfterHandshake, "disable_obfs_after_handshake", false, "disable obfs after handshake") flag.BoolVar(&cfg.DisableObfsAfterHandshake, "disable_obfs_after_handshake", false, "disable obfs after handshake")
flag.Parse() flag.Parse()
if debug { if cfg.Debug {
obfssh.SSHLogLevel = obfssh.DEBUG obfssh.SSHLogLevel = obfssh.DEBUG
} }
@ -52,35 +59,24 @@ func main() {
os.Exit(1) os.Exit(1)
} }
var host, path string var err error
r1 := ""
var toLocal = false
if strings.Contains(args[0], ":") { if strings.Contains(args[0], ":") {
toLocal = true err = download(args, &cfg)
r1 = args[0]
} else { } else {
toLocal = false err = upload(args, &cfg)
r1 = args[len(args)-1]
} }
if strings.Contains(r1, "@") { if err != nil {
ss1 := strings.SplitN(r1, "@", 2) log.Fatal(err)
user = ss1[0]
r1 = ss1[1]
}
ss2 := strings.SplitN(r1, ":", 2)
if len(ss2) != 2 {
//log.Fatal("Usage: \n\tscp user@host:path local\n\tor\n\tscp local... user@host:path")
flag.Usage()
os.Exit(1)
} }
host = ss2[0] }
path = ss2[1]
func createSFTPConn(host, user string, cfg *options) (*sftp.Client, error) {
auths := []ssh.AuthMethod{} auths := []ssh.AuthMethod{}
// read ssh agent and default auth key // read ssh agent and default auth key
if pass == "" && key == "" { if cfg.Passwd == "" && cfg.PrivateKey == "" {
var pkeys []ssh.Signer var pkeys []ssh.Signer
if aconn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { if aconn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil {
//auths = append(auths, ssh.PublicKeysCallback(agent.NewClient(aconn).Signers)) //auths = append(auths, ssh.PublicKeysCallback(agent.NewClient(aconn).Signers))
@ -121,19 +117,19 @@ func main() {
} }
} }
if pass != "" { if cfg.Passwd != "" {
debuglog("add password auth") debuglog("add password auth")
auths = append(auths, ssh.Password(pass)) auths = append(auths, ssh.Password(cfg.Passwd))
} else { } else {
debuglog("add keyboard interactive") debuglog("add keyboard interactive")
auths = append(auths, auths = append(auths,
ssh.RetryableAuthMethod(ssh.PasswordCallback(passwordAuth), 3)) ssh.RetryableAuthMethod(ssh.PasswordCallback(passwordAuth), 3))
} }
if key != "" { if cfg.PrivateKey != "" {
if buf, err := ioutil.ReadFile(key); err == nil { if buf, err := ioutil.ReadFile(cfg.PrivateKey); err == nil {
if p, err := ssh.ParsePrivateKey(buf); err == nil { if p, err := ssh.ParsePrivateKey(buf); err == nil {
debuglog("add private key: %s", key) debuglog("add private key: %s", cfg.PrivateKey)
auths = append(auths, ssh.PublicKeys(p)) auths = append(auths, ssh.PublicKeys(p))
} else { } else {
debuglog("parse private key failed: %s", err) debuglog("parse private key failed: %s", err)
@ -142,6 +138,9 @@ func main() {
debuglog("read private key failed: %s", err) debuglog("read private key failed: %s", err)
} }
} }
if user == "" {
user = cfg.User
}
config := &ssh.ClientConfig{ config := &ssh.ClientConfig{
User: user, User: user,
@ -149,113 +148,189 @@ func main() {
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
} }
rhost := net.JoinHostPort(host, port) rhost := net.JoinHostPort(host, fmt.Sprintf("%d", cfg.Port))
c, err := net.Dial("tcp", rhost) c, err := net.Dial("tcp", rhost)
if err != nil { if err != nil {
log.Fatal(err) //log.Fatal(err)
return nil, err
} }
conf := &obfssh.Conf{ conf := &obfssh.Conf{
ObfsMethod: obfsMethod, ObfsMethod: cfg.ObfsMethod,
ObfsKey: obfsKey, ObfsKey: cfg.ObfsKey,
Timeout: 10 * time.Second, Timeout: 10 * time.Second,
KeepAliveInterval: 10 * time.Second, KeepAliveInterval: 10 * time.Second,
KeepAliveMax: 5, KeepAliveMax: 5,
DisableObfsAfterHandshake: disableObfsAfterHandshake, DisableObfsAfterHandshake: cfg.DisableObfsAfterHandshake,
} }
conn, err := obfssh.NewClient(c, config, rhost, conf) conn, err := obfssh.NewClient(c, config, rhost, conf)
//conn, err := ssh.Dial("tcp", h, config)
if err != nil { if err != nil {
log.Fatal(err) //log.Fatal(err)
return nil, err
} }
defer conn.Close() //defer conn.Close()
sftpConn, err := sftp.NewClient(conn.Client(), sftp.MaxPacket(64*1024)) sftpConn, err := sftp.NewClient(conn.Client(), sftp.MaxPacket(64*1024))
if err != nil { if err != nil {
log.Fatal(err) //log.Fatal(err)
return nil, err
} }
defer sftpConn.Close() return sftpConn, nil
}
// download func splitHostPath(s string) (string, string, string) {
if toLocal { var user, host, path string
localFile := args[1] r1 := s
st, err := sftpConn.Stat(path) if strings.Contains(r1, "@") {
if err != nil { ss1 := strings.SplitN(r1, "@", 2)
log.Fatal(err) user = ss1[0]
r1 = ss1[1]
} }
if strings.Contains(r1, ":") {
ss2 := strings.SplitN(r1, ":", 2)
host = ss2[0]
path = ss2[1]
} else {
host = r1
}
return user, host, path
}
if st.Mode().IsDir() && !recursive { func download(args []string, cfg *options) error {
log.Fatal("use -r to transfer the directory")
var err1 error
localFile := clean(args[len(args)-1])
st, _ := os.Stat(localFile)
if len(args) > 2 {
if st != nil && !st.Mode().IsDir() {
log.Fatal("can't transfer multiple files to file")
}
if st == nil {
makeDirs(localFile, osDir{})
if err := os.Mkdir(localFile, 0755); err != nil {
log.Fatal(err)
}
}
st, _ = os.Stat(localFile)
} }
st1, err := os.Stat(localFile) for _, f := range args[:len(args)-1] {
if err == nil && !st1.Mode().IsDir() && st.Mode().IsDir() { user, host, path := splitHostPath(f)
log.Fatal("can't transfer directory to file") if host == "" || path == "" {
return errors.New("invalid path")
} }
if !st.Mode().IsDir() { path = clean(path)
if st1 != nil && st1.Mode().IsDir() {
// to local directory sftpConn, err := createSFTPConn(host, user, cfg)
bname := filepath.Base(path) if err != nil {
localFile = filepath.Join(localFile, bname) return err
} }
debuglog("transfer remote to local, %s -> %s", path, localFile) st1, err := sftpConn.Stat(path)
if err != nil {
err1 = err
debuglog("%s", err)
sftpConn.Close()
continue
}
if st1.Mode().IsDir() {
if !cfg.Recursive {
debuglog("omit remote directory %s", path)
sftpConn.Close()
continue
}
if err := rget(sftpConn, path, localFile); err != nil {
debuglog("download error: %s", err)
err1 = err
}
sftpConn.Close()
continue
}
if err := get(sftpConn, path, localFile); err != nil { lfile := localFile
log.Fatal(err) if st != nil && st.Mode().IsDir() {
lfile = filepath.Join(lfile, filepath.Base(path))
} }
debuglog("done")
return if err := get(sftpConn, path, lfile); err != nil {
debuglog("download error: %s", err)
err1 = err
} }
// recursive download sftpConn.Close()
if err := rget(sftpConn, path, localFile); err != nil {
log.Fatal(err)
} }
// download done
debuglog("done") debuglog("done")
return err1
}
func upload(args []string, cfg *options) error {
rfile := args[len(args)-1]
return rfile = clean(rfile)
user, host, path := splitHostPath(rfile)
if host == "" || path == "" {
return errors.New("invalid path")
} }
// upload path = clean(path)
sftpConn, err := createSFTPConn(host, user, cfg)
if err != nil {
return err
}
defer sftpConn.Close()
st, _ := sftpConn.Stat(path)
var err1 error
if len(args) > 2 { if len(args) > 2 {
if st, err := sftpConn.Stat(path); err == nil { if st != nil && !st.Mode().IsDir() {
if !st.Mode().IsDir() {
log.Fatal("multiple files can only been transferred to directory") log.Fatal("multiple files can only been transferred to directory")
} }
} else { if st == nil {
log.Fatalf("remote file or directory not exists") makeDirs(path, sftpConn)
if err := sftpConn.Mkdir(path); err != nil {
log.Fatal(err)
}
} }
st, _ = sftpConn.Stat(path)
} }
for i := 0; i < len(args)-1; i++ { for i := 0; i < len(args)-1; i++ {
localFile := args[i] localFile := args[i]
st, err := os.Stat(localFile)
localFile = clean(localFile)
st1, err := os.Stat(localFile)
// local file not exists // local file not exists
if err != nil { if err != nil {
debuglog("%s", err) debuglog("%s", err)
hasError = true err1 = err
continue continue
} }
// directory // directory
if st.Mode().IsDir() { if st1.Mode().IsDir() {
if !recursive { if !cfg.Recursive {
debuglog("omit directory %s", localFile) debuglog("omit directory %s", localFile)
continue continue
} }
// transfer directory // transfer directory
if err := rput(sftpConn, localFile, path); err != nil { if err := rput(sftpConn, localFile, path); err != nil {
debuglog("%s", err) debuglog("%s", err)
hasError = true err1 = err
} }
// next entry // next entry
@ -266,21 +341,16 @@ func main() {
remoteFile := path remoteFile := path
st1, err := sftpConn.Stat(path) if st != nil && st.Mode().IsDir() {
if err == nil && st1.Mode().IsDir() {
remoteFile = filepath.Join(path, filepath.Base(localFile)) remoteFile = filepath.Join(path, filepath.Base(localFile))
} }
if err := put(sftpConn, localFile, remoteFile); err != nil { if err := put(sftpConn, localFile, remoteFile); err != nil {
hasError = true
debuglog("upload %s failed: %s", localFile, err.Error()) debuglog("upload %s failed: %s", localFile, err.Error())
err1 = err
} }
} }
return err1
if hasError {
os.Exit(1)
}
} }
func get(sftpConn *sftp.Client, remoteFile, localFile string) error { func get(sftpConn *sftp.Client, remoteFile, localFile string) error {
@ -382,13 +452,13 @@ func rput(sftpConn *sftp.Client, localDir, remoteDir string) error {
continue continue
} }
p := walker.Path() p := clean(walker.Path())
p1 := strings.Replace(p, localDir, "", 1) p1 := strings.Replace(p, localDir, "", 1)
fmt.Println(strings.TrimLeft(p1, "/")) fmt.Println(strings.TrimLeft(p1, "/"))
p2 := filepath.Join(remoteDir, p1) p2 := clean(filepath.Join(remoteDir, p1))
if err := makeDirs(p2, sftpConn); err != nil { if err := makeDirs(p2, sftpConn); err != nil {
return err return err
@ -415,13 +485,13 @@ func rget(sftpConn *sftp.Client, remoteDir, localDir string) error {
continue continue
} }
p := walker.Path() p := clean(walker.Path())
p1 := strings.Replace(p, remoteDir, "", 1) p1 := strings.Replace(p, remoteDir, "", 1)
p2 := filepath.Join(localDir, p1) p2 := clean(filepath.Join(localDir, p1))
fmt.Println(strings.TrimLeft(p1, "/")) fmt.Println(strings.TrimLeft(p1, "/"))
if err := makeDirs(p2, fi{}); err != nil { if err := makeDirs(p2, osDir{}); err != nil {
return err return err
} }
@ -433,22 +503,24 @@ func rget(sftpConn *sftp.Client, remoteDir, localDir string) error {
return nil return nil
} }
type fi struct{} type osDir struct{}
func (f fi) Stat(s string) (os.FileInfo, error) { func (f osDir) Stat(s string) (os.FileInfo, error) {
return os.Stat(s) return os.Stat(s)
} }
func (f fi) Mkdir(s string) error { func (f osDir) Mkdir(s string) error {
return os.Mkdir(s, 0755) return os.Mkdir(s, 0755)
} }
type fileInterface interface { type dirInterface interface {
Stat(s string) (os.FileInfo, error) Stat(s string) (os.FileInfo, error)
Mkdir(s string) error Mkdir(s string) error
} }
func makeDirs(p string, c fileInterface) error { func makeDirs(p string, c dirInterface) error {
p = clean(p)
debuglog("make directory for %s", p) debuglog("make directory for %s", p)
for i := 1; i < len(p); i++ { for i := 1; i < len(p); i++ {
@ -544,3 +616,11 @@ Options for obfuscation:
fmt.Printf("%s", usageStr) fmt.Printf("%s", usageStr)
os.Exit(1) os.Exit(1)
} }
func clean(p string) string {
p = filepath.Clean(p)
if os.PathSeparator != '/' {
p = strings.Replace(p, string([]byte{os.PathSeparator}), "/", -1)
}
return p
}

Loading…
Cancel
Save