diff --git a/conf.go b/conf.go index e380423..90cdd47 100644 --- a/conf.go +++ b/conf.go @@ -6,9 +6,10 @@ import ( ) type conf struct { - Listen []listen - Docroot string - URLRules []rule + Listen []listen + Docroot string + URLRules []rule + LocalDomains []string } type listen struct { diff --git a/config_example.yaml b/config_example.yaml index af8eb8f..32bd98b 100644 --- a/config_example.yaml +++ b/config_example.yaml @@ -1,6 +1,10 @@ # document root docroot: /var/www/html +localdomains: + - www.simicloud.com + - localhost + - 127.0.0.1 # listener listen: diff --git a/handler.go b/handler.go index 12afc34..a1c984e 100644 --- a/handler.go +++ b/handler.go @@ -11,7 +11,8 @@ import ( ) type handler struct { - enableProxy bool + enableProxy bool + localDomains []string } var defaultTransport http.RoundTripper = &http.Transport{ @@ -23,7 +24,12 @@ var defaultTransport http.RoundTripper = &http.Transport{ } func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if r.RequestURI[0] == '/' { + if r.ProtoMajor == 1 && r.RequestURI[0] == '/' { + http.DefaultServeMux.ServeHTTP(w, r) + return + } + + if r.ProtoMajor == 2 && h.isLocalRequest(r) { http.DefaultServeMux.ServeHTTP(w, r) return } @@ -33,6 +39,7 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "

page not found!

") return } + if r.Method == http.MethodConnect { h.handleCONNECT(w, r) } else { @@ -51,6 +58,11 @@ func (h *handler) handleHTTP(w http.ResponseWriter, r *http.Request) { r.URL.Scheme = "http" r.URL.Host = r.Host r.RequestURI = r.URL.String() + if r.Method != http.MethodPost && r.Method != http.MethodPut { + r.ContentLength = 0 + r.Body.Close() + r.Body = nil + } } resp, err = defaultTransport.RoundTrip(r) @@ -79,6 +91,16 @@ func (h *handler) handleHTTP(w http.ResponseWriter, r *http.Request) { io.Copy(w, resp.Body) } +type flushWriter struct { + w io.Writer +} + +func (fw flushWriter) Write(buf []byte) (int, error) { + n, err := fw.w.Write(buf) + fw.w.(http.Flusher).Flush() + return n, err +} + func (h *handler) handleCONNECT(w http.ResponseWriter, r *http.Request) { host := r.RequestURI @@ -126,13 +148,31 @@ func (h *handler) handleCONNECT(w http.ResponseWriter, r *http.Request) { }() go func() { - io.Copy(w, conn) + io.Copy(flushWriter{w}, conn) ch <- 1 }() <-ch } +func (h *handler) isLocalRequest(r *http.Request) bool { + if len(h.localDomains) == 0 { + return true + } + + host := r.Host + if h1, _, err := net.SplitHostPort(r.Host); err == nil { + host = h1 + } + + for _, s := range h.localDomains { + if strings.HasSuffix(host, s) { + return true + } + } + + return false +} func pipeAndClose(r1, r2 io.ReadWriteCloser) { ch := make(chan int, 2) go func() { diff --git a/server.go b/server.go index 22f2169..e897132 100644 --- a/server.go +++ b/server.go @@ -11,14 +11,13 @@ func initListeners(c *conf) { for _, l := range c.Listen { go func(l listen) { addr := fmt.Sprintf("%s:%d", l.Host, l.Port) + h := &handler{enableProxy: l.EnableProxy, localDomains: c.LocalDomains} if l.Cert != "" && l.Key != "" { - if err := http.ListenAndServeTLS(addr, l.Cert, l.Key, - &handler{enableProxy: l.EnableProxy}); err != nil { + if err := http.ListenAndServeTLS(addr, l.Cert, l.Key, h); err != nil { log.Fatal(err) } } else { - if err := http.ListenAndServe(addr, - &handler{enableProxy: l.EnableProxy}); err != nil { + if err := http.ListenAndServe(addr, h); err != nil { log.Fatal(err) } }