gdns: first version

main
fangdingjun 8 years ago
commit 13c78296f2

2
.gitignore vendored

@ -0,0 +1,2 @@
*~
dns*

@ -0,0 +1,21 @@
gdns
====
gdns is a dns proxy server
features
=======
- support forward the query by rule,
different domains use different upstream server
- support ip black list
- support google https dns
usage
=====
go get github.com/fangdingjun/gdns
cp $GOPATH/src/github.com/fangdingjun/gdns/config_example.yaml config.yaml
vim config.yaml
$GOPATH/bin/gdns -c config.yaml

@ -0,0 +1,212 @@
package main
import (
"bufio"
"github.com/go-yaml/yaml"
"io/ioutil"
"os"
"strings"
)
type conf struct {
Listen []addr
BlacklistFile string
HostFile string
ForwardRules []rule
DefaultUpstream []addr
Timeout int
Debug bool
blacklist item
hosts hostitem
}
type rule struct {
Server []addr
DomainFile string
domains item
}
type addr struct {
Host string
Port int
Network string
}
func loadConfig(f string) (*conf, error) {
c := new(conf)
data, err := ioutil.ReadFile(f)
if err != nil {
return nil, err
}
if err := yaml.Unmarshal(data, c); err != nil {
return nil, err
}
if c.Debug {
logLevel = DEBUG
}
if c.blacklist == nil {
c.blacklist = item{}
}
if c.Timeout == 0 {
c.Timeout = 2
}
if err := loadItemFile(c.blacklist, c.BlacklistFile); err != nil {
return nil, err
}
for i := range c.ForwardRules {
if c.ForwardRules[i].domains == nil {
c.ForwardRules[i].domains = item{}
}
if err := loadItemFile(c.ForwardRules[i].domains,
c.ForwardRules[i].DomainFile); err != nil {
return nil, err
}
}
if c.hosts == nil {
c.hosts = hostitem{}
}
if err := loadHostsFile(c.hosts, c.HostFile); err != nil {
return nil, err
}
return c, nil
}
func loadHostsFile(h hostitem, f string) error {
if f == "" {
return nil
}
fd, err := os.Open(f)
if err != nil {
return err
}
defer fd.Close()
r := bufio.NewReader(fd)
for {
s, err := r.ReadString('\n')
if err != nil {
break
}
s1 := strings.Trim(s, " \t\r\n")
// ignore blank line and comment
if s1 == "" || s1[0] == '#' {
continue
}
s1 = strings.Replace(s1, "\t", " ", -1)
s1 = strings.Trim(s1, " \t\r\n")
ss := strings.Split(s1, " ")
// ipv4
t := 1
if strings.Index(ss[0], ":") != -1 {
// ipv6
t = 28
}
for _, s2 := range ss[1:] {
if s2 == "" {
continue
}
h.add(s2, ss[0], t)
}
}
return nil
}
func loadItemFile(it item, f string) error {
if f == "" {
return nil
}
fd, err := os.Open(f)
if err != nil {
return err
}
defer fd.Close()
r := bufio.NewReader(fd)
for {
s, err := r.ReadString('\n')
if s != "" {
s1 := strings.Trim(s, " \r\n")
if s1 != "" && s1[0] != '#' {
it.add(s1)
}
}
if err != nil {
break
}
}
return nil
}
type item map[string]int
func (it item) has(s string) bool {
ss := strings.Split(s, ".")
for i := 0; i < len(ss); i++ {
s1 := strings.Join(ss[i:], ".")
if _, ok := it[s1]; ok {
return true
}
}
return false
}
func (it item) exists(s string) bool {
_, ok := it[s]
return ok
}
func (it item) add(s string) {
it[s] = 1
}
type hostitem map[string][]hostentry
func (ht hostitem) get(domain string, t int) string {
if v, ok := ht[domain]; ok {
for _, v1 := range v {
if v1.domain == domain && v1.t == t {
return v1.ip
}
}
}
return ""
}
func (ht hostitem) add(domain, ip string, t int) {
if v, ok := ht[domain]; ok {
exists := false
for _, v1 := range v {
if v1.domain == domain && v1.ip == ip && v1.t == t {
exists = true
break
}
}
if !exists {
ht[domain] = append(ht[domain], hostentry{domain, ip, t})
}
} else {
v1 := []hostentry{{domain, ip, t}}
ht[domain] = v1
}
}
type hostentry struct {
domain string
ip string
t int
}

