diff --git a/http2_proxy.go b/http2_proxy.go index a1f21e9..86f5383 100644 --- a/http2_proxy.go +++ b/http2_proxy.go @@ -27,6 +27,7 @@ import ( "net/http/httputil" "os" "sync" + "time" ) type clientConn struct { @@ -38,6 +39,25 @@ type clientConn struct { lock *sync.Mutex } +type timeoutConn struct { + net.Conn + timeout time.Duration +} + +func (tc *timeoutConn) Read(b []byte) (n int, err error) { + if err = tc.Conn.SetReadDeadline(time.Now().Add(tc.timeout)); err != nil { + return 0, err + } + return tc.Conn.Read(b) +} + +func (tc *timeoutConn) Write(b []byte) (n int, err error) { + if err = tc.Conn.SetWriteDeadline(time.Now().Add(tc.timeout)); err != nil { + return 0, err + } + return tc.Conn.Write(b) +} + type handler struct { transport *http2.Transport } @@ -155,21 +175,29 @@ func (p *clientConn) GetClientConn(req *http.Request, addr string) (*http2.Clien return p.conn, nil } - config := &tls.Config{ - ServerName: p.hostname, - NextProtos: []string{"h2"}, - } - if debug { log.Printf("dial to %s:%s", p.host, p.port) } - conn, err := tls.Dial("tcp", net.JoinHostPort(p.host, p.port), config) + c, err := net.Dial("tcp", net.JoinHostPort(p.host, p.port)) if err != nil { log.Println(err) return nil, err } + cc := &timeoutConn{c, time.Duration(idleTimeout) * time.Second} + config := &tls.Config{ + ServerName: p.hostname, + NextProtos: []string{"h2"}, + InsecureSkipVerify: insecure, + } + + conn := tls.Client(cc, config) + if err := conn.Handshake(); err != nil { + log.Println(err) + return nil, err + } + http2conn, err := p.transport.NewClientConn(conn) if err != nil { conn.Close() @@ -194,6 +222,8 @@ func (p *clientConn) MarkDead(conn *http2.ClientConn) { } var debug bool +var insecure bool +var idleTimeout int func main() { var addr string @@ -203,6 +233,8 @@ func main() { flag.StringVar(&hostname, "name", "", "server 's SNI name") flag.StringVar(&listen, "listen", ":8080", "listen address") flag.BoolVar(&debug, "debug", false, "verbose mode") + flag.BoolVar(&insecure, "insecure", false, "insecure mode, not verify the server's certificate") + flag.IntVar(&idleTimeout, "idletime", 30, "idle timeout, close connection when no data transfer") flag.Parse() if addr == "" {