commit d1dcee056f40e6088be9a259b18422316cbc0b07 Author: fangdingjun Date: Mon Aug 31 17:40:51 2015 +0800 0.0.1 diff --git a/arguments.go b/arguments.go new file mode 100644 index 0000000..5cd0866 --- /dev/null +++ b/arguments.go @@ -0,0 +1,89 @@ +package main + +import ( + "flag" + "github.com/miekg/dns" + "github.com/vharitonsky/iniflags" + "log" +) + +var bind_addr string + +var default_server string + +var srv ArgSrvs + +var logfile string + +type ArgSrvs []string + +var DefaultServer *UpstreamServer + +var blacklist_file string + +func (s *ArgSrvs) String() string { + //Sprintf("%s", s) + return "filter1.txt,udp:8.8.8.8:53" +} + +func (s *ArgSrvs) Set(s1 string) error { + *s = append(*s, s1) + return nil +} + +func parse_flags() { + iniflags.Parse() + + var err error + for _, s := range srv { + sv, err := parse_server(s) + if err != nil { + log.Print(err) + } else { + Servers = append(Servers, sv) + } + } + + proto, addr, err := parse_addr(default_server) + if err != nil { + log.Fatal(err) + } + + var c *dns.Client + if proto == "udp" { + c = client_udp + } else { + c = client_tcp + } + + DefaultServer = &UpstreamServer{ + Addr: addr, + Proto: proto, + client: c, + } + + a, err := load_domain(blacklist_file) + if err != nil { + log.Println(err) + } else { + Blacklist_ips = a + } +} + +func init() { + + flag.Var(&srv, "server", `special the filter and the upstream server to use when match + format: + FILTER_FILE_NAME,PROTOCOL:SERVER_NAME:PORT + example: + filter1.json,udp:8.8.8.8:53 + means the domains in the filter1.json will use the google dns server by udp + you can specail multiple filter and upstream server + `) + + flag.StringVar(&bind_addr, "bind", ":53", "the address bind to") + flag.StringVar(&default_server, "upstream", "udp:114.114.114.114:53", "the default upstream server to use") + flag.StringVar(&logfile, "logfile", "error.log", "the logfile, default stdout") + flag.StringVar(&blacklist_file, "blacklist", "", "the blacklist file") + flag.BoolVar(&debug, "debug", false, "output debug log, default false") +} diff --git a/log.go b/log.go new file mode 100644 index 0000000..6c8128a --- /dev/null +++ b/log.go @@ -0,0 +1,51 @@ +package main + +import ( + "log" + "os" +) + +//var LogLevel int + +type LogOut struct { + //out *os.File + debug bool + dbglog *log.Logger + errlog *log.Logger + infolog *log.Logger +} + +func NewLogger(logfile string, debug bool) *LogOut { + var out *os.File + var err error + if logfile != "" { + out, err = os.OpenFile(logfile, os.O_APPEND|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + log.Println(err) + out = os.Stdout + } + } else { + out = os.Stdout + } + + return &LogOut{ + debug, + log.New(out, "DEBUG: ", log.LstdFlags), + log.New(out, "ERROR: ", log.LstdFlags), + log.New(out, "INFO: ", log.LstdFlags), + } +} + +func (l *LogOut) Debug(format string, args ...interface{}) { + if l.debug { + l.dbglog.Printf(format, args...) + } +} + +func (l *LogOut) Error(format string, args ...interface{}) { + l.errlog.Printf(format, args...) +} + +func (l *LogOut) Info(format string, args ...interface{}) { + l.infolog.Printf(format, args...) +} diff --git a/parse.go b/parse.go new file mode 100644 index 0000000..68de35d --- /dev/null +++ b/parse.go @@ -0,0 +1,76 @@ +package main + +import ( + "encoding/json" + "errors" + . "fmt" + "github.com/miekg/dns" + "io/ioutil" + "log" + "strings" +) + +func load_domain(f string) (Kv, error) { + var m1 Kv + c, err := ioutil.ReadFile(f) + if err != nil { + return nil, err + } + err = json.Unmarshal(c, &m1) + if err != nil { + return nil, err + } + return m1, nil +} + +func parse_addr(s string) (string, string, error) { + s2 := strings.Split(s, ":") + if len(s2) != 3 { + msg := Sprintf("error %s not well formatted\n", s2) + err := errors.New(msg) + return "", "", err + } + if s2[0] != "tcp" && s2[0] != "udp" { + msg := Sprintf("invalid %s, only tcp or udp allowed\n", s2[0]) + err := errors.New(msg) + return "", "", err + } + t := Sprintf("%s:%s", s2[1], s2[2]) + return s2[0], t, nil +} + +func parse_server(s string) (*UpstreamServer, error) { + s1 := strings.Split(s, ",") + + if len(s1) != 2 { + msg := Sprintf("error %s not well formatted\n", s) + err := errors.New(msg) + return nil, err + } + + proto, addr, err := parse_addr(s1[1]) + if err != nil { + log.Fatal(err) + } + + var c *dns.Client + if proto == "tcp" { + c = client_tcp + } else { + c = client_udp + } + + d, err := load_domain(s1[0]) + if err != nil { + log.Print(err) + } + + var sv *UpstreamServer = &UpstreamServer{ + Addr: addr, + domains: d, + client: c, + Proto: proto, + } + + return sv, nil +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..bcf8d42 --- /dev/null +++ b/server.go @@ -0,0 +1,109 @@ +package main + +import ( + "github.com/miekg/dns" + "log" +) + +var client_udp *dns.Client = &dns.Client{} + +var client_tcp *dns.Client = &dns.Client{Net: "tcp"} + +var Servers []*UpstreamServer = nil + +var logger *LogOut = nil + +var Blacklist_ips Kv = nil + +var debug bool = false + +func in_blacklist(m *dns.Msg) bool { + if Blacklist_ips == nil { + return false + } + + if m == nil { + return false + } + + for _, rr := range m.Answer { + if t, ok := rr.(*dns.A); ok { + ip := t.A.String() + if _, ok1 := Blacklist_ips[ip]; ok1 { + logger.Debug("%s is in blacklist\n", ip) + return true + } + } + } + return false +} +func handleRoot(w dns.ResponseWriter, r *dns.Msg) { + var err error + var res *dns.Msg + domain := r.Question[0].Name + + var done int + + for i := 0; i < 2; i++ { + done = 0 + for _, sv := range Servers { + if sv.match(domain) { + logger.Debug("query %s %s %s, forward to %s:%s\n", + domain, + dns.ClassToString[r.Question[0].Qclass], + dns.TypeToString[r.Question[0].Qtype], + sv.Proto, sv.Addr) + res, err = sv.query(r) + if err == nil { + if !in_blacklist(res) && res.Rcode != dns.RcodeServerFailure { + w.WriteMsg(res) + done = 1 + break + } + } else { + logger.Error("%s", err) + } + } + } + + if done != 1 { + res, err = DefaultServer.query(r) + logger.Debug("query %s %s %s, use default server %s:%s\n", + domain, + dns.ClassToString[r.Question[0].Qclass], + dns.TypeToString[r.Question[0].Qtype], + DefaultServer.Proto, DefaultServer.Addr) + if err == nil { + if !in_blacklist(res) && res.Rcode != dns.RcodeServerFailure { + w.WriteMsg(res) + done = 1 + } + } else { + logger.Error("%s", err) + } + } + + if done == 1 { + break + } + } + + if done != 1 { + dns.HandleFailed(w, r) + } +} + +func main() { + parse_flags() + + dns.HandleFunc(".", handleRoot) + + logger = NewLogger(logfile, debug) + + logger.Info("Listen on %s\n", bind_addr) + + err := dns.ListenAndServe(bind_addr, "udp", nil) + if err != nil { + log.Fatal(err) + } +} diff --git a/upstream.go b/upstream.go new file mode 100644 index 0000000..162526e --- /dev/null +++ b/upstream.go @@ -0,0 +1,37 @@ +package main + +import ( + "github.com/miekg/dns" + "strings" +) + +type Kv map[string]int + +type UpstreamServer struct { + domains Kv + Proto string + Addr string + client *dns.Client +} + +func (srv *UpstreamServer) match(d string) bool { + if srv.domains == nil { + return false + } + + s := strings.Split(strings.Trim(d, "."), ".") + + for i := 0; i < len(s)-1; i++ { + s1 := strings.Join(s[i:], ".") + if _, ok := srv.domains[s1]; ok { + return true + } + } + + return false +} + +func (srv *UpstreamServer) query(req *dns.Msg) (*dns.Msg, error) { + res, _, err := srv.client.Exchange(req, srv.Addr) + return res, err +}