support more than one upstream dns server

nghttp2
fangdingjun 9 years ago
parent 922a51b106
commit 5cae50d369

@ -9,7 +9,7 @@ import (
var bind_addr string var bind_addr string
var default_server string var default_server ArgSrvs
var srv ArgSrvs var srv ArgSrvs
@ -17,7 +17,7 @@ var logfile string
type ArgSrvs []string type ArgSrvs []string
var DefaultServer *UpstreamServer var DefaultServer []*UpstreamServer
var blacklist_file string var blacklist_file string
@ -44,22 +44,29 @@ func parse_flags() {
} }
} }
proto, addr, err := parse_addr(default_server) for _, dsvr := range default_server {
if err != nil { proto, addr, err := parse_addr(dsvr)
log.Fatal(err) if err != nil {
} log.Fatal(err)
}
var c *dns.Client var c *dns.Client
if proto == "udp" { if proto == "udp" {
c = client_udp c = client_udp
} else { } else {
c = client_tcp c = client_tcp
}
upsrv := &UpstreamServer{
Addr: addr,
Proto: proto,
client: c,
}
DefaultServer = append(DefaultServer, upsrv)
} }
DefaultServer = &UpstreamServer{ if len(DefaultServer) == 0 {
Addr: addr, log.Fatal("please special a -upstream")
Proto: proto,
client: c,
} }
a, err := load_domain(blacklist_file) a, err := load_domain(blacklist_file)
@ -93,7 +100,7 @@ func init() {
`) `)
flag.StringVar(&bind_addr, "bind", ":53", "the address bind to") flag.StringVar(&bind_addr, "bind", ":53", "the address bind to")
flag.StringVar(&default_server, "upstream", "udp:114.114.114.114:53", "the default upstream server to use") flag.Var(&default_server, "upstream", "special the upstream server to use")
flag.StringVar(&logfile, "logfile", "", "the logfile, default stdout") flag.StringVar(&logfile, "logfile", "", "the logfile, default stdout")
flag.StringVar(&blacklist_file, "blacklist", "", "the blacklist file") flag.StringVar(&blacklist_file, "blacklist", "", "the blacklist file")
flag.BoolVar(&debug, "debug", false, "output debug log, default false") flag.BoolVar(&debug, "debug", false, "output debug log, default false")

@ -168,30 +168,33 @@ func handleRoot(w dns.ResponseWriter, r *dns.Msg) {
// fallback to default upstream server // fallback to default upstream server
if done != 1 { if done != 1 {
res, err = DefaultServer.query(r) for _, dfsrv := range DefaultServer {
if err != nil { res, err = dfsrv.query(r)
logger.Error("%s", err) if err != nil {
continue logger.Error("%s", err)
} continue
}
logger.Debug("%s query %s %s %s, use default server %s:%s, %s\n", logger.Debug("%s query %s %s %s, use default server %s:%s, %s\n",
w.RemoteAddr(), w.RemoteAddr(),
domain, domain,
dns.ClassToString[r.Question[0].Qclass], dns.ClassToString[r.Question[0].Qclass],
dns.TypeToString[r.Question[0].Qtype], dns.TypeToString[r.Question[0].Qtype],
DefaultServer.Proto, DefaultServer.Addr, dfsrv.Proto, dfsrv.Addr,
dns.RcodeToString[res.Rcode], dns.RcodeToString[res.Rcode],
) )
if res.Rcode != dns.RcodeServerFailure && !in_blacklist(res) { if res.Rcode != dns.RcodeServerFailure && !in_blacklist(res) {
// add to cache // add to cache
v := []string{} v := []string{}
for _, as := range res.Answer { for _, as := range res.Answer {
v = append(v, as.String()) v = append(v, as.String())
}
dns_cache.Add(key, strings.Join(v, "|"))
w.WriteMsg(res)
done = 1
break
} }
dns_cache.Add(key, strings.Join(v, "|"))
w.WriteMsg(res)
done = 1
} }
} }

Loading…
Cancel
Save