add socks client support

master
fangdingjun 8 years ago
parent 14c964c1c2
commit 4094a344ef

@ -0,0 +1,145 @@
package socks
import (
"encoding/binary"
"fmt"
"io"
"net"
)
// Client is a net.Conn with socks5 support
type Client struct {
net.Conn
// socks5 username
Username string
// socks5 password
Password string
handshakeDone bool
connected bool
}
func (sc *Client) handShake() error {
if sc.handshakeDone {
return nil
}
// password auth or none
if _, err := sc.Conn.Write([]byte{socks5Version, 0x02, 0x00, 0x02}); err != nil {
return err
}
buf := make([]byte, 2)
if _, err := io.ReadFull(sc.Conn, buf); err != nil {
return err
}
if buf[0] != socks5Version {
return fmt.Errorf("error socks version %s", buf[0])
}
if buf[1] != 0x00 && buf[1] != 0x02 {
return fmt.Errorf("server return with code %s", buf[1])
}
if buf[1] == 0x00 {
sc.handshakeDone = true
return nil
}
// password auth
buf = make([]byte, 3+len(sc.Username)+len(sc.Password))
buf[0] = 0x01 // auth protocol version
buf[1] = byte(len(sc.Username)) // username length
copy(buf[2:], []byte(sc.Username)) // username
buf[2+len(sc.Username)] = byte(len(sc.Password)) // password length
copy(buf[3+len(sc.Username):], []byte(sc.Password)) //password
if _, err := sc.Conn.Write(buf); err != nil {
return err
}
if _, err := sc.Conn.Read(buf[:2]); err != nil {
return err
}
if buf[0] != 0x01 {
return fmt.Errorf("unexpected auth protocol version %v", buf[0])
}
// password auth success
if buf[1] == 0x00 {
sc.handshakeDone = true
return nil
}
return fmt.Errorf("password rejected")
}
// Connect connects to socks5 server and handshake
func (sc *Client) Connect(host string, port uint16) error {
if !sc.handshakeDone {
if err := sc.handShake(); err != nil {
return err
}
}
if sc.connected {
return nil
}
l := 4 + len(host) + 1 + 2
buf := make([]byte, l)
buf[0] = socks5Version
buf[1] = cmdConnect
buf[2] = 0x00
buf[3] = addrTypeDomain
buf[4] = byte(len(host))
copy(buf[5:5+len(host)], []byte(host))
binary.BigEndian.PutUint16(buf[l-2:l], port)
if _, err := sc.Conn.Write(buf); err != nil {
return err
}
buf1 := make([]byte, 128)
if _, err := io.ReadAtLeast(sc.Conn, buf1, 10); err != nil {
return err
}
if buf1[0] != socks5Version {
return fmt.Errorf("error socks version %d", buf1[0])
}
if buf1[1] != 0x00 {
return fmt.Errorf("server error code %d", buf1[1])
}
sc.connected = true
return nil
}
// Read read from the underlying connection
func (sc *Client) Read(b []byte) (int, error) {
if !sc.connected {
return 0, fmt.Errorf("call connect first")
}
return sc.Conn.Read(b)
}
// Write write data to underlying connection
func (sc *Client) Write(b []byte) (int, error) {
if !sc.connected {
return 0, fmt.Errorf("call connect first")
}
return sc.Conn.Write(b)
}
// Close close the underlying connection
func (sc *Client) Close() error {
return sc.Conn.Close()
}

