|
|
@ -1,11 +1,13 @@
|
|
|
|
package main
|
|
|
|
package main
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
import (
|
|
|
|
|
|
|
|
"bufio"
|
|
|
|
"crypto/tls"
|
|
|
|
"crypto/tls"
|
|
|
|
"fmt"
|
|
|
|
"fmt"
|
|
|
|
"io"
|
|
|
|
"io"
|
|
|
|
"log"
|
|
|
|
"log"
|
|
|
|
"net"
|
|
|
|
"net"
|
|
|
|
|
|
|
|
"net/http"
|
|
|
|
"net/url"
|
|
|
|
"net/url"
|
|
|
|
"strings"
|
|
|
|
"strings"
|
|
|
|
)
|
|
|
|
)
|
|
|
@ -46,10 +48,10 @@ func initHandler() {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func handleConnection(c net.Conn) {
|
|
|
|
func handleConnection(c net.Conn) {
|
|
|
|
defer c.Close()
|
|
|
|
//log.Printf("connection from %s", c.RemoteAddr().String())
|
|
|
|
|
|
|
|
|
|
|
|
log.Printf("connection from %s", c.RemoteAddr().String())
|
|
|
|
|
|
|
|
tlsconn := c.(*tls.Conn)
|
|
|
|
tlsconn := c.(*tls.Conn)
|
|
|
|
|
|
|
|
defer tlsconn.Close()
|
|
|
|
|
|
|
|
|
|
|
|
connstate := tlsconn.ConnectionState()
|
|
|
|
connstate := tlsconn.ConnectionState()
|
|
|
|
for !connstate.HandshakeComplete {
|
|
|
|
for !connstate.HandshakeComplete {
|
|
|
|
if err := tlsconn.Handshake(); err != nil {
|
|
|
|
if err := tlsconn.Handshake(); err != nil {
|
|
|
@ -59,7 +61,7 @@ func handleConnection(c net.Conn) {
|
|
|
|
connstate = tlsconn.ConnectionState()
|
|
|
|
connstate = tlsconn.ConnectionState()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
log.Printf("handshake complete")
|
|
|
|
//log.Printf("handshake complete")
|
|
|
|
servername := connstate.ServerName
|
|
|
|
servername := connstate.ServerName
|
|
|
|
var backend *url.URL
|
|
|
|
var backend *url.URL
|
|
|
|
for _, f := range _config.Forward {
|
|
|
|
for _, f := range _config.Forward {
|
|
|
@ -77,7 +79,7 @@ func handleConnection(c net.Conn) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
backend = _b
|
|
|
|
backend = _b
|
|
|
|
}
|
|
|
|
}
|
|
|
|
log.Printf("sni name %s, get backend: %s", servername, backend.String())
|
|
|
|
//log.Printf("sni name %s, get backend: %s", servername, backend.String())
|
|
|
|
handleForward(tlsconn, backend)
|
|
|
|
handleForward(tlsconn, backend)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -85,7 +87,7 @@ func handleForward(c *tls.Conn, b *url.URL) {
|
|
|
|
var remote net.Conn
|
|
|
|
var remote net.Conn
|
|
|
|
var err error
|
|
|
|
var err error
|
|
|
|
|
|
|
|
|
|
|
|
log.Printf("forward to %s", b.String())
|
|
|
|
//log.Printf("forward to %s", b.String())
|
|
|
|
switch b.Scheme {
|
|
|
|
switch b.Scheme {
|
|
|
|
case "tcp":
|
|
|
|
case "tcp":
|
|
|
|
remote, err = net.Dial("tcp", b.Host)
|
|
|
|
remote, err = net.Dial("tcp", b.Host)
|
|
|
@ -96,6 +98,8 @@ func handleForward(c *tls.Conn, b *url.URL) {
|
|
|
|
b.Host = fmt.Sprintf("%s:80", b.Host)
|
|
|
|
b.Host = fmt.Sprintf("%s:80", b.Host)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
remote, err = net.Dial("tcp", b.Host)
|
|
|
|
remote, err = net.Dial("tcp", b.Host)
|
|
|
|
|
|
|
|
httpForward(c, remote)
|
|
|
|
|
|
|
|
return
|
|
|
|
case "tls":
|
|
|
|
case "tls":
|
|
|
|
h, _, _ := net.SplitHostPort(b.Host)
|
|
|
|
h, _, _ := net.SplitHostPort(b.Host)
|
|
|
|
remote, err = tls.Dial("tcp", b.Host, &tls.Config{ServerName: h})
|
|
|
|
remote, err = tls.Dial("tcp", b.Host, &tls.Config{ServerName: h})
|
|
|
@ -105,7 +109,7 @@ func handleForward(c *tls.Conn, b *url.URL) {
|
|
|
|
log.Println(err)
|
|
|
|
log.Println(err)
|
|
|
|
return
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
log.Println("begin data forward")
|
|
|
|
//log.Println("begin data forward")
|
|
|
|
|
|
|
|
|
|
|
|
defer remote.Close()
|
|
|
|
defer remote.Close()
|
|
|
|
ch := make(chan struct{}, 2)
|
|
|
|
ch := make(chan struct{}, 2)
|
|
|
@ -119,3 +123,33 @@ func handleForward(c *tls.Conn, b *url.URL) {
|
|
|
|
}()
|
|
|
|
}()
|
|
|
|
<-ch
|
|
|
|
<-ch
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func httpForward(r, b net.Conn) {
|
|
|
|
|
|
|
|
defer b.Close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rb := bufio.NewReader(r)
|
|
|
|
|
|
|
|
bb := bufio.NewReader(b)
|
|
|
|
|
|
|
|
for {
|
|
|
|
|
|
|
|
req, err := http.ReadRequest(rb)
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
|
|
if err != io.EOF {
|
|
|
|
|
|
|
|
log.Println(err)
|
|
|
|
|
|
|
|
fmt.Fprintf(b, "HTTP/1.1 504 Bad gateway\r\nConnection: close\r\n\r\n")
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
addr := r.RemoteAddr().(*net.TCPAddr)
|
|
|
|
|
|
|
|
req.Header.Add("X-Forwarded-For", addr.IP.String())
|
|
|
|
|
|
|
|
req.Header.Add("X-Real-Ip", addr.IP.String())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//log.Printf("%+v\n", req.Header)
|
|
|
|
|
|
|
|
req.Write(b)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
res, err := http.ReadResponse(bb, req)
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
|
|
log.Println(err)
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
res.Write(r)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|