diff --git a/proxy_protocol.go b/proxy_protocol.go index ef3bf0a..3e960bf 100644 --- a/proxy_protocol.go +++ b/proxy_protocol.go @@ -4,6 +4,7 @@ import ( "bufio" "net" + "github.com/fangdingjun/go-log" proxyproto "github.com/pires/go-proxyproto" ) @@ -16,6 +17,7 @@ type protoConn struct { headerDone bool r *bufio.Reader proxy *proxyproto.Header + err error } // New create a wrapped listener @@ -31,12 +33,24 @@ func (l *protoListener) Accept() (net.Conn, error) { return &protoConn{Conn: c}, err } -func (c *protoConn) Read(buf []byte) (int, error) { +func (c *protoConn) readHeader() error { var err error + c.r = bufio.NewReader(c.Conn) + c.proxy, err = proxyproto.Read(c.r) + if err != nil && err != proxyproto.ErrNoProxyProtocol { + c.err = err + return err + } + return nil +} + +func (c *protoConn) Read(buf []byte) (int, error) { + if c.err != nil { + return 0, c.err + } if !c.headerDone { - c.r = bufio.NewReader(c.Conn) - c.proxy, err = proxyproto.Read(c.r) - if err != nil && err != proxyproto.ErrNoProxyProtocol { + if err := c.readHeader(); err != nil { + c.headerDone = true return 0, err } c.headerDone = true @@ -46,6 +60,12 @@ func (c *protoConn) Read(buf []byte) (int, error) { } func (c *protoConn) RemoteAddr() net.Addr { + if !c.headerDone { + if err := c.readHeader(); err != nil { + log.Errorln(err) + } + c.headerDone = true + } if c.proxy == nil { return c.Conn.RemoteAddr() }