diff --git a/common.go b/common.go index 4e3e644..6d4225a 100644 --- a/common.go +++ b/common.go @@ -69,9 +69,11 @@ const ( ) // TLS compression types. +/* const ( compressionNone uint8 = 0 ) +*/ // TLS extension numbers const ( @@ -107,9 +109,11 @@ const ( // TLS Elliptic Curve Point Formats // http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9 +/* const ( pointFormatUncompressed uint8 = 0 ) +*/ // TLS CertificateStatusType (RFC 3546) const ( @@ -117,6 +121,7 @@ const ( ) // Certificate types (for certificateRequestMsg) +/* const ( certTypeRSASign = 1 // A certificate containing an RSA key certTypeDSSSign = 2 // A certificate containing a DSA key @@ -130,19 +135,24 @@ const ( // Rest of these are reserved by the TLS spec ) +*/ // Hash functions for TLS 1.2 (See RFC 5246, section A.4.1) +/* const ( hashSHA1 uint8 = 2 hashSHA256 uint8 = 4 hashSHA384 uint8 = 5 ) +*/ // Signature algorithms for TLS 1.2 (See RFC 5246, section A.4.1) +/* const ( signatureRSA uint8 = 1 signatureECDSA uint8 = 3 ) +*/ // signatureAndHash mirrors the TLS 1.2, SignatureAndHashAlgorithm struct. See // RFC 5246, section A.4.1. @@ -153,6 +163,7 @@ type signatureAndHash struct { // supportedSignatureAlgorithms contains the signature and hash algorithms that // the code advertises as supported in a TLS 1.2 ClientHello and in a TLS 1.2 // CertificateRequest. +/* var supportedSignatureAlgorithms = []signatureAndHash{ {hashSHA256, signatureRSA}, {hashSHA256, signatureECDSA}, @@ -161,6 +172,7 @@ var supportedSignatureAlgorithms = []signatureAndHash{ {hashSHA1, signatureRSA}, {hashSHA1, signatureECDSA}, } +*/ // ConnectionState records basic TLS details about the connection. type ConnectionState struct { @@ -201,6 +213,7 @@ const ( // ClientSessionState contains the state needed by clients to resume TLS // sessions. +/* type ClientSessionState struct { sessionTicket []uint8 // Encrypted ticket used for session resumption with server vers uint16 // SSL/TLS version negotiated for the session @@ -209,12 +222,14 @@ type ClientSessionState struct { serverCertificates []*x509.Certificate // Certificate chain presented by the server verifiedChains [][]*x509.Certificate // Certificate chains we built for verification } +*/ // ClientSessionCache is a cache of ClientSessionState objects that can be used // by a client to resume a TLS session with a given server. ClientSessionCache // implementations should expect to be called concurrently from different // goroutines. Only ticket-based resumption is supported, not SessionID-based // resumption. +/* type ClientSessionCache interface { // Get searches for a ClientSessionState associated with the given key. // On return, ok is true if one was found. @@ -223,6 +238,7 @@ type ClientSessionCache interface { // Put adds the ClientSessionState to the cache with the given key. Put(sessionKey string, cs *ClientSessionState) } +*/ /* // SignatureScheme identifies a signature algorithm supported by TLS. See diff --git a/conf_test.go b/conf_test.go index 53c1ffb..2561fd2 100644 --- a/conf_test.go +++ b/conf_test.go @@ -1,14 +1,14 @@ package main import ( - "io/ioutil" + "os" "testing" yaml "gopkg.in/yaml.v2" ) func TestConf(t *testing.T) { - data, err := ioutil.ReadFile("config.sample.yaml") + data, err := os.ReadFile("config.sample.yaml") if err != nil { t.Fatal(err) } diff --git a/handshake_messages.go b/handshake_messages.go index 558b81f..b0cf7c4 100644 --- a/handshake_messages.go +++ b/handshake_messages.go @@ -33,272 +33,272 @@ type clientHelloMsg struct { } /* -func (m *clientHelloMsg) equal(i interface{}) bool { - m1, ok := i.(*clientHelloMsg) - if !ok { - return false - } - - return bytes.Equal(m.raw, m1.raw) && - m.vers == m1.vers && - bytes.Equal(m.random, m1.random) && - bytes.Equal(m.sessionId, m1.sessionId) && - eqUint16s(m.cipherSuites, m1.cipherSuites) && - bytes.Equal(m.compressionMethods, m1.compressionMethods) && - m.nextProtoNeg == m1.nextProtoNeg && - m.serverName == m1.serverName && - m.ocspStapling == m1.ocspStapling && - m.scts == m1.scts && - eqCurveIDs(m.supportedCurves, m1.supportedCurves) && - bytes.Equal(m.supportedPoints, m1.supportedPoints) && - m.ticketSupported == m1.ticketSupported && - bytes.Equal(m.sessionTicket, m1.sessionTicket) && - eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) && - m.secureRenegotiationSupported == m1.secureRenegotiationSupported && - bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) && - eqStrings(m.alpnProtocols, m1.alpnProtocols) -} + func (m *clientHelloMsg) equal(i interface{}) bool { + m1, ok := i.(*clientHelloMsg) + if !ok { + return false + } -func (m *clientHelloMsg) marshal() []byte { - if m.raw != nil { - return m.raw - } + return bytes.Equal(m.raw, m1.raw) && + m.vers == m1.vers && + bytes.Equal(m.random, m1.random) && + bytes.Equal(m.sessionId, m1.sessionId) && + eqUint16s(m.cipherSuites, m1.cipherSuites) && + bytes.Equal(m.compressionMethods, m1.compressionMethods) && + m.nextProtoNeg == m1.nextProtoNeg && + m.serverName == m1.serverName && + m.ocspStapling == m1.ocspStapling && + m.scts == m1.scts && + eqCurveIDs(m.supportedCurves, m1.supportedCurves) && + bytes.Equal(m.supportedPoints, m1.supportedPoints) && + m.ticketSupported == m1.ticketSupported && + bytes.Equal(m.sessionTicket, m1.sessionTicket) && + eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) && + m.secureRenegotiationSupported == m1.secureRenegotiationSupported && + bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) && + eqStrings(m.alpnProtocols, m1.alpnProtocols) + } + + func (m *clientHelloMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } - length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods) - numExtensions := 0 - extensionsLength := 0 - if m.nextProtoNeg { - numExtensions++ - } - if m.ocspStapling { - extensionsLength += 1 + 2 + 2 - numExtensions++ - } - if len(m.serverName) > 0 { - extensionsLength += 5 + len(m.serverName) - numExtensions++ - } - if len(m.supportedCurves) > 0 { - extensionsLength += 2 + 2*len(m.supportedCurves) - numExtensions++ - } - if len(m.supportedPoints) > 0 { - extensionsLength += 1 + len(m.supportedPoints) - numExtensions++ - } - if m.ticketSupported { - extensionsLength += len(m.sessionTicket) - numExtensions++ - } - if len(m.signatureAndHashes) > 0 { - extensionsLength += 2 + 2*len(m.signatureAndHashes) - numExtensions++ - } - if m.secureRenegotiationSupported { - extensionsLength += 1 + len(m.secureRenegotiation) - numExtensions++ - } - if len(m.alpnProtocols) > 0 { - extensionsLength += 2 - for _, s := range m.alpnProtocols { - if l := len(s); l == 0 || l > 255 { - panic("invalid ALPN protocol") + length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods) + numExtensions := 0 + extensionsLength := 0 + if m.nextProtoNeg { + numExtensions++ + } + if m.ocspStapling { + extensionsLength += 1 + 2 + 2 + numExtensions++ + } + if len(m.serverName) > 0 { + extensionsLength += 5 + len(m.serverName) + numExtensions++ + } + if len(m.supportedCurves) > 0 { + extensionsLength += 2 + 2*len(m.supportedCurves) + numExtensions++ + } + if len(m.supportedPoints) > 0 { + extensionsLength += 1 + len(m.supportedPoints) + numExtensions++ + } + if m.ticketSupported { + extensionsLength += len(m.sessionTicket) + numExtensions++ + } + if len(m.signatureAndHashes) > 0 { + extensionsLength += 2 + 2*len(m.signatureAndHashes) + numExtensions++ + } + if m.secureRenegotiationSupported { + extensionsLength += 1 + len(m.secureRenegotiation) + numExtensions++ + } + if len(m.alpnProtocols) > 0 { + extensionsLength += 2 + for _, s := range m.alpnProtocols { + if l := len(s); l == 0 || l > 255 { + panic("invalid ALPN protocol") + } + extensionsLength++ + extensionsLength += len(s) } - extensionsLength++ - extensionsLength += len(s) + numExtensions++ + } + if m.scts { + numExtensions++ + } + if numExtensions > 0 { + extensionsLength += 4 * numExtensions + length += 2 + extensionsLength } - numExtensions++ - } - if m.scts { - numExtensions++ - } - if numExtensions > 0 { - extensionsLength += 4 * numExtensions - length += 2 + extensionsLength - } - - x := make([]byte, 4+length) - x[0] = typeClientHello - x[1] = uint8(length >> 16) - x[2] = uint8(length >> 8) - x[3] = uint8(length) - x[4] = uint8(m.vers >> 8) - x[5] = uint8(m.vers) - copy(x[6:38], m.random) - x[38] = uint8(len(m.sessionId)) - copy(x[39:39+len(m.sessionId)], m.sessionId) - y := x[39+len(m.sessionId):] - y[0] = uint8(len(m.cipherSuites) >> 7) - y[1] = uint8(len(m.cipherSuites) << 1) - for i, suite := range m.cipherSuites { - y[2+i*2] = uint8(suite >> 8) - y[3+i*2] = uint8(suite) - } - z := y[2+len(m.cipherSuites)*2:] - z[0] = uint8(len(m.compressionMethods)) - copy(z[1:], m.compressionMethods) - - z = z[1+len(m.compressionMethods):] - if numExtensions > 0 { - z[0] = byte(extensionsLength >> 8) - z[1] = byte(extensionsLength) - z = z[2:] - } - if m.nextProtoNeg { - z[0] = byte(extensionNextProtoNeg >> 8) - z[1] = byte(extensionNextProtoNeg & 0xff) - // The length is always 0 - z = z[4:] - } - if len(m.serverName) > 0 { - z[0] = byte(extensionServerName >> 8) - z[1] = byte(extensionServerName & 0xff) - l := len(m.serverName) + 5 - z[2] = byte(l >> 8) - z[3] = byte(l) - z = z[4:] - // RFC 3546, section 3.1 - // - // struct { - // NameType name_type; - // select (name_type) { - // case host_name: HostName; - // } name; - // } ServerName; - // - // enum { - // host_name(0), (255) - // } NameType; - // - // opaque HostName<1..2^16-1>; - // - // struct { - // ServerName server_name_list<1..2^16-1> - // } ServerNameList; - - z[0] = byte((len(m.serverName) + 3) >> 8) - z[1] = byte(len(m.serverName) + 3) - z[3] = byte(len(m.serverName) >> 8) - z[4] = byte(len(m.serverName)) - copy(z[5:], []byte(m.serverName)) - z = z[l:] - } - if m.ocspStapling { - // RFC 4366, section 3.6 - z[0] = byte(extensionStatusRequest >> 8) - z[1] = byte(extensionStatusRequest) - z[2] = 0 - z[3] = 5 - z[4] = 1 // OCSP type - // Two zero valued uint16s for the two lengths. - z = z[9:] - } - if len(m.supportedCurves) > 0 { - // http://tools.ietf.org/html/rfc4492#section-5.5.1 - z[0] = byte(extensionSupportedCurves >> 8) - z[1] = byte(extensionSupportedCurves) - l := 2 + 2*len(m.supportedCurves) - z[2] = byte(l >> 8) - z[3] = byte(l) - l -= 2 - z[4] = byte(l >> 8) - z[5] = byte(l) - z = z[6:] - for _, curve := range m.supportedCurves { - z[0] = byte(curve >> 8) - z[1] = byte(curve) + x := make([]byte, 4+length) + x[0] = typeClientHello + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + x[4] = uint8(m.vers >> 8) + x[5] = uint8(m.vers) + copy(x[6:38], m.random) + x[38] = uint8(len(m.sessionId)) + copy(x[39:39+len(m.sessionId)], m.sessionId) + y := x[39+len(m.sessionId):] + y[0] = uint8(len(m.cipherSuites) >> 7) + y[1] = uint8(len(m.cipherSuites) << 1) + for i, suite := range m.cipherSuites { + y[2+i*2] = uint8(suite >> 8) + y[3+i*2] = uint8(suite) + } + z := y[2+len(m.cipherSuites)*2:] + z[0] = uint8(len(m.compressionMethods)) + copy(z[1:], m.compressionMethods) + + z = z[1+len(m.compressionMethods):] + if numExtensions > 0 { + z[0] = byte(extensionsLength >> 8) + z[1] = byte(extensionsLength) z = z[2:] } - } - if len(m.supportedPoints) > 0 { - // http://tools.ietf.org/html/rfc4492#section-5.5.2 - z[0] = byte(extensionSupportedPoints >> 8) - z[1] = byte(extensionSupportedPoints) - l := 1 + len(m.supportedPoints) - z[2] = byte(l >> 8) - z[3] = byte(l) - l-- - z[4] = byte(l) - z = z[5:] - for _, pointFormat := range m.supportedPoints { - z[0] = pointFormat - z = z[1:] + if m.nextProtoNeg { + z[0] = byte(extensionNextProtoNeg >> 8) + z[1] = byte(extensionNextProtoNeg & 0xff) + // The length is always 0 + z = z[4:] } - } - if m.ticketSupported { - // http://tools.ietf.org/html/rfc5077#section-3.2 - z[0] = byte(extensionSessionTicket >> 8) - z[1] = byte(extensionSessionTicket) - l := len(m.sessionTicket) - z[2] = byte(l >> 8) - z[3] = byte(l) - z = z[4:] - copy(z, m.sessionTicket) - z = z[len(m.sessionTicket):] - } - if len(m.signatureAndHashes) > 0 { - // https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1 - z[0] = byte(extensionSignatureAlgorithms >> 8) - z[1] = byte(extensionSignatureAlgorithms) - l := 2 + 2*len(m.signatureAndHashes) - z[2] = byte(l >> 8) - z[3] = byte(l) - z = z[4:] - - l -= 2 - z[0] = byte(l >> 8) - z[1] = byte(l) - z = z[2:] - for _, sigAndHash := range m.signatureAndHashes { - z[0] = sigAndHash.hash - z[1] = sigAndHash.signature + if len(m.serverName) > 0 { + z[0] = byte(extensionServerName >> 8) + z[1] = byte(extensionServerName & 0xff) + l := len(m.serverName) + 5 + z[2] = byte(l >> 8) + z[3] = byte(l) + z = z[4:] + + // RFC 3546, section 3.1 + // + // struct { + // NameType name_type; + // select (name_type) { + // case host_name: HostName; + // } name; + // } ServerName; + // + // enum { + // host_name(0), (255) + // } NameType; + // + // opaque HostName<1..2^16-1>; + // + // struct { + // ServerName server_name_list<1..2^16-1> + // } ServerNameList; + + z[0] = byte((len(m.serverName) + 3) >> 8) + z[1] = byte(len(m.serverName) + 3) + z[3] = byte(len(m.serverName) >> 8) + z[4] = byte(len(m.serverName)) + copy(z[5:], []byte(m.serverName)) + z = z[l:] + } + if m.ocspStapling { + // RFC 4366, section 3.6 + z[0] = byte(extensionStatusRequest >> 8) + z[1] = byte(extensionStatusRequest) + z[2] = 0 + z[3] = 5 + z[4] = 1 // OCSP type + // Two zero valued uint16s for the two lengths. + z = z[9:] + } + if len(m.supportedCurves) > 0 { + // http://tools.ietf.org/html/rfc4492#section-5.5.1 + z[0] = byte(extensionSupportedCurves >> 8) + z[1] = byte(extensionSupportedCurves) + l := 2 + 2*len(m.supportedCurves) + z[2] = byte(l >> 8) + z[3] = byte(l) + l -= 2 + z[4] = byte(l >> 8) + z[5] = byte(l) + z = z[6:] + for _, curve := range m.supportedCurves { + z[0] = byte(curve >> 8) + z[1] = byte(curve) + z = z[2:] + } + } + if len(m.supportedPoints) > 0 { + // http://tools.ietf.org/html/rfc4492#section-5.5.2 + z[0] = byte(extensionSupportedPoints >> 8) + z[1] = byte(extensionSupportedPoints) + l := 1 + len(m.supportedPoints) + z[2] = byte(l >> 8) + z[3] = byte(l) + l-- + z[4] = byte(l) + z = z[5:] + for _, pointFormat := range m.supportedPoints { + z[0] = pointFormat + z = z[1:] + } + } + if m.ticketSupported { + // http://tools.ietf.org/html/rfc5077#section-3.2 + z[0] = byte(extensionSessionTicket >> 8) + z[1] = byte(extensionSessionTicket) + l := len(m.sessionTicket) + z[2] = byte(l >> 8) + z[3] = byte(l) + z = z[4:] + copy(z, m.sessionTicket) + z = z[len(m.sessionTicket):] + } + if len(m.signatureAndHashes) > 0 { + // https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1 + z[0] = byte(extensionSignatureAlgorithms >> 8) + z[1] = byte(extensionSignatureAlgorithms) + l := 2 + 2*len(m.signatureAndHashes) + z[2] = byte(l >> 8) + z[3] = byte(l) + z = z[4:] + + l -= 2 + z[0] = byte(l >> 8) + z[1] = byte(l) z = z[2:] + for _, sigAndHash := range m.signatureAndHashes { + z[0] = sigAndHash.hash + z[1] = sigAndHash.signature + z = z[2:] + } } - } - if m.secureRenegotiationSupported { - z[0] = byte(extensionRenegotiationInfo >> 8) - z[1] = byte(extensionRenegotiationInfo & 0xff) - z[2] = 0 - z[3] = byte(len(m.secureRenegotiation) + 1) - z[4] = byte(len(m.secureRenegotiation)) - z = z[5:] - copy(z, m.secureRenegotiation) - z = z[len(m.secureRenegotiation):] - } - if len(m.alpnProtocols) > 0 { - z[0] = byte(extensionALPN >> 8) - z[1] = byte(extensionALPN & 0xff) - lengths := z[2:] - z = z[6:] - - stringsLength := 0 - for _, s := range m.alpnProtocols { - l := len(s) - z[0] = byte(l) - copy(z[1:], s) - z = z[1+l:] - stringsLength += 1 + l + if m.secureRenegotiationSupported { + z[0] = byte(extensionRenegotiationInfo >> 8) + z[1] = byte(extensionRenegotiationInfo & 0xff) + z[2] = 0 + z[3] = byte(len(m.secureRenegotiation) + 1) + z[4] = byte(len(m.secureRenegotiation)) + z = z[5:] + copy(z, m.secureRenegotiation) + z = z[len(m.secureRenegotiation):] } + if len(m.alpnProtocols) > 0 { + z[0] = byte(extensionALPN >> 8) + z[1] = byte(extensionALPN & 0xff) + lengths := z[2:] + z = z[6:] + + stringsLength := 0 + for _, s := range m.alpnProtocols { + l := len(s) + z[0] = byte(l) + copy(z[1:], s) + z = z[1+l:] + stringsLength += 1 + l + } - lengths[2] = byte(stringsLength >> 8) - lengths[3] = byte(stringsLength) - stringsLength += 2 - lengths[0] = byte(stringsLength >> 8) - lengths[1] = byte(stringsLength) - } - if m.scts { - // https://tools.ietf.org/html/rfc6962#section-3.3.1 - z[0] = byte(extensionSCT >> 8) - z[1] = byte(extensionSCT) - // zero uint16 for the zero-length extension_data - z = z[4:] - } + lengths[2] = byte(stringsLength >> 8) + lengths[3] = byte(stringsLength) + stringsLength += 2 + lengths[0] = byte(stringsLength >> 8) + lengths[1] = byte(stringsLength) + } + if m.scts { + // https://tools.ietf.org/html/rfc6962#section-3.3.1 + z[0] = byte(extensionSCT >> 8) + z[1] = byte(extensionSCT) + // zero uint16 for the zero-length extension_data + z = z[4:] + } - m.raw = x + m.raw = x - return x -} + return x + } */ func (m *clientHelloMsg) unmarshal(data []byte) bool { if len(data) < 42 { diff --git a/main.go b/main.go index 902531d..5acd292 100644 --- a/main.go +++ b/main.go @@ -5,7 +5,6 @@ import ( "flag" "fmt" "io" - "io/ioutil" "net" "os" "os/signal" @@ -78,7 +77,11 @@ func forward(ctx context.Context, c net.Conn, data []byte, dst string) { Command: proxyproto.PROXY, TransportProtocol: proxyproto.TCPv4, SourceAddr: raddr, - DestinationAddr: &net.TCPAddr{}, + DestinationAddr: &net.TCPAddr{ + IP: []byte{0, 0, 0, 0}, + Port: 0, + Zone: "", + }, } switch strings.ToLower(ss[1]) { @@ -186,7 +189,7 @@ func main() { flag.StringVar(&loglevel, "log_level", "INFO", "log level") flag.Parse() - data, err := ioutil.ReadFile(cfgfile) + data, err := os.ReadFile(cfgfile) if err != nil { glog.Fatal(err) } @@ -231,9 +234,7 @@ func main() { ch := make(chan os.Signal, 2) signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM) - select { - case s := <-ch: - cancel() - glog.Printf("received signal %s, exit.", s) - } + s := <-ch + cancel() + glog.Printf("received signal %s, exit.", s) } diff --git a/proto_test.go b/proto_test.go index 989bb1f..2c35f21 100644 --- a/proto_test.go +++ b/proto_test.go @@ -1,96 +1,96 @@ -package main - -import ( - "context" - "crypto/tls" - "io/ioutil" - "net" - "testing" - - "github.com/fangdingjun/go-log/v5" - "github.com/fangdingjun/protolistener" - yaml "gopkg.in/yaml.v2" -) - -func TestProxyProto(t *testing.T) { - log.Default.Level = log.DEBUG - - data, err := ioutil.ReadFile("config.sample.yaml") - if err != nil { - log.Fatal(err) - } - if err = yaml.Unmarshal(data, &cfg); err != nil { - log.Fatal(err) - } - - listener, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal(err) - } - defer listener.Close() - log.Printf("listen %s", listener.Addr().String()) - - go func() { - for { - conn, err := listener.Accept() - if err != nil { - return - } - go serve(context.Background(), conn) - } - }() - cert, err := tls.LoadX509KeyPair("server.crt", "server.key") - if err != nil { - t.Fatal(err) - } - listener2, err := net.Listen("tcp", "127.0.0.1:8443") - if err != nil { - t.Fatal(err) - } - defer listener2.Close() - - listener2 = tls.NewListener(protolistener.New(listener2), &tls.Config{ - Certificates: []tls.Certificate{cert}, - }) - - go func() { - for { - conn, err := listener2.Accept() - if err != nil { - return - } - go func(c net.Conn) { - defer c.Close() - addr := conn.RemoteAddr() - _conn := c.(*tls.Conn) - if err := _conn.Handshake(); err != nil { - log.Errorf("handshake error: %s", err) - return - } - conn.Write([]byte(addr.String())) - }(conn) - } - }() - - conn, err := tls.Dial("tcp", listener.Addr().String(), &tls.Config{ - ServerName: "www.example.com", - InsecureSkipVerify: true, - }) - if err != nil { - log.Println("dial error") - t.Fatal(err) - } - conn.Handshake() - buf := make([]byte, 200) - n, err := conn.Read(buf) - if err != nil { - log.Println("read error") - t.Fatal(err) - } - addr1 := conn.LocalAddr().String() - addr2 := string(buf[:n]) - conn.Close() - if addr1 != addr2 { - t.Errorf("expect %s, got: %s", addr1, addr2) - } -} +package main + +import ( + "context" + "crypto/tls" + "net" + "os" + "testing" + + "github.com/fangdingjun/go-log/v5" + "github.com/fangdingjun/protolistener" + yaml "gopkg.in/yaml.v2" +) + +func TestProxyProto(t *testing.T) { + log.Default.Level = log.DEBUG + + data, err := os.ReadFile("config.sample.yaml") + if err != nil { + log.Fatal(err) + } + if err = yaml.Unmarshal(data, &cfg); err != nil { + log.Fatal(err) + } + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + log.Printf("listen %s", listener.Addr().String()) + + go func() { + for { + conn, err := listener.Accept() + if err != nil { + return + } + go serve(context.Background(), conn) + } + }() + cert, err := tls.LoadX509KeyPair("server.crt", "server.key") + if err != nil { + t.Fatal(err) + } + listener2, err := net.Listen("tcp", "127.0.0.1:8443") + if err != nil { + t.Fatal(err) + } + defer listener2.Close() + + listener2 = tls.NewListener(protolistener.New(listener2), &tls.Config{ + Certificates: []tls.Certificate{cert}, + }) + + go func() { + for { + conn, err := listener2.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + addr := conn.RemoteAddr() + _conn := c.(*tls.Conn) + if err := _conn.Handshake(); err != nil { + log.Errorf("handshake error: %s", err) + return + } + conn.Write([]byte(addr.String())) + }(conn) + } + }() + + conn, err := tls.Dial("tcp", listener.Addr().String(), &tls.Config{ + ServerName: "www.example.com", + InsecureSkipVerify: true, + }) + if err != nil { + log.Println("dial error") + t.Fatal(err) + } + conn.Handshake() + buf := make([]byte, 200) + n, err := conn.Read(buf) + if err != nil { + log.Println("read error") + t.Fatal(err) + } + addr1 := conn.LocalAddr().String() + addr2 := string(buf[:n]) + conn.Close() + if addr1 != addr2 { + t.Errorf("expect %s, got: %s", addr1, addr2) + } +}