diff --git a/.gitignore b/.gitignore index 8a138d7..ca7ff6d 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,4 @@ *.swp *.json *.txt -dns +gdns diff --git a/cache.go b/cache.go new file mode 100644 index 0000000..9bda29c --- /dev/null +++ b/cache.go @@ -0,0 +1,84 @@ +package main + +import ( + "crypto/md5" + "encoding/binary" + "encoding/hex" + "github.com/miekg/dns" + "log" + "sync" + "time" +) + +type cache struct { + m map[string]*elem + lock sync.RWMutex + ttl int64 + max int +} + +type elem struct { + m *dns.Msg + t int64 +} + +func newCache(max int, ttl int64) *cache { + return &cache{ + max: max, + ttl: ttl, + m: map[string]*elem{}, + lock: sync.RWMutex{}, + } +} + +func key(m *dns.Msg) string { + d := m.Question[0].Name + b := []byte(d) + b1 := make([]byte, 4) + binary.BigEndian.PutUint16(b1[0:], m.Question[0].Qclass) + binary.BigEndian.PutUint16(b1[2:], m.Question[0].Qtype) + b = append(b, b1...) + h := md5.New() + h.Write(b) + s1 := hex.EncodeToString(h.Sum(nil)) + return s1 +} + +func (c cache) get(m *dns.Msg) *dns.Msg { + c.lock.RLock() + defer c.lock.RUnlock() + k := key(m) + if m1, ok := c.m[k]; ok { + t := time.Now().Unix() + if t < m1.t { + return m1.m + } + } + return nil +} + +func (c cache) set(m *dns.Msg) { + c.lock.Lock() + defer c.lock.Unlock() + + if len(c.m) >= c.max { + log.Printf("clean the old cache") + c.cleanOld() + } + + k := key(m) + c.m[k] = &elem{ + m.Copy(), + time.Now().Unix() + c.ttl, + } +} + +// must hold the write lock +func (c cache) cleanOld() { + t1 := time.Now().Unix() + for k, v := range c.m { + if v.t >= t1 { + delete(c.m, k) + } + } +} diff --git a/cache_test.go b/cache_test.go new file mode 100644 index 0000000..48a033b --- /dev/null +++ b/cache_test.go @@ -0,0 +1,58 @@ +package main + +import ( + "github.com/miekg/dns" + "testing" + "time" +) + +func TestCache(t *testing.T) { + c := newCache(5, 2) + + tests := map[string]uint16{ + "www.google.com": dns.TypeA, + "www.google.com.hk": dns.TypeA, + "www.google.com.sg": dns.TypeA, + "www.google.com.it": dns.TypeA, + "www.google.com.de": dns.TypeA, + "www.google.com.cn": dns.TypeA, + } + + var datas []*dns.Msg + + for k, v := range tests { + m1 := new(dns.Msg) + m1.SetQuestion(k, v) + datas = append(datas, m1) + } + + for i := 0; i < 3; i++ { + c.set(datas[i]) + } + + for i := 0; i < 3; i++ { + m2 := c.get(datas[i]) + if m2 == nil { + t.Errorf("store cache failed") + } + if m2.Question[0].Name != datas[i].Question[0].Name { + t.Errorf("cache error") + } + } + + time.Sleep(3 * time.Second) + for i := 0; i < 3; i++ { + m2 := c.get(datas[i]) + if m2 != nil { + t.Errorf("cache not expired") + } + } + + for i := 3; i < 6; i++ { + c.set(datas[i]) + } + + if len(c.m) > len(datas) { + t.Errorf("old cache not purged") + } +} diff --git a/cfg_test.go b/cfg_test.go index 454ef7a..95228d0 100644 --- a/cfg_test.go +++ b/cfg_test.go @@ -2,14 +2,17 @@ package main import ( "fmt" + "os" "testing" ) func TestCfg(t *testing.T) { + os.Chdir("example_config") c, err := parseCfg("config.json") if err != nil { t.Fatalf("%s\n", err) } + fmt.Printf("%+v\n", c) fmt.Printf("%v\n", c.Rules[0].domains.match("google.com")) fmt.Printf("%v\n", c.Rules[0].domains.match("www.ip.cn")) diff --git a/example_config/config.json b/example_config/config.json index a87d181..0c43fc5 100644 --- a/example_config/config.json +++ b/example_config/config.json @@ -1,6 +1,6 @@ { "listen":["tcp:0.0.0.0:8053","udp::8053"], - "default_servers":["tcp:114.114.114.114:53","tcp:8.8.8.8:53"], + "default_servers":["tcp:208.67.222.222:53","tcp:8.8.8.8:53"], "timeout":1, "blacklist_ips":["ip.txt"], "rules":[ diff --git a/routers.go b/routers.go index 77e5d62..0ca03b8 100644 --- a/routers.go +++ b/routers.go @@ -8,9 +8,10 @@ import ( ) type routers struct { - c *cfg - tcp *dns.Client - udp *dns.Client + c *cfg + tcp *dns.Client + udp *dns.Client + cache *cache } func (r routers) checkBlacklist(m *dns.Msg) bool { @@ -39,6 +40,15 @@ func (r routers) checkBlacklist(m *dns.Msg) bool { func (r routers) query(m *dns.Msg, servers []addr) (*dns.Msg, error) { var up *dns.Client var lastErr error + + // query cache + m2 := r.cache.get(m) + if m2 != nil { + log.Printf("query %s, reply from cache\n", m.Question[0].Name) + m2.Id = m.Id + return m2, nil + } + for _, srv := range servers { switch srv.network { case "tcp": @@ -51,9 +61,13 @@ func (r routers) query(m *dns.Msg, servers []addr) (*dns.Msg, error) { log.Printf("query %s use %s:%s\n", m.Question[0].Name, srv.network, srv.addr) - m, _, err := up.Exchange(m, srv.addr) + m1, _, err := up.Exchange(m, srv.addr) if err == nil && !r.checkBlacklist(m) { - return m, err + if m1.Rcode == dns.RcodeSuccess { + // store to cache + r.cache.set(m1) + } + return m1, err } log.Println(err) @@ -74,9 +88,10 @@ func (r routers) ServeDNS(w dns.ResponseWriter, m *dns.Msg) { if err == nil { w.WriteMsg(m1) return - } else { - log.Println(err) } + + log.Println(err) + } } @@ -95,6 +110,7 @@ func initRouters(c *cfg) { c, &dns.Client{Net: "tcp", Timeout: time.Duration(c.Timeout) * time.Second}, &dns.Client{Net: "udp", Timeout: time.Duration(c.Timeout) * time.Second}, + newCache(1000, 5*60*60), // cache 5 hours } dns.Handle(".", router) } diff --git a/server.go b/server.go index f484cf6..274bbcf 100644 --- a/server.go +++ b/server.go @@ -17,8 +17,10 @@ func initListeners(c *cfg) { func main() { var configFile string + flag.StringVar(&configFile, "c", "", "config file") flag.Parse() + config, err := parseCfg(configFile) if err != nil { log.Println(err) @@ -27,5 +29,6 @@ func main() { initRouters(config) initListeners(config) + select {} }