diff --git a/server.go b/server.go index 9a4bcf8..3578bf4 100644 --- a/server.go +++ b/server.go @@ -23,8 +23,11 @@ run it package main import ( + "fmt" + lru "github.com/hashicorp/golang-lru" "github.com/miekg/dns" "log" + "strings" ) var client_udp *dns.Client = &dns.Client{} @@ -39,6 +42,8 @@ var Blacklist_ips Kv = nil var debug bool = false +var dns_cache *lru.Cache + var hostfile string = "" var record_hosts Hosts = nil @@ -101,6 +106,32 @@ func handleRoot(w dns.ResponseWriter, r *dns.Msg) { } } + key := fmt.Sprintf("%s_%s", domain, dns.TypeToString[r.Question[0].Qtype]) + + // reply from cache + if a, ok := dns_cache.Get(key); ok { + msg := new(dns.Msg) + msg.SetReply(r) + + aa := strings.Split(a.(string), "|") + for _, a1 := range aa { + rr, _ := dns.NewRR(a1) + if rr != nil { + msg.Answer = append(msg.Answer, rr) + } + } + + w.WriteMsg(msg) + logger.Debug("%s query %s %s %s, reply from cache\n", + w.RemoteAddr(), + domain, + dns.ClassToString[r.Question[0].Qclass], + dns.TypeToString[r.Question[0].Qtype], + ) + return + } + + // forward to upstream server for i := 0; i < 2; i++ { done = 0 for _, sv := range Servers { @@ -122,6 +153,12 @@ func handleRoot(w dns.ResponseWriter, r *dns.Msg) { ) if res.Rcode != dns.RcodeServerFailure && !in_blacklist(res) { + // add to cache + v := []string{} + for _, as := range res.Answer { + v = append(v, as.String()) + } + dns_cache.Add(key, strings.Join(v, "|")) w.WriteMsg(res) done = 1 break @@ -129,6 +166,7 @@ func handleRoot(w dns.ResponseWriter, r *dns.Msg) { } } + // fallback to default upstream server if done != 1 { res, err = DefaultServer.query(r) if err != nil { @@ -146,6 +184,12 @@ func handleRoot(w dns.ResponseWriter, r *dns.Msg) { ) if res.Rcode != dns.RcodeServerFailure && !in_blacklist(res) { + // add to cache + v := []string{} + for _, as := range res.Answer { + v = append(v, as.String()) + } + dns_cache.Add(key, strings.Join(v, "|")) w.WriteMsg(res) done = 1 } @@ -164,6 +208,14 @@ func handleRoot(w dns.ResponseWriter, r *dns.Msg) { func main() { parse_flags() + var err error + + // create cache + dns_cache, err = lru.New(1000) + if err != nil { + log.Fatal(err) + } + dns.HandleFunc(".", handleRoot) logger = NewLogger(logfile, debug) @@ -179,7 +231,7 @@ func main() { }() /* listen udp */ - err := dns.ListenAndServe(bind_addr, "udp", nil) + err = dns.ListenAndServe(bind_addr, "udp", nil) if err != nil { log.Fatal(err) }