go fmt and fix static check warnings

master
dingjun 2 years ago
parent 12f4938134
commit c39f4f526a

@ -69,9 +69,11 @@ const (
) )
// TLS compression types. // TLS compression types.
/*
const ( const (
compressionNone uint8 = 0 compressionNone uint8 = 0
) )
*/
// TLS extension numbers // TLS extension numbers
const ( const (
@ -107,9 +109,11 @@ const (
// TLS Elliptic Curve Point Formats // TLS Elliptic Curve Point Formats
// http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9 // http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9
/*
const ( const (
pointFormatUncompressed uint8 = 0 pointFormatUncompressed uint8 = 0
) )
*/
// TLS CertificateStatusType (RFC 3546) // TLS CertificateStatusType (RFC 3546)
const ( const (
@ -117,6 +121,7 @@ const (
) )
// Certificate types (for certificateRequestMsg) // Certificate types (for certificateRequestMsg)
/*
const ( const (
certTypeRSASign = 1 // A certificate containing an RSA key certTypeRSASign = 1 // A certificate containing an RSA key
certTypeDSSSign = 2 // A certificate containing a DSA key certTypeDSSSign = 2 // A certificate containing a DSA key
@ -130,19 +135,24 @@ const (
// Rest of these are reserved by the TLS spec // Rest of these are reserved by the TLS spec
) )
*/
// Hash functions for TLS 1.2 (See RFC 5246, section A.4.1) // Hash functions for TLS 1.2 (See RFC 5246, section A.4.1)
/*
const ( const (
hashSHA1 uint8 = 2 hashSHA1 uint8 = 2
hashSHA256 uint8 = 4 hashSHA256 uint8 = 4
hashSHA384 uint8 = 5 hashSHA384 uint8 = 5
) )
*/
// Signature algorithms for TLS 1.2 (See RFC 5246, section A.4.1) // Signature algorithms for TLS 1.2 (See RFC 5246, section A.4.1)
/*
const ( const (
signatureRSA uint8 = 1 signatureRSA uint8 = 1
signatureECDSA uint8 = 3 signatureECDSA uint8 = 3
) )
*/
// signatureAndHash mirrors the TLS 1.2, SignatureAndHashAlgorithm struct. See // signatureAndHash mirrors the TLS 1.2, SignatureAndHashAlgorithm struct. See
// RFC 5246, section A.4.1. // RFC 5246, section A.4.1.
@ -153,6 +163,7 @@ type signatureAndHash struct {
// supportedSignatureAlgorithms contains the signature and hash algorithms that // supportedSignatureAlgorithms contains the signature and hash algorithms that
// the code advertises as supported in a TLS 1.2 ClientHello and in a TLS 1.2 // the code advertises as supported in a TLS 1.2 ClientHello and in a TLS 1.2
// CertificateRequest. // CertificateRequest.
/*
var supportedSignatureAlgorithms = []signatureAndHash{ var supportedSignatureAlgorithms = []signatureAndHash{
{hashSHA256, signatureRSA}, {hashSHA256, signatureRSA},
{hashSHA256, signatureECDSA}, {hashSHA256, signatureECDSA},
@ -161,6 +172,7 @@ var supportedSignatureAlgorithms = []signatureAndHash{
{hashSHA1, signatureRSA}, {hashSHA1, signatureRSA},
{hashSHA1, signatureECDSA}, {hashSHA1, signatureECDSA},
} }
*/
// ConnectionState records basic TLS details about the connection. // ConnectionState records basic TLS details about the connection.
type ConnectionState struct { type ConnectionState struct {
@ -201,6 +213,7 @@ const (
// ClientSessionState contains the state needed by clients to resume TLS // ClientSessionState contains the state needed by clients to resume TLS
// sessions. // sessions.
/*
type ClientSessionState struct { type ClientSessionState struct {
sessionTicket []uint8 // Encrypted ticket used for session resumption with server sessionTicket []uint8 // Encrypted ticket used for session resumption with server
vers uint16 // SSL/TLS version negotiated for the session vers uint16 // SSL/TLS version negotiated for the session
@ -209,12 +222,14 @@ type ClientSessionState struct {
serverCertificates []*x509.Certificate // Certificate chain presented by the server serverCertificates []*x509.Certificate // Certificate chain presented by the server
verifiedChains [][]*x509.Certificate // Certificate chains we built for verification verifiedChains [][]*x509.Certificate // Certificate chains we built for verification
} }
*/
// ClientSessionCache is a cache of ClientSessionState objects that can be used // ClientSessionCache is a cache of ClientSessionState objects that can be used
// by a client to resume a TLS session with a given server. ClientSessionCache // by a client to resume a TLS session with a given server. ClientSessionCache
// implementations should expect to be called concurrently from different // implementations should expect to be called concurrently from different
// goroutines. Only ticket-based resumption is supported, not SessionID-based // goroutines. Only ticket-based resumption is supported, not SessionID-based
// resumption. // resumption.
/*
type ClientSessionCache interface { type ClientSessionCache interface {
// Get searches for a ClientSessionState associated with the given key. // Get searches for a ClientSessionState associated with the given key.
// On return, ok is true if one was found. // On return, ok is true if one was found.
@ -223,6 +238,7 @@ type ClientSessionCache interface {
// Put adds the ClientSessionState to the cache with the given key. // Put adds the ClientSessionState to the cache with the given key.
Put(sessionKey string, cs *ClientSessionState) Put(sessionKey string, cs *ClientSessionState)
} }
*/
/* /*
// SignatureScheme identifies a signature algorithm supported by TLS. See // SignatureScheme identifies a signature algorithm supported by TLS. See

@ -1,14 +1,14 @@
package main package main
import ( import (
"io/ioutil" "os"
"testing" "testing"
yaml "gopkg.in/yaml.v2" yaml "gopkg.in/yaml.v2"
) )
func TestConf(t *testing.T) { func TestConf(t *testing.T) {
data, err := ioutil.ReadFile("config.sample.yaml") data, err := os.ReadFile("config.sample.yaml")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

@ -33,272 +33,272 @@ type clientHelloMsg struct {
} }
/* /*
func (m *clientHelloMsg) equal(i interface{}) bool { func (m *clientHelloMsg) equal(i interface{}) bool {
m1, ok := i.(*clientHelloMsg) m1, ok := i.(*clientHelloMsg)
if !ok { if !ok {
return false return false
} }
return bytes.Equal(m.raw, m1.raw) &&
m.vers == m1.vers &&
bytes.Equal(m.random, m1.random) &&
bytes.Equal(m.sessionId, m1.sessionId) &&
eqUint16s(m.cipherSuites, m1.cipherSuites) &&
bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
m.nextProtoNeg == m1.nextProtoNeg &&
m.serverName == m1.serverName &&
m.ocspStapling == m1.ocspStapling &&
m.scts == m1.scts &&
eqCurveIDs(m.supportedCurves, m1.supportedCurves) &&
bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
m.ticketSupported == m1.ticketSupported &&
bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) &&
m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
eqStrings(m.alpnProtocols, m1.alpnProtocols)
}
func (m *clientHelloMsg) marshal() []byte { return bytes.Equal(m.raw, m1.raw) &&
if m.raw != nil { m.vers == m1.vers &&
return m.raw bytes.Equal(m.random, m1.random) &&
} bytes.Equal(m.sessionId, m1.sessionId) &&
eqUint16s(m.cipherSuites, m1.cipherSuites) &&
bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
m.nextProtoNeg == m1.nextProtoNeg &&
m.serverName == m1.serverName &&
m.ocspStapling == m1.ocspStapling &&
m.scts == m1.scts &&
eqCurveIDs(m.supportedCurves, m1.supportedCurves) &&
bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
m.ticketSupported == m1.ticketSupported &&
bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) &&
m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
eqStrings(m.alpnProtocols, m1.alpnProtocols)
}
func (m *clientHelloMsg) marshal() []byte {
if m.raw != nil {
return m.raw
}
length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods) length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
numExtensions := 0 numExtensions := 0
extensionsLength := 0 extensionsLength := 0
if m.nextProtoNeg { if m.nextProtoNeg {
numExtensions++ numExtensions++
} }
if m.ocspStapling { if m.ocspStapling {
extensionsLength += 1 + 2 + 2 extensionsLength += 1 + 2 + 2
numExtensions++ numExtensions++
} }
if len(m.serverName) > 0 { if len(m.serverName) > 0 {
extensionsLength += 5 + len(m.serverName) extensionsLength += 5 + len(m.serverName)
numExtensions++ numExtensions++
} }
if len(m.supportedCurves) > 0 { if len(m.supportedCurves) > 0 {
extensionsLength += 2 + 2*len(m.supportedCurves) extensionsLength += 2 + 2*len(m.supportedCurves)
numExtensions++ numExtensions++
} }
if len(m.supportedPoints) > 0 { if len(m.supportedPoints) > 0 {
extensionsLength += 1 + len(m.supportedPoints) extensionsLength += 1 + len(m.supportedPoints)
numExtensions++ numExtensions++
} }
if m.ticketSupported { if m.ticketSupported {
extensionsLength += len(m.sessionTicket) extensionsLength += len(m.sessionTicket)
numExtensions++ numExtensions++
} }
if len(m.signatureAndHashes) > 0 { if len(m.signatureAndHashes) > 0 {
extensionsLength += 2 + 2*len(m.signatureAndHashes) extensionsLength += 2 + 2*len(m.signatureAndHashes)
numExtensions++ numExtensions++
} }
if m.secureRenegotiationSupported { if m.secureRenegotiationSupported {
extensionsLength += 1 + len(m.secureRenegotiation) extensionsLength += 1 + len(m.secureRenegotiation)
numExtensions++ numExtensions++
} }
if len(m.alpnProtocols) > 0 { if len(m.alpnProtocols) > 0 {
extensionsLength += 2 extensionsLength += 2
for _, s := range m.alpnProtocols { for _, s := range m.alpnProtocols {
if l := len(s); l == 0 || l > 255 { if l := len(s); l == 0 || l > 255 {
panic("invalid ALPN protocol") panic("invalid ALPN protocol")
}
extensionsLength++
extensionsLength += len(s)
} }
extensionsLength++ numExtensions++
extensionsLength += len(s) }
if m.scts {
numExtensions++
}
if numExtensions > 0 {
extensionsLength += 4 * numExtensions
length += 2 + extensionsLength
} }
numExtensions++
}
if m.scts {
numExtensions++
}
if numExtensions > 0 {
extensionsLength += 4 * numExtensions
length += 2 + extensionsLength
}
x := make([]byte, 4+length)
x[0] = typeClientHello
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
x[3] = uint8(length)
x[4] = uint8(m.vers >> 8)
x[5] = uint8(m.vers)
copy(x[6:38], m.random)
x[38] = uint8(len(m.sessionId))
copy(x[39:39+len(m.sessionId)], m.sessionId)
y := x[39+len(m.sessionId):]
y[0] = uint8(len(m.cipherSuites) >> 7)
y[1] = uint8(len(m.cipherSuites) << 1)
for i, suite := range m.cipherSuites {
y[2+i*2] = uint8(suite >> 8)
y[3+i*2] = uint8(suite)
}
z := y[2+len(m.cipherSuites)*2:]
z[0] = uint8(len(m.compressionMethods))
copy(z[1:], m.compressionMethods)
z = z[1+len(m.compressionMethods):]
if numExtensions > 0 {
z[0] = byte(extensionsLength >> 8)
z[1] = byte(extensionsLength)
z = z[2:]
}
if m.nextProtoNeg {
z[0] = byte(extensionNextProtoNeg >> 8)
z[1] = byte(extensionNextProtoNeg & 0xff)
// The length is always 0
z = z[4:]
}
if len(m.serverName) > 0 {
z[0] = byte(extensionServerName >> 8)
z[1] = byte(extensionServerName & 0xff)
l := len(m.serverName) + 5
z[2] = byte(l >> 8)
z[3] = byte(l)
z = z[4:]
// RFC 3546, section 3.1 x := make([]byte, 4+length)
// x[0] = typeClientHello
// struct { x[1] = uint8(length >> 16)
// NameType name_type; x[2] = uint8(length >> 8)
// select (name_type) { x[3] = uint8(length)
// case host_name: HostName; x[4] = uint8(m.vers >> 8)
// } name; x[5] = uint8(m.vers)
// } ServerName; copy(x[6:38], m.random)
// x[38] = uint8(len(m.sessionId))
// enum { copy(x[39:39+len(m.sessionId)], m.sessionId)
// host_name(0), (255) y := x[39+len(m.sessionId):]
// } NameType; y[0] = uint8(len(m.cipherSuites) >> 7)
// y[1] = uint8(len(m.cipherSuites) << 1)
// opaque HostName<1..2^16-1>; for i, suite := range m.cipherSuites {
// y[2+i*2] = uint8(suite >> 8)
// struct { y[3+i*2] = uint8(suite)
// ServerName server_name_list<1..2^16-1> }
// } ServerNameList; z := y[2+len(m.cipherSuites)*2:]
z[0] = uint8(len(m.compressionMethods))
z[0] = byte((len(m.serverName) + 3) >> 8) copy(z[1:], m.compressionMethods)
z[1] = byte(len(m.serverName) + 3)
z[3] = byte(len(m.serverName) >> 8) z = z[1+len(m.compressionMethods):]
z[4] = byte(len(m.serverName)) if numExtensions > 0 {
copy(z[5:], []byte(m.serverName)) z[0] = byte(extensionsLength >> 8)
z = z[l:] z[1] = byte(extensionsLength)
}
if m.ocspStapling {
// RFC 4366, section 3.6
z[0] = byte(extensionStatusRequest >> 8)
z[1] = byte(extensionStatusRequest)
z[2] = 0
z[3] = 5
z[4] = 1 // OCSP type
// Two zero valued uint16s for the two lengths.
z = z[9:]
}
if len(m.supportedCurves) > 0 {
// http://tools.ietf.org/html/rfc4492#section-5.5.1
z[0] = byte(extensionSupportedCurves >> 8)
z[1] = byte(extensionSupportedCurves)
l := 2 + 2*len(m.supportedCurves)
z[2] = byte(l >> 8)
z[3] = byte(l)
l -= 2
z[4] = byte(l >> 8)
z[5] = byte(l)
z = z[6:]
for _, curve := range m.supportedCurves {
z[0] = byte(curve >> 8)
z[1] = byte(curve)
z = z[2:] z = z[2:]
} }
} if m.nextProtoNeg {
if len(m.supportedPoints) > 0 { z[0] = byte(extensionNextProtoNeg >> 8)
// http://tools.ietf.org/html/rfc4492#section-5.5.2 z[1] = byte(extensionNextProtoNeg & 0xff)
z[0] = byte(extensionSupportedPoints >> 8) // The length is always 0
z[1] = byte(extensionSupportedPoints) z = z[4:]
l := 1 + len(m.supportedPoints)
z[2] = byte(l >> 8)
z[3] = byte(l)
l--
z[4] = byte(l)
z = z[5:]
for _, pointFormat := range m.supportedPoints {
z[0] = pointFormat
z = z[1:]
} }
} if len(m.serverName) > 0 {
if m.ticketSupported { z[0] = byte(extensionServerName >> 8)
// http://tools.ietf.org/html/rfc5077#section-3.2 z[1] = byte(extensionServerName & 0xff)
z[0] = byte(extensionSessionTicket >> 8) l := len(m.serverName) + 5
z[1] = byte(extensionSessionTicket) z[2] = byte(l >> 8)
l := len(m.sessionTicket) z[3] = byte(l)
z[2] = byte(l >> 8) z = z[4:]
z[3] = byte(l)
z = z[4:] // RFC 3546, section 3.1
copy(z, m.sessionTicket) //
z = z[len(m.sessionTicket):] // struct {
} // NameType name_type;
if len(m.signatureAndHashes) > 0 { // select (name_type) {
// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1 // case host_name: HostName;
z[0] = byte(extensionSignatureAlgorithms >> 8) // } name;
z[1] = byte(extensionSignatureAlgorithms) // } ServerName;
l := 2 + 2*len(m.signatureAndHashes) //
z[2] = byte(l >> 8) // enum {
z[3] = byte(l) // host_name(0), (255)
z = z[4:] // } NameType;
//
l -= 2 // opaque HostName<1..2^16-1>;
z[0] = byte(l >> 8) //
z[1] = byte(l) // struct {
z = z[2:] // ServerName server_name_list<1..2^16-1>
for _, sigAndHash := range m.signatureAndHashes { // } ServerNameList;
z[0] = sigAndHash.hash
z[1] = sigAndHash.signature z[0] = byte((len(m.serverName) + 3) >> 8)
z[1] = byte(len(m.serverName) + 3)
z[3] = byte(len(m.serverName) >> 8)
z[4] = byte(len(m.serverName))
copy(z[5:], []byte(m.serverName))
z = z[l:]
}
if m.ocspStapling {
// RFC 4366, section 3.6
z[0] = byte(extensionStatusRequest >> 8)
z[1] = byte(extensionStatusRequest)
z[2] = 0
z[3] = 5
z[4] = 1 // OCSP type
// Two zero valued uint16s for the two lengths.
z = z[9:]
}
if len(m.supportedCurves) > 0 {
// http://tools.ietf.org/html/rfc4492#section-5.5.1
z[0] = byte(extensionSupportedCurves >> 8)
z[1] = byte(extensionSupportedCurves)
l := 2 + 2*len(m.supportedCurves)
z[2] = byte(l >> 8)
z[3] = byte(l)
l -= 2
z[4] = byte(l >> 8)
z[5] = byte(l)
z = z[6:]
for _, curve := range m.supportedCurves {
z[0] = byte(curve >> 8)
z[1] = byte(curve)
z = z[2:]
}
}
if len(m.supportedPoints) > 0 {
// http://tools.ietf.org/html/rfc4492#section-5.5.2
z[0] = byte(extensionSupportedPoints >> 8)
z[1] = byte(extensionSupportedPoints)
l := 1 + len(m.supportedPoints)
z[2] = byte(l >> 8)
z[3] = byte(l)
l--
z[4] = byte(l)
z = z[5:]
for _, pointFormat := range m.supportedPoints {
z[0] = pointFormat
z = z[1:]
}
}
if m.ticketSupported {
// http://tools.ietf.org/html/rfc5077#section-3.2
z[0] = byte(extensionSessionTicket >> 8)
z[1] = byte(extensionSessionTicket)
l := len(m.sessionTicket)
z[2] = byte(l >> 8)
z[3] = byte(l)
z = z[4:]
copy(z, m.sessionTicket)
z = z[len(m.sessionTicket):]
}
if len(m.signatureAndHashes) > 0 {
// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
z[0] = byte(extensionSignatureAlgorithms >> 8)
z[1] = byte(extensionSignatureAlgorithms)
l := 2 + 2*len(m.signatureAndHashes)
z[2] = byte(l >> 8)
z[3] = byte(l)
z = z[4:]
l -= 2
z[0] = byte(l >> 8)
z[1] = byte(l)
z = z[2:] z = z[2:]
for _, sigAndHash := range m.signatureAndHashes {
z[0] = sigAndHash.hash
z[1] = sigAndHash.signature
z = z[2:]
}
} }
} if m.secureRenegotiationSupported {
if m.secureRenegotiationSupported { z[0] = byte(extensionRenegotiationInfo >> 8)
z[0] = byte(extensionRenegotiationInfo >> 8) z[1] = byte(extensionRenegotiationInfo & 0xff)
z[1] = byte(extensionRenegotiationInfo & 0xff) z[2] = 0
z[2] = 0 z[3] = byte(len(m.secureRenegotiation) + 1)
z[3] = byte(len(m.secureRenegotiation) + 1) z[4] = byte(len(m.secureRenegotiation))
z[4] = byte(len(m.secureRenegotiation)) z = z[5:]
z = z[5:] copy(z, m.secureRenegotiation)
copy(z, m.secureRenegotiation) z = z[len(m.secureRenegotiation):]
z = z[len(m.secureRenegotiation):]
}
if len(m.alpnProtocols) > 0 {
z[0] = byte(extensionALPN >> 8)
z[1] = byte(extensionALPN & 0xff)
lengths := z[2:]
z = z[6:]
stringsLength := 0
for _, s := range m.alpnProtocols {
l := len(s)
z[0] = byte(l)
copy(z[1:], s)
z = z[1+l:]
stringsLength += 1 + l
} }
if len(m.alpnProtocols) > 0 {
z[0] = byte(extensionALPN >> 8)
z[1] = byte(extensionALPN & 0xff)
lengths := z[2:]
z = z[6:]
stringsLength := 0
for _, s := range m.alpnProtocols {
l := len(s)
z[0] = byte(l)
copy(z[1:], s)
z = z[1+l:]
stringsLength += 1 + l
}
lengths[2] = byte(stringsLength >> 8) lengths[2] = byte(stringsLength >> 8)
lengths[3] = byte(stringsLength) lengths[3] = byte(stringsLength)
stringsLength += 2 stringsLength += 2
lengths[0] = byte(stringsLength >> 8) lengths[0] = byte(stringsLength >> 8)
lengths[1] = byte(stringsLength) lengths[1] = byte(stringsLength)
} }
if m.scts { if m.scts {
// https://tools.ietf.org/html/rfc6962#section-3.3.1 // https://tools.ietf.org/html/rfc6962#section-3.3.1
z[0] = byte(extensionSCT >> 8) z[0] = byte(extensionSCT >> 8)
z[1] = byte(extensionSCT) z[1] = byte(extensionSCT)
// zero uint16 for the zero-length extension_data // zero uint16 for the zero-length extension_data
z = z[4:] z = z[4:]
} }
m.raw = x m.raw = x
return x return x
} }
*/ */
func (m *clientHelloMsg) unmarshal(data []byte) bool { func (m *clientHelloMsg) unmarshal(data []byte) bool {
if len(data) < 42 { if len(data) < 42 {

@ -5,7 +5,6 @@ import (
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net" "net"
"os" "os"
"os/signal" "os/signal"
@ -78,7 +77,11 @@ func forward(ctx context.Context, c net.Conn, data []byte, dst string) {
Command: proxyproto.PROXY, Command: proxyproto.PROXY,
TransportProtocol: proxyproto.TCPv4, TransportProtocol: proxyproto.TCPv4,
SourceAddr: raddr, SourceAddr: raddr,
DestinationAddr: &net.TCPAddr{}, DestinationAddr: &net.TCPAddr{
IP: []byte{0, 0, 0, 0},
Port: 0,
Zone: "",
},
} }
switch strings.ToLower(ss[1]) { switch strings.ToLower(ss[1]) {
@ -186,7 +189,7 @@ func main() {
flag.StringVar(&loglevel, "log_level", "INFO", "log level") flag.StringVar(&loglevel, "log_level", "INFO", "log level")
flag.Parse() flag.Parse()
data, err := ioutil.ReadFile(cfgfile) data, err := os.ReadFile(cfgfile)
if err != nil { if err != nil {
glog.Fatal(err) glog.Fatal(err)
} }
@ -231,9 +234,7 @@ func main() {
ch := make(chan os.Signal, 2) ch := make(chan os.Signal, 2)
signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM) signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM)
select { s := <-ch
case s := <-ch: cancel()
cancel() glog.Printf("received signal %s, exit.", s)
glog.Printf("received signal %s, exit.", s)
}
} }

@ -1,96 +1,96 @@
package main package main
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"io/ioutil" "net"
"net" "os"
"testing" "testing"
"github.com/fangdingjun/go-log/v5" "github.com/fangdingjun/go-log/v5"
"github.com/fangdingjun/protolistener" "github.com/fangdingjun/protolistener"
yaml "gopkg.in/yaml.v2" yaml "gopkg.in/yaml.v2"
) )
func TestProxyProto(t *testing.T) { func TestProxyProto(t *testing.T) {
log.Default.Level = log.DEBUG log.Default.Level = log.DEBUG
data, err := ioutil.ReadFile("config.sample.yaml") data, err := os.ReadFile("config.sample.yaml")
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
if err = yaml.Unmarshal(data, &cfg); err != nil { if err = yaml.Unmarshal(data, &cfg); err != nil {
log.Fatal(err) log.Fatal(err)
} }
listener, err := net.Listen("tcp", "127.0.0.1:0") listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer listener.Close() defer listener.Close()
log.Printf("listen %s", listener.Addr().String()) log.Printf("listen %s", listener.Addr().String())
go func() { go func() {
for { for {
conn, err := listener.Accept() conn, err := listener.Accept()
if err != nil { if err != nil {
return return
} }
go serve(context.Background(), conn) go serve(context.Background(), conn)
} }
}() }()
cert, err := tls.LoadX509KeyPair("server.crt", "server.key") cert, err := tls.LoadX509KeyPair("server.crt", "server.key")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
listener2, err := net.Listen("tcp", "127.0.0.1:8443") listener2, err := net.Listen("tcp", "127.0.0.1:8443")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer listener2.Close() defer listener2.Close()
listener2 = tls.NewListener(protolistener.New(listener2), &tls.Config{ listener2 = tls.NewListener(protolistener.New(listener2), &tls.Config{
Certificates: []tls.Certificate{cert}, Certificates: []tls.Certificate{cert},
}) })
go func() { go func() {
for { for {
conn, err := listener2.Accept() conn, err := listener2.Accept()
if err != nil { if err != nil {
return return
} }
go func(c net.Conn) { go func(c net.Conn) {
defer c.Close() defer c.Close()
addr := conn.RemoteAddr() addr := conn.RemoteAddr()
_conn := c.(*tls.Conn) _conn := c.(*tls.Conn)
if err := _conn.Handshake(); err != nil { if err := _conn.Handshake(); err != nil {
log.Errorf("handshake error: %s", err) log.Errorf("handshake error: %s", err)
return return
} }
conn.Write([]byte(addr.String())) conn.Write([]byte(addr.String()))
}(conn) }(conn)
} }
}() }()
conn, err := tls.Dial("tcp", listener.Addr().String(), &tls.Config{ conn, err := tls.Dial("tcp", listener.Addr().String(), &tls.Config{
ServerName: "www.example.com", ServerName: "www.example.com",
InsecureSkipVerify: true, InsecureSkipVerify: true,
}) })
if err != nil { if err != nil {
log.Println("dial error") log.Println("dial error")
t.Fatal(err) t.Fatal(err)
} }
conn.Handshake() conn.Handshake()
buf := make([]byte, 200) buf := make([]byte, 200)
n, err := conn.Read(buf) n, err := conn.Read(buf)
if err != nil { if err != nil {
log.Println("read error") log.Println("read error")
t.Fatal(err) t.Fatal(err)
} }
addr1 := conn.LocalAddr().String() addr1 := conn.LocalAddr().String()
addr2 := string(buf[:n]) addr2 := string(buf[:n])
conn.Close() conn.Close()
if addr1 != addr2 { if addr1 != addr2 {
t.Errorf("expect %s, got: %s", addr1, addr2) t.Errorf("expect %s, got: %s", addr1, addr2)
} }
} }

Loading…
Cancel
Save