diff --git a/obfssh_scp/scp.go b/obfssh_scp/scp.go index f3e4d79..783d6a6 100644 --- a/obfssh_scp/scp.go +++ b/obfssh_scp/scp.go @@ -3,6 +3,7 @@ package main import ( "flag" "fmt" + "github.com/bgentry/speakeasy" "github.com/fangdingjun/obfssh" "github.com/kr/fs" "github.com/pkg/sftp" @@ -24,7 +25,16 @@ func main() { var recursive bool var obfsMethod, obfsKey string var disableObfsAfterHandshake bool + var debug bool + flag.Usage = func() { + fmt.Printf("Usage: \n\t%s [options] user@host:path local\n\tor\n\t%s [options] local... user@host:path\n", os.Args[0], os.Args[0]) + + fmt.Printf("Options:\n\n") + + flag.PrintDefaults() + } + flag.BoolVar(&debug, "d", false, "verbose mode") flag.StringVar(&port, "p", "22", "port") flag.StringVar(&user, "l", os.Getenv("USER"), "user") flag.StringVar(&pass, "pw", "", "password") @@ -35,10 +45,15 @@ func main() { flag.BoolVar(&disableObfsAfterHandshake, "disable_obfs_after_handshake", false, "disable obfs after handshake") flag.Parse() + if debug { + obfssh.SSHLogLevel = obfssh.DEBUG + } + args := flag.Args() if len(args) < 2 { - fmt.Printf("Usage: \n\tscp user@host:path local\n\tor\n\tscp local... user@host:path\n") + //fmt.Printf("Usage: \n\tscp user@host:path local\n\tor\n\tscp local... user@host:path\n") + flag.Usage() os.Exit(1) } @@ -60,7 +75,9 @@ func main() { } ss2 := strings.SplitN(r1, ":", 2) if len(ss2) != 2 { - log.Fatal("Usage: \n\tscp user@host:path local\n\tor\n\tscp local... user@host:path") + //log.Fatal("Usage: \n\tscp user@host:path local\n\tor\n\tscp local... user@host:path") + flag.Usage() + os.Exit(1) } host = ss2[0] path = ss2[1] @@ -72,9 +89,14 @@ func main() { var pkeys []ssh.Signer if aconn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { //auths = append(auths, ssh.PublicKeysCallback(agent.NewClient(aconn).Signers)) - if signers, err := agent.NewClient(aconn).Signers(); err != nil { + if signers, err := agent.NewClient(aconn).Signers(); err == nil { + debuglog("add private key from agent") pkeys = append(pkeys, signers...) + } else { + debuglog("get key from agent failed: %s", err) } + } else { + debuglog("dial to agent failed: %s", err) } home := os.Getenv("HOME") @@ -88,29 +110,44 @@ func main() { k1 := filepath.Join(home, f) if pemBytes, err := ioutil.ReadFile(k1); err == nil { if priKey, err := ssh.ParsePrivateKey(pemBytes); err == nil { - //log.Printf("add pri...") //auths = append(auths, ssh.PublicKeys(priKey)) pkeys = append(pkeys, priKey) + debuglog("add private key %s", k1) + } else { + debuglog("parse private key failed: %s", err) } } } + if len(pkeys) != 0 { + debuglog("totol %d private keys", len(pkeys)) auths = append(auths, ssh.PublicKeys(pkeys...)) } } if pass != "" { + debuglog("add password auth") auths = append(auths, ssh.Password(pass)) + } else { + debuglog("add keyboard interactive") + auths = append(auths, + ssh.RetryableAuthMethod(ssh.PasswordCallback(passwordAuth), 3)) } if key != "" { if buf, err := ioutil.ReadFile(key); err == nil { if p, err := ssh.ParsePrivateKey(buf); err == nil { + debuglog("add private key: %s", key) auths = append(auths, ssh.PublicKeys(p)) + } else { + debuglog("parse private key failed: %s", err) } + } else { + debuglog("read private key failed: %s", err) } } + config := &ssh.ClientConfig{ User: user, Auth: auths, @@ -166,19 +203,25 @@ func main() { } if !st.Mode().IsDir() { - if st1.Mode().IsDir() { + if st1 != nil && st1.Mode().IsDir() { + // to local directory bname := filepath.Base(path) localFile = filepath.Join(localFile, bname) } + + debuglog("transfer remote to local, %s -> %s", path, localFile) + if err := get(sftpConn, path, localFile); err != nil { log.Fatal(err) } + debuglog("done") return } // recursive download rget(sftpConn, path, localFile) + debuglog("done") return } @@ -199,14 +242,14 @@ func main() { // local file not exists if err != nil { - log.Println(err) + debuglog("%s", err) continue } // directory if st.Mode().IsDir() { if !recursive { - log.Printf("omit directory %s", localFile) + debuglog("omit directory %s", localFile) continue } // transfer directory @@ -219,16 +262,18 @@ func main() { // file remoteFile := filepath.Join(path, filepath.Base(localFile)) if err := put(sftpConn, localFile, remoteFile); err != nil { - log.Printf("upload %s failed: %s", localFile, err.Error()) + debuglog("upload %s failed: %s", localFile, err.Error()) } } } func get(sftpConn *sftp.Client, remoteFile, localFile string) error { + + debuglog("download %s -> %s", remoteFile, localFile) + fp, err := sftpConn.Open(remoteFile) if err != nil { - //log.Fatal(err) return err } @@ -236,45 +281,56 @@ func get(sftpConn *sftp.Client, remoteFile, localFile string) error { fp1, err := os.OpenFile(localFile, syscall.O_WRONLY|syscall.O_CREAT|syscall.O_TRUNC, 0644) if err != nil { - //log.Fatal(err) return err } defer fp1.Close() - _, err = io.Copy(fp1, fp) + //_, err = io.Copy(fp1, fp) + err = copyFile(fp1, fp) if err != nil { - //log.Fatal(err) return err } // set permission and modtime - st, _ := sftpConn.Stat(remoteFile) + + st, err := sftpConn.Stat(remoteFile) + if err != nil { + return err + } + if err := os.Chmod(localFile, st.Mode().Perm()); err != nil { - //log.Println(err) return err } + if err := os.Chtimes(localFile, st.ModTime(), st.ModTime()); err != nil { - //log.Println(err) return err } + + debuglog("done") + return nil } func put(sftpConn *sftp.Client, localFile, remoteFile string) error { + debuglog("upload %s -> %s", localFile, remoteFile) + fpw, err := sftpConn.OpenFile(remoteFile, syscall.O_WRONLY|syscall.O_CREAT|syscall.O_TRUNC) if err != nil { - //log.Fatal(err) return err } + defer fpw.Close() + fpr, err := os.Open(localFile) if err != nil { - //log.Fatal(err) return err } + defer fpr.Close() - _, err = io.Copy(fpw, fpr) + + //_, err = io.Copy(fpw, fpr) + err = copyFile(fpw, fpr) if err != nil { //log.Fatal(err) return err @@ -282,14 +338,17 @@ func put(sftpConn *sftp.Client, localFile, remoteFile string) error { // set permission and modtime st, _ := os.Stat(localFile) + if err := sftpConn.Chmod(remoteFile, st.Mode().Perm()); err != nil { - //log.Println(err) return err } + if err := sftpConn.Chtimes(remoteFile, st.ModTime(), st.ModTime()); err != nil { - //log.Println(err) return err } + + debuglog("done") + return nil } @@ -297,45 +356,62 @@ func rput(sftpConn *sftp.Client, localDir, remoteDir string) error { walker := fs.Walk(localDir) for walker.Step() { if err := walker.Err(); err != nil { - log.Println(err) + debuglog("walker error: %s", err) continue } - if st := walker.Stat(); st.Mode().IsDir() { + + if st := walker.Stat(); !st.Mode().IsRegular() { + debuglog("skip %s", walker.Path()) continue } + p := walker.Path() + p1 := strings.Replace(p, localDir, "", 1) + + fmt.Println(strings.TrimLeft(p1, "/")) + p2 := filepath.Join(remoteDir, p1) + if err := makeDirs(p2, sftpConn); err != nil { - log.Println(err) + debuglog("make directory error: %s", err) continue } if err := put(sftpConn, p, p2); err != nil { - log.Printf("upload %s failed: %s", p, err.Error()) + debuglog("upload %s failed: %s", p, err.Error()) } } return nil } func rget(sftpConn *sftp.Client, remoteDir, localDir string) error { + debuglog("transfer recusive from remote to local, %s -> %s", remoteDir, localDir) + walker := sftpConn.Walk(remoteDir) for walker.Step() { if err := walker.Err(); err != nil { - log.Println(err) + debuglog("walk error: %s", err) continue } - if st := walker.Stat(); st.Mode().IsDir() { + + if st := walker.Stat(); !st.Mode().IsRegular() { + debuglog("skip %s", walker.Path()) continue } + p := walker.Path() p1 := strings.Replace(p, remoteDir, "", 1) p2 := filepath.Join(localDir, p1) + + fmt.Println(strings.TrimLeft(p1, "/")) + if err := makeDirs(p2, fi{}); err != nil { - log.Println(err) + debuglog("make directory failed: %s", err) } + if err := get(sftpConn, p, p2); err != nil { - log.Printf("download %s failed: %s", p, err.Error()) + debuglog("download %s failed: %s", p, err.Error()) } } return nil @@ -357,10 +433,13 @@ type fileInterface interface { } func makeDirs(p string, c fileInterface) error { + debuglog("make directory for %s", p) + for i := 1; i < len(p); i++ { if p[i] == '/' { p1 := p[:i] if _, err := c.Stat(p1); err != nil { + debuglog("make directory %s", p1) if err := c.Mkdir(p1); err != nil { return err } @@ -369,3 +448,39 @@ func makeDirs(p string, c fileInterface) error { } return nil } + +func passwordAuth() (string, error) { + // read password from console + s, err := speakeasy.Ask("Password: ") + return strings.Trim(s, " \r\n"), err +} + +func debuglog(format string, args ...interface{}) { + obfssh.Log(obfssh.DEBUG, format, args...) +} + +// +// when use pkg/sftp client transfer a big file from pkg/sftp server, +// io.Copy while cause connection hang, +// I don't known why, +// use this func with smaller buffer has no problem +// +func copyFile(w io.Writer, r io.Reader) error { + buf := make([]byte, 512) + for { + n, err := r.Read(buf) + if n > 0 { + _, err1 := w.Write(buf[:n]) + if err1 != nil { + return err1 + } + } + if err != nil { + if err == io.EOF { + // trust io.EOF as success + return nil + } + return err + } + } +}