Compare commits

...

3 Commits

Author SHA1 Message Date
Dingjun Fang c43912d031
Merge pull request #7 from fangdingjun/dependabot/go_modules/github.com/pires/go-proxyproto-0.6.1
Bump github.com/pires/go-proxyproto from 0.6.0 to 0.6.1
2 years ago
dependabot[bot] e8971aca3b
Bump github.com/pires/go-proxyproto from 0.6.0 to 0.6.1
Bumps [github.com/pires/go-proxyproto](https://github.com/pires/go-proxyproto) from 0.6.0 to 0.6.1.
- [Release notes](https://github.com/pires/go-proxyproto/releases)
- [Commits](https://github.com/pires/go-proxyproto/compare/v0.6.0...v0.6.1)

---
updated-dependencies:
- dependency-name: github.com/pires/go-proxyproto
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
2 years ago
dingjun c39f4f526a go fmt and fix static check warnings 2 years ago

@ -69,9 +69,11 @@ const (
)
// TLS compression types.
/*
const (
compressionNone uint8 = 0
)
*/
// TLS extension numbers
const (
@ -107,9 +109,11 @@ const (
// TLS Elliptic Curve Point Formats
// http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9
/*
const (
pointFormatUncompressed uint8 = 0
)
*/
// TLS CertificateStatusType (RFC 3546)
const (
@ -117,6 +121,7 @@ const (
)
// Certificate types (for certificateRequestMsg)
/*
const (
certTypeRSASign = 1 // A certificate containing an RSA key
certTypeDSSSign = 2 // A certificate containing a DSA key
@ -130,19 +135,24 @@ const (
// Rest of these are reserved by the TLS spec
)
*/
// Hash functions for TLS 1.2 (See RFC 5246, section A.4.1)
/*
const (
hashSHA1 uint8 = 2
hashSHA256 uint8 = 4
hashSHA384 uint8 = 5
)
*/
// Signature algorithms for TLS 1.2 (See RFC 5246, section A.4.1)
/*
const (
signatureRSA uint8 = 1
signatureECDSA uint8 = 3
)
*/
// signatureAndHash mirrors the TLS 1.2, SignatureAndHashAlgorithm struct. See
// RFC 5246, section A.4.1.
@ -153,6 +163,7 @@ type signatureAndHash struct {
// supportedSignatureAlgorithms contains the signature and hash algorithms that
// the code advertises as supported in a TLS 1.2 ClientHello and in a TLS 1.2
// CertificateRequest.
/*
var supportedSignatureAlgorithms = []signatureAndHash{
{hashSHA256, signatureRSA},
{hashSHA256, signatureECDSA},
@ -161,6 +172,7 @@ var supportedSignatureAlgorithms = []signatureAndHash{
{hashSHA1, signatureRSA},
{hashSHA1, signatureECDSA},
}
*/
// ConnectionState records basic TLS details about the connection.
type ConnectionState struct {
@ -201,6 +213,7 @@ const (
// ClientSessionState contains the state needed by clients to resume TLS
// sessions.
/*
type ClientSessionState struct {
sessionTicket []uint8 // Encrypted ticket used for session resumption with server
vers uint16 // SSL/TLS version negotiated for the session
@ -209,12 +222,14 @@ type ClientSessionState struct {
serverCertificates []*x509.Certificate // Certificate chain presented by the server
verifiedChains [][]*x509.Certificate // Certificate chains we built for verification
}
*/
// ClientSessionCache is a cache of ClientSessionState objects that can be used
// by a client to resume a TLS session with a given server. ClientSessionCache
// implementations should expect to be called concurrently from different
// goroutines. Only ticket-based resumption is supported, not SessionID-based
// resumption.
/*
type ClientSessionCache interface {
// Get searches for a ClientSessionState associated with the given key.
// On return, ok is true if one was found.
@ -223,6 +238,7 @@ type ClientSessionCache interface {
// Put adds the ClientSessionState to the cache with the given key.
Put(sessionKey string, cs *ClientSessionState)
}
*/
/*
// SignatureScheme identifies a signature algorithm supported by TLS. See

@ -1,14 +1,14 @@
package main
import (
"io/ioutil"
"os"
"testing"
yaml "gopkg.in/yaml.v2"
)
func TestConf(t *testing.T) {
data, err := ioutil.ReadFile("config.sample.yaml")
data, err := os.ReadFile("config.sample.yaml")
if err != nil {
t.Fatal(err)
}

@ -5,7 +5,7 @@ go 1.17
require (
github.com/fangdingjun/go-log/v5 v5.0.0
github.com/fangdingjun/protolistener v0.0.0-20210804081554-626e6590d6e7
github.com/pires/go-proxyproto v0.6.0
github.com/pires/go-proxyproto v0.6.1
gopkg.in/yaml.v2 v2.4.0
)

@ -2,8 +2,9 @@ github.com/fangdingjun/go-log/v5 v5.0.0 h1:vdh9Bk9C4ZFL6KoO6rII73zQIyaLf7hFdBvuc
github.com/fangdingjun/go-log/v5 v5.0.0/go.mod h1:V012Oxo0/pSbccX4OFSp9MJglXwNsZo2ByBBorr7zzM=
github.com/fangdingjun/protolistener v0.0.0-20210804081554-626e6590d6e7 h1:z5NlLvUNbGZxQhtUZELvIorAqpDmcIhGQXE5GdO7+5I=
github.com/fangdingjun/protolistener v0.0.0-20210804081554-626e6590d6e7/go.mod h1:ljbjhI4fVrT5GwMu1iBhWTZwLJSqsXKwKlGKas5GudM=
github.com/pires/go-proxyproto v0.6.0 h1:cLJUPnuQdiNf7P/wbeOKmM1khVdaMgTFDLj8h9ZrVYk=
github.com/pires/go-proxyproto v0.6.0/go.mod h1:Odh9VFOZJCf9G8cLW5o435Xf1J95Jw9Gw5rnCjcwzAY=
github.com/pires/go-proxyproto v0.6.1 h1:EBupykFmo22SDjv4fQVQd2J9NOoLPmyZA/15ldOGkPw=
github.com/pires/go-proxyproto v0.6.1/go.mod h1:Odh9VFOZJCf9G8cLW5o435Xf1J95Jw9Gw5rnCjcwzAY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=

@ -33,272 +33,272 @@ type clientHelloMsg struct {
}
/*
func (m *clientHelloMsg) equal(i interface{}) bool {
m1, ok := i.(*clientHelloMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
m.vers == m1.vers &&
bytes.Equal(m.random, m1.random) &&
bytes.Equal(m.sessionId, m1.sessionId) &&
eqUint16s(m.cipherSuites, m1.cipherSuites) &&
bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
m.nextProtoNeg == m1.nextProtoNeg &&
m.serverName == m1.serverName &&
m.ocspStapling == m1.ocspStapling &&
m.scts == m1.scts &&
eqCurveIDs(m.supportedCurves, m1.supportedCurves) &&
bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
m.ticketSupported == m1.ticketSupported &&
bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) &&
m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
eqStrings(m.alpnProtocols, m1.alpnProtocols)
}
func (m *clientHelloMsg) equal(i interface{}) bool {
m1, ok := i.(*clientHelloMsg)
if !ok {
return false
}
func (m *clientHelloMsg) marshal() []byte {
if m.raw != nil {
return m.raw
}
return bytes.Equal(m.raw, m1.raw) &&
m.vers == m1.vers &&
bytes.Equal(m.random, m1.random) &&
bytes.Equal(m.sessionId, m1.sessionId) &&
eqUint16s(m.cipherSuites, m1.cipherSuites) &&
bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
m.nextProtoNeg == m1.nextProtoNeg &&
m.serverName == m1.serverName &&
m.ocspStapling == m1.ocspStapling &&
m.scts == m1.scts &&
eqCurveIDs(m.supportedCurves, m1.supportedCurves) &&
bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
m.ticketSupported == m1.ticketSupported &&
bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) &&
m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
eqStrings(m.alpnProtocols, m1.alpnProtocols)
}
func (m *clientHelloMsg) marshal() []byte {
if m.raw != nil {
return m.raw
}
length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
numExtensions := 0
extensionsLength := 0
if m.nextProtoNeg {
numExtensions++
}
if m.ocspStapling {
extensionsLength += 1 + 2 + 2
numExtensions++
}
if len(m.serverName) > 0 {
extensionsLength += 5 + len(m.serverName)
numExtensions++
}
if len(m.supportedCurves) > 0 {
extensionsLength += 2 + 2*len(m.supportedCurves)
numExtensions++
}
if len(m.supportedPoints) > 0 {
extensionsLength += 1 + len(m.supportedPoints)
numExtensions++
}
if m.ticketSupported {
extensionsLength += len(m.sessionTicket)
numExtensions++
}
if len(m.signatureAndHashes) > 0 {
extensionsLength += 2 + 2*len(m.signatureAndHashes)
numExtensions++
}
if m.secureRenegotiationSupported {
extensionsLength += 1 + len(m.secureRenegotiation)
numExtensions++
}
if len(m.alpnProtocols) > 0 {
extensionsLength += 2
for _, s := range m.alpnProtocols {
if l := len(s); l == 0 || l > 255 {
panic("invalid ALPN protocol")
length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
numExtensions := 0
extensionsLength := 0
if m.nextProtoNeg {
numExtensions++
}
if m.ocspStapling {
extensionsLength += 1 + 2 + 2
numExtensions++
}
if len(m.serverName) > 0 {
extensionsLength += 5 + len(m.serverName)
numExtensions++
}
if len(m.supportedCurves) > 0 {
extensionsLength += 2 + 2*len(m.supportedCurves)
numExtensions++
}
if len(m.supportedPoints) > 0 {
extensionsLength += 1 + len(m.supportedPoints)
numExtensions++
}
if m.ticketSupported {
extensionsLength += len(m.sessionTicket)
numExtensions++
}
if len(m.signatureAndHashes) > 0 {
extensionsLength += 2 + 2*len(m.signatureAndHashes)
numExtensions++
}
if m.secureRenegotiationSupported {
extensionsLength += 1 + len(m.secureRenegotiation)
numExtensions++
}
if len(m.alpnProtocols) > 0 {
extensionsLength += 2
for _, s := range m.alpnProtocols {
if l := len(s); l == 0 || l > 255 {
panic("invalid ALPN protocol")
}
extensionsLength++
extensionsLength += len(s)
}
extensionsLength++
extensionsLength += len(s)
numExtensions++
}
if m.scts {
numExtensions++
}
if numExtensions > 0 {
extensionsLength += 4 * numExtensions
length += 2 + extensionsLength
}
numExtensions++
}
if m.scts {
numExtensions++
}
if numExtensions > 0 {
extensionsLength += 4 * numExtensions
length += 2 + extensionsLength
}
x := make([]byte, 4+length)
x[0] = typeClientHello
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
x[3] = uint8(length)
x[4] = uint8(m.vers >> 8)
x[5] = uint8(m.vers)
copy(x[6:38], m.random)
x[38] = uint8(len(m.sessionId))
copy(x[39:39+len(m.sessionId)], m.sessionId)
y := x[39+len(m.sessionId):]
y[0] = uint8(len(m.cipherSuites) >> 7)
y[1] = uint8(len(m.cipherSuites) << 1)
for i, suite := range m.cipherSuites {
y[2+i*2] = uint8(suite >> 8)
y[3+i*2] = uint8(suite)
}
z := y[2+len(m.cipherSuites)*2:]
z[0] = uint8(len(m.compressionMethods))
copy(z[1:], m.compressionMethods)
z = z[1+len(m.compressionMethods):]
if numExtensions > 0 {
z[0] = byte(extensionsLength >> 8)
z[1] = byte(extensionsLength)
z = z[2:]
}
if m.nextProtoNeg {
z[0] = byte(extensionNextProtoNeg >> 8)
z[1] = byte(extensionNextProtoNeg & 0xff)
// The length is always 0
z = z[4:]
}
if len(m.serverName) > 0 {
z[0] = byte(extensionServerName >> 8)
z[1] = byte(extensionServerName & 0xff)
l := len(m.serverName) + 5
z[2] = byte(l >> 8)
z[3] = byte(l)
z = z[4:]
// RFC 3546, section 3.1
//
// struct {
// NameType name_type;
// select (name_type) {
// case host_name: HostName;
// } name;
// } ServerName;
//
// enum {
// host_name(0), (255)
// } NameType;
//
// opaque HostName<1..2^16-1>;
//
// struct {
// ServerName server_name_list<1..2^16-1>
// } ServerNameList;
z[0] = byte((len(m.serverName) + 3) >> 8)
z[1] = byte(len(m.serverName) + 3)
z[3] = byte(len(m.serverName) >> 8)
z[4] = byte(len(m.serverName))
copy(z[5:], []byte(m.serverName))
z = z[l:]
}
if m.ocspStapling {
// RFC 4366, section 3.6
z[0] = byte(extensionStatusRequest >> 8)
z[1] = byte(extensionStatusRequest)
z[2] = 0
z[3] = 5
z[4] = 1 // OCSP type
// Two zero valued uint16s for the two lengths.
z = z[9:]
}
if len(m.supportedCurves) > 0 {
// http://tools.ietf.org/html/rfc4492#section-5.5.1
z[0] = byte(extensionSupportedCurves >> 8)
z[1] = byte(extensionSupportedCurves)
l := 2 + 2*len(m.supportedCurves)
z[2] = byte(l >> 8)
z[3] = byte(l)
l -= 2
z[4] = byte(l >> 8)
z[5] = byte(l)
z = z[6:]
for _, curve := range m.supportedCurves {
z[0] = byte(curve >> 8)
z[1] = byte(curve)
x := make([]byte, 4+length)
x[0] = typeClientHello
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
x[3] = uint8(length)
x[4] = uint8(m.vers >> 8)
x[5] = uint8(m.vers)
copy(x[6:38], m.random)
x[38] = uint8(len(m.sessionId))
copy(x[39:39+len(m.sessionId)], m.sessionId)
y := x[39+len(m.sessionId):]
y[0] = uint8(len(m.cipherSuites) >> 7)
y[1] = uint8(len(m.cipherSuites) << 1)
for i, suite := range m.cipherSuites {
y[2+i*2] = uint8(suite >> 8)
y[3+i*2] = uint8(suite)
}
z := y[2+len(m.cipherSuites)*2:]
z[0] = uint8(len(m.compressionMethods))
copy(z[1:], m.compressionMethods)
z = z[1+len(m.compressionMethods):]
if numExtensions > 0 {
z[0] = byte(extensionsLength >> 8)
z[1] = byte(extensionsLength)
z = z[2:]
}
}
if len(m.supportedPoints) > 0 {
// http://tools.ietf.org/html/rfc4492#section-5.5.2
z[0] = byte(extensionSupportedPoints >> 8)
z[1] = byte(extensionSupportedPoints)
l := 1 + len(m.supportedPoints)
z[2] = byte(l >> 8)
z[3] = byte(l)
l--
z[4] = byte(l)
z = z[5:]
for _, pointFormat := range m.supportedPoints {
z[0] = pointFormat
z = z[1:]
if m.nextProtoNeg {
z[0] = byte(extensionNextProtoNeg >> 8)
z[1] = byte(extensionNextProtoNeg & 0xff)
// The length is always 0
z = z[4:]
}
}
if m.ticketSupported {
// http://tools.ietf.org/html/rfc5077#section-3.2
z[0] = byte(extensionSessionTicket >> 8)
z[1] = byte(extensionSessionTicket)
l := len(m.sessionTicket)
z[2] = byte(l >> 8)
z[3] = byte(l)
z = z[4:]
copy(z, m.sessionTicket)
z = z[len(m.sessionTicket):]
}
if len(m.signatureAndHashes) > 0 {
// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
z[0] = byte(extensionSignatureAlgorithms >> 8)
z[1] = byte(extensionSignatureAlgorithms)
l := 2 + 2*len(m.signatureAndHashes)
z[2] = byte(l >> 8)
z[3] = byte(l)
z = z[4:]
l -= 2
z[0] = byte(l >> 8)
z[1] = byte(l)
z = z[2:]
for _, sigAndHash := range m.signatureAndHashes {
z[0] = sigAndHash.hash
z[1] = sigAndHash.signature
if len(m.serverName) > 0 {
z[0] = byte(extensionServerName >> 8)
z[1] = byte(extensionServerName & 0xff)
l := len(m.serverName) + 5
z[2] = byte(l >> 8)
z[3] = byte(l)
z = z[4:]
// RFC 3546, section 3.1
//
// struct {
// NameType name_type;
// select (name_type) {
// case host_name: HostName;
// } name;
// } ServerName;
//
// enum {
// host_name(0), (255)
// } NameType;
//
// opaque HostName<1..2^16-1>;
//
// struct {
// ServerName server_name_list<1..2^16-1>
// } ServerNameList;
z[0] = byte((len(m.serverName) + 3) >> 8)
z[1] = byte(len(m.serverName) + 3)
z[3] = byte(len(m.serverName) >> 8)
z[4] = byte(len(m.serverName))
copy(z[5:], []byte(m.serverName))
z = z[l:]
}
if m.ocspStapling {
// RFC 4366, section 3.6
z[0] = byte(extensionStatusRequest >> 8)
z[1] = byte(extensionStatusRequest)
z[2] = 0
z[3] = 5
z[4] = 1 // OCSP type
// Two zero valued uint16s for the two lengths.
z = z[9:]
}
if len(m.supportedCurves) > 0 {
// http://tools.ietf.org/html/rfc4492#section-5.5.1
z[0] = byte(extensionSupportedCurves >> 8)
z[1] = byte(extensionSupportedCurves)
l := 2 + 2*len(m.supportedCurves)
z[2] = byte(l >> 8)
z[3] = byte(l)
l -= 2
z[4] = byte(l >> 8)
z[5] = byte(l)
z = z[6:]
for _, curve := range m.supportedCurves {
z[0] = byte(curve >> 8)
z[1] = byte(curve)
z = z[2:]
}
}
if len(m.supportedPoints) > 0 {
// http://tools.ietf.org/html/rfc4492#section-5.5.2
z[0] = byte(extensionSupportedPoints >> 8)
z[1] = byte(extensionSupportedPoints)
l := 1 + len(m.supportedPoints)
z[2] = byte(l >> 8)
z[3] = byte(l)
l--
z[4] = byte(l)
z = z[5:]
for _, pointFormat := range m.supportedPoints {
z[0] = pointFormat
z = z[1:]
}
}
if m.ticketSupported {
// http://tools.ietf.org/html/rfc5077#section-3.2
z[0] = byte(extensionSessionTicket >> 8)
z[1] = byte(extensionSessionTicket)
l := len(m.sessionTicket)
z[2] = byte(l >> 8)
z[3] = byte(l)
z = z[4:]
copy(z, m.sessionTicket)
z = z[len(m.sessionTicket):]
}
if len(m.signatureAndHashes) > 0 {
// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
z[0] = byte(extensionSignatureAlgorithms >> 8)
z[1] = byte(extensionSignatureAlgorithms)
l := 2 + 2*len(m.signatureAndHashes)
z[2] = byte(l >> 8)
z[3] = byte(l)
z = z[4:]
l -= 2
z[0] = byte(l >> 8)
z[1] = byte(l)
z = z[2:]
for _, sigAndHash := range m.signatureAndHashes {
z[0] = sigAndHash.hash
z[1] = sigAndHash.signature
z = z[2:]
}
}
}
if m.secureRenegotiationSupported {
z[0] = byte(extensionRenegotiationInfo >> 8)
z[1] = byte(extensionRenegotiationInfo & 0xff)
z[2] = 0
z[3] = byte(len(m.secureRenegotiation) + 1)
z[4] = byte(len(m.secureRenegotiation))
z = z[5:]
copy(z, m.secureRenegotiation)
z = z[len(m.secureRenegotiation):]
}
if len(m.alpnProtocols) > 0 {
z[0] = byte(extensionALPN >> 8)
z[1] = byte(extensionALPN & 0xff)
lengths := z[2:]
z = z[6:]
stringsLength := 0
for _, s := range m.alpnProtocols {
l := len(s)
z[0] = byte(l)
copy(z[1:], s)
z = z[1+l:]
stringsLength += 1 + l
if m.secureRenegotiationSupported {
z[0] = byte(extensionRenegotiationInfo >> 8)
z[1] = byte(extensionRenegotiationInfo & 0xff)
z[2] = 0
z[3] = byte(len(m.secureRenegotiation) + 1)
z[4] = byte(len(m.secureRenegotiation))
z = z[5:]
copy(z, m.secureRenegotiation)
z = z[len(m.secureRenegotiation):]
}
if len(m.alpnProtocols) > 0 {
z[0] = byte(extensionALPN >> 8)
z[1] = byte(extensionALPN & 0xff)
lengths := z[2:]
z = z[6:]
stringsLength := 0
for _, s := range m.alpnProtocols {
l := len(s)
z[0] = byte(l)
copy(z[1:], s)
z = z[1+l:]
stringsLength += 1 + l
}
lengths[2] = byte(stringsLength >> 8)
lengths[3] = byte(stringsLength)
stringsLength += 2
lengths[0] = byte(stringsLength >> 8)
lengths[1] = byte(stringsLength)
}
if m.scts {
// https://tools.ietf.org/html/rfc6962#section-3.3.1
z[0] = byte(extensionSCT >> 8)
z[1] = byte(extensionSCT)
// zero uint16 for the zero-length extension_data
z = z[4:]
}
lengths[2] = byte(stringsLength >> 8)
lengths[3] = byte(stringsLength)
stringsLength += 2
lengths[0] = byte(stringsLength >> 8)
lengths[1] = byte(stringsLength)
}
if m.scts {
// https://tools.ietf.org/html/rfc6962#section-3.3.1
z[0] = byte(extensionSCT >> 8)
z[1] = byte(extensionSCT)
// zero uint16 for the zero-length extension_data
z = z[4:]
}
m.raw = x
m.raw = x
return x
}
return x
}
*/
func (m *clientHelloMsg) unmarshal(data []byte) bool {
if len(data) < 42 {

@ -5,7 +5,6 @@ import (
"flag"
"fmt"
"io"
"io/ioutil"
"net"
"os"
"os/signal"
@ -78,7 +77,11 @@ func forward(ctx context.Context, c net.Conn, data []byte, dst string) {
Command: proxyproto.PROXY,
TransportProtocol: proxyproto.TCPv4,
SourceAddr: raddr,
DestinationAddr: &net.TCPAddr{},
DestinationAddr: &net.TCPAddr{
IP: []byte{0, 0, 0, 0},
Port: 0,
Zone: "",
},
}
switch strings.ToLower(ss[1]) {
@ -186,7 +189,7 @@ func main() {
flag.StringVar(&loglevel, "log_level", "INFO", "log level")
flag.Parse()
data, err := ioutil.ReadFile(cfgfile)
data, err := os.ReadFile(cfgfile)
if err != nil {
glog.Fatal(err)
}
@ -231,9 +234,7 @@ func main() {
ch := make(chan os.Signal, 2)
signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM)
select {
case s := <-ch:
cancel()
glog.Printf("received signal %s, exit.", s)
}
s := <-ch
cancel()
glog.Printf("received signal %s, exit.", s)
}

@ -1,96 +1,96 @@
package main
import (
"context"
"crypto/tls"
"io/ioutil"
"net"
"testing"
"github.com/fangdingjun/go-log/v5"
"github.com/fangdingjun/protolistener"
yaml "gopkg.in/yaml.v2"
)
func TestProxyProto(t *testing.T) {
log.Default.Level = log.DEBUG
data, err := ioutil.ReadFile("config.sample.yaml")
if err != nil {
log.Fatal(err)
}
if err = yaml.Unmarshal(data, &cfg); err != nil {
log.Fatal(err)
}
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer listener.Close()
log.Printf("listen %s", listener.Addr().String())
go func() {
for {
conn, err := listener.Accept()
if err != nil {
return
}
go serve(context.Background(), conn)
}
}()
cert, err := tls.LoadX509KeyPair("server.crt", "server.key")
if err != nil {
t.Fatal(err)
}
listener2, err := net.Listen("tcp", "127.0.0.1:8443")
if err != nil {
t.Fatal(err)
}
defer listener2.Close()
listener2 = tls.NewListener(protolistener.New(listener2), &tls.Config{
Certificates: []tls.Certificate{cert},
})
go func() {
for {
conn, err := listener2.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer c.Close()
addr := conn.RemoteAddr()
_conn := c.(*tls.Conn)
if err := _conn.Handshake(); err != nil {
log.Errorf("handshake error: %s", err)
return
}
conn.Write([]byte(addr.String()))
}(conn)
}
}()
conn, err := tls.Dial("tcp", listener.Addr().String(), &tls.Config{
ServerName: "www.example.com",
InsecureSkipVerify: true,
})
if err != nil {
log.Println("dial error")
t.Fatal(err)
}
conn.Handshake()
buf := make([]byte, 200)
n, err := conn.Read(buf)
if err != nil {
log.Println("read error")
t.Fatal(err)
}
addr1 := conn.LocalAddr().String()
addr2 := string(buf[:n])
conn.Close()
if addr1 != addr2 {
t.Errorf("expect %s, got: %s", addr1, addr2)
}
}
package main
import (
"context"
"crypto/tls"
"net"
"os"
"testing"
"github.com/fangdingjun/go-log/v5"
"github.com/fangdingjun/protolistener"
yaml "gopkg.in/yaml.v2"
)
func TestProxyProto(t *testing.T) {
log.Default.Level = log.DEBUG
data, err := os.ReadFile("config.sample.yaml")
if err != nil {
log.Fatal(err)
}
if err = yaml.Unmarshal(data, &cfg); err != nil {
log.Fatal(err)
}
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer listener.Close()
log.Printf("listen %s", listener.Addr().String())
go func() {
for {
conn, err := listener.Accept()
if err != nil {
return
}
go serve(context.Background(), conn)
}
}()
cert, err := tls.LoadX509KeyPair("server.crt", "server.key")
if err != nil {
t.Fatal(err)
}
listener2, err := net.Listen("tcp", "127.0.0.1:8443")
if err != nil {
t.Fatal(err)
}
defer listener2.Close()
listener2 = tls.NewListener(protolistener.New(listener2), &tls.Config{
Certificates: []tls.Certificate{cert},
})
go func() {
for {
conn, err := listener2.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer c.Close()
addr := conn.RemoteAddr()
_conn := c.(*tls.Conn)
if err := _conn.Handshake(); err != nil {
log.Errorf("handshake error: %s", err)
return
}
conn.Write([]byte(addr.String()))
}(conn)
}
}()
conn, err := tls.Dial("tcp", listener.Addr().String(), &tls.Config{
ServerName: "www.example.com",
InsecureSkipVerify: true,
})
if err != nil {
log.Println("dial error")
t.Fatal(err)
}
conn.Handshake()
buf := make([]byte, 200)
n, err := conn.Read(buf)
if err != nil {
log.Println("read error")
t.Fatal(err)
}
addr1 := conn.LocalAddr().String()
addr2 := string(buf[:n])
conn.Close()
if addr1 != addr2 {
t.Errorf("expect %s, got: %s", addr1, addr2)
}
}

Loading…
Cancel
Save