From adb6d676c09a205febbd609a00fa4ac178e5200f Mon Sep 17 00:00:00 2001 From: Dingjun Date: Wed, 7 Dec 2016 15:31:12 +0800 Subject: [PATCH] more error check --- obfssh_scp/.gitignore | 2 +- obfssh_scp/scp.go | 56 ++++++++++++++++++++++++++++++------------- 2 files changed, 40 insertions(+), 18 deletions(-) diff --git a/obfssh_scp/.gitignore b/obfssh_scp/.gitignore index 08e177f..c3ee97f 100644 --- a/obfssh_scp/.gitignore +++ b/obfssh_scp/.gitignore @@ -1 +1 @@ -obfssh_scp +obfssh_scp* diff --git a/obfssh_scp/scp.go b/obfssh_scp/scp.go index 783d6a6..141c00c 100644 --- a/obfssh_scp/scp.go +++ b/obfssh_scp/scp.go @@ -26,6 +26,7 @@ func main() { var obfsMethod, obfsKey string var disableObfsAfterHandshake bool var debug bool + var hasError bool flag.Usage = func() { fmt.Printf("Usage: \n\t%s [options] user@host:path local\n\tor\n\t%s [options] local... user@host:path\n", os.Args[0], os.Args[0]) @@ -40,7 +41,7 @@ func main() { flag.StringVar(&pass, "pw", "", "password") flag.StringVar(&key, "i", "", "private key") flag.BoolVar(&recursive, "r", false, "recursively copy entries") - flag.StringVar(&obfsMethod, "obfs_method", "", "obfs encrypt method, rc4, aes or none") + flag.StringVar(&obfsMethod, "obfs_method", "none", "obfs encrypt method, rc4, aes or none") flag.StringVar(&obfsKey, "obfs_key", "", "obfs encrypt key") flag.BoolVar(&disableObfsAfterHandshake, "disable_obfs_after_handshake", false, "disable obfs after handshake") flag.Parse() @@ -219,9 +220,13 @@ func main() { } // recursive download - rget(sftpConn, path, localFile) + if err := rget(sftpConn, path, localFile); err != nil { + log.Fatal(err) + } + // download done debuglog("done") + return } @@ -243,6 +248,7 @@ func main() { // local file not exists if err != nil { debuglog("%s", err) + hasError = true continue } @@ -253,19 +259,34 @@ func main() { continue } // transfer directory - rput(sftpConn, localFile, path) + if err := rput(sftpConn, localFile, path); err != nil { + debuglog("%s", err) + hasError = true + } // next entry continue } // file - remoteFile := filepath.Join(path, filepath.Base(localFile)) + + remoteFile := path + + st1, err := sftpConn.Stat(path) + if err == nil && st1.Mode().IsDir() { + remoteFile = filepath.Join(path, filepath.Base(localFile)) + } + if err := put(sftpConn, localFile, remoteFile); err != nil { + hasError = true debuglog("upload %s failed: %s", localFile, err.Error()) } } + if hasError { + os.Exit(1) + } + } func get(sftpConn *sftp.Client, remoteFile, localFile string) error { @@ -332,12 +353,14 @@ func put(sftpConn *sftp.Client, localFile, remoteFile string) error { //_, err = io.Copy(fpw, fpr) err = copyFile(fpw, fpr) if err != nil { - //log.Fatal(err) return err } // set permission and modtime - st, _ := os.Stat(localFile) + st, err := os.Stat(localFile) + if err != nil { + return err + } if err := sftpConn.Chmod(remoteFile, st.Mode().Perm()); err != nil { return err @@ -354,10 +377,10 @@ func put(sftpConn *sftp.Client, localFile, remoteFile string) error { func rput(sftpConn *sftp.Client, localDir, remoteDir string) error { walker := fs.Walk(localDir) + for walker.Step() { if err := walker.Err(); err != nil { - debuglog("walker error: %s", err) - continue + return err } if st := walker.Stat(); !st.Mode().IsRegular() { @@ -374,12 +397,11 @@ func rput(sftpConn *sftp.Client, localDir, remoteDir string) error { p2 := filepath.Join(remoteDir, p1) if err := makeDirs(p2, sftpConn); err != nil { - debuglog("make directory error: %s", err) - continue + return err } if err := put(sftpConn, p, p2); err != nil { - debuglog("upload %s failed: %s", p, err.Error()) + return err } } return nil @@ -391,8 +413,7 @@ func rget(sftpConn *sftp.Client, remoteDir, localDir string) error { walker := sftpConn.Walk(remoteDir) for walker.Step() { if err := walker.Err(); err != nil { - debuglog("walk error: %s", err) - continue + return err } if st := walker.Stat(); !st.Mode().IsRegular() { @@ -407,13 +428,14 @@ func rget(sftpConn *sftp.Client, remoteDir, localDir string) error { fmt.Println(strings.TrimLeft(p1, "/")) if err := makeDirs(p2, fi{}); err != nil { - debuglog("make directory failed: %s", err) + return err } if err := get(sftpConn, p, p2); err != nil { - debuglog("download %s failed: %s", p, err.Error()) + return err } } + return nil } @@ -463,10 +485,10 @@ func debuglog(format string, args ...interface{}) { // when use pkg/sftp client transfer a big file from pkg/sftp server, // io.Copy while cause connection hang, // I don't known why, -// use this func with smaller buffer has no problem +// use this function has no problem // func copyFile(w io.Writer, r io.Reader) error { - buf := make([]byte, 512) + buf := make([]byte, 34*1024) for { n, err := r.Read(buf) if n > 0 {