support more than one upstream dns server

nghttp2
fangdingjun 9 years ago
parent 922a51b106
commit 5cae50d369

@ -9,7 +9,7 @@ import (
var bind_addr string
var default_server string
var default_server ArgSrvs
var srv ArgSrvs
@ -17,7 +17,7 @@ var logfile string
type ArgSrvs []string
var DefaultServer *UpstreamServer
var DefaultServer []*UpstreamServer
var blacklist_file string
@ -44,22 +44,29 @@ func parse_flags() {
}
}
proto, addr, err := parse_addr(default_server)
if err != nil {
log.Fatal(err)
}
for _, dsvr := range default_server {
proto, addr, err := parse_addr(dsvr)
if err != nil {
log.Fatal(err)
}
var c *dns.Client
if proto == "udp" {
c = client_udp
} else {
c = client_tcp
var c *dns.Client
if proto == "udp" {
c = client_udp
} else {
c = client_tcp
}
upsrv := &UpstreamServer{
Addr: addr,
Proto: proto,
client: c,
}
DefaultServer = append(DefaultServer, upsrv)
}
DefaultServer = &UpstreamServer{
Addr: addr,
Proto: proto,
client: c,
if len(DefaultServer) == 0 {
log.Fatal("please special a -upstream")
}
a, err := load_domain(blacklist_file)
@ -93,7 +100,7 @@ func init() {
`)
flag.StringVar(&bind_addr, "bind", ":53", "the address bind to")
flag.StringVar(&default_server, "upstream", "udp:114.114.114.114:53", "the default upstream server to use")
flag.Var(&default_server, "upstream", "special the upstream server to use")
flag.StringVar(&logfile, "logfile", "", "the logfile, default stdout")
flag.StringVar(&blacklist_file, "blacklist", "", "the blacklist file")
flag.BoolVar(&debug, "debug", false, "output debug log, default false")

@ -168,30 +168,33 @@ func handleRoot(w dns.ResponseWriter, r *dns.Msg) {
// fallback to default upstream server
if done != 1 {
res, err = DefaultServer.query(r)
if err != nil {
logger.Error("%s", err)
continue
}
for _, dfsrv := range DefaultServer {
res, err = dfsrv.query(r)
if err != nil {
logger.Error("%s", err)
continue
}
logger.Debug("%s query %s %s %s, use default server %s:%s, %s\n",
w.RemoteAddr(),
domain,
dns.ClassToString[r.Question[0].Qclass],
dns.TypeToString[r.Question[0].Qtype],
DefaultServer.Proto, DefaultServer.Addr,
dns.RcodeToString[res.Rcode],
)
logger.Debug("%s query %s %s %s, use default server %s:%s, %s\n",
w.RemoteAddr(),
domain,
dns.ClassToString[r.Question[0].Qclass],
dns.TypeToString[r.Question[0].Qtype],
dfsrv.Proto, dfsrv.Addr,
dns.RcodeToString[res.Rcode],
)
if res.Rcode != dns.RcodeServerFailure && !in_blacklist(res) {
// add to cache
v := []string{}
for _, as := range res.Answer {
v = append(v, as.String())
if res.Rcode != dns.RcodeServerFailure && !in_blacklist(res) {
// add to cache
v := []string{}
for _, as := range res.Answer {
v = append(v, as.String())
}
dns_cache.Add(key, strings.Join(v, "|"))
w.WriteMsg(res)
done = 1
break
}
dns_cache.Add(key, strings.Join(v, "|"))
w.WriteMsg(res)
done = 1
}
}

Loading…
Cancel
Save