diff --git a/.gitignore b/.gitignore index 2ed562a..a62d708 100755 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ *.exe *~ *bak +*.swp diff --git a/README.md b/README.md index aa5f47e..21329bf 100755 --- a/README.md +++ b/README.md @@ -1,20 +1,21 @@ socks-go ======= -A socks server implement by golang, support socks4/4a, socks5. +A socks server implemented by golang, support socks4/4a, socks5. +Only support connect command now. usage ==== Usage example: - import "github.com/fangdingjun/socks" + import socks "github.com/fangdingjun/socks-go" - fucn main(){ + func main(){ l, _ := net.Listen("tcp", ":1080") for { conn, _ := l.Accept() - s := socks.SocksConn{conn} - go s.Serve() + s := socks.SocksConn{ClientConn: conn, Dial: nil} // Dial is a function which dial to the upstream server + go s.Serve() // serve the socks request } } diff --git a/examples/server/.gitignore b/examples/server/.gitignore new file mode 100644 index 0000000..254defd --- /dev/null +++ b/examples/server/.gitignore @@ -0,0 +1 @@ +server diff --git a/cmd/server/README.md b/examples/server/README.md similarity index 100% rename from cmd/server/README.md rename to examples/server/README.md diff --git a/cmd/server/server.go b/examples/server/server.go old mode 100755 new mode 100644 similarity index 66% rename from cmd/server/server.go rename to examples/server/server.go index 249ec7f..527f3c8 --- a/cmd/server/server.go +++ b/examples/server/server.go @@ -1,9 +1,10 @@ package main import ( - "github.com/fangdingjun/socks" + socks "github.com/fangdingjun/socks-go" "log" "net" + "time" ) func main() { @@ -18,7 +19,8 @@ func main() { continue } log.Printf("connected from %s", c.RemoteAddr()) - s := socks.SocksConn{ClientConn: c} + d := net.Dialer{Timeout: 10 * time.Second} + s := socks.SocksConn{ClientConn: c, Dial: d.Dial} go s.Serve() } } diff --git a/socks.go b/socks.go old mode 100755 new mode 100644 index 4c58e97..e150545 --- a/socks.go +++ b/socks.go @@ -4,6 +4,7 @@ import ( "io" "log" "net" + "time" ) const ( @@ -15,25 +16,36 @@ const ( addrTypeIPv6 = 0x04 ) +type dialFunc func(network, addr string) (net.Conn, error) + +// SocksConn present a client connection type SocksConn struct { ClientConn net.Conn + Dial dialFunc } +// Serve serve the client func (s *SocksConn) Serve() { buf := make([]byte, 1) // read version io.ReadFull(s.ClientConn, buf) + dial := s.Dial + if s.Dial == nil { + d := net.Dialer{Timeout: 10 * time.Second} + dial = d.Dial + } + switch buf[0] { case socks4Version: - s4 := socks4Conn{client_conn: s.ClientConn} + s4 := socks4Conn{clientConn: s.ClientConn, dial: dial} s4.Serve() case socks5Version: - s5 := socks5Conn{client_conn: s.ClientConn} + s5 := socks5Conn{clientConn: s.ClientConn, dial: dial} s5.Serve() default: - log.Printf("error version %s", buf[0]) + log.Printf("error version %d", buf[0]) s.ClientConn.Close() } } diff --git a/socks4.go b/socks4.go index aa4cdc4..0eeda0d 100755 --- a/socks4.go +++ b/socks4.go @@ -8,13 +8,38 @@ import ( "net" ) +/* +socks4 protocol + +request +byte | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | ... | + |0x04|cmd| port | ip | user\0 | + +reply +byte | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7| + |0x00|status| | | + + +socks4a protocol + +request +byte | 0 | 1 | 2 | 3 |4 | 5 | 6 | 7 | 8 | ... |... | + |0x04|cmd| port | 0.0.0.x | user\0 |domain\0| + +reply +byte | 0 | 1 | 2 | 3 | 4 | 5 | 6| 7 | + |0x00|staus| port | ip | + +*/ type socks4Conn struct { - server_conn net.Conn - client_conn net.Conn + serverConn net.Conn + clientConn net.Conn + dial dialFunc } func (s4 *socks4Conn) Serve() { defer s4.Close() + if err := s4.processRequest(); err != nil { log.Println(err) return @@ -22,20 +47,30 @@ func (s4 *socks4Conn) Serve() { } func (s4 *socks4Conn) Close() { - if s4.client_conn != nil { - s4.client_conn.Close() + if s4.clientConn != nil { + s4.clientConn.Close() } - if s4.server_conn != nil { - s4.server_conn.Close() + + if s4.serverConn != nil { + s4.serverConn.Close() } } func (s4 *socks4Conn) forward() { + + c := make(chan int, 2) + go func() { - io.Copy(s4.client_conn, s4.server_conn) + io.Copy(s4.clientConn, s4.serverConn) + c <- 1 }() - io.Copy(s4.server_conn, s4.client_conn) + go func() { + io.Copy(s4.serverConn, s4.clientConn) + c <- 1 + }() + + <-c } func (s4 *socks4Conn) processRequest() error { @@ -43,18 +78,22 @@ func (s4 *socks4Conn) processRequest() error { // process command and target here buf := make([]byte, 128) - n, err := io.ReadAtLeast(s4.client_conn, buf, 8) + + // read header + n, err := io.ReadAtLeast(s4.clientConn, buf, 8) if err != nil { return err } - // only support connect + // command only support connect if buf[0] != cmdConnect { return fmt.Errorf("error command %s", buf[0]) } + // get port port := binary.BigEndian.Uint16(buf[1:3]) + // get ip ip := net.IP(buf[3:7]) // NULL-terminated user string @@ -87,12 +126,12 @@ func (s4 *socks4Conn) processRequest() error { // reply user with connect success // if dial to target failed, user will receive connection reset - s4.client_conn.Write([]byte{0x00, 0x5a, 0x01, 0x02, 0x00, 0x00, 0x00, 0x00}) + s4.clientConn.Write([]byte{0x00, 0x5a, 0x01, 0x02, 0x00, 0x00, 0x00, 0x00}) - log.Printf("connecting to %s", target) + //log.Printf("connecting to %s\r\n", target) // connect to the target - s4.server_conn, err = net.Dial("tcp", target) + s4.serverConn, err = s4.dial("tcp", target) if err != nil { return err } diff --git a/socks5.go b/socks5.go old mode 100755 new mode 100644 index dd2a8f1..d624a4d --- a/socks5.go +++ b/socks5.go @@ -9,14 +9,41 @@ import ( "encoding/binary" ) +/* +socks5 protocol + +initial + +byte | 0 | 1 | 2 | ...... | n | + |0x05|num auth| auth methods | + + +reply + +byte | 0 | 1 | + |0x05| auth| + + +request + +byte | 0 | 1 | 2 | 3 | 4 | .. | n-2 | n-1| n | + |0x05|cmd|0x00|addrtype| addr | port | + +response +byte |0 | 1 | 2 | 3 | 4 | .. | n-2 | n-1 | n | + |0x05|status|0x00|addrtype| addr | port | + +*/ type socks5Conn struct { //addr string - client_conn net.Conn - server_conn net.Conn + clientConn net.Conn + serverConn net.Conn + dial dialFunc } func (s5 *socks5Conn) Serve() { defer s5.Close() + if err := s5.handshake(); err != nil { log.Println(err) return @@ -33,43 +60,50 @@ func (s5 *socks5Conn) handshake() error { // only process auth methods here buf := make([]byte, 258) - n, err := io.ReadAtLeast(s5.client_conn, buf, 1) + + // read auth methods + n, err := io.ReadAtLeast(s5.clientConn, buf, 1) if err != nil { return err } l := int(buf[0]) + 1 if n < l { - _, err := io.ReadFull(s5.client_conn, buf[n:l]) + // read remains data + _, err := io.ReadFull(s5.clientConn, buf[n:l]) if err != nil { return err } } // no auth required - s5.client_conn.Write([]byte{0x05, 0x00}) + s5.clientConn.Write([]byte{0x05, 0x00}) return nil } func (s5 *socks5Conn) processRequest() error { buf := make([]byte, 258) - n, err := io.ReadAtLeast(s5.client_conn, buf, 10) + + // read header + n, err := io.ReadAtLeast(s5.clientConn, buf, 10) if err != nil { return err } + if buf[0] != socks5Version { - return fmt.Errorf("error version %s", buf[0]) + return fmt.Errorf("error version %d", buf[0]) } - // only support connect + // command only support connect if buf[1] != cmdConnect { return fmt.Errorf("unsupported command %s", buf[1]) } - hlen := 0 - host := "" - msglen := 0 + hlen := 0 // target address length + host := "" // target address + msglen := 0 // header length + switch buf[3] { case addrTypeIPv4: hlen = 4 @@ -82,7 +116,8 @@ func (s5 *socks5Conn) processRequest() error { msglen = 6 + hlen if n < msglen { - _, err := io.ReadFull(s5.client_conn, buf[n:msglen]) + // read remains header + _, err := io.ReadFull(s5.clientConn, buf[n:msglen]) if err != nil { return err } @@ -96,18 +131,20 @@ func (s5 *socks5Conn) processRequest() error { host = net.IP(addr).String() } + // get target port port := binary.BigEndian.Uint16(buf[msglen-2 : msglen]) + // target address target := net.JoinHostPort(host, fmt.Sprintf("%d", port)) // reply user with connect success // if dial to target failed, user will receive connection reset - s5.client_conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01}) + s5.clientConn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01}) - log.Printf("connecing to %s", target) + //log.Printf("connecing to %s\r\n", target) // connect to the target - s5.server_conn, err = net.Dial("tcp", target) + s5.serverConn, err = s5.dial("tcp", target) if err != nil { return err } @@ -119,15 +156,27 @@ func (s5 *socks5Conn) processRequest() error { } func (s5 *socks5Conn) forward() { - go io.Copy(s5.client_conn, s5.server_conn) - io.Copy(s5.server_conn, s5.client_conn) + + c := make(chan int, 2) + + go func() { + io.Copy(s5.clientConn, s5.serverConn) + c <- 1 + }() + + go func() { + io.Copy(s5.serverConn, s5.clientConn) + c <- 1 + }() + + <-c } func (s5 *socks5Conn) Close() { - if s5.server_conn != nil { - s5.server_conn.Close() + if s5.serverConn != nil { + s5.serverConn.Close() } - if s5.client_conn != nil { - s5.client_conn.Close() + if s5.clientConn != nil { + s5.clientConn.Close() } }