You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
101 lines
1.8 KiB
Go
101 lines
1.8 KiB
Go
8 years ago
|
package main
|
||
|
|
||
|
import (
|
||
|
"github.com/miekg/dns"
|
||
|
"log"
|
||
|
"strings"
|
||
|
"time"
|
||
|
)
|
||
|
|
||
|
type routers struct {
|
||
|
c *cfg
|
||
|
tcp *dns.Client
|
||
|
udp *dns.Client
|
||
|
}
|
||
|
|
||
|
func (r routers) checkBlacklist(m *dns.Msg) bool {
|
||
|
if m.Rcode != dns.RcodeSuccess {
|
||
|
// not success, not in blacklist
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
for _, rr := range m.Answer {
|
||
|
var ip = ""
|
||
|
if t, ok := rr.(*dns.A); ok {
|
||
|
ip = t.A.String()
|
||
|
} else if t, ok := rr.(*dns.AAAA); ok {
|
||
|
ip = t.AAAA.String()
|
||
|
}
|
||
|
|
||
|
if ip != "" && r.c.blacklistIps.has(ip) {
|
||
|
log.Printf("%s is in blacklist.\n", ip)
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
}
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
func (r routers) query(m *dns.Msg, servers []addr) (*dns.Msg, error) {
|
||
|
var up *dns.Client
|
||
|
var lastErr error
|
||
|
for _, srv := range servers {
|
||
|
switch srv.network {
|
||
|
case "tcp":
|
||
|
up = r.tcp
|
||
|
case "udp":
|
||
|
up = r.udp
|
||
|
default:
|
||
|
up = r.udp
|
||
|
}
|
||
|
|
||
|
log.Printf("query %s use %s:%s\n", m.Question[0].Name, srv.network, srv.addr)
|
||
|
|
||
|
m, _, err := up.Exchange(m, srv.addr)
|
||
|
if err == nil && !r.checkBlacklist(m) {
|
||
|
return m, err
|
||
|
}
|
||
|
|
||
|
log.Println(err)
|
||
|
lastErr = err
|
||
|
}
|
||
|
|
||
|
// return last error
|
||
|
return nil, lastErr
|
||
|
}
|
||
|
|
||
|
// ServeDNS implements dns.Handler interface
|
||
|
func (r routers) ServeDNS(w dns.ResponseWriter, m *dns.Msg) {
|
||
|
domain := m.Question[0].Name
|
||
|
d := strings.Trim(domain, ".")
|
||
|
for _, rule := range r.c.Rules {
|
||
|
if rule.domains.match(d) {
|
||
|
m1, err := r.query(m, rule.servers)
|
||
|
if err == nil {
|
||
|
w.WriteMsg(m1)
|
||
|
return
|
||
|
} else {
|
||
|
log.Println(err)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// no match or failed, fallback to default
|
||
|
m1, err := r.query(m, r.c.servers)
|
||
|
if err != nil {
|
||
|
log.Println(err)
|
||
|
dns.HandleFailed(w, m)
|
||
|
} else {
|
||
|
w.WriteMsg(m1)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func initRouters(c *cfg) {
|
||
|
router := &routers{
|
||
|
c,
|
||
|
&dns.Client{Net: "tcp", Timeout: time.Duration(c.Timeout) * time.Second},
|
||
|
&dns.Client{Net: "udp", Timeout: time.Duration(c.Timeout) * time.Second},
|
||
|
}
|
||
|
dns.Handle(".", router)
|
||
|
}
|