diff --git a/obfssh_scp/scp.go b/obfssh_scp/scp.go index fb35bb8..62cf0a8 100644 --- a/obfssh_scp/scp.go +++ b/obfssh_scp/scp.go @@ -1,6 +1,7 @@ package main import ( + "errors" "flag" "fmt" "github.com/bgentry/speakeasy" @@ -20,28 +21,34 @@ import ( "time" ) -func main() { - var user, port, pass, key string - var recursive bool - var obfsMethod, obfsKey string - var disableObfsAfterHandshake bool - var debug bool - var hasError bool +type options struct { + Debug bool + Port int + User string + Passwd string + Recursive bool + ObfsMethod string + ObfsKey string + DisableObfsAfterHandshake bool + PrivateKey string +} +func main() { + var cfg options flag.Usage = usage - flag.BoolVar(&debug, "d", false, "verbose mode") - flag.StringVar(&port, "p", "22", "port") - flag.StringVar(&user, "l", os.Getenv("USER"), "user") - flag.StringVar(&pass, "pw", "", "password") - flag.StringVar(&key, "i", "", "private key") - flag.BoolVar(&recursive, "r", false, "recursively copy entries") - flag.StringVar(&obfsMethod, "obfs_method", "none", "obfs encrypt method, rc4, aes or none") - flag.StringVar(&obfsKey, "obfs_key", "", "obfs encrypt key") - flag.BoolVar(&disableObfsAfterHandshake, "disable_obfs_after_handshake", false, "disable obfs after handshake") + flag.BoolVar(&cfg.Debug, "d", false, "verbose mode") + flag.IntVar(&cfg.Port, "p", 22, "port") + flag.StringVar(&cfg.User, "l", os.Getenv("USER"), "user") + flag.StringVar(&cfg.Passwd, "pw", "", "password") + flag.StringVar(&cfg.PrivateKey, "i", "", "private key") + flag.BoolVar(&cfg.Recursive, "r", false, "recursively copy entries") + flag.StringVar(&cfg.ObfsMethod, "obfs_method", "none", "obfs encrypt method, rc4, aes or none") + flag.StringVar(&cfg.ObfsKey, "obfs_key", "", "obfs encrypt key") + flag.BoolVar(&cfg.DisableObfsAfterHandshake, "disable_obfs_after_handshake", false, "disable obfs after handshake") flag.Parse() - if debug { + if cfg.Debug { obfssh.SSHLogLevel = obfssh.DEBUG } @@ -52,35 +59,24 @@ func main() { os.Exit(1) } - var host, path string - r1 := "" - var toLocal = false + var err error + if strings.Contains(args[0], ":") { - toLocal = true - r1 = args[0] + err = download(args, &cfg) } else { - toLocal = false - r1 = args[len(args)-1] + err = upload(args, &cfg) } - if strings.Contains(r1, "@") { - ss1 := strings.SplitN(r1, "@", 2) - user = ss1[0] - r1 = ss1[1] - } - ss2 := strings.SplitN(r1, ":", 2) - if len(ss2) != 2 { - //log.Fatal("Usage: \n\tscp user@host:path local\n\tor\n\tscp local... user@host:path") - flag.Usage() - os.Exit(1) + if err != nil { + log.Fatal(err) } - host = ss2[0] - path = ss2[1] +} +func createSFTPConn(host, user string, cfg *options) (*sftp.Client, error) { auths := []ssh.AuthMethod{} // read ssh agent and default auth key - if pass == "" && key == "" { + if cfg.Passwd == "" && cfg.PrivateKey == "" { var pkeys []ssh.Signer if aconn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { //auths = append(auths, ssh.PublicKeysCallback(agent.NewClient(aconn).Signers)) @@ -121,19 +117,19 @@ func main() { } } - if pass != "" { + if cfg.Passwd != "" { debuglog("add password auth") - auths = append(auths, ssh.Password(pass)) + auths = append(auths, ssh.Password(cfg.Passwd)) } else { debuglog("add keyboard interactive") auths = append(auths, ssh.RetryableAuthMethod(ssh.PasswordCallback(passwordAuth), 3)) } - if key != "" { - if buf, err := ioutil.ReadFile(key); err == nil { + if cfg.PrivateKey != "" { + if buf, err := ioutil.ReadFile(cfg.PrivateKey); err == nil { if p, err := ssh.ParsePrivateKey(buf); err == nil { - debuglog("add private key: %s", key) + debuglog("add private key: %s", cfg.PrivateKey) auths = append(auths, ssh.PublicKeys(p)) } else { debuglog("parse private key failed: %s", err) @@ -142,6 +138,9 @@ func main() { debuglog("read private key failed: %s", err) } } + if user == "" { + user = cfg.User + } config := &ssh.ClientConfig{ User: user, @@ -149,113 +148,189 @@ func main() { Timeout: 5 * time.Second, } - rhost := net.JoinHostPort(host, port) + rhost := net.JoinHostPort(host, fmt.Sprintf("%d", cfg.Port)) c, err := net.Dial("tcp", rhost) if err != nil { - log.Fatal(err) + //log.Fatal(err) + return nil, err } conf := &obfssh.Conf{ - ObfsMethod: obfsMethod, - ObfsKey: obfsKey, + ObfsMethod: cfg.ObfsMethod, + ObfsKey: cfg.ObfsKey, Timeout: 10 * time.Second, KeepAliveInterval: 10 * time.Second, KeepAliveMax: 5, - DisableObfsAfterHandshake: disableObfsAfterHandshake, + DisableObfsAfterHandshake: cfg.DisableObfsAfterHandshake, } conn, err := obfssh.NewClient(c, config, rhost, conf) - - //conn, err := ssh.Dial("tcp", h, config) if err != nil { - log.Fatal(err) + //log.Fatal(err) + return nil, err } - defer conn.Close() + //defer conn.Close() sftpConn, err := sftp.NewClient(conn.Client(), sftp.MaxPacket(64*1024)) if err != nil { - log.Fatal(err) + //log.Fatal(err) + return nil, err } - defer sftpConn.Close() + return sftpConn, nil +} - // download - if toLocal { - localFile := args[1] - st, err := sftpConn.Stat(path) - if err != nil { - log.Fatal(err) - } +func splitHostPath(s string) (string, string, string) { + var user, host, path string + r1 := s + if strings.Contains(r1, "@") { + ss1 := strings.SplitN(r1, "@", 2) + user = ss1[0] + r1 = ss1[1] + } + if strings.Contains(r1, ":") { + ss2 := strings.SplitN(r1, ":", 2) + host = ss2[0] + path = ss2[1] + } else { + host = r1 + } + return user, host, path +} + +func download(args []string, cfg *options) error { + + var err1 error - if st.Mode().IsDir() && !recursive { - log.Fatal("use -r to transfer the directory") + localFile := clean(args[len(args)-1]) + + st, _ := os.Stat(localFile) + + if len(args) > 2 { + if st != nil && !st.Mode().IsDir() { + log.Fatal("can't transfer multiple files to file") + } + if st == nil { + makeDirs(localFile, osDir{}) + if err := os.Mkdir(localFile, 0755); err != nil { + log.Fatal(err) + } } + st, _ = os.Stat(localFile) + } - st1, err := os.Stat(localFile) - if err == nil && !st1.Mode().IsDir() && st.Mode().IsDir() { - log.Fatal("can't transfer directory to file") + for _, f := range args[:len(args)-1] { + user, host, path := splitHostPath(f) + if host == "" || path == "" { + return errors.New("invalid path") } - if !st.Mode().IsDir() { - if st1 != nil && st1.Mode().IsDir() { - // to local directory - bname := filepath.Base(path) - localFile = filepath.Join(localFile, bname) - } + path = clean(path) - debuglog("transfer remote to local, %s -> %s", path, localFile) + sftpConn, err := createSFTPConn(host, user, cfg) + if err != nil { + return err + } - if err := get(sftpConn, path, localFile); err != nil { - log.Fatal(err) + st1, err := sftpConn.Stat(path) + if err != nil { + err1 = err + debuglog("%s", err) + sftpConn.Close() + continue + } + if st1.Mode().IsDir() { + if !cfg.Recursive { + debuglog("omit remote directory %s", path) + sftpConn.Close() + continue } - debuglog("done") - return + if err := rget(sftpConn, path, localFile); err != nil { + debuglog("download error: %s", err) + err1 = err + } + sftpConn.Close() + continue } - // recursive download - if err := rget(sftpConn, path, localFile); err != nil { - log.Fatal(err) + lfile := localFile + if st != nil && st.Mode().IsDir() { + lfile = filepath.Join(lfile, filepath.Base(path)) } - // download done - debuglog("done") + if err := get(sftpConn, path, lfile); err != nil { + debuglog("download error: %s", err) + err1 = err + } - return + sftpConn.Close() } - // upload + debuglog("done") + return err1 +} + +func upload(args []string, cfg *options) error { + + rfile := args[len(args)-1] + + rfile = clean(rfile) + + user, host, path := splitHostPath(rfile) + + if host == "" || path == "" { + return errors.New("invalid path") + } + + path = clean(path) + + sftpConn, err := createSFTPConn(host, user, cfg) + if err != nil { + return err + } + defer sftpConn.Close() + + st, _ := sftpConn.Stat(path) + + var err1 error if len(args) > 2 { - if st, err := sftpConn.Stat(path); err == nil { - if !st.Mode().IsDir() { - log.Fatal("multiple files can only been transferred to directory") + if st != nil && !st.Mode().IsDir() { + log.Fatal("multiple files can only been transferred to directory") + } + if st == nil { + makeDirs(path, sftpConn) + if err := sftpConn.Mkdir(path); err != nil { + log.Fatal(err) } - } else { - log.Fatalf("remote file or directory not exists") } + st, _ = sftpConn.Stat(path) } for i := 0; i < len(args)-1; i++ { localFile := args[i] - st, err := os.Stat(localFile) + + localFile = clean(localFile) + + st1, err := os.Stat(localFile) // local file not exists if err != nil { debuglog("%s", err) - hasError = true + err1 = err continue } // directory - if st.Mode().IsDir() { - if !recursive { + if st1.Mode().IsDir() { + if !cfg.Recursive { debuglog("omit directory %s", localFile) continue } // transfer directory if err := rput(sftpConn, localFile, path); err != nil { debuglog("%s", err) - hasError = true + err1 = err } // next entry @@ -266,21 +341,16 @@ func main() { remoteFile := path - st1, err := sftpConn.Stat(path) - if err == nil && st1.Mode().IsDir() { + if st != nil && st.Mode().IsDir() { remoteFile = filepath.Join(path, filepath.Base(localFile)) } if err := put(sftpConn, localFile, remoteFile); err != nil { - hasError = true debuglog("upload %s failed: %s", localFile, err.Error()) + err1 = err } } - - if hasError { - os.Exit(1) - } - + return err1 } func get(sftpConn *sftp.Client, remoteFile, localFile string) error { @@ -382,13 +452,13 @@ func rput(sftpConn *sftp.Client, localDir, remoteDir string) error { continue } - p := walker.Path() + p := clean(walker.Path()) p1 := strings.Replace(p, localDir, "", 1) fmt.Println(strings.TrimLeft(p1, "/")) - p2 := filepath.Join(remoteDir, p1) + p2 := clean(filepath.Join(remoteDir, p1)) if err := makeDirs(p2, sftpConn); err != nil { return err @@ -415,13 +485,13 @@ func rget(sftpConn *sftp.Client, remoteDir, localDir string) error { continue } - p := walker.Path() + p := clean(walker.Path()) p1 := strings.Replace(p, remoteDir, "", 1) - p2 := filepath.Join(localDir, p1) + p2 := clean(filepath.Join(localDir, p1)) fmt.Println(strings.TrimLeft(p1, "/")) - if err := makeDirs(p2, fi{}); err != nil { + if err := makeDirs(p2, osDir{}); err != nil { return err } @@ -433,22 +503,24 @@ func rget(sftpConn *sftp.Client, remoteDir, localDir string) error { return nil } -type fi struct{} +type osDir struct{} -func (f fi) Stat(s string) (os.FileInfo, error) { +func (f osDir) Stat(s string) (os.FileInfo, error) { return os.Stat(s) } -func (f fi) Mkdir(s string) error { +func (f osDir) Mkdir(s string) error { return os.Mkdir(s, 0755) } -type fileInterface interface { +type dirInterface interface { Stat(s string) (os.FileInfo, error) Mkdir(s string) error } -func makeDirs(p string, c fileInterface) error { +func makeDirs(p string, c dirInterface) error { + p = clean(p) + debuglog("make directory for %s", p) for i := 1; i < len(p); i++ { @@ -544,3 +616,11 @@ Options for obfuscation: fmt.Printf("%s", usageStr) os.Exit(1) } + +func clean(p string) string { + p = filepath.Clean(p) + if os.PathSeparator != '/' { + p = strings.Replace(p, string([]byte{os.PathSeparator}), "/", -1) + } + return p +}