From b6afc1ebed1363783eeb3f70dc993e86af261f3c Mon Sep 17 00:00:00 2001 From: fangdingjun Date: Mon, 21 Sep 2015 15:23:05 +0800 Subject: [PATCH] add function to parse /etc/hosts file --- arguments.go | 14 ++++- hosts.go | 135 +++++++++++++++++++++++++++++++++++++++++++++ hosts_test.go | 40 ++++++++++++++ server.go | 24 +++++++- testdata/hosts.txt | 8 +++ 5 files changed, 218 insertions(+), 3 deletions(-) create mode 100644 hosts.go create mode 100644 hosts_test.go create mode 100644 testdata/hosts.txt diff --git a/arguments.go b/arguments.go index 5cd0866..e51a797 100644 --- a/arguments.go +++ b/arguments.go @@ -68,6 +68,17 @@ func parse_flags() { } else { Blacklist_ips = a } + + if hostfile == "" { + hostfile = GetHost() + } + + if hostfile != "" { + record_hosts, err = ReadHosts(hostfile) + if err != nil { + log.Fatal(err) + } + } } func init() { @@ -83,7 +94,8 @@ func init() { flag.StringVar(&bind_addr, "bind", ":53", "the address bind to") flag.StringVar(&default_server, "upstream", "udp:114.114.114.114:53", "the default upstream server to use") - flag.StringVar(&logfile, "logfile", "error.log", "the logfile, default stdout") + flag.StringVar(&logfile, "logfile", "", "the logfile, default stdout") flag.StringVar(&blacklist_file, "blacklist", "", "the blacklist file") flag.BoolVar(&debug, "debug", false, "output debug log, default false") + flag.StringVar(&hostfile, "hosts", "", "load special ip from hosts or /etc/hosts") } diff --git a/hosts.go b/hosts.go new file mode 100644 index 0000000..e16bdf0 --- /dev/null +++ b/hosts.go @@ -0,0 +1,135 @@ +package main + +import ( + "bufio" + "errors" + "fmt" + "github.com/miekg/dns" + "net" + "os" + "path/filepath" + "runtime" + "strings" +) + +type HostRecord struct { + /* RR record */ + rr dns.RR + + /* type, dns.A or dns.AAAA */ + t uint16 +} + +type Hosts map[string][]HostRecord + +/* + get special type of record form Hosts +*/ +func (h Hosts) Get(n string, t uint16) dns.RR { + n1 := dns.Fqdn(n) + if hr, ok := h[n1]; ok { + for _, v := range hr { + if v.t == t { + return v.rr + } + } + } + return nil +} + +/* + read and parse the hosts file +*/ +func ReadHosts(fn string) (Hosts, error) { + fp, err := os.Open(fn) + if err != nil { + return nil, err + } + + defer fp.Close() + + hts := Hosts{} + + bf := bufio.NewReader(fp) + + for { + var t uint16 + bline, _, err := bf.ReadLine() + if err != nil { + break + } + + sline := string(bline) + sline = strings.Trim(sline, " ") + + /* empty line */ + if sline == "" { + continue + } + + /* comment */ + if sline[0] == '#' { + continue + } + + /* replace tab to space */ + sline = strings.Replace(sline, "\t", " ", -1) + + lns := strings.Split(sline, " ") + + if len(lns) < 1 { + return nil, errors.New(fmt.Sprintf("invalid hosts line: %s", sline)) + } + + ip := net.ParseIP(lns[0]) + if ip == nil { + return nil, errors.New(fmt.Sprintf("invalid ip: %s", lns[0])) + } + + if strings.Index(lns[0], ".") != -1 { + t = dns.TypeA + } else { + t = dns.TypeAAAA + } + + for _, dn := range lns[1:] { + + dd := dns.Fqdn(strings.Trim(dn, " ")) + + /* ignore space */ + if dd == "." { + continue + } + + s := fmt.Sprintf("%s 36000 IN %s %s", dd, + dns.TypeToString[t], lns[0]) + + r, err := dns.NewRR(s) + if err != nil { + return nil, err + } + + if _, ok := hts[dd]; ok { + hts[dd] = append(hts[dd], HostRecord{r, t}) + } else { + hts[dd] = []HostRecord{HostRecord{r, t}} + } + } + } + + return hts, nil +} + +/* + return the path of hosts file +*/ +func GetHost() string { + var p string + if runtime.GOOS == "windows" { + p = filepath.Join(os.Getenv("SYSTEMROOT"), + "system32/drivers/etc/hosts") + } else { + p = "/etc/hosts" + } + return filepath.Clean(p) +} diff --git a/hosts_test.go b/hosts_test.go new file mode 100644 index 0000000..b84a0dc --- /dev/null +++ b/hosts_test.go @@ -0,0 +1,40 @@ +package main + +import ( + "github.com/miekg/dns" + "testing" + //"fmt" +) + +func TestReadHosts(t *testing.T) { + a, err := ReadHosts("testdata/hosts.txt") + if err != nil { + t.Error(err) + } + + for k, v := range a { + for _, v1 := range v { + t.Logf("%s: %s\n", k, v1.rr.String()) + } + } + r1 := a.Get("localhost", dns.TypeA) + if dnsa, ok := r1.(*dns.A); ok { + if dnsa.A.String() != "127.0.0.1" { + t.Errorf("get failed a\n") + } + } else { + t.Errorf("type not a\n") + } + r2 := a.Get("localhost", dns.TypeAAAA) + if dnsaa, ok := r2.(*dns.AAAA); ok { + if dnsaa.AAAA.String() != "::1" { + t.Errorf("get failed aaaa\n") + } + } else { + t.Errorf("type not aaaa\n") + } +} + +func TestGetHost(t *testing.T) { + t.Logf("host: %s\n", GetHost()) +} diff --git a/server.go b/server.go index 02eec53..daf58ef 100644 --- a/server.go +++ b/server.go @@ -1,5 +1,3 @@ - - /* gdns is a dns proxy server write by go. @@ -41,6 +39,9 @@ var Blacklist_ips Kv = nil var debug bool = false +var hostfile string = "" +var record_hosts Hosts = nil + func in_blacklist(m *dns.Msg) bool { if Blacklist_ips == nil { return false @@ -68,6 +69,25 @@ func handleRoot(w dns.ResponseWriter, r *dns.Msg) { var done int + /* + reply from hosts + */ + if record_hosts != nil { + rr := record_hosts.Get(domain, r.Question[0].Qtype) + if rr != nil { + msg := new(dns.Msg) + msg.SetReply(r) + msg.Answer = append(msg.Answer, rr) + w.WriteMsg(msg) + logger.Debug("query %s %s %s, reply from hosts\n", + domain, + dns.ClassToString[r.Question[0].Qclass], + dns.TypeToString[r.Question[0].Qtype], + ) + return + } + } + for i := 0; i < 2; i++ { done = 0 for _, sv := range Servers { diff --git a/testdata/hosts.txt b/testdata/hosts.txt new file mode 100644 index 0000000..e447fa9 --- /dev/null +++ b/testdata/hosts.txt @@ -0,0 +1,8 @@ +127.0.0.1 localhost localhost.localdomain +192.168.1.1 gw +::1 localhost + + +192.243.112.217 www.ratafee.nl + 2607:8700:103:9b22:: www.ratafee.nl +# 127.0.0.1 ccc.com