|
|
|
@ -7,19 +7,21 @@ import (
|
|
|
|
|
"fmt"
|
|
|
|
|
"io"
|
|
|
|
|
"io/ioutil"
|
|
|
|
|
"log"
|
|
|
|
|
"net"
|
|
|
|
|
"net/http"
|
|
|
|
|
"net/url"
|
|
|
|
|
"strings"
|
|
|
|
|
|
|
|
|
|
"github.com/fangdingjun/go-log"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
func initHandler() {
|
|
|
|
|
var _certs []tls.Certificate
|
|
|
|
|
for _, _c := range _config.Certificate {
|
|
|
|
|
log.Debugf("load certificate %s, %s...", _c.Cert, _c.Key)
|
|
|
|
|
_cert, err := _c.load()
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Println("load certificate failed", err)
|
|
|
|
|
log.Errorln("load certificate failed", err)
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
_certs = append(_certs, _cert)
|
|
|
|
@ -41,7 +43,7 @@ func initHandler() {
|
|
|
|
|
for {
|
|
|
|
|
c, err := l.Accept()
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("accept error: %s", err)
|
|
|
|
|
log.Errorf("accept error: %s", err)
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
go handleConnection(c)
|
|
|
|
@ -51,20 +53,26 @@ func initHandler() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func handleConnection(c net.Conn) {
|
|
|
|
|
//log.Printf("connection from %s", c.RemoteAddr().String())
|
|
|
|
|
defer func() {
|
|
|
|
|
log.Debugf("close connection %s", c.RemoteAddr())
|
|
|
|
|
if err := c.Close(); err != nil {
|
|
|
|
|
log.Debugf("close connection %s error: %s", c.RemoteAddr(), err)
|
|
|
|
|
}
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
log.Debugf("connection from %s\n", c.RemoteAddr().String())
|
|
|
|
|
tlsconn := c.(*tls.Conn)
|
|
|
|
|
defer tlsconn.Close()
|
|
|
|
|
|
|
|
|
|
connstate := tlsconn.ConnectionState()
|
|
|
|
|
for !connstate.HandshakeComplete {
|
|
|
|
|
if err := tlsconn.Handshake(); err != nil {
|
|
|
|
|
log.Printf("handshake error: %s", err)
|
|
|
|
|
log.Errorf("%s tls handshake error: %s", c.RemoteAddr(), err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
connstate = tlsconn.ConnectionState()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//log.Printf("handshake complete")
|
|
|
|
|
log.Debugf("%s handshake complete", c.RemoteAddr())
|
|
|
|
|
|
|
|
|
|
servername := connstate.ServerName
|
|
|
|
|
|
|
|
|
@ -79,13 +87,13 @@ func handleConnection(c net.Conn) {
|
|
|
|
|
if backend == nil {
|
|
|
|
|
_b, err := url.Parse(_config.DefaultBackend)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("parse addr error: %s", err)
|
|
|
|
|
log.Errorf("parse addr error: %s", err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
backend = _b
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
log.Printf("connection from %s, tls version 0x%x, sni: %s, forward to: %s\n",
|
|
|
|
|
log.Debugf("connection from %s, tls version 0x%x, sni: %s, forward to: %s",
|
|
|
|
|
c.RemoteAddr().String(),
|
|
|
|
|
connstate.Version,
|
|
|
|
|
servername,
|
|
|
|
@ -99,39 +107,33 @@ func handleForward(c *tls.Conn, b *url.URL) {
|
|
|
|
|
var remote net.Conn
|
|
|
|
|
var err error
|
|
|
|
|
|
|
|
|
|
//log.Printf("forward to %s", b.String())
|
|
|
|
|
log.Debugf("%s forward to %s", c.RemoteAddr(), b.String())
|
|
|
|
|
switch b.Scheme {
|
|
|
|
|
case "http":
|
|
|
|
|
httpForward(c, b)
|
|
|
|
|
return
|
|
|
|
|
case "tcp":
|
|
|
|
|
//log.Debugf("forward to tcp backend %s", b.Host)
|
|
|
|
|
remote, err = net.Dial("tcp", b.Host)
|
|
|
|
|
case "unix":
|
|
|
|
|
//log.Debugf("forward to unix backend %s", b.Host)
|
|
|
|
|
remote, err = net.Dial("unix", b.Host)
|
|
|
|
|
case "http":
|
|
|
|
|
if !strings.Contains(b.Host, ":") {
|
|
|
|
|
b.Host = fmt.Sprintf("%s:80", b.Host)
|
|
|
|
|
}
|
|
|
|
|
remote, err = net.Dial("tcp", b.Host)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("dial to %s error: %s", b.Host, err)
|
|
|
|
|
writeErrResponse(c, http.StatusBadGateway, err.Error())
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
httpForward(c, remote)
|
|
|
|
|
return
|
|
|
|
|
case "tls":
|
|
|
|
|
//log.Debugf("forward to tls backend %s", b.Host)
|
|
|
|
|
h, _, _ := net.SplitHostPort(b.Host)
|
|
|
|
|
remote, err = tls.Dial("tcp", b.Host, &tls.Config{ServerName: h})
|
|
|
|
|
default:
|
|
|
|
|
log.Printf("backend type '%s' is not supported", b.Scheme)
|
|
|
|
|
log.Errorf("backend type '%s' is not supported", b.Scheme)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("dail to backend %s error: %s", b.String(), err)
|
|
|
|
|
log.Errorf("dail to backend %s error: %s", b.String(), err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if remote == nil {
|
|
|
|
|
log.Println("remote is nil")
|
|
|
|
|
log.Warnln("remote is nil")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
//log.Println("begin data forward")
|
|
|
|
@ -173,18 +175,29 @@ func writeErrResponse(w io.Writer, status int, content string) {
|
|
|
|
|
res.Write(w)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func httpForward(r, b net.Conn) {
|
|
|
|
|
defer b.Close()
|
|
|
|
|
func httpForward(r net.Conn, b *url.URL) {
|
|
|
|
|
//log.Debugf("forward to http backend %s", b.Host)
|
|
|
|
|
if !strings.Contains(b.Host, ":") {
|
|
|
|
|
b.Host = fmt.Sprintf("%s:80", b.Host)
|
|
|
|
|
}
|
|
|
|
|
backend, err := net.Dial("tcp", b.Host)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Errorf("dial to %s error: %s", b.Host, err)
|
|
|
|
|
// consume the client request,
|
|
|
|
|
// or client will get forcely close connection error
|
|
|
|
|
http.ReadRequest(bufio.NewReader(r))
|
|
|
|
|
writeErrResponse(r, http.StatusBadGateway, err.Error())
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
defer backend.Close()
|
|
|
|
|
|
|
|
|
|
rb := bufio.NewReader(r)
|
|
|
|
|
bb := bufio.NewReader(b)
|
|
|
|
|
bb := bufio.NewReader(backend)
|
|
|
|
|
for {
|
|
|
|
|
req, err := http.ReadRequest(rb)
|
|
|
|
|
if err != nil {
|
|
|
|
|
if err != io.EOF {
|
|
|
|
|
log.Printf("read http request error: %s", err)
|
|
|
|
|
//fmt.Fprintf(b, "HTTP/1.1 504 Bad gateway\r\nConnection: close\r\n\r\n")
|
|
|
|
|
//writeErrResponse(b, http.StatusBadGateway, err.Error())
|
|
|
|
|
log.Errorf("read http request error: %s", err)
|
|
|
|
|
}
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
@ -195,14 +208,14 @@ func httpForward(r, b net.Conn) {
|
|
|
|
|
req.Header.Add("X-Real-Ip", addr.IP.String())
|
|
|
|
|
|
|
|
|
|
//log.Printf("%+v\n", req.Header)
|
|
|
|
|
err = req.Write(b)
|
|
|
|
|
err = req.Write(backend)
|
|
|
|
|
|
|
|
|
|
if req.Body != nil {
|
|
|
|
|
req.Body.Close()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("write request to backend error: %s", err)
|
|
|
|
|
log.Errorf("write request to backend error: %s", err)
|
|
|
|
|
writeErrResponse(r, http.StatusBadGateway, err.Error())
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
@ -210,7 +223,7 @@ func httpForward(r, b net.Conn) {
|
|
|
|
|
res, err := http.ReadResponse(bb, req)
|
|
|
|
|
if err != nil {
|
|
|
|
|
if err != io.EOF {
|
|
|
|
|
log.Printf("read http response from backend error: %s", err)
|
|
|
|
|
log.Errorf("read http response from backend error: %s", err)
|
|
|
|
|
}
|
|
|
|
|
writeErrResponse(r, http.StatusBadGateway, "gateway error")
|
|
|
|
|
return
|
|
|
|
@ -223,7 +236,7 @@ func httpForward(r, b net.Conn) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("write response to client error: %s", err)
|
|
|
|
|
log.Errorf("write response to client error: %s", err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|