package main import ( "fmt" "io" "io/ioutil" "strings" "github.com/go-yaml/yaml" "github.com/golang/glog" proxyproto "github.com/pires/go-proxyproto" //"crypto/tls" "flag" //"log" "net" ) func getSNIServerName(buf []byte) string { n := len(buf) if n < 5 { glog.Error("not tls handshake") return "" } // tls record type if recordType(buf[0]) != recordTypeHandshake { glog.Error("not tls") return "" } // tls major version if buf[1] != 3 { glog.Error("TLS version < 3 not supported") return "" } // payload length //l := int(buf[3])<<16 + int(buf[4]) //log.Printf("length: %d, got: %d", l, n) // handshake message type if uint8(buf[5]) != typeClientHello { glog.Error("not client hello") return "" } // parse client hello message msg := &clientHelloMsg{} // client hello message not include tls header, 5 bytes ret := msg.unmarshal(buf[5:n]) if !ret { glog.Error("parse hello message return false") return "" } return msg.serverName } func forward(c net.Conn, data []byte, dst string) { addr := dst proxyProto := 0 ss := strings.Fields(dst) var hdr proxyproto.Header if len(ss) > 1 { addr = ss[0] raddr := c.RemoteAddr().(*net.TCPAddr) hdr = proxyproto.Header{ Version: 1, Command: proxyproto.PROXY, TransportProtocol: proxyproto.TCPv4, SourceAddress: raddr.IP.To4(), DestinationAddress: net.IP{0, 0, 0, 0}, SourcePort: uint16(raddr.Port), DestinationPort: 0, } switch strings.ToLower(ss[1]) { case "proxy-v1": proxyProto = 1 hdr.Version = 1 case "proxy-v2": proxyProto = 2 hdr.Version = 2 } } c1, err := net.Dial("tcp", addr) if err != nil { glog.Error(err) return } defer c1.Close() if proxyProto != 0 { hdr.WriteTo(c1) } if _, err = c1.Write(data); err != nil { glog.Error(err) return } ch := make(chan struct{}, 2) go func() { io.Copy(c1, c) ch <- struct{}{} }() go func() { io.Copy(c, c1) ch <- struct{}{} }() <-ch } func getDST(c net.Conn, name string) string { addr := c.LocalAddr().(*net.TCPAddr) dst := cfg.ForwardRules.GetN(name, addr.Port) return dst } func getDefaultDST() string { return cfg.Default } func serve(c net.Conn) { defer c.Close() buf := make([]byte, 1024) n, err := c.Read(buf) if err != nil { glog.Error(err) return } servername := getSNIServerName(buf[:n]) if servername == "" { forward(c, buf[:n], getDefaultDST()) return } dst := getDST(c, servername) if dst == "" { dst = getDefaultDST() } forward(c, buf[:n], dst) } var cfg conf func main() { var cfgfile string flag.StringVar(&cfgfile, "c", "config.yaml", "config file") flag.Set("logtostderr", "true") flag.Parse() data, err := ioutil.ReadFile(cfgfile) if err != nil { glog.Fatal(err) } if err := yaml.Unmarshal(data, &cfg); err != nil { glog.Fatal(err) } for _, d := range cfg.Listen { glog.Infof("listen on :%d", d) l, err := net.Listen("tcp", fmt.Sprintf(":%d", d)) if err != nil { glog.Fatal(err) } go func(l net.Listener) { defer l.Close() for { c1, err := l.Accept() if err != nil { glog.Fatal(err) } go serve(c1) } }(l) } select {} }