add concurrent function

add domain blacklist
nghttp2
fangdingjun 9 years ago
parent 783790557e
commit e73cc707c5

@ -27,7 +27,7 @@ var region_file = ""
func (s *ArgSrvs) String() string {
//Sprintf("%s", s)
return "filter1.txt,udp:8.8.8.8:53"
return ""
}
func (s *ArgSrvs) Set(s1 string) error {

@ -90,6 +90,7 @@
"31.22.4.60": 1,
"31.222.185.202": 1,
"31.25.191.134": 1,
"31.13.99.18": 1,
"34.254.247.151": 1,
"37.1.205.21": 1,
"37.1.207.129": 1,
@ -319,6 +320,7 @@
"120.198.243.48": 1,
"120.198.243.52": 1,
"120.198.243.53": 1,
"120.197.234.71": 1,
"120.89.93.248": 1,
"122.214.2.171": 1,
"122.218.101.190": 1,

@ -2,6 +2,7 @@ package main
import (
//"fmt"
"errors"
"github.com/miekg/dns"
"net"
"time"
@ -14,38 +15,53 @@ type res struct {
err error
}
func _query(m *dns.Msg, s *UpstreamServer, c chan *res) {
res1 := make(chan *res)
go query_one(s, m, res1)
select {
case r := <-res1:
c <- r
case <-time.After(600 * time.Millisecond):
c <- &res{err: errors.New("timed out")}
}
}
func query(m *dns.Msg) *dns.Msg {
resch := make(chan *res, len(DefaultServer))
for _, s := range DefaultServer {
go query_one(s, m, resch)
go _query(m, s, resch)
}
delayed := []*dns.Msg{}
slen := len(DefaultServer)
got := 0
loop:
for {
select {
case r := <-resch:
r1 := *r
if r1.err != nil {
logger.Error("error %s\n", r1.err.Error())
continue
}
if in_blacklist(r1.m) {
continue
}
if answer_in_region(r1.m, ip_region) {
return r1.m
} else {
delayed = append(delayed, r1.m)
}
got += 1
if got >= slen {
break loop
}
case <-time.After(900 * time.Millisecond):
break loop
for i := 0; i < slen; i++ {
r := <-resch
r1 := *r
if r1.err != nil {
logger.Error("error %s\n", r1.err.Error())
continue
}
// drop the result with no error but has an empty result
if r1.m.Rcode == dns.RcodeSuccess &&
len(r1.m.Answer) == 0 {
continue
}
// drop blacklist
if in_blacklist(r1.m) {
continue
}
// check ip region
if answer_in_region(r1.m, ip_region) {
return r1.m
} else {
delayed = append(delayed, r1.m)
}
}

File diff suppressed because it is too large Load Diff

@ -68,6 +68,11 @@ func in_region(ip net.IP, nets []*net.IPNet) bool {
// test dns reply A or AAAA in special ip range region
func answer_in_region(m *dns.Msg, nets []*net.IPNet) bool {
// no region loaded, return true
if len(nets) == 0 {
return true
}
for _, rr := range m.Answer {
if a, ok := rr.(*dns.A); ok {
if in_region(a.A, nets) {

@ -1,20 +1,21 @@
package main
import (
"fmt"
//"fmt"
"github.com/miekg/dns"
"testing"
"time"
//"os"
)
func TestParseNet(t *testing.T) {
nets := parse_net("net_china.txt")
fmt.Printf("get %d networks\n", len(nets))
fmt.Printf("1st %s\n", nets[0].String())
nets := parse_net("region_cn.txt")
t.Logf("get %d networks\n", len(nets))
t.Logf("1st %s\n", nets[0].String())
}
func TestQuery(t *testing.T) {
ip_region = parse_net("net_china.txt")
ip_region = parse_net("region_cn.txt")
var c *dns.Client
for _, srv := range []string{
"tcp:114.114.114.114:53",
@ -41,7 +42,7 @@ func TestQuery(t *testing.T) {
blacklist_file = "blacklist.txt"
a, err := load_domain(blacklist_file)
if err != nil {
fmt.Println(err)
t.Log(err)
} else {
Blacklist_ips = a
}
@ -65,10 +66,18 @@ func TestQuery(t *testing.T) {
} {
m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(dn), dns.TypeA)
t1 := time.Now()
m1 := query(m)
t2 := time.Now()
if m1 == nil {
t.Error("query failed")
t.Errorf("query %s failed", dn)
} else {
t.Logf("query time: %s\n", t2.Sub(t1))
t.Logf("result of %s\n", dn)
for _, a1 := range m1.Answer {
t.Logf("%s\n", a1)
}
//.Printf("\n")
}
fmt.Printf("%s\n", m1)
}
}

@ -154,7 +154,8 @@ func handleRoot(w dns.ResponseWriter, r *dns.Msg) {
dns.RcodeToString[res.Rcode],
)
if res.Rcode != dns.RcodeServerFailure && !in_blacklist(res) {
if res.Rcode == dns.RcodeSuccess &&
!in_blacklist(res) && len(res.Answer) > 0 {
if enable_cache {
// add to cache
v := []string{}
@ -180,7 +181,9 @@ func handleRoot(w dns.ResponseWriter, r *dns.Msg) {
)
res := query(r)
if res != nil {
if enable_cache {
//logger.Debug("get: %s", res)
if enable_cache && res.Rcode == dns.RcodeSuccess &&
len(res.Answer) > 0 {
// add to cache
v := []string{}
for _, as := range res.Answer {

@ -20,8 +20,7 @@ func (srv *UpstreamServer) match(d string) bool {
}
s := strings.Split(strings.Trim(d, "."), ".")
for i := 0; i < len(s)-1; i++ {
for i := 0; i < len(s); i++ {
s1 := strings.Join(s[i:], ".")
if _, ok := srv.domains[s1]; ok {
return true

@ -17,7 +17,7 @@ func TestServerMathNil(t *testing.T) {
}
func TestServerMatch(t *testing.T) {
d := Kv{"twitter.com": 1, "google.com": 1}
d := Kv{"twitter.com": 1, "google.com": 1, "cn": 1}
srv := UpstreamServer{domains: d}
test_domains := map[string]bool{
@ -33,6 +33,7 @@ func TestServerMatch(t *testing.T) {
"twitter.com.aa.com": false,
"google.com.cccc.com": false,
"google.com.aeddasdfc3.com": false,
"ip.cn": true,
}
for d, r := range test_domains {

Loading…
Cancel
Save