From db0d02192766f28d7cd31e849dd53f4ecdfc6e05 Mon Sep 17 00:00:00 2001 From: fangdingjun Date: Fri, 5 Oct 2018 18:16:05 +0800 Subject: [PATCH] DoH: add http/1.1 support --- .gitignore | 7 +++++++ man.go | 2 +- middleware.go | 55 +++++++++++++++++++++++++++++++++++++++++++++++++ server.go | 24 ++++----------------- server_https.go | 35 +++++++++---------------------- 5 files changed, 77 insertions(+), 46 deletions(-) create mode 100644 middleware.go diff --git a/.gitignore b/.gitignore index daf913b..f3dcbf7 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,10 @@ _testmain.go *.exe *.test *.prof +*~ +*.swp +gdns* +*.key +*.crt +client.yaml +server.yaml diff --git a/man.go b/man.go index eda996b..4ca98ff 100644 --- a/man.go +++ b/man.go @@ -22,7 +22,7 @@ func main() { flag.Int64Var(&logFileSize, "log_size", 10, "max log file size MB") flag.StringVar(&loglevel, "log_level", "INFO", "log level, values:\nOFF, FATAL, PANIC, ERROR, WARN, INFO, DEBUG") - flag.StringVar(&configfile, "c", "", "config file") + flag.StringVar(&configfile, "c", "gdns.yaml", "config file") flag.Parse() if logfile != "" { diff --git a/middleware.go b/middleware.go new file mode 100644 index 0000000..3b9e877 --- /dev/null +++ b/middleware.go @@ -0,0 +1,55 @@ +package main + +import ( + "net/http" + + "github.com/fangdingjun/go-log" +) + +type logHandler struct { + status int + w http.ResponseWriter + size int +} + +func (lh *logHandler) WriteHeader(status int) { + lh.status = status + lh.w.WriteHeader(status) +} + +func (lh *logHandler) Write(buf []byte) (int, error) { + lh.size += len(buf) + return lh.w.Write(buf) +} + +func (lh *logHandler) Header() http.Header { + return lh.w.Header() +} + +func (lh *logHandler) Status() int { + if lh.status != 0 { + return lh.status + } + return 200 +} + +var _ http.ResponseWriter = &logHandler{} + +func LogHandler(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil { + log.Error(err) + w.WriteHeader(http.StatusInternalServerError) + log.Infof("\"%s %s %s\" - %d %d \"%s\"", + r.Method, r.RequestURI, r.Proto, 500, 0, r.UserAgent()) + } + }() + + lh := &logHandler{w: w} + handler.ServeHTTP(lh, r) + log.Infof("\"%s %s %s\" - %d %d \"%s\"", + r.Method, r.RequestURI, r.Proto, lh.Status(), lh.size, r.UserAgent()) + + }) +} diff --git a/server.go b/server.go index ab64711..5029ad3 100644 --- a/server.go +++ b/server.go @@ -3,6 +3,7 @@ package main import ( "crypto/tls" "net" + "net/http" "net/url" "strconv" @@ -100,27 +101,10 @@ func (srv *server) serveTLS() { } func (srv *server) serveHTTPS() { - cert, err := tls.LoadX509KeyPair(srv.cert, srv.key) - if err != nil { - log.Fatalln("load certificate failed", err) - } - l, err := tls.Listen("tcp", srv.addr.Host, - &tls.Config{ - Certificates: []tls.Certificate{cert}, - NextProtos: []string{"h2"}, - }) + log.Debugf("listen https://%s", srv.addr.Host) + err := http.ListenAndServeTLS(srv.addr.Host, srv.cert, srv.key, LogHandler(srv)) if err != nil { - log.Fatalln("listen https", err) - } - defer l.Close() - log.Debugf("listen https://%s", l.Addr().String()) - for { - conn, err := l.Accept() - if err != nil { - log.Debugln("https accept", err) - break - } - go srv.handleHTTPSConn(conn) + log.Fatal(err) } } diff --git a/server_https.go b/server_https.go index 8f34d70..722a814 100644 --- a/server_https.go +++ b/server_https.go @@ -1,40 +1,25 @@ package main import ( - "crypto/tls" "io/ioutil" - "net" "net/http" + "strings" "github.com/fangdingjun/go-log" - //"github.com/fangdingjun/nghttp2-go" "github.com/miekg/dns" - "golang.org/x/net/http2" ) -func (srv *server) handleHTTPSConn(c net.Conn) { - defer c.Close() - tlsconn := c.(*tls.Conn) - if err := tlsconn.Handshake(); err != nil { - log.Errorln("handshake", err) - return - } - state := tlsconn.ConnectionState() - if state.NegotiatedProtocol != "h2" { - log.Errorln("http2 is needed") +func (srv *server) handleHTTPReq(w http.ResponseWriter, r *http.Request) { + ctype := r.Header.Get("content-type") + if !strings.HasPrefix(ctype, "application/dns-message") { + log.Errorf("request type %s, require application/dns-message", ctype) + http.Error(w, "dns message is required", http.StatusBadRequest) return } - _srv := &http2.Server{} - _srv.ServeConn(c, &http2.ServeConnOpts{ - BaseConfig: &http.Server{}, - Handler: srv, - }) -} -func (srv *server) handleHTTP2Req(w http.ResponseWriter, r *http.Request) { - ctype := r.Header.Get("content-type") - if ctype != "application/dns-message" { - http.Error(w, "dns message is required", http.StatusBadRequest) + if r.ContentLength < 10 { + log.Errorf("message is too small, %v", r.ContentLength) + http.Error(w, "message is too small", http.StatusBadRequest) return } @@ -70,5 +55,5 @@ func (srv *server) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) return } - srv.handleHTTP2Req(w, r) + srv.handleHTTPReq(w, r) }