0.0.1
commit
d1dcee056f
@ -0,0 +1,89 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/vharitonsky/iniflags"
|
||||||
|
"log"
|
||||||
|
)
|
||||||
|
|
||||||
|
var bind_addr string
|
||||||
|
|
||||||
|
var default_server string
|
||||||
|
|
||||||
|
var srv ArgSrvs
|
||||||
|
|
||||||
|
var logfile string
|
||||||
|
|
||||||
|
type ArgSrvs []string
|
||||||
|
|
||||||
|
var DefaultServer *UpstreamServer
|
||||||
|
|
||||||
|
var blacklist_file string
|
||||||
|
|
||||||
|
func (s *ArgSrvs) String() string {
|
||||||
|
//Sprintf("%s", s)
|
||||||
|
return "filter1.txt,udp:8.8.8.8:53"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ArgSrvs) Set(s1 string) error {
|
||||||
|
*s = append(*s, s1)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parse_flags() {
|
||||||
|
iniflags.Parse()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
for _, s := range srv {
|
||||||
|
sv, err := parse_server(s)
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
} else {
|
||||||
|
Servers = append(Servers, sv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
proto, addr, err := parse_addr(default_server)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var c *dns.Client
|
||||||
|
if proto == "udp" {
|
||||||
|
c = client_udp
|
||||||
|
} else {
|
||||||
|
c = client_tcp
|
||||||
|
}
|
||||||
|
|
||||||
|
DefaultServer = &UpstreamServer{
|
||||||
|
Addr: addr,
|
||||||
|
Proto: proto,
|
||||||
|
client: c,
|
||||||
|
}
|
||||||
|
|
||||||
|
a, err := load_domain(blacklist_file)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
} else {
|
||||||
|
Blacklist_ips = a
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
|
||||||
|
flag.Var(&srv, "server", `special the filter and the upstream server to use when match
|
||||||
|
format:
|
||||||
|
FILTER_FILE_NAME,PROTOCOL:SERVER_NAME:PORT
|
||||||
|
example:
|
||||||
|
filter1.json,udp:8.8.8.8:53
|
||||||
|
means the domains in the filter1.json will use the google dns server by udp
|
||||||
|
you can specail multiple filter and upstream server
|
||||||
|
`)
|
||||||
|
|
||||||
|
flag.StringVar(&bind_addr, "bind", ":53", "the address bind to")
|
||||||
|
flag.StringVar(&default_server, "upstream", "udp:114.114.114.114:53", "the default upstream server to use")
|
||||||
|
flag.StringVar(&logfile, "logfile", "error.log", "the logfile, default stdout")
|
||||||
|
flag.StringVar(&blacklist_file, "blacklist", "", "the blacklist file")
|
||||||
|
flag.BoolVar(&debug, "debug", false, "output debug log, default false")
|
||||||
|
}
|
@ -0,0 +1,51 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
//var LogLevel int
|
||||||
|
|
||||||
|
type LogOut struct {
|
||||||
|
//out *os.File
|
||||||
|
debug bool
|
||||||
|
dbglog *log.Logger
|
||||||
|
errlog *log.Logger
|
||||||
|
infolog *log.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLogger(logfile string, debug bool) *LogOut {
|
||||||
|
var out *os.File
|
||||||
|
var err error
|
||||||
|
if logfile != "" {
|
||||||
|
out, err = os.OpenFile(logfile, os.O_APPEND|os.O_CREATE|os.O_TRUNC, 0644)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
out = os.Stdout
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
out = os.Stdout
|
||||||
|
}
|
||||||
|
|
||||||
|
return &LogOut{
|
||||||
|
debug,
|
||||||
|
log.New(out, "DEBUG: ", log.LstdFlags),
|
||||||
|
log.New(out, "ERROR: ", log.LstdFlags),
|
||||||
|
log.New(out, "INFO: ", log.LstdFlags),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LogOut) Debug(format string, args ...interface{}) {
|
||||||
|
if l.debug {
|
||||||
|
l.dbglog.Printf(format, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LogOut) Error(format string, args ...interface{}) {
|
||||||
|
l.errlog.Printf(format, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LogOut) Info(format string, args ...interface{}) {
|
||||||
|
l.infolog.Printf(format, args...)
|
||||||
|
}
|
@ -0,0 +1,76 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
. "fmt"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func load_domain(f string) (Kv, error) {
|
||||||
|
var m1 Kv
|
||||||
|
c, err := ioutil.ReadFile(f)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(c, &m1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return m1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parse_addr(s string) (string, string, error) {
|
||||||
|
s2 := strings.Split(s, ":")
|
||||||
|
if len(s2) != 3 {
|
||||||
|
msg := Sprintf("error %s not well formatted\n", s2)
|
||||||
|
err := errors.New(msg)
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
if s2[0] != "tcp" && s2[0] != "udp" {
|
||||||
|
msg := Sprintf("invalid %s, only tcp or udp allowed\n", s2[0])
|
||||||
|
err := errors.New(msg)
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
t := Sprintf("%s:%s", s2[1], s2[2])
|
||||||
|
return s2[0], t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parse_server(s string) (*UpstreamServer, error) {
|
||||||
|
s1 := strings.Split(s, ",")
|
||||||
|
|
||||||
|
if len(s1) != 2 {
|
||||||
|
msg := Sprintf("error %s not well formatted\n", s)
|
||||||
|
err := errors.New(msg)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
proto, addr, err := parse_addr(s1[1])
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var c *dns.Client
|
||||||
|
if proto == "tcp" {
|
||||||
|
c = client_tcp
|
||||||
|
} else {
|
||||||
|
c = client_udp
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := load_domain(s1[0])
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var sv *UpstreamServer = &UpstreamServer{
|
||||||
|
Addr: addr,
|
||||||
|
domains: d,
|
||||||
|
client: c,
|
||||||
|
Proto: proto,
|
||||||
|
}
|
||||||
|
|
||||||
|
return sv, nil
|
||||||
|
}
|
@ -0,0 +1,109 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"log"
|
||||||
|
)
|
||||||
|
|
||||||
|
var client_udp *dns.Client = &dns.Client{}
|
||||||
|
|
||||||
|
var client_tcp *dns.Client = &dns.Client{Net: "tcp"}
|
||||||
|
|
||||||
|
var Servers []*UpstreamServer = nil
|
||||||
|
|
||||||
|
var logger *LogOut = nil
|
||||||
|
|
||||||
|
var Blacklist_ips Kv = nil
|
||||||
|
|
||||||
|
var debug bool = false
|
||||||
|
|
||||||
|
func in_blacklist(m *dns.Msg) bool {
|
||||||
|
if Blacklist_ips == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if m == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rr := range m.Answer {
|
||||||
|
if t, ok := rr.(*dns.A); ok {
|
||||||
|
ip := t.A.String()
|
||||||
|
if _, ok1 := Blacklist_ips[ip]; ok1 {
|
||||||
|
logger.Debug("%s is in blacklist\n", ip)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
func handleRoot(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
var err error
|
||||||
|
var res *dns.Msg
|
||||||
|
domain := r.Question[0].Name
|
||||||
|
|
||||||
|
var done int
|
||||||
|
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
done = 0
|
||||||
|
for _, sv := range Servers {
|
||||||
|
if sv.match(domain) {
|
||||||
|
logger.Debug("query %s %s %s, forward to %s:%s\n",
|
||||||
|
domain,
|
||||||
|
dns.ClassToString[r.Question[0].Qclass],
|
||||||
|
dns.TypeToString[r.Question[0].Qtype],
|
||||||
|
sv.Proto, sv.Addr)
|
||||||
|
res, err = sv.query(r)
|
||||||
|
if err == nil {
|
||||||
|
if !in_blacklist(res) && res.Rcode != dns.RcodeServerFailure {
|
||||||
|
w.WriteMsg(res)
|
||||||
|
done = 1
|
||||||
|
break
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Error("%s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if done != 1 {
|
||||||
|
res, err = DefaultServer.query(r)
|
||||||
|
logger.Debug("query %s %s %s, use default server %s:%s\n",
|
||||||
|
domain,
|
||||||
|
dns.ClassToString[r.Question[0].Qclass],
|
||||||
|
dns.TypeToString[r.Question[0].Qtype],
|
||||||
|
DefaultServer.Proto, DefaultServer.Addr)
|
||||||
|
if err == nil {
|
||||||
|
if !in_blacklist(res) && res.Rcode != dns.RcodeServerFailure {
|
||||||
|
w.WriteMsg(res)
|
||||||
|
done = 1
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Error("%s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if done == 1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if done != 1 {
|
||||||
|
dns.HandleFailed(w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
parse_flags()
|
||||||
|
|
||||||
|
dns.HandleFunc(".", handleRoot)
|
||||||
|
|
||||||
|
logger = NewLogger(logfile, debug)
|
||||||
|
|
||||||
|
logger.Info("Listen on %s\n", bind_addr)
|
||||||
|
|
||||||
|
err := dns.ListenAndServe(bind_addr, "udp", nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,37 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Kv map[string]int
|
||||||
|
|
||||||
|
type UpstreamServer struct {
|
||||||
|
domains Kv
|
||||||
|
Proto string
|
||||||
|
Addr string
|
||||||
|
client *dns.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *UpstreamServer) match(d string) bool {
|
||||||
|
if srv.domains == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
s := strings.Split(strings.Trim(d, "."), ".")
|
||||||
|
|
||||||
|
for i := 0; i < len(s)-1; i++ {
|
||||||
|
s1 := strings.Join(s[i:], ".")
|
||||||
|
if _, ok := srv.domains[s1]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *UpstreamServer) query(req *dns.Msg) (*dns.Msg, error) {
|
||||||
|
res, _, err := srv.client.Exchange(req, srv.Addr)
|
||||||
|
return res, err
|
||||||
|
}
|
Loading…
Reference in New Issue