diff --git a/dial.go b/dial.go deleted file mode 100644 index 5d771ea..0000000 --- a/dial.go +++ /dev/null @@ -1,75 +0,0 @@ -package main - -import ( - "context" - "net" - "time" -) - -var dialer *net.Dialer - -func dial(network, addr string) (net.Conn, error) { - var err error - - host, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - - if ip := net.ParseIP(host); ip != nil { - return dialer.Dial(network, addr) - } - - ips, err := lookupHost(host) - if err != nil { - return nil, err - } - - var conn net.Conn - - for _, ip := range ips { - address := net.JoinHostPort(ip, port) - if conn, err = dialer.Dial(network, address); err == nil { - return conn, err - } - } - - // return last error - return conn, err -} - -func dialContext(ctx context.Context, network, addr string) (net.Conn, error) { - var err error - - host, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - - if ip := net.ParseIP(host); ip != nil { - return dialer.DialContext(ctx, network, addr) - } - - ips, err := lookupHost(host) - if err != nil { - return nil, err - } - - var conn net.Conn - - for _, ip := range ips { - address := net.JoinHostPort(ip, port) - if conn, err = dialer.DialContext(ctx, network, address); err == nil { - return conn, err - } - } - - // return last error - return conn, err -} - -func init() { - dialer = &net.Dialer{ - Timeout: 2 * time.Second, - } -} diff --git a/dns.go b/dns.go deleted file mode 100644 index 768ac38..0000000 --- a/dns.go +++ /dev/null @@ -1,127 +0,0 @@ -// +build unix - -package main - -import ( - "fmt" - "github.com/miekg/dns" - "log" - "time" -) - -var clientConfig *dns.ClientConfig -var dnsClient *dns.Client - -func lookupHost(host string) ([]string, error) { - var result = []string{} - var err error - - ret, err1 := getAAAA(host) - if err1 == nil { - result = append(result, ret...) - } else { - err = err1 - } - - ret1, err2 := getA(host) - if err2 == nil { - result = append(result, ret1...) - } else { - err = err2 - } - - if len(result) > 0 { - return result, nil - } - - if err == nil { - return nil, fmt.Errorf("dns lookup failed for %s", host) - } - - return nil, err -} - -func getA(host string) ([]string, error) { - var err error - var msg *dns.Msg - var result = []string{} - - m := new(dns.Msg) - m.SetQuestion(dns.Fqdn(host), dns.TypeA) - - for _, srv := range clientConfig.Servers { - dnsserver := fmt.Sprintf("%s:%s", srv, clientConfig.Port) - msg, _, err = dnsClient.Exchange(m, dnsserver) - if err == nil { - break - } else { - log.Println(err) - } - } - - if err != nil { - return result, err - } - - for _, rr := range msg.Answer { - if a, ok := rr.(*dns.A); ok { - result = append(result, a.A.String()) - } - - } - - return result, nil -} - -func getAAAA(host string) ([]string, error) { - var err error - var msg *dns.Msg - var result = []string{} - - m := new(dns.Msg) - m.SetQuestion(dns.Fqdn(host), dns.TypeAAAA) - - for _, srv := range clientConfig.Servers { - dnsserver := fmt.Sprintf("%s:%s", srv, clientConfig.Port) - - msg, _, err = dnsClient.Exchange(m, dnsserver) - if err == nil { - break - } else { - log.Println(err) - } - } - if err != nil { - return result, err - } - - for _, rr := range msg.Answer { - if aaaa, ok := rr.(*dns.AAAA); ok { - result = append(result, aaaa.AAAA.String()) - } - - } - - return result, nil -} - -func init() { - var err error - clientConfig, err = dns.ClientConfigFromFile("/etc/resolv.conf") - if err != nil { - clientConfig = &dns.ClientConfig{ - Servers: []string{"8.8.8.8", "4.2.2.2"}, - Port: "53", - Ndots: 1, - Timeout: 2, - Attempts: 3, - } - } - - //clientConfig.Port = "53" - dnsClient = &dns.Client{ - Net: "udp", - Timeout: time.Duration(clientConfig.Timeout) * time.Second, - UDPSize: 4096, - } -} diff --git a/dns_other.go b/dns_other.go deleted file mode 100644 index 11fc5f1..0000000 --- a/dns_other.go +++ /dev/null @@ -1,11 +0,0 @@ -// +build !unix - -package main - -import ( - "net" -) - -func lookupHost(host string) ([]string, error) { - return net.LookupHost(host) -} diff --git a/dns_test.go b/dns_test.go deleted file mode 100644 index 91fd1c7..0000000 --- a/dns_test.go +++ /dev/null @@ -1,21 +0,0 @@ -package main - -import ( - "fmt" - //"net" - "testing" -) - -func TestLookuphost(t *testing.T) { - for _, h := range []string{"www.ifeng.com", "www.taobao.com", - "www.baidu.com", "www.sina.com.cn", "www.163.com", "www.qq.com", - "www.google.com", "www.facebook.com", "twitter.com", - } { - - ret, err := lookupHost(h) - if err != nil { - t.Errorf("%s: %s", h, err) - } - fmt.Printf("%#v\n", ret) - } -} diff --git a/handler.go b/handler.go index 3d5783f..ef0267d 100644 --- a/handler.go +++ b/handler.go @@ -19,11 +19,12 @@ type handler struct { } var defaultTransport http.RoundTripper = &http.Transport{ - DialContext: dialContext, - MaxIdleConns: 50, - IdleConnTimeout: 30 * time.Second, - MaxIdleConnsPerHost: 3, - //ResponseHeaderTimeout: 2 * time.Second, + //DialContext: dialContext, + MaxIdleConns: 50, + IdleConnTimeout: 30 * time.Second, + MaxIdleConnsPerHost: 3, + DisableKeepAlives: true, + ResponseHeaderTimeout: 2 * time.Second, } // ServeHTTP implements the http.Handler interface @@ -132,7 +133,7 @@ func (h *handler) handleCONNECT(w http.ResponseWriter, r *http.Request) { var conn net.Conn var err error - conn, err = dial("tcp", host) + conn, err = net.Dial("tcp", host) if err != nil { log.Printf("net.dial: %s", err) w.Header().Set("Content-Type", "text/plain")