@ -0,0 +1,121 @@
package main
import (
"fmt"
//"github.com/go-yaml/yaml"
//"io/ioutil"
"testing"
)
func TestConf(t *testing.T) {
c, err := loadConfig("testdata/config.yaml")
if err != nil {
t.Fatal(err)
}
//fmt.Printf("%#v\n", c)
if len(c.Listen) != 2 {
t.Errorf("expected listers 2, got %d", len(c.Listen))
}
if !c.blacklist.exists("1.2.3.4") {
fmt.Printf("%#v\n", c.blacklist)
t.Errorf("blacklist load failed")
}
if c.hosts.get("localhost", 1) != "127.0.0.1" {
fmt.Printf("%#v\n", c.hosts)
t.Errorf("hosts load failed")
}
if c.hosts.get("localhost", 28) != "::1" {
fmt.Printf("%#v\n", c.hosts)
t.Errorf("hosts load failed")
}
if len(c.ForwardRules) != 2 {
fmt.Printf("%#v\n", c.ForwardRules)
t.Errorf("expected rules 2, got %d", len(c.ForwardRules))
}
if !c.ForwardRules[0].domains.has("a.com") {
fmt.Printf("%#v\n", c.ForwardRules[0].domains)
t.Errorf("some domains should exit, may be load config failed")
}
if !c.ForwardRules[1].domains.has("d.com") {
fmt.Printf("%#v\n", c.ForwardRules[1].domains)
t.Errorf("some domains should exit, may be load config failed")
}
}
func TestItemExists(t *testing.T) {
it := item{
"google.cn": 1,
"www.baidu.com": 1,
"org": 1,
}
testdata := []struct {
d string
b bool
}{
{"google.cn", true},
{"www.google.cn", false},
{"www.a.org", false},
}
for _, d := range testdata {
b1 := it.exists(d.d)
if b1 != d.b {
t.Errorf("%s, expected %v, got %v", d.d, d.b, b1)
}
}
}
func TestItemHas(t *testing.T) {
it := item{
"google.cn": 1,
"www.baidu.com": 1,
"org": 1,
}
testdata := []struct {
d string
b bool
}{
{"google.cn", true},
{"www.google.cn", true},
{"www.a.org", true},
{"pan.baidu.com", false},
{"abc.org", true},
}
for _, d := range testdata {
b1 := it.has(d.d)
if b1 != d.b {
t.Errorf("%s, expected %v, got %v", d.d, d.b, b1)
}
}
}
func TestItemAdd(t *testing.T) {
it := item{}
it.add("www.example.org")
if !it.exists("www.example.org") {
t.Errorf("add failed")
}
}
func TestHostitem(t *testing.T) {
ht := hostitem{}
testdata := []hostentry{
{"www.google.com", "127.0.0.1", 1},
{"www.google.com", "127.0.0.2", 28},
{"www.example.org", "127.0.0.3", 28},
{"www.abc.org", "127.0.0.4", 1},
}
for _, d := range testdata {
ht.add(d.domain, d.ip, d.t)
ip := ht.get(d.domain, d.t)
if ip != d.ip {
t.Errorf("%s, expected %s, got %s", d.domain, d.ip, ip)
}
}
//fmt.Printf("%v\n", ht)
}

