master
Dingjun 8 years ago
parent 0cdf31ddaf
commit cdfc17d24a

@ -18,11 +18,13 @@ const (
type dialFunc func(network, addr string) (net.Conn, error) type dialFunc func(network, addr string) (net.Conn, error)
// SocksConn present a client connection
type SocksConn struct { type SocksConn struct {
ClientConn net.Conn ClientConn net.Conn
Dial dialFunc Dial dialFunc
} }
// Serve serve the client
func (s *SocksConn) Serve() { func (s *SocksConn) Serve() {
buf := make([]byte, 1) buf := make([]byte, 1)

@ -1,105 +1,105 @@
package socks package socks
import ( import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io" "io"
"log" "log"
"net" "net"
) )
type socks4Conn struct { type socks4Conn struct {
server_conn net.Conn serverConn net.Conn
client_conn net.Conn clientConn net.Conn
dial dialFunc dial dialFunc
} }
func (s4 *socks4Conn) Serve() { func (s4 *socks4Conn) Serve() {
defer s4.Close() defer s4.Close()
if err := s4.processRequest(); err != nil { if err := s4.processRequest(); err != nil {
log.Println(err) log.Println(err)
return return
} }
} }
func (s4 *socks4Conn) Close() { func (s4 *socks4Conn) Close() {
if s4.client_conn != nil { if s4.clientConn != nil {
s4.client_conn.Close() s4.clientConn.Close()
} }
if s4.server_conn != nil { if s4.serverConn != nil {
s4.server_conn.Close() s4.serverConn.Close()
} }
} }
func (s4 *socks4Conn) forward() { func (s4 *socks4Conn) forward() {
go func() { go func() {
io.Copy(s4.client_conn, s4.server_conn) io.Copy(s4.clientConn, s4.serverConn)
}() }()
io.Copy(s4.server_conn, s4.client_conn) io.Copy(s4.serverConn, s4.clientConn)
} }
func (s4 *socks4Conn) processRequest() error { func (s4 *socks4Conn) processRequest() error {
// version has already read out by socksConn.Serve() // version has already read out by socksConn.Serve()
// process command and target here // process command and target here
buf := make([]byte, 128) buf := make([]byte, 128)
n, err := io.ReadAtLeast(s4.client_conn, buf, 8) n, err := io.ReadAtLeast(s4.clientConn, buf, 8)
if err != nil { if err != nil {
return err return err
} }
// only support connect // only support connect
if buf[0] != cmdConnect { if buf[0] != cmdConnect {
return fmt.Errorf("error command %s", buf[0]) return fmt.Errorf("error command %s", buf[0])
} }
port := binary.BigEndian.Uint16(buf[1:3]) port := binary.BigEndian.Uint16(buf[1:3])
ip := net.IP(buf[3:7]) ip := net.IP(buf[3:7])
// NULL-terminated user string // NULL-terminated user string
// jump to NULL character // jump to NULL character
var j int var j int
for j = 7; j < n; j++ { for j = 7; j < n; j++ {
if buf[j] == 0x00 { if buf[j] == 0x00 {
break break
} }
} }
host := ip.String() host := ip.String()
// socks4a // socks4a
// 0.0.0.x // 0.0.0.x
if ip[0] == 0x00 && ip[1] == 0x00 && ip[2] == 0x00 && ip[3] != 0x00 { if ip[0] == 0x00 && ip[1] == 0x00 && ip[2] == 0x00 && ip[3] != 0x00 {
j++ j++
var i = j var i = j
// jump to the end of hostname // jump to the end of hostname
for j = i; j < n; j++ { for j = i; j < n; j++ {
if buf[j] == 0x00 { if buf[j] == 0x00 {
break break
} }
} }
host = string(buf[i:j]) host = string(buf[i:j])
} }
target := net.JoinHostPort(host, fmt.Sprintf("%d", port)) target := net.JoinHostPort(host, fmt.Sprintf("%d", port))
// reply user with connect success // reply user with connect success
// if dial to target failed, user will receive connection reset // if dial to target failed, user will receive connection reset
s4.client_conn.Write([]byte{0x00, 0x5a, 0x01, 0x02, 0x00, 0x00, 0x00, 0x00}) s4.clientConn.Write([]byte{0x00, 0x5a, 0x01, 0x02, 0x00, 0x00, 0x00, 0x00})
log.Printf("connecting to %s", target) log.Printf("connecting to %s", target)
// connect to the target // connect to the target
s4.server_conn, err = s4.dial("tcp", target) s4.serverConn, err = s4.dial("tcp", target)
if err != nil { if err != nil {
return err return err
} }
// enter data exchange // enter data exchange
s4.forward() s4.forward()
return nil return nil
} }

@ -11,9 +11,9 @@ import (
type socks5Conn struct { type socks5Conn struct {
//addr string //addr string
client_conn net.Conn clientConn net.Conn
server_conn net.Conn serverConn net.Conn
dial dialFunc dial dialFunc
} }
func (s5 *socks5Conn) Serve() { func (s5 *socks5Conn) Serve() {
@ -34,28 +34,28 @@ func (s5 *socks5Conn) handshake() error {
// only process auth methods here // only process auth methods here
buf := make([]byte, 258) buf := make([]byte, 258)
n, err := io.ReadAtLeast(s5.client_conn, buf, 1) n, err := io.ReadAtLeast(s5.clientConn, buf, 1)
if err != nil { if err != nil {
return err return err
} }
l := int(buf[0]) + 1 l := int(buf[0]) + 1
if n < l { if n < l {
_, err := io.ReadFull(s5.client_conn, buf[n:l]) _, err := io.ReadFull(s5.clientConn, buf[n:l])
if err != nil { if err != nil {
return err return err
} }
} }
// no auth required // no auth required
s5.client_conn.Write([]byte{0x05, 0x00}) s5.clientConn.Write([]byte{0x05, 0x00})
return nil return nil
} }
func (s5 *socks5Conn) processRequest() error { func (s5 *socks5Conn) processRequest() error {
buf := make([]byte, 258) buf := make([]byte, 258)
n, err := io.ReadAtLeast(s5.client_conn, buf, 10) n, err := io.ReadAtLeast(s5.clientConn, buf, 10)
if err != nil { if err != nil {
return err return err
} }
@ -83,7 +83,7 @@ func (s5 *socks5Conn) processRequest() error {
msglen = 6 + hlen msglen = 6 + hlen
if n < msglen { if n < msglen {
_, err := io.ReadFull(s5.client_conn, buf[n:msglen]) _, err := io.ReadFull(s5.clientConn, buf[n:msglen])
if err != nil { if err != nil {
return err return err
} }
@ -103,12 +103,12 @@ func (s5 *socks5Conn) processRequest() error {
// reply user with connect success // reply user with connect success
// if dial to target failed, user will receive connection reset // if dial to target failed, user will receive connection reset
s5.client_conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01}) s5.clientConn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01})
log.Printf("connecing to %s", target) log.Printf("connecing to %s", target)
// connect to the target // connect to the target
s5.server_conn, err = s5.dial("tcp", target) s5.serverConn, err = s5.dial("tcp", target)
if err != nil { if err != nil {
return err return err
} }
@ -120,15 +120,15 @@ func (s5 *socks5Conn) processRequest() error {
} }
func (s5 *socks5Conn) forward() { func (s5 *socks5Conn) forward() {
go io.Copy(s5.client_conn, s5.server_conn) go io.Copy(s5.clientConn, s5.serverConn)
io.Copy(s5.server_conn, s5.client_conn) io.Copy(s5.serverConn, s5.clientConn)
} }
func (s5 *socks5Conn) Close() { func (s5 *socks5Conn) Close() {
if s5.server_conn != nil { if s5.serverConn != nil {
s5.server_conn.Close() s5.serverConn.Close()
} }
if s5.client_conn != nil { if s5.clientConn != nil {
s5.client_conn.Close() s5.clientConn.Close()
} }
} }

Loading…
Cancel
Save