diff --git a/client.go b/client.go new file mode 100755 index 0000000..d9425cc --- /dev/null +++ b/client.go @@ -0,0 +1,145 @@ +package socks + +import ( + "encoding/binary" + "fmt" + "io" + "net" +) + +// Client is a net.Conn with socks5 support +type Client struct { + net.Conn + // socks5 username + Username string + // socks5 password + Password string + handshakeDone bool + connected bool +} + +func (sc *Client) handShake() error { + if sc.handshakeDone { + return nil + } + + // password auth or none + if _, err := sc.Conn.Write([]byte{socks5Version, 0x02, 0x00, 0x02}); err != nil { + return err + } + + buf := make([]byte, 2) + if _, err := io.ReadFull(sc.Conn, buf); err != nil { + return err + } + + if buf[0] != socks5Version { + return fmt.Errorf("error socks version %s", buf[0]) + } + + if buf[1] != 0x00 && buf[1] != 0x02 { + return fmt.Errorf("server return with code %s", buf[1]) + } + + if buf[1] == 0x00 { + sc.handshakeDone = true + return nil + } + + // password auth + + buf = make([]byte, 3+len(sc.Username)+len(sc.Password)) + + buf[0] = 0x01 // auth protocol version + buf[1] = byte(len(sc.Username)) // username length + copy(buf[2:], []byte(sc.Username)) // username + buf[2+len(sc.Username)] = byte(len(sc.Password)) // password length + copy(buf[3+len(sc.Username):], []byte(sc.Password)) //password + + if _, err := sc.Conn.Write(buf); err != nil { + return err + } + + if _, err := sc.Conn.Read(buf[:2]); err != nil { + return err + } + + if buf[0] != 0x01 { + return fmt.Errorf("unexpected auth protocol version %v", buf[0]) + } + + // password auth success + if buf[1] == 0x00 { + sc.handshakeDone = true + return nil + } + + return fmt.Errorf("password rejected") +} + +// Connect connects to socks5 server and handshake +func (sc *Client) Connect(host string, port uint16) error { + if !sc.handshakeDone { + if err := sc.handShake(); err != nil { + return err + } + } + + if sc.connected { + return nil + } + + l := 4 + len(host) + 1 + 2 + buf := make([]byte, l) + buf[0] = socks5Version + buf[1] = cmdConnect + buf[2] = 0x00 + buf[3] = addrTypeDomain + buf[4] = byte(len(host)) + + copy(buf[5:5+len(host)], []byte(host)) + + binary.BigEndian.PutUint16(buf[l-2:l], port) + + if _, err := sc.Conn.Write(buf); err != nil { + return err + } + + buf1 := make([]byte, 128) + + if _, err := io.ReadAtLeast(sc.Conn, buf1, 10); err != nil { + return err + } + + if buf1[0] != socks5Version { + return fmt.Errorf("error socks version %d", buf1[0]) + } + + if buf1[1] != 0x00 { + return fmt.Errorf("server error code %d", buf1[1]) + } + + sc.connected = true + return nil +} + +// Read read from the underlying connection +func (sc *Client) Read(b []byte) (int, error) { + if !sc.connected { + return 0, fmt.Errorf("call connect first") + } + return sc.Conn.Read(b) +} + +// Write write data to underlying connection +func (sc *Client) Write(b []byte) (int, error) { + if !sc.connected { + return 0, fmt.Errorf("call connect first") + } + return sc.Conn.Write(b) +} + +// Close close the underlying connection +func (sc *Client) Close() error { + return sc.Conn.Close() +} diff --git a/socks.go b/socks.go index e150545..3bb0ea2 100755 --- a/socks.go +++ b/socks.go @@ -21,7 +21,13 @@ type dialFunc func(network, addr string) (net.Conn, error) // SocksConn present a client connection type SocksConn struct { ClientConn net.Conn - Dial dialFunc + // the function to dial to upstream server + // when nil, use net.Dial + Dial dialFunc + // username for socks5 server + Username string + // password + Password string } // Serve serve the client @@ -42,7 +48,8 @@ func (s *SocksConn) Serve() { s4 := socks4Conn{clientConn: s.ClientConn, dial: dial} s4.Serve() case socks5Version: - s5 := socks5Conn{clientConn: s.ClientConn, dial: dial} + s5 := socks5Conn{clientConn: s.ClientConn, dial: dial, + username: s.Username, password: s.Password} s5.Serve() default: log.Printf("error version %d", buf[0]) diff --git a/socks5.go b/socks5.go index 43b51d0..7a4a638 100755 --- a/socks5.go +++ b/socks5.go @@ -46,9 +46,13 @@ byte |0 | 1 | 2 | 3 | 4 | .. | n-2 | n-1 | n | */ -var Socks5AuthRequired bool +// Socks5AuthRequired means socks5 server need auth or not type socks5Conn struct { + // username + username string + // password + password string //addr string clientConn net.Conn serverConn net.Conn @@ -84,14 +88,14 @@ func (s5 *socks5Conn) handshake() error { l := int(buf[0]) + 1 if n < l { // read remains data - if n1, err := io.ReadFull(s5.clientConn, buf[n:l]); err != nil { + n1, err := io.ReadFull(s5.clientConn, buf[n:l]) + if err != nil { return err - } else { - n += n1 } + n += n1 } - if !Socks5AuthRequired { + if s5.username == "" { // no auth required s5.clientConn.Write([]byte{0x05, 0x00}) return nil @@ -135,10 +139,10 @@ func (s5 *socks5Conn) passwordAuth() error { return errors.New("unsupported auth version") } - username_len := int(buf[1]) + usernameLen := int(buf[1]) p0 := 2 - p1 := p0 + username_len + p1 := p0 + usernameLen if n < p1 { n1, err := s5.clientConn.Read(buf[n:]) @@ -149,10 +153,10 @@ func (s5 *socks5Conn) passwordAuth() error { } username := buf[p0:p1] - password_len := int(buf[p1]) + passwordLen := int(buf[p1]) p3 := p1 + 1 - p4 := p3 + password_len + p4 := p3 + passwordLen if n < p4 { n1, err := s5.clientConn.Read(buf[n:]) @@ -164,15 +168,17 @@ func (s5 *socks5Conn) passwordAuth() error { password := buf[p3:p4] - log.Printf("get username: %s, password: %s", username, password) + // log.Printf("get username: %s, password: %s", username, password) - if string(username) == "" && string(password) == "" { - s5.clientConn.Write([]byte{0x01, 0x01}) - } else { + if string(username) == s5.username && string(password) == s5.password { s5.clientConn.Write([]byte{0x01, 0x00}) + return nil } - return nil + // auth failed + s5.clientConn.Write([]byte{0x01, 0x01}) + + return fmt.Errorf("wrong password") } func (s5 *socks5Conn) processRequest() error { diff --git a/socks_test.go b/socks_test.go new file mode 100755 index 0000000..e680dad --- /dev/null +++ b/socks_test.go @@ -0,0 +1,112 @@ +package socks + +import ( + //"bytes" + //"fmt" + "errors" + "log" + "net" + "testing" +) + +func TestSocks(t *testing.T) { + if err := testSocks(t, "u1", "p1", true); err != nil { + t.Error(err) + } + if err := testSocks(t, "", "", false); err != nil { + t.Error(err) + } + if err := testSocks(t, "u3", "p3", true); err != nil { + t.Error(err) + } + + if err := testSocks(t, "u3", "p3", false); err != nil { + log.Println(err) + } else { + t.Error("password not active") + } + + if err := testSocks(t, "u3", "", false); err != nil { + log.Println(err) + } else { + t.Error("password not active") + } +} + +func testSocks(t *testing.T, user, pass string, auth bool) error { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return err + } + defer l.Close() + + l1, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return err + } + defer l1.Close() + + addr := l.Addr().String() + + addr1 := l1.Addr() + + go func() { + conn, err := l.Accept() + if err != nil { + return + } + log.Printf("connected from %s", conn.RemoteAddr()) + s := SocksConn{ClientConn: conn, Username: user, Password: pass} + s.Serve() + }() + + go func() { + conn, err := l1.Accept() + if err != nil { + return + } + log.Printf("server 2 accept connection from %s", conn.RemoteAddr()) + defer conn.Close() + buf := make([]byte, 512) + n, err := conn.Read(buf) + if err != nil { + return + } + conn.Write(buf[:n]) + }() + + c, err := net.Dial("tcp", addr) + if err != nil { + return err + } + + defer c.Close() + var sc Client + if auth { + sc = Client{Conn: c, Username: user, Password: pass} + } else { + sc = Client{Conn: c} + } + if err = sc.Connect("localhost", uint16(addr1.(*net.TCPAddr).Port)); err != nil { + return err + } + + log.Printf("connect success") + + str := "hello1234" + buf := make([]byte, 512) + + if _, err := sc.Write([]byte(str)); err != nil { + return err + } + + n, err := sc.Read(buf) + if err != nil { + return err + } + + if string(buf[:n]) != str { + return errors.New("socks test failed") + } + return nil +}