query upstream on concurrency

master
fangdingjun 6 years ago
parent e52286ebc7
commit 20c5e4794d

@ -50,28 +50,18 @@ func (srv *server) handleHTTP2Req(w http.ResponseWriter, r *http.Request) {
log.Errorln("parse dns message", err)
return
}
reply := false
for _, up := range srv.upstreams {
log.Debugf("from %s query upstream %s", r.RemoteAddr, up.String())
log.Debugln("query", msg.Question[0].String())
m, err := queryUpstream(msg, up)
if err == nil {
w.Header().Set("content-type", "application/dns-message")
w.WriteHeader(http.StatusOK)
m, err := getResponseFromUpstream(msg, srv.upstreams)
if err != nil {
log.Debugln("query", msg.Question[0].String(), "timeout")
w.WriteHeader(http.StatusServiceUnavailable)
return
}
for _, a := range m.Answer {
log.Debugln("result", a.String())
}
d, _ := m.Pack()
w.Write(d)
reply = true
break
} else {
log.Errorf("https query upstream %s", err)
}
}
if !reply {
w.WriteHeader(http.StatusServiceUnavailable)
}
}
func (srv *server) ServeHTTP(w http.ResponseWriter, r *http.Request) {

@ -17,24 +17,16 @@ func (srv *server) handleTCP(c net.Conn) {
log.Debugln("tcp read message", err)
break
}
reply := false
for _, up := range srv.upstreams {
log.Debugf("from %s query upstream %s", conn.RemoteAddr(), up.String())
log.Debugln("query", msg.Question[0].String())
m, err := queryUpstream(msg, up)
if err == nil {
m, err := getResponseFromUpstream(msg, srv.upstreams)
if err != nil {
log.Debugln("query", msg.Question[0].String(), "timeout")
break
}
for _, a := range m.Answer {
log.Debugln("result", a.String())
}
log.Debugln("got reply", m.String())
conn.WriteMsg(m)
reply = true
break
}
log.Debugln("query upstream", up.String(), err)
}
if !reply {
break
}
}
}

@ -13,19 +13,16 @@ func (srv *server) handleUDP(buf []byte, addr net.Addr, conn *net.UDPConn) {
log.Debugln("udp parse msg", err)
return
}
for _, up := range srv.upstreams {
log.Debugf("from %s query upstream %s", addr, up.String())
log.Debugln("query", msg.Question[0].String())
m, err := queryUpstream(msg, up)
if err == nil {
m, err := getResponseFromUpstream(msg, srv.upstreams)
if err != nil {
log.Debugln("query", msg.Question[0].String(), "timeout")
return
}
for _, a := range m.Answer {
log.Debugln("result", a.String())
}
d, _ := m.Pack()
conn.WriteTo(d, addr)
break
} else {
log.Debugln("udp query upstream err", err)
}
}
}

@ -20,6 +20,31 @@ var dnsClientHTTPS *dns.Client
var dnsClientUDP *dns.Client
var dnsClientTLS *dns.Client
func getResponseFromUpstream(msg *dns.Msg, upstreams []*url.URL) (*dns.Msg, error) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
resch := make(chan *dns.Msg, len(upstreams))
for _, up := range upstreams {
go func(u *url.URL) {
m, err := queryUpstream(msg, u)
if err == nil {
resch <- m
return
}
log.Errorln(u.String(), err)
}(up)
}
select {
case <-ctx.Done():
return nil, errors.New("time out")
case m := <-resch:
return m, nil
}
}
func queryUpstream(msg *dns.Msg, upstream *url.URL) (*dns.Msg, error) {
switch upstream.Scheme {
case "tcp":

Loading…
Cancel
Save