diff --git a/upstream.go b/upstream.go index be63bdf..b17ae6c 100644 --- a/upstream.go +++ b/upstream.go @@ -1,12 +1,16 @@ package main import ( + "bufio" "context" "crypto/tls" "errors" + "fmt" "net" "net/http" "net/url" + "os" + "strings" "time" log "github.com/fangdingjun/go-log" @@ -147,8 +151,42 @@ func initDNSClient(c *conf) { Transport: &http2.Transport{ DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { log.Debugln("dial to", network, addr) - conn, err := tls.DialWithDialer(dialer, network, addr, cfg) - return conn, err + p := os.Getenv("https_proxy") + if p == "" { + p = os.Getenv("http_proxy") + } + if p == "" { + conn, err := tls.DialWithDialer(dialer, network, addr, cfg) + return conn, err + } + u, _ := url.Parse(p) + log.Debugf("dial to proxy %s", u.Host) + conn, err := net.Dial(network, u.Host) + if err != nil { + return nil, err + } + fmt.Fprintf(conn, "CONNECT %s HTTP/1.1\r\nHost: %s\r\n\r\n", addr, addr) + r := bufio.NewReader(conn) + line, err := r.ReadString('\n') + ss := strings.Fields(line) + if ss[1] != "200" { + return nil, fmt.Errorf("http code %s", ss[1]) + } + for { + line, err := r.ReadString('\n') + if err != nil { + return nil, err + } + line = strings.Trim(line, "\r\n") + if line == "" { + break + } + } + tlsconn := tls.Client(conn, cfg) + if err = tlsconn.Handshake(); err != nil { + return nil, err + } + return tlsconn, nil }, TLSClientConfig: &tls.Config{ InsecureSkipVerify: c.UpstreamInsecure,