You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
153 lines
2.5 KiB
Go
153 lines
2.5 KiB
Go
7 years ago
|
package main
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"github.com/go-yaml/yaml"
|
||
|
"github.com/golang/glog"
|
||
|
"io"
|
||
|
"io/ioutil"
|
||
|
//"crypto/tls"
|
||
|
"flag"
|
||
|
//"log"
|
||
|
"net"
|
||
|
)
|
||
|
|
||
|
func getSNIServerName(buf []byte) string {
|
||
|
n := len(buf)
|
||
|
if n < 5 {
|
||
|
glog.Error("not tls handshake")
|
||
|
return ""
|
||
|
}
|
||
|
|
||
|
// tls record type
|
||
|
if recordType(buf[0]) != recordTypeHandshake {
|
||
|
glog.Error("not tls")
|
||
|
return ""
|
||
|
}
|
||
|
|
||
|
// tls major version
|
||
|
if buf[1] != 3 {
|
||
|
glog.Error("TLS version < 3 not supported")
|
||
|
return ""
|
||
|
}
|
||
|
|
||
|
// payload length
|
||
|
//l := int(buf[3])<<16 + int(buf[4])
|
||
|
|
||
|
//log.Printf("length: %d, got: %d", l, n)
|
||
|
|
||
|
// handshake message type
|
||
|
if uint8(buf[5]) != typeClientHello {
|
||
|
glog.Error("not client hello")
|
||
|
return ""
|
||
|
}
|
||
|
|
||
|
// parse client hello message
|
||
|
|
||
|
msg := &clientHelloMsg{}
|
||
|
|
||
|
// client hello message not include tls header, 5 bytes
|
||
|
ret := msg.unmarshal(buf[5:n])
|
||
|
if !ret {
|
||
|
glog.Error("parse hello message return false")
|
||
|
return ""
|
||
|
}
|
||
|
return msg.serverName
|
||
|
}
|
||
|
|
||
|
func forward(c net.Conn, data []byte, dst string) {
|
||
|
c1, err := net.Dial("tcp", dst)
|
||
|
if err != nil {
|
||
|
glog.Error(err)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
defer c1.Close()
|
||
|
|
||
|
if _, err = c1.Write(data); err != nil {
|
||
|
glog.Error(err)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
ch := make(chan struct{}, 2)
|
||
|
|
||
|
go func() {
|
||
|
io.Copy(c1, c)
|
||
|
ch <- struct{}{}
|
||
|
}()
|
||
|
|
||
|
go func() {
|
||
|
io.Copy(c, c1)
|
||
|
ch <- struct{}{}
|
||
|
}()
|
||
|
|
||
|
<-ch
|
||
|
}
|
||
|
|
||
|
func getDST(c net.Conn, name string) string {
|
||
|
addr := c.LocalAddr().(*net.TCPAddr)
|
||
|
dst := cfg.ForwardRules.GetN(name, addr.Port)
|
||
|
return dst
|
||
|
}
|
||
|
|
||
|
func getDefaultDST() string {
|
||
|
return cfg.Default
|
||
|
}
|
||
|
|
||
|
func serve(c net.Conn) {
|
||
|
defer c.Close()
|
||
|
|
||
|
buf := make([]byte, 1024)
|
||
|
n, err := c.Read(buf)
|
||
|
if err != nil {
|
||
|
glog.Error(err)
|
||
|
return
|
||
|
}
|
||
|
servername := getSNIServerName(buf[:n])
|
||
|
if servername == "" {
|
||
|
forward(c, buf[:n], getDefaultDST())
|
||
|
return
|
||
|
}
|
||
|
dst := getDST(c, servername)
|
||
|
if dst == "" {
|
||
|
dst = getDefaultDST()
|
||
|
}
|
||
|
forward(c, buf[:n], dst)
|
||
|
}
|
||
|
|
||
|
var cfg conf
|
||
|
|
||
|
func main() {
|
||
|
var cfgfile string
|
||
|
flag.StringVar(&cfgfile, "c", "config.yaml", "config file")
|
||
|
flag.Set("logtostderr", "true")
|
||
|
flag.Parse()
|
||
|
|
||
|
data, err := ioutil.ReadFile(cfgfile)
|
||
|
if err != nil {
|
||
|
glog.Fatal(err)
|
||
|
}
|
||
|
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||
|
glog.Fatal(err)
|
||
|
}
|
||
|
|
||
|
for _, d := range cfg.Listen {
|
||
|
glog.Infof("listen on :%d", d)
|
||
|
l, err := net.Listen("tcp", fmt.Sprintf(":%d", d))
|
||
|
if err != nil {
|
||
|
glog.Fatal(err)
|
||
|
}
|
||
|
go func(l net.Listener) {
|
||
|
defer l.Close()
|
||
|
for {
|
||
|
c1, err := l.Accept()
|
||
|
if err != nil {
|
||
|
glog.Fatal(err)
|
||
|
}
|
||
|
go serve(c1)
|
||
|
}
|
||
|
}(l)
|
||
|
}
|
||
|
select {}
|
||
|
}
|