add http forward handler

master
fangdingjun 7 years ago
parent a336f89a1e
commit 270d89c89a

@ -1,11 +1,13 @@
package main package main
import ( import (
"bufio"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"io" "io"
"log" "log"
"net" "net"
"net/http"
"net/url" "net/url"
"strings" "strings"
) )
@ -46,10 +48,10 @@ func initHandler() {
} }
func handleConnection(c net.Conn) { func handleConnection(c net.Conn) {
defer c.Close() //log.Printf("connection from %s", c.RemoteAddr().String())
log.Printf("connection from %s", c.RemoteAddr().String())
tlsconn := c.(*tls.Conn) tlsconn := c.(*tls.Conn)
defer tlsconn.Close()
connstate := tlsconn.ConnectionState() connstate := tlsconn.ConnectionState()
for !connstate.HandshakeComplete { for !connstate.HandshakeComplete {
if err := tlsconn.Handshake(); err != nil { if err := tlsconn.Handshake(); err != nil {
@ -59,7 +61,7 @@ func handleConnection(c net.Conn) {
connstate = tlsconn.ConnectionState() connstate = tlsconn.ConnectionState()
} }
log.Printf("handshake complete") //log.Printf("handshake complete")
servername := connstate.ServerName servername := connstate.ServerName
var backend *url.URL var backend *url.URL
for _, f := range _config.Forward { for _, f := range _config.Forward {
@ -77,7 +79,7 @@ func handleConnection(c net.Conn) {
} }
backend = _b backend = _b
} }
log.Printf("sni name %s, get backend: %s", servername, backend.String()) //log.Printf("sni name %s, get backend: %s", servername, backend.String())
handleForward(tlsconn, backend) handleForward(tlsconn, backend)
} }
@ -85,7 +87,7 @@ func handleForward(c *tls.Conn, b *url.URL) {
var remote net.Conn var remote net.Conn
var err error var err error
log.Printf("forward to %s", b.String()) //log.Printf("forward to %s", b.String())
switch b.Scheme { switch b.Scheme {
case "tcp": case "tcp":
remote, err = net.Dial("tcp", b.Host) remote, err = net.Dial("tcp", b.Host)
@ -96,6 +98,8 @@ func handleForward(c *tls.Conn, b *url.URL) {
b.Host = fmt.Sprintf("%s:80", b.Host) b.Host = fmt.Sprintf("%s:80", b.Host)
} }
remote, err = net.Dial("tcp", b.Host) remote, err = net.Dial("tcp", b.Host)
httpForward(c, remote)
return
case "tls": case "tls":
h, _, _ := net.SplitHostPort(b.Host) h, _, _ := net.SplitHostPort(b.Host)
remote, err = tls.Dial("tcp", b.Host, &tls.Config{ServerName: h}) remote, err = tls.Dial("tcp", b.Host, &tls.Config{ServerName: h})
@ -105,7 +109,7 @@ func handleForward(c *tls.Conn, b *url.URL) {
log.Println(err) log.Println(err)
return return
} }
log.Println("begin data forward") //log.Println("begin data forward")
defer remote.Close() defer remote.Close()
ch := make(chan struct{}, 2) ch := make(chan struct{}, 2)
@ -119,3 +123,33 @@ func handleForward(c *tls.Conn, b *url.URL) {
}() }()
<-ch <-ch
} }
func httpForward(r, b net.Conn) {
defer b.Close()
rb := bufio.NewReader(r)
bb := bufio.NewReader(b)
for {
req, err := http.ReadRequest(rb)
if err != nil {
if err != io.EOF {
log.Println(err)
fmt.Fprintf(b, "HTTP/1.1 504 Bad gateway\r\nConnection: close\r\n\r\n")
}
return
}
addr := r.RemoteAddr().(*net.TCPAddr)
req.Header.Add("X-Forwarded-For", addr.IP.String())
req.Header.Add("X-Real-Ip", addr.IP.String())
//log.Printf("%+v\n", req.Header)
req.Write(b)
res, err := http.ReadResponse(bb, req)
if err != nil {
log.Println(err)
return
}
res.Write(r)
}
}

Loading…
Cancel
Save