@ -0,0 +1,41 @@
listen:
- network: tcp
host: 0.0.0.0
port: 1053
- network: udp
host: 0.0.0.0
port: 1053
blacklistfile: ./testdata/blacklist.dat
hostfile: ./testdata/hosts
timeout: 2
debug: true
defaultupstream:
-
network: tcp
host: 8.8.8.8
port: 53
-
network:udp
host: 8.8.8.8
port: 53
forwardrules:
- domainfile: ./testdata/cn.dat
server:
- network: tcp
host: 114.114.114.114
port: 53
- domainfile: testdata/us.dat
server:
- network: tcp
host: 8.8.8.8
port: 53
- network: https
host: 74.175.200.100
port: 443

@ -0,0 +1,183 @@
package main
import (
"crypto/tls"
"encoding/json"
"fmt"
"github.com/miekg/dns"
"io/ioutil"
"net"
"net/http"
"net/url"
"sync"
"time"
)
// ServerAddr is Google dns server ip
var ServerAddr = "74.125.200.100"
var queryIPApi = "https://www.simicloud.com/media/httpbin/ip"
// GoogleHTTPDns struct
type GoogleHTTPDns struct {
myip string
l sync.Mutex
}
func (h *GoogleHTTPDns) getMyIP() string {
if h.myip != "" {
return h.myip
}
go h.queryMyIP()
return ""
}
type ipAPI struct {
IP string `json:"origin"`
}
func (h *GoogleHTTPDns) queryMyIP() {
h.l.Lock()
defer h.l.Unlock()
if h.myip != "" {
//fmt.Printf("myip: %s\n", h.myip)
return
}
//fmt.Println("get ip...")
res, err := http.Get(queryIPApi)
if err != nil {
//fmt.Println(err)
return
}
defer res.Body.Close()
d, err := ioutil.ReadAll(res.Body)
if err != nil {
//fmt.Println(err)
return
}
//fmt.Printf("%s\n", string(d))
ip := ipAPI{}
err = json.Unmarshal(d, &ip)
if err != nil {
//fmt.Println(err)
return
}
//fmt.Printf("got: %s\n", ip.Ip)
h.myip = ip.IP
}
func (h *GoogleHTTPDns) getMyNet() string {
ip := h.getMyIP()
if ip == "" {
return ""
}
mask := net.IPv4Mask(255, 255, 255, 0)
ipByte := net.ParseIP(ip)
ipnet := net.IPNet{ipByte.Mask(mask), mask}
return ipnet.String()
}
// Exchange send query to server and return the response
func (h *GoogleHTTPDns) Exchange(m *dns.Msg, addr string) (*dns.Msg, time.Duration, error) {
name := m.Question[0].Name
t := dns.TypeToString[m.Question[0].Qtype]
mynet := h.getMyNet()
r, err := queryGoogleHTTPDNS(name, t, mynet, "", addr)
if err != nil {
return nil, 0, err
}
m1 := new(dns.Msg)
m1.SetRcode(m, r.Status)
for _, rr := range r.Answer {
_rr := fmt.Sprintf("%s %d IN %s %s", rr.Name, rr.TTL,
dns.TypeToString[uint16(rr.Type)], rr.Data)
an, err := dns.NewRR(_rr)
if err != nil {
return nil, 0, err
}
m1.Answer = append(m1.Answer, an)
}
m1.Truncated = r.TC
m1.RecursionDesired = r.RD
m1.RecursionAvailable = r.RA
m1.AuthenticatedData = r.AD
m1.CheckingDisabled = r.CD
return m1, 0, nil
}
// Response represent the dns response from server
type Response struct {
Status int
TC bool
RD bool
RA bool
AD bool
CD bool
Question []RR
Answer []RR
Additional []RR
EDNSClientSubnet string `json:"edns_client_subnet"`
Comment string
}
// RR represent the RR record
type RR struct {
Name string `json:"name"`
Type int `json:"type"`
TTL int
Data string `json:"data"`
}
var httpclient = &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{ServerName: "dns.google.com"},
TLSHandshakeTimeout: 3 * time.Second,
},
}
func queryGoogleHTTPDNS(name, t, ednsClientSubnet, padding, srvAddr string) (*Response, error) {
srvaddr := ServerAddr
if srvAddr != "" {
srvaddr = srvAddr
}
v := url.Values{}
v.Add("name", name)
v.Add("type", t)
if ednsClientSubnet != "" {
v.Add("edns_client_subnet", ednsClientSubnet)
}
if padding != "" {
v.Add("random_padding", padding)
}
u := fmt.Sprintf("https://%s/resolve?%s", srvaddr, v.Encode())
r, _ := http.NewRequest("GET", u, nil)
r.Host = "dns.google.com"
//r.URL.Host = "dns.google.com"
res, err := httpclient.Do(r)
if err != nil {
return nil, err
}
defer res.Body.Close()
data, err := ioutil.ReadAll(res.Body)
if err != nil {
return nil, err
}
d := Response{}
err = json.Unmarshal(data, &d)
if err != nil {
return nil, err
}
return &d, nil
}

