commit 034712fda08c011a1410d797230493a35dcef00c Author: Dingjun Date: Mon Aug 1 15:53:28 2016 +0800 gdns v0.0.1 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8a138d7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +*~ +*.swp +*.json +*.txt +dns diff --git a/README.md b/README.md new file mode 100644 index 0000000..5317a19 --- /dev/null +++ b/README.md @@ -0,0 +1,12 @@ +gdns +==== + +a dns forward proxy write by go + +**Features** + +- support listen tcp and udp protocol +- support configure different domains forward to different upstream servers +- support tcp or udp to communicate to upstream servers +- support ip blacklist, drop unwanted dns reply + diff --git a/cfg.go b/cfg.go new file mode 100644 index 0000000..b156eff --- /dev/null +++ b/cfg.go @@ -0,0 +1,142 @@ +package main + +import ( + "bufio" + "encoding/json" + "fmt" + "io/ioutil" + "os" + "strings" +) + +type item map[string]int + +func (i item) has(s string) bool { + if _, ok := i[s]; ok { + return true + } + return false +} + +func (it item) match(s string) bool { + iis := strings.Split(s, ".") + for i := 0; i < len(iis); i++ { + ii := strings.Join(iis[i:], ".") + if _, ok := it[ii]; ok { + return true + } + } + return false +} + +type addr struct { + network string + addr string +} + +// Rule present a forward rule +type Rule struct { + DomainlistFile string `json:"domain_list_file"` + domains item + ServersString []string `json:"servers"` + servers []addr +} + +type cfg struct { + Listen []string `json:"listen"` + listen []addr + ServersString []string `json:"default_servers"` + servers []addr + Timeout int `json:"timeout"` + BlacklistFiles []string `json:"blacklist_ips"` + blacklistIps item + Rules []Rule `json:"rules"` +} + +func parseCfg(fn string) (*cfg, error) { + fp, err := os.Open(fn) + if err != nil { + return nil, err + } + defer fp.Close() + c := cfg{} + buf, err := ioutil.ReadAll(fp) + err = json.Unmarshal(buf, &c) + if err != nil { + return nil, err + } + + var adr []addr + for _, a := range c.ServersString { + a1 := parseAddr(a) + if a1.network != "" { + adr = append(adr, a1) + } + } + c.servers = adr + + var ll []addr + for _, a := range c.Listen { + a1 := parseAddr(a) + if a1.network != "" { + ll = append(ll, a1) + } + } + c.listen = ll + + l1 := make(item) + for _, a := range c.BlacklistFiles { + parseFile(a, &l1) + } + c.blacklistIps = l1 + + for i, r := range c.Rules { + l2 := make(item) + parseFile(r.DomainlistFile, &l2) + c.Rules[i].domains = l2 + + var adr1 []addr + for _, a := range r.ServersString { + a1 := parseAddr(a) + if a1.network != "" { + adr1 = append(adr1, a1) + } + } + c.Rules[i].servers = adr1 + } + return &c, nil +} + +func parseAddr(addr1 string) addr { + a := strings.SplitN(addr1, ":", 2) + if len(a) != 2 { + fmt.Printf("addr error") + return addr{"", ""} + } + return addr{a[0], a[1]} +} + +func parseFile(fn string, i *item) { + ii := *i + fp, err := os.Open(fn) + if err != nil { + fmt.Printf("open failed: %s\n", err) + return + } + defer fp.Close() + r := bufio.NewReader(fp) + for { + line, err := r.ReadString('\n') + l := strings.Trim(line, " \r\n\t") + if err != nil { + if l != "" { + ii[l] = 1 + } + break + } + if l == "" { + continue + } + ii[l] = 1 + } +} diff --git a/cfg_test.go b/cfg_test.go new file mode 100644 index 0000000..454ef7a --- /dev/null +++ b/cfg_test.go @@ -0,0 +1,16 @@ +package main + +import ( + "fmt" + "testing" +) + +func TestCfg(t *testing.T) { + c, err := parseCfg("config.json") + if err != nil { + t.Fatalf("%s\n", err) + } + fmt.Printf("%+v\n", c) + fmt.Printf("%v\n", c.Rules[0].domains.match("google.com")) + fmt.Printf("%v\n", c.Rules[0].domains.match("www.ip.cn")) +} diff --git a/example_config/config.json b/example_config/config.json new file mode 100644 index 0000000..a87d181 --- /dev/null +++ b/example_config/config.json @@ -0,0 +1,16 @@ +{ + "listen":["tcp:0.0.0.0:8053","udp::8053"], + "default_servers":["tcp:114.114.114.114:53","tcp:8.8.8.8:53"], + "timeout":1, + "blacklist_ips":["ip.txt"], + "rules":[ + { + "domain_list_file":"domain1.txt", + "servers":["tcp:4.2.2.2:53"] + }, + { + "domain_list_file":"domain2.txt", + "servers":["tcp:8.8.4.4:53"] + } + ] +} diff --git a/example_config/domain1.txt b/example_config/domain1.txt new file mode 100644 index 0000000..dfad8ae --- /dev/null +++ b/example_config/domain1.txt @@ -0,0 +1,4 @@ +ww.goole.com +ip.cn +taobao.com +a.cn diff --git a/example_config/domain2.txt b/example_config/domain2.txt new file mode 100644 index 0000000..214818c --- /dev/null +++ b/example_config/domain2.txt @@ -0,0 +1,3 @@ +a.com +b.org +c.net diff --git a/example_config/ip.txt b/example_config/ip.txt new file mode 100644 index 0000000..378038c --- /dev/null +++ b/example_config/ip.txt @@ -0,0 +1,2 @@ +1.2.3.4 +2.21.2.2 diff --git a/routers.go b/routers.go new file mode 100644 index 0000000..77e5d62 --- /dev/null +++ b/routers.go @@ -0,0 +1,100 @@ +package main + +import ( + "github.com/miekg/dns" + "log" + "strings" + "time" +) + +type routers struct { + c *cfg + tcp *dns.Client + udp *dns.Client +} + +func (r routers) checkBlacklist(m *dns.Msg) bool { + if m.Rcode != dns.RcodeSuccess { + // not success, not in blacklist + return false + } + + for _, rr := range m.Answer { + var ip = "" + if t, ok := rr.(*dns.A); ok { + ip = t.A.String() + } else if t, ok := rr.(*dns.AAAA); ok { + ip = t.AAAA.String() + } + + if ip != "" && r.c.blacklistIps.has(ip) { + log.Printf("%s is in blacklist.\n", ip) + return true + } + + } + return false +} + +func (r routers) query(m *dns.Msg, servers []addr) (*dns.Msg, error) { + var up *dns.Client + var lastErr error + for _, srv := range servers { + switch srv.network { + case "tcp": + up = r.tcp + case "udp": + up = r.udp + default: + up = r.udp + } + + log.Printf("query %s use %s:%s\n", m.Question[0].Name, srv.network, srv.addr) + + m, _, err := up.Exchange(m, srv.addr) + if err == nil && !r.checkBlacklist(m) { + return m, err + } + + log.Println(err) + lastErr = err + } + + // return last error + return nil, lastErr +} + +// ServeDNS implements dns.Handler interface +func (r routers) ServeDNS(w dns.ResponseWriter, m *dns.Msg) { + domain := m.Question[0].Name + d := strings.Trim(domain, ".") + for _, rule := range r.c.Rules { + if rule.domains.match(d) { + m1, err := r.query(m, rule.servers) + if err == nil { + w.WriteMsg(m1) + return + } else { + log.Println(err) + } + } + } + + // no match or failed, fallback to default + m1, err := r.query(m, r.c.servers) + if err != nil { + log.Println(err) + dns.HandleFailed(w, m) + } else { + w.WriteMsg(m1) + } +} + +func initRouters(c *cfg) { + router := &routers{ + c, + &dns.Client{Net: "tcp", Timeout: time.Duration(c.Timeout) * time.Second}, + &dns.Client{Net: "udp", Timeout: time.Duration(c.Timeout) * time.Second}, + } + dns.Handle(".", router) +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..f484cf6 --- /dev/null +++ b/server.go @@ -0,0 +1,31 @@ +package main + +import ( + "flag" + "github.com/miekg/dns" + "log" + "os" +) + +func initListeners(c *cfg) { + for _, a := range c.listen { + log.Printf("Listen on %s %s...\n", a.network, a.addr) + s := dns.Server{Addr: a.addr, Net: a.network} + go s.ListenAndServe() + } +} + +func main() { + var configFile string + flag.StringVar(&configFile, "c", "", "config file") + flag.Parse() + config, err := parseCfg(configFile) + if err != nil { + log.Println(err) + os.Exit(-1) + } + + initRouters(config) + initListeners(config) + select {} +}