diff --git a/server_https.go b/server_https.go index 98be95f..3a23350 100644 --- a/server_https.go +++ b/server_https.go @@ -50,28 +50,18 @@ func (srv *server) handleHTTP2Req(w http.ResponseWriter, r *http.Request) { log.Errorln("parse dns message", err) return } - reply := false - for _, up := range srv.upstreams { - log.Debugf("from %s query upstream %s", r.RemoteAddr, up.String()) - log.Debugln("query", msg.Question[0].String()) - m, err := queryUpstream(msg, up) - if err == nil { - w.Header().Set("content-type", "application/dns-message") - w.WriteHeader(http.StatusOK) - for _, a := range m.Answer { - log.Debugln("result", a.String()) - } - d, _ := m.Pack() - w.Write(d) - reply = true - break - } else { - log.Errorf("https query upstream %s", err) - } - } - if !reply { + m, err := getResponseFromUpstream(msg, srv.upstreams) + if err != nil { + log.Debugln("query", msg.Question[0].String(), "timeout") w.WriteHeader(http.StatusServiceUnavailable) + return + } + + for _, a := range m.Answer { + log.Debugln("result", a.String()) } + d, _ := m.Pack() + w.Write(d) } func (srv *server) ServeHTTP(w http.ResponseWriter, r *http.Request) { diff --git a/server_tcp.go b/server_tcp.go index dce483b..02133ca 100644 --- a/server_tcp.go +++ b/server_tcp.go @@ -17,24 +17,16 @@ func (srv *server) handleTCP(c net.Conn) { log.Debugln("tcp read message", err) break } - reply := false - for _, up := range srv.upstreams { - log.Debugf("from %s query upstream %s", conn.RemoteAddr(), up.String()) - log.Debugln("query", msg.Question[0].String()) - m, err := queryUpstream(msg, up) - if err == nil { - for _, a := range m.Answer { - log.Debugln("result", a.String()) - } - log.Debugln("got reply", m.String()) - conn.WriteMsg(m) - reply = true - break - } - log.Debugln("query upstream", up.String(), err) - } - if !reply { + + m, err := getResponseFromUpstream(msg, srv.upstreams) + if err != nil { + log.Debugln("query", msg.Question[0].String(), "timeout") break } + + for _, a := range m.Answer { + log.Debugln("result", a.String()) + } + conn.WriteMsg(m) } } diff --git a/server_udp.go b/server_udp.go index 5c7f909..5216d12 100644 --- a/server_udp.go +++ b/server_udp.go @@ -13,19 +13,16 @@ func (srv *server) handleUDP(buf []byte, addr net.Addr, conn *net.UDPConn) { log.Debugln("udp parse msg", err) return } - for _, up := range srv.upstreams { - log.Debugf("from %s query upstream %s", addr, up.String()) - log.Debugln("query", msg.Question[0].String()) - m, err := queryUpstream(msg, up) - if err == nil { - for _, a := range m.Answer { - log.Debugln("result", a.String()) - } - d, _ := m.Pack() - conn.WriteTo(d, addr) - break - } else { - log.Debugln("udp query upstream err", err) - } + + m, err := getResponseFromUpstream(msg, srv.upstreams) + if err != nil { + log.Debugln("query", msg.Question[0].String(), "timeout") + return + } + + for _, a := range m.Answer { + log.Debugln("result", a.String()) } + d, _ := m.Pack() + conn.WriteTo(d, addr) } diff --git a/upstream.go b/upstream.go index e7c85e6..b41c77a 100644 --- a/upstream.go +++ b/upstream.go @@ -20,6 +20,31 @@ var dnsClientHTTPS *dns.Client var dnsClientUDP *dns.Client var dnsClientTLS *dns.Client +func getResponseFromUpstream(msg *dns.Msg, upstreams []*url.URL) (*dns.Msg, error) { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + resch := make(chan *dns.Msg, len(upstreams)) + + for _, up := range upstreams { + go func(u *url.URL) { + m, err := queryUpstream(msg, u) + if err == nil { + resch <- m + return + } + log.Errorln(u.String(), err) + }(up) + } + + select { + case <-ctx.Done(): + return nil, errors.New("time out") + case m := <-resch: + return m, nil + } +} + func queryUpstream(msg *dns.Msg, upstream *url.URL) (*dns.Msg, error) { switch upstream.Scheme { case "tcp":