From 13c78296f2ba640e77fe322d3a7f3a92f21a2345 Mon Sep 17 00:00:00 2001 From: fangdingjun Date: Mon, 19 Dec 2016 16:36:15 +0800 Subject: [PATCH] gdns: first version --- .gitignore | 2 + README.md | 21 ++++ conf.go | 212 +++++++++++++++++++++++++++++++++++++++++ conf_test.go | 121 +++++++++++++++++++++++ config_example.yaml | 41 ++++++++ google_httpdns.go | 183 +++++++++++++++++++++++++++++++++++ handler.go | 165 ++++++++++++++++++++++++++++++++ log.go | 31 ++++++ server.go | 33 +++++++ testdata/blacklist.dat | 5 + testdata/cn.dat | 8 ++ testdata/config.yaml | 31 ++++++ testdata/hosts | 9 ++ testdata/us.dat | 12 +++ 14 files changed, 874 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 conf.go create mode 100644 conf_test.go create mode 100644 config_example.yaml create mode 100644 google_httpdns.go create mode 100644 handler.go create mode 100644 log.go create mode 100644 server.go create mode 100644 testdata/blacklist.dat create mode 100644 testdata/cn.dat create mode 100644 testdata/config.yaml create mode 100644 testdata/hosts create mode 100644 testdata/us.dat diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8bde6a6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +*~ +dns* diff --git a/README.md b/README.md new file mode 100644 index 0000000..1102e64 --- /dev/null +++ b/README.md @@ -0,0 +1,21 @@ +gdns +==== + +gdns is a dns proxy server + + +features +======= + +- support forward the query by rule, + different domains use different upstream server +- support ip black list +- support google https dns + +usage +===== + + go get github.com/fangdingjun/gdns + cp $GOPATH/src/github.com/fangdingjun/gdns/config_example.yaml config.yaml + vim config.yaml + $GOPATH/bin/gdns -c config.yaml diff --git a/conf.go b/conf.go new file mode 100644 index 0000000..37a62e8 --- /dev/null +++ b/conf.go @@ -0,0 +1,212 @@ +package main + +import ( + "bufio" + "github.com/go-yaml/yaml" + "io/ioutil" + "os" + "strings" +) + +type conf struct { + Listen []addr + BlacklistFile string + HostFile string + ForwardRules []rule + DefaultUpstream []addr + Timeout int + Debug bool + blacklist item + hosts hostitem +} + +type rule struct { + Server []addr + DomainFile string + domains item +} + +type addr struct { + Host string + Port int + Network string +} + +func loadConfig(f string) (*conf, error) { + c := new(conf) + data, err := ioutil.ReadFile(f) + if err != nil { + return nil, err + } + + if err := yaml.Unmarshal(data, c); err != nil { + return nil, err + } + + if c.Debug { + logLevel = DEBUG + } + + if c.blacklist == nil { + c.blacklist = item{} + } + + if c.Timeout == 0 { + c.Timeout = 2 + } + + if err := loadItemFile(c.blacklist, c.BlacklistFile); err != nil { + return nil, err + } + + for i := range c.ForwardRules { + if c.ForwardRules[i].domains == nil { + c.ForwardRules[i].domains = item{} + } + if err := loadItemFile(c.ForwardRules[i].domains, + c.ForwardRules[i].DomainFile); err != nil { + return nil, err + } + } + + if c.hosts == nil { + c.hosts = hostitem{} + } + + if err := loadHostsFile(c.hosts, c.HostFile); err != nil { + return nil, err + } + + return c, nil +} + +func loadHostsFile(h hostitem, f string) error { + if f == "" { + return nil + } + fd, err := os.Open(f) + if err != nil { + return err + } + defer fd.Close() + + r := bufio.NewReader(fd) + for { + s, err := r.ReadString('\n') + if err != nil { + break + } + s1 := strings.Trim(s, " \t\r\n") + + // ignore blank line and comment + if s1 == "" || s1[0] == '#' { + continue + } + s1 = strings.Replace(s1, "\t", " ", -1) + s1 = strings.Trim(s1, " \t\r\n") + ss := strings.Split(s1, " ") + + // ipv4 + t := 1 + if strings.Index(ss[0], ":") != -1 { + // ipv6 + t = 28 + } + + for _, s2 := range ss[1:] { + if s2 == "" { + continue + } + + h.add(s2, ss[0], t) + } + + } + return nil +} + +func loadItemFile(it item, f string) error { + if f == "" { + return nil + } + fd, err := os.Open(f) + if err != nil { + return err + } + defer fd.Close() + + r := bufio.NewReader(fd) + for { + s, err := r.ReadString('\n') + if s != "" { + s1 := strings.Trim(s, " \r\n") + if s1 != "" && s1[0] != '#' { + it.add(s1) + } + } + if err != nil { + break + } + } + return nil +} + +type item map[string]int + +func (it item) has(s string) bool { + ss := strings.Split(s, ".") + + for i := 0; i < len(ss); i++ { + s1 := strings.Join(ss[i:], ".") + if _, ok := it[s1]; ok { + return true + } + } + return false +} + +func (it item) exists(s string) bool { + _, ok := it[s] + return ok +} + +func (it item) add(s string) { + it[s] = 1 +} + +type hostitem map[string][]hostentry + +func (ht hostitem) get(domain string, t int) string { + if v, ok := ht[domain]; ok { + for _, v1 := range v { + if v1.domain == domain && v1.t == t { + return v1.ip + } + } + } + return "" +} + +func (ht hostitem) add(domain, ip string, t int) { + if v, ok := ht[domain]; ok { + exists := false + for _, v1 := range v { + if v1.domain == domain && v1.ip == ip && v1.t == t { + exists = true + break + } + } + if !exists { + ht[domain] = append(ht[domain], hostentry{domain, ip, t}) + } + } else { + v1 := []hostentry{{domain, ip, t}} + ht[domain] = v1 + } +} + +type hostentry struct { + domain string + ip string + t int +} diff --git a/conf_test.go b/conf_test.go new file mode 100644 index 0000000..91260b7 --- /dev/null +++ b/conf_test.go @@ -0,0 +1,121 @@ +package main + +import ( + "fmt" + //"github.com/go-yaml/yaml" + //"io/ioutil" + "testing" +) + +func TestConf(t *testing.T) { + c, err := loadConfig("testdata/config.yaml") + if err != nil { + t.Fatal(err) + } + //fmt.Printf("%#v\n", c) + if len(c.Listen) != 2 { + t.Errorf("expected listers 2, got %d", len(c.Listen)) + } + + if !c.blacklist.exists("1.2.3.4") { + fmt.Printf("%#v\n", c.blacklist) + t.Errorf("blacklist load failed") + } + if c.hosts.get("localhost", 1) != "127.0.0.1" { + fmt.Printf("%#v\n", c.hosts) + t.Errorf("hosts load failed") + } + if c.hosts.get("localhost", 28) != "::1" { + fmt.Printf("%#v\n", c.hosts) + t.Errorf("hosts load failed") + } + if len(c.ForwardRules) != 2 { + fmt.Printf("%#v\n", c.ForwardRules) + t.Errorf("expected rules 2, got %d", len(c.ForwardRules)) + } + if !c.ForwardRules[0].domains.has("a.com") { + fmt.Printf("%#v\n", c.ForwardRules[0].domains) + t.Errorf("some domains should exit, may be load config failed") + } + if !c.ForwardRules[1].domains.has("d.com") { + fmt.Printf("%#v\n", c.ForwardRules[1].domains) + t.Errorf("some domains should exit, may be load config failed") + } +} + +func TestItemExists(t *testing.T) { + it := item{ + "google.cn": 1, + "www.baidu.com": 1, + "org": 1, + } + + testdata := []struct { + d string + b bool + }{ + {"google.cn", true}, + {"www.google.cn", false}, + {"www.a.org", false}, + } + + for _, d := range testdata { + b1 := it.exists(d.d) + if b1 != d.b { + t.Errorf("%s, expected %v, got %v", d.d, d.b, b1) + } + } +} + +func TestItemHas(t *testing.T) { + it := item{ + "google.cn": 1, + "www.baidu.com": 1, + "org": 1, + } + + testdata := []struct { + d string + b bool + }{ + {"google.cn", true}, + {"www.google.cn", true}, + {"www.a.org", true}, + {"pan.baidu.com", false}, + {"abc.org", true}, + } + + for _, d := range testdata { + b1 := it.has(d.d) + if b1 != d.b { + t.Errorf("%s, expected %v, got %v", d.d, d.b, b1) + } + } +} + +func TestItemAdd(t *testing.T) { + it := item{} + it.add("www.example.org") + if !it.exists("www.example.org") { + t.Errorf("add failed") + } +} + +func TestHostitem(t *testing.T) { + ht := hostitem{} + testdata := []hostentry{ + {"www.google.com", "127.0.0.1", 1}, + {"www.google.com", "127.0.0.2", 28}, + {"www.example.org", "127.0.0.3", 28}, + {"www.abc.org", "127.0.0.4", 1}, + } + + for _, d := range testdata { + ht.add(d.domain, d.ip, d.t) + ip := ht.get(d.domain, d.t) + if ip != d.ip { + t.Errorf("%s, expected %s, got %s", d.domain, d.ip, ip) + } + } + //fmt.Printf("%v\n", ht) +} diff --git a/config_example.yaml b/config_example.yaml new file mode 100644 index 0000000..d26a2a1 --- /dev/null +++ b/config_example.yaml @@ -0,0 +1,41 @@ + +listen: + - network: tcp + host: 0.0.0.0 + port: 1053 + + - network: udp + host: 0.0.0.0 + port: 1053 + +blacklistfile: ./testdata/blacklist.dat +hostfile: ./testdata/hosts +timeout: 2 +debug: true + +defaultupstream: + - + network: tcp + host: 8.8.8.8 + port: 53 + - + network:udp + host: 8.8.8.8 + port: 53 + +forwardrules: + - domainfile: ./testdata/cn.dat + server: + - network: tcp + host: 114.114.114.114 + port: 53 + + - domainfile: testdata/us.dat + server: + - network: tcp + host: 8.8.8.8 + port: 53 + + - network: https + host: 74.175.200.100 + port: 443 diff --git a/google_httpdns.go b/google_httpdns.go new file mode 100644 index 0000000..1d2b2f3 --- /dev/null +++ b/google_httpdns.go @@ -0,0 +1,183 @@ +package main + +import ( + "crypto/tls" + "encoding/json" + "fmt" + "github.com/miekg/dns" + "io/ioutil" + "net" + "net/http" + "net/url" + "sync" + "time" +) + +// ServerAddr is Google dns server ip +var ServerAddr = "74.125.200.100" +var queryIPApi = "https://www.simicloud.com/media/httpbin/ip" + +// GoogleHTTPDns struct +type GoogleHTTPDns struct { + myip string + l sync.Mutex +} + +func (h *GoogleHTTPDns) getMyIP() string { + if h.myip != "" { + return h.myip + } + go h.queryMyIP() + return "" +} + +type ipAPI struct { + IP string `json:"origin"` +} + +func (h *GoogleHTTPDns) queryMyIP() { + h.l.Lock() + defer h.l.Unlock() + if h.myip != "" { + //fmt.Printf("myip: %s\n", h.myip) + return + } + //fmt.Println("get ip...") + res, err := http.Get(queryIPApi) + if err != nil { + //fmt.Println(err) + return + } + defer res.Body.Close() + d, err := ioutil.ReadAll(res.Body) + if err != nil { + //fmt.Println(err) + return + } + //fmt.Printf("%s\n", string(d)) + ip := ipAPI{} + err = json.Unmarshal(d, &ip) + if err != nil { + //fmt.Println(err) + return + } + //fmt.Printf("got: %s\n", ip.Ip) + h.myip = ip.IP +} + +func (h *GoogleHTTPDns) getMyNet() string { + ip := h.getMyIP() + if ip == "" { + return "" + } + mask := net.IPv4Mask(255, 255, 255, 0) + ipByte := net.ParseIP(ip) + ipnet := net.IPNet{ipByte.Mask(mask), mask} + return ipnet.String() +} + +// Exchange send query to server and return the response +func (h *GoogleHTTPDns) Exchange(m *dns.Msg, addr string) (*dns.Msg, time.Duration, error) { + name := m.Question[0].Name + t := dns.TypeToString[m.Question[0].Qtype] + mynet := h.getMyNet() + r, err := queryGoogleHTTPDNS(name, t, mynet, "", addr) + if err != nil { + return nil, 0, err + } + + m1 := new(dns.Msg) + + m1.SetRcode(m, r.Status) + for _, rr := range r.Answer { + _rr := fmt.Sprintf("%s %d IN %s %s", rr.Name, rr.TTL, + dns.TypeToString[uint16(rr.Type)], rr.Data) + + an, err := dns.NewRR(_rr) + if err != nil { + return nil, 0, err + } + m1.Answer = append(m1.Answer, an) + } + m1.Truncated = r.TC + m1.RecursionDesired = r.RD + m1.RecursionAvailable = r.RA + m1.AuthenticatedData = r.AD + m1.CheckingDisabled = r.CD + return m1, 0, nil + +} + +// Response represent the dns response from server +type Response struct { + Status int + TC bool + RD bool + RA bool + AD bool + CD bool + Question []RR + Answer []RR + Additional []RR + EDNSClientSubnet string `json:"edns_client_subnet"` + Comment string +} + +// RR represent the RR record +type RR struct { + Name string `json:"name"` + Type int `json:"type"` + TTL int + Data string `json:"data"` +} + +var httpclient = &http.Client{ + Timeout: 5 * time.Second, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ServerName: "dns.google.com"}, + TLSHandshakeTimeout: 3 * time.Second, + }, +} + +func queryGoogleHTTPDNS(name, t, ednsClientSubnet, padding, srvAddr string) (*Response, error) { + srvaddr := ServerAddr + if srvAddr != "" { + srvaddr = srvAddr + } + v := url.Values{} + v.Add("name", name) + v.Add("type", t) + + if ednsClientSubnet != "" { + v.Add("edns_client_subnet", ednsClientSubnet) + } + + if padding != "" { + v.Add("random_padding", padding) + } + + u := fmt.Sprintf("https://%s/resolve?%s", srvaddr, v.Encode()) + r, _ := http.NewRequest("GET", u, nil) + r.Host = "dns.google.com" + //r.URL.Host = "dns.google.com" + + res, err := httpclient.Do(r) + if err != nil { + return nil, err + } + defer res.Body.Close() + + data, err := ioutil.ReadAll(res.Body) + if err != nil { + return nil, err + } + + d := Response{} + err = json.Unmarshal(data, &d) + + if err != nil { + return nil, err + } + + return &d, nil +} diff --git a/handler.go b/handler.go new file mode 100644 index 0000000..d93ab87 --- /dev/null +++ b/handler.go @@ -0,0 +1,165 @@ +package main + +import ( + "errors" + "fmt" + "github.com/miekg/dns" + "strings" + "time" +) + +type dnsClient interface { + Exchange(m *dns.Msg, addr string) (*dns.Msg, time.Duration, error) +} + +type dnsHandler struct { + cfg *conf + tcpclient dnsClient + udpclient dnsClient + httpsclient dnsClient +} + +func newDNSHandler(cfg *conf) *dnsHandler { + return &dnsHandler{ + cfg: cfg, + tcpclient: &dns.Client{Net: "tcp", Timeout: 2 * time.Second}, + udpclient: &dns.Client{Net: "udp", Timeout: 2 * time.Second}, + httpsclient: &GoogleHTTPDns{}, + } + +} + +// ServerDNS implements the dns.Handler interface +func (h *dnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + domain := r.Question[0].Name + + if ok := h.answerFromHosts(w, r); ok { + return + } + + srvs := h.getUpstreamServer(domain) + if srvs == nil { + srvs = h.cfg.DefaultUpstream + } + + if msg, err := h.getAnswerFromUpstream(r, srvs); err == nil { + w.WriteMsg(msg) + return + } + + dns.HandleFailed(w, r) + +} + +func (h *dnsHandler) getUpstreamServer(domain string) []addr { + for _, srv := range h.cfg.ForwardRules { + if ok := srv.domains.has(strings.Trim(domain, ".")); ok { + return srv.Server + } + } + return nil +} + +func (h *dnsHandler) queryUpstream(r *dns.Msg, srv addr, ch chan *dns.Msg) { + var m *dns.Msg + var err error + + switch srv.Network { + case "tcp": + info("query %s IN %s, forward to %s:%d through tcp", + r.Question[0].Name, + dns.TypeToString[r.Question[0].Qtype], + srv.Host, + srv.Port) + m, _, err = h.tcpclient.Exchange(r, fmt.Sprintf("%s:%d", srv.Host, srv.Port)) + case "udp": + info("query %s IN %s, forward to %s:%d through udp", + r.Question[0].Name, + dns.TypeToString[r.Question[0].Qtype], + srv.Host, + srv.Port) + m, _, err = h.tcpclient.Exchange(r, fmt.Sprintf("%s:%d", srv.Host, srv.Port)) + case "https": + info("query %s IN %s, forward to %s:%d through https", + r.Question[0].Name, + dns.TypeToString[r.Question[0].Qtype], + srv.Host, + srv.Port) + m, _, err = h.httpsclient.Exchange(r, fmt.Sprintf("%s:%d", srv.Host, srv.Port)) + default: + // ignore + } + + if err == nil { + select { + case ch <- m: + default: + } + } else { + errorlog("%s", err) + } +} + +func (h *dnsHandler) getAnswerFromUpstream(r *dns.Msg, servers []addr) (*dns.Msg, error) { + ch := make(chan *dns.Msg, 5) + for _, srv := range servers { + go func(a addr) { + h.queryUpstream(r, a, ch) + }(srv) + } + + var savedErr *dns.Msg + for { + select { + case m := <-ch: + if m.Rcode == dns.RcodeSuccess && !h.inBlacklist(m) { + return m, nil + } + savedErr = m + case <-time.After(time.Duration(h.cfg.Timeout) * time.Second): + if savedErr != nil { + return savedErr, nil + } + info("query %s IN %s, timeout", r.Question[0].Name, dns.TypeToString[r.Question[0].Qtype]) + return nil, errors.New("timeout") + } + } +} + +func (h *dnsHandler) inBlacklist(m *dns.Msg) bool { + var ip string + for _, rr := range m.Answer { + if a, ok := rr.(*dns.A); ok { + ip = a.String() + } else if aaaa, ok := rr.(*dns.AAAA); ok { + ip = aaaa.String() + } else { + ip = "" + } + if ip != "" && h.cfg.blacklist.exists(ip) { + info("%s in blacklist", ip) + return true + } + } + return false +} + +func (h *dnsHandler) answerFromHosts(w dns.ResponseWriter, r *dns.Msg) bool { + domain := r.Question[0].Name + t := r.Question[0].Qtype + + ip := h.cfg.hosts.get(strings.Trim(domain, "."), int(t)) + if ip != "" { + rr, _ := dns.NewRR(fmt.Sprintf("%s 3600 IN %s %s", domain, dns.TypeToString[t], ip)) + if rr == nil { + return false + } + msg := new(dns.Msg) + msg.SetReply(r) + msg.Answer = append(msg.Answer, rr) + w.WriteMsg(msg) + info("query %s IN %s, reply from hosts", domain, dns.TypeToString[t]) + return true + } + return false +} diff --git a/log.go b/log.go new file mode 100644 index 0000000..7592c4f --- /dev/null +++ b/log.go @@ -0,0 +1,31 @@ +package main + +import ( + "log" +) + +const ( + _ = iota + FATAL + ERROR + WARN + NOTICE + INFO + DEBUG +) + +var logLevel = WARN + +func logMsg(l int, fmt string, args ...interface{}) { + if l <= logLevel { + log.Printf(fmt, args...) + } +} + +func info(fmt string, args ...interface{}) { + logMsg(INFO, fmt, args...) +} + +func errorlog(fmt string, args ...interface{}) { + logMsg(ERROR, fmt, args...) +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..ddd4ca9 --- /dev/null +++ b/server.go @@ -0,0 +1,33 @@ +package main + +import ( + "flag" + "fmt" + "github.com/miekg/dns" + "log" +) + +func main() { + var configfile string + + flag.StringVar(&configfile, "c", "config.yaml", "config file") + flag.Parse() + + config, err := loadConfig(configfile) + if err != nil { + log.Fatal(err) + } + + h := newDNSHandler(config) + + for _, l := range config.Listen { + go func(l addr) { + if err := dns.ListenAndServe( + fmt.Sprintf("%s:%d", l.Host, l.Port), l.Network, h); err != nil { + log.Fatal(err) + } + }(l) + } + + select {} +} diff --git a/testdata/blacklist.dat b/testdata/blacklist.dat new file mode 100644 index 0000000..e7a18f5 --- /dev/null +++ b/testdata/blacklist.dat @@ -0,0 +1,5 @@ +# asdfas +1.2.3.4 +2.3.1.21 + +3.1.1.1 diff --git a/testdata/cn.dat b/testdata/cn.dat new file mode 100644 index 0000000..d83c930 --- /dev/null +++ b/testdata/cn.dat @@ -0,0 +1,8 @@ +#12 + +a.com +b.org + +cn + +c.net diff --git a/testdata/config.yaml b/testdata/config.yaml new file mode 100644 index 0000000..5102d49 --- /dev/null +++ b/testdata/config.yaml @@ -0,0 +1,31 @@ + +listen: + - network: tcp + host: 0.0.0.0 + port: 1053 + + - network: udp + host: 0.0.0.0 + port: 1053 + +blacklistfile: ./testdata/blacklist.dat +hostfile: ./testdata/hosts +timeout: 2 +debug: true + + +forwardrules: + - domainfile: ./testdata/cn.dat + server: + - network: tcp + host: 114.114.114.114 + port: 53 + + - domainfile: testdata/us.dat + server: + - network: tcp + host: 8.8.8.8 + port: 53 + - network: https + host: 74.175.200.200 + port: 443 diff --git a/testdata/hosts b/testdata/hosts new file mode 100644 index 0000000..86c2d20 --- /dev/null +++ b/testdata/hosts @@ -0,0 +1,9 @@ +#asdf + +1.2.3.1 a.com b.com c.aaa + +1.2.9.1 addd bbb ddd aa +::1 localhost +127.0.0.1 localhost +1.2.1.1 kkddaso +# daas diff --git a/testdata/us.dat b/testdata/us.dat new file mode 100644 index 0000000..3c2c677 --- /dev/null +++ b/testdata/us.dat @@ -0,0 +1,12 @@ +# +us +b.cc +a.cc + +dda +# da +a.cf +d.com + + +