add concurrent function

add domain blacklist
nghttp2
fangdingjun 9 years ago
parent 783790557e
commit e73cc707c5

@ -27,7 +27,7 @@ var region_file = ""
func (s *ArgSrvs) String() string { func (s *ArgSrvs) String() string {
//Sprintf("%s", s) //Sprintf("%s", s)
return "filter1.txt,udp:8.8.8.8:53" return ""
} }
func (s *ArgSrvs) Set(s1 string) error { func (s *ArgSrvs) Set(s1 string) error {

@ -90,6 +90,7 @@
"31.22.4.60": 1, "31.22.4.60": 1,
"31.222.185.202": 1, "31.222.185.202": 1,
"31.25.191.134": 1, "31.25.191.134": 1,
"31.13.99.18": 1,
"34.254.247.151": 1, "34.254.247.151": 1,
"37.1.205.21": 1, "37.1.205.21": 1,
"37.1.207.129": 1, "37.1.207.129": 1,
@ -319,6 +320,7 @@
"120.198.243.48": 1, "120.198.243.48": 1,
"120.198.243.52": 1, "120.198.243.52": 1,
"120.198.243.53": 1, "120.198.243.53": 1,
"120.197.234.71": 1,
"120.89.93.248": 1, "120.89.93.248": 1,
"122.214.2.171": 1, "122.214.2.171": 1,
"122.218.101.190": 1, "122.218.101.190": 1,

@ -2,6 +2,7 @@ package main
import ( import (
//"fmt" //"fmt"
"errors"
"github.com/miekg/dns" "github.com/miekg/dns"
"net" "net"
"time" "time"
@ -14,38 +15,53 @@ type res struct {
err error err error
} }
func _query(m *dns.Msg, s *UpstreamServer, c chan *res) {
res1 := make(chan *res)
go query_one(s, m, res1)
select {
case r := <-res1:
c <- r
case <-time.After(600 * time.Millisecond):
c <- &res{err: errors.New("timed out")}
}
}
func query(m *dns.Msg) *dns.Msg { func query(m *dns.Msg) *dns.Msg {
resch := make(chan *res, len(DefaultServer)) resch := make(chan *res, len(DefaultServer))
for _, s := range DefaultServer { for _, s := range DefaultServer {
go query_one(s, m, resch) go _query(m, s, resch)
} }
delayed := []*dns.Msg{} delayed := []*dns.Msg{}
slen := len(DefaultServer) slen := len(DefaultServer)
got := 0
for i := 0; i < slen; i++ {
loop: r := <-resch
for { r1 := *r
select { if r1.err != nil {
case r := <-resch: logger.Error("error %s\n", r1.err.Error())
r1 := *r continue
if r1.err != nil { }
logger.Error("error %s\n", r1.err.Error())
continue // drop the result with no error but has an empty result
} if r1.m.Rcode == dns.RcodeSuccess &&
if in_blacklist(r1.m) { len(r1.m.Answer) == 0 {
continue continue
} }
if answer_in_region(r1.m, ip_region) {
return r1.m // drop blacklist
} else { if in_blacklist(r1.m) {
delayed = append(delayed, r1.m) continue
} }
got += 1
if got >= slen { // check ip region
break loop if answer_in_region(r1.m, ip_region) {
} return r1.m
case <-time.After(900 * time.Millisecond): } else {
break loop delayed = append(delayed, r1.m)
} }
} }

File diff suppressed because it is too large Load Diff

@ -68,6 +68,11 @@ func in_region(ip net.IP, nets []*net.IPNet) bool {
// test dns reply A or AAAA in special ip range region // test dns reply A or AAAA in special ip range region
func answer_in_region(m *dns.Msg, nets []*net.IPNet) bool { func answer_in_region(m *dns.Msg, nets []*net.IPNet) bool {
// no region loaded, return true
if len(nets) == 0 {
return true
}
for _, rr := range m.Answer { for _, rr := range m.Answer {
if a, ok := rr.(*dns.A); ok { if a, ok := rr.(*dns.A); ok {
if in_region(a.A, nets) { if in_region(a.A, nets) {

@ -1,20 +1,21 @@
package main package main
import ( import (
"fmt" //"fmt"
"github.com/miekg/dns" "github.com/miekg/dns"
"testing" "testing"
"time"
//"os" //"os"
) )
func TestParseNet(t *testing.T) { func TestParseNet(t *testing.T) {
nets := parse_net("net_china.txt") nets := parse_net("region_cn.txt")
fmt.Printf("get %d networks\n", len(nets)) t.Logf("get %d networks\n", len(nets))
fmt.Printf("1st %s\n", nets[0].String()) t.Logf("1st %s\n", nets[0].String())
} }
func TestQuery(t *testing.T) { func TestQuery(t *testing.T) {
ip_region = parse_net("net_china.txt") ip_region = parse_net("region_cn.txt")
var c *dns.Client var c *dns.Client
for _, srv := range []string{ for _, srv := range []string{
"tcp:114.114.114.114:53", "tcp:114.114.114.114:53",
@ -41,7 +42,7 @@ func TestQuery(t *testing.T) {
blacklist_file = "blacklist.txt" blacklist_file = "blacklist.txt"
a, err := load_domain(blacklist_file) a, err := load_domain(blacklist_file)
if err != nil { if err != nil {
fmt.Println(err) t.Log(err)
} else { } else {
Blacklist_ips = a Blacklist_ips = a
} }
@ -65,10 +66,18 @@ func TestQuery(t *testing.T) {
} { } {
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(dn), dns.TypeA) m.SetQuestion(dns.Fqdn(dn), dns.TypeA)
t1 := time.Now()
m1 := query(m) m1 := query(m)
t2 := time.Now()
if m1 == nil { if m1 == nil {
t.Error("query failed") t.Errorf("query %s failed", dn)
} else {
t.Logf("query time: %s\n", t2.Sub(t1))
t.Logf("result of %s\n", dn)
for _, a1 := range m1.Answer {
t.Logf("%s\n", a1)
}
//.Printf("\n")
} }
fmt.Printf("%s\n", m1)
} }
} }

@ -154,7 +154,8 @@ func handleRoot(w dns.ResponseWriter, r *dns.Msg) {
dns.RcodeToString[res.Rcode], dns.RcodeToString[res.Rcode],
) )
if res.Rcode != dns.RcodeServerFailure && !in_blacklist(res) { if res.Rcode == dns.RcodeSuccess &&
!in_blacklist(res) && len(res.Answer) > 0 {
if enable_cache { if enable_cache {
// add to cache // add to cache
v := []string{} v := []string{}
@ -180,7 +181,9 @@ func handleRoot(w dns.ResponseWriter, r *dns.Msg) {
) )
res := query(r) res := query(r)
if res != nil { if res != nil {
if enable_cache { //logger.Debug("get: %s", res)
if enable_cache && res.Rcode == dns.RcodeSuccess &&
len(res.Answer) > 0 {
// add to cache // add to cache
v := []string{} v := []string{}
for _, as := range res.Answer { for _, as := range res.Answer {

@ -20,8 +20,7 @@ func (srv *UpstreamServer) match(d string) bool {
} }
s := strings.Split(strings.Trim(d, "."), ".") s := strings.Split(strings.Trim(d, "."), ".")
for i := 0; i < len(s); i++ {
for i := 0; i < len(s)-1; i++ {
s1 := strings.Join(s[i:], ".") s1 := strings.Join(s[i:], ".")
if _, ok := srv.domains[s1]; ok { if _, ok := srv.domains[s1]; ok {
return true return true

@ -17,7 +17,7 @@ func TestServerMathNil(t *testing.T) {
} }
func TestServerMatch(t *testing.T) { func TestServerMatch(t *testing.T) {
d := Kv{"twitter.com": 1, "google.com": 1} d := Kv{"twitter.com": 1, "google.com": 1, "cn": 1}
srv := UpstreamServer{domains: d} srv := UpstreamServer{domains: d}
test_domains := map[string]bool{ test_domains := map[string]bool{
@ -33,6 +33,7 @@ func TestServerMatch(t *testing.T) {
"twitter.com.aa.com": false, "twitter.com.aa.com": false,
"google.com.cccc.com": false, "google.com.cccc.com": false,
"google.com.aeddasdfc3.com": false, "google.com.aeddasdfc3.com": false,
"ip.cn": true,
} }
for d, r := range test_domains { for d, r := range test_domains {

Loading…
Cancel
Save