add optional Dial option

the dial function provider a way to connect to server
master
Dingjun 8 years ago
parent 179b61da8d
commit 0cdf31ddaf

@ -1,9 +1,10 @@
package main package main
import ( import (
"github.com/fangdingjun/socks" socks "github.com/fangdingjun/socks-go"
"log" "log"
"net" "net"
"time"
) )
func main() { func main() {
@ -18,7 +19,8 @@ func main() {
continue continue
} }
log.Printf("connected from %s", c.RemoteAddr()) log.Printf("connected from %s", c.RemoteAddr())
s := socks.SocksConn{ClientConn: c} d := net.Dialer{Timeout: 10 * time.Second}
s := socks.SocksConn{ClientConn: c, Dial: d.Dial}
go s.Serve() go s.Serve()
} }
} }

@ -4,6 +4,7 @@ import (
"io" "io"
"log" "log"
"net" "net"
"time"
) )
const ( const (
@ -15,8 +16,11 @@ const (
addrTypeIPv6 = 0x04 addrTypeIPv6 = 0x04
) )
type dialFunc func(network, addr string) (net.Conn, error)
type SocksConn struct { type SocksConn struct {
ClientConn net.Conn ClientConn net.Conn
Dial dialFunc
} }
func (s *SocksConn) Serve() { func (s *SocksConn) Serve() {
@ -25,12 +29,18 @@ func (s *SocksConn) Serve() {
// read version // read version
io.ReadFull(s.ClientConn, buf) io.ReadFull(s.ClientConn, buf)
dial := s.Dial
if s.Dial == nil {
d := net.Dialer{Timeout: 10 * time.Second}
dial = d.Dial
}
switch buf[0] { switch buf[0] {
case socks4Version: case socks4Version:
s4 := socks4Conn{client_conn: s.ClientConn} s4 := socks4Conn{client_conn: s.ClientConn, dial: dial}
s4.Serve() s4.Serve()
case socks5Version: case socks5Version:
s5 := socks5Conn{client_conn: s.ClientConn} s5 := socks5Conn{client_conn: s.ClientConn, dial: dial}
s5.Serve() s5.Serve()
default: default:
log.Printf("error version %s", buf[0]) log.Printf("error version %s", buf[0])

@ -11,6 +11,7 @@ import (
type socks4Conn struct { type socks4Conn struct {
server_conn net.Conn server_conn net.Conn
client_conn net.Conn client_conn net.Conn
dial dialFunc
} }
func (s4 *socks4Conn) Serve() { func (s4 *socks4Conn) Serve() {
@ -92,7 +93,7 @@ func (s4 *socks4Conn) processRequest() error {
log.Printf("connecting to %s", target) log.Printf("connecting to %s", target)
// connect to the target // connect to the target
s4.server_conn, err = net.Dial("tcp", target) s4.server_conn, err = s4.dial("tcp", target)
if err != nil { if err != nil {
return err return err
} }

@ -13,6 +13,7 @@ type socks5Conn struct {
//addr string //addr string
client_conn net.Conn client_conn net.Conn
server_conn net.Conn server_conn net.Conn
dial dialFunc
} }
func (s5 *socks5Conn) Serve() { func (s5 *socks5Conn) Serve() {
@ -107,7 +108,7 @@ func (s5 *socks5Conn) processRequest() error {
log.Printf("connecing to %s", target) log.Printf("connecing to %s", target)
// connect to the target // connect to the target
s5.server_conn, err = net.Dial("tcp", target) s5.server_conn, err = s5.dial("tcp", target)
if err != nil { if err != nil {
return err return err
} }

Loading…
Cancel
Save