query upstream on concurrency

master
fangdingjun 6 years ago
parent e52286ebc7
commit 20c5e4794d

@ -50,28 +50,18 @@ func (srv *server) handleHTTP2Req(w http.ResponseWriter, r *http.Request) {
log.Errorln("parse dns message", err) log.Errorln("parse dns message", err)
return return
} }
reply := false m, err := getResponseFromUpstream(msg, srv.upstreams)
for _, up := range srv.upstreams { if err != nil {
log.Debugf("from %s query upstream %s", r.RemoteAddr, up.String()) log.Debugln("query", msg.Question[0].String(), "timeout")
log.Debugln("query", msg.Question[0].String()) w.WriteHeader(http.StatusServiceUnavailable)
m, err := queryUpstream(msg, up) return
if err == nil { }
w.Header().Set("content-type", "application/dns-message")
w.WriteHeader(http.StatusOK)
for _, a := range m.Answer { for _, a := range m.Answer {
log.Debugln("result", a.String()) log.Debugln("result", a.String())
} }
d, _ := m.Pack() d, _ := m.Pack()
w.Write(d) w.Write(d)
reply = true
break
} else {
log.Errorf("https query upstream %s", err)
}
}
if !reply {
w.WriteHeader(http.StatusServiceUnavailable)
}
} }
func (srv *server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (srv *server) ServeHTTP(w http.ResponseWriter, r *http.Request) {

@ -17,24 +17,16 @@ func (srv *server) handleTCP(c net.Conn) {
log.Debugln("tcp read message", err) log.Debugln("tcp read message", err)
break break
} }
reply := false
for _, up := range srv.upstreams { m, err := getResponseFromUpstream(msg, srv.upstreams)
log.Debugf("from %s query upstream %s", conn.RemoteAddr(), up.String()) if err != nil {
log.Debugln("query", msg.Question[0].String()) log.Debugln("query", msg.Question[0].String(), "timeout")
m, err := queryUpstream(msg, up) break
if err == nil { }
for _, a := range m.Answer { for _, a := range m.Answer {
log.Debugln("result", a.String()) log.Debugln("result", a.String())
} }
log.Debugln("got reply", m.String())
conn.WriteMsg(m) conn.WriteMsg(m)
reply = true
break
}
log.Debugln("query upstream", up.String(), err)
}
if !reply {
break
}
} }
} }

@ -13,19 +13,16 @@ func (srv *server) handleUDP(buf []byte, addr net.Addr, conn *net.UDPConn) {
log.Debugln("udp parse msg", err) log.Debugln("udp parse msg", err)
return return
} }
for _, up := range srv.upstreams {
log.Debugf("from %s query upstream %s", addr, up.String()) m, err := getResponseFromUpstream(msg, srv.upstreams)
log.Debugln("query", msg.Question[0].String()) if err != nil {
m, err := queryUpstream(msg, up) log.Debugln("query", msg.Question[0].String(), "timeout")
if err == nil { return
}
for _, a := range m.Answer { for _, a := range m.Answer {
log.Debugln("result", a.String()) log.Debugln("result", a.String())
} }
d, _ := m.Pack() d, _ := m.Pack()
conn.WriteTo(d, addr) conn.WriteTo(d, addr)
break
} else {
log.Debugln("udp query upstream err", err)
}
}
} }

@ -20,6 +20,31 @@ var dnsClientHTTPS *dns.Client
var dnsClientUDP *dns.Client var dnsClientUDP *dns.Client
var dnsClientTLS *dns.Client var dnsClientTLS *dns.Client
func getResponseFromUpstream(msg *dns.Msg, upstreams []*url.URL) (*dns.Msg, error) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
resch := make(chan *dns.Msg, len(upstreams))
for _, up := range upstreams {
go func(u *url.URL) {
m, err := queryUpstream(msg, u)
if err == nil {
resch <- m
return
}
log.Errorln(u.String(), err)
}(up)
}
select {
case <-ctx.Done():
return nil, errors.New("time out")
case m := <-resch:
return m, nil
}
}
func queryUpstream(msg *dns.Msg, upstream *url.URL) (*dns.Msg, error) { func queryUpstream(msg *dns.Msg, upstream *url.URL) (*dns.Msg, error) {
switch upstream.Scheme { switch upstream.Scheme {
case "tcp": case "tcp":

Loading…
Cancel
Save