commit 590ea98c0b1b82c89a7370e618e4658412d71af2 Author: Dingjun Date: Tue Aug 8 11:00:01 2017 +0800 first version diff --git a/common.go b/common.go new file mode 100644 index 0000000..4e3e644 --- /dev/null +++ b/common.go @@ -0,0 +1,980 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// this file is from $GOROOT/src/crypto/tls + +package main + +import ( + //"container/list" + //"crypto/tls" + //"crypto/internal/cipherhw" + //"crypto/rand" + //"crypto/sha512" + "crypto/x509" + //"errors" + //"fmt" + //"io" + //"math/big" + //"net" + //"strings" + //"sync" + //"time" +) + +/* +const ( + VersionSSL30 = 0x0300 + VersionTLS10 = 0x0301 + VersionTLS11 = 0x0302 + VersionTLS12 = 0x0303 +) + +const ( + maxPlaintext = 16384 // maximum plaintext payload length + maxCiphertext = 16384 + 2048 // maximum ciphertext payload length + recordHeaderLen = 5 // record header length + maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB) + + minVersion = VersionTLS10 + maxVersion = VersionTLS12 +) +*/ +// TLS record types. +type recordType uint8 + +const ( + recordTypeChangeCipherSpec recordType = 20 + recordTypeAlert recordType = 21 + recordTypeHandshake recordType = 22 + recordTypeApplicationData recordType = 23 +) + +// TLS handshake message types. +const ( + typeHelloRequest uint8 = 0 + typeClientHello uint8 = 1 + typeServerHello uint8 = 2 + typeNewSessionTicket uint8 = 4 + typeCertificate uint8 = 11 + typeServerKeyExchange uint8 = 12 + typeCertificateRequest uint8 = 13 + typeServerHelloDone uint8 = 14 + typeCertificateVerify uint8 = 15 + typeClientKeyExchange uint8 = 16 + typeFinished uint8 = 20 + typeCertificateStatus uint8 = 22 + typeNextProtocol uint8 = 67 // Not IANA assigned +) + +// TLS compression types. +const ( + compressionNone uint8 = 0 +) + +// TLS extension numbers +const ( + extensionServerName uint16 = 0 + extensionStatusRequest uint16 = 5 + extensionSupportedCurves uint16 = 10 + extensionSupportedPoints uint16 = 11 + extensionSignatureAlgorithms uint16 = 13 + extensionALPN uint16 = 16 + extensionSCT uint16 = 18 // https://tools.ietf.org/html/rfc6962#section-6 + extensionSessionTicket uint16 = 35 + extensionNextProtoNeg uint16 = 13172 // not IANA assigned + extensionRenegotiationInfo uint16 = 0xff01 +) + +// TLS signaling cipher suite values +const ( + scsvRenegotiation uint16 = 0x00ff +) + +// CurveID is the type of a TLS identifier for an elliptic curve. See +// http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-8 +type CurveID uint16 + +/* +const ( + CurveP256 CurveID = 23 + CurveP384 CurveID = 24 + CurveP521 CurveID = 25 + X25519 CurveID = 29 +) +*/ + +// TLS Elliptic Curve Point Formats +// http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9 +const ( + pointFormatUncompressed uint8 = 0 +) + +// TLS CertificateStatusType (RFC 3546) +const ( + statusTypeOCSP uint8 = 1 +) + +// Certificate types (for certificateRequestMsg) +const ( + certTypeRSASign = 1 // A certificate containing an RSA key + certTypeDSSSign = 2 // A certificate containing a DSA key + certTypeRSAFixedDH = 3 // A certificate containing a static DH key + certTypeDSSFixedDH = 4 // A certificate containing a static DH key + + // See RFC 4492 sections 3 and 5.5. + certTypeECDSASign = 64 // A certificate containing an ECDSA-capable public key, signed with ECDSA. + certTypeRSAFixedECDH = 65 // A certificate containing an ECDH-capable public key, signed with RSA. + certTypeECDSAFixedECDH = 66 // A certificate containing an ECDH-capable public key, signed with ECDSA. + + // Rest of these are reserved by the TLS spec +) + +// Hash functions for TLS 1.2 (See RFC 5246, section A.4.1) +const ( + hashSHA1 uint8 = 2 + hashSHA256 uint8 = 4 + hashSHA384 uint8 = 5 +) + +// Signature algorithms for TLS 1.2 (See RFC 5246, section A.4.1) +const ( + signatureRSA uint8 = 1 + signatureECDSA uint8 = 3 +) + +// signatureAndHash mirrors the TLS 1.2, SignatureAndHashAlgorithm struct. See +// RFC 5246, section A.4.1. +type signatureAndHash struct { + hash, signature uint8 +} + +// supportedSignatureAlgorithms contains the signature and hash algorithms that +// the code advertises as supported in a TLS 1.2 ClientHello and in a TLS 1.2 +// CertificateRequest. +var supportedSignatureAlgorithms = []signatureAndHash{ + {hashSHA256, signatureRSA}, + {hashSHA256, signatureECDSA}, + {hashSHA384, signatureRSA}, + {hashSHA384, signatureECDSA}, + {hashSHA1, signatureRSA}, + {hashSHA1, signatureECDSA}, +} + +// ConnectionState records basic TLS details about the connection. +type ConnectionState struct { + Version uint16 // TLS version used by the connection (e.g. VersionTLS12) + HandshakeComplete bool // TLS handshake is complete + DidResume bool // connection resumes a previous TLS connection + CipherSuite uint16 // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...) + NegotiatedProtocol string // negotiated next protocol (not guaranteed to be from Config.NextProtos) + NegotiatedProtocolIsMutual bool // negotiated protocol was advertised by server (client side only) + ServerName string // server name requested by client, if any (server side only) + PeerCertificates []*x509.Certificate // certificate chain presented by remote peer + VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates + SignedCertificateTimestamps [][]byte // SCTs from the server, if any + OCSPResponse []byte // stapled OCSP response from server, if any + + // TLSUnique contains the "tls-unique" channel binding value (see RFC + // 5929, section 3). For resumed sessions this value will be nil + // because resumption does not include enough context (see + // https://mitls.org/pages/attacks/3SHAKE#channelbindings). This will + // change in future versions of Go once the TLS master-secret fix has + // been standardized and implemented. + TLSUnique []byte +} + +// ClientAuthType declares the policy the server will follow for +// TLS Client Authentication. +type ClientAuthType int + +/* +const ( + NoClientCert ClientAuthType = iota + RequestClientCert + RequireAnyClientCert + VerifyClientCertIfGiven + RequireAndVerifyClientCert +) +*/ + +// ClientSessionState contains the state needed by clients to resume TLS +// sessions. +type ClientSessionState struct { + sessionTicket []uint8 // Encrypted ticket used for session resumption with server + vers uint16 // SSL/TLS version negotiated for the session + cipherSuite uint16 // Ciphersuite negotiated for the session + masterSecret []byte // MasterSecret generated by client on a full handshake + serverCertificates []*x509.Certificate // Certificate chain presented by the server + verifiedChains [][]*x509.Certificate // Certificate chains we built for verification +} + +// ClientSessionCache is a cache of ClientSessionState objects that can be used +// by a client to resume a TLS session with a given server. ClientSessionCache +// implementations should expect to be called concurrently from different +// goroutines. Only ticket-based resumption is supported, not SessionID-based +// resumption. +type ClientSessionCache interface { + // Get searches for a ClientSessionState associated with the given key. + // On return, ok is true if one was found. + Get(sessionKey string) (session *ClientSessionState, ok bool) + + // Put adds the ClientSessionState to the cache with the given key. + Put(sessionKey string, cs *ClientSessionState) +} + +/* +// SignatureScheme identifies a signature algorithm supported by TLS. See +// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.3. +type SignatureScheme uint16 + +const ( + PKCS1WithSHA1 SignatureScheme = 0x0201 + PKCS1WithSHA256 SignatureScheme = 0x0401 + PKCS1WithSHA384 SignatureScheme = 0x0501 + PKCS1WithSHA512 SignatureScheme = 0x0601 + + PSSWithSHA256 SignatureScheme = 0x0804 + PSSWithSHA384 SignatureScheme = 0x0805 + PSSWithSHA512 SignatureScheme = 0x0806 + + ECDSAWithP256AndSHA256 SignatureScheme = 0x0403 + ECDSAWithP384AndSHA384 SignatureScheme = 0x0503 + ECDSAWithP521AndSHA512 SignatureScheme = 0x0603 +) + +// ClientHelloInfo contains information from a ClientHello message in order to +// guide certificate selection in the GetCertificate callback. +type ClientHelloInfo struct { + // CipherSuites lists the CipherSuites supported by the client (e.g. + // TLS_RSA_WITH_RC4_128_SHA). + CipherSuites []uint16 + + // ServerName indicates the name of the server requested by the client + // in order to support virtual hosting. ServerName is only set if the + // client is using SNI (see + // http://tools.ietf.org/html/rfc4366#section-3.1). + ServerName string + + // SupportedCurves lists the elliptic curves supported by the client. + // SupportedCurves is set only if the Supported Elliptic Curves + // Extension is being used (see + // http://tools.ietf.org/html/rfc4492#section-5.1.1). + SupportedCurves []CurveID + + // SupportedPoints lists the point formats supported by the client. + // SupportedPoints is set only if the Supported Point Formats Extension + // is being used (see + // http://tools.ietf.org/html/rfc4492#section-5.1.2). + SupportedPoints []uint8 + + // SignatureSchemes lists the signature and hash schemes that the client + // is willing to verify. SignatureSchemes is set only if the Signature + // Algorithms Extension is being used (see + // https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1). + SignatureSchemes []SignatureScheme + + // SupportedProtos lists the application protocols supported by the client. + // SupportedProtos is set only if the Application-Layer Protocol + // Negotiation Extension is being used (see + // https://tools.ietf.org/html/rfc7301#section-3.1). + // + // Servers can select a protocol by setting Config.NextProtos in a + // GetConfigForClient return value. + SupportedProtos []string + + // SupportedVersions lists the TLS versions supported by the client. + // For TLS versions less than 1.3, this is extrapolated from the max + // version advertised by the client, so values other than the greatest + // might be rejected if used. + SupportedVersions []uint16 + + // Conn is the underlying net.Conn for the connection. Do not read + // from, or write to, this connection; that will cause the TLS + // connection to fail. + Conn net.Conn +} + +// CertificateRequestInfo contains information from a server's +// CertificateRequest message, which is used to demand a certificate and proof +// of control from a client. +type CertificateRequestInfo struct { + // AcceptableCAs contains zero or more, DER-encoded, X.501 + // Distinguished Names. These are the names of root or intermediate CAs + // that the server wishes the returned certificate to be signed by. An + // empty slice indicates that the server has no preference. + AcceptableCAs [][]byte + + // SignatureSchemes lists the signature schemes that the server is + // willing to verify. + SignatureSchemes []SignatureScheme +} + +// RenegotiationSupport enumerates the different levels of support for TLS +// renegotiation. TLS renegotiation is the act of performing subsequent +// handshakes on a connection after the first. This significantly complicates +// the state machine and has been the source of numerous, subtle security +// issues. Initiating a renegotiation is not supported, but support for +// accepting renegotiation requests may be enabled. +// +// Even when enabled, the server may not change its identity between handshakes +// (i.e. the leaf certificate must be the same). Additionally, concurrent +// handshake and application data flow is not permitted so renegotiation can +// only be used with protocols that synchronise with the renegotiation, such as +// HTTPS. +type RenegotiationSupport int + +const ( + // RenegotiateNever disables renegotiation. + RenegotiateNever RenegotiationSupport = iota + + // RenegotiateOnceAsClient allows a remote server to request + // renegotiation once per connection. + RenegotiateOnceAsClient + + // RenegotiateFreelyAsClient allows a remote server to repeatedly + // request renegotiation. + RenegotiateFreelyAsClient +) + +// A Config structure is used to configure a TLS client or server. +// After one has been passed to a TLS function it must not be +// modified. A Config may be reused; the tls package will also not +// modify it. +type Config struct { + // Rand provides the source of entropy for nonces and RSA blinding. + // If Rand is nil, TLS uses the cryptographic random reader in package + // crypto/rand. + // The Reader must be safe for use by multiple goroutines. + Rand io.Reader + + // Time returns the current time as the number of seconds since the epoch. + // If Time is nil, TLS uses time.Now. + Time func() time.Time + + // Certificates contains one or more certificate chains to present to + // the other side of the connection. Server configurations must include + // at least one certificate or else set GetCertificate. Clients doing + // client-authentication may set either Certificates or + // GetClientCertificate. + Certificates []tls.Certificate + + // NameToCertificate maps from a certificate name to an element of + // Certificates. Note that a certificate name can be of the form + // '*.example.com' and so doesn't have to be a domain name as such. + // See Config.BuildNameToCertificate + // The nil value causes the first element of Certificates to be used + // for all connections. + NameToCertificate map[string]*tls.Certificate + + // GetCertificate returns a Certificate based on the given + // ClientHelloInfo. It will only be called if the client supplies SNI + // information or if Certificates is empty. + // + // If GetCertificate is nil or returns nil, then the certificate is + // retrieved from NameToCertificate. If NameToCertificate is nil, the + // first element of Certificates will be used. + GetCertificate func(*ClientHelloInfo) (*tls.Certificate, error) + + // GetClientCertificate, if not nil, is called when a server requests a + // certificate from a client. If set, the contents of Certificates will + // be ignored. + // + // If GetClientCertificate returns an error, the handshake will be + // aborted and that error will be returned. Otherwise + // GetClientCertificate must return a non-nil Certificate. If + // Certificate.Certificate is empty then no certificate will be sent to + // the server. If this is unacceptable to the server then it may abort + // the handshake. + // + // GetClientCertificate may be called multiple times for the same + // connection if renegotiation occurs or if TLS 1.3 is in use. + GetClientCertificate func(*CertificateRequestInfo) (*tls.Certificate, error) + + // GetConfigForClient, if not nil, is called after a ClientHello is + // received from a client. It may return a non-nil Config in order to + // change the Config that will be used to handle this connection. If + // the returned Config is nil, the original Config will be used. The + // Config returned by this callback may not be subsequently modified. + // + // If GetConfigForClient is nil, the Config passed to Server() will be + // used for all connections. + // + // Uniquely for the fields in the returned Config, session ticket keys + // will be duplicated from the original Config if not set. + // Specifically, if SetSessionTicketKeys was called on the original + // config but not on the returned config then the ticket keys from the + // original config will be copied into the new config before use. + // Otherwise, if SessionTicketKey was set in the original config but + // not in the returned config then it will be copied into the returned + // config before use. If neither of those cases applies then the key + // material from the returned config will be used for session tickets. + GetConfigForClient func(*ClientHelloInfo) (*Config, error) + + // VerifyPeerCertificate, if not nil, is called after normal + // certificate verification by either a TLS client or server. It + // receives the raw ASN.1 certificates provided by the peer and also + // any verified chains that normal processing found. If it returns a + // non-nil error, the handshake is aborted and that error results. + // + // If normal verification fails then the handshake will abort before + // considering this callback. If normal verification is disabled by + // setting InsecureSkipVerify then this callback will be considered but + // the verifiedChains argument will always be nil. + VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error + + // RootCAs defines the set of root certificate authorities + // that clients use when verifying server certificates. + // If RootCAs is nil, TLS uses the host's root CA set. + RootCAs *x509.CertPool + + // NextProtos is a list of supported, application level protocols. + NextProtos []string + + // ServerName is used to verify the hostname on the returned + // certificates unless InsecureSkipVerify is given. It is also included + // in the client's handshake to support virtual hosting unless it is + // an IP address. + ServerName string + + // ClientAuth determines the server's policy for + // TLS Client Authentication. The default is NoClientCert. + ClientAuth ClientAuthType + + // ClientCAs defines the set of root certificate authorities + // that servers use if required to verify a client certificate + // by the policy in ClientAuth. + ClientCAs *x509.CertPool + + // InsecureSkipVerify controls whether a client verifies the + // server's certificate chain and host name. + // If InsecureSkipVerify is true, TLS accepts any certificate + // presented by the server and any host name in that certificate. + // In this mode, TLS is susceptible to man-in-the-middle attacks. + // This should be used only for testing. + InsecureSkipVerify bool + + // CipherSuites is a list of supported cipher suites. If CipherSuites + // is nil, TLS uses a list of suites supported by the implementation. + CipherSuites []uint16 + + // PreferServerCipherSuites controls whether the server selects the + // client's most preferred ciphersuite, or the server's most preferred + // ciphersuite. If true then the server's preference, as expressed in + // the order of elements in CipherSuites, is used. + PreferServerCipherSuites bool + + // SessionTicketsDisabled may be set to true to disable session ticket + // (resumption) support. + SessionTicketsDisabled bool + + // SessionTicketKey is used by TLS servers to provide session + // resumption. See RFC 5077. If zero, it will be filled with + // random data before the first server handshake. + // + // If multiple servers are terminating connections for the same host + // they should all have the same SessionTicketKey. If the + // SessionTicketKey leaks, previously recorded and future TLS + // connections using that key are compromised. + SessionTicketKey [32]byte + + // SessionCache is a cache of ClientSessionState entries for TLS session + // resumption. + ClientSessionCache ClientSessionCache + + // MinVersion contains the minimum SSL/TLS version that is acceptable. + // If zero, then TLS 1.0 is taken as the minimum. + MinVersion uint16 + + // MaxVersion contains the maximum SSL/TLS version that is acceptable. + // If zero, then the maximum version supported by this package is used, + // which is currently TLS 1.2. + MaxVersion uint16 + + // CurvePreferences contains the elliptic curves that will be used in + // an ECDHE handshake, in preference order. If empty, the default will + // be used. + CurvePreferences []CurveID + + // DynamicRecordSizingDisabled disables adaptive sizing of TLS records. + // When true, the largest possible TLS record size is always used. When + // false, the size of TLS records may be adjusted in an attempt to + // improve latency. + DynamicRecordSizingDisabled bool + + // Renegotiation controls what types of renegotiation are supported. + // The default, none, is correct for the vast majority of applications. + Renegotiation RenegotiationSupport + + // KeyLogWriter optionally specifies a destination for TLS master secrets + // in NSS key log format that can be used to allow external programs + // such as Wireshark to decrypt TLS connections. + // See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format. + // Use of KeyLogWriter compromises security and should only be + // used for debugging. + KeyLogWriter io.Writer + + serverInitOnce sync.Once // guards calling (*Config).serverInit + + // mutex protects sessionTicketKeys. + mutex sync.RWMutex + // sessionTicketKeys contains zero or more ticket keys. If the length + // is zero, SessionTicketsDisabled must be true. The first key is used + // for new tickets and any subsequent keys can be used to decrypt old + // tickets. + sessionTicketKeys []ticketKey +} + +// ticketKeyNameLen is the number of bytes of identifier that is prepended to +// an encrypted session ticket in order to identify the key used to encrypt it. +const ticketKeyNameLen = 16 + +// ticketKey is the internal representation of a session ticket key. +type ticketKey struct { + // keyName is an opaque byte string that serves to identify the session + // ticket key. It's exposed as plaintext in every session ticket. + keyName [ticketKeyNameLen]byte + aesKey [16]byte + hmacKey [16]byte +} + +// ticketKeyFromBytes converts from the external representation of a session +// ticket key to a ticketKey. Externally, session ticket keys are 32 random +// bytes and this function expands that into sufficient name and key material. +func ticketKeyFromBytes(b [32]byte) (key ticketKey) { + hashed := sha512.Sum512(b[:]) + copy(key.keyName[:], hashed[:ticketKeyNameLen]) + copy(key.aesKey[:], hashed[ticketKeyNameLen:ticketKeyNameLen+16]) + copy(key.hmacKey[:], hashed[ticketKeyNameLen+16:ticketKeyNameLen+32]) + return key +} + +// Clone returns a shallow clone of c. It is safe to clone a Config that is +// being used concurrently by a TLS client or server. +func (c *Config) Clone() *Config { + // Running serverInit ensures that it's safe to read + // SessionTicketsDisabled. + c.serverInitOnce.Do(func() { c.serverInit(nil) }) + + var sessionTicketKeys []ticketKey + c.mutex.RLock() + sessionTicketKeys = c.sessionTicketKeys + c.mutex.RUnlock() + + return &Config{ + Rand: c.Rand, + Time: c.Time, + Certificates: c.Certificates, + NameToCertificate: c.NameToCertificate, + GetCertificate: c.GetCertificate, + GetClientCertificate: c.GetClientCertificate, + GetConfigForClient: c.GetConfigForClient, + VerifyPeerCertificate: c.VerifyPeerCertificate, + RootCAs: c.RootCAs, + NextProtos: c.NextProtos, + ServerName: c.ServerName, + ClientAuth: c.ClientAuth, + ClientCAs: c.ClientCAs, + InsecureSkipVerify: c.InsecureSkipVerify, + CipherSuites: c.CipherSuites, + PreferServerCipherSuites: c.PreferServerCipherSuites, + SessionTicketsDisabled: c.SessionTicketsDisabled, + SessionTicketKey: c.SessionTicketKey, + ClientSessionCache: c.ClientSessionCache, + MinVersion: c.MinVersion, + MaxVersion: c.MaxVersion, + CurvePreferences: c.CurvePreferences, + DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, + Renegotiation: c.Renegotiation, + KeyLogWriter: c.KeyLogWriter, + sessionTicketKeys: sessionTicketKeys, + } +} + +// serverInit is run under c.serverInitOnce to do initialization of c. If c was +// returned by a GetConfigForClient callback then the argument should be the +// Config that was passed to Server, otherwise it should be nil. +func (c *Config) serverInit(originalConfig *Config) { + if c.SessionTicketsDisabled || len(c.ticketKeys()) != 0 { + return + } + + alreadySet := false + for _, b := range c.SessionTicketKey { + if b != 0 { + alreadySet = true + break + } + } + + if !alreadySet { + if originalConfig != nil { + copy(c.SessionTicketKey[:], originalConfig.SessionTicketKey[:]) + } else if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil { + c.SessionTicketsDisabled = true + return + } + } + + if originalConfig != nil { + originalConfig.mutex.RLock() + c.sessionTicketKeys = originalConfig.sessionTicketKeys + originalConfig.mutex.RUnlock() + } else { + c.sessionTicketKeys = []ticketKey{ticketKeyFromBytes(c.SessionTicketKey)} + } +} + +func (c *Config) ticketKeys() []ticketKey { + c.mutex.RLock() + // c.sessionTicketKeys is constant once created. SetSessionTicketKeys + // will only update it by replacing it with a new value. + ret := c.sessionTicketKeys + c.mutex.RUnlock() + return ret +} + +// SetSessionTicketKeys updates the session ticket keys for a server. The first +// key will be used when creating new tickets, while all keys can be used for +// decrypting tickets. It is safe to call this function while the server is +// running in order to rotate the session ticket keys. The function will panic +// if keys is empty. +func (c *Config) SetSessionTicketKeys(keys [][32]byte) { + if len(keys) == 0 { + panic("tls: keys must have at least one key") + } + + newKeys := make([]ticketKey, len(keys)) + for i, bytes := range keys { + newKeys[i] = ticketKeyFromBytes(bytes) + } + + c.mutex.Lock() + c.sessionTicketKeys = newKeys + c.mutex.Unlock() +} + +func (c *Config) rand() io.Reader { + r := c.Rand + if r == nil { + return rand.Reader + } + return r +} + +func (c *Config) time() time.Time { + t := c.Time + if t == nil { + t = time.Now + } + return t() +} + +func (c *Config) cipherSuites() []uint16 { + s := c.CipherSuites + if s == nil { + s = defaultCipherSuites() + } + return s +} + +func (c *Config) minVersion() uint16 { + if c == nil || c.MinVersion == 0 { + return minVersion + } + return c.MinVersion +} + +func (c *Config) maxVersion() uint16 { + if c == nil || c.MaxVersion == 0 { + return maxVersion + } + return c.MaxVersion +} + +var defaultCurvePreferences = []CurveID{X25519, CurveP256, CurveP384, CurveP521} + +func (c *Config) curvePreferences() []CurveID { + if c == nil || len(c.CurvePreferences) == 0 { + return defaultCurvePreferences + } + return c.CurvePreferences +} + +// mutualVersion returns the protocol version to use given the advertised +// version of the peer. +func (c *Config) mutualVersion(vers uint16) (uint16, bool) { + minVersion := c.minVersion() + maxVersion := c.maxVersion() + + if vers < minVersion { + return 0, false + } + if vers > maxVersion { + vers = maxVersion + } + return vers, true +} + +// getCertificate returns the best certificate for the given ClientHelloInfo, +// defaulting to the first element of c.Certificates. +func (c *Config) getCertificate(clientHello *ClientHelloInfo) (*Certificate, error) { + if c.GetCertificate != nil && + (len(c.Certificates) == 0 || len(clientHello.ServerName) > 0) { + cert, err := c.GetCertificate(clientHello) + if cert != nil || err != nil { + return cert, err + } + } + + if len(c.Certificates) == 0 { + return nil, errors.New("tls: no certificates configured") + } + + if len(c.Certificates) == 1 || c.NameToCertificate == nil { + // There's only one choice, so no point doing any work. + return &c.Certificates[0], nil + } + + name := strings.ToLower(clientHello.ServerName) + for len(name) > 0 && name[len(name)-1] == '.' { + name = name[:len(name)-1] + } + + if cert, ok := c.NameToCertificate[name]; ok { + return cert, nil + } + + // try replacing labels in the name with wildcards until we get a + // match. + labels := strings.Split(name, ".") + for i := range labels { + labels[i] = "*" + candidate := strings.Join(labels, ".") + if cert, ok := c.NameToCertificate[candidate]; ok { + return cert, nil + } + } + + // If nothing matches, return the first certificate. + return &c.Certificates[0], nil +} + +// BuildNameToCertificate parses c.Certificates and builds c.NameToCertificate +// from the CommonName and SubjectAlternateName fields of each of the leaf +// certificates. +func (c *Config) BuildNameToCertificate() { + c.NameToCertificate = make(map[string]*Certificate) + for i := range c.Certificates { + cert := &c.Certificates[i] + x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + continue + } + if len(x509Cert.Subject.CommonName) > 0 { + c.NameToCertificate[x509Cert.Subject.CommonName] = cert + } + for _, san := range x509Cert.DNSNames { + c.NameToCertificate[san] = cert + } + } +} + +// writeKeyLog logs client random and master secret if logging was enabled by +// setting c.KeyLogWriter. +func (c *Config) writeKeyLog(clientRandom, masterSecret []byte) error { + if c.KeyLogWriter == nil { + return nil + } + + logLine := []byte(fmt.Sprintf("CLIENT_RANDOM %x %x\n", clientRandom, masterSecret)) + + writerMutex.Lock() + _, err := c.KeyLogWriter.Write(logLine) + writerMutex.Unlock() + + return err +} + +// writerMutex protects all KeyLogWriters globally. It is rarely enabled, +// and is only for debugging, so a global mutex saves space. +var writerMutex sync.Mutex + +// A Certificate is a chain of one or more certificates, leaf first. +type Certificate struct { + Certificate [][]byte + // PrivateKey contains the private key corresponding to the public key + // in Leaf. For a server, this must implement crypto.Signer and/or + // crypto.Decrypter, with an RSA or ECDSA PublicKey. For a client + // (performing client authentication), this must be a crypto.Signer + // with an RSA or ECDSA PublicKey. + PrivateKey crypto.PrivateKey + // OCSPStaple contains an optional OCSP response which will be served + // to clients that request it. + OCSPStaple []byte + // SignedCertificateTimestamps contains an optional list of Signed + // Certificate Timestamps which will be served to clients that request it. + SignedCertificateTimestamps [][]byte + // Leaf is the parsed form of the leaf certificate, which may be + // initialized using x509.ParseCertificate to reduce per-handshake + // processing for TLS clients doing client authentication. If nil, the + // leaf certificate will be parsed as needed. + Leaf *x509.Certificate +} + +type handshakeMessage interface { + marshal() []byte + unmarshal([]byte) bool +} + +// lruSessionCache is a ClientSessionCache implementation that uses an LRU +// caching strategy. +type lruSessionCache struct { + sync.Mutex + + m map[string]*list.Element + q *list.List + capacity int +} + +type lruSessionCacheEntry struct { + sessionKey string + state *ClientSessionState +} + +// NewLRUClientSessionCache returns a ClientSessionCache with the given +// capacity that uses an LRU strategy. If capacity is < 1, a default capacity +// is used instead. +func NewLRUClientSessionCache(capacity int) ClientSessionCache { + const defaultSessionCacheCapacity = 64 + + if capacity < 1 { + capacity = defaultSessionCacheCapacity + } + return &lruSessionCache{ + m: make(map[string]*list.Element), + q: list.New(), + capacity: capacity, + } +} + +// Put adds the provided (sessionKey, cs) pair to the cache. +func (c *lruSessionCache) Put(sessionKey string, cs *ClientSessionState) { + c.Lock() + defer c.Unlock() + + if elem, ok := c.m[sessionKey]; ok { + entry := elem.Value.(*lruSessionCacheEntry) + entry.state = cs + c.q.MoveToFront(elem) + return + } + + if c.q.Len() < c.capacity { + entry := &lruSessionCacheEntry{sessionKey, cs} + c.m[sessionKey] = c.q.PushFront(entry) + return + } + + elem := c.q.Back() + entry := elem.Value.(*lruSessionCacheEntry) + delete(c.m, entry.sessionKey) + entry.sessionKey = sessionKey + entry.state = cs + c.q.MoveToFront(elem) + c.m[sessionKey] = elem +} + +// Get returns the ClientSessionState value associated with a given key. It +// returns (nil, false) if no value is found. +func (c *lruSessionCache) Get(sessionKey string) (*ClientSessionState, bool) { + c.Lock() + defer c.Unlock() + + if elem, ok := c.m[sessionKey]; ok { + c.q.MoveToFront(elem) + return elem.Value.(*lruSessionCacheEntry).state, true + } + return nil, false +} + +// TODO(jsing): Make these available to both crypto/x509 and crypto/tls. +type dsaSignature struct { + R, S *big.Int +} + +type ecdsaSignature dsaSignature + +var emptyConfig Config + +func defaultConfig() *Config { + return &emptyConfig +} + +var ( + once sync.Once + varDefaultCipherSuites []uint16 +) + +func defaultCipherSuites() []uint16 { + once.Do(initDefaultCipherSuites) + return varDefaultCipherSuites +} + +func initDefaultCipherSuites() { + var topCipherSuites []uint16 + if cipherhw.AESGCMSupport() { + // If AES-GCM hardware is provided then prioritise AES-GCM + // cipher suites. + topCipherSuites = []uint16{ + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + } + } else { + // Without AES-GCM hardware, we put the ChaCha20-Poly1305 + // cipher suites first. + topCipherSuites = []uint16{ + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + } + } + + varDefaultCipherSuites = make([]uint16, 0, len(cipherSuites)) + varDefaultCipherSuites = append(varDefaultCipherSuites, topCipherSuites...) + +NextCipherSuite: + for _, suite := range cipherSuites { + if suite.flags&suiteDefaultOff != 0 { + continue + } + for _, existing := range varDefaultCipherSuites { + if existing == suite.id { + continue NextCipherSuite + } + } + varDefaultCipherSuites = append(varDefaultCipherSuites, suite.id) + } +} + +func unexpectedMessageError(wanted, got interface{}) error { + return fmt.Errorf("tls: received unexpected handshake message of type %T when waiting for %T", got, wanted) +} + +func isSupportedSignatureAndHash(sigHash signatureAndHash, sigHashes []signatureAndHash) bool { + for _, s := range sigHashes { + if s == sigHash { + return true + } + } + return false +} +*/ diff --git a/conf.go b/conf.go new file mode 100644 index 0000000..e7a93e4 --- /dev/null +++ b/conf.go @@ -0,0 +1,35 @@ +package main + +import ( + "fmt" + "strings" +) + +type forwardRules map[string]string + +func (fr forwardRules) Get(uri string) string { + if fw, ok := fr[uri]; ok { + return fw + } + return "" +} + +func (fr forwardRules) GetN(uri string, port int) string { + fw := fr.Get(uri) + if fw != "" { + return fw + } + + if fw = fr.Get(fmt.Sprintf("*:%d", port)); fw != "" { + fw = strings.Replace(fw, "*", uri, 1) + return fw + } + return "" +} + +type conf struct { + Timeout int + Listen []int + Default string + ForwardRules forwardRules `yaml:"forward_rules"` +} diff --git a/conf_test.go b/conf_test.go new file mode 100644 index 0000000..70cf334 --- /dev/null +++ b/conf_test.go @@ -0,0 +1,35 @@ +package main + +import ( + "fmt" + "github.com/go-yaml/yaml" + "io/ioutil" + "testing" +) + +func TestConf(t *testing.T) { + data, err := ioutil.ReadFile("config.sample.yaml") + if err != nil { + t.Fatal(err) + } + var c conf + if err := yaml.Unmarshal(data, &c); err != nil { + t.Fatal(err) + } + fmt.Printf("%+v\n", c) + var testdata = map[string]string{ + "www.example.com": "127.0.0.1:8443", + "b.example.com": "127.0.0.1:8541", + } + r := forwardRules(c.ForwardRules) + for k, v := range testdata { + s := r.Get(k) + if s != v { + t.Errorf("expected: %s, got: %s", v, s) + } + } + + if r.GetN("a.com", 9999) != "a.com:443" { + t.Errorf("expected a.com:9999, got %s", r.GetN("a.com", 9999)) + } +} diff --git a/config.sample.yaml b/config.sample.yaml new file mode 100644 index 0000000..6e3e14f --- /dev/null +++ b/config.sample.yaml @@ -0,0 +1,13 @@ + + +timeout: 10 +default: 127.0.0.1:8443 + +listen: + - 443 + - 9999 + +forward_rules: + www.example.com: 127.0.0.1:8443 + b.example.com: 127.0.0.1:8541 + "*:9999": "*:443" diff --git a/handshake_messages.go b/handshake_messages.go new file mode 100644 index 0000000..558b81f --- /dev/null +++ b/handshake_messages.go @@ -0,0 +1,1573 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// this file is from $GOROOT/src/crypto/tls + +package main + +import ( + //"bytes" + "strings" +) + +type clientHelloMsg struct { + raw []byte + vers uint16 + random []byte + sessionID []byte + cipherSuites []uint16 + compressionMethods []uint8 + nextProtoNeg bool + serverName string + ocspStapling bool + scts bool + supportedCurves []CurveID + supportedPoints []uint8 + ticketSupported bool + sessionTicket []uint8 + signatureAndHashes []signatureAndHash + secureRenegotiation []byte + secureRenegotiationSupported bool + alpnProtocols []string +} + +/* +func (m *clientHelloMsg) equal(i interface{}) bool { + m1, ok := i.(*clientHelloMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.vers == m1.vers && + bytes.Equal(m.random, m1.random) && + bytes.Equal(m.sessionId, m1.sessionId) && + eqUint16s(m.cipherSuites, m1.cipherSuites) && + bytes.Equal(m.compressionMethods, m1.compressionMethods) && + m.nextProtoNeg == m1.nextProtoNeg && + m.serverName == m1.serverName && + m.ocspStapling == m1.ocspStapling && + m.scts == m1.scts && + eqCurveIDs(m.supportedCurves, m1.supportedCurves) && + bytes.Equal(m.supportedPoints, m1.supportedPoints) && + m.ticketSupported == m1.ticketSupported && + bytes.Equal(m.sessionTicket, m1.sessionTicket) && + eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) && + m.secureRenegotiationSupported == m1.secureRenegotiationSupported && + bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) && + eqStrings(m.alpnProtocols, m1.alpnProtocols) +} + +func (m *clientHelloMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods) + numExtensions := 0 + extensionsLength := 0 + if m.nextProtoNeg { + numExtensions++ + } + if m.ocspStapling { + extensionsLength += 1 + 2 + 2 + numExtensions++ + } + if len(m.serverName) > 0 { + extensionsLength += 5 + len(m.serverName) + numExtensions++ + } + if len(m.supportedCurves) > 0 { + extensionsLength += 2 + 2*len(m.supportedCurves) + numExtensions++ + } + if len(m.supportedPoints) > 0 { + extensionsLength += 1 + len(m.supportedPoints) + numExtensions++ + } + if m.ticketSupported { + extensionsLength += len(m.sessionTicket) + numExtensions++ + } + if len(m.signatureAndHashes) > 0 { + extensionsLength += 2 + 2*len(m.signatureAndHashes) + numExtensions++ + } + if m.secureRenegotiationSupported { + extensionsLength += 1 + len(m.secureRenegotiation) + numExtensions++ + } + if len(m.alpnProtocols) > 0 { + extensionsLength += 2 + for _, s := range m.alpnProtocols { + if l := len(s); l == 0 || l > 255 { + panic("invalid ALPN protocol") + } + extensionsLength++ + extensionsLength += len(s) + } + numExtensions++ + } + if m.scts { + numExtensions++ + } + if numExtensions > 0 { + extensionsLength += 4 * numExtensions + length += 2 + extensionsLength + } + + x := make([]byte, 4+length) + x[0] = typeClientHello + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + x[4] = uint8(m.vers >> 8) + x[5] = uint8(m.vers) + copy(x[6:38], m.random) + x[38] = uint8(len(m.sessionId)) + copy(x[39:39+len(m.sessionId)], m.sessionId) + y := x[39+len(m.sessionId):] + y[0] = uint8(len(m.cipherSuites) >> 7) + y[1] = uint8(len(m.cipherSuites) << 1) + for i, suite := range m.cipherSuites { + y[2+i*2] = uint8(suite >> 8) + y[3+i*2] = uint8(suite) + } + z := y[2+len(m.cipherSuites)*2:] + z[0] = uint8(len(m.compressionMethods)) + copy(z[1:], m.compressionMethods) + + z = z[1+len(m.compressionMethods):] + if numExtensions > 0 { + z[0] = byte(extensionsLength >> 8) + z[1] = byte(extensionsLength) + z = z[2:] + } + if m.nextProtoNeg { + z[0] = byte(extensionNextProtoNeg >> 8) + z[1] = byte(extensionNextProtoNeg & 0xff) + // The length is always 0 + z = z[4:] + } + if len(m.serverName) > 0 { + z[0] = byte(extensionServerName >> 8) + z[1] = byte(extensionServerName & 0xff) + l := len(m.serverName) + 5 + z[2] = byte(l >> 8) + z[3] = byte(l) + z = z[4:] + + // RFC 3546, section 3.1 + // + // struct { + // NameType name_type; + // select (name_type) { + // case host_name: HostName; + // } name; + // } ServerName; + // + // enum { + // host_name(0), (255) + // } NameType; + // + // opaque HostName<1..2^16-1>; + // + // struct { + // ServerName server_name_list<1..2^16-1> + // } ServerNameList; + + z[0] = byte((len(m.serverName) + 3) >> 8) + z[1] = byte(len(m.serverName) + 3) + z[3] = byte(len(m.serverName) >> 8) + z[4] = byte(len(m.serverName)) + copy(z[5:], []byte(m.serverName)) + z = z[l:] + } + if m.ocspStapling { + // RFC 4366, section 3.6 + z[0] = byte(extensionStatusRequest >> 8) + z[1] = byte(extensionStatusRequest) + z[2] = 0 + z[3] = 5 + z[4] = 1 // OCSP type + // Two zero valued uint16s for the two lengths. + z = z[9:] + } + if len(m.supportedCurves) > 0 { + // http://tools.ietf.org/html/rfc4492#section-5.5.1 + z[0] = byte(extensionSupportedCurves >> 8) + z[1] = byte(extensionSupportedCurves) + l := 2 + 2*len(m.supportedCurves) + z[2] = byte(l >> 8) + z[3] = byte(l) + l -= 2 + z[4] = byte(l >> 8) + z[5] = byte(l) + z = z[6:] + for _, curve := range m.supportedCurves { + z[0] = byte(curve >> 8) + z[1] = byte(curve) + z = z[2:] + } + } + if len(m.supportedPoints) > 0 { + // http://tools.ietf.org/html/rfc4492#section-5.5.2 + z[0] = byte(extensionSupportedPoints >> 8) + z[1] = byte(extensionSupportedPoints) + l := 1 + len(m.supportedPoints) + z[2] = byte(l >> 8) + z[3] = byte(l) + l-- + z[4] = byte(l) + z = z[5:] + for _, pointFormat := range m.supportedPoints { + z[0] = pointFormat + z = z[1:] + } + } + if m.ticketSupported { + // http://tools.ietf.org/html/rfc5077#section-3.2 + z[0] = byte(extensionSessionTicket >> 8) + z[1] = byte(extensionSessionTicket) + l := len(m.sessionTicket) + z[2] = byte(l >> 8) + z[3] = byte(l) + z = z[4:] + copy(z, m.sessionTicket) + z = z[len(m.sessionTicket):] + } + if len(m.signatureAndHashes) > 0 { + // https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1 + z[0] = byte(extensionSignatureAlgorithms >> 8) + z[1] = byte(extensionSignatureAlgorithms) + l := 2 + 2*len(m.signatureAndHashes) + z[2] = byte(l >> 8) + z[3] = byte(l) + z = z[4:] + + l -= 2 + z[0] = byte(l >> 8) + z[1] = byte(l) + z = z[2:] + for _, sigAndHash := range m.signatureAndHashes { + z[0] = sigAndHash.hash + z[1] = sigAndHash.signature + z = z[2:] + } + } + if m.secureRenegotiationSupported { + z[0] = byte(extensionRenegotiationInfo >> 8) + z[1] = byte(extensionRenegotiationInfo & 0xff) + z[2] = 0 + z[3] = byte(len(m.secureRenegotiation) + 1) + z[4] = byte(len(m.secureRenegotiation)) + z = z[5:] + copy(z, m.secureRenegotiation) + z = z[len(m.secureRenegotiation):] + } + if len(m.alpnProtocols) > 0 { + z[0] = byte(extensionALPN >> 8) + z[1] = byte(extensionALPN & 0xff) + lengths := z[2:] + z = z[6:] + + stringsLength := 0 + for _, s := range m.alpnProtocols { + l := len(s) + z[0] = byte(l) + copy(z[1:], s) + z = z[1+l:] + stringsLength += 1 + l + } + + lengths[2] = byte(stringsLength >> 8) + lengths[3] = byte(stringsLength) + stringsLength += 2 + lengths[0] = byte(stringsLength >> 8) + lengths[1] = byte(stringsLength) + } + if m.scts { + // https://tools.ietf.org/html/rfc6962#section-3.3.1 + z[0] = byte(extensionSCT >> 8) + z[1] = byte(extensionSCT) + // zero uint16 for the zero-length extension_data + z = z[4:] + } + + m.raw = x + + return x +} +*/ +func (m *clientHelloMsg) unmarshal(data []byte) bool { + if len(data) < 42 { + return false + } + m.raw = data + m.vers = uint16(data[4])<<8 | uint16(data[5]) + m.random = data[6:38] + sessionIDLen := int(data[38]) + if sessionIDLen > 32 || len(data) < 39+sessionIDLen { + return false + } + m.sessionID = data[39 : 39+sessionIDLen] + data = data[39+sessionIDLen:] + if len(data) < 2 { + return false + } + // cipherSuiteLen is the number of bytes of cipher suite numbers. Since + // they are uint16s, the number must be even. + cipherSuiteLen := int(data[0])<<8 | int(data[1]) + if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen { + return false + } + numCipherSuites := cipherSuiteLen / 2 + m.cipherSuites = make([]uint16, numCipherSuites) + for i := 0; i < numCipherSuites; i++ { + m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i]) + if m.cipherSuites[i] == scsvRenegotiation { + m.secureRenegotiationSupported = true + } + } + data = data[2+cipherSuiteLen:] + if len(data) < 1 { + return false + } + compressionMethodsLen := int(data[0]) + if len(data) < 1+compressionMethodsLen { + return false + } + m.compressionMethods = data[1 : 1+compressionMethodsLen] + + data = data[1+compressionMethodsLen:] + + m.nextProtoNeg = false + m.serverName = "" + m.ocspStapling = false + m.ticketSupported = false + m.sessionTicket = nil + m.signatureAndHashes = nil + m.alpnProtocols = nil + m.scts = false + + if len(data) == 0 { + // ClientHello is optionally followed by extension data + return true + } + if len(data) < 2 { + return false + } + + extensionsLength := int(data[0])<<8 | int(data[1]) + data = data[2:] + if extensionsLength != len(data) { + return false + } + + for len(data) != 0 { + if len(data) < 4 { + return false + } + extension := uint16(data[0])<<8 | uint16(data[1]) + length := int(data[2])<<8 | int(data[3]) + data = data[4:] + if len(data) < length { + return false + } + + switch extension { + case extensionServerName: + d := data[:length] + if len(d) < 2 { + return false + } + namesLen := int(d[0])<<8 | int(d[1]) + d = d[2:] + if len(d) != namesLen { + return false + } + for len(d) > 0 { + if len(d) < 3 { + return false + } + nameType := d[0] + nameLen := int(d[1])<<8 | int(d[2]) + d = d[3:] + if len(d) < nameLen { + return false + } + if nameType == 0 { + m.serverName = string(d[:nameLen]) + // An SNI value may not include a + // trailing dot. See + // https://tools.ietf.org/html/rfc6066#section-3. + if strings.HasSuffix(m.serverName, ".") { + return false + } + break + } + d = d[nameLen:] + } + case extensionNextProtoNeg: + if length > 0 { + return false + } + m.nextProtoNeg = true + case extensionStatusRequest: + m.ocspStapling = length > 0 && data[0] == statusTypeOCSP + case extensionSupportedCurves: + // http://tools.ietf.org/html/rfc4492#section-5.5.1 + if length < 2 { + return false + } + l := int(data[0])<<8 | int(data[1]) + if l%2 == 1 || length != l+2 { + return false + } + numCurves := l / 2 + m.supportedCurves = make([]CurveID, numCurves) + d := data[2:] + for i := 0; i < numCurves; i++ { + m.supportedCurves[i] = CurveID(d[0])<<8 | CurveID(d[1]) + d = d[2:] + } + case extensionSupportedPoints: + // http://tools.ietf.org/html/rfc4492#section-5.5.2 + if length < 1 { + return false + } + l := int(data[0]) + if length != l+1 { + return false + } + m.supportedPoints = make([]uint8, l) + copy(m.supportedPoints, data[1:]) + case extensionSessionTicket: + // http://tools.ietf.org/html/rfc5077#section-3.2 + m.ticketSupported = true + m.sessionTicket = data[:length] + case extensionSignatureAlgorithms: + // https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1 + if length < 2 || length&1 != 0 { + return false + } + l := int(data[0])<<8 | int(data[1]) + if l != length-2 { + return false + } + n := l / 2 + d := data[2:] + m.signatureAndHashes = make([]signatureAndHash, n) + for i := range m.signatureAndHashes { + m.signatureAndHashes[i].hash = d[0] + m.signatureAndHashes[i].signature = d[1] + d = d[2:] + } + case extensionRenegotiationInfo: + if length == 0 { + return false + } + d := data[:length] + l := int(d[0]) + d = d[1:] + if l != len(d) { + return false + } + + m.secureRenegotiation = d + m.secureRenegotiationSupported = true + case extensionALPN: + if length < 2 { + return false + } + l := int(data[0])<<8 | int(data[1]) + if l != length-2 { + return false + } + d := data[2:length] + for len(d) != 0 { + stringLen := int(d[0]) + d = d[1:] + if stringLen == 0 || stringLen > len(d) { + return false + } + m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen])) + d = d[stringLen:] + } + case extensionSCT: + m.scts = true + if length != 0 { + return false + } + } + data = data[length:] + } + + return true +} + +/* +type serverHelloMsg struct { + raw []byte + vers uint16 + random []byte + sessionId []byte + cipherSuite uint16 + compressionMethod uint8 + nextProtoNeg bool + nextProtos []string + ocspStapling bool + scts [][]byte + ticketSupported bool + secureRenegotiation []byte + secureRenegotiationSupported bool + alpnProtocol string +} + +func (m *serverHelloMsg) equal(i interface{}) bool { + m1, ok := i.(*serverHelloMsg) + if !ok { + return false + } + + if len(m.scts) != len(m1.scts) { + return false + } + for i, sct := range m.scts { + if !bytes.Equal(sct, m1.scts[i]) { + return false + } + } + + return bytes.Equal(m.raw, m1.raw) && + m.vers == m1.vers && + bytes.Equal(m.random, m1.random) && + bytes.Equal(m.sessionId, m1.sessionId) && + m.cipherSuite == m1.cipherSuite && + m.compressionMethod == m1.compressionMethod && + m.nextProtoNeg == m1.nextProtoNeg && + eqStrings(m.nextProtos, m1.nextProtos) && + m.ocspStapling == m1.ocspStapling && + m.ticketSupported == m1.ticketSupported && + m.secureRenegotiationSupported == m1.secureRenegotiationSupported && + bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) && + m.alpnProtocol == m1.alpnProtocol +} + +func (m *serverHelloMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + length := 38 + len(m.sessionId) + numExtensions := 0 + extensionsLength := 0 + + nextProtoLen := 0 + if m.nextProtoNeg { + numExtensions++ + for _, v := range m.nextProtos { + nextProtoLen += len(v) + } + nextProtoLen += len(m.nextProtos) + extensionsLength += nextProtoLen + } + if m.ocspStapling { + numExtensions++ + } + if m.ticketSupported { + numExtensions++ + } + if m.secureRenegotiationSupported { + extensionsLength += 1 + len(m.secureRenegotiation) + numExtensions++ + } + if alpnLen := len(m.alpnProtocol); alpnLen > 0 { + if alpnLen >= 256 { + panic("invalid ALPN protocol") + } + extensionsLength += 2 + 1 + alpnLen + numExtensions++ + } + sctLen := 0 + if len(m.scts) > 0 { + for _, sct := range m.scts { + sctLen += len(sct) + 2 + } + extensionsLength += 2 + sctLen + numExtensions++ + } + + if numExtensions > 0 { + extensionsLength += 4 * numExtensions + length += 2 + extensionsLength + } + + x := make([]byte, 4+length) + x[0] = typeServerHello + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + x[4] = uint8(m.vers >> 8) + x[5] = uint8(m.vers) + copy(x[6:38], m.random) + x[38] = uint8(len(m.sessionId)) + copy(x[39:39+len(m.sessionId)], m.sessionId) + z := x[39+len(m.sessionId):] + z[0] = uint8(m.cipherSuite >> 8) + z[1] = uint8(m.cipherSuite) + z[2] = m.compressionMethod + + z = z[3:] + if numExtensions > 0 { + z[0] = byte(extensionsLength >> 8) + z[1] = byte(extensionsLength) + z = z[2:] + } + if m.nextProtoNeg { + z[0] = byte(extensionNextProtoNeg >> 8) + z[1] = byte(extensionNextProtoNeg & 0xff) + z[2] = byte(nextProtoLen >> 8) + z[3] = byte(nextProtoLen) + z = z[4:] + + for _, v := range m.nextProtos { + l := len(v) + if l > 255 { + l = 255 + } + z[0] = byte(l) + copy(z[1:], []byte(v[0:l])) + z = z[1+l:] + } + } + if m.ocspStapling { + z[0] = byte(extensionStatusRequest >> 8) + z[1] = byte(extensionStatusRequest) + z = z[4:] + } + if m.ticketSupported { + z[0] = byte(extensionSessionTicket >> 8) + z[1] = byte(extensionSessionTicket) + z = z[4:] + } + if m.secureRenegotiationSupported { + z[0] = byte(extensionRenegotiationInfo >> 8) + z[1] = byte(extensionRenegotiationInfo & 0xff) + z[2] = 0 + z[3] = byte(len(m.secureRenegotiation) + 1) + z[4] = byte(len(m.secureRenegotiation)) + z = z[5:] + copy(z, m.secureRenegotiation) + z = z[len(m.secureRenegotiation):] + } + if alpnLen := len(m.alpnProtocol); alpnLen > 0 { + z[0] = byte(extensionALPN >> 8) + z[1] = byte(extensionALPN & 0xff) + l := 2 + 1 + alpnLen + z[2] = byte(l >> 8) + z[3] = byte(l) + l -= 2 + z[4] = byte(l >> 8) + z[5] = byte(l) + l -= 1 + z[6] = byte(l) + copy(z[7:], []byte(m.alpnProtocol)) + z = z[7+alpnLen:] + } + if sctLen > 0 { + z[0] = byte(extensionSCT >> 8) + z[1] = byte(extensionSCT) + l := sctLen + 2 + z[2] = byte(l >> 8) + z[3] = byte(l) + z[4] = byte(sctLen >> 8) + z[5] = byte(sctLen) + + z = z[6:] + for _, sct := range m.scts { + z[0] = byte(len(sct) >> 8) + z[1] = byte(len(sct)) + copy(z[2:], sct) + z = z[len(sct)+2:] + } + } + + m.raw = x + + return x +} + +func (m *serverHelloMsg) unmarshal(data []byte) bool { + if len(data) < 42 { + return false + } + m.raw = data + m.vers = uint16(data[4])<<8 | uint16(data[5]) + m.random = data[6:38] + sessionIdLen := int(data[38]) + if sessionIdLen > 32 || len(data) < 39+sessionIdLen { + return false + } + m.sessionId = data[39 : 39+sessionIdLen] + data = data[39+sessionIdLen:] + if len(data) < 3 { + return false + } + m.cipherSuite = uint16(data[0])<<8 | uint16(data[1]) + m.compressionMethod = data[2] + data = data[3:] + + m.nextProtoNeg = false + m.nextProtos = nil + m.ocspStapling = false + m.scts = nil + m.ticketSupported = false + m.alpnProtocol = "" + + if len(data) == 0 { + // ServerHello is optionally followed by extension data + return true + } + if len(data) < 2 { + return false + } + + extensionsLength := int(data[0])<<8 | int(data[1]) + data = data[2:] + if len(data) != extensionsLength { + return false + } + + for len(data) != 0 { + if len(data) < 4 { + return false + } + extension := uint16(data[0])<<8 | uint16(data[1]) + length := int(data[2])<<8 | int(data[3]) + data = data[4:] + if len(data) < length { + return false + } + + switch extension { + case extensionNextProtoNeg: + m.nextProtoNeg = true + d := data[:length] + for len(d) > 0 { + l := int(d[0]) + d = d[1:] + if l == 0 || l > len(d) { + return false + } + m.nextProtos = append(m.nextProtos, string(d[:l])) + d = d[l:] + } + case extensionStatusRequest: + if length > 0 { + return false + } + m.ocspStapling = true + case extensionSessionTicket: + if length > 0 { + return false + } + m.ticketSupported = true + case extensionRenegotiationInfo: + if length == 0 { + return false + } + d := data[:length] + l := int(d[0]) + d = d[1:] + if l != len(d) { + return false + } + + m.secureRenegotiation = d + m.secureRenegotiationSupported = true + case extensionALPN: + d := data[:length] + if len(d) < 3 { + return false + } + l := int(d[0])<<8 | int(d[1]) + if l != len(d)-2 { + return false + } + d = d[2:] + l = int(d[0]) + if l != len(d)-1 { + return false + } + d = d[1:] + if len(d) == 0 { + // ALPN protocols must not be empty. + return false + } + m.alpnProtocol = string(d) + case extensionSCT: + d := data[:length] + + if len(d) < 2 { + return false + } + l := int(d[0])<<8 | int(d[1]) + d = d[2:] + if len(d) != l || l == 0 { + return false + } + + m.scts = make([][]byte, 0, 3) + for len(d) != 0 { + if len(d) < 2 { + return false + } + sctLen := int(d[0])<<8 | int(d[1]) + d = d[2:] + if sctLen == 0 || len(d) < sctLen { + return false + } + m.scts = append(m.scts, d[:sctLen]) + d = d[sctLen:] + } + } + data = data[length:] + } + + return true +} + +type certificateMsg struct { + raw []byte + certificates [][]byte +} + +func (m *certificateMsg) equal(i interface{}) bool { + m1, ok := i.(*certificateMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + eqByteSlices(m.certificates, m1.certificates) +} + +func (m *certificateMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + var i int + for _, slice := range m.certificates { + i += len(slice) + } + + length := 3 + 3*len(m.certificates) + i + x = make([]byte, 4+length) + x[0] = typeCertificate + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + + certificateOctets := length - 3 + x[4] = uint8(certificateOctets >> 16) + x[5] = uint8(certificateOctets >> 8) + x[6] = uint8(certificateOctets) + + y := x[7:] + for _, slice := range m.certificates { + y[0] = uint8(len(slice) >> 16) + y[1] = uint8(len(slice) >> 8) + y[2] = uint8(len(slice)) + copy(y[3:], slice) + y = y[3+len(slice):] + } + + m.raw = x + return +} + +func (m *certificateMsg) unmarshal(data []byte) bool { + if len(data) < 7 { + return false + } + + m.raw = data + certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6]) + if uint32(len(data)) != certsLen+7 { + return false + } + + numCerts := 0 + d := data[7:] + for certsLen > 0 { + if len(d) < 4 { + return false + } + certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) + if uint32(len(d)) < 3+certLen { + return false + } + d = d[3+certLen:] + certsLen -= 3 + certLen + numCerts++ + } + + m.certificates = make([][]byte, numCerts) + d = data[7:] + for i := 0; i < numCerts; i++ { + certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) + m.certificates[i] = d[3 : 3+certLen] + d = d[3+certLen:] + } + + return true +} + +type serverKeyExchangeMsg struct { + raw []byte + key []byte +} + +func (m *serverKeyExchangeMsg) equal(i interface{}) bool { + m1, ok := i.(*serverKeyExchangeMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.key, m1.key) +} + +func (m *serverKeyExchangeMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + length := len(m.key) + x := make([]byte, length+4) + x[0] = typeServerKeyExchange + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + copy(x[4:], m.key) + + m.raw = x + return x +} + +func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool { + m.raw = data + if len(data) < 4 { + return false + } + m.key = data[4:] + return true +} + +type certificateStatusMsg struct { + raw []byte + statusType uint8 + response []byte +} + +func (m *certificateStatusMsg) equal(i interface{}) bool { + m1, ok := i.(*certificateStatusMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.statusType == m1.statusType && + bytes.Equal(m.response, m1.response) +} + +func (m *certificateStatusMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var x []byte + if m.statusType == statusTypeOCSP { + x = make([]byte, 4+4+len(m.response)) + x[0] = typeCertificateStatus + l := len(m.response) + 4 + x[1] = byte(l >> 16) + x[2] = byte(l >> 8) + x[3] = byte(l) + x[4] = statusTypeOCSP + + l -= 4 + x[5] = byte(l >> 16) + x[6] = byte(l >> 8) + x[7] = byte(l) + copy(x[8:], m.response) + } else { + x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType} + } + + m.raw = x + return x +} + +func (m *certificateStatusMsg) unmarshal(data []byte) bool { + m.raw = data + if len(data) < 5 { + return false + } + m.statusType = data[4] + + m.response = nil + if m.statusType == statusTypeOCSP { + if len(data) < 8 { + return false + } + respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7]) + if uint32(len(data)) != 4+4+respLen { + return false + } + m.response = data[8:] + } + return true +} + +type serverHelloDoneMsg struct{} + +func (m *serverHelloDoneMsg) equal(i interface{}) bool { + _, ok := i.(*serverHelloDoneMsg) + return ok +} + +func (m *serverHelloDoneMsg) marshal() []byte { + x := make([]byte, 4) + x[0] = typeServerHelloDone + return x +} + +func (m *serverHelloDoneMsg) unmarshal(data []byte) bool { + return len(data) == 4 +} + +type clientKeyExchangeMsg struct { + raw []byte + ciphertext []byte +} + +func (m *clientKeyExchangeMsg) equal(i interface{}) bool { + m1, ok := i.(*clientKeyExchangeMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.ciphertext, m1.ciphertext) +} + +func (m *clientKeyExchangeMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + length := len(m.ciphertext) + x := make([]byte, length+4) + x[0] = typeClientKeyExchange + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + copy(x[4:], m.ciphertext) + + m.raw = x + return x +} + +func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool { + m.raw = data + if len(data) < 4 { + return false + } + l := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) + if l != len(data)-4 { + return false + } + m.ciphertext = data[4:] + return true +} + +type finishedMsg struct { + raw []byte + verifyData []byte +} + +func (m *finishedMsg) equal(i interface{}) bool { + m1, ok := i.(*finishedMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.verifyData, m1.verifyData) +} + +func (m *finishedMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + x = make([]byte, 4+len(m.verifyData)) + x[0] = typeFinished + x[3] = byte(len(m.verifyData)) + copy(x[4:], m.verifyData) + m.raw = x + return +} + +func (m *finishedMsg) unmarshal(data []byte) bool { + m.raw = data + if len(data) < 4 { + return false + } + m.verifyData = data[4:] + return true +} + +type nextProtoMsg struct { + raw []byte + proto string +} + +func (m *nextProtoMsg) equal(i interface{}) bool { + m1, ok := i.(*nextProtoMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.proto == m1.proto +} + +func (m *nextProtoMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + l := len(m.proto) + if l > 255 { + l = 255 + } + + padding := 32 - (l+2)%32 + length := l + padding + 2 + x := make([]byte, length+4) + x[0] = typeNextProtocol + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + + y := x[4:] + y[0] = byte(l) + copy(y[1:], []byte(m.proto[0:l])) + y = y[1+l:] + y[0] = byte(padding) + + m.raw = x + + return x +} + +func (m *nextProtoMsg) unmarshal(data []byte) bool { + m.raw = data + + if len(data) < 5 { + return false + } + data = data[4:] + protoLen := int(data[0]) + data = data[1:] + if len(data) < protoLen { + return false + } + m.proto = string(data[0:protoLen]) + data = data[protoLen:] + + if len(data) < 1 { + return false + } + paddingLen := int(data[0]) + data = data[1:] + if len(data) != paddingLen { + return false + } + + return true +} + +type certificateRequestMsg struct { + raw []byte + // hasSignatureAndHash indicates whether this message includes a list + // of signature and hash functions. This change was introduced with TLS + // 1.2. + hasSignatureAndHash bool + + certificateTypes []byte + signatureAndHashes []signatureAndHash + certificateAuthorities [][]byte +} + +func (m *certificateRequestMsg) equal(i interface{}) bool { + m1, ok := i.(*certificateRequestMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.certificateTypes, m1.certificateTypes) && + eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) && + eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) +} + +func (m *certificateRequestMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + // See http://tools.ietf.org/html/rfc4346#section-7.4.4 + length := 1 + len(m.certificateTypes) + 2 + casLength := 0 + for _, ca := range m.certificateAuthorities { + casLength += 2 + len(ca) + } + length += casLength + + if m.hasSignatureAndHash { + length += 2 + 2*len(m.signatureAndHashes) + } + + x = make([]byte, 4+length) + x[0] = typeCertificateRequest + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + + x[4] = uint8(len(m.certificateTypes)) + + copy(x[5:], m.certificateTypes) + y := x[5+len(m.certificateTypes):] + + if m.hasSignatureAndHash { + n := len(m.signatureAndHashes) * 2 + y[0] = uint8(n >> 8) + y[1] = uint8(n) + y = y[2:] + for _, sigAndHash := range m.signatureAndHashes { + y[0] = sigAndHash.hash + y[1] = sigAndHash.signature + y = y[2:] + } + } + + y[0] = uint8(casLength >> 8) + y[1] = uint8(casLength) + y = y[2:] + for _, ca := range m.certificateAuthorities { + y[0] = uint8(len(ca) >> 8) + y[1] = uint8(len(ca)) + y = y[2:] + copy(y, ca) + y = y[len(ca):] + } + + m.raw = x + return +} + +func (m *certificateRequestMsg) unmarshal(data []byte) bool { + m.raw = data + + if len(data) < 5 { + return false + } + + length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) + if uint32(len(data))-4 != length { + return false + } + + numCertTypes := int(data[4]) + data = data[5:] + if numCertTypes == 0 || len(data) <= numCertTypes { + return false + } + + m.certificateTypes = make([]byte, numCertTypes) + if copy(m.certificateTypes, data) != numCertTypes { + return false + } + + data = data[numCertTypes:] + + if m.hasSignatureAndHash { + if len(data) < 2 { + return false + } + sigAndHashLen := uint16(data[0])<<8 | uint16(data[1]) + data = data[2:] + if sigAndHashLen&1 != 0 { + return false + } + if len(data) < int(sigAndHashLen) { + return false + } + numSigAndHash := sigAndHashLen / 2 + m.signatureAndHashes = make([]signatureAndHash, numSigAndHash) + for i := range m.signatureAndHashes { + m.signatureAndHashes[i].hash = data[0] + m.signatureAndHashes[i].signature = data[1] + data = data[2:] + } + } + + if len(data) < 2 { + return false + } + casLength := uint16(data[0])<<8 | uint16(data[1]) + data = data[2:] + if len(data) < int(casLength) { + return false + } + cas := make([]byte, casLength) + copy(cas, data) + data = data[casLength:] + + m.certificateAuthorities = nil + for len(cas) > 0 { + if len(cas) < 2 { + return false + } + caLen := uint16(cas[0])<<8 | uint16(cas[1]) + cas = cas[2:] + + if len(cas) < int(caLen) { + return false + } + + m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen]) + cas = cas[caLen:] + } + + return len(data) == 0 +} + +type certificateVerifyMsg struct { + raw []byte + hasSignatureAndHash bool + signatureAndHash signatureAndHash + signature []byte +} + +func (m *certificateVerifyMsg) equal(i interface{}) bool { + m1, ok := i.(*certificateVerifyMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.hasSignatureAndHash == m1.hasSignatureAndHash && + m.signatureAndHash.hash == m1.signatureAndHash.hash && + m.signatureAndHash.signature == m1.signatureAndHash.signature && + bytes.Equal(m.signature, m1.signature) +} + +func (m *certificateVerifyMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + // See http://tools.ietf.org/html/rfc4346#section-7.4.8 + siglength := len(m.signature) + length := 2 + siglength + if m.hasSignatureAndHash { + length += 2 + } + x = make([]byte, 4+length) + x[0] = typeCertificateVerify + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + y := x[4:] + if m.hasSignatureAndHash { + y[0] = m.signatureAndHash.hash + y[1] = m.signatureAndHash.signature + y = y[2:] + } + y[0] = uint8(siglength >> 8) + y[1] = uint8(siglength) + copy(y[2:], m.signature) + + m.raw = x + + return +} + +func (m *certificateVerifyMsg) unmarshal(data []byte) bool { + m.raw = data + + if len(data) < 6 { + return false + } + + length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) + if uint32(len(data))-4 != length { + return false + } + + data = data[4:] + if m.hasSignatureAndHash { + m.signatureAndHash.hash = data[0] + m.signatureAndHash.signature = data[1] + data = data[2:] + } + + if len(data) < 2 { + return false + } + siglength := int(data[0])<<8 + int(data[1]) + data = data[2:] + if len(data) != siglength { + return false + } + + m.signature = data + + return true +} + +type newSessionTicketMsg struct { + raw []byte + ticket []byte +} + +func (m *newSessionTicketMsg) equal(i interface{}) bool { + m1, ok := i.(*newSessionTicketMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.ticket, m1.ticket) +} + +func (m *newSessionTicketMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + // See http://tools.ietf.org/html/rfc5077#section-3.3 + ticketLen := len(m.ticket) + length := 2 + 4 + ticketLen + x = make([]byte, 4+length) + x[0] = typeNewSessionTicket + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + x[8] = uint8(ticketLen >> 8) + x[9] = uint8(ticketLen) + copy(x[10:], m.ticket) + + m.raw = x + + return +} + +func (m *newSessionTicketMsg) unmarshal(data []byte) bool { + m.raw = data + + if len(data) < 10 { + return false + } + + length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) + if uint32(len(data))-4 != length { + return false + } + + ticketLen := int(data[8])<<8 + int(data[9]) + if len(data)-10 != ticketLen { + return false + } + + m.ticket = data[10:] + + return true +} + +type helloRequestMsg struct { +} + +func (*helloRequestMsg) marshal() []byte { + return []byte{typeHelloRequest, 0, 0, 0} +} + +func (*helloRequestMsg) unmarshal(data []byte) bool { + return len(data) == 4 +} + +func eqUint16s(x, y []uint16) bool { + if len(x) != len(y) { + return false + } + for i, v := range x { + if y[i] != v { + return false + } + } + return true +} + +func eqCurveIDs(x, y []CurveID) bool { + if len(x) != len(y) { + return false + } + for i, v := range x { + if y[i] != v { + return false + } + } + return true +} + +func eqStrings(x, y []string) bool { + if len(x) != len(y) { + return false + } + for i, v := range x { + if y[i] != v { + return false + } + } + return true +} + +func eqByteSlices(x, y [][]byte) bool { + if len(x) != len(y) { + return false + } + for i, v := range x { + if !bytes.Equal(v, y[i]) { + return false + } + } + return true +} + +func eqSignatureAndHashes(x, y []signatureAndHash) bool { + if len(x) != len(y) { + return false + } + for i, v := range x { + v2 := y[i] + if v.hash != v2.hash || v.signature != v2.signature { + return false + } + } + return true +} +*/ diff --git a/main.go b/main.go new file mode 100644 index 0000000..7acef1b --- /dev/null +++ b/main.go @@ -0,0 +1,152 @@ +package main + +import ( + "fmt" + "github.com/go-yaml/yaml" + "github.com/golang/glog" + "io" + "io/ioutil" + //"crypto/tls" + "flag" + //"log" + "net" +) + +func getSNIServerName(buf []byte) string { + n := len(buf) + if n < 5 { + glog.Error("not tls handshake") + return "" + } + + // tls record type + if recordType(buf[0]) != recordTypeHandshake { + glog.Error("not tls") + return "" + } + + // tls major version + if buf[1] != 3 { + glog.Error("TLS version < 3 not supported") + return "" + } + + // payload length + //l := int(buf[3])<<16 + int(buf[4]) + + //log.Printf("length: %d, got: %d", l, n) + + // handshake message type + if uint8(buf[5]) != typeClientHello { + glog.Error("not client hello") + return "" + } + + // parse client hello message + + msg := &clientHelloMsg{} + + // client hello message not include tls header, 5 bytes + ret := msg.unmarshal(buf[5:n]) + if !ret { + glog.Error("parse hello message return false") + return "" + } + return msg.serverName +} + +func forward(c net.Conn, data []byte, dst string) { + c1, err := net.Dial("tcp", dst) + if err != nil { + glog.Error(err) + return + } + + defer c1.Close() + + if _, err = c1.Write(data); err != nil { + glog.Error(err) + return + } + + ch := make(chan struct{}, 2) + + go func() { + io.Copy(c1, c) + ch <- struct{}{} + }() + + go func() { + io.Copy(c, c1) + ch <- struct{}{} + }() + + <-ch +} + +func getDST(c net.Conn, name string) string { + addr := c.LocalAddr().(*net.TCPAddr) + dst := cfg.ForwardRules.GetN(name, addr.Port) + return dst +} + +func getDefaultDST() string { + return cfg.Default +} + +func serve(c net.Conn) { + defer c.Close() + + buf := make([]byte, 1024) + n, err := c.Read(buf) + if err != nil { + glog.Error(err) + return + } + servername := getSNIServerName(buf[:n]) + if servername == "" { + forward(c, buf[:n], getDefaultDST()) + return + } + dst := getDST(c, servername) + if dst == "" { + dst = getDefaultDST() + } + forward(c, buf[:n], dst) +} + +var cfg conf + +func main() { + var cfgfile string + flag.StringVar(&cfgfile, "c", "config.yaml", "config file") + flag.Set("logtostderr", "true") + flag.Parse() + + data, err := ioutil.ReadFile(cfgfile) + if err != nil { + glog.Fatal(err) + } + if err := yaml.Unmarshal(data, &cfg); err != nil { + glog.Fatal(err) + } + + for _, d := range cfg.Listen { + glog.Infof("listen on :%d", d) + l, err := net.Listen("tcp", fmt.Sprintf(":%d", d)) + if err != nil { + glog.Fatal(err) + } + go func(l net.Listener) { + defer l.Close() + for { + c1, err := l.Accept() + if err != nil { + glog.Fatal(err) + } + go serve(c1) + } + }(l) + } + select {} +}