@ -0,0 +1,165 @@
package main
import (
"errors"
"fmt"
"github.com/miekg/dns"
"strings"
"time"
)
type dnsClient interface {
Exchange(m *dns.Msg, addr string) (*dns.Msg, time.Duration, error)
}
type dnsHandler struct {
cfg *conf
tcpclient dnsClient
udpclient dnsClient
httpsclient dnsClient
}
func newDNSHandler(cfg *conf) *dnsHandler {
return &dnsHandler{
cfg: cfg,
tcpclient: &dns.Client{Net: "tcp", Timeout: 2 * time.Second},
udpclient: &dns.Client{Net: "udp", Timeout: 2 * time.Second},
httpsclient: &GoogleHTTPDns{},
}
}
// ServerDNS implements the dns.Handler interface
func (h *dnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
domain := r.Question[0].Name
if ok := h.answerFromHosts(w, r); ok {
return
}
srvs := h.getUpstreamServer(domain)
if srvs == nil {
srvs = h.cfg.DefaultUpstream
}
if msg, err := h.getAnswerFromUpstream(r, srvs); err == nil {
w.WriteMsg(msg)
return
}
dns.HandleFailed(w, r)
}
func (h *dnsHandler) getUpstreamServer(domain string) []addr {
for _, srv := range h.cfg.ForwardRules {
if ok := srv.domains.has(strings.Trim(domain, ".")); ok {
return srv.Server
}
}
return nil
}
func (h *dnsHandler) queryUpstream(r *dns.Msg, srv addr, ch chan *dns.Msg) {
var m *dns.Msg
var err error
switch srv.Network {
case "tcp":
info("query %s IN %s, forward to %s:%d through tcp",
r.Question[0].Name,
dns.TypeToString[r.Question[0].Qtype],
srv.Host,
srv.Port)
m, _, err = h.tcpclient.Exchange(r, fmt.Sprintf("%s:%d", srv.Host, srv.Port))
case "udp":
info("query %s IN %s, forward to %s:%d through udp",
r.Question[0].Name,
dns.TypeToString[r.Question[0].Qtype],
srv.Host,
srv.Port)
m, _, err = h.tcpclient.Exchange(r, fmt.Sprintf("%s:%d", srv.Host, srv.Port))
case "https":
info("query %s IN %s, forward to %s:%d through https",
r.Question[0].Name,
dns.TypeToString[r.Question[0].Qtype],
srv.Host,
srv.Port)
m, _, err = h.httpsclient.Exchange(r, fmt.Sprintf("%s:%d", srv.Host, srv.Port))
default:
// ignore
}
if err == nil {
select {
case ch <- m:
default:
}
} else {
errorlog("%s", err)
}
}
func (h *dnsHandler) getAnswerFromUpstream(r *dns.Msg, servers []addr) (*dns.Msg, error) {
ch := make(chan *dns.Msg, 5)
for _, srv := range servers {
go func(a addr) {
h.queryUpstream(r, a, ch)
}(srv)
}
var savedErr *dns.Msg
for {
select {
case m := <-ch:
if m.Rcode == dns.RcodeSuccess && !h.inBlacklist(m) {
return m, nil
}
savedErr = m
case <-time.After(time.Duration(h.cfg.Timeout) * time.Second):
if savedErr != nil {
return savedErr, nil
}
info("query %s IN %s, timeout", r.Question[0].Name, dns.TypeToString[r.Question[0].Qtype])
return nil, errors.New("timeout")
}
}
}
func (h *dnsHandler) inBlacklist(m *dns.Msg) bool {
var ip string
for _, rr := range m.Answer {
if a, ok := rr.(*dns.A); ok {
ip = a.String()
} else if aaaa, ok := rr.(*dns.AAAA); ok {
ip = aaaa.String()
} else {
ip = ""
}
if ip != "" && h.cfg.blacklist.exists(ip) {
info("%s in blacklist", ip)
return true
}
}
return false
}
func (h *dnsHandler) answerFromHosts(w dns.ResponseWriter, r *dns.Msg) bool {
domain := r.Question[0].Name
t := r.Question[0].Qtype
ip := h.cfg.hosts.get(strings.Trim(domain, "."), int(t))
if ip != "" {
rr, _ := dns.NewRR(fmt.Sprintf("%s 3600 IN %s %s", domain, dns.TypeToString[t], ip))
if rr == nil {
return false
}
msg := new(dns.Msg)
msg.SetReply(r)
msg.Answer = append(msg.Answer, rr)
w.WriteMsg(msg)
info("query %s IN %s, reply from hosts", domain, dns.TypeToString[t])
return true
}
return false
}

