diff --git a/arguments.go b/arguments.go index e51a797..8bbaabb 100644 --- a/arguments.go +++ b/arguments.go @@ -9,7 +9,7 @@ import ( var bind_addr string -var default_server string +var default_server ArgSrvs var srv ArgSrvs @@ -17,7 +17,7 @@ var logfile string type ArgSrvs []string -var DefaultServer *UpstreamServer +var DefaultServer []*UpstreamServer var blacklist_file string @@ -44,22 +44,29 @@ func parse_flags() { } } - proto, addr, err := parse_addr(default_server) - if err != nil { - log.Fatal(err) - } + for _, dsvr := range default_server { + proto, addr, err := parse_addr(dsvr) + if err != nil { + log.Fatal(err) + } - var c *dns.Client - if proto == "udp" { - c = client_udp - } else { - c = client_tcp + var c *dns.Client + if proto == "udp" { + c = client_udp + } else { + c = client_tcp + } + + upsrv := &UpstreamServer{ + Addr: addr, + Proto: proto, + client: c, + } + DefaultServer = append(DefaultServer, upsrv) } - DefaultServer = &UpstreamServer{ - Addr: addr, - Proto: proto, - client: c, + if len(DefaultServer) == 0 { + log.Fatal("please special a -upstream") } a, err := load_domain(blacklist_file) @@ -93,7 +100,7 @@ func init() { `) flag.StringVar(&bind_addr, "bind", ":53", "the address bind to") - flag.StringVar(&default_server, "upstream", "udp:114.114.114.114:53", "the default upstream server to use") + flag.Var(&default_server, "upstream", "special the upstream server to use") flag.StringVar(&logfile, "logfile", "", "the logfile, default stdout") flag.StringVar(&blacklist_file, "blacklist", "", "the blacklist file") flag.BoolVar(&debug, "debug", false, "output debug log, default false") diff --git a/server.go b/server.go index 3578bf4..55752f2 100644 --- a/server.go +++ b/server.go @@ -168,30 +168,33 @@ func handleRoot(w dns.ResponseWriter, r *dns.Msg) { // fallback to default upstream server if done != 1 { - res, err = DefaultServer.query(r) - if err != nil { - logger.Error("%s", err) - continue - } + for _, dfsrv := range DefaultServer { + res, err = dfsrv.query(r) + if err != nil { + logger.Error("%s", err) + continue + } - logger.Debug("%s query %s %s %s, use default server %s:%s, %s\n", - w.RemoteAddr(), - domain, - dns.ClassToString[r.Question[0].Qclass], - dns.TypeToString[r.Question[0].Qtype], - DefaultServer.Proto, DefaultServer.Addr, - dns.RcodeToString[res.Rcode], - ) + logger.Debug("%s query %s %s %s, use default server %s:%s, %s\n", + w.RemoteAddr(), + domain, + dns.ClassToString[r.Question[0].Qclass], + dns.TypeToString[r.Question[0].Qtype], + dfsrv.Proto, dfsrv.Addr, + dns.RcodeToString[res.Rcode], + ) - if res.Rcode != dns.RcodeServerFailure && !in_blacklist(res) { - // add to cache - v := []string{} - for _, as := range res.Answer { - v = append(v, as.String()) + if res.Rcode != dns.RcodeServerFailure && !in_blacklist(res) { + // add to cache + v := []string{} + for _, as := range res.Answer { + v = append(v, as.String()) + } + dns_cache.Add(key, strings.Join(v, "|")) + w.WriteMsg(res) + done = 1 + break } - dns_cache.Add(key, strings.Join(v, "|")) - w.WriteMsg(res) - done = 1 } }