use lru for dns result cache

nghttp2
fangdingjun 9 years ago
parent a72bd45423
commit 922a51b106

@ -23,8 +23,11 @@ run it
package main
import (
"fmt"
lru "github.com/hashicorp/golang-lru"
"github.com/miekg/dns"
"log"
"strings"
)
var client_udp *dns.Client = &dns.Client{}
@ -39,6 +42,8 @@ var Blacklist_ips Kv = nil
var debug bool = false
var dns_cache *lru.Cache
var hostfile string = ""
var record_hosts Hosts = nil
@ -101,6 +106,32 @@ func handleRoot(w dns.ResponseWriter, r *dns.Msg) {
}
}
key := fmt.Sprintf("%s_%s", domain, dns.TypeToString[r.Question[0].Qtype])
// reply from cache
if a, ok := dns_cache.Get(key); ok {
msg := new(dns.Msg)
msg.SetReply(r)
aa := strings.Split(a.(string), "|")
for _, a1 := range aa {
rr, _ := dns.NewRR(a1)
if rr != nil {
msg.Answer = append(msg.Answer, rr)
}
}
w.WriteMsg(msg)
logger.Debug("%s query %s %s %s, reply from cache\n",
w.RemoteAddr(),
domain,
dns.ClassToString[r.Question[0].Qclass],
dns.TypeToString[r.Question[0].Qtype],
)
return
}
// forward to upstream server
for i := 0; i < 2; i++ {
done = 0
for _, sv := range Servers {
@ -122,6 +153,12 @@ func handleRoot(w dns.ResponseWriter, r *dns.Msg) {
)
if res.Rcode != dns.RcodeServerFailure && !in_blacklist(res) {
// add to cache
v := []string{}
for _, as := range res.Answer {
v = append(v, as.String())
}
dns_cache.Add(key, strings.Join(v, "|"))
w.WriteMsg(res)
done = 1
break
@ -129,6 +166,7 @@ func handleRoot(w dns.ResponseWriter, r *dns.Msg) {
}
}
// fallback to default upstream server
if done != 1 {
res, err = DefaultServer.query(r)
if err != nil {
@ -146,6 +184,12 @@ func handleRoot(w dns.ResponseWriter, r *dns.Msg) {
)
if res.Rcode != dns.RcodeServerFailure && !in_blacklist(res) {
// add to cache
v := []string{}
for _, as := range res.Answer {
v = append(v, as.String())
}
dns_cache.Add(key, strings.Join(v, "|"))
w.WriteMsg(res)
done = 1
}
@ -164,6 +208,14 @@ func handleRoot(w dns.ResponseWriter, r *dns.Msg) {
func main() {
parse_flags()
var err error
// create cache
dns_cache, err = lru.New(1000)
if err != nil {
log.Fatal(err)
}
dns.HandleFunc(".", handleRoot)
logger = NewLogger(logfile, debug)
@ -179,7 +231,7 @@ func main() {
}()
/* listen udp */
err := dns.ListenAndServe(bind_addr, "udp", nil)
err = dns.ListenAndServe(bind_addr, "udp", nil)
if err != nil {
log.Fatal(err)
}

Loading…
Cancel
Save