diff --git a/client.go b/client.go index 49f540d..d22da0a 100755 --- a/client.go +++ b/client.go @@ -30,8 +30,9 @@ func (sc *Client) handShake() error { return err } - buf := make([]byte, 2) - if _, err := io.ReadFull(sc.Conn, buf); err != nil { + buf := make([]byte, 512) + + if _, err := io.ReadFull(sc.Conn, buf[:2]); err != nil { return err } @@ -50,7 +51,7 @@ func (sc *Client) handShake() error { // password auth - buf = make([]byte, 3+len(sc.Username)+len(sc.Password)) + l := 3 + len(sc.Username) + len(sc.Password) buf[0] = 0x01 // auth protocol version buf[1] = byte(len(sc.Username)) // username length @@ -58,7 +59,7 @@ func (sc *Client) handShake() error { buf[2+len(sc.Username)] = byte(len(sc.Password)) // password length copy(buf[3+len(sc.Username):], []byte(sc.Password)) //password - if _, err := sc.Conn.Write(buf); err != nil { + if _, err := sc.Conn.Write(buf[:l]); err != nil { return err } @@ -118,8 +119,9 @@ func (sc *Client) Connect(host string, port uint16) error { return fmt.Errorf("only one connection allowed") } + buf := make([]byte, 512) + l := 4 + len(host) + 1 + 2 - buf := make([]byte, l) buf[0] = socks5Version buf[1] = cmdConnect buf[2] = 0x00 @@ -130,22 +132,20 @@ func (sc *Client) Connect(host string, port uint16) error { binary.BigEndian.PutUint16(buf[l-2:l], port) - if _, err := sc.Conn.Write(buf); err != nil { + if _, err := sc.Conn.Write(buf[:l]); err != nil { return err } - buf1 := make([]byte, 128) - - if _, err := io.ReadAtLeast(sc.Conn, buf1, 10); err != nil { + if _, err := io.ReadAtLeast(sc.Conn, buf, 10); err != nil { return err } - if buf1[0] != socks5Version { - return fmt.Errorf("error socks version %d", buf1[0]) + if buf[0] != socks5Version { + return fmt.Errorf("error socks version %d", buf[0]) } - if buf1[1] != 0x00 { - return fmt.Errorf("server error code %d", buf1[1]) + if buf[1] != 0x00 { + return fmt.Errorf("server error code %d", buf[1]) } sc.connected = true diff --git a/socks.go b/socks.go index 1f9a625..ffa89bc 100755 --- a/socks.go +++ b/socks.go @@ -33,10 +33,14 @@ type Conn struct { // Serve serve the client func (s *Conn) Serve() { - buf := make([]byte, 1) + buf := make([]byte, 512) // read version - io.ReadFull(s.Conn, buf) + n, err := io.ReadAtLeast(s.Conn, buf, 1) + if err != nil { + log.Println(err) + return + } dial := s.Dial if s.Dial == nil { @@ -47,13 +51,30 @@ func (s *Conn) Serve() { switch buf[0] { case socks4Version: s4 := socks4Conn{clientConn: s.Conn, dial: dial} - s4.Serve() + s4.Serve(buf, n) case socks5Version: s5 := socks5Conn{clientConn: s.Conn, dial: dial, username: s.Username, password: s.Password} - s5.Serve() + s5.Serve(buf, n) default: - log.Printf("error version %d", buf[0]) + log.Printf("unknown socks version 0x%x", buf[0]) s.Conn.Close() } } + +func forward(c1, c2 io.ReadWriter) { + + c := make(chan struct{}, 2) + + go func() { + io.Copy(c1, c2) + c <- struct{}{} + }() + + go func() { + io.Copy(c2, c1) + c <- struct{}{} + }() + + <-c +} diff --git a/socks4.go b/socks4.go index 6e9b1de..c735ff4 100755 --- a/socks4.go +++ b/socks4.go @@ -37,10 +37,10 @@ type socks4Conn struct { dial DialFunc } -func (s4 *socks4Conn) Serve() { +func (s4 *socks4Conn) Serve(b []byte, n int) { defer s4.Close() - if err := s4.processRequest(); err != nil { + if err := s4.processRequest(b, n); err != nil { log.Println(err) return } @@ -56,35 +56,19 @@ func (s4 *socks4Conn) Close() { } } -func (s4 *socks4Conn) forward() { - - c := make(chan int, 2) - - go func() { - io.Copy(s4.clientConn, s4.serverConn) - c <- 1 - }() - - go func() { - io.Copy(s4.serverConn, s4.clientConn) - c <- 1 - }() - - <-c -} - -func (s4 *socks4Conn) processRequest() error { - // version has already read out by socksConn.Serve() +func (s4 *socks4Conn) processRequest(buf []byte, n int) (err error) { // process command and target here - buf := make([]byte, 128) - - // read header - n, err := io.ReadAtLeast(s4.clientConn, buf, 8) - if err != nil { - return err + if n < 8 { + n1, err := io.ReadAtLeast(s4.clientConn, buf[n:], 8-n) + if err != nil { + return err + } + n += n1 } + buf = buf[1:n] + // command only support connect if buf[0] != cmdConnect { return fmt.Errorf("error command %d", buf[0]) @@ -99,7 +83,7 @@ func (s4 *socks4Conn) processRequest() error { // NULL-terminated user string // jump to NULL character var j int - for j = 7; j < n; j++ { + for j = 7; j < n-1; j++ { if buf[j] == 0x00 { break } @@ -114,7 +98,7 @@ func (s4 *socks4Conn) processRequest() error { var i = j // jump to the end of hostname - for j = i; j < n; j++ { + for j = i; j < n-1; j++ { if buf[j] == 0x00 { break } @@ -124,20 +108,21 @@ func (s4 *socks4Conn) processRequest() error { target := net.JoinHostPort(host, fmt.Sprintf("%d", port)) - // reply user with connect success - // if dial to target failed, user will receive connection reset - s4.clientConn.Write([]byte{0x00, 0x5a, 0x01, 0x02, 0x00, 0x00, 0x00, 0x00}) - //log.Printf("connecting to %s\r\n", target) // connect to the target s4.serverConn, err = s4.dial("tcp", target) if err != nil { + // connection failed + s4.clientConn.Write([]byte{0x00, 0x5b, 0x01, 0x02, 0x00, 0x00, 0x00, 0x00}) return err } + // connection success + s4.clientConn.Write([]byte{0x00, 0x5a, 0x01, 0x02, 0x00, 0x00, 0x00, 0x00}) + // enter data exchange - s4.forward() + forward(s4.clientConn, s4.serverConn) return nil } diff --git a/socks5.go b/socks5.go index 0535555..1c00c39 100755 --- a/socks5.go +++ b/socks5.go @@ -59,10 +59,10 @@ type socks5Conn struct { dial DialFunc } -func (s5 *socks5Conn) Serve() { +func (s5 *socks5Conn) Serve(b []byte, n int) { defer s5.Close() - if err := s5.handshake(); err != nil { + if err := s5.handshake(b, n); err != nil { log.Println(err) return } @@ -73,22 +73,21 @@ func (s5 *socks5Conn) Serve() { } } -func (s5 *socks5Conn) handshake() error { - // version has already readed by socksConn.Serve() - // only process auth methods here - - buf := make([]byte, 258) +func (s5 *socks5Conn) handshake(buf []byte, n int) (err error) { // read auth methods - n, err := io.ReadAtLeast(s5.clientConn, buf, 1) - if err != nil { - return err + if n < 2 { + n1, err := io.ReadAtLeast(s5.clientConn, buf[1:], 1) + if err != nil { + return err + } + n += n1 } - l := int(buf[0]) + 1 - if n < l { + l := int(buf[1]) + if n != (l + 2) { // read remains data - n1, err := io.ReadFull(s5.clientConn, buf[n:l]) + n1, err := io.ReadFull(s5.clientConn, buf[n:l+2+1]) if err != nil { return err } @@ -106,7 +105,7 @@ func (s5 *socks5Conn) handshake() error { // check auth method // only password(0x02) supported - for i := 1; i < n; i++ { + for i := 2; i < n; i++ { if buf[i] == passAuth { hasPassAuth = true break @@ -236,41 +235,25 @@ func (s5 *socks5Conn) processRequest() error { // target address target := net.JoinHostPort(host, fmt.Sprintf("%d", port)) - // reply user with connect success - // if dial to target failed, user will receive connection reset - s5.clientConn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01}) - //log.Printf("connecing to %s\r\n", target) // connect to the target s5.serverConn, err = s5.dial("tcp", target) if err != nil { + // connection failed + s5.clientConn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01}) return err } + // connection success + s5.clientConn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01}) + // enter data exchange - s5.forward() + forward(s5.clientConn, s5.serverConn) return nil } -func (s5 *socks5Conn) forward() { - - c := make(chan int, 2) - - go func() { - io.Copy(s5.clientConn, s5.serverConn) - c <- 1 - }() - - go func() { - io.Copy(s5.serverConn, s5.clientConn) - c <- 1 - }() - - <-c -} - func (s5 *socks5Conn) Close() { if s5.serverConn != nil { s5.serverConn.Close()