add more functions

1. support ipv6 address in blacklist
2. listen tcp port
3. add remote address on debug log
nghttp2
fangdingjun 9 years ago
parent cd0d993f1e
commit a86b993931

@ -52,6 +52,7 @@ func in_blacklist(m *dns.Msg) bool {
} }
for _, rr := range m.Answer { for _, rr := range m.Answer {
/* A */
if t, ok := rr.(*dns.A); ok { if t, ok := rr.(*dns.A); ok {
ip := t.A.String() ip := t.A.String()
if _, ok1 := Blacklist_ips[ip]; ok1 { if _, ok1 := Blacklist_ips[ip]; ok1 {
@ -59,9 +60,20 @@ func in_blacklist(m *dns.Msg) bool {
return true return true
} }
} }
/* AAAA */
if t, ok := rr.(*dns.AAAA); ok {
ip := t.AAAA.String()
if _, ok1 := Blacklist_ips[ip]; ok1 {
logger.Debug("%s is in blacklist\n", ip)
return true
}
}
} }
return false return false
} }
func handleRoot(w dns.ResponseWriter, r *dns.Msg) { func handleRoot(w dns.ResponseWriter, r *dns.Msg) {
var err error var err error
var res *dns.Msg var res *dns.Msg
@ -79,7 +91,8 @@ func handleRoot(w dns.ResponseWriter, r *dns.Msg) {
msg.SetReply(r) msg.SetReply(r)
msg.Answer = append(msg.Answer, rr) msg.Answer = append(msg.Answer, rr)
w.WriteMsg(msg) w.WriteMsg(msg)
logger.Debug("query %s %s %s, reply from hosts\n", logger.Debug("%s query %s %s %s, reply from hosts\n",
w.RemoteAddr(),
domain, domain,
dns.ClassToString[r.Question[0].Qclass], dns.ClassToString[r.Question[0].Qclass],
dns.TypeToString[r.Question[0].Qtype], dns.TypeToString[r.Question[0].Qtype],
@ -92,7 +105,8 @@ func handleRoot(w dns.ResponseWriter, r *dns.Msg) {
done = 0 done = 0
for _, sv := range Servers { for _, sv := range Servers {
if sv.match(domain) { if sv.match(domain) {
logger.Debug("query %s %s %s, forward to %s:%s\n", logger.Debug("%s query %s %s %s, forward to %s:%s\n",
w.RemoteAddr(),
domain, domain,
dns.ClassToString[r.Question[0].Qclass], dns.ClassToString[r.Question[0].Qclass],
dns.TypeToString[r.Question[0].Qtype], dns.TypeToString[r.Question[0].Qtype],
@ -112,7 +126,8 @@ func handleRoot(w dns.ResponseWriter, r *dns.Msg) {
if done != 1 { if done != 1 {
res, err = DefaultServer.query(r) res, err = DefaultServer.query(r)
logger.Debug("query %s %s %s, use default server %s:%s\n", logger.Debug("%s query %s %s %s, use default server %s:%s\n",
w.RemoteAddr(),
domain, domain,
dns.ClassToString[r.Question[0].Qclass], dns.ClassToString[r.Question[0].Qclass],
dns.TypeToString[r.Question[0].Qtype], dns.TypeToString[r.Question[0].Qtype],
@ -146,6 +161,15 @@ func main() {
logger.Info("Listen on %s\n", bind_addr) logger.Info("Listen on %s\n", bind_addr)
go func() {
/* listen tcp */
err := dns.ListenAndServe(bind_addr, "tcp", nil)
if err != nil {
log.Fatal(err)
}
}()
/* listen udp */
err := dns.ListenAndServe(bind_addr, "udp", nil) err := dns.ListenAndServe(bind_addr, "udp", nil)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

Loading…
Cancel
Save