diff --git a/conf.go b/conf.go
index e380423..90cdd47 100644
--- a/conf.go
+++ b/conf.go
@@ -6,9 +6,10 @@ import (
)
type conf struct {
- Listen []listen
- Docroot string
- URLRules []rule
+ Listen []listen
+ Docroot string
+ URLRules []rule
+ LocalDomains []string
}
type listen struct {
diff --git a/config_example.yaml b/config_example.yaml
index af8eb8f..32bd98b 100644
--- a/config_example.yaml
+++ b/config_example.yaml
@@ -1,6 +1,10 @@
# document root
docroot: /var/www/html
+localdomains:
+ - www.simicloud.com
+ - localhost
+ - 127.0.0.1
# listener
listen:
diff --git a/handler.go b/handler.go
index 12afc34..a1c984e 100644
--- a/handler.go
+++ b/handler.go
@@ -11,7 +11,8 @@ import (
)
type handler struct {
- enableProxy bool
+ enableProxy bool
+ localDomains []string
}
var defaultTransport http.RoundTripper = &http.Transport{
@@ -23,7 +24,12 @@ var defaultTransport http.RoundTripper = &http.Transport{
}
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- if r.RequestURI[0] == '/' {
+ if r.ProtoMajor == 1 && r.RequestURI[0] == '/' {
+ http.DefaultServeMux.ServeHTTP(w, r)
+ return
+ }
+
+ if r.ProtoMajor == 2 && h.isLocalRequest(r) {
http.DefaultServeMux.ServeHTTP(w, r)
return
}
@@ -33,6 +39,7 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "
page not found!
")
return
}
+
if r.Method == http.MethodConnect {
h.handleCONNECT(w, r)
} else {
@@ -51,6 +58,11 @@ func (h *handler) handleHTTP(w http.ResponseWriter, r *http.Request) {
r.URL.Scheme = "http"
r.URL.Host = r.Host
r.RequestURI = r.URL.String()
+ if r.Method != http.MethodPost && r.Method != http.MethodPut {
+ r.ContentLength = 0
+ r.Body.Close()
+ r.Body = nil
+ }
}
resp, err = defaultTransport.RoundTrip(r)
@@ -79,6 +91,16 @@ func (h *handler) handleHTTP(w http.ResponseWriter, r *http.Request) {
io.Copy(w, resp.Body)
}
+type flushWriter struct {
+ w io.Writer
+}
+
+func (fw flushWriter) Write(buf []byte) (int, error) {
+ n, err := fw.w.Write(buf)
+ fw.w.(http.Flusher).Flush()
+ return n, err
+}
+
func (h *handler) handleCONNECT(w http.ResponseWriter, r *http.Request) {
host := r.RequestURI
@@ -126,13 +148,31 @@ func (h *handler) handleCONNECT(w http.ResponseWriter, r *http.Request) {
}()
go func() {
- io.Copy(w, conn)
+ io.Copy(flushWriter{w}, conn)
ch <- 1
}()
<-ch
}
+func (h *handler) isLocalRequest(r *http.Request) bool {
+ if len(h.localDomains) == 0 {
+ return true
+ }
+
+ host := r.Host
+ if h1, _, err := net.SplitHostPort(r.Host); err == nil {
+ host = h1
+ }
+
+ for _, s := range h.localDomains {
+ if strings.HasSuffix(host, s) {
+ return true
+ }
+ }
+
+ return false
+}
func pipeAndClose(r1, r2 io.ReadWriteCloser) {
ch := make(chan int, 2)
go func() {
diff --git a/server.go b/server.go
index 22f2169..e897132 100644
--- a/server.go
+++ b/server.go
@@ -11,14 +11,13 @@ func initListeners(c *conf) {
for _, l := range c.Listen {
go func(l listen) {
addr := fmt.Sprintf("%s:%d", l.Host, l.Port)
+ h := &handler{enableProxy: l.EnableProxy, localDomains: c.LocalDomains}
if l.Cert != "" && l.Key != "" {
- if err := http.ListenAndServeTLS(addr, l.Cert, l.Key,
- &handler{enableProxy: l.EnableProxy}); err != nil {
+ if err := http.ListenAndServeTLS(addr, l.Cert, l.Key, h); err != nil {
log.Fatal(err)
}
} else {
- if err := http.ListenAndServe(addr,
- &handler{enableProxy: l.EnableProxy}); err != nil {
+ if err := http.ListenAndServe(addr, h); err != nil {
log.Fatal(err)
}
}