commit 30000054d0d9d04845b9cd2096d241ec4364b997 Author: dingjun Date: Thu Dec 22 14:24:13 2016 +0800 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..08bd99b --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +*~ +fileserver* diff --git a/dial.go b/dial.go new file mode 100644 index 0000000..5d771ea --- /dev/null +++ b/dial.go @@ -0,0 +1,75 @@ +package main + +import ( + "context" + "net" + "time" +) + +var dialer *net.Dialer + +func dial(network, addr string) (net.Conn, error) { + var err error + + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + if ip := net.ParseIP(host); ip != nil { + return dialer.Dial(network, addr) + } + + ips, err := lookupHost(host) + if err != nil { + return nil, err + } + + var conn net.Conn + + for _, ip := range ips { + address := net.JoinHostPort(ip, port) + if conn, err = dialer.Dial(network, address); err == nil { + return conn, err + } + } + + // return last error + return conn, err +} + +func dialContext(ctx context.Context, network, addr string) (net.Conn, error) { + var err error + + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + if ip := net.ParseIP(host); ip != nil { + return dialer.DialContext(ctx, network, addr) + } + + ips, err := lookupHost(host) + if err != nil { + return nil, err + } + + var conn net.Conn + + for _, ip := range ips { + address := net.JoinHostPort(ip, port) + if conn, err = dialer.DialContext(ctx, network, address); err == nil { + return conn, err + } + } + + // return last error + return conn, err +} + +func init() { + dialer = &net.Dialer{ + Timeout: 2 * time.Second, + } +} diff --git a/dns.go b/dns.go new file mode 100644 index 0000000..0a91795 --- /dev/null +++ b/dns.go @@ -0,0 +1,127 @@ +// +build linux darwin + +package main + +import ( + "fmt" + "github.com/miekg/dns" + "log" + "time" +) + +var clientConfig *dns.ClientConfig +var dnsClient *dns.Client + +func lookupHost(host string) ([]string, error) { + var result = []string{} + var err error + + ret, err1 := getAAAA(host) + if err1 == nil { + result = append(result, ret...) + } else { + err = err1 + } + + ret1, err2 := getA(host) + if err2 == nil { + result = append(result, ret1...) + } else { + err = err2 + } + + if len(result) > 0 { + return result, nil + } + + if err == nil { + return nil, fmt.Errorf("dns lookup failed for %s", host) + } + + return nil, err +} + +func getA(host string) ([]string, error) { + var err error + var msg *dns.Msg + var result = []string{} + + m := new(dns.Msg) + m.SetQuestion(dns.Fqdn(host), dns.TypeA) + + for _, srv := range clientConfig.Servers { + dnsserver := fmt.Sprintf("%s:%s", srv, clientConfig.Port) + msg, _, err = dnsClient.Exchange(m, dnsserver) + if err == nil { + break + } else { + log.Println(err) + } + } + + if err != nil { + return result, err + } + + for _, rr := range msg.Answer { + if a, ok := rr.(*dns.A); ok { + result = append(result, a.A.String()) + } + + } + + return result, nil +} + +func getAAAA(host string) ([]string, error) { + var err error + var msg *dns.Msg + var result = []string{} + + m := new(dns.Msg) + m.SetQuestion(dns.Fqdn(host), dns.TypeAAAA) + + for _, srv := range clientConfig.Servers { + dnsserver := fmt.Sprintf("%s:%s", srv, clientConfig.Port) + + msg, _, err = dnsClient.Exchange(m, dnsserver) + if err == nil { + break + } else { + log.Println(err) + } + } + if err != nil { + return result, err + } + + for _, rr := range msg.Answer { + if aaaa, ok := rr.(*dns.AAAA); ok { + result = append(result, aaaa.AAAA.String()) + } + + } + + return result, nil +} + +func init() { + var err error + clientConfig, err = dns.ClientConfigFromFile("/etc/resolv.conf") + if err != nil { + clientConfig = &dns.ClientConfig{ + Servers: []string{"8.8.8.8", "4.2.2.2"}, + Port: "53", + Ndots: 1, + Timeout: 2, + Attempts: 3, + } + } + + //clientConfig.Port = "53" + dnsClient = &dns.Client{ + Net: "udp", + Timeout: time.Duration(clientConfig.Timeout) * time.Second, + UDPSize: 4096, + } +} diff --git a/dns_test.go b/dns_test.go new file mode 100644 index 0000000..91fd1c7 --- /dev/null +++ b/dns_test.go @@ -0,0 +1,21 @@ +package main + +import ( + "fmt" + //"net" + "testing" +) + +func TestLookuphost(t *testing.T) { + for _, h := range []string{"www.ifeng.com", "www.taobao.com", + "www.baidu.com", "www.sina.com.cn", "www.163.com", "www.qq.com", + "www.google.com", "www.facebook.com", "twitter.com", + } { + + ret, err := lookupHost(h) + if err != nil { + t.Errorf("%s: %s", h, err) + } + fmt.Printf("%#v\n", ret) + } +} diff --git a/dns_windows.go b/dns_windows.go new file mode 100644 index 0000000..0665eb8 --- /dev/null +++ b/dns_windows.go @@ -0,0 +1,9 @@ +package main + +import ( + "net" +) + +func lookupHost(host string) ([]string, error) { + return net.LookupHost(host) +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..5582ad3 --- /dev/null +++ b/server.go @@ -0,0 +1,147 @@ +package main + +import ( + "flag" + "fmt" + "io" + "log" + "net" + "net/http" + "os" + "strings" + "time" +) + +var defaultTransport http.RoundTripper = &http.Transport{ + DialContext: dialContext, + MaxIdleConns: 50, + IdleConnTimeout: 30 * time.Second, + MaxIdleConnsPerHost: 3, + //ResponseHeaderTimeout: 2 * time.Second, +} + +func main() { + var docroot string + var enableProxy bool + var port int + + curdir, err := os.Getwd() + if err != nil { + curdir = "." + } + + flag.StringVar(&docroot, "docroot", curdir, "document root") + flag.BoolVar(&enableProxy, "enable_proxy", false, "enable proxy function") + flag.IntVar(&port, "port", 8080, "the port listen to") + flag.Parse() + + http.Handle("/", http.FileServer(http.Dir(docroot))) + + log.Printf("Listen on :%d", port) + log.Printf("document root %s", docroot) + if enableProxy { + log.Println("proxy enabled") + } + err = http.ListenAndServe(fmt.Sprintf(":%d", port), &handler{ + enableProxy: enableProxy, + }) + if err != nil { + log.Fatal(err) + } +} + +type handler struct { + enableProxy bool +} + +func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.RequestURI[0] == '/' { + http.DefaultServeMux.ServeHTTP(w, r) + return + } + + if !h.enableProxy { + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf(w, "

page not found!

") + return + } + if r.Method == http.MethodConnect { + h.handleCONNECT(w, r) + } else { + h.handleHTTP(w, r) + } +} + +func (h *handler) handleHTTP(w http.ResponseWriter, r *http.Request) { + + var resp *http.Response + var err error + + r.Header.Del("proxy-connection") + + resp, err = defaultTransport.RoundTrip(r) + if err != nil { + log.Printf("RoundTrip: %s", err) + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusServiceUnavailable) + fmt.Fprintf(w, "%s", err) + return + } + + defer resp.Body.Close() + + hdr := w.Header() + + //resp.Header.Del("connection") + + for k, v := range resp.Header { + for _, v1 := range v { + hdr.Add(k, v1) + } + } + + w.WriteHeader(resp.StatusCode) + + io.Copy(w, resp.Body) +} + +func (h *handler) handleCONNECT(w http.ResponseWriter, r *http.Request) { + host := r.RequestURI + if !strings.Contains(host, ":") { + host = fmt.Sprintf("%s:443", host) + } + + var conn net.Conn + var err error + + conn, err = dial("tcp", host) + if err != nil { + log.Printf("net.dial: %s", err) + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusServiceUnavailable) + fmt.Fprintf(w, "dial to %s failed: %s", host, err) + return + } + + hj, _ := w.(http.Hijacker) + conn1, _, _ := hj.Hijack() + + fmt.Fprintf(conn1, "%s 200 connection established\r\n\r\n", r.Proto) + + pipeAndClose(conn, conn1) +} + +func pipeAndClose(r1, r2 io.ReadWriteCloser) { + ch := make(chan int, 2) + go func() { + io.Copy(r1, r2) + ch <- 1 + }() + + go func() { + io.Copy(r2, r1) + ch <- 1 + }() + + <-ch +}