change log library

tls
fangdingjun 7 years ago
parent 2b973b326d
commit 911c955dd9

@ -12,6 +12,7 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/fangdingjun/go-log"
socks "github.com/fangdingjun/socks-go" socks "github.com/fangdingjun/socks-go"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal" "golang.org/x/crypto/ssh/terminal"
@ -62,11 +63,11 @@ func (cc *Client) Run() error {
} }
// wait port forward to finish // wait port forward to finish
if cc.listeners != nil { if cc.listeners != nil {
Log(DEBUG, "wait all channel to be done") log.Debugf("wait all channel to be done")
go cc.registerSignal() go cc.registerSignal()
go func() { go func() {
cc.err = cc.sshConn.Wait() cc.err = cc.sshConn.Wait()
Log(DEBUG, "connection hang up") log.Debugf("connection hang up")
//close(cc.ch) //close(cc.ch)
select { select {
case cc.ch <- struct{}{}: case cc.ch <- struct{}{}:
@ -77,11 +78,11 @@ func (cc *Client) Run() error {
// wait exit signal // wait exit signal
select { select {
case <-cc.ch: case <-cc.ch:
Log(INFO, "got signal, exit") log.Debugf("got signal, exit")
} }
} }
cc.Close() cc.Close()
Log(DEBUG, "Done") log.Debugf("Done")
return cc.err return cc.err
} }
@ -96,9 +97,9 @@ func (cc *Client) closeListener() {
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
for _, l := range cc.listeners { for _, l := range cc.listeners {
go func(l net.Listener) { go func(l net.Listener) {
Log(DEBUG, "begin to close listener %s", l.Addr().String()) log.Debugf("begin to close listener %s", l.Addr().String())
l.Close() l.Close()
Log(DEBUG, "close listener %s done", l.Addr().String()) log.Debugf("close listener %s done", l.Addr().String())
wg.Done() wg.Done()
}(l) }(l)
wg.Add(1) wg.Add(1)
@ -121,20 +122,20 @@ func (cc *Client) closeListener() {
func (cc *Client) Close() { func (cc *Client) Close() {
cc.closeListener() cc.closeListener()
Log(DEBUG, "close ssh connection") log.Debugf("close ssh connection")
cc.sshConn.Close() cc.sshConn.Close()
cc.conn.Close() cc.conn.Close()
Log(DEBUG, "close ssh connection done") log.Debugf("close ssh connection done")
} }
// RunCmd run a single command on server // RunCmd run a single command on server
func (cc *Client) RunCmd(cmd string) ([]byte, error) { func (cc *Client) RunCmd(cmd string) ([]byte, error) {
Log(INFO, "run command %s", cmd) log.Debugf("run command %s", cmd)
session, err := cc.client.NewSession() session, err := cc.client.NewSession()
if err != nil { if err != nil {
Log(DEBUG, "command exited with error: %s", err.Error()) log.Debugf("command exited with error: %s", err.Error())
} else { } else {
Log(DEBUG, "command exited with no error") log.Debugf("command exited with no error")
} }
if err != nil { if err != nil {
@ -147,7 +148,7 @@ func (cc *Client) RunCmd(cmd string) ([]byte, error) {
// Shell start a login shell on server // Shell start a login shell on server
func (cc *Client) Shell() error { func (cc *Client) Shell() error {
Log(DEBUG, "request new session") log.Debugf("request new session")
session, err := cc.client.NewSession() session, err := cc.client.NewSession()
if err != nil { if err != nil {
return err return err
@ -163,34 +164,34 @@ func (cc *Client) Shell() error {
} }
// this make CTRL+C works // this make CTRL+C works
Log(DEBUG, "turn terminal mode to raw") log.Debugf("turn terminal mode to raw")
oldState, _ := terminal.MakeRaw(0) oldState, _ := terminal.MakeRaw(0)
w, h, _ := terminal.GetSize(0) w, h, _ := terminal.GetSize(0)
Log(DEBUG, "request pty") log.Debugf("request pty")
if err := session.RequestPty("xterm", h, w, modes); err != nil { if err := session.RequestPty("xterm", h, w, modes); err != nil {
Log(ERROR, "request pty error: %s", err.Error()) log.Errorf("request pty error: %s", err.Error())
Log(DEBUG, "restore terminal mode") log.Debugf("restore terminal mode")
terminal.Restore(0, oldState) terminal.Restore(0, oldState)
return err return err
} }
Log(DEBUG, "request shell") log.Debugf("request shell")
if err := session.Shell(); err != nil { if err := session.Shell(); err != nil {
Log(ERROR, "start shell error: %s", err.Error()) log.Errorf("start shell error: %s", err.Error())
Log(DEBUG, "restore terminal mode") log.Debugf("restore terminal mode")
terminal.Restore(0, oldState) terminal.Restore(0, oldState)
return err return err
} }
session.Wait() session.Wait()
Log(DEBUG, "session closed") log.Debugf("session closed")
terminal.Restore(0, oldState) terminal.Restore(0, oldState)
Log(DEBUG, "restore terminal mode") log.Debugf("restore terminal mode")
return nil return nil
} }
// AddLocalForward add a local to remote port forward // AddLocalForward add a local to remote port forward
func (cc *Client) AddLocalForward(local, remote string) error { func (cc *Client) AddLocalForward(local, remote string) error {
Log(DEBUG, "add local forward %s -> %s", local, remote) log.Debugf("add local forward %s -> %s", local, remote)
l, err := net.Listen("tcp", local) l, err := net.Listen("tcp", local)
if err != nil { if err != nil {
return err return err
@ -201,10 +202,10 @@ func (cc *Client) AddLocalForward(local, remote string) error {
for { for {
c, err := l.Accept() c, err := l.Accept()
if err != nil { if err != nil {
Log(DEBUG, "local listen %s closed", l.Addr()) log.Debugf("local listen %s closed", l.Addr())
return return
} }
Log(DEBUG, "connection accepted from %s", c.RemoteAddr()) log.Debugf("connection accepted from %s", c.RemoteAddr())
go cc.handleLocalForward(c, remote) go cc.handleLocalForward(c, remote)
} }
}(l) }(l)
@ -214,7 +215,7 @@ func (cc *Client) AddLocalForward(local, remote string) error {
// AddRemoteForward add a remote to local port forward // AddRemoteForward add a remote to local port forward
func (cc *Client) AddRemoteForward(local, remote string) error { func (cc *Client) AddRemoteForward(local, remote string) error {
Log(DEBUG, "add remote forward %s -> %s", remote, local) log.Debugf("add remote forward %s -> %s", remote, local)
l, err := cc.client.Listen("tcp", remote) l, err := cc.client.Listen("tcp", remote)
if err != nil { if err != nil {
return err return err
@ -226,10 +227,10 @@ func (cc *Client) AddRemoteForward(local, remote string) error {
for { for {
c, err := l.Accept() c, err := l.Accept()
if err != nil { if err != nil {
Log(DEBUG, "remote listener %s closed", l.Addr()) log.Debugf("remote listener %s closed", l.Addr())
return return
} }
Log(DEBUG, "accept remote forward connection from %s", c.RemoteAddr()) log.Debugf("accept remote forward connection from %s", c.RemoteAddr())
go cc.handleRemoteForward(c, local) go cc.handleRemoteForward(c, local)
} }
}(l) }(l)
@ -238,7 +239,7 @@ func (cc *Client) AddRemoteForward(local, remote string) error {
// AddDynamicForward add a dynamic port forward // AddDynamicForward add a dynamic port forward
func (cc *Client) AddDynamicForward(local string) error { func (cc *Client) AddDynamicForward(local string) error {
Log(DEBUG, "add dynamic forward %s", local) log.Debugf("add dynamic forward %s", local)
l, err := net.Listen("tcp", local) l, err := net.Listen("tcp", local)
if err != nil { if err != nil {
return err return err
@ -249,10 +250,10 @@ func (cc *Client) AddDynamicForward(local string) error {
for { for {
c, err := l.Accept() c, err := l.Accept()
if err != nil { if err != nil {
Log(DEBUG, "local listener %s closed", l.Addr()) log.Debugf("local listener %s closed", l.Addr())
return return
} }
Log(DEBUG, "accept connection from %s", c.RemoteAddr()) log.Debugf("accept connection from %s", c.RemoteAddr())
go cc.handleDynamicForward(c) go cc.handleDynamicForward(c)
} }
}(l) }(l)
@ -262,22 +263,22 @@ func (cc *Client) AddDynamicForward(local string) error {
func (cc *Client) handleLocalForward(conn net.Conn, remote string) { func (cc *Client) handleLocalForward(conn net.Conn, remote string) {
rconn, err := cc.client.Dial("tcp", remote) rconn, err := cc.client.Dial("tcp", remote)
if err != nil { if err != nil {
Log(ERROR, "connect to %s failed: %s", remote, err.Error()) log.Errorf("connect to %s failed: %s", remote, err.Error())
conn.Close() conn.Close()
return return
} }
Log(DEBUG, "remote connect to %s success", remote) log.Debugf("remote connect to %s success", remote)
PipeAndClose(rconn, conn) PipeAndClose(rconn, conn)
} }
func (cc *Client) handleRemoteForward(conn net.Conn, local string) { func (cc *Client) handleRemoteForward(conn net.Conn, local string) {
lconn, err := dialer.Dial("tcp", local) lconn, err := dialer.Dial("tcp", local)
if err != nil { if err != nil {
Log(ERROR, "connect to %s failed: %s", local, err.Error()) log.Errorf("connect to %s failed: %s", local, err.Error())
conn.Close() conn.Close()
return return
} }
Log(DEBUG, "connect to %s success", local) log.Debugf("connect to %s success", local)
PipeAndClose(conn, lconn) PipeAndClose(conn, lconn)
} }
@ -287,19 +288,19 @@ func (cc *Client) handleDynamicForward(conn net.Conn) {
if addr.String() != conn.LocalAddr().String() { if addr.String() != conn.LocalAddr().String() {
// transparent proxy // transparent proxy
// iptables redirect the packet to this port // iptables redirect the packet to this port
Log(DEBUG, "transparent %s -> %s", conn.RemoteAddr(), addr) log.Debugf("transparent %s -> %s", conn.RemoteAddr(), addr)
cc.handleTransparentProxy(conn, addr) cc.handleTransparentProxy(conn, addr)
return return
} }
} else { } else {
// SO_ORIGNAL_DST failed // SO_ORIGNAL_DST failed
// just ignore it // just ignore it
Log(DEBUG, "get original destination on %s failed: %s, ignore", log.Debugf("get original destination on %s failed: %s, ignore",
conn.LocalAddr(), err) conn.LocalAddr(), err)
} }
// socks5 to this port // socks5 to this port
Log(DEBUG, "socks %s", conn.RemoteAddr()) log.Debugf("socks %s", conn.RemoteAddr())
s := socks.Conn{Conn: conn, Dial: cc.client.Dial} s := socks.Conn{Conn: conn, Dial: cc.client.Dial}
s.Serve() s.Serve()
} }
@ -307,7 +308,7 @@ func (cc *Client) handleDynamicForward(conn net.Conn) {
func (cc *Client) handleTransparentProxy(c net.Conn, addr net.Addr) { func (cc *Client) handleTransparentProxy(c net.Conn, addr net.Addr) {
c2, err := cc.client.Dial("tcp", addr.String()) c2, err := cc.client.Dial("tcp", addr.String())
if err != nil { if err != nil {
Log(ERROR, "%s", err) log.Errorf("%s", err)
c.Close() c.Close()
return return
} }
@ -328,19 +329,19 @@ func (cc *Client) keepAlive(interval time.Duration, maxCount int) {
select { select {
case err := <-resCh: case err := <-resCh:
if err != nil { if err != nil {
Log(DEBUG, "keep alive error: %s", err.Error()) log.Debugf("keep alive error: %s", err.Error())
count++ count++
} else { } else {
count = 0 count = 0
} }
case <-time.After(3 * time.Second): case <-time.After(3 * time.Second):
Log(DEBUG, "keep alive timed out") log.Debugf("keep alive timed out")
count++ count++
} }
if count >= maxCount { if count >= maxCount {
cc.err = fmt.Errorf("keep alive detects connection hang up") cc.err = fmt.Errorf("keep alive detects connection hang up")
Log(ERROR, "keep alive hit max count, exit") log.Errorf("keep alive hit max count, exit")
//cc.sshConn.Close() //cc.sshConn.Close()
//cc.conn.Close() //cc.conn.Close()
// send exit signal // send exit signal
@ -361,7 +362,7 @@ func (cc *Client) registerSignal() {
select { select {
case s1 := <-c: case s1 := <-c:
cc.err = fmt.Errorf("signal %v", s1) cc.err = fmt.Errorf("signal %v", s1)
Log(ERROR, "signal %d received, exit", s1) log.Errorf("signal %d received, exit", s1)
//close(cc.ch) //close(cc.ch)
select { select {
case cc.ch <- struct{}{}: case cc.ch <- struct{}{}:
@ -373,10 +374,10 @@ func (cc *Client) registerSignal() {
// AddDynamicHTTPForward add a http dynamic forward through // AddDynamicHTTPForward add a http dynamic forward through
// secure channel // secure channel
func (cc *Client) AddDynamicHTTPForward(addr string) error { func (cc *Client) AddDynamicHTTPForward(addr string) error {
Log(DEBUG, "add dynamic http listen: %s", addr) log.Debugf("add dynamic http listen: %s", addr)
l, err := net.Listen("tcp", addr) l, err := net.Listen("tcp", addr)
if err != nil { if err != nil {
Log(ERROR, "listen on %s failed, %s", addr, err) log.Errorf("listen on %s failed, %s", addr, err)
return err return err
} }
@ -387,7 +388,7 @@ func (cc *Client) AddDynamicHTTPForward(addr string) error {
for { for {
c, err := l.Accept() c, err := l.Accept()
if err != nil { if err != nil {
Log(ERROR, "accept error %s", err) log.Errorf("accept error %s", err)
break break
} }
go cc.handleHTTPIncoming(c) go cc.handleHTTPIncoming(c)
@ -403,7 +404,7 @@ func (cc *Client) handleHTTPIncoming(c net.Conn) {
req, err := http.ReadRequest(r) req, err := http.ReadRequest(r)
if err != nil { if err != nil {
Log(ERROR, "read http request error %s", err) log.Errorf("read http request error %s", err)
c.Close() c.Close()
return return
} }
@ -416,13 +417,13 @@ func (cc *Client) handleHTTPIncoming(c net.Conn) {
} }
func (cc *Client) handleConnect(req *http.Request, c net.Conn) { func (cc *Client) handleConnect(req *http.Request, c net.Conn) {
Log(DEBUG, "connect to %s", req.RequestURI) log.Debugf("connect to %s", req.RequestURI)
c1, err := cc.client.Dial("tcp", req.RequestURI) c1, err := cc.client.Dial("tcp", req.RequestURI)
if err != nil { if err != nil {
c.Close() c.Close()
fmt.Fprintf(c, "HTTP/1.0 503 connection failed\r\n\r\n") fmt.Fprintf(c, "HTTP/1.0 503 connection failed\r\n\r\n")
Log(ERROR, "dial error %s", err) log.Errorf("dial error %s", err)
return return
} }
@ -438,19 +439,19 @@ func (cc *Client) handleHTTPReq(req *http.Request, c net.Conn) {
host = fmt.Sprintf("%s:80", host) host = fmt.Sprintf("%s:80", host)
} }
Log(DEBUG, "request to %s", host) log.Debugf("request to %s", host)
c1, err := cc.client.Dial("tcp", host) c1, err := cc.client.Dial("tcp", host)
if err != nil { if err != nil {
c.Close() c.Close()
fmt.Fprintf(c, "HTTP/1.1 503 connection failed\r\nConnection: close\r\n\r\n") fmt.Fprintf(c, "HTTP/1.1 503 connection failed\r\nConnection: close\r\n\r\n")
Log(ERROR, "connection failed %s", err) log.Errorf("connection failed %s", err)
return return
} }
//defer c1.Close() //defer c1.Close()
if err = req.Write(c1); err != nil { if err = req.Write(c1); err != nil {
fmt.Fprintf(c, "HTTP/1.1 503 write to server error\r\nConnection: close\r\n\r\n") fmt.Fprintf(c, "HTTP/1.1 503 write to server error\r\nConnection: close\r\n\r\n")
Log(ERROR, "write request to server error %s", err) log.Errorf("write request to server error %s", err)
c.Close() c.Close()
c1.Close() c1.Close()
return return

@ -5,41 +5,46 @@ import (
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"github.com/bgentry/speakeasy"
"github.com/fangdingjun/obfssh"
"github.com/kr/fs"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"io" "io"
"io/ioutil" "io/ioutil"
"log"
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"syscall" "syscall"
"time" "time"
"github.com/bgentry/speakeasy"
"github.com/fangdingjun/go-log"
"github.com/fangdingjun/go-log/formatters"
"github.com/fangdingjun/go-log/writers"
"github.com/fangdingjun/obfssh"
"github.com/kr/fs"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
) )
type options struct { type options struct {
Debug bool Debug bool
Port int Port int
User string User string
Passwd string Passwd string
TLS bool TLS bool
TLSInsecure bool TLSInsecure bool
Recursive bool Recursive bool
ObfsMethod string PrivateKey string
ObfsKey string
DisableObfsAfterHandshake bool
PrivateKey string
} }
var dialer = &net.Dialer{Timeout: 10 * time.Second} var dialer = &net.Dialer{Timeout: 10 * time.Second}
func main() { func main() {
var cfg options var cfg options
var logfile string
var logFileCount int
var logFileSize int64
var loglevel string
flag.Usage = usage flag.Usage = usage
flag.BoolVar(&cfg.Debug, "d", false, "verbose mode") flag.BoolVar(&cfg.Debug, "d", false, "verbose mode")
@ -50,18 +55,36 @@ func main() {
flag.StringVar(&cfg.Passwd, "pw", "", "password") flag.StringVar(&cfg.Passwd, "pw", "", "password")
flag.StringVar(&cfg.PrivateKey, "i", "", "private key") flag.StringVar(&cfg.PrivateKey, "i", "", "private key")
flag.BoolVar(&cfg.Recursive, "r", false, "recursively copy entries") flag.BoolVar(&cfg.Recursive, "r", false, "recursively copy entries")
flag.StringVar(&logfile, "log_file", "", "log file, default stdout")
flag.IntVar(&logFileCount, "log_count", 10, "max count of log to keep")
flag.Int64Var(&logFileSize, "log_size", 10, "max log file size MB")
flag.StringVar(&loglevel, "log_level", "INFO", "log level, values:\nOFF, FATAL, PANIC, ERROR, WARN, INFO, DEBUG")
flag.Parse() flag.Parse()
if cfg.Debug {
obfssh.SSHLogLevel = obfssh.DEBUG
}
args := flag.Args() args := flag.Args()
if len(args) < 2 { if len(args) < 2 {
flag.Usage() flag.Usage()
os.Exit(1) os.Exit(1)
} }
if logfile != "" {
log.Default.Out = &writers.FixedSizeFileWriter{
MaxCount: logFileCount,
Name: logfile,
MaxSize: logFileSize * 1024 * 1024,
}
}
if loglevel != "" {
lv, err := log.ParseLevel(loglevel)
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
log.Default.Level = lv
}
log.Default.Formatter = &formatters.TextFormatter{TimeFormat: "2006-01-02 15:04:05.000"}
var err error var err error
@ -85,13 +108,13 @@ func createSFTPConn(host, user string, cfg *options) (*sftp.Client, error) {
if aconn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { if aconn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil {
//auths = append(auths, ssh.PublicKeysCallback(agent.NewClient(aconn).Signers)) //auths = append(auths, ssh.PublicKeysCallback(agent.NewClient(aconn).Signers))
if signers, err := agent.NewClient(aconn).Signers(); err == nil { if signers, err := agent.NewClient(aconn).Signers(); err == nil {
debuglog("add private key from agent") log.Debugf("add private key from agent")
pkeys = append(pkeys, signers...) pkeys = append(pkeys, signers...)
} else { } else {
debuglog("get key from agent failed: %s", err) log.Debugf("get key from agent failed: %s", err)
} }
} else { } else {
debuglog("dial to agent failed: %s", err) log.Debugf("dial to agent failed: %s", err)
} }
home := os.Getenv("HOME") home := os.Getenv("HOME")
@ -107,25 +130,25 @@ func createSFTPConn(host, user string, cfg *options) (*sftp.Client, error) {
if priKey, err := ssh.ParsePrivateKey(pemBytes); err == nil { if priKey, err := ssh.ParsePrivateKey(pemBytes); err == nil {
//auths = append(auths, ssh.PublicKeys(priKey)) //auths = append(auths, ssh.PublicKeys(priKey))
pkeys = append(pkeys, priKey) pkeys = append(pkeys, priKey)
debuglog("add private key %s", k1) log.Debugf("add private key %s", k1)
} else { } else {
debuglog("parse private key failed: %s", err) log.Debugf("parse private key failed: %s", err)
} }
} }
} }
if len(pkeys) != 0 { if len(pkeys) != 0 {
debuglog("totol %d private keys", len(pkeys)) log.Debugf("totol %d private keys", len(pkeys))
auths = append(auths, ssh.PublicKeys(pkeys...)) auths = append(auths, ssh.PublicKeys(pkeys...))
} }
} }
if cfg.Passwd != "" { if cfg.Passwd != "" {
debuglog("add password auth") log.Debugf("add password auth")
auths = append(auths, ssh.Password(cfg.Passwd)) auths = append(auths, ssh.Password(cfg.Passwd))
} else { } else {
debuglog("add keyboard interactive") log.Debugf("add keyboard interactive")
auths = append(auths, auths = append(auths,
ssh.RetryableAuthMethod(ssh.PasswordCallback(passwordAuth), 3)) ssh.RetryableAuthMethod(ssh.PasswordCallback(passwordAuth), 3))
} }
@ -133,13 +156,13 @@ func createSFTPConn(host, user string, cfg *options) (*sftp.Client, error) {
if cfg.PrivateKey != "" { if cfg.PrivateKey != "" {
if buf, err := ioutil.ReadFile(cfg.PrivateKey); err == nil { if buf, err := ioutil.ReadFile(cfg.PrivateKey); err == nil {
if p, err := ssh.ParsePrivateKey(buf); err == nil { if p, err := ssh.ParsePrivateKey(buf); err == nil {
debuglog("add private key: %s", cfg.PrivateKey) log.Debugf("add private key: %s", cfg.PrivateKey)
auths = append(auths, ssh.PublicKeys(p)) auths = append(auths, ssh.PublicKeys(p))
} else { } else {
debuglog("parse private key failed: %s", err) log.Debugf("parse private key failed: %s", err)
} }
} else { } else {
debuglog("read private key failed: %s", err) log.Debugf("read private key failed: %s", err)
} }
} }
if user == "" { if user == "" {
@ -250,18 +273,18 @@ func download(args []string, cfg *options) error {
st1, err := sftpConn.Stat(path) st1, err := sftpConn.Stat(path)
if err != nil { if err != nil {
err1 = err err1 = err
debuglog("%s", err) log.Debugf("%s", err)
sftpConn.Close() sftpConn.Close()
continue continue
} }
if st1.Mode().IsDir() { if st1.Mode().IsDir() {
if !cfg.Recursive { if !cfg.Recursive {
debuglog("omit remote directory %s", path) log.Debugf("omit remote directory %s", path)
sftpConn.Close() sftpConn.Close()
continue continue
} }
if err := rget(sftpConn, path, localFile); err != nil { if err := rget(sftpConn, path, localFile); err != nil {
debuglog("download error: %s", err) log.Debugf("download error: %s", err)
err1 = err err1 = err
} }
sftpConn.Close() sftpConn.Close()
@ -276,14 +299,14 @@ func download(args []string, cfg *options) error {
lfile = clean(lfile) lfile = clean(lfile)
if err := get(sftpConn, path, lfile); err != nil { if err := get(sftpConn, path, lfile); err != nil {
debuglog("download error: %s", err) log.Debugf("download error: %s", err)
err1 = err err1 = err
} }
sftpConn.Close() sftpConn.Close()
} }
debuglog("done") log.Debugf("done")
return err1 return err1
} }
@ -332,7 +355,7 @@ func upload(args []string, cfg *options) error {
// local file not exists // local file not exists
if err != nil { if err != nil {
debuglog("%s", err) log.Debugf("%s", err)
err1 = err err1 = err
continue continue
} }
@ -340,12 +363,12 @@ func upload(args []string, cfg *options) error {
// directory // directory
if st1.Mode().IsDir() { if st1.Mode().IsDir() {
if !cfg.Recursive { if !cfg.Recursive {
debuglog("omit directory %s", localFile) log.Debugf("omit directory %s", localFile)
continue continue
} }
// transfer directory // transfer directory
if err := rput(sftpConn, localFile, path); err != nil { if err := rput(sftpConn, localFile, path); err != nil {
debuglog("%s", err) log.Debugf("%s", err)
err1 = err err1 = err
} }
@ -364,7 +387,7 @@ func upload(args []string, cfg *options) error {
remoteFile = clean(remoteFile) remoteFile = clean(remoteFile)
if err := put(sftpConn, localFile, remoteFile); err != nil { if err := put(sftpConn, localFile, remoteFile); err != nil {
debuglog("upload %s failed: %s", localFile, err.Error()) log.Debugf("upload %s failed: %s", localFile, err.Error())
err1 = err err1 = err
} }
} }
@ -373,7 +396,7 @@ func upload(args []string, cfg *options) error {
func get(sftpConn *sftp.Client, remoteFile, localFile string) error { func get(sftpConn *sftp.Client, remoteFile, localFile string) error {
debuglog("download %s -> %s", remoteFile, localFile) log.Debugf("download %s -> %s", remoteFile, localFile)
fp, err := sftpConn.Open(remoteFile) fp, err := sftpConn.Open(remoteFile)
if err != nil { if err != nil {
@ -410,13 +433,13 @@ func get(sftpConn *sftp.Client, remoteFile, localFile string) error {
return err return err
} }
debuglog("done") log.Debugf("done")
return nil return nil
} }
func put(sftpConn *sftp.Client, localFile, remoteFile string) error { func put(sftpConn *sftp.Client, localFile, remoteFile string) error {
debuglog("upload %s -> %s", localFile, remoteFile) log.Debugf("upload %s -> %s", localFile, remoteFile)
fpw, err := sftpConn.OpenFile(remoteFile, syscall.O_WRONLY|syscall.O_CREAT|syscall.O_TRUNC) fpw, err := sftpConn.OpenFile(remoteFile, syscall.O_WRONLY|syscall.O_CREAT|syscall.O_TRUNC)
if err != nil { if err != nil {
@ -452,7 +475,7 @@ func put(sftpConn *sftp.Client, localFile, remoteFile string) error {
return err return err
} }
debuglog("done") log.Debugf("done")
return nil return nil
} }
@ -466,7 +489,7 @@ func rput(sftpConn *sftp.Client, localDir, remoteDir string) error {
} }
if st := walker.Stat(); !st.Mode().IsRegular() { if st := walker.Stat(); !st.Mode().IsRegular() {
debuglog("skip %s", walker.Path()) log.Debugf("skip %s", walker.Path())
continue continue
} }
@ -490,7 +513,7 @@ func rput(sftpConn *sftp.Client, localDir, remoteDir string) error {
} }
func rget(sftpConn *sftp.Client, remoteDir, localDir string) error { func rget(sftpConn *sftp.Client, remoteDir, localDir string) error {
debuglog("transfer recusive from remote to local, %s -> %s", remoteDir, localDir) log.Debugf("transfer recusive from remote to local, %s -> %s", remoteDir, localDir)
walker := sftpConn.Walk(remoteDir) walker := sftpConn.Walk(remoteDir)
for walker.Step() { for walker.Step() {
@ -499,7 +522,7 @@ func rget(sftpConn *sftp.Client, remoteDir, localDir string) error {
} }
if st := walker.Stat(); !st.Mode().IsRegular() { if st := walker.Stat(); !st.Mode().IsRegular() {
debuglog("skip %s", walker.Path()) log.Debugf("skip %s", walker.Path())
continue continue
} }
@ -539,13 +562,13 @@ type dirInterface interface {
func makeDirs(p string, c dirInterface) error { func makeDirs(p string, c dirInterface) error {
p = clean(p) p = clean(p)
debuglog("make directory for %s", p) log.Debugf("make directory for %s", p)
for i := 1; i < len(p); i++ { for i := 1; i < len(p); i++ {
if p[i] == '/' { if p[i] == '/' {
p1 := p[:i] p1 := p[:i]
if _, err := c.Stat(p1); err != nil { if _, err := c.Stat(p1); err != nil {
debuglog("make directory %s", p1) log.Debugf("make directory %s", p1)
if err := c.Mkdir(p1); err != nil { if err := c.Mkdir(p1); err != nil {
return err return err
} }
@ -561,10 +584,6 @@ func passwordAuth() (string, error) {
return strings.Trim(s, " \r\n"), err return strings.Trim(s, " \r\n"), err
} }
func debuglog(format string, args ...interface{}) {
obfssh.Log(obfssh.DEBUG, format, args...)
}
// //
// when use pkg/sftp client transfer a big file from pkg/sftp server, // when use pkg/sftp client transfer a big file from pkg/sftp server,
// io.Copy while cause connection hang, // io.Copy while cause connection hang,
@ -618,7 +637,20 @@ Options:
connect to server via TLS connect to server via TLS
-tls-insecure -tls-insecure
do not verify server's certificate do not verify server's certificate
-log_file
log file, default stdout
-log_count
max count of log file to keep, default 10
-log_size
max log size MB, default 10
-log_level
log level, values:
OFF, FATAL, PANIC, ERROR, WARN, INFO, DEBUG
` `
fmt.Printf("%s", usageStr) fmt.Printf("%s", usageStr)
os.Exit(1) os.Exit(1)

@ -13,7 +13,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/fangdingjun/obfssh" "github.com/fangdingjun/go-log"
socks "github.com/fangdingjun/socks-go" socks "github.com/fangdingjun/socks-go"
) )
@ -58,7 +58,7 @@ var _ net.Conn = &httpProxyConn{}
func updateProxyFromEnv(cfg *config) { func updateProxyFromEnv(cfg *config) {
if cfg.Proxy.Scheme != "" && cfg.Proxy.Host != "" && cfg.Proxy.Port != 0 { if cfg.Proxy.Scheme != "" && cfg.Proxy.Host != "" && cfg.Proxy.Port != 0 {
obfssh.Log(obfssh.DEBUG, "proxy already specified by config, not parse environment proxy") log.Debugf("proxy already specified by config, not parse environment proxy")
return return
} }
@ -81,7 +81,7 @@ func updateProxyFromEnv(cfg *config) {
u, err := url.Parse(proxyStr) u, err := url.Parse(proxyStr)
if err != nil { if err != nil {
obfssh.Log(obfssh.DEBUG, "parse proxy from environment failed: %s", err) log.Debugf("parse proxy from environment failed: %s", err)
return return
} }

@ -5,7 +5,6 @@ import (
"flag" "flag"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log"
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
@ -13,6 +12,9 @@ import (
"time" "time"
"github.com/bgentry/speakeasy" "github.com/bgentry/speakeasy"
"github.com/fangdingjun/go-log"
"github.com/fangdingjun/go-log/formatters"
"github.com/fangdingjun/go-log/writers"
"github.com/fangdingjun/obfssh" "github.com/fangdingjun/obfssh"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent" "golang.org/x/crypto/ssh/agent"
@ -23,6 +25,10 @@ var dialer = &net.Dialer{Timeout: 15 * time.Second}
func main() { func main() {
var configfile string var configfile string
var cfg config var cfg config
var logfile string
var logFileCount int
var logFileSize int64
var loglevel string
flag.StringVar(&configfile, "f", "", "configure file") flag.StringVar(&configfile, "f", "", "configure file")
flag.StringVar(&cfg.Username, "l", os.Getenv("USER"), "ssh username") flag.StringVar(&cfg.Username, "l", os.Getenv("USER"), "ssh username")
@ -38,20 +44,41 @@ func main() {
flag.BoolVar(&cfg.Debug, "d", false, "verbose mode") flag.BoolVar(&cfg.Debug, "d", false, "verbose mode")
flag.IntVar(&cfg.KeepaliveInterval, "keepalive_interval", 10, "keep alive interval") flag.IntVar(&cfg.KeepaliveInterval, "keepalive_interval", 10, "keep alive interval")
flag.IntVar(&cfg.KeepaliveMax, "keepalive_max", 5, "keep alive max") flag.IntVar(&cfg.KeepaliveMax, "keepalive_max", 5, "keep alive max")
flag.StringVar(&logfile, "log_file", "", "log file, default stdout")
flag.IntVar(&logFileCount, "log_count", 10, "max count of log to keep")
flag.Int64Var(&logFileSize, "log_size", 10, "max log file size MB")
flag.StringVar(&loglevel, "log_level", "INFO", "log level, values:\nOFF, FATAL, PANIC, ERROR, WARN, INFO, DEBUG")
flag.Usage = usage flag.Usage = usage
flag.Parse() flag.Parse()
if logfile != "" {
log.Default.Out = &writers.FixedSizeFileWriter{
MaxCount: logFileCount,
Name: logfile,
MaxSize: logFileSize * 1024 * 1024,
}
}
if loglevel != "" {
lv, err := log.ParseLevel(loglevel)
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
log.Default.Level = lv
}
log.Default.Formatter = &formatters.TextFormatter{TimeFormat: "2006-01-02 15:04:05.000"}
if configfile != "" { if configfile != "" {
if err := loadConfig(&cfg, configfile); err != nil { if err := loadConfig(&cfg, configfile); err != nil {
log.Fatal(err) log.Fatal(err)
} }
} }
if cfg.Debug { log.Debugf("obfssh client start")
obfssh.SSHLogLevel = obfssh.DEBUG
}
obfssh.Log(obfssh.DEBUG, "obfssh client start")
auth := []ssh.AuthMethod{} auth := []ssh.AuthMethod{}
@ -74,7 +101,7 @@ func main() {
k1 := filepath.Join(home, f) k1 := filepath.Join(home, f)
if pemBytes, err := ioutil.ReadFile(k1); err == nil { if pemBytes, err := ioutil.ReadFile(k1); err == nil {
if priKey, err := ssh.ParsePrivateKey(pemBytes); err == nil { if priKey, err := ssh.ParsePrivateKey(pemBytes); err == nil {
obfssh.Log(obfssh.DEBUG, "add private key: %s", k1) log.Debugf("add private key: %s", k1)
//auth = append(auth, ssh.PublicKeys(priKey)) //auth = append(auth, ssh.PublicKeys(priKey))
pkeys = append(pkeys, priKey) pkeys = append(pkeys, priKey)
} }
@ -85,21 +112,21 @@ func main() {
agentConn, err = net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")) agentConn, err = net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
if err == nil { if err == nil {
defer agentConn.Close() defer agentConn.Close()
obfssh.Log(obfssh.DEBUG, "add auth method with agent %s", os.Getenv("SSH_AUTH_SOCK")) log.Debugf("add auth method with agent %s", os.Getenv("SSH_AUTH_SOCK"))
agentClient := agent.NewClient(agentConn) agentClient := agent.NewClient(agentConn)
//auth = append(auth, ssh.PublicKeysCallback(agentClient.Signers)) //auth = append(auth, ssh.PublicKeysCallback(agentClient.Signers))
signers, err := agentClient.Signers() signers, err := agentClient.Signers()
if err == nil { if err == nil {
pkeys = append(pkeys, signers...) pkeys = append(pkeys, signers...)
} else { } else {
obfssh.Log(obfssh.DEBUG, "get key from agent failed: %s", err) log.Debugf("get key from agent failed: %s", err)
} }
} else { } else {
obfssh.Log(obfssh.DEBUG, "connect to agent failed") log.Debugf("connect to agent failed")
} }
if len(pkeys) != 0 { if len(pkeys) != 0 {
obfssh.Log(obfssh.DEBUG, "private key length %d", len(pkeys)) log.Debugf("private key length %d", len(pkeys))
auth = append(auth, ssh.PublicKeys(pkeys...)) auth = append(auth, ssh.PublicKeys(pkeys...))
} }
@ -140,15 +167,15 @@ func main() {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
obfssh.Log(obfssh.DEBUG, "add private key %s", cfg.PrivateKey) log.Debugf("add private key %s", cfg.PrivateKey)
auth = append(auth, ssh.PublicKeys(priKey)) auth = append(auth, ssh.PublicKeys(priKey))
} }
if cfg.Password != "" { if cfg.Password != "" {
obfssh.Log(obfssh.DEBUG, "add password auth method") log.Debugf("add password auth method")
auth = append(auth, ssh.Password(cfg.Password)) auth = append(auth, ssh.Password(cfg.Password))
} else { } else {
obfssh.Log(obfssh.DEBUG, "add keyboard interactive auth") log.Debugf("add keyboard interactive auth")
//auth = append(auth, //auth = append(auth,
// ssh.RetryableAuthMethod(ssh.KeyboardInteractive(keyboardAuth), 3)) // ssh.RetryableAuthMethod(ssh.KeyboardInteractive(keyboardAuth), 3))
auth = append(auth, auth = append(auth,
@ -161,7 +188,7 @@ func main() {
Timeout: 10 * time.Second, Timeout: 10 * time.Second,
HostKeyCallback: func(hostname string, remote net.Addr, HostKeyCallback: func(hostname string, remote net.Addr,
key ssh.PublicKey) error { key ssh.PublicKey) error {
obfssh.Log(obfssh.INFO, "%s %s %+v", hostname, remote, key) log.Debugf("%s %s %+v", hostname, remote, key)
return nil return nil
}, },
} }
@ -175,22 +202,22 @@ func main() {
if cfg.Proxy.Scheme != "" && cfg.Proxy.Host != "" && cfg.Proxy.Port != 0 { if cfg.Proxy.Scheme != "" && cfg.Proxy.Host != "" && cfg.Proxy.Port != 0 {
switch cfg.Proxy.Scheme { switch cfg.Proxy.Scheme {
case "http": case "http":
obfssh.Log(obfssh.DEBUG, "use http proxy %s:%d to connect to server", log.Debugf("use http proxy %s:%d to connect to server",
cfg.Proxy.Host, cfg.Proxy.Port) cfg.Proxy.Host, cfg.Proxy.Port)
c, err = dialHTTPProxy(host, cfg.Port, cfg.Proxy) c, err = dialHTTPProxy(host, cfg.Port, cfg.Proxy)
case "https": case "https":
obfssh.Log(obfssh.DEBUG, "use https proxy %s:%d to connect to server", log.Debugf("use https proxy %s:%d to connect to server",
cfg.Proxy.Host, cfg.Proxy.Port) cfg.Proxy.Host, cfg.Proxy.Port)
c, err = dialHTTPSProxy(host, cfg.Port, cfg.Proxy) c, err = dialHTTPSProxy(host, cfg.Port, cfg.Proxy)
case "socks5": case "socks5":
obfssh.Log(obfssh.DEBUG, "use socks proxy %s:%d to connect to server", log.Debugf("use socks proxy %s:%d to connect to server",
cfg.Proxy.Host, cfg.Proxy.Port) cfg.Proxy.Host, cfg.Proxy.Port)
c, err = dialSocks5Proxy(host, cfg.Port, cfg.Proxy) c, err = dialSocks5Proxy(host, cfg.Port, cfg.Proxy)
default: default:
err = fmt.Errorf("unsupported scheme: %s", cfg.Proxy.Scheme) err = fmt.Errorf("unsupported scheme: %s", cfg.Proxy.Scheme)
} }
} else { } else {
obfssh.Log(obfssh.DEBUG, "dail to %s", rhost) log.Debugf("dail to %s", rhost)
c, err = dialer.Dial("tcp", rhost) c, err = dialer.Dial("tcp", rhost)
} }
@ -198,7 +225,7 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
obfssh.Log(obfssh.DEBUG, "dail success") log.Debugf("dail success")
timeout := time.Duration(cfg.KeepaliveInterval*2) * time.Second timeout := time.Duration(cfg.KeepaliveInterval*2) * time.Second
@ -207,7 +234,7 @@ func main() {
conn := &obfssh.TimedOutConn{Conn: c, Timeout: timeout} conn := &obfssh.TimedOutConn{Conn: c, Timeout: timeout}
if cfg.TLS { if cfg.TLS {
obfssh.Log(obfssh.DEBUG, "begin tls handshake") log.Debugf("begin tls handshake")
_conn = tls.Client(conn, &tls.Config{ _conn = tls.Client(conn, &tls.Config{
ServerName: host, ServerName: host,
InsecureSkipVerify: cfg.TLSInsecure, InsecureSkipVerify: cfg.TLSInsecure,
@ -215,7 +242,7 @@ func main() {
if err := _conn.(*tls.Conn).Handshake(); err != nil { if err := _conn.(*tls.Conn).Handshake(); err != nil {
log.Fatal(err) log.Fatal(err)
} }
obfssh.Log(obfssh.DEBUG, "tls handshake done") log.Debugf("tls handshake done")
} }
conf := &obfssh.Conf{ conf := &obfssh.Conf{
@ -224,13 +251,13 @@ func main() {
KeepAliveMax: cfg.KeepaliveMax, KeepAliveMax: cfg.KeepaliveMax,
} }
obfssh.Log(obfssh.DEBUG, "ssh negotation") log.Debugf("ssh negotation")
client, err := obfssh.NewClient(_conn, config, rhost, conf) client, err := obfssh.NewClient(_conn, config, rhost, conf)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
obfssh.Log(obfssh.DEBUG, "ssh negotation success") log.Debugf("ssh negotation success")
var local, remote string var local, remote string
@ -239,7 +266,7 @@ func main() {
for _, p := range cfg.LocalForwards { for _, p := range cfg.LocalForwards {
addr := parseForwardAddr(p) addr := parseForwardAddr(p)
if len(addr) != 4 && len(addr) != 3 { if len(addr) != 4 && len(addr) != 3 {
log.Printf("wrong forward addr %s, format: [local_host:]local_port:remote_host:remote_port", p) log.Errorf("wrong forward addr %s, format: [local_host:]local_port:remote_host:remote_port", p)
continue continue
} }
if len(addr) == 4 { if len(addr) == 4 {
@ -251,14 +278,14 @@ func main() {
} }
//log.Printf("add local to remote %s->%s", local, remote) //log.Printf("add local to remote %s->%s", local, remote)
if err := client.AddLocalForward(local, remote); err != nil { if err := client.AddLocalForward(local, remote); err != nil {
log.Println(err) log.Errorln(err)
} }
} }
for _, p := range cfg.RemoteForwards { for _, p := range cfg.RemoteForwards {
addr := parseForwardAddr(p) addr := parseForwardAddr(p)
if len(addr) != 4 && len(addr) != 3 { if len(addr) != 4 && len(addr) != 3 {
log.Printf("wrong forward addr %s, format: [local_host:]local_port:remote_host:remote_port", p) log.Errorf("wrong forward addr %s, format: [local_host:]local_port:remote_host:remote_port", p)
continue continue
} }
if len(addr) == 4 { if len(addr) == 4 {
@ -270,7 +297,7 @@ func main() {
} }
//log.Printf("add remote to local %s->%s", remote, local) //log.Printf("add remote to local %s->%s", remote, local)
if err := client.AddRemoteForward(local, remote); err != nil { if err := client.AddRemoteForward(local, remote); err != nil {
log.Println(err) log.Errorln(err)
} }
} }
for _, p := range cfg.DynamicForwards { for _, p := range cfg.DynamicForwards {
@ -282,7 +309,7 @@ func main() {
} }
//log.Printf("listen on %s", local) //log.Printf("listen on %s", local)
if err := client.AddDynamicForward(local); err != nil { if err := client.AddDynamicForward(local); err != nil {
log.Println(err) log.Errorln(err)
} }
} }
@ -294,7 +321,7 @@ func main() {
} }
//log.Printf("listen on %s", local) //log.Printf("listen on %s", local)
if err := client.AddDynamicHTTPForward(local); err != nil { if err := client.AddDynamicHTTPForward(local); err != nil {
log.Println(err) log.Errorln(err)
} }
} }
@ -304,7 +331,7 @@ func main() {
if !cfg.NotRunCmd { if !cfg.NotRunCmd {
if cmd != "" { if cmd != "" {
if d, err := client.RunCmd(cmd); err != nil { if d, err := client.RunCmd(cmd); err != nil {
log.Println(err) log.Errorln(err)
hasErr = true hasErr = true
} else { } else {
//log.Printf("%s", string(d)) //log.Printf("%s", string(d))
@ -313,17 +340,17 @@ func main() {
} else { } else {
if err := client.Shell(); err != nil { if err := client.Shell(); err != nil {
hasErr = true hasErr = true
log.Println(err) log.Errorln(err)
} }
} }
} }
if err := client.Run(); err != nil { if err := client.Run(); err != nil {
log.Println(err) log.Errorln(err)
hasErr = true hasErr = true
} }
obfssh.Log(obfssh.DEBUG, "obfssh client exit") log.Debugf("obfssh client exit")
if hasErr { if hasErr {
os.Exit(1) os.Exit(1)
} }
@ -428,6 +455,19 @@ Options:
-tls-insecure -tls-insecure
do not verify server's tls ceritificate do not verify server's tls ceritificate
-log_file
log file, default stdout
-log_count
max count of log file to keep, default 10
-log_size
max log size MB, default 10
-log_level
log level, values:
OFF, FATAL, PANIC, ERROR, WARN, INFO, DEBUG
` `
fmt.Printf("%s", usageStr) fmt.Printf("%s", usageStr)
os.Exit(1) os.Exit(1)

@ -5,30 +5,57 @@ import (
"crypto/tls" "crypto/tls"
"flag" "flag"
"fmt" "fmt"
"github.com/fangdingjun/obfssh"
"golang.org/x/crypto/ssh"
"io/ioutil" "io/ioutil"
"log"
"net" "net"
"os"
"github.com/fangdingjun/go-log"
"github.com/fangdingjun/go-log/formatters"
"github.com/fangdingjun/go-log/writers"
"github.com/fangdingjun/obfssh"
"golang.org/x/crypto/ssh"
) )
func main() { func main() {
var configfile string var configfile string
var logfile string
var logFileCount int
var logFileSize int64
var loglevel string
flag.StringVar(&configfile, "c", "config.yaml", "configure file") flag.StringVar(&configfile, "c", "config.yaml", "configure file")
flag.StringVar(&logfile, "log_file", "", "log file, default stdout")
flag.IntVar(&logFileCount, "log_count", 10, "max count of log to keep")
flag.Int64Var(&logFileSize, "log_size", 10, "max log file size MB")
flag.StringVar(&loglevel, "log_level", "INFO", "log level, values:\nOFF, FATAL, PANIC, ERROR, WARN, INFO, DEBUG")
flag.Parse() flag.Parse()
if logfile != "" {
log.Default.Out = &writers.FixedSizeFileWriter{
MaxCount: logFileCount,
Name: logfile,
MaxSize: logFileSize * 1024 * 1024,
}
}
if loglevel != "" {
lv, err := log.ParseLevel(loglevel)
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
log.Default.Level = lv
}
log.Default.Formatter = &formatters.TextFormatter{TimeFormat: "2006-01-02 15:04:05.000"}
conf, err := loadConfig(configfile) conf, err := loadConfig(configfile)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
// set log level
if conf.Debug {
obfssh.SSHLogLevel = obfssh.DEBUG
}
sconf := &obfssh.Conf{} sconf := &obfssh.Conf{}
config := &ssh.ServerConfig{ config := &ssh.ServerConfig{
@ -55,10 +82,10 @@ func main() {
// auth log // auth log
AuthLogCallback: func(c ssh.ConnMetadata, method string, err error) { AuthLogCallback: func(c ssh.ConnMetadata, method string, err error) {
if err != nil { if err != nil {
obfssh.Log(obfssh.ERROR, "%s", err.Error()) log.Errorf("%s", err.Error())
obfssh.Log(obfssh.ERROR, "%s auth failed for %s from %s", method, c.User(), c.RemoteAddr()) log.Errorf("%s auth failed for %s from %s", method, c.User(), c.RemoteAddr())
} else { } else {
obfssh.Log(obfssh.INFO, "Accepted %s for user %s from %s", method, c.User(), c.RemoteAddr()) log.Debugf("Accepted %s for user %s from %s", method, c.User(), c.RemoteAddr())
} }
}, },
} }
@ -102,14 +129,14 @@ func main() {
return return
} }
obfssh.Log(obfssh.DEBUG, "accept tcp connection from %s", c.RemoteAddr()) log.Debugf("accept tcp connection from %s", c.RemoteAddr())
go func(c net.Conn) { go func(c net.Conn) {
defer c.Close() defer c.Close()
sc, err := obfssh.NewServer(c, config, sconf) sc, err := obfssh.NewServer(c, config, sconf)
if err != nil { if err != nil {
c.Close() c.Close()
obfssh.Log(obfssh.ERROR, "%s", err.Error()) log.Errorf("%s", err.Error())
return return
} }
sc.Run() sc.Run()

@ -5,6 +5,7 @@ import (
"net" "net"
"time" "time"
"github.com/fangdingjun/go-log"
"github.com/pkg/sftp" "github.com/pkg/sftp"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@ -45,16 +46,16 @@ func NewServer(c net.Conn, config *ssh.ServerConfig, conf *Conf) (*Server, error
// Run waits for server connection finish // Run waits for server connection finish
func (sc *Server) Run() { func (sc *Server) Run() {
sc.sshConn.Wait() sc.sshConn.Wait()
Log(DEBUG, "ssh connection closed") log.Debugf("ssh connection closed")
sc.close() sc.close()
} }
func (sc *Server) close() { func (sc *Server) close() {
Log(DEBUG, "close connection") log.Debugf("close connection")
sc.sshConn.Close() sc.sshConn.Close()
//Log(DEBUG, "close listener") //log.Debugf( "close listener")
for _, l := range sc.forwardedPorts { for _, l := range sc.forwardedPorts {
Log(DEBUG, "close listener %s", l.Addr()) log.Debugf("close listener %s", l.Addr())
l.Close() l.Close()
} }
} }
@ -62,7 +63,7 @@ func (sc *Server) close() {
func (sc *Server) handleNewChannelRequest(ch <-chan ssh.NewChannel) { func (sc *Server) handleNewChannelRequest(ch <-chan ssh.NewChannel) {
for newch := range ch { for newch := range ch {
Log(DEBUG, "request channel %s", newch.ChannelType()) log.Debugf("request channel %s", newch.ChannelType())
switch newch.ChannelType() { switch newch.ChannelType() {
case "session": case "session":
@ -74,7 +75,7 @@ func (sc *Server) handleNewChannelRequest(ch <-chan ssh.NewChannel) {
continue continue
} }
Log(DEBUG, "reject channel request %s", newch.ChannelType()) log.Debugf("reject channel request %s", newch.ChannelType())
newch.Reject(ssh.UnknownChannelType, "unknown channel type") newch.Reject(ssh.UnknownChannelType, "unknown channel type")
} }
@ -83,15 +84,15 @@ func (sc *Server) handleNewChannelRequest(ch <-chan ssh.NewChannel) {
func (sc *Server) handleGlobalRequest(req <-chan *ssh.Request) { func (sc *Server) handleGlobalRequest(req <-chan *ssh.Request) {
for r := range req { for r := range req {
Log(DEBUG, "global request %s", r.Type) log.Debugf("global request %s", r.Type)
switch r.Type { switch r.Type {
case "tcpip-forward": case "tcpip-forward":
Log(DEBUG, "request port forward") log.Debugf("request port forward")
go sc.handleTcpipForward(r) go sc.handleTcpipForward(r)
continue continue
case "cancel-tcpip-forward": case "cancel-tcpip-forward":
Log(DEBUG, "request cancel port forward") log.Debugf("request cancel port forward")
go sc.handleCancelTcpipForward(r) go sc.handleCancelTcpipForward(r)
continue continue
} }
@ -108,12 +109,12 @@ func serveSFTP(ch ssh.Channel) {
server, err := sftp.NewServer(ch) server, err := sftp.NewServer(ch)
if err != nil { if err != nil {
Log(DEBUG, "start sftp server failed: %s", err) log.Debugf("start sftp server failed: %s", err)
return return
} }
if err := server.Serve(); err != nil { if err := server.Serve(); err != nil {
Log(DEBUG, "sftp server finished with error: %s", err) log.Debugf("sftp server finished with error: %s", err)
return return
} }
} }
@ -132,7 +133,7 @@ type args struct {
func (sc *Server) handleSession(newch ssh.NewChannel) { func (sc *Server) handleSession(newch ssh.NewChannel) {
ch, req, err := newch.Accept() ch, req, err := newch.Accept()
if err != nil { if err != nil {
Log(ERROR, "%s", err.Error()) log.Errorf("%s", err.Error())
return return
} }
@ -152,7 +153,7 @@ func (sc *Server) handleSession(newch ssh.NewChannel) {
ret = true ret = true
Log(DEBUG, "handle sftp request") log.Debugf("handle sftp request")
go serveSFTP(ch) go serveSFTP(ch)
} }
@ -161,7 +162,7 @@ func (sc *Server) handleSession(newch ssh.NewChannel) {
ret = false ret = false
} }
Log(DEBUG, "session request %s, reply %v", r.Type, ret) log.Debugf("session request %s, reply %v", r.Type, ret)
if r.WantReply { if r.WantReply {
r.Reply(ret, nil) r.Reply(ret, nil)
@ -176,16 +177,16 @@ func handleDirectTcpip(newch ssh.NewChannel) {
err := ssh.Unmarshal(data, &r) err := ssh.Unmarshal(data, &r)
if err != nil { if err != nil {
Log(DEBUG, "invalid ssh parameter") log.Debugf("invalid ssh parameter")
newch.Reject(ssh.ConnectionFailed, "invalid argument") newch.Reject(ssh.ConnectionFailed, "invalid argument")
return return
} }
Log(DEBUG, "create connection to %s:%d", r.Raddr, r.Rport) log.Debugf("create connection to %s:%d", r.Raddr, r.Rport)
rconn, err := dialer.Dial("tcp", net.JoinHostPort(r.Raddr, fmt.Sprintf("%d", r.Rport))) rconn, err := dialer.Dial("tcp", net.JoinHostPort(r.Raddr, fmt.Sprintf("%d", r.Rport)))
if err != nil { if err != nil {
Log(ERROR, "%s", err.Error()) log.Errorf("%s", err.Error())
newch.Reject(ssh.ConnectionFailed, "invalid argument") newch.Reject(ssh.ConnectionFailed, "invalid argument")
return return
} }
@ -193,7 +194,7 @@ func handleDirectTcpip(newch ssh.NewChannel) {
channel, requests, err := newch.Accept() channel, requests, err := newch.Accept()
if err != nil { if err != nil {
rconn.Close() rconn.Close()
Log(ERROR, "%s", err.Error()) log.Errorf("%s", err.Error())
return return
} }
@ -212,7 +213,7 @@ func (sc *Server) handleCancelTcpipForward(req *ssh.Request) {
var a tcpipForwardAddr var a tcpipForwardAddr
if err := ssh.Unmarshal(req.Payload, &a); err != nil { if err := ssh.Unmarshal(req.Payload, &a); err != nil {
Log(ERROR, "invalid ssh parameter for cancel port forward") log.Errorf("invalid ssh parameter for cancel port forward")
if req.WantReply { if req.WantReply {
req.Reply(false, nil) req.Reply(false, nil)
} }
@ -233,7 +234,7 @@ func (sc *Server) handleCancelTcpipForward(req *ssh.Request) {
func (sc *Server) handleTcpipForward(req *ssh.Request) { func (sc *Server) handleTcpipForward(req *ssh.Request) {
var addr tcpipForwardAddr var addr tcpipForwardAddr
if err := ssh.Unmarshal(req.Payload, &addr); err != nil { if err := ssh.Unmarshal(req.Payload, &addr); err != nil {
Log(ERROR, "parse ssh data error: %s", err) log.Errorf("parse ssh data error: %s", err)
if req.WantReply { if req.WantReply {
req.Reply(false, nil) req.Reply(false, nil)
} }
@ -241,7 +242,7 @@ func (sc *Server) handleTcpipForward(req *ssh.Request) {
} }
if addr.Port > 65535 || addr.Port < 0 { if addr.Port > 65535 || addr.Port < 0 {
Log(ERROR, "invalid port %d", addr.Port) log.Errorf("invalid port %d", addr.Port)
if req.WantReply { if req.WantReply {
req.Reply(false, nil) req.Reply(false, nil)
} }
@ -250,7 +251,7 @@ func (sc *Server) handleTcpipForward(req *ssh.Request) {
ip := net.ParseIP(addr.Addr) ip := net.ParseIP(addr.Addr)
if ip == nil { if ip == nil {
Log(ERROR, "invalid ip %d", addr.Port) log.Errorf("invalid ip %d", addr.Port)
if req.WantReply { if req.WantReply {
req.Reply(false, nil) req.Reply(false, nil)
} }
@ -261,19 +262,19 @@ func (sc *Server) handleTcpipForward(req *ssh.Request) {
if _, ok := sc.forwardedPorts[k]; ok { if _, ok := sc.forwardedPorts[k]; ok {
// port in use // port in use
Log(ERROR, "port in use: %s", k) log.Errorf("port in use: %s", k)
if req.WantReply { if req.WantReply {
req.Reply(false, nil) req.Reply(false, nil)
} }
return return
} }
//Log(DEBUG, "get request for addr: %s, port: %d", addr.Addr, addr.Port) //log.Debugf( "get request for addr: %s, port: %d", addr.Addr, addr.Port)
l, err := net.ListenTCP("tcp", &net.TCPAddr{IP: ip, Port: int(addr.Port)}) l, err := net.ListenTCP("tcp", &net.TCPAddr{IP: ip, Port: int(addr.Port)})
if err != nil { if err != nil {
// listen failed // listen failed
Log(ERROR, "%s", err.Error()) log.Errorf("%s", err.Error())
if req.WantReply { if req.WantReply {
req.Reply(false, nil) req.Reply(false, nil)
} }
@ -281,7 +282,7 @@ func (sc *Server) handleTcpipForward(req *ssh.Request) {
} }
a1 := l.Addr() a1 := l.Addr()
Log(DEBUG, "Listening port %s", a1) log.Debugf("Listening port %s", a1)
p := struct { p := struct {
Port uint32 Port uint32
}{ }{
@ -297,10 +298,10 @@ func (sc *Server) handleTcpipForward(req *ssh.Request) {
for { for {
c, err := l.Accept() c, err := l.Accept()
if err != nil { if err != nil {
Log(ERROR, "%s", err.Error()) log.Errorf("%s", err.Error())
return return
} }
Log(DEBUG, "accept connection from %s", c.RemoteAddr()) log.Debugf("accept connection from %s", c.RemoteAddr())
go func(c net.Conn) { go func(c net.Conn) {
laddr := c.LocalAddr() laddr := c.LocalAddr()
raddr := c.RemoteAddr() raddr := c.RemoteAddr()
@ -317,7 +318,7 @@ func (sc *Server) handleTcpipForward(req *ssh.Request) {
} }
ch, r, err := sc.sshConn.OpenChannel("forwarded-tcpip", ssh.Marshal(a2)) ch, r, err := sc.sshConn.OpenChannel("forwarded-tcpip", ssh.Marshal(a2))
if err != nil { if err != nil {
Log(ERROR, "forward port failed: %s", err.Error()) log.Errorf("forward port failed: %s", err.Error())
c.Close() c.Close()
return return
} }

@ -5,19 +5,6 @@ import (
"log" "log"
) )
const (
_ = iota
// DEBUG log level debug
DEBUG
// INFO log level info
INFO
// ERROR log level error
ERROR
)
// SSHLogLevel global value for log level
var SSHLogLevel = ERROR
// PipeAndClose pipe the data between c and s, close both when done // PipeAndClose pipe the data between c and s, close both when done
func PipeAndClose(c io.ReadWriteCloser, s io.ReadWriteCloser) { func PipeAndClose(c io.ReadWriteCloser, s io.ReadWriteCloser) {
defer func() { defer func() {
@ -41,10 +28,3 @@ func PipeAndClose(c io.ReadWriteCloser, s io.ReadWriteCloser) {
<-cc <-cc
} }
// Log log the message by level
func Log(level int, s string, args ...interface{}) {
if level >= SSHLogLevel {
log.Printf(s, args...)
}
}

Loading…
Cancel
Save