diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..de647f7 --- /dev/null +++ b/auth.go @@ -0,0 +1,23 @@ +package socks + +import "net" + +// AuthService the service to authenticate the user for socks5 +type AuthService interface { + // Authenticate auth the user + // return true means ok, false means no access + Authenticate(username, password string, addr net.Addr) bool +} + +// default password auth service +type passwordAuth struct { + username string + password string +} + +func (pa *passwordAuth) Authenticate(username, password string, addr net.Addr) bool { + if username == pa.username && password == pa.password { + return true + } + return false +} diff --git a/socks.go b/socks.go index edc7737..0ace162 100755 --- a/socks.go +++ b/socks.go @@ -25,10 +25,8 @@ type Conn struct { // the function to dial to upstream server // when nil, use net.Dial Dial DialFunc - // username for socks5 server - Username string - // password - Password string + // Auth the auth service to authenticate the user for socks5 + Auth AuthService } // Serve serve the client @@ -53,8 +51,7 @@ func (s *Conn) Serve() { s4 := socks4Conn{clientConn: s.Conn, dial: dial} s4.Serve(buf, n) case socks5Version: - s5 := socks5Conn{clientConn: s.Conn, dial: dial, - username: s.Username, password: s.Password} + s5 := socks5Conn{clientConn: s.Conn, dial: dial, auth: s.Auth} s5.Serve(buf, n) default: log.Printf("unknown socks version 0x%x", buf[0]) diff --git a/socks5.go b/socks5.go index 1c00c39..bd05ace 100755 --- a/socks5.go +++ b/socks5.go @@ -49,14 +49,11 @@ byte |0 | 1 | 2 | 3 | 4 | .. | n-2 | n-1 | n | // Socks5AuthRequired means socks5 server need auth or not type socks5Conn struct { - // username - username string - // password - password string //addr string clientConn net.Conn serverConn net.Conn dial DialFunc + auth AuthService } func (s5 *socks5Conn) Serve(b []byte, n int) { @@ -94,7 +91,7 @@ func (s5 *socks5Conn) handshake(buf []byte, n int) (err error) { n += n1 } - if s5.username == "" { + if s5.auth == nil { // no auth required s5.clientConn.Write([]byte{0x05, 0x00}) return nil @@ -169,15 +166,20 @@ func (s5 *socks5Conn) passwordAuth() error { // log.Printf("get username: %s, password: %s", username, password) - if string(username) == s5.username && string(password) == s5.password { - s5.clientConn.Write([]byte{0x01, 0x00}) - return nil - } + if s5.auth != nil { + ret := s5.auth.Authenticate( + string(username), string(password), + s5.clientConn.RemoteAddr()) + if ret { + s5.clientConn.Write([]byte{0x01, 0x00}) + return nil + } + s5.clientConn.Write([]byte{0x01, 0x01}) - // auth failed - s5.clientConn.Write([]byte{0x01, 0x01}) + return errors.New("access denied") + } - return fmt.Errorf("wrong password") + return errors.New("no auth method") } func (s5 *socks5Conn) processRequest() error { diff --git a/socks_test.go b/socks_test.go index 24d7463..ef6ef0a 100755 --- a/socks_test.go +++ b/socks_test.go @@ -56,7 +56,7 @@ func testSocks(t *testing.T, user, pass string, auth bool) error { return } log.Printf("connected from %s", conn.RemoteAddr()) - s := Conn{Conn: conn, Username: user, Password: pass} + s := Conn{Conn: conn, Auth: &passwordAuth{user, pass}} s.Serve() }()