@ -0,0 +1,31 @@
package main
import (
"log"
)
const (
_ = iota
FATAL
ERROR
WARN
NOTICE
INFO
DEBUG
)
var logLevel = WARN
func logMsg(l int, fmt string, args ...interface{}) {
if l <= logLevel {
log.Printf(fmt, args...)
}
}
func info(fmt string, args ...interface{}) {
logMsg(INFO, fmt, args...)
}
func errorlog(fmt string, args ...interface{}) {
logMsg(ERROR, fmt, args...)
}

@ -0,0 +1,33 @@
package main
import (
"flag"
"fmt"
"github.com/miekg/dns"
"log"
)
func main() {
var configfile string
flag.StringVar(&configfile, "c", "config.yaml", "config file")
flag.Parse()
config, err := loadConfig(configfile)
if err != nil {
log.Fatal(err)
}
h := newDNSHandler(config)
for _, l := range config.Listen {
go func(l addr) {
if err := dns.ListenAndServe(
fmt.Sprintf("%s:%d", l.Host, l.Port), l.Network, h); err != nil {
log.Fatal(err)
}
}(l)
}
select {}
}

@ -0,0 +1,5 @@
# asdfas
1.2.3.4
2.3.1.21
3.1.1.1

8
testdata/cn.dat vendored

@ -0,0 +1,8 @@
#12
a.com
b.org
cn
c.net

@ -0,0 +1,31 @@
listen:
- network: tcp
host: 0.0.0.0
port: 1053
- network: udp
host: 0.0.0.0
port: 1053
blacklistfile: ./testdata/blacklist.dat
hostfile: ./testdata/hosts
timeout: 2
debug: true
forwardrules:
- domainfile: ./testdata/cn.dat
server:
- network: tcp
host: 114.114.114.114
port: 53
- domainfile: testdata/us.dat
server:
- network: tcp
host: 8.8.8.8
port: 53
- network: https
host: 74.175.200.200
port: 443

9
testdata/hosts vendored

@ -0,0 +1,9 @@
#asdf
1.2.3.1 a.com b.com c.aaa
1.2.9.1 addd bbb ddd aa
::1 localhost
127.0.0.1 localhost
1.2.1.1 kkddaso
# daas

12
testdata/us.dat vendored

@ -0,0 +1,12 @@
#
us
b.cc
a.cc
dda
# da
a.cf
d.com
Loading…
Cancel
Save