add function to parse /etc/hosts file

nghttp2
fangdingjun 9 years ago
parent f23b0f27a5
commit b6afc1ebed

@ -68,6 +68,17 @@ func parse_flags() {
} else { } else {
Blacklist_ips = a Blacklist_ips = a
} }
if hostfile == "" {
hostfile = GetHost()
}
if hostfile != "" {
record_hosts, err = ReadHosts(hostfile)
if err != nil {
log.Fatal(err)
}
}
} }
func init() { func init() {
@ -83,7 +94,8 @@ func init() {
flag.StringVar(&bind_addr, "bind", ":53", "the address bind to") flag.StringVar(&bind_addr, "bind", ":53", "the address bind to")
flag.StringVar(&default_server, "upstream", "udp:114.114.114.114:53", "the default upstream server to use") flag.StringVar(&default_server, "upstream", "udp:114.114.114.114:53", "the default upstream server to use")
flag.StringVar(&logfile, "logfile", "error.log", "the logfile, default stdout") flag.StringVar(&logfile, "logfile", "", "the logfile, default stdout")
flag.StringVar(&blacklist_file, "blacklist", "", "the blacklist file") flag.StringVar(&blacklist_file, "blacklist", "", "the blacklist file")
flag.BoolVar(&debug, "debug", false, "output debug log, default false") flag.BoolVar(&debug, "debug", false, "output debug log, default false")
flag.StringVar(&hostfile, "hosts", "", "load special ip from hosts or /etc/hosts")
} }

@ -0,0 +1,135 @@
package main
import (
"bufio"
"errors"
"fmt"
"github.com/miekg/dns"
"net"
"os"
"path/filepath"
"runtime"
"strings"
)
type HostRecord struct {
/* RR record */
rr dns.RR
/* type, dns.A or dns.AAAA */
t uint16
}
type Hosts map[string][]HostRecord
/*
get special type of record form Hosts
*/
func (h Hosts) Get(n string, t uint16) dns.RR {
n1 := dns.Fqdn(n)
if hr, ok := h[n1]; ok {
for _, v := range hr {
if v.t == t {
return v.rr
}
}
}
return nil
}
/*
read and parse the hosts file
*/
func ReadHosts(fn string) (Hosts, error) {
fp, err := os.Open(fn)
if err != nil {
return nil, err
}
defer fp.Close()
hts := Hosts{}
bf := bufio.NewReader(fp)
for {
var t uint16
bline, _, err := bf.ReadLine()
if err != nil {
break
}
sline := string(bline)
sline = strings.Trim(sline, " ")
/* empty line */
if sline == "" {
continue
}
/* comment */
if sline[0] == '#' {
continue
}
/* replace tab to space */
sline = strings.Replace(sline, "\t", " ", -1)
lns := strings.Split(sline, " ")
if len(lns) < 1 {
return nil, errors.New(fmt.Sprintf("invalid hosts line: %s", sline))
}
ip := net.ParseIP(lns[0])
if ip == nil {
return nil, errors.New(fmt.Sprintf("invalid ip: %s", lns[0]))
}
if strings.Index(lns[0], ".") != -1 {
t = dns.TypeA
} else {
t = dns.TypeAAAA
}
for _, dn := range lns[1:] {
dd := dns.Fqdn(strings.Trim(dn, " "))
/* ignore space */
if dd == "." {
continue
}
s := fmt.Sprintf("%s 36000 IN %s %s", dd,
dns.TypeToString[t], lns[0])
r, err := dns.NewRR(s)
if err != nil {
return nil, err
}
if _, ok := hts[dd]; ok {
hts[dd] = append(hts[dd], HostRecord{r, t})
} else {
hts[dd] = []HostRecord{HostRecord{r, t}}
}
}
}
return hts, nil
}
/*
return the path of hosts file
*/
func GetHost() string {
var p string
if runtime.GOOS == "windows" {
p = filepath.Join(os.Getenv("SYSTEMROOT"),
"system32/drivers/etc/hosts")
} else {
p = "/etc/hosts"
}
return filepath.Clean(p)
}

@ -0,0 +1,40 @@
package main
import (
"github.com/miekg/dns"
"testing"
//"fmt"
)
func TestReadHosts(t *testing.T) {
a, err := ReadHosts("testdata/hosts.txt")
if err != nil {
t.Error(err)
}
for k, v := range a {
for _, v1 := range v {
t.Logf("%s: %s\n", k, v1.rr.String())
}
}
r1 := a.Get("localhost", dns.TypeA)
if dnsa, ok := r1.(*dns.A); ok {
if dnsa.A.String() != "127.0.0.1" {
t.Errorf("get failed a\n")
}
} else {
t.Errorf("type not a\n")
}
r2 := a.Get("localhost", dns.TypeAAAA)
if dnsaa, ok := r2.(*dns.AAAA); ok {
if dnsaa.AAAA.String() != "::1" {
t.Errorf("get failed aaaa\n")
}
} else {
t.Errorf("type not aaaa\n")
}
}
func TestGetHost(t *testing.T) {
t.Logf("host: %s\n", GetHost())
}

@ -1,5 +1,3 @@
/* /*
gdns is a dns proxy server write by go. gdns is a dns proxy server write by go.
@ -41,6 +39,9 @@ var Blacklist_ips Kv = nil
var debug bool = false var debug bool = false
var hostfile string = ""
var record_hosts Hosts = nil
func in_blacklist(m *dns.Msg) bool { func in_blacklist(m *dns.Msg) bool {
if Blacklist_ips == nil { if Blacklist_ips == nil {
return false return false
@ -68,6 +69,25 @@ func handleRoot(w dns.ResponseWriter, r *dns.Msg) {
var done int var done int
/*
reply from hosts
*/
if record_hosts != nil {
rr := record_hosts.Get(domain, r.Question[0].Qtype)
if rr != nil {
msg := new(dns.Msg)
msg.SetReply(r)
msg.Answer = append(msg.Answer, rr)
w.WriteMsg(msg)
logger.Debug("query %s %s %s, reply from hosts\n",
domain,
dns.ClassToString[r.Question[0].Qclass],
dns.TypeToString[r.Question[0].Qtype],
)
return
}
}
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
done = 0 done = 0
for _, sv := range Servers { for _, sv := range Servers {

@ -0,0 +1,8 @@
127.0.0.1 localhost localhost.localdomain
192.168.1.1 gw
::1 localhost
192.243.112.217 www.ratafee.nl
2607:8700:103:9b22:: www.ratafee.nl
# 127.0.0.1 ccc.com
Loading…
Cancel
Save