diff --git a/proxy_local/.gitignore b/proxy_local/.gitignore new file mode 100644 index 0000000..218e3e4 --- /dev/null +++ b/proxy_local/.gitignore @@ -0,0 +1 @@ +proxy_local* diff --git a/proxy_local/README.md b/proxy_local/README.md new file mode 100644 index 0000000..288cdff --- /dev/null +++ b/proxy_local/README.md @@ -0,0 +1,9 @@ +proxy_local +=========== + +accept http proxy request and forward to upstream proxy server via http2 + +usage +==== + +use `./proxy_local -h` to see options \ No newline at end of file diff --git a/http2_proxy.go b/proxy_local/http2_proxy.go similarity index 81% rename from http2_proxy.go rename to proxy_local/http2_proxy.go index d77f3d6..74c8ca8 100644 --- a/http2_proxy.go +++ b/proxy_local/http2_proxy.go @@ -1,5 +1,3 @@ -// +build ignore - package main /* @@ -18,16 +16,17 @@ import ( "crypto/tls" "flag" "fmt" - "golang.org/x/net/http2" "io" "io/ioutil" - "log" "net" "net/http" "net/http/httputil" "os" "sync" "time" + + log "github.com/fangdingjun/go-log" + "golang.org/x/net/http2" ) type clientConn struct { @@ -65,7 +64,7 @@ type handler struct { func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if debug { req, _ := httputil.DumpRequest(r, false) - log.Printf("%s", string(req)) + log.Debugf("%s", string(req)) } if r.Method == http.MethodConnect { @@ -88,9 +87,8 @@ func (h *handler) handleConnect(w http.ResponseWriter, r *http.Request) { resp, err := h.transport.RoundTrip(r) if err != nil { - log.Printf("roundtrip: %s", err) - w.WriteHeader(http.StatusServiceUnavailable) - fmt.Fprintf(w, "%s", err) + log.Errorf("roundtrip: %s", err) + http.Error(w, err.Error(), http.StatusServiceUnavailable) return } @@ -98,19 +96,20 @@ func (h *handler) handleConnect(w http.ResponseWriter, r *http.Request) { if debug { d, _ := httputil.DumpResponse(resp, false) - log.Printf("%s", string(d)) + log.Debugf("%s", string(d)) } if resp.StatusCode != http.StatusOK { w.WriteHeader(resp.StatusCode) + io.Copy(w, resp.Body) + resp.Body.Close() return } c, _, err := w.(http.Hijacker).Hijack() if err != nil { - log.Println("hijack: %s", err) - w.WriteHeader(http.StatusServiceUnavailable) - fmt.Fprintf(w, "%s", err) + log.Errorf("hijack: %s", err) + http.Error(w, err.Error(), http.StatusServiceUnavailable) return } @@ -135,16 +134,15 @@ func (h *handler) handleConnect(w http.ResponseWriter, r *http.Request) { func (h *handler) handleHTTP(w http.ResponseWriter, r *http.Request) { resp, err := h.transport.RoundTrip(r) if err != nil { - log.Println(err) - w.WriteHeader(http.StatusServiceUnavailable) - fmt.Fprintf(w, "%s", err) + log.Errorln(err) + http.Error(w, err.Error(), http.StatusServiceUnavailable) return } defer resp.Body.Close() if debug { d, _ := httputil.DumpResponse(resp, false) - log.Printf("%s", string(d)) + log.Debugf("%s", string(d)) } hdr := w.Header() @@ -176,17 +174,21 @@ func (p *clientConn) GetClientConn(req *http.Request, addr string) (*http2.Clien return p.conn, nil } - if debug { - log.Printf("dial to %s:%s", p.host, p.port) + if p.conn != nil { + p.conn.Close() + p.conn = nil } + log.Infof("dial to %s:%s", p.host, p.port) + c, err := net.Dial("tcp", net.JoinHostPort(p.host, p.port)) if err != nil { - log.Println(err) + log.Errorln(err) return nil, err } - cc := &timeoutConn{c, time.Duration(idleTimeout) * time.Second} + // cc := &timeoutConn{c, time.Duration(idleTimeout) * time.Second} + cc := c config := &tls.Config{ ServerName: p.hostname, NextProtos: []string{"h2"}, @@ -195,14 +197,14 @@ func (p *clientConn) GetClientConn(req *http.Request, addr string) (*http2.Clien conn := tls.Client(cc, config) if err := conn.Handshake(); err != nil { - log.Println(err) + log.Errorln(err) return nil, err } http2conn, err := p.transport.NewClientConn(conn) if err != nil { conn.Close() - log.Println(err) + log.Errorln(err) return nil, err } @@ -212,14 +214,12 @@ func (p *clientConn) GetClientConn(req *http.Request, addr string) (*http2.Clien } func (p *clientConn) MarkDead(conn *http2.ClientConn) { - //p.lock.Lock() - //defer p.lock.Unlock() - - if debug { - log.Println("mark dead") - } + p.lock.Lock() + defer p.lock.Unlock() - //p.conn = nil + log.Errorln("mark dead") + p.conn.Close() + p.conn = nil } var debug bool @@ -235,7 +235,7 @@ func main() { flag.StringVar(&listen, "listen", ":8080", "listen address") flag.BoolVar(&debug, "debug", false, "verbose mode") flag.BoolVar(&insecure, "insecure", false, "insecure mode, not verify the server's certificate") - flag.IntVar(&idleTimeout, "idletime", 30, "idle timeout, close connection when no data transfer") + flag.IntVar(&idleTimeout, "idletime", 600, "idle timeout, close connection when no data transfer") flag.Parse() if addr == "" { @@ -243,6 +243,12 @@ func main() { os.Exit(-1) } + if debug { + log.Default.Level = log.DEBUG + } else { + log.Default.Level = log.INFO + } + host, port, err := net.SplitHostPort(addr) if err != nil { host = addr @@ -262,10 +268,8 @@ func main() { log.Printf("listen on %s", listen) - if debug { - log.Printf("use parent proxy https://%s:%s/", host, port) - log.Printf("server SNI name %s", hostname) - } + log.Printf("use parent proxy https://%s:%s/", host, port) + log.Printf("server SNI name %s", hostname) if err := http.ListenAndServe(listen, &handler{transport}); err != nil { log.Fatal(err)