From 8b6c5a7f266e74de93bad27b073cbe9b97b50d48 Mon Sep 17 00:00:00 2001 From: Dingjun Date: Wed, 3 Aug 2016 17:50:05 +0800 Subject: [PATCH] use pointer --- cache.go | 10 +++++----- routers.go | 12 +++++++++--- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/cache.go b/cache.go index 9bda29c..7923066 100644 --- a/cache.go +++ b/cache.go @@ -12,7 +12,7 @@ import ( type cache struct { m map[string]*elem - lock sync.RWMutex + lock *sync.RWMutex ttl int64 max int } @@ -27,7 +27,7 @@ func newCache(max int, ttl int64) *cache { max: max, ttl: ttl, m: map[string]*elem{}, - lock: sync.RWMutex{}, + lock: new(sync.RWMutex), } } @@ -44,7 +44,7 @@ func key(m *dns.Msg) string { return s1 } -func (c cache) get(m *dns.Msg) *dns.Msg { +func (c *cache) get(m *dns.Msg) *dns.Msg { c.lock.RLock() defer c.lock.RUnlock() k := key(m) @@ -57,7 +57,7 @@ func (c cache) get(m *dns.Msg) *dns.Msg { return nil } -func (c cache) set(m *dns.Msg) { +func (c *cache) set(m *dns.Msg) { c.lock.Lock() defer c.lock.Unlock() @@ -74,7 +74,7 @@ func (c cache) set(m *dns.Msg) { } // must hold the write lock -func (c cache) cleanOld() { +func (c *cache) cleanOld() { t1 := time.Now().Unix() for k, v := range c.m { if v.t >= t1 { diff --git a/routers.go b/routers.go index 0ca03b8..8f0a5ed 100644 --- a/routers.go +++ b/routers.go @@ -1,6 +1,7 @@ package main import ( + "errors" "github.com/miekg/dns" "log" "strings" @@ -14,7 +15,7 @@ type routers struct { cache *cache } -func (r routers) checkBlacklist(m *dns.Msg) bool { +func (r *routers) checkBlacklist(m *dns.Msg) bool { if m.Rcode != dns.RcodeSuccess { // not success, not in blacklist return false @@ -37,7 +38,7 @@ func (r routers) checkBlacklist(m *dns.Msg) bool { return false } -func (r routers) query(m *dns.Msg, servers []addr) (*dns.Msg, error) { +func (r *routers) query(m *dns.Msg, servers []addr) (*dns.Msg, error) { var up *dns.Client var lastErr error @@ -74,12 +75,17 @@ func (r routers) query(m *dns.Msg, servers []addr) (*dns.Msg, error) { lastErr = err } + if lastErr == nil { + // this happens when ip in blacklist + lastErr = errors.New("timeout") + } + // return last error return nil, lastErr } // ServeDNS implements dns.Handler interface -func (r routers) ServeDNS(w dns.ResponseWriter, m *dns.Msg) { +func (r *routers) ServeDNS(w dns.ResponseWriter, m *dns.Msg) { domain := m.Question[0].Name d := strings.Trim(domain, ".") for _, rule := range r.c.Rules {