diff --git a/handler.go b/handler.go index 9aa7578..13da806 100644 --- a/handler.go +++ b/handler.go @@ -32,7 +32,7 @@ var defaultTransport http.RoundTripper = &http.Transport{ IdleConnTimeout: 30 * time.Second, MaxIdleConnsPerHost: 3, DisableKeepAlives: true, - ResponseHeaderTimeout: 2 * time.Second, + ResponseHeaderTimeout: 10 * time.Second, } // ServeHTTP implements the http.Handler interface @@ -100,7 +100,10 @@ func (h *handler) handleHTTP(w http.ResponseWriter, r *http.Request) { r.Body = nil } } - + if r.Method == http.MethodPost && r.ContentLength == 0 { + r.Body = http.NoBody + } + //log.Println("content-length", r.Header.Get("content-length")) resp, err = defaultTransport.RoundTrip(r) if err != nil { log.Printf("RoundTrip: %s", err) @@ -114,9 +117,15 @@ func (h *handler) handleHTTP(w http.ResponseWriter, r *http.Request) { hdr := w.Header() - resp.Header.Del("connection") - for k, v := range resp.Header { + _k := strings.ToLower(k) + if _k == "connection" || _k == "transfer-encoding" || + _k == "keep-alive" || _k == "upgrade" || _k == "te" { + continue + } + if resp.StatusCode == 204 && _k == "content-length" { + continue + } for _, v1 := range v { hdr.Add(k, v1) } diff --git a/http2_proxy.go b/http2_proxy.go index d77f3d6..6b03893 100644 --- a/http2_proxy.go +++ b/http2_proxy.go @@ -15,29 +15,20 @@ usage example */ import ( - "crypto/tls" "flag" "fmt" - "golang.org/x/net/http2" "io" - "io/ioutil" "log" "net" "net/http" - "net/http/httputil" + _ "net/http/pprof" "os" "sync" "time" -) -type clientConn struct { - host string - port string - hostname string - transport *http2.Transport - conn *http2.ClientConn - lock *sync.Mutex -} + "github.com/fangdingjun/gnutls" + "github.com/fangdingjun/nghttp2-go" +) type timeoutConn struct { net.Conn @@ -48,61 +39,131 @@ func (tc *timeoutConn) Read(b []byte) (n int, err error) { if err = tc.Conn.SetReadDeadline(time.Now().Add(tc.timeout)); err != nil { return 0, err } - return tc.Conn.Read(b) + n, err = tc.Conn.Read(b) + //log.Printf("read %d bytes from network", n) + return } func (tc *timeoutConn) Write(b []byte) (n int, err error) { if err = tc.Conn.SetWriteDeadline(time.Now().Add(tc.timeout)); err != nil { return 0, err } - return tc.Conn.Write(b) + n, err = tc.Conn.Write(b) + //log.Printf("write %d bytes to network", n) + return } type handler struct { - transport *http2.Transport + h2conn *nghttp2.ClientConn + addr string + hostname string + insecure bool + lock *sync.Mutex } -func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if debug { - req, _ := httputil.DumpRequest(r, false) - log.Printf("%s", string(req)) +func (h *handler) createConnection() (*nghttp2.ClientConn, error) { + log.Println("create connection to ", h.addr) + c, err := net.DialTimeout("tcp", h.addr, 5*time.Second) + if err != nil { + return nil, err } - - if r.Method == http.MethodConnect { - h.handleConnect(w, r) - } else { - h.handleHTTP(w, r) + conn, err := gnutls.Client( + &timeoutConn{c, 20 * time.Second}, + &gnutls.Config{ + ServerName: h.hostname, + InsecureSkipVerify: h.insecure, + NextProtos: []string{"h2"}, + }) + if err != nil { + return nil, err + } + if err := conn.Handshake(); err != nil { + return nil, err } + client, err := nghttp2.Client(conn) + if err != nil { + return nil, err + } + return client, nil } -func (h *handler) handleConnect(w http.ResponseWriter, r *http.Request) { - pr, pw := io.Pipe() +func (h *handler) getConn() (*nghttp2.ClientConn, error) { + h.lock.Lock() + defer h.lock.Unlock() - defer pr.Close() - defer pw.Close() + if h.h2conn != nil { + if h.h2conn.CanTakeNewRequest() { + return h.h2conn, nil + } + h.h2conn.Close() + } - r.Body = ioutil.NopCloser(pr) - r.URL.Scheme = "https" + for i := 0; i < 2; i++ { + h2conn, err := h.createConnection() + if err == nil { + h.h2conn = h2conn + return h2conn, nil + } + } + return nil, fmt.Errorf("create conn failed") +} - r.Header.Del("proxy-connection") +func (h *handler) checkError() { + h.lock.Lock() + defer h.lock.Unlock() - resp, err := h.transport.RoundTrip(r) - if err != nil { - log.Printf("roundtrip: %s", err) - w.WriteHeader(http.StatusServiceUnavailable) - fmt.Fprintf(w, "%s", err) + if h.h2conn == nil { return } - defer resp.Body.Close() + if err := h.h2conn.Error(); err != nil { + //log.Println("connection has error ", err) + h.h2conn.Close() + h.h2conn = nil + } +} + +func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if debug { - d, _ := httputil.DumpResponse(resp, false) - log.Printf("%s", string(d)) + if r.Method == http.MethodConnect { + h.handleConnect(w, r) + } else { + h.handleHTTP(w, r) } +} - if resp.StatusCode != http.StatusOK { - w.WriteHeader(resp.StatusCode) +func (h *handler) handleConnect(w http.ResponseWriter, r *http.Request) { + var err error + var h2conn *nghttp2.ClientConn + var code int + //var resp *http.Response + + var cs *nghttp2.ClientStream + + for i := 0; i < 2; i++ { + h2conn, err = h.getConn() + if err != nil { + log.Println("connection error ", err) + w.WriteHeader(http.StatusBadGateway) + return + } + cs, code, err = h2conn.Connect(r.RequestURI) + if cs != nil { + break + } + h.checkError() + } + + if err != nil || cs == nil { + log.Println("send connect error ", err) + h.checkError() + w.WriteHeader(http.StatusBadGateway) + return + } + + defer cs.Close() + if code != http.StatusOK { + w.WriteHeader(code) return } @@ -120,12 +181,12 @@ func (h *handler) handleConnect(w http.ResponseWriter, r *http.Request) { ch := make(chan struct{}, 2) go func() { - io.Copy(pw, c) + io.Copy(cs, c) ch <- struct{}{} }() go func() { - io.Copy(c, resp.Body) + io.Copy(c, cs) ch <- struct{}{} }() @@ -133,19 +194,42 @@ func (h *handler) handleConnect(w http.ResponseWriter, r *http.Request) { } func (h *handler) handleHTTP(w http.ResponseWriter, r *http.Request) { - resp, err := h.transport.RoundTrip(r) - if err != nil { - log.Println(err) + var err error + var resp *http.Response + var h2conn *nghttp2.ClientConn + + if r.RequestURI[0] == '/' { + http.DefaultServeMux.ServeHTTP(w, r) + return + } + + for i := 0; i < 2; i++ { + h2conn, err = h.getConn() + if err != nil { + //log.Println("create connection ", err) + w.WriteHeader(http.StatusBadGateway) + return + } + resp, err = h2conn.CreateRequest(r) + if resp != nil { + break + } + h.checkError() + } + + if err != nil || resp == nil { + log.Println("create request error ", err) + h.checkError() w.WriteHeader(http.StatusServiceUnavailable) fmt.Fprintf(w, "%s", err) return } - defer resp.Body.Close() - if debug { - d, _ := httputil.DumpResponse(resp, false) - log.Printf("%s", string(d)) - } + defer func() { + if resp.Body != nil { + resp.Body.Close() + } + }() hdr := w.Header() for k, v := range resp.Header { @@ -158,73 +242,7 @@ func (h *handler) handleHTTP(w http.ResponseWriter, r *http.Request) { io.Copy(w, resp.Body) } -func newClientConn(host string, port string, hostname string, t *http2.Transport) *clientConn { - return &clientConn{ - host: host, - port: port, - hostname: hostname, - transport: t, - lock: new(sync.Mutex), - } -} - -func (p *clientConn) GetClientConn(req *http.Request, addr string) (*http2.ClientConn, error) { - p.lock.Lock() - defer p.lock.Unlock() - - if p.conn != nil && p.conn.CanTakeNewRequest() { - return p.conn, nil - } - - if debug { - log.Printf("dial to %s:%s", p.host, p.port) - } - - c, err := net.Dial("tcp", net.JoinHostPort(p.host, p.port)) - if err != nil { - log.Println(err) - return nil, err - } - - cc := &timeoutConn{c, time.Duration(idleTimeout) * time.Second} - config := &tls.Config{ - ServerName: p.hostname, - NextProtos: []string{"h2"}, - InsecureSkipVerify: insecure, - } - - conn := tls.Client(cc, config) - if err := conn.Handshake(); err != nil { - log.Println(err) - return nil, err - } - - http2conn, err := p.transport.NewClientConn(conn) - if err != nil { - conn.Close() - log.Println(err) - return nil, err - } - - p.conn = http2conn - - return http2conn, err -} - -func (p *clientConn) MarkDead(conn *http2.ClientConn) { - //p.lock.Lock() - //defer p.lock.Unlock() - - if debug { - log.Println("mark dead") - } - - //p.conn = nil -} - -var debug bool var insecure bool -var idleTimeout int func main() { var addr string @@ -233,9 +251,7 @@ func main() { flag.StringVar(&addr, "server", "", "server address") flag.StringVar(&hostname, "name", "", "server 's SNI name") flag.StringVar(&listen, "listen", ":8080", "listen address") - flag.BoolVar(&debug, "debug", false, "verbose mode") flag.BoolVar(&insecure, "insecure", false, "insecure mode, not verify the server's certificate") - flag.IntVar(&idleTimeout, "idletime", 30, "idle timeout, close connection when no data transfer") flag.Parse() if addr == "" { @@ -243,31 +259,25 @@ func main() { os.Exit(-1) } - host, port, err := net.SplitHostPort(addr) + host, _, err := net.SplitHostPort(addr) if err != nil { host = addr - port = "443" + addr = fmt.Sprintf("%s:443", addr) } if hostname == "" { hostname = host } - transport := &http2.Transport{ - AllowHTTP: true, - } - - p := newClientConn(host, port, hostname, transport) - transport.ConnPool = p - log.Printf("listen on %s", listen) - if debug { - log.Printf("use parent proxy https://%s:%s/", host, port) - log.Printf("server SNI name %s", hostname) + hdr := &handler{ + addr: addr, + hostname: hostname, + insecure: insecure, + lock: new(sync.Mutex), } - - if err := http.ListenAndServe(listen, &handler{transport}); err != nil { + if err := http.ListenAndServe(listen, hdr); err != nil { log.Fatal(err) } } diff --git a/routers.go b/routers.go index 504f60b..2241404 100644 --- a/routers.go +++ b/routers.go @@ -16,7 +16,6 @@ import ( "github.com/fangdingjun/gnutls" auth "github.com/fangdingjun/go-http-auth" "github.com/fangdingjun/gofast" - loghandler "github.com/gorilla/handlers" "github.com/gorilla/mux" ) @@ -33,18 +32,19 @@ func (lw *logwriter) Write(buf []byte) (int, error) { func initRouters(cfg conf) { - logout := os.Stdout - - if logfile != "" { - fp, err := os.OpenFile(logfile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644) - if err != nil { - log.Println(err) - } else { - logout = fp + //logout := os.Stdout + /* + if logfile != "" { + fp, err := os.OpenFile(logfile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644) + if err != nil { + log.Println(err) + } else { + logout = fp + } } - } + */ - w := &logwriter{logout, new(sync.Mutex)} + //w := &logwriter{logout, new(sync.Mutex)} for _, l := range cfg { router := mux.NewRouter() @@ -137,7 +137,8 @@ func initRouters(cfg conf) { log.Fatal(err) } - handler := loghandler.CombinedLoggingHandler(w, hdlr) + //handler := loghandler.CombinedLoggingHandler(w, hdlr) + handler := hdlr log.Printf("listen https on %s", addr) go func() { defer listener.Close() @@ -153,9 +154,10 @@ func initRouters(cfg conf) { } else { log.Printf("listen http on %s", addr) + //handler := loghandler.CombinedLoggingHandler(w, hdlr), + handler := hdlr if err := http.ListenAndServe( - addr, - loghandler.CombinedLoggingHandler(w, hdlr), + addr, handler, ); err != nil { log.Fatal(err) }