refine the code

read all data at once
allocate byte buffer once
reply a error to client when connect to remote failed, instead of
connection reset
socks4/5 use the same data exchange method
master
Dingjun 7 years ago
parent 1b11369965
commit 1ff9df1ae0

@ -30,8 +30,9 @@ func (sc *Client) handShake() error {
return err
}
buf := make([]byte, 2)
if _, err := io.ReadFull(sc.Conn, buf); err != nil {
buf := make([]byte, 512)
if _, err := io.ReadFull(sc.Conn, buf[:2]); err != nil {
return err
}
@ -50,7 +51,7 @@ func (sc *Client) handShake() error {
// password auth
buf = make([]byte, 3+len(sc.Username)+len(sc.Password))
l := 3 + len(sc.Username) + len(sc.Password)
buf[0] = 0x01 // auth protocol version
buf[1] = byte(len(sc.Username)) // username length
@ -58,7 +59,7 @@ func (sc *Client) handShake() error {
buf[2+len(sc.Username)] = byte(len(sc.Password)) // password length
copy(buf[3+len(sc.Username):], []byte(sc.Password)) //password
if _, err := sc.Conn.Write(buf); err != nil {
if _, err := sc.Conn.Write(buf[:l]); err != nil {
return err
}
@ -118,8 +119,9 @@ func (sc *Client) Connect(host string, port uint16) error {
return fmt.Errorf("only one connection allowed")
}
buf := make([]byte, 512)
l := 4 + len(host) + 1 + 2
buf := make([]byte, l)
buf[0] = socks5Version
buf[1] = cmdConnect
buf[2] = 0x00
@ -130,22 +132,20 @@ func (sc *Client) Connect(host string, port uint16) error {
binary.BigEndian.PutUint16(buf[l-2:l], port)
if _, err := sc.Conn.Write(buf); err != nil {
if _, err := sc.Conn.Write(buf[:l]); err != nil {
return err
}
buf1 := make([]byte, 128)
if _, err := io.ReadAtLeast(sc.Conn, buf1, 10); err != nil {
if _, err := io.ReadAtLeast(sc.Conn, buf, 10); err != nil {
return err
}
if buf1[0] != socks5Version {
return fmt.Errorf("error socks version %d", buf1[0])
if buf[0] != socks5Version {
return fmt.Errorf("error socks version %d", buf[0])
}
if buf1[1] != 0x00 {
return fmt.Errorf("server error code %d", buf1[1])
if buf[1] != 0x00 {
return fmt.Errorf("server error code %d", buf[1])
}
sc.connected = true

@ -33,10 +33,14 @@ type Conn struct {
// Serve serve the client
func (s *Conn) Serve() {
buf := make([]byte, 1)
buf := make([]byte, 512)
// read version
io.ReadFull(s.Conn, buf)
n, err := io.ReadAtLeast(s.Conn, buf, 1)
if err != nil {
log.Println(err)
return
}
dial := s.Dial
if s.Dial == nil {
@ -47,13 +51,30 @@ func (s *Conn) Serve() {
switch buf[0] {
case socks4Version:
s4 := socks4Conn{clientConn: s.Conn, dial: dial}
s4.Serve()
s4.Serve(buf, n)
case socks5Version:
s5 := socks5Conn{clientConn: s.Conn, dial: dial,
username: s.Username, password: s.Password}
s5.Serve()
s5.Serve(buf, n)
default:
log.Printf("error version %d", buf[0])
log.Printf("unknown socks version 0x%x", buf[0])
s.Conn.Close()
}
}
func forward(c1, c2 io.ReadWriter) {
c := make(chan struct{}, 2)
go func() {
io.Copy(c1, c2)
c <- struct{}{}
}()
go func() {
io.Copy(c2, c1)
c <- struct{}{}
}()
<-c
}

@ -37,10 +37,10 @@ type socks4Conn struct {
dial DialFunc
}
func (s4 *socks4Conn) Serve() {
func (s4 *socks4Conn) Serve(b []byte, n int) {
defer s4.Close()
if err := s4.processRequest(); err != nil {
if err := s4.processRequest(b, n); err != nil {
log.Println(err)
return
}
@ -56,34 +56,18 @@ func (s4 *socks4Conn) Close() {
}
}
func (s4 *socks4Conn) forward() {
c := make(chan int, 2)
go func() {
io.Copy(s4.clientConn, s4.serverConn)
c <- 1
}()
go func() {
io.Copy(s4.serverConn, s4.clientConn)
c <- 1
}()
<-c
}
func (s4 *socks4Conn) processRequest() error {
// version has already read out by socksConn.Serve()
func (s4 *socks4Conn) processRequest(buf []byte, n int) (err error) {
// process command and target here
buf := make([]byte, 128)
// read header
n, err := io.ReadAtLeast(s4.clientConn, buf, 8)
if n < 8 {
n1, err := io.ReadAtLeast(s4.clientConn, buf[n:], 8-n)
if err != nil {
return err
}
n += n1
}
buf = buf[1:n]
// command only support connect
if buf[0] != cmdConnect {
@ -99,7 +83,7 @@ func (s4 *socks4Conn) processRequest() error {
// NULL-terminated user string
// jump to NULL character
var j int
for j = 7; j < n; j++ {
for j = 7; j < n-1; j++ {
if buf[j] == 0x00 {
break
}
@ -114,7 +98,7 @@ func (s4 *socks4Conn) processRequest() error {
var i = j
// jump to the end of hostname
for j = i; j < n; j++ {
for j = i; j < n-1; j++ {
if buf[j] == 0x00 {
break
}
@ -124,20 +108,21 @@ func (s4 *socks4Conn) processRequest() error {
target := net.JoinHostPort(host, fmt.Sprintf("%d", port))
// reply user with connect success
// if dial to target failed, user will receive connection reset
s4.clientConn.Write([]byte{0x00, 0x5a, 0x01, 0x02, 0x00, 0x00, 0x00, 0x00})
//log.Printf("connecting to %s\r\n", target)
// connect to the target
s4.serverConn, err = s4.dial("tcp", target)
if err != nil {
// connection failed
s4.clientConn.Write([]byte{0x00, 0x5b, 0x01, 0x02, 0x00, 0x00, 0x00, 0x00})
return err
}
// connection success
s4.clientConn.Write([]byte{0x00, 0x5a, 0x01, 0x02, 0x00, 0x00, 0x00, 0x00})
// enter data exchange
s4.forward()
forward(s4.clientConn, s4.serverConn)
return nil
}

@ -59,10 +59,10 @@ type socks5Conn struct {
dial DialFunc
}
func (s5 *socks5Conn) Serve() {
func (s5 *socks5Conn) Serve(b []byte, n int) {
defer s5.Close()
if err := s5.handshake(); err != nil {
if err := s5.handshake(b, n); err != nil {
log.Println(err)
return
}
@ -73,22 +73,21 @@ func (s5 *socks5Conn) Serve() {
}
}
func (s5 *socks5Conn) handshake() error {
// version has already readed by socksConn.Serve()
// only process auth methods here
buf := make([]byte, 258)
func (s5 *socks5Conn) handshake(buf []byte, n int) (err error) {
// read auth methods
n, err := io.ReadAtLeast(s5.clientConn, buf, 1)
if n < 2 {
n1, err := io.ReadAtLeast(s5.clientConn, buf[1:], 1)
if err != nil {
return err
}
n += n1
}
l := int(buf[0]) + 1
if n < l {
l := int(buf[1])
if n != (l + 2) {
// read remains data
n1, err := io.ReadFull(s5.clientConn, buf[n:l])
n1, err := io.ReadFull(s5.clientConn, buf[n:l+2+1])
if err != nil {
return err
}
@ -106,7 +105,7 @@ func (s5 *socks5Conn) handshake() error {
// check auth method
// only password(0x02) supported
for i := 1; i < n; i++ {
for i := 2; i < n; i++ {
if buf[i] == passAuth {
hasPassAuth = true
break
@ -236,41 +235,25 @@ func (s5 *socks5Conn) processRequest() error {
// target address
target := net.JoinHostPort(host, fmt.Sprintf("%d", port))
// reply user with connect success
// if dial to target failed, user will receive connection reset
s5.clientConn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01})
//log.Printf("connecing to %s\r\n", target)
// connect to the target
s5.serverConn, err = s5.dial("tcp", target)
if err != nil {
// connection failed
s5.clientConn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01})
return err
}
// connection success
s5.clientConn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01})
// enter data exchange
s5.forward()
forward(s5.clientConn, s5.serverConn)
return nil
}
func (s5 *socks5Conn) forward() {
c := make(chan int, 2)
go func() {
io.Copy(s5.clientConn, s5.serverConn)
c <- 1
}()
go func() {
io.Copy(s5.serverConn, s5.clientConn)
c <- 1
}()
<-c
}
func (s5 *socks5Conn) Close() {
if s5.serverConn != nil {
s5.serverConn.Close()

Loading…
Cancel
Save