@ -21,7 +21,13 @@ type dialFunc func(network, addr string) (net.Conn, error)
// SocksConn present a client connection // SocksConn present a client connection
type SocksConn struct { type SocksConn struct {
ClientConn net.Conn ClientConn net.Conn
// the function to dial to upstream server
// when nil, use net.Dial
Dial dialFunc Dial dialFunc
// username for socks5 server
Username string
// password
Password string
} }
// Serve serve the client // Serve serve the client
@ -42,7 +48,8 @@ func (s *SocksConn) Serve() {
s4 := socks4Conn{clientConn: s.ClientConn, dial: dial} s4 := socks4Conn{clientConn: s.ClientConn, dial: dial}
s4.Serve() s4.Serve()
case socks5Version: case socks5Version:
s5 := socks5Conn{clientConn: s.ClientConn, dial: dial} s5 := socks5Conn{clientConn: s.ClientConn, dial: dial,
username: s.Username, password: s.Password}
s5.Serve() s5.Serve()
default: default:
log.Printf("error version %d", buf[0]) log.Printf("error version %d", buf[0])

@ -46,9 +46,13 @@ byte |0 | 1 | 2 | 3 | 4 | .. | n-2 | n-1 | n |
*/ */
var Socks5AuthRequired bool // Socks5AuthRequired means socks5 server need auth or not
type socks5Conn struct { type socks5Conn struct {
// username
username string
// password
password string
//addr string //addr string
clientConn net.Conn clientConn net.Conn
serverConn net.Conn serverConn net.Conn
@ -84,14 +88,14 @@ func (s5 *socks5Conn) handshake() error {
l := int(buf[0]) + 1 l := int(buf[0]) + 1
if n < l { if n < l {
// read remains data // read remains data
if n1, err := io.ReadFull(s5.clientConn, buf[n:l]); err != nil { n1, err := io.ReadFull(s5.clientConn, buf[n:l])
if err != nil {
return err return err
} else {
n += n1
} }
n += n1
} }
if !Socks5AuthRequired { if s5.username == "" {
// no auth required // no auth required
s5.clientConn.Write([]byte{0x05, 0x00}) s5.clientConn.Write([]byte{0x05, 0x00})
return nil return nil
@ -135,10 +139,10 @@ func (s5 *socks5Conn) passwordAuth() error {
return errors.New("unsupported auth version") return errors.New("unsupported auth version")
} }
username_len := int(buf[1]) usernameLen := int(buf[1])
p0 := 2 p0 := 2
p1 := p0 + username_len p1 := p0 + usernameLen
if n < p1 { if n < p1 {
n1, err := s5.clientConn.Read(buf[n:]) n1, err := s5.clientConn.Read(buf[n:])
@ -149,10 +153,10 @@ func (s5 *socks5Conn) passwordAuth() error {
} }
username := buf[p0:p1] username := buf[p0:p1]
password_len := int(buf[p1]) passwordLen := int(buf[p1])
p3 := p1 + 1 p3 := p1 + 1
p4 := p3 + password_len p4 := p3 + passwordLen
if n < p4 { if n < p4 {
n1, err := s5.clientConn.Read(buf[n:]) n1, err := s5.clientConn.Read(buf[n:])
@ -164,15 +168,17 @@ func (s5 *socks5Conn) passwordAuth() error {
password := buf[p3:p4] password := buf[p3:p4]
log.Printf("get username: %s, password: %s", username, password) // log.Printf("get username: %s, password: %s", username, password)
if string(username) == "" && string(password) == "" { if string(username) == s5.username && string(password) == s5.password {
s5.clientConn.Write([]byte{0x01, 0x01})
} else {
s5.clientConn.Write([]byte{0x01, 0x00}) s5.clientConn.Write([]byte{0x01, 0x00})
return nil
} }
return nil // auth failed
s5.clientConn.Write([]byte{0x01, 0x01})
return fmt.Errorf("wrong password")
} }
func (s5 *socks5Conn) processRequest() error { func (s5 *socks5Conn) processRequest() error {

@ -0,0 +1,112 @@
package socks
import (
//"bytes"
//"fmt"
"errors"
"log"
"net"
"testing"
)
func TestSocks(t *testing.T) {
if err := testSocks(t, "u1", "p1", true); err != nil {
t.Error(err)
}
if err := testSocks(t, "", "", false); err != nil {
t.Error(err)
}
if err := testSocks(t, "u3", "p3", true); err != nil {
t.Error(err)
}
if err := testSocks(t, "u3", "p3", false); err != nil {
log.Println(err)
} else {
t.Error("password not active")
}
if err := testSocks(t, "u3", "", false); err != nil {
log.Println(err)
} else {
t.Error("password not active")
}
}
func testSocks(t *testing.T, user, pass string, auth bool) error {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return err
}
defer l.Close()
l1, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return err
}
defer l1.Close()
addr := l.Addr().String()
addr1 := l1.Addr()
go func() {
conn, err := l.Accept()
if err != nil {
return
}
log.Printf("connected from %s", conn.RemoteAddr())
s := SocksConn{ClientConn: conn, Username: user, Password: pass}
s.Serve()
}()
go func() {
conn, err := l1.Accept()
if err != nil {
return
}
log.Printf("server 2 accept connection from %s", conn.RemoteAddr())
defer conn.Close()
buf := make([]byte, 512)
n, err := conn.Read(buf)
if err != nil {
return
}
conn.Write(buf[:n])
}()
c, err := net.Dial("tcp", addr)
if err != nil {
return err
}
defer c.Close()
var sc Client
if auth {
sc = Client{Conn: c, Username: user, Password: pass}
} else {
sc = Client{Conn: c}
}
if err = sc.Connect("localhost", uint16(addr1.(*net.TCPAddr).Port)); err != nil {
return err
}
log.Printf("connect success")
str := "hello1234"
buf := make([]byte, 512)
if _, err := sc.Write([]byte(str)); err != nil {
return err
}
n, err := sc.Read(buf)
if err != nil {
return err
}
if string(buf[:n]) != str {
return errors.New("socks test failed")
}
return nil
}
Loading…
Cancel
Save