use pointer

dns
Dingjun 8 years ago
parent a0c68ccf61
commit 8b6c5a7f26

@ -12,7 +12,7 @@ import (
type cache struct { type cache struct {
m map[string]*elem m map[string]*elem
lock sync.RWMutex lock *sync.RWMutex
ttl int64 ttl int64
max int max int
} }
@ -27,7 +27,7 @@ func newCache(max int, ttl int64) *cache {
max: max, max: max,
ttl: ttl, ttl: ttl,
m: map[string]*elem{}, m: map[string]*elem{},
lock: sync.RWMutex{}, lock: new(sync.RWMutex),
} }
} }
@ -44,7 +44,7 @@ func key(m *dns.Msg) string {
return s1 return s1
} }
func (c cache) get(m *dns.Msg) *dns.Msg { func (c *cache) get(m *dns.Msg) *dns.Msg {
c.lock.RLock() c.lock.RLock()
defer c.lock.RUnlock() defer c.lock.RUnlock()
k := key(m) k := key(m)
@ -57,7 +57,7 @@ func (c cache) get(m *dns.Msg) *dns.Msg {
return nil return nil
} }
func (c cache) set(m *dns.Msg) { func (c *cache) set(m *dns.Msg) {
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
@ -74,7 +74,7 @@ func (c cache) set(m *dns.Msg) {
} }
// must hold the write lock // must hold the write lock
func (c cache) cleanOld() { func (c *cache) cleanOld() {
t1 := time.Now().Unix() t1 := time.Now().Unix()
for k, v := range c.m { for k, v := range c.m {
if v.t >= t1 { if v.t >= t1 {

@ -1,6 +1,7 @@
package main package main
import ( import (
"errors"
"github.com/miekg/dns" "github.com/miekg/dns"
"log" "log"
"strings" "strings"
@ -14,7 +15,7 @@ type routers struct {
cache *cache cache *cache
} }
func (r routers) checkBlacklist(m *dns.Msg) bool { func (r *routers) checkBlacklist(m *dns.Msg) bool {
if m.Rcode != dns.RcodeSuccess { if m.Rcode != dns.RcodeSuccess {
// not success, not in blacklist // not success, not in blacklist
return false return false
@ -37,7 +38,7 @@ func (r routers) checkBlacklist(m *dns.Msg) bool {
return false return false
} }
func (r routers) query(m *dns.Msg, servers []addr) (*dns.Msg, error) { func (r *routers) query(m *dns.Msg, servers []addr) (*dns.Msg, error) {
var up *dns.Client var up *dns.Client
var lastErr error var lastErr error
@ -74,12 +75,17 @@ func (r routers) query(m *dns.Msg, servers []addr) (*dns.Msg, error) {
lastErr = err lastErr = err
} }
if lastErr == nil {
// this happens when ip in blacklist
lastErr = errors.New("timeout")
}
// return last error // return last error
return nil, lastErr return nil, lastErr
} }
// ServeDNS implements dns.Handler interface // ServeDNS implements dns.Handler interface
func (r routers) ServeDNS(w dns.ResponseWriter, m *dns.Msg) { func (r *routers) ServeDNS(w dns.ResponseWriter, m *dns.Msg) {
domain := m.Question[0].Name domain := m.Question[0].Name
d := strings.Trim(domain, ".") d := strings.Trim(domain, ".")
for _, rule := range r.c.Rules { for _, rule := range r.c.Rules {

Loading…
Cancel
Save