You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

244 lines
5.1 KiB
Go

7 years ago
package main
import (
"bufio"
"bytes"
7 years ago
"crypto/tls"
"fmt"
"io"
"io/ioutil"
7 years ago
"net"
"net/http"
7 years ago
"net/url"
"strings"
"github.com/fangdingjun/go-log"
7 years ago
)
func initHandler() {
var _certs []tls.Certificate
for _, _c := range _config.Certificate {
log.Debugf("load certificate %s, %s...", _c.Cert, _c.Key)
7 years ago
_cert, err := _c.load()
if err != nil {
log.Errorln("load certificate failed", err)
7 years ago
continue
}
_certs = append(_certs, _cert)
}
_tlsconfig := &tls.Config{
Certificates: _certs,
}
_tlsconfig.BuildNameToCertificate()
for _, _l := range _config.Listen {
l, err := tls.Listen("tcp", _l, _tlsconfig)
if err != nil {
log.Fatalf("listen on %s error: %s", _l, err)
7 years ago
}
log.Printf("Listen on %s", l.Addr().String())
7 years ago
go func(l net.Listener) {
defer l.Close()
for {
c, err := l.Accept()
if err != nil {
log.Errorf("accept error: %s", err)
continue
7 years ago
}
go handleConnection(c)
}
}(l)
}
}
func handleConnection(c net.Conn) {
defer func() {
log.Debugf("close connection %s", c.RemoteAddr())
if err := c.Close(); err != nil {
log.Debugf("close connection %s error: %s", c.RemoteAddr(), err)
}
}()
log.Debugf("connection from %s\n", c.RemoteAddr().String())
7 years ago
tlsconn := c.(*tls.Conn)
7 years ago
connstate := tlsconn.ConnectionState()
for !connstate.HandshakeComplete {
if err := tlsconn.Handshake(); err != nil {
log.Errorf("%s tls handshake error: %s", c.RemoteAddr(), err)
7 years ago
return
}
connstate = tlsconn.ConnectionState()
}
log.Debugf("%s handshake complete", c.RemoteAddr())
7 years ago
servername := connstate.ServerName
7 years ago
var backend *url.URL
for _, f := range _config.Forward {
if !f.match(servername) {
continue
}
backend = f.getBackend()
break
}
if backend == nil {
_b, err := url.Parse(_config.DefaultBackend)
if err != nil {
log.Errorf("parse addr error: %s", err)
7 years ago
return
}
backend = _b
}
log.Debugf("connection from %s, tls version 0x%x, sni: %s, forward to: %s",
c.RemoteAddr().String(),
connstate.Version,
servername,
backend.String(),
)
7 years ago
handleForward(tlsconn, backend)
}
func handleForward(c *tls.Conn, b *url.URL) {
var remote net.Conn
var err error
log.Debugf("%s forward to %s", c.RemoteAddr(), b.String())
7 years ago
switch b.Scheme {
case "http":
httpForward(c, b)
return
7 years ago
case "tcp":
//log.Debugf("forward to tcp backend %s", b.Host)
7 years ago
remote, err = net.Dial("tcp", b.Host)
case "unix":
//log.Debugf("forward to unix backend %s", b.Host)
7 years ago
remote, err = net.Dial("unix", b.Host)
case "tls":
//log.Debugf("forward to tls backend %s", b.Host)
7 years ago
h, _, _ := net.SplitHostPort(b.Host)
remote, err = tls.Dial("tcp", b.Host, &tls.Config{ServerName: h})
default:
log.Errorf("backend type '%s' is not supported", b.Scheme)
return
7 years ago
}
if err != nil {
log.Errorf("dail to backend %s error: %s", b.String(), err)
return
}
if remote == nil {
log.Warnln("remote is nil")
7 years ago
return
}
//log.Println("begin data forward")
7 years ago
defer remote.Close()
7 years ago
ch := make(chan struct{}, 2)
go func() {
io.Copy(c, remote)
ch <- struct{}{}
}()
7 years ago
go func() {
io.Copy(remote, c)
ch <- struct{}{}
}()
<-ch
}
func writeErrResponse(w io.Writer, status int, content string) {
hdr := http.Header{}
hdr.Add("Connection", "close")
hdr.Add("Server", "nginx/1.10.1")
hdr.Add("Content-Type", "text/plain")
r := bytes.NewBuffer([]byte(content))
body := ioutil.NopCloser(r)
res := &http.Response{
Status: http.StatusText(status),
StatusCode: status,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: hdr,
Body: body,
ContentLength: int64(len(content)),
}
res.Write(w)
}
func httpForward(r net.Conn, b *url.URL) {
//log.Debugf("forward to http backend %s", b.Host)
if !strings.Contains(b.Host, ":") {
b.Host = fmt.Sprintf("%s:80", b.Host)
}
backend, err := net.Dial("tcp", b.Host)
if err != nil {
log.Errorf("dial to %s error: %s", b.Host, err)
// consume the client request,
// or client will get forcely close connection error
http.ReadRequest(bufio.NewReader(r))
writeErrResponse(r, http.StatusBadGateway, err.Error())
return
}
defer backend.Close()
rb := bufio.NewReader(r)
bb := bufio.NewReader(backend)
for {
req, err := http.ReadRequest(rb)
if err != nil {
if err != io.EOF {
log.Errorf("read http request error: %s", err)
}
return
}
addr := r.RemoteAddr().(*net.TCPAddr)
req.Header.Add("X-Forwarded-For", addr.IP.String())
req.Header.Add("X-Real-Ip", addr.IP.String())
//log.Printf("%+v\n", req.Header)
err = req.Write(backend)
if req.Body != nil {
req.Body.Close()
}
if err != nil {
log.Errorf("write request to backend error: %s", err)
writeErrResponse(r, http.StatusBadGateway, err.Error())
return
}
res, err := http.ReadResponse(bb, req)
if err != nil {
if err != io.EOF {
log.Errorf("read http response from backend error: %s", err)
}
writeErrResponse(r, http.StatusBadGateway, "gateway error")
return
}
err = res.Write(r)
if res.Body != nil {
res.Body.Close()
}
if err != nil {
log.Errorf("write response to client error: %s", err)
return
}
}
}