diff --git a/obfsshd/proxy_protocol.go b/obfsshd/proxy_protocol.go new file mode 100644 index 0000000..6536fed --- /dev/null +++ b/obfsshd/proxy_protocol.go @@ -0,0 +1,50 @@ +package main + +import ( + "bufio" + "net" + + proxyproto "github.com/pires/go-proxyproto" +) + +type protoListener struct { + net.Listener +} + +type protoConn struct { + net.Conn + headerDone bool + r *bufio.Reader + proxy *proxyproto.Header +} + +func (l *protoListener) Accept() (net.Conn, error) { + c, err := l.Listener.Accept() + if err != nil { + return nil, err + } + return &protoConn{Conn: c}, err +} + +func (c *protoConn) Read(buf []byte) (int, error) { + var err error + if !c.headerDone { + c.r = bufio.NewReader(c.Conn) + c.proxy, err = proxyproto.Read(c.r) + if err != nil && err != proxyproto.ErrNoProxyProtocol { + return 0, err + } + c.headerDone = true + return c.r.Read(buf) + } + return c.r.Read(buf) +} + +func (c *protoConn) RemoteAddr() net.Addr { + if c.proxy == nil { + return c.Conn.RemoteAddr() + } + return &net.TCPAddr{ + IP: c.proxy.SourceAddress, + Port: int(c.proxy.SourcePort)} +} diff --git a/obfsshd/server.go b/obfsshd/server.go index b1a3fc8..dc665c8 100644 --- a/obfsshd/server.go +++ b/obfsshd/server.go @@ -115,23 +115,23 @@ func main() { go func(lst listen) { var l net.Listener var err error - if lst.Key == "" || lst.Cert == "" { - l, err = net.Listen("tcp", fmt.Sprintf(":%d", lst.Port)) - } else { + + l, err = net.Listen("tcp", fmt.Sprintf(":%d", lst.Port)) + if err != nil { + log.Fatal(err) + } + defer l.Close() + + if lst.Key != "" && lst.Cert != "" { cert, err := tls.LoadX509KeyPair(lst.Cert, lst.Key) if err != nil { log.Fatal(err) } - l, err = tls.Listen("tcp", fmt.Sprintf(":%d", lst.Port), &tls.Config{ + l = tls.NewListener(&protoListener{l}, &tls.Config{ Certificates: []tls.Certificate{cert}, }) } - if err != nil { - log.Fatal(err) - } - defer l.Close() - for { c, err := l.Accept() if err != nil {