diff --git a/cmd/server/server.go b/cmd/server/server.go old mode 100755 new mode 100644 index 249ec7f..527f3c8 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -1,9 +1,10 @@ package main import ( - "github.com/fangdingjun/socks" + socks "github.com/fangdingjun/socks-go" "log" "net" + "time" ) func main() { @@ -18,7 +19,8 @@ func main() { continue } log.Printf("connected from %s", c.RemoteAddr()) - s := socks.SocksConn{ClientConn: c} + d := net.Dialer{Timeout: 10 * time.Second} + s := socks.SocksConn{ClientConn: c, Dial: d.Dial} go s.Serve() } } diff --git a/socks.go b/socks.go old mode 100755 new mode 100644 index 4c58e97..bba1b87 --- a/socks.go +++ b/socks.go @@ -4,6 +4,7 @@ import ( "io" "log" "net" + "time" ) const ( @@ -15,8 +16,11 @@ const ( addrTypeIPv6 = 0x04 ) +type dialFunc func(network, addr string) (net.Conn, error) + type SocksConn struct { ClientConn net.Conn + Dial dialFunc } func (s *SocksConn) Serve() { @@ -25,12 +29,18 @@ func (s *SocksConn) Serve() { // read version io.ReadFull(s.ClientConn, buf) + dial := s.Dial + if s.Dial == nil { + d := net.Dialer{Timeout: 10 * time.Second} + dial = d.Dial + } + switch buf[0] { case socks4Version: - s4 := socks4Conn{client_conn: s.ClientConn} + s4 := socks4Conn{client_conn: s.ClientConn, dial: dial} s4.Serve() case socks5Version: - s5 := socks5Conn{client_conn: s.ClientConn} + s5 := socks5Conn{client_conn: s.ClientConn, dial: dial} s5.Serve() default: log.Printf("error version %s", buf[0]) diff --git a/socks4.go b/socks4.go old mode 100755 new mode 100644 index 9b9da53..af54654 --- a/socks4.go +++ b/socks4.go @@ -11,6 +11,7 @@ import ( type socks4Conn struct { server_conn net.Conn client_conn net.Conn + dial dialFunc } func (s4 *socks4Conn) Serve() { @@ -92,7 +93,7 @@ func (s4 *socks4Conn) processRequest() error { log.Printf("connecting to %s", target) // connect to the target - s4.server_conn, err = net.Dial("tcp", target) + s4.server_conn, err = s4.dial("tcp", target) if err != nil { return err } diff --git a/socks5.go b/socks5.go old mode 100755 new mode 100644 index dd2a8f1..868c602 --- a/socks5.go +++ b/socks5.go @@ -13,6 +13,7 @@ type socks5Conn struct { //addr string client_conn net.Conn server_conn net.Conn + dial dialFunc } func (s5 *socks5Conn) Serve() { @@ -107,7 +108,7 @@ func (s5 *socks5Conn) processRequest() error { log.Printf("connecing to %s", target) // connect to the target - s5.server_conn, err = net.Dial("tcp", target) + s5.server_conn, err = s5.dial("tcp", target) if err != nil { return err }