add cache store
parent
a1536bb7b4
commit
1cf266166f
@ -0,0 +1,84 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"github.com/miekg/dns"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type cache struct {
|
||||
m map[string]*elem
|
||||
lock sync.RWMutex
|
||||
ttl int64
|
||||
max int
|
||||
}
|
||||
|
||||
type elem struct {
|
||||
m *dns.Msg
|
||||
t int64
|
||||
}
|
||||
|
||||
func newCache(max int, ttl int64) *cache {
|
||||
return &cache{
|
||||
max: max,
|
||||
ttl: ttl,
|
||||
m: map[string]*elem{},
|
||||
lock: sync.RWMutex{},
|
||||
}
|
||||
}
|
||||
|
||||
func key(m *dns.Msg) string {
|
||||
d := m.Question[0].Name
|
||||
b := []byte(d)
|
||||
b1 := make([]byte, 4)
|
||||
binary.BigEndian.PutUint16(b1[0:], m.Question[0].Qclass)
|
||||
binary.BigEndian.PutUint16(b1[2:], m.Question[0].Qtype)
|
||||
b = append(b, b1...)
|
||||
h := md5.New()
|
||||
h.Write(b)
|
||||
s1 := hex.EncodeToString(h.Sum(nil))
|
||||
return s1
|
||||
}
|
||||
|
||||
func (c cache) get(m *dns.Msg) *dns.Msg {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
k := key(m)
|
||||
if m1, ok := c.m[k]; ok {
|
||||
t := time.Now().Unix()
|
||||
if t < m1.t {
|
||||
return m1.m
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c cache) set(m *dns.Msg) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
if len(c.m) >= c.max {
|
||||
log.Printf("clean the old cache")
|
||||
c.cleanOld()
|
||||
}
|
||||
|
||||
k := key(m)
|
||||
c.m[k] = &elem{
|
||||
m.Copy(),
|
||||
time.Now().Unix() + c.ttl,
|
||||
}
|
||||
}
|
||||
|
||||
// must hold the write lock
|
||||
func (c cache) cleanOld() {
|
||||
t1 := time.Now().Unix()
|
||||
for k, v := range c.m {
|
||||
if v.t >= t1 {
|
||||
delete(c.m, k)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,58 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCache(t *testing.T) {
|
||||
c := newCache(5, 2)
|
||||
|
||||
tests := map[string]uint16{
|
||||
"www.google.com": dns.TypeA,
|
||||
"www.google.com.hk": dns.TypeA,
|
||||
"www.google.com.sg": dns.TypeA,
|
||||
"www.google.com.it": dns.TypeA,
|
||||
"www.google.com.de": dns.TypeA,
|
||||
"www.google.com.cn": dns.TypeA,
|
||||
}
|
||||
|
||||
var datas []*dns.Msg
|
||||
|
||||
for k, v := range tests {
|
||||
m1 := new(dns.Msg)
|
||||
m1.SetQuestion(k, v)
|
||||
datas = append(datas, m1)
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
c.set(datas[i])
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
m2 := c.get(datas[i])
|
||||
if m2 == nil {
|
||||
t.Errorf("store cache failed")
|
||||
}
|
||||
if m2.Question[0].Name != datas[i].Question[0].Name {
|
||||
t.Errorf("cache error")
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
for i := 0; i < 3; i++ {
|
||||
m2 := c.get(datas[i])
|
||||
if m2 != nil {
|
||||
t.Errorf("cache not expired")
|
||||
}
|
||||
}
|
||||
|
||||
for i := 3; i < 6; i++ {
|
||||
c.set(datas[i])
|
||||
}
|
||||
|
||||
if len(c.m) > len(datas) {
|
||||
t.Errorf("old cache not purged")
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue