add cache store

dns
Dingjun 8 years ago
parent a1536bb7b4
commit 1cf266166f

2
.gitignore vendored

@ -2,4 +2,4 @@
*.swp
*.json
*.txt
dns
gdns

@ -0,0 +1,84 @@
package main
import (
"crypto/md5"
"encoding/binary"
"encoding/hex"
"github.com/miekg/dns"
"log"
"sync"
"time"
)
type cache struct {
m map[string]*elem
lock sync.RWMutex
ttl int64
max int
}
type elem struct {
m *dns.Msg
t int64
}
func newCache(max int, ttl int64) *cache {
return &cache{
max: max,
ttl: ttl,
m: map[string]*elem{},
lock: sync.RWMutex{},
}
}
func key(m *dns.Msg) string {
d := m.Question[0].Name
b := []byte(d)
b1 := make([]byte, 4)
binary.BigEndian.PutUint16(b1[0:], m.Question[0].Qclass)
binary.BigEndian.PutUint16(b1[2:], m.Question[0].Qtype)
b = append(b, b1...)
h := md5.New()
h.Write(b)
s1 := hex.EncodeToString(h.Sum(nil))
return s1
}
func (c cache) get(m *dns.Msg) *dns.Msg {
c.lock.RLock()
defer c.lock.RUnlock()
k := key(m)
if m1, ok := c.m[k]; ok {
t := time.Now().Unix()
if t < m1.t {
return m1.m
}
}
return nil
}
func (c cache) set(m *dns.Msg) {
c.lock.Lock()
defer c.lock.Unlock()
if len(c.m) >= c.max {
log.Printf("clean the old cache")
c.cleanOld()
}
k := key(m)
c.m[k] = &elem{
m.Copy(),
time.Now().Unix() + c.ttl,
}
}
// must hold the write lock
func (c cache) cleanOld() {
t1 := time.Now().Unix()
for k, v := range c.m {
if v.t >= t1 {
delete(c.m, k)
}
}
}

@ -0,0 +1,58 @@
package main
import (
"github.com/miekg/dns"
"testing"
"time"
)
func TestCache(t *testing.T) {
c := newCache(5, 2)
tests := map[string]uint16{
"www.google.com": dns.TypeA,
"www.google.com.hk": dns.TypeA,
"www.google.com.sg": dns.TypeA,
"www.google.com.it": dns.TypeA,
"www.google.com.de": dns.TypeA,
"www.google.com.cn": dns.TypeA,
}
var datas []*dns.Msg
for k, v := range tests {
m1 := new(dns.Msg)
m1.SetQuestion(k, v)
datas = append(datas, m1)
}
for i := 0; i < 3; i++ {
c.set(datas[i])
}
for i := 0; i < 3; i++ {
m2 := c.get(datas[i])
if m2 == nil {
t.Errorf("store cache failed")
}
if m2.Question[0].Name != datas[i].Question[0].Name {
t.Errorf("cache error")
}
}
time.Sleep(3 * time.Second)
for i := 0; i < 3; i++ {
m2 := c.get(datas[i])
if m2 != nil {
t.Errorf("cache not expired")
}
}
for i := 3; i < 6; i++ {
c.set(datas[i])
}
if len(c.m) > len(datas) {
t.Errorf("old cache not purged")
}
}

@ -2,14 +2,17 @@ package main
import (
"fmt"
"os"
"testing"
)
func TestCfg(t *testing.T) {
os.Chdir("example_config")
c, err := parseCfg("config.json")
if err != nil {
t.Fatalf("%s\n", err)
}
fmt.Printf("%+v\n", c)
fmt.Printf("%v\n", c.Rules[0].domains.match("google.com"))
fmt.Printf("%v\n", c.Rules[0].domains.match("www.ip.cn"))

@ -1,6 +1,6 @@
{
"listen":["tcp:0.0.0.0:8053","udp::8053"],
"default_servers":["tcp:114.114.114.114:53","tcp:8.8.8.8:53"],
"default_servers":["tcp:208.67.222.222:53","tcp:8.8.8.8:53"],
"timeout":1,
"blacklist_ips":["ip.txt"],
"rules":[

@ -11,6 +11,7 @@ type routers struct {
c *cfg
tcp *dns.Client
udp *dns.Client
cache *cache
}
func (r routers) checkBlacklist(m *dns.Msg) bool {
@ -39,6 +40,15 @@ func (r routers) checkBlacklist(m *dns.Msg) bool {
func (r routers) query(m *dns.Msg, servers []addr) (*dns.Msg, error) {
var up *dns.Client
var lastErr error
// query cache
m2 := r.cache.get(m)
if m2 != nil {
log.Printf("query %s, reply from cache\n", m.Question[0].Name)
m2.Id = m.Id
return m2, nil
}
for _, srv := range servers {
switch srv.network {
case "tcp":
@ -51,9 +61,13 @@ func (r routers) query(m *dns.Msg, servers []addr) (*dns.Msg, error) {
log.Printf("query %s use %s:%s\n", m.Question[0].Name, srv.network, srv.addr)
m, _, err := up.Exchange(m, srv.addr)
m1, _, err := up.Exchange(m, srv.addr)
if err == nil && !r.checkBlacklist(m) {
return m, err
if m1.Rcode == dns.RcodeSuccess {
// store to cache
r.cache.set(m1)
}
return m1, err
}
log.Println(err)
@ -74,9 +88,10 @@ func (r routers) ServeDNS(w dns.ResponseWriter, m *dns.Msg) {
if err == nil {
w.WriteMsg(m1)
return
} else {
log.Println(err)
}
log.Println(err)
}
}
@ -95,6 +110,7 @@ func initRouters(c *cfg) {
c,
&dns.Client{Net: "tcp", Timeout: time.Duration(c.Timeout) * time.Second},
&dns.Client{Net: "udp", Timeout: time.Duration(c.Timeout) * time.Second},
newCache(1000, 5*60*60), // cache 5 hours
}
dns.Handle(".", router)
}

@ -17,8 +17,10 @@ func initListeners(c *cfg) {
func main() {
var configFile string
flag.StringVar(&configFile, "c", "", "config file")
flag.Parse()
config, err := parseCfg(configFile)
if err != nil {
log.Println(err)
@ -27,5 +29,6 @@ func main() {
initRouters(config)
initListeners(config)
select {}
}

Loading…
Cancel
Save