You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

191 lines
3.1 KiB
Go

package main
import (
"fmt"
"io"
"io/ioutil"
"strings"
"github.com/go-yaml/yaml"
"github.com/golang/glog"
proxyproto "github.com/pires/go-proxyproto"
//"crypto/tls"
"flag"
//"log"
"net"
)
func getSNIServerName(buf []byte) string {
n := len(buf)
if n < 5 {
glog.Error("not tls handshake")
return ""
}
// tls record type
if recordType(buf[0]) != recordTypeHandshake {
glog.Error("not tls")
return ""
}
// tls major version
if buf[1] != 3 {
glog.Error("TLS version < 3 not supported")
return ""
}
// payload length
//l := int(buf[3])<<16 + int(buf[4])
//log.Printf("length: %d, got: %d", l, n)
// handshake message type
if uint8(buf[5]) != typeClientHello {
glog.Error("not client hello")
return ""
}
// parse client hello message
msg := &clientHelloMsg{}
// client hello message not include tls header, 5 bytes
ret := msg.unmarshal(buf[5:n])
if !ret {
glog.Error("parse hello message return false")
return ""
}
return msg.serverName
}
func forward(c net.Conn, data []byte, dst string) {
addr := dst
proxyProto := 0
ss := strings.Fields(dst)
var hdr proxyproto.Header
if len(ss) > 1 {
addr = ss[0]
raddr := c.RemoteAddr().(*net.TCPAddr)
hdr = proxyproto.Header{
Version: 1,
Command: proxyproto.PROXY,
TransportProtocol: proxyproto.TCPv4,
SourceAddress: raddr.IP.To4(),
DestinationAddress: net.IP{0, 0, 0, 0},
SourcePort: uint16(raddr.Port),
DestinationPort: 0,
}
switch strings.ToLower(ss[1]) {
case "proxy-v1":
proxyProto = 1
hdr.Version = 1
case "proxy-v2":
proxyProto = 2
hdr.Version = 2
}
}
c1, err := net.Dial("tcp", addr)
if err != nil {
glog.Error(err)
return
}
defer c1.Close()
if proxyProto != 0 {
hdr.WriteTo(c1)
}
if _, err = c1.Write(data); err != nil {
glog.Error(err)
return
}
ch := make(chan struct{}, 2)
go func() {
io.Copy(c1, c)
ch <- struct{}{}
}()
go func() {
io.Copy(c, c1)
ch <- struct{}{}
}()
<-ch
}
func getDST(c net.Conn, name string) string {
addr := c.LocalAddr().(*net.TCPAddr)
dst := cfg.ForwardRules.GetN(name, addr.Port)
return dst
}
func getDefaultDST() string {
return cfg.Default
}
func serve(c net.Conn) {
defer c.Close()
buf := make([]byte, 1024)
n, err := c.Read(buf)
if err != nil {
glog.Error(err)
return
}
servername := getSNIServerName(buf[:n])
if servername == "" {
forward(c, buf[:n], getDefaultDST())
return
}
dst := getDST(c, servername)
if dst == "" {
dst = getDefaultDST()
}
forward(c, buf[:n], dst)
}
var cfg conf
func main() {
var cfgfile string
flag.StringVar(&cfgfile, "c", "config.yaml", "config file")
flag.Set("logtostderr", "true")
flag.Parse()
data, err := ioutil.ReadFile(cfgfile)
if err != nil {
glog.Fatal(err)
}
if err := yaml.Unmarshal(data, &cfg); err != nil {
glog.Fatal(err)
}
for _, d := range cfg.Listen {
glog.Infof("listen on :%d", d)
l, err := net.Listen("tcp", fmt.Sprintf(":%d", d))
if err != nil {
glog.Fatal(err)
}
go func(l net.Listener) {
defer l.Close()
for {
c1, err := l.Accept()
if err != nil {
glog.Fatal(err)
}
go serve(c1)
}
}(l)
}
select {}
}