diff --git a/obfsshd/server.go b/obfsshd/server.go index ed1b6ad..42182d3 100644 --- a/obfsshd/server.go +++ b/obfsshd/server.go @@ -3,6 +3,7 @@ package main import ( "bytes" "crypto/tls" + "errors" "flag" "fmt" "io/ioutil" @@ -69,21 +70,34 @@ func main() { }, PublicKeyCallback: func(c ssh.ConnMetadata, k ssh.PublicKey) (*ssh.Permissions, error) { - if u, err := conf.getUser(c.User()); err == nil { - for _, pk := range u.publicKeys { - if k.Type() == pk.Type() && bytes.Compare(k.Marshal(), pk.Marshal()) == 0 { - return nil, nil + checker := &ssh.CertChecker{ + IsUserAuthority: func(k ssh.PublicKey) bool { + if u, err := conf.getUser(c.User()); err == nil { + for _, pk := range u.publicKeys { + if k.Type() == pk.Type() && + bytes.Compare(k.Marshal(), pk.Marshal()) == 0 { + return true + } + } } + return false + }, + } + checker.UserKeyFallback = func(c1 ssh.ConnMetadata, k1 ssh.PublicKey) (*ssh.Permissions, error) { + log.Debug("user key fallback") + if checker.IsUserAuthority(k1) { + return nil, nil } + return nil, errors.New("public not acceptable") } - return nil, fmt.Errorf("publickey reject for user %s", c.User()) + return checker.Authenticate(c, k) }, // auth log AuthLogCallback: func(c ssh.ConnMetadata, method string, err error) { if err != nil { - log.Errorf("%s", err.Error()) - log.Errorf("%s auth failed for %s from %s", method, c.User(), c.RemoteAddr()) + log.Debugf("%s", err.Error()) + log.Debugf("%s auth failed for %s from %s", method, c.User(), c.RemoteAddr()) } else { log.Debugf("Accepted %s for user %s from %s", method, c.User(), c.RemoteAddr()) }