refine the code

read all data at once
allocate byte buffer once
reply a error to client when connect to remote failed, instead of
connection reset
socks4/5 use the same data exchange method
master
Dingjun 7 years ago
parent 1b11369965
commit 1ff9df1ae0

@ -30,8 +30,9 @@ func (sc *Client) handShake() error {
return err return err
} }
buf := make([]byte, 2) buf := make([]byte, 512)
if _, err := io.ReadFull(sc.Conn, buf); err != nil {
if _, err := io.ReadFull(sc.Conn, buf[:2]); err != nil {
return err return err
} }
@ -50,7 +51,7 @@ func (sc *Client) handShake() error {
// password auth // password auth
buf = make([]byte, 3+len(sc.Username)+len(sc.Password)) l := 3 + len(sc.Username) + len(sc.Password)
buf[0] = 0x01 // auth protocol version buf[0] = 0x01 // auth protocol version
buf[1] = byte(len(sc.Username)) // username length buf[1] = byte(len(sc.Username)) // username length
@ -58,7 +59,7 @@ func (sc *Client) handShake() error {
buf[2+len(sc.Username)] = byte(len(sc.Password)) // password length buf[2+len(sc.Username)] = byte(len(sc.Password)) // password length
copy(buf[3+len(sc.Username):], []byte(sc.Password)) //password copy(buf[3+len(sc.Username):], []byte(sc.Password)) //password
if _, err := sc.Conn.Write(buf); err != nil { if _, err := sc.Conn.Write(buf[:l]); err != nil {
return err return err
} }
@ -118,8 +119,9 @@ func (sc *Client) Connect(host string, port uint16) error {
return fmt.Errorf("only one connection allowed") return fmt.Errorf("only one connection allowed")
} }
buf := make([]byte, 512)
l := 4 + len(host) + 1 + 2 l := 4 + len(host) + 1 + 2
buf := make([]byte, l)
buf[0] = socks5Version buf[0] = socks5Version
buf[1] = cmdConnect buf[1] = cmdConnect
buf[2] = 0x00 buf[2] = 0x00
@ -130,22 +132,20 @@ func (sc *Client) Connect(host string, port uint16) error {
binary.BigEndian.PutUint16(buf[l-2:l], port) binary.BigEndian.PutUint16(buf[l-2:l], port)
if _, err := sc.Conn.Write(buf); err != nil { if _, err := sc.Conn.Write(buf[:l]); err != nil {
return err return err
} }
buf1 := make([]byte, 128) if _, err := io.ReadAtLeast(sc.Conn, buf, 10); err != nil {
if _, err := io.ReadAtLeast(sc.Conn, buf1, 10); err != nil {
return err return err
} }
if buf1[0] != socks5Version { if buf[0] != socks5Version {
return fmt.Errorf("error socks version %d", buf1[0]) return fmt.Errorf("error socks version %d", buf[0])
} }
if buf1[1] != 0x00 { if buf[1] != 0x00 {
return fmt.Errorf("server error code %d", buf1[1]) return fmt.Errorf("server error code %d", buf[1])
} }
sc.connected = true sc.connected = true

@ -33,10 +33,14 @@ type Conn struct {
// Serve serve the client // Serve serve the client
func (s *Conn) Serve() { func (s *Conn) Serve() {
buf := make([]byte, 1) buf := make([]byte, 512)
// read version // read version
io.ReadFull(s.Conn, buf) n, err := io.ReadAtLeast(s.Conn, buf, 1)
if err != nil {
log.Println(err)
return
}
dial := s.Dial dial := s.Dial
if s.Dial == nil { if s.Dial == nil {
@ -47,13 +51,30 @@ func (s *Conn) Serve() {
switch buf[0] { switch buf[0] {
case socks4Version: case socks4Version:
s4 := socks4Conn{clientConn: s.Conn, dial: dial} s4 := socks4Conn{clientConn: s.Conn, dial: dial}
s4.Serve() s4.Serve(buf, n)
case socks5Version: case socks5Version:
s5 := socks5Conn{clientConn: s.Conn, dial: dial, s5 := socks5Conn{clientConn: s.Conn, dial: dial,
username: s.Username, password: s.Password} username: s.Username, password: s.Password}
s5.Serve() s5.Serve(buf, n)
default: default:
log.Printf("error version %d", buf[0]) log.Printf("unknown socks version 0x%x", buf[0])
s.Conn.Close() s.Conn.Close()
} }
} }
func forward(c1, c2 io.ReadWriter) {
c := make(chan struct{}, 2)
go func() {
io.Copy(c1, c2)
c <- struct{}{}
}()
go func() {
io.Copy(c2, c1)
c <- struct{}{}
}()
<-c
}

@ -37,10 +37,10 @@ type socks4Conn struct {
dial DialFunc dial DialFunc
} }
func (s4 *socks4Conn) Serve() { func (s4 *socks4Conn) Serve(b []byte, n int) {
defer s4.Close() defer s4.Close()
if err := s4.processRequest(); err != nil { if err := s4.processRequest(b, n); err != nil {
log.Println(err) log.Println(err)
return return
} }
@ -56,34 +56,18 @@ func (s4 *socks4Conn) Close() {
} }
} }
func (s4 *socks4Conn) forward() { func (s4 *socks4Conn) processRequest(buf []byte, n int) (err error) {
c := make(chan int, 2)
go func() {
io.Copy(s4.clientConn, s4.serverConn)
c <- 1
}()
go func() {
io.Copy(s4.serverConn, s4.clientConn)
c <- 1
}()
<-c
}
func (s4 *socks4Conn) processRequest() error {
// version has already read out by socksConn.Serve()
// process command and target here // process command and target here
buf := make([]byte, 128) if n < 8 {
n1, err := io.ReadAtLeast(s4.clientConn, buf[n:], 8-n)
// read header
n, err := io.ReadAtLeast(s4.clientConn, buf, 8)
if err != nil { if err != nil {
return err return err
} }
n += n1
}
buf = buf[1:n]
// command only support connect // command only support connect
if buf[0] != cmdConnect { if buf[0] != cmdConnect {
@ -99,7 +83,7 @@ func (s4 *socks4Conn) processRequest() error {
// NULL-terminated user string // NULL-terminated user string
// jump to NULL character // jump to NULL character
var j int var j int
for j = 7; j < n; j++ { for j = 7; j < n-1; j++ {
if buf[j] == 0x00 { if buf[j] == 0x00 {
break break
} }
@ -114,7 +98,7 @@ func (s4 *socks4Conn) processRequest() error {
var i = j var i = j
// jump to the end of hostname // jump to the end of hostname
for j = i; j < n; j++ { for j = i; j < n-1; j++ {
if buf[j] == 0x00 { if buf[j] == 0x00 {
break break
} }
@ -124,20 +108,21 @@ func (s4 *socks4Conn) processRequest() error {
target := net.JoinHostPort(host, fmt.Sprintf("%d", port)) target := net.JoinHostPort(host, fmt.Sprintf("%d", port))
// reply user with connect success
// if dial to target failed, user will receive connection reset
s4.clientConn.Write([]byte{0x00, 0x5a, 0x01, 0x02, 0x00, 0x00, 0x00, 0x00})
//log.Printf("connecting to %s\r\n", target) //log.Printf("connecting to %s\r\n", target)
// connect to the target // connect to the target
s4.serverConn, err = s4.dial("tcp", target) s4.serverConn, err = s4.dial("tcp", target)
if err != nil { if err != nil {
// connection failed
s4.clientConn.Write([]byte{0x00, 0x5b, 0x01, 0x02, 0x00, 0x00, 0x00, 0x00})
return err return err
} }
// connection success
s4.clientConn.Write([]byte{0x00, 0x5a, 0x01, 0x02, 0x00, 0x00, 0x00, 0x00})
// enter data exchange // enter data exchange
s4.forward() forward(s4.clientConn, s4.serverConn)
return nil return nil
} }

@ -59,10 +59,10 @@ type socks5Conn struct {
dial DialFunc dial DialFunc
} }
func (s5 *socks5Conn) Serve() { func (s5 *socks5Conn) Serve(b []byte, n int) {
defer s5.Close() defer s5.Close()
if err := s5.handshake(); err != nil { if err := s5.handshake(b, n); err != nil {
log.Println(err) log.Println(err)
return return
} }
@ -73,22 +73,21 @@ func (s5 *socks5Conn) Serve() {
} }
} }
func (s5 *socks5Conn) handshake() error { func (s5 *socks5Conn) handshake(buf []byte, n int) (err error) {
// version has already readed by socksConn.Serve()
// only process auth methods here
buf := make([]byte, 258)
// read auth methods // read auth methods
n, err := io.ReadAtLeast(s5.clientConn, buf, 1) if n < 2 {
n1, err := io.ReadAtLeast(s5.clientConn, buf[1:], 1)
if err != nil { if err != nil {
return err return err
} }
n += n1
}
l := int(buf[0]) + 1 l := int(buf[1])
if n < l { if n != (l + 2) {
// read remains data // read remains data
n1, err := io.ReadFull(s5.clientConn, buf[n:l]) n1, err := io.ReadFull(s5.clientConn, buf[n:l+2+1])
if err != nil { if err != nil {
return err return err
} }
@ -106,7 +105,7 @@ func (s5 *socks5Conn) handshake() error {
// check auth method // check auth method
// only password(0x02) supported // only password(0x02) supported
for i := 1; i < n; i++ { for i := 2; i < n; i++ {
if buf[i] == passAuth { if buf[i] == passAuth {
hasPassAuth = true hasPassAuth = true
break break
@ -236,41 +235,25 @@ func (s5 *socks5Conn) processRequest() error {
// target address // target address
target := net.JoinHostPort(host, fmt.Sprintf("%d", port)) target := net.JoinHostPort(host, fmt.Sprintf("%d", port))
// reply user with connect success
// if dial to target failed, user will receive connection reset
s5.clientConn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01})
//log.Printf("connecing to %s\r\n", target) //log.Printf("connecing to %s\r\n", target)
// connect to the target // connect to the target
s5.serverConn, err = s5.dial("tcp", target) s5.serverConn, err = s5.dial("tcp", target)
if err != nil { if err != nil {
// connection failed
s5.clientConn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01})
return err return err
} }
// connection success
s5.clientConn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01})
// enter data exchange // enter data exchange
s5.forward() forward(s5.clientConn, s5.serverConn)
return nil return nil
} }
func (s5 *socks5Conn) forward() {
c := make(chan int, 2)
go func() {
io.Copy(s5.clientConn, s5.serverConn)
c <- 1
}()
go func() {
io.Copy(s5.serverConn, s5.clientConn)
c <- 1
}()
<-c
}
func (s5 *socks5Conn) Close() { func (s5 *socks5Conn) Close() {
if s5.serverConn != nil { if s5.serverConn != nil {
s5.serverConn.Close() s5.serverConn.Close()

Loading…
Cancel
Save