package main import ( "bytes" "crypto/tls" "errors" "flag" "fmt" "io/ioutil" "net" "os" "os/signal" "syscall" "github.com/fangdingjun/go-log" "github.com/fangdingjun/obfssh" "github.com/fangdingjun/protolistener" "golang.org/x/crypto/ssh" ) func main() { var configfile string var logfile string var logFileCount int var logFileSize int64 var loglevel string flag.StringVar(&configfile, "c", "config.yaml", "configure file") flag.StringVar(&logfile, "log_file", "", "log file, default stdout") flag.IntVar(&logFileCount, "log_count", 10, "max count of log to keep") flag.Int64Var(&logFileSize, "log_size", 10, "max log file size MB") flag.StringVar(&loglevel, "log_level", "INFO", "log level, values:\nOFF, FATAL, PANIC, ERROR, WARN, INFO, DEBUG") flag.Parse() if logfile != "" { log.Default.Out = &log.FixedSizeFileWriter{ MaxCount: logFileCount, Name: logfile, MaxSize: logFileSize * 1024 * 1024, } } if loglevel != "" { lv, err := log.ParseLevel(loglevel) if err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } log.Default.Level = lv } conf, err := loadConfig(configfile) if err != nil { log.Fatal(err) } sconf := &obfssh.Conf{} config := &ssh.ServerConfig{ PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { if u, err := conf.getUser(c.User()); err == nil { if u.Password != "" && c.User() == u.Username && string(pass) == u.Password { return nil, nil } } return nil, fmt.Errorf("password reject for user %s", c.User()) }, PublicKeyCallback: func(c ssh.ConnMetadata, k ssh.PublicKey) (*ssh.Permissions, error) { checker := &ssh.CertChecker{ IsUserAuthority: func(k ssh.PublicKey) bool { if u, err := conf.getUser(c.User()); err == nil { for _, pk := range u.publicKeys { if k.Type() == pk.Type() && bytes.Compare(k.Marshal(), pk.Marshal()) == 0 { return true } } } return false }, } checker.UserKeyFallback = func(c1 ssh.ConnMetadata, k1 ssh.PublicKey) (*ssh.Permissions, error) { log.Debug("user key fallback") if checker.IsUserAuthority(k1) { return nil, nil } return nil, errors.New("public not acceptable") } return checker.Authenticate(c, k) }, // auth log AuthLogCallback: func(c ssh.ConnMetadata, method string, err error) { if err != nil { log.Debugf("%s", err.Error()) if method != "none" { log.Errorf("%s auth failed for %s from %s", method, c.User(), c.RemoteAddr()) } } else { log.Printf("Accepted %s for user %s from %s", method, c.User(), c.RemoteAddr()) } }, } privateBytes, err := ioutil.ReadFile(conf.HostKey) if err != nil { log.Fatal(err) } private, err := ssh.ParsePrivateKey(privateBytes) if err != nil { log.Fatal(err) } config.AddHostKey(private) for _, lst := range conf.Listen { go func(lst listen) { var l net.Listener var err error l, err = net.Listen("tcp", fmt.Sprintf(":%d", lst.Port)) if err != nil { log.Fatal(err) } defer l.Close() l = protolistener.New(l) if lst.Key != "" && lst.Cert != "" { cert, err := tls.LoadX509KeyPair(lst.Cert, lst.Key) if err != nil { log.Fatal(err) } l = tls.NewListener(l, &tls.Config{ Certificates: []tls.Certificate{cert}, }) } for { c, err := l.Accept() if err != nil { fmt.Println(err) return } log.Infof("accept tcp connection from %s", c.RemoteAddr()) go func(c net.Conn) { defer c.Close() sc, err := obfssh.NewServer(c, config, sconf) if err != nil { log.Errorf("%s", err.Error()) return } sc.Run() }(c) } }(lst) } ch := make(chan os.Signal, 2) signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM) select { case s := <-ch: log.Printf("received signal %s, exit.", s) } }