master
fangdingjun 6 years ago
parent 2c4c546551
commit e168de539b

@ -1,176 +1,176 @@
package gnutls package gnutls
/* /*
#include "_gnutls.h" #include "_gnutls.h"
*/ */
import "C" import "C"
import ( import (
"fmt" "fmt"
"log" "log"
"runtime" "runtime"
"strings" "strings"
"unsafe" "unsafe"
) )
// Certificate x509 certificate // Certificate x509 certificate
type Certificate struct { type Certificate struct {
cert *C.gnutls_pcert_st cert *C.gnutls_pcert_st
privkey C.gnutls_privkey_t privkey C.gnutls_privkey_t
certSize C.int certSize C.int
} }
// Free free the certificate context // Free free the certificate context
func (c *Certificate) Free() { func (c *Certificate) Free() {
if c.cert != nil { if c.cert != nil {
C.free_cert_list(c.cert, c.certSize) C.free_cert_list(c.cert, c.certSize)
} }
if c.privkey != nil { if c.privkey != nil {
C.gnutls_privkey_deinit(c.privkey) C.gnutls_privkey_deinit(c.privkey)
} }
c.cert = nil c.cert = nil
c.privkey = nil c.privkey = nil
c.certSize = 0 c.certSize = 0
} }
func (c *Certificate) free() { func (c *Certificate) free() {
//log.Println("free certificate") //log.Println("free certificate")
c.Free() c.Free()
} }
func (c *Certificate) matchName(name string) bool { func (c *Certificate) matchName(name string) bool {
cname := C.CString(name) cname := C.CString(name)
defer C.free(unsafe.Pointer(cname)) defer C.free(unsafe.Pointer(cname))
ret := C.cert_check_hostname(c.cert, c.certSize, cname) ret := C.cert_check_hostname(c.cert, c.certSize, cname)
if int(ret) < 0 { if int(ret) < 0 {
log.Println(C.GoString(C.gnutls_strerror(ret))) log.Println(C.GoString(C.gnutls_strerror(ret)))
return false return false
} }
if int(ret) > 0 { if int(ret) > 0 {
return true return true
} }
return false return false
} }
// CommonName get CN field in subject, // CommonName get CN field in subject,
// //
// example: subject C=xx,ST=bbb,CN=abc will return abc // example: subject C=xx,ST=bbb,CN=abc will return abc
func (c *Certificate) CommonName() string { func (c *Certificate) CommonName() string {
return c.commonName(0) return c.commonName(0)
} }
func (c *Certificate) commonName(index int) string { func (c *Certificate) commonName(index int) string {
dnMap := make(map[string]string) dnMap := make(map[string]string)
dn := c.getDN(0) dn := c.getDN(0)
if dn != "" { if dn != "" {
dnFields := strings.Split(dn, ",") dnFields := strings.Split(dn, ",")
for _, d := range dnFields { for _, d := range dnFields {
dd := strings.Split(d, "=") dd := strings.Split(d, "=")
k := dd[0] k := dd[0]
v := dd[1] v := dd[1]
dnMap[k] = v dnMap[k] = v
} }
} }
if cn, ok := dnMap["CN"]; ok { if cn, ok := dnMap["CN"]; ok {
return cn return cn
} }
return "" return ""
} }
// GetAltName get altname in certificate // GetAltName get altname in certificate
func (c *Certificate) GetAltName() string { func (c *Certificate) GetAltName() string {
return c.getAltName(0, 0) return c.getAltName(0, 0)
} }
func (c *Certificate) getAltName(index int, nameindex int) string { func (c *Certificate) getAltName(index int, nameindex int) string {
out := C.malloc(1024) out := C.malloc(1024)
defer C.free(out) defer C.free(out)
size := C.get_pcert_alt_name( size := C.get_pcert_alt_name(
c.cert, C.int(index), C.int(nameindex), (*C.char)(out)) c.cert, C.int(index), C.int(nameindex), (*C.char)(out))
if int(size) < 0 { if int(size) < 0 {
log.Println(C.GoString(C.gnutls_strerror(size))) log.Println(C.GoString(C.gnutls_strerror(size)))
return "" return ""
} }
name := C.GoBytes(out, size) name := C.GoBytes(out, size)
return string(name) return string(name)
} }
//GetCertString return certificate info string in one line //GetCertString return certificate info string in one line
func (c *Certificate) GetCertString() string { func (c *Certificate) GetCertString() string {
return c.getCertString(0, 1) return c.getCertString(0, 1)
} }
func (c *Certificate) getCertString(index int, flag int) string { func (c *Certificate) getCertString(index int, flag int) string {
out := C.malloc(4096) out := C.malloc(4096)
defer C.free(out) defer C.free(out)
size := C.get_cert_str(c.cert, C.int(index), C.int(flag), (*C.char)(out)) size := C.get_cert_str(c.cert, C.int(index), C.int(flag), (*C.char)(out))
if int(size) < 0 { if int(size) < 0 {
log.Println(C.GoString(C.gnutls_strerror(size))) log.Println(C.GoString(C.gnutls_strerror(size)))
return "" return ""
} }
s := C.GoBytes(out, size) s := C.GoBytes(out, size)
return string(s) return string(s)
} }
// GetDN get the certificate subject, like O=st,C=aa,CN=localhost // GetDN get the certificate subject, like O=st,C=aa,CN=localhost
func (c *Certificate) GetDN() string { func (c *Certificate) GetDN() string {
return c.getDN(0) return c.getDN(0)
} }
func (c *Certificate) getDN(index int) string { func (c *Certificate) getDN(index int) string {
cbuf := C.malloc(200) cbuf := C.malloc(200)
defer C.free(cbuf) defer C.free(cbuf)
size := C.get_cert_dn(c.cert, C.int(index), (*C.char)(cbuf)) size := C.get_cert_dn(c.cert, C.int(index), (*C.char)(cbuf))
if int(size) < 0 { if int(size) < 0 {
log.Println(C.GoString(C.gnutls_strerror(size))) log.Println(C.GoString(C.gnutls_strerror(size)))
return "" return ""
} }
s := C.GoBytes(cbuf, size) s := C.GoBytes(cbuf, size)
return string(s) return string(s)
} }
// GetIssuerDN get the certificate issuer's subject, like O=st,C=ac,CN=localhost // GetIssuerDN get the certificate issuer's subject, like O=st,C=ac,CN=localhost
func (c *Certificate) GetIssuerDN() string { func (c *Certificate) GetIssuerDN() string {
return c.getIssuerDN(0) return c.getIssuerDN(0)
} }
func (c *Certificate) getIssuerDN(index int) string { func (c *Certificate) getIssuerDN(index int) string {
cbuf := C.malloc(200) cbuf := C.malloc(200)
defer C.free(cbuf) defer C.free(cbuf)
size := C.get_cert_issuer_dn(c.cert, C.int(index), (*C.char)(cbuf)) size := C.get_cert_issuer_dn(c.cert, C.int(index), (*C.char)(cbuf))
if int(size) < 0 { if int(size) < 0 {
log.Println(C.GoString(C.gnutls_strerror(size))) log.Println(C.GoString(C.gnutls_strerror(size)))
return "" return ""
} }
s := C.GoBytes(cbuf, size) s := C.GoBytes(cbuf, size)
return string(s) return string(s)
} }
// LoadX509KeyPair load certificate pair, // LoadX509KeyPair load certificate pair,
// the return Certifciate must be freed by call Free(), // the return Certifciate must be freed by call Free(),
func LoadX509KeyPair(certfile, keyfile string) (*Certificate, error) { func LoadX509KeyPair(certfile, keyfile string) (*Certificate, error) {
_certfile := C.CString(certfile) _certfile := C.CString(certfile)
_keyfile := C.CString(keyfile) _keyfile := C.CString(keyfile)
defer C.free(unsafe.Pointer(_certfile)) defer C.free(unsafe.Pointer(_certfile))
defer C.free(unsafe.Pointer(_keyfile)) defer C.free(unsafe.Pointer(_keyfile))
certificate := &Certificate{} certificate := &Certificate{}
var ret C.int var ret C.int
var certSize C.int var certSize C.int
cert := C.load_cert_list(_certfile, (*C.int)(unsafe.Pointer(&certSize)), cert := C.load_cert_list(_certfile, (*C.int)(unsafe.Pointer(&certSize)),
(*C.int)(unsafe.Pointer(&ret))) (*C.int)(unsafe.Pointer(&ret)))
if int(ret) < 0 { if int(ret) < 0 {
return nil, fmt.Errorf("load cert failed: %s", return nil, fmt.Errorf("load cert failed: %s",
C.GoString(C.gnutls_strerror(ret))) C.GoString(C.gnutls_strerror(ret)))
} }
privkey := C.load_privkey(_keyfile, (*C.int)(unsafe.Pointer(&ret))) privkey := C.load_privkey(_keyfile, (*C.int)(unsafe.Pointer(&ret)))
if int(ret) < 0 { if int(ret) < 0 {
return nil, fmt.Errorf("load privkey: %s", return nil, fmt.Errorf("load privkey: %s",
C.GoString(C.gnutls_strerror(ret))) C.GoString(C.gnutls_strerror(ret)))
} }
certificate.cert = cert certificate.cert = cert
certificate.privkey = privkey certificate.privkey = privkey
certificate.certSize = certSize certificate.certSize = certSize
runtime.SetFinalizer(certificate, (*Certificate).free) runtime.SetFinalizer(certificate, (*Certificate).free)
return certificate, nil return certificate, nil
} }

@ -1,35 +1,35 @@
package gnutls package gnutls
import ( import (
"runtime" "runtime"
"testing" "testing"
"time" "time"
) )
func TestGetAltname(t *testing.T) { func TestGetAltname(t *testing.T) {
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key") cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Logf("%+v", cert) t.Logf("%+v", cert)
t.Log(cert.GetAltName()) t.Log(cert.GetAltName())
t.Log(cert.GetCertString()) t.Log(cert.GetCertString())
t.Log(cert.GetDN()) t.Log(cert.GetDN())
t.Log(cert.GetIssuerDN()) t.Log(cert.GetIssuerDN())
t.Log("CN ", cert.CommonName()) t.Log("CN ", cert.CommonName())
//t.Log("flag 0: ", cert.getCertString(0, 0)) //t.Log("flag 0: ", cert.getCertString(0, 0))
//t.Log("flag 1: ", cert.getCertString(0, 1)) //t.Log("flag 1: ", cert.getCertString(0, 1))
//t.Log("flag 2: ", cert.getCertString(0, 2)) //t.Log("flag 2: ", cert.getCertString(0, 2))
//t.Log("flag 3: ", cert.getCertString(0, 3)) //t.Log("flag 3: ", cert.getCertString(0, 3))
cert.Free() cert.Free()
} }
func _loadCert(certfile, keyfile string) (*Certificate, error) { func _loadCert(certfile, keyfile string) (*Certificate, error) {
return LoadX509KeyPair(certfile, keyfile) return LoadX509KeyPair(certfile, keyfile)
} }
func TestCertGC(t *testing.T) { func TestCertGC(t *testing.T) {
_loadCert("testdata/server.crt", "testdata/server.key") _loadCert("testdata/server.crt", "testdata/server.key")
runtime.GC() runtime.GC()
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
} }

@ -1,167 +1,167 @@
package gnutls package gnutls
/* /*
#include "_gnutls.h" #include "_gnutls.h"
#include <stdlib.h> #include <stdlib.h>
*/ */
import "C" import "C"
import ( import (
"errors" "errors"
"fmt" "fmt"
"log" "log"
"runtime" "runtime"
) )
// CipherType cipher type // CipherType cipher type
type CipherType int type CipherType int
const ( const (
GNUTLS_CIPHER_AES_128_CBC CipherType = 4 GNUTLS_CIPHER_AES_128_CBC CipherType = 4
GNUTLS_CIPHER_AES_256_CBC CipherType = 5 GNUTLS_CIPHER_AES_256_CBC CipherType = 5
GNUTLS_CIPHER_ARCFOUR_40 CipherType = 6 GNUTLS_CIPHER_ARCFOUR_40 CipherType = 6
GNUTLS_CIPHER_CAMELLIA_128_CBC CipherType = 7 GNUTLS_CIPHER_CAMELLIA_128_CBC CipherType = 7
GNUTLS_CIPHER_CAMELLIA_256_CBC CipherType = 8 GNUTLS_CIPHER_CAMELLIA_256_CBC CipherType = 8
GNUTLS_CIPHER_AES_192_CBC CipherType = 9 GNUTLS_CIPHER_AES_192_CBC CipherType = 9
GNUTLS_CIPHER_AES_128_GCM CipherType = 10 GNUTLS_CIPHER_AES_128_GCM CipherType = 10
GNUTLS_CIPHER_AES_256_GCM CipherType = 11 GNUTLS_CIPHER_AES_256_GCM CipherType = 11
GNUTLS_CIPHER_CAMELLIA_192_CBC CipherType = 12 GNUTLS_CIPHER_CAMELLIA_192_CBC CipherType = 12
GNUTLS_CIPHER_SALSA20_256 CipherType = 13 GNUTLS_CIPHER_SALSA20_256 CipherType = 13
GNUTLS_CIPHER_ESTREAM_SALSA20_256 CipherType = 14 GNUTLS_CIPHER_ESTREAM_SALSA20_256 CipherType = 14
GNUTLS_CIPHER_CAMELLIA_128_GCM CipherType = 15 GNUTLS_CIPHER_CAMELLIA_128_GCM CipherType = 15
GNUTLS_CIPHER_CAMELLIA_256_GCM CipherType = 16 GNUTLS_CIPHER_CAMELLIA_256_GCM CipherType = 16
GNUTLS_CIPHER_RC2_40_CBC CipherType = 17 GNUTLS_CIPHER_RC2_40_CBC CipherType = 17
GNUTLS_CIPHER_DES_CBC CipherType = 18 GNUTLS_CIPHER_DES_CBC CipherType = 18
GNUTLS_CIPHER_AES_128_CCM CipherType = 19 GNUTLS_CIPHER_AES_128_CCM CipherType = 19
GNUTLS_CIPHER_AES_256_CCM CipherType = 20 GNUTLS_CIPHER_AES_256_CCM CipherType = 20
GNUTLS_CIPHER_AES_128_CCM_8 CipherType = 21 GNUTLS_CIPHER_AES_128_CCM_8 CipherType = 21
GNUTLS_CIPHER_AES_256_CCM_8 CipherType = 22 GNUTLS_CIPHER_AES_256_CCM_8 CipherType = 22
GNUTLS_CIPHER_CHACHA20_POLY1305 CipherType = 23 GNUTLS_CIPHER_CHACHA20_POLY1305 CipherType = 23
) )
var ( var (
// ErrBlockSize wrong block size // ErrBlockSize wrong block size
ErrBlockSize = errors.New("wrong block size") ErrBlockSize = errors.New("wrong block size")
// ErrKeyLength wrong key length // ErrKeyLength wrong key length
ErrKeyLength = errors.New("wrong key length") ErrKeyLength = errors.New("wrong key length")
// ErrIVLength wrong iv length // ErrIVLength wrong iv length
ErrIVLength = errors.New("wrong iv length") ErrIVLength = errors.New("wrong iv length")
) )
// Cipher gnutls cipher struct // Cipher gnutls cipher struct
type Cipher struct { type Cipher struct {
cipher C.gnutls_cipher_hd_t cipher C.gnutls_cipher_hd_t
t CipherType t CipherType
blockSize int blockSize int
} }
// NewCipher create a new cipher by give type, key, iv // NewCipher create a new cipher by give type, key, iv
// //
// example: // example:
// NewCipher(GNUTLS_CIPHER_AES_128_CBC, []byte("1234567890abcdef"), []byte("abcdef0123456789")) // NewCipher(GNUTLS_CIPHER_AES_128_CBC, []byte("1234567890abcdef"), []byte("abcdef0123456789"))
// //
// you can use GetCipherKeySize, GetCipherBlockSize, GetCipherIVSize to determine the given cipher 's key, block, iv size // you can use GetCipherKeySize, GetCipherBlockSize, GetCipherIVSize to determine the given cipher 's key, block, iv size
func NewCipher(t CipherType, key []byte, iv []byte) (*Cipher, error) { func NewCipher(t CipherType, key []byte, iv []byte) (*Cipher, error) {
keysize := GetCipherKeySize(t) keysize := GetCipherKeySize(t)
ivSize := GetCipherIVSize(t) ivSize := GetCipherIVSize(t)
blocksize := GetCipherBlockSize(t) blocksize := GetCipherBlockSize(t)
//log.Printf("block size: %d, iv size: %d", int(ivSize), int(blockSize)) //log.Printf("block size: %d, iv size: %d", int(ivSize), int(blockSize))
if len(key) != int(keysize) { if len(key) != int(keysize) {
return nil, ErrKeyLength return nil, ErrKeyLength
} }
if len(iv) != int(ivSize) { if len(iv) != int(ivSize) {
return nil, ErrIVLength return nil, ErrIVLength
} }
ckey := C.CBytes(key) ckey := C.CBytes(key)
civ := C.CBytes(iv) civ := C.CBytes(iv)
defer C.free(ckey) defer C.free(ckey)
defer C.free(civ) defer C.free(civ)
c := C.new_cipher(C.int(t), (*C.char)(ckey), C.int(len(key)), (*C.char)(civ), C.int(len(iv))) c := C.new_cipher(C.int(t), (*C.char)(ckey), C.int(len(key)), (*C.char)(civ), C.int(len(iv)))
if c == nil { if c == nil {
log.Println("new cipher return nil") log.Println("new cipher return nil")
return nil, nil return nil, nil
} }
cipher := &Cipher{c, t, blocksize} cipher := &Cipher{c, t, blocksize}
runtime.SetFinalizer(cipher, (*Cipher).free) runtime.SetFinalizer(cipher, (*Cipher).free)
return cipher, nil return cipher, nil
} }
// Encrypt encrypt the buf and place the encrypted data in dst, // Encrypt encrypt the buf and place the encrypted data in dst,
// the buf size must multiple of cipher's block size // the buf size must multiple of cipher's block size
func (c *Cipher) Encrypt(dst, buf []byte) error { func (c *Cipher) Encrypt(dst, buf []byte) error {
if len(buf)%c.blockSize != 0 { if len(buf)%c.blockSize != 0 {
return ErrBlockSize return ErrBlockSize
} }
cbuf := C.CBytes(buf) cbuf := C.CBytes(buf)
defer C.free(cbuf) defer C.free(cbuf)
bufLen := C.size_t(len(buf)) bufLen := C.size_t(len(buf))
dstBuf := C.malloc(bufLen) dstBuf := C.malloc(bufLen)
defer C.free(dstBuf) defer C.free(dstBuf)
ret := C.gnutls_cipher_encrypt2(c.cipher, cbuf, bufLen, dstBuf, bufLen) ret := C.gnutls_cipher_encrypt2(c.cipher, cbuf, bufLen, dstBuf, bufLen)
if int(ret) < 0 { if int(ret) < 0 {
return fmt.Errorf("encrypt error: %s", C.GoString(C.gnutls_strerror(ret))) return fmt.Errorf("encrypt error: %s", C.GoString(C.gnutls_strerror(ret)))
} }
_buf := C.GoBytes(dstBuf, C.int(bufLen)) _buf := C.GoBytes(dstBuf, C.int(bufLen))
copy(dst, _buf) copy(dst, _buf)
return nil return nil
} }
// Decrypt decrypt the buf and place the decrypted data in dst, // Decrypt decrypt the buf and place the decrypted data in dst,
// the buf size must multiple of cipher's block size // the buf size must multiple of cipher's block size
func (c *Cipher) Decrypt(dst, buf []byte) error { func (c *Cipher) Decrypt(dst, buf []byte) error {
if len(buf)%c.blockSize != 0 { if len(buf)%c.blockSize != 0 {
return ErrBlockSize return ErrBlockSize
} }
cbuf := C.CBytes(buf) cbuf := C.CBytes(buf)
defer C.free(cbuf) defer C.free(cbuf)
bufLen := C.size_t(len(buf)) bufLen := C.size_t(len(buf))
dstBuf := C.malloc(C.size_t(len(buf))) dstBuf := C.malloc(C.size_t(len(buf)))
defer C.free(dstBuf) defer C.free(dstBuf)
ret := C.gnutls_cipher_decrypt2(c.cipher, cbuf, bufLen, dstBuf, bufLen) ret := C.gnutls_cipher_decrypt2(c.cipher, cbuf, bufLen, dstBuf, bufLen)
if int(ret) < 0 { if int(ret) < 0 {
return fmt.Errorf("decrypt error: %s", C.GoString(C.gnutls_strerror(ret))) return fmt.Errorf("decrypt error: %s", C.GoString(C.gnutls_strerror(ret)))
} }
_buf := C.GoBytes(dstBuf, C.int(bufLen)) _buf := C.GoBytes(dstBuf, C.int(bufLen))
copy(dst, _buf) copy(dst, _buf)
return nil return nil
} }
// Close destroy the cipher context // Close destroy the cipher context
func (c *Cipher) Close() error { func (c *Cipher) Close() error {
if c.cipher != nil { if c.cipher != nil {
C.gnutls_cipher_deinit(c.cipher) C.gnutls_cipher_deinit(c.cipher)
c.cipher = nil c.cipher = nil
} }
return nil return nil
} }
func (c *Cipher) free() { func (c *Cipher) free() {
//log.Println("free cipher") //log.Println("free cipher")
c.Close() c.Close()
} }
// GetCipherKeySize get the cipher algorithm key length // GetCipherKeySize get the cipher algorithm key length
func GetCipherKeySize(t CipherType) int { func GetCipherKeySize(t CipherType) int {
return int(C.gnutls_cipher_get_key_size(C.gnutls_cipher_algorithm_t(t))) return int(C.gnutls_cipher_get_key_size(C.gnutls_cipher_algorithm_t(t)))
} }
// GetCipherIVSize get the cipher algorithm iv length // GetCipherIVSize get the cipher algorithm iv length
func GetCipherIVSize(t CipherType) int { func GetCipherIVSize(t CipherType) int {
return int(C.gnutls_cipher_get_iv_size(C.gnutls_cipher_algorithm_t(t))) return int(C.gnutls_cipher_get_iv_size(C.gnutls_cipher_algorithm_t(t)))
} }
// GetCipherBlockSize get the cipher algorithm block size // GetCipherBlockSize get the cipher algorithm block size
func GetCipherBlockSize(t CipherType) int { func GetCipherBlockSize(t CipherType) int {
return int(C.gnutls_cipher_get_block_size(C.gnutls_cipher_algorithm_t(t))) return int(C.gnutls_cipher_get_block_size(C.gnutls_cipher_algorithm_t(t)))
} }

@ -1,130 +1,130 @@
package gnutls package gnutls
import ( import (
"bytes" "bytes"
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/rand" "crypto/rand"
"runtime" "runtime"
"testing" "testing"
"time" "time"
) )
func TestCipherSize(t *testing.T) { func TestCipherSize(t *testing.T) {
data := []struct { data := []struct {
t CipherType t CipherType
bsize int bsize int
isize int isize int
}{ }{
{GNUTLS_CIPHER_AES_128_CBC, 16, 16}, {GNUTLS_CIPHER_AES_128_CBC, 16, 16},
{GNUTLS_CIPHER_AES_192_CBC, 24, 16}, {GNUTLS_CIPHER_AES_192_CBC, 24, 16},
{GNUTLS_CIPHER_AES_256_CBC, 32, 16}, {GNUTLS_CIPHER_AES_256_CBC, 32, 16},
} }
for _, d := range data { for _, d := range data {
blocksize := GetCipherKeySize(d.t) blocksize := GetCipherKeySize(d.t)
if blocksize != d.bsize { if blocksize != d.bsize {
t.Errorf("%d block size expect: %d, got: %d", d.t, d.bsize, blocksize) t.Errorf("%d block size expect: %d, got: %d", d.t, d.bsize, blocksize)
} }
ivsize := GetCipherIVSize(d.t) ivsize := GetCipherIVSize(d.t)
if ivsize != d.isize { if ivsize != d.isize {
t.Errorf("%d iv size expect: %d, got: %d", d.t, d.bsize, ivsize) t.Errorf("%d iv size expect: %d, got: %d", d.t, d.bsize, ivsize)
} }
} }
} }
func TestEncryptDecrypt(t *testing.T) { func TestEncryptDecrypt(t *testing.T) {
cipherName := GNUTLS_CIPHER_AES_256_CBC cipherName := GNUTLS_CIPHER_AES_256_CBC
keysize := GetCipherKeySize(cipherName) keysize := GetCipherKeySize(cipherName)
ivsize := GetCipherIVSize(cipherName) ivsize := GetCipherIVSize(cipherName)
blocksize := GetCipherBlockSize(cipherName) blocksize := GetCipherBlockSize(cipherName)
key := make([]byte, keysize) key := make([]byte, keysize)
iv := make([]byte, ivsize) iv := make([]byte, ivsize)
rand.Reader.Read(key) rand.Reader.Read(key)
rand.Reader.Read(iv) rand.Reader.Read(iv)
c, err := NewCipher(cipherName, key, iv) c, err := NewCipher(cipherName, key, iv)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
//defer c.Close() //defer c.Close()
c1, err := NewCipher(cipherName, key, iv) c1, err := NewCipher(cipherName, key, iv)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
//defer c1.Close() //defer c1.Close()
data := make([]byte, blocksize*10) data := make([]byte, blocksize*10)
if c == nil { if c == nil {
t.Fatal("new ciphoer failed") t.Fatal("new ciphoer failed")
} }
cdata := make([]byte, len(data)) cdata := make([]byte, len(data))
err = c.Encrypt(cdata, data) err = c.Encrypt(cdata, data)
if err != nil { if err != nil {
t.Fatal("encrypt failed", err) t.Fatal("encrypt failed", err)
} }
data1 := make([]byte, len(data)) data1 := make([]byte, len(data))
err = c1.Decrypt(data1, cdata) err = c1.Decrypt(data1, cdata)
if err != nil { if err != nil {
t.Fatal("decrypt failed", err) t.Fatal("decrypt failed", err)
} }
if !bytes.Equal(data, data1) { if !bytes.Equal(data, data1) {
t.Fatal("encrypt/decrypt failed", string(data), string(data1)) t.Fatal("encrypt/decrypt failed", string(data), string(data1))
} }
block, _ := aes.NewCipher(key) block, _ := aes.NewCipher(key)
mode := cipher.NewCBCEncrypter(block, iv) mode := cipher.NewCBCEncrypter(block, iv)
dst := make([]byte, len(data)) dst := make([]byte, len(data))
mode.CryptBlocks(dst, data) mode.CryptBlocks(dst, data)
if !bytes.Equal(dst, cdata) { if !bytes.Equal(dst, cdata) {
t.Fatal("cipher text not equal to cypto/aes") t.Fatal("cipher text not equal to cypto/aes")
} }
runtime.GC() runtime.GC()
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
} }
func BenchmarkAESEncrypt(b *testing.B) { func BenchmarkAESEncrypt(b *testing.B) {
cipherName := GNUTLS_CIPHER_AES_256_CBC cipherName := GNUTLS_CIPHER_AES_256_CBC
keysize := GetCipherKeySize(cipherName) keysize := GetCipherKeySize(cipherName)
ivsize := GetCipherIVSize(cipherName) ivsize := GetCipherIVSize(cipherName)
blocksize := GetCipherBlockSize(cipherName) blocksize := GetCipherBlockSize(cipherName)
datalen := blocksize * 500 datalen := blocksize * 500
key := make([]byte, keysize) key := make([]byte, keysize)
iv := make([]byte, ivsize) iv := make([]byte, ivsize)
rand.Reader.Read(key) rand.Reader.Read(key)
rand.Reader.Read(iv) rand.Reader.Read(iv)
buf := make([]byte, datalen) buf := make([]byte, datalen)
dst := make([]byte, datalen) dst := make([]byte, datalen)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
c, err := NewCipher(cipherName, key, iv) c, err := NewCipher(cipherName, key, iv)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
c.Encrypt(dst, buf) c.Encrypt(dst, buf)
c.Close() c.Close()
} }
} }
func BenchmarkAESEncrypt2(b *testing.B) { func BenchmarkAESEncrypt2(b *testing.B) {
cipherName := GNUTLS_CIPHER_AES_256_CBC cipherName := GNUTLS_CIPHER_AES_256_CBC
keysize := GetCipherKeySize(cipherName) keysize := GetCipherKeySize(cipherName)
ivsize := GetCipherIVSize(cipherName) ivsize := GetCipherIVSize(cipherName)
blocksize := GetCipherBlockSize(cipherName) blocksize := GetCipherBlockSize(cipherName)
datalen := blocksize * 500 datalen := blocksize * 500
key := make([]byte, keysize) key := make([]byte, keysize)
iv := make([]byte, ivsize) iv := make([]byte, ivsize)
buf := make([]byte, datalen) buf := make([]byte, datalen)
rand.Reader.Read(buf) rand.Reader.Read(buf)
rand.Reader.Read(key) rand.Reader.Read(key)
rand.Reader.Read(iv) rand.Reader.Read(iv)
dst := make([]byte, datalen) dst := make([]byte, datalen)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
block, _ := aes.NewCipher(key) block, _ := aes.NewCipher(key)
mode := cipher.NewCBCEncrypter(block, iv) mode := cipher.NewCBCEncrypter(block, iv)
mode.CryptBlocks(dst, buf) mode.CryptBlocks(dst, buf)
} }
} }

216
doc.go

@ -1,108 +1,108 @@
// Package gnutls is a gnutls binding for golang. // Package gnutls is a gnutls binding for golang.
/* /*
a limit set of api is supported. a limit set of api is supported.
TLS api is very similar to crypto/tls on standard library. TLS api is very similar to crypto/tls on standard library.
TLS client example: TLS client example:
addr := "127.0.0.1:9443" addr := "127.0.0.1:9443"
c, err := gnutls.Dial("tcp", addr, &gnutls.Config{ServerName: "localhost",InsecureSkipVerify: true}) c, err := gnutls.Dial("tcp", addr, &gnutls.Config{ServerName: "localhost",InsecureSkipVerify: true})
if err != nil { if err != nil {
log.Fatal("gnutls dial ", err) log.Fatal("gnutls dial ", err)
} }
defer c.Close() defer c.Close()
data := "hello, world" data := "hello, world"
if _, err = c.Write([]byte(data)); err != nil { if _, err = c.Write([]byte(data)); err != nil {
log.Fatal("gnutls write ", err) log.Fatal("gnutls write ", err)
} }
buf := make([]byte, 100) buf := make([]byte, 100)
n, err := c.Read(buf) n, err := c.Read(buf)
if err != nil { if err != nil {
log.Fatal("gnutls read ", err) log.Fatal("gnutls read ", err)
} }
TLS Server example: TLS Server example:
cert, err := gnutls.LoadX509KeyPair("testdata/server/crt", "testdata/server.key") cert, err := gnutls.LoadX509KeyPair("testdata/server/crt", "testdata/server.key")
if err != nil{ if err != nil{
// handle error // handle error
} }
l, err := gnults.Listen("tcp", "127.0.0.1:9443", &gnutls.Config{ l, err := gnults.Listen("tcp", "127.0.0.1:9443", &gnutls.Config{
Certificates: []*gnutls.Certificate{cert}}) Certificates: []*gnutls.Certificate{cert}})
if err != nil { if err != nil {
// handle error // handle error
} }
defer l.Close() defer l.Close()
for { for {
c, err := l.Accept() c, err := l.Accept()
if err != nil { if err != nil {
log.Println("gnutls accept ", err) log.Println("gnutls accept ", err)
break break
} }
log.Println("accept connection from ", c.RemoteAddr()) log.Println("accept connection from ", c.RemoteAddr())
go func(c net.Conn) { go func(c net.Conn) {
defer c.Close() defer c.Close()
buf := make([]byte, 4096) buf := make([]byte, 4096)
for { for {
n, err := c.Read(buf[0:]) n, err := c.Read(buf[0:])
if err != nil { if err != nil {
log.Println("gnutls read ", err) log.Println("gnutls read ", err)
break break
} }
if _, err := c.Write(buf[:n]); err != nil { if _, err := c.Write(buf[:n]); err != nil {
log.Println("gnutls write ", err) log.Println("gnutls write ", err)
break break
} }
} }
}(c) }(c)
} }
AES encrypt/decrypt example: AES encrypt/decrypt example:
key := []byte("0123456789abcdef") key := []byte("0123456789abcdef")
iv := []byte("abcdefg123456789") iv := []byte("abcdefg123456789")
c, err := gnutls.NewCipher(gnutls.GNUTLS_CIPHER_AES_128_CBC, key, iv) c, err := gnutls.NewCipher(gnutls.GNUTLS_CIPHER_AES_128_CBC, key, iv)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
defer c.Close() defer c.Close()
c1, err := gnutls.NewCipher(gnutls.GNUTLS_CIPHER_AES_128_CBC, key, iv) c1, err := gnutls.NewCipher(gnutls.GNUTLS_CIPHER_AES_128_CBC, key, iv)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
defer c1.Close() defer c1.Close()
data := []byte("1234012121212121") data := []byte("1234012121212121")
if c == nil { if c == nil {
log.Fatal("new ciphoer failed") log.Fatal("new ciphoer failed")
} }
// encrypt // encrypt
dst := make([]byte, len(data)) dst := make([]byte, len(data))
err := c.Encrypt(dst, data) err := c.Encrypt(dst, data)
if err != nil { if err != nil {
log.Fatal("encrypt failed", err) log.Fatal("encrypt failed", err)
} }
// decrypt // decrypt
data1 := make([]byte, len(data)) data1 := make([]byte, len(data))
err := c1.Decrypt(data1, cdata) err := c1.Decrypt(data1, cdata)
if err != nil { if err != nil {
log.Fatal("decrypt failed", err) log.Fatal("decrypt failed", err)
} }
Hash example: Hash example:
h := gnutls.NewHash(gnutls.GNUTLS_HASH_SHA512) h := gnutls.NewHash(gnutls.GNUTLS_HASH_SHA512)
defer h.Close() defer h.Close()
data := []byte("1234") data := []byte("1234")
h1 := h.Sum(data) h1 := h.Sum(data)
*/ */
package gnutls package gnutls

@ -1,90 +1,90 @@
package gnutls package gnutls
/* /*
#include <stdlib.h> #include <stdlib.h>
#include "_gnutls.h" #include "_gnutls.h"
*/ */
import "C" import "C"
import ( import (
"fmt" "fmt"
"runtime" "runtime"
) )
// HashType hash type // HashType hash type
type HashType int type HashType int
const ( const (
GNUTLS_HASH_MD5 HashType = 2 GNUTLS_HASH_MD5 HashType = 2
GNUTLS_HASH_SHA1 HashType = 3 GNUTLS_HASH_SHA1 HashType = 3
GNUTLS_HASH_MD2 HashType = 5 GNUTLS_HASH_MD2 HashType = 5
GNUTLS_HASH_SHA256 HashType = 6 GNUTLS_HASH_SHA256 HashType = 6
GNUTLS_HASH_SHA384 HashType = 7 GNUTLS_HASH_SHA384 HashType = 7
GNUTLS_HASH_SHA512 HashType = 8 GNUTLS_HASH_SHA512 HashType = 8
GNUTLS_HASH_SHA224 HashType = 9 GNUTLS_HASH_SHA224 HashType = 9
) )
// Hash hash struct // Hash hash struct
type Hash struct { type Hash struct {
hash C.gnutls_hash_hd_t hash C.gnutls_hash_hd_t
t HashType t HashType
hashLen C.int hashLen C.int
} }
// NewHash new hash struct // NewHash new hash struct
func NewHash(t HashType) *Hash { func NewHash(t HashType) *Hash {
h := C.new_hash(C.int(t)) h := C.new_hash(C.int(t))
hashOutLen := GetHashOutputLen(t) hashOutLen := GetHashOutputLen(t)
hash := &Hash{h, t, C.int(hashOutLen)} hash := &Hash{h, t, C.int(hashOutLen)}
runtime.SetFinalizer(hash, (*Hash).free) runtime.SetFinalizer(hash, (*Hash).free)
return hash return hash
} }
// Write write data to hash context // Write write data to hash context
func (h *Hash) Write(buf []byte) error { func (h *Hash) Write(buf []byte) error {
dataLen := len(buf) dataLen := len(buf)
cbuf := C.CBytes(buf) cbuf := C.CBytes(buf)
defer C.free(cbuf) defer C.free(cbuf)
ret := C.gnutls_hash(h.hash, cbuf, C.size_t(dataLen)) ret := C.gnutls_hash(h.hash, cbuf, C.size_t(dataLen))
if int(ret) < 0 { if int(ret) < 0 {
return fmt.Errorf("hash failed: %s", C.GoString(C.gnutls_strerror(ret))) return fmt.Errorf("hash failed: %s", C.GoString(C.gnutls_strerror(ret)))
} }
return nil return nil
} }
// Sum get hash result // Sum get hash result
func (h *Hash) Sum(buf []byte) []byte { func (h *Hash) Sum(buf []byte) []byte {
if buf != nil { if buf != nil {
h.Write(buf) h.Write(buf)
} }
dstBuf := C.malloc(C.size_t(h.hashLen)) dstBuf := C.malloc(C.size_t(h.hashLen))
defer C.free(dstBuf) defer C.free(dstBuf)
C.gnutls_hash_output(h.hash, dstBuf) C.gnutls_hash_output(h.hash, dstBuf)
gobuf := C.GoBytes(dstBuf, h.hashLen) gobuf := C.GoBytes(dstBuf, h.hashLen)
return gobuf return gobuf
} }
// Close destroy hash context // Close destroy hash context
func (h *Hash) Close() error { func (h *Hash) Close() error {
if h.hash != nil { if h.hash != nil {
C.gnutls_hash_deinit(h.hash, nil) C.gnutls_hash_deinit(h.hash, nil)
h.hash = nil h.hash = nil
} }
return nil return nil
} }
func (h *Hash) free() { func (h *Hash) free() {
//log.Println("free hash") //log.Println("free hash")
h.Close() h.Close()
} }
// GetHashOutputLen get the hash algorithm output length // GetHashOutputLen get the hash algorithm output length
// //
// example GNUTLS_MD5 is 16 // example GNUTLS_MD5 is 16
func GetHashOutputLen(t HashType) int { func GetHashOutputLen(t HashType) int {
return int(C.gnutls_hash_get_len(C.gnutls_digest_algorithm_t(t))) return int(C.gnutls_hash_get_len(C.gnutls_digest_algorithm_t(t)))
} }

@ -1,52 +1,52 @@
package gnutls package gnutls
import ( import (
"bytes" "bytes"
"crypto/rand" "crypto/rand"
"crypto/sha512" "crypto/sha512"
"encoding/hex" "encoding/hex"
"runtime" "runtime"
"testing" "testing"
"time" "time"
) )
func TestHashSHA(t *testing.T) { func TestHashSHA(t *testing.T) {
h := NewHash(GNUTLS_HASH_SHA512) h := NewHash(GNUTLS_HASH_SHA512)
//defer h.Close() //defer h.Close()
data := []byte("1234") data := []byte("1234")
h1 := h.Sum(data) h1 := h.Sum(data)
h3 := sha512.New() h3 := sha512.New()
h3.Write(data) h3.Write(data)
h4 := h3.Sum(nil) h4 := h3.Sum(nil)
if !bytes.Equal(h4[:], h1) { if !bytes.Equal(h4[:], h1) {
t.Logf("\n%s\n%s", hex.EncodeToString(h4[:]), hex.EncodeToString(h1)) t.Logf("\n%s\n%s", hex.EncodeToString(h4[:]), hex.EncodeToString(h1))
t.Fatal("hash not equal") t.Fatal("hash not equal")
} }
runtime.GC() runtime.GC()
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
} }
func BenchmarkHashSHA512(b *testing.B) { func BenchmarkHashSHA512(b *testing.B) {
buf := make([]byte, 1000*1024) buf := make([]byte, 1000*1024)
rand.Reader.Read(buf) rand.Reader.Read(buf)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
h := NewHash(GNUTLS_HASH_SHA512) h := NewHash(GNUTLS_HASH_SHA512)
h.Write(buf) h.Write(buf)
h.Sum(nil) h.Sum(nil)
h.Close() h.Close()
} }
} }
func BenchmarkHashSHA512s(b *testing.B) { func BenchmarkHashSHA512s(b *testing.B) {
buf := make([]byte, 1000*1024) buf := make([]byte, 1000*1024)
rand.Reader.Read(buf) rand.Reader.Read(buf)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
h := sha512.New() h := sha512.New()
h.Write(buf) h.Write(buf)
h.Sum(nil) h.Sum(nil)
//h.Close() //h.Close()
} }
} }

984
tls.go

@ -1,492 +1,492 @@
package gnutls package gnutls
/* /*
#include "_gnutls.h" #include "_gnutls.h"
#cgo pkg-config: gnutls #cgo pkg-config: gnutls
*/ */
import "C" import "C"
import ( import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log" "log"
"net" "net"
"runtime" "runtime"
"sync" "sync"
"time" "time"
"unsafe" "unsafe"
) )
const ( const (
GNUTLS_NAME_DNS = 1 GNUTLS_NAME_DNS = 1
GNUTLS_X509_FMT_PEM = 1 GNUTLS_X509_FMT_PEM = 1
GNUTLS_ALPN_MANDATORY = 1 GNUTLS_ALPN_MANDATORY = 1
GNUTLS_ALPN_SERVER_PRECEDENCE = 1 << 1 GNUTLS_ALPN_SERVER_PRECEDENCE = 1 << 1
) )
// Conn gnutls TLS connection, // Conn gnutls TLS connection,
// use Listen, Dial, Server or Client create this object // use Listen, Dial, Server or Client create this object
type Conn struct { type Conn struct {
c net.Conn c net.Conn
handshake bool handshake bool
sess *C.struct_session sess *C.struct_session
cservname *C.char cservname *C.char
state *ConnectionState state *ConnectionState
cfg *Config cfg *Config
closed bool closed bool
lock *sync.Mutex lock *sync.Mutex
} }
// Config gnutls TLS configure, // Config gnutls TLS configure,
type Config struct { type Config struct {
ServerName string ServerName string
Certificates []*Certificate Certificates []*Certificate
InsecureSkipVerify bool InsecureSkipVerify bool
NextProtos []string NextProtos []string
} }
// ConnectionState gnutls TLS connection state // ConnectionState gnutls TLS connection state
type ConnectionState struct { type ConnectionState struct {
// SNI name client send // SNI name client send
ServerName string ServerName string
// selected ALPN protocl // selected ALPN protocl
NegotiatedProtocol string NegotiatedProtocol string
HandshakeComplete bool HandshakeComplete bool
// TLS version number, ex: 0x303 // TLS version number, ex: 0x303
Version uint16 Version uint16
// TLS version number, ex: TLS1.0 // TLS version number, ex: TLS1.0
VersionName string VersionName string
// peer's certificate // peer's certificate
PeerCertificate *Certificate PeerCertificate *Certificate
} }
type listener struct { type listener struct {
l net.Listener l net.Listener
c *Config c *Config
} }
// Accept // Accept
func (l *listener) Accept() (net.Conn, error) { func (l *listener) Accept() (net.Conn, error) {
c, err := l.l.Accept() c, err := l.l.Accept()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return Server(c, l.c) return Server(c, l.c)
} }
// Close // Close
func (l *listener) Close() error { func (l *listener) Close() error {
return l.l.Close() return l.l.Close()
} }
// Addr // Addr
func (l *listener) Addr() net.Addr { func (l *listener) Addr() net.Addr {
return l.l.Addr() return l.l.Addr()
} }
// Dial dial to (network, addr) and create a gnutls Conn // Dial dial to (network, addr) and create a gnutls Conn
func Dial(network, addr string, cfg *Config) (*Conn, error) { func Dial(network, addr string, cfg *Config) (*Conn, error) {
c, err := net.Dial(network, addr) c, err := net.Dial(network, addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return Client(c, cfg) return Client(c, cfg)
} }
// Listen create a gnutls listener on (network, addr), // Listen create a gnutls listener on (network, addr),
func Listen(network, addr string, cfg *Config) (net.Listener, error) { func Listen(network, addr string, cfg *Config) (net.Listener, error) {
if cfg == nil { if cfg == nil {
return nil, errors.New("config is needed") return nil, errors.New("config is needed")
} }
l, err := net.Listen(network, addr) l, err := net.Listen(network, addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &listener{l, cfg}, nil return &listener{l, cfg}, nil
} }
// Server create a server TLS Conn on c // Server create a server TLS Conn on c
func Server(c net.Conn, cfg *Config) (*Conn, error) { func Server(c net.Conn, cfg *Config) (*Conn, error) {
if cfg == nil { if cfg == nil {
return nil, errors.New("config is needed") return nil, errors.New("config is needed")
} }
var sess = C.init_gnutls_server_session() var sess = C.init_gnutls_server_session()
conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)} conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)}
C.set_data(sess, C.size_t(uintptr(unsafe.Pointer(conn)))) C.set_data(sess, C.size_t(uintptr(unsafe.Pointer(conn))))
C.set_callback(sess) C.set_callback(sess)
if cfg.NextProtos != nil { if cfg.NextProtos != nil {
if err := setAlpnProtocols(sess, cfg); err != nil { if err := setAlpnProtocols(sess, cfg); err != nil {
log.Println(err) log.Println(err)
} }
} }
runtime.SetFinalizer(conn, (*Conn).free) runtime.SetFinalizer(conn, (*Conn).free)
return conn, nil return conn, nil
} }
// Client create a client TLS Conn on c // Client create a client TLS Conn on c
func Client(c net.Conn, cfg *Config) (*Conn, error) { func Client(c net.Conn, cfg *Config) (*Conn, error) {
var sess = C.init_gnutls_client_session() var sess = C.init_gnutls_client_session()
conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)} conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)}
C.set_data(sess, C.size_t(uintptr(unsafe.Pointer(conn)))) C.set_data(sess, C.size_t(uintptr(unsafe.Pointer(conn))))
C.set_callback(sess) C.set_callback(sess)
if cfg != nil { if cfg != nil {
if cfg.ServerName != "" { if cfg.ServerName != "" {
srvname := C.CString(cfg.ServerName) srvname := C.CString(cfg.ServerName)
//defer C.free(unsafe.Pointer(srvname)) //defer C.free(unsafe.Pointer(srvname))
conn.cservname = srvname conn.cservname = srvname
C.gnutls_server_name_set(sess.session, GNUTLS_NAME_DNS, C.gnutls_server_name_set(sess.session, GNUTLS_NAME_DNS,
unsafe.Pointer(srvname), C.size_t(len(cfg.ServerName))) unsafe.Pointer(srvname), C.size_t(len(cfg.ServerName)))
} }
if !cfg.InsecureSkipVerify { if !cfg.InsecureSkipVerify {
if conn.cservname != nil { if conn.cservname != nil {
C.gnutls_session_set_verify_cert(sess.session, conn.cservname, 0) C.gnutls_session_set_verify_cert(sess.session, conn.cservname, 0)
} else { } else {
C.gnutls_session_set_verify_cert(sess.session, nil, 0) C.gnutls_session_set_verify_cert(sess.session, nil, 0)
} }
} }
if cfg.NextProtos != nil { if cfg.NextProtos != nil {
if err := setAlpnProtocols(sess, cfg); err != nil { if err := setAlpnProtocols(sess, cfg); err != nil {
log.Println(err) log.Println(err)
} }
} }
} else { } else {
C.gnutls_session_set_verify_cert(sess.session, nil, 0) C.gnutls_session_set_verify_cert(sess.session, nil, 0)
} }
runtime.SetFinalizer(conn, (*Conn).free) runtime.SetFinalizer(conn, (*Conn).free)
return conn, nil return conn, nil
} }
func setAlpnProtocols(sess *C.struct_session, cfg *Config) error { func setAlpnProtocols(sess *C.struct_session, cfg *Config) error {
arg := make([](*C.char), 0) arg := make([](*C.char), 0)
for _, s := range cfg.NextProtos { for _, s := range cfg.NextProtos {
cbuf := C.CString(s) cbuf := C.CString(s)
defer C.free(unsafe.Pointer(cbuf)) defer C.free(unsafe.Pointer(cbuf))
arg = append(arg, (*C.char)(cbuf)) arg = append(arg, (*C.char)(cbuf))
} }
ret := C.alpn_set_protocols(sess, ret := C.alpn_set_protocols(sess,
(**C.char)(unsafe.Pointer(&arg[0])), C.int(len(cfg.NextProtos))) (**C.char)(unsafe.Pointer(&arg[0])), C.int(len(cfg.NextProtos)))
if int(ret) < 0 { if int(ret) < 0 {
return fmt.Errorf("set alpn failed: %s", C.GoString(C.gnutls_strerror(ret))) return fmt.Errorf("set alpn failed: %s", C.GoString(C.gnutls_strerror(ret)))
} }
return nil return nil
} }
// Handshake call handshake for TLS Conn, // Handshake call handshake for TLS Conn,
// this function will call automatic on Read/Write, if not handshake yet // this function will call automatic on Read/Write, if not handshake yet
func (c *Conn) Handshake() error { func (c *Conn) Handshake() error {
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
if c.handshake { if c.handshake {
return nil return nil
} }
ret := C.handshake(c.sess) ret := C.handshake(c.sess)
if int(ret) < 0 { if int(ret) < 0 {
return fmt.Errorf("handshake error: %s", C.GoString(C.gnutls_strerror(ret))) return fmt.Errorf("handshake error: %s", C.GoString(C.gnutls_strerror(ret)))
} }
c.handshake = true c.handshake = true
//log.Println("handshake done") //log.Println("handshake done")
return nil return nil
} }
// Read read application data from TLS connection // Read read application data from TLS connection
func (c *Conn) Read(buf []byte) (n int, err error) { func (c *Conn) Read(buf []byte) (n int, err error) {
err = c.Handshake() err = c.Handshake()
if err != nil { if err != nil {
return return
} }
if len(buf) == 0 { if len(buf) == 0 {
return 0, nil return 0, nil
} }
if c.closed { if c.closed {
return 0, io.EOF return 0, io.EOF
} }
//bufLen := len(buf) //bufLen := len(buf)
//cbuf := C.malloc(C.size_t(bufLen)) //cbuf := C.malloc(C.size_t(bufLen))
//defer C.free(cbuf) //defer C.free(cbuf)
ret := C.gnutls_record_recv(c.sess.session, ret := C.gnutls_record_recv(c.sess.session,
unsafe.Pointer(&buf[0]), C.size_t(len(buf))) unsafe.Pointer(&buf[0]), C.size_t(len(buf)))
if int(ret) < 0 { if int(ret) < 0 {
return 0, fmt.Errorf("read error: %s", return 0, fmt.Errorf("read error: %s",
C.GoString(C.gnutls_strerror(C.int(ret)))) C.GoString(C.gnutls_strerror(C.int(ret))))
} }
if int(ret) == 0 { if int(ret) == 0 {
return 0, io.EOF return 0, io.EOF
} }
n = int(ret) n = int(ret)
//gobuf2 := C.GoBytes(cbuf, C.int(ret)) //gobuf2 := C.GoBytes(cbuf, C.int(ret))
//copy(buf, gobuf2) //copy(buf, gobuf2)
return n, nil return n, nil
} }
// Write write application data to TLS connection // Write write application data to TLS connection
func (c *Conn) Write(buf []byte) (n int, err error) { func (c *Conn) Write(buf []byte) (n int, err error) {
err = c.Handshake() err = c.Handshake()
if err != nil { if err != nil {
return return
} }
// user may call Write(nil) to do handshake // user may call Write(nil) to do handshake
if len(buf) == 0 { if len(buf) == 0 {
return 0, nil return 0, nil
} }
if c.closed { if c.closed {
return 0, io.EOF return 0, io.EOF
} }
//cbuf := C.CBytes(buf) //cbuf := C.CBytes(buf)
//defer C.free(cbuf) //defer C.free(cbuf)
ret := C.gnutls_record_send(c.sess.session, ret := C.gnutls_record_send(c.sess.session,
unsafe.Pointer(&buf[0]), C.size_t(len(buf))) unsafe.Pointer(&buf[0]), C.size_t(len(buf)))
n = int(ret) n = int(ret)
if n < 0 { if n < 0 {
return 0, fmt.Errorf("write error: %s", return 0, fmt.Errorf("write error: %s",
C.GoString(C.gnutls_strerror(C.int(ret)))) C.GoString(C.gnutls_strerror(C.int(ret))))
} }
if int(ret) == 0 { if int(ret) == 0 {
return 0, io.EOF return 0, io.EOF
} }
return n, nil return n, nil
} }
// Close close the TLS conn and destroy the tls context // Close close the TLS conn and destroy the tls context
func (c *Conn) Close() error { func (c *Conn) Close() error {
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
if c.closed { if c.closed {
return nil return nil
} }
C.gnutls_record_send(c.sess.session, nil, 0) C.gnutls_record_send(c.sess.session, nil, 0)
C.session_destroy(c.sess) C.session_destroy(c.sess)
c.c.Close() c.c.Close()
if c.cservname != nil { if c.cservname != nil {
C.free(unsafe.Pointer(c.cservname)) C.free(unsafe.Pointer(c.cservname))
} }
if c.state != nil && c.state.PeerCertificate != nil { if c.state != nil && c.state.PeerCertificate != nil {
c.state.PeerCertificate.Free() c.state.PeerCertificate.Free()
} }
c.closed = true c.closed = true
return nil return nil
} }
func (c *Conn) free() { func (c *Conn) free() {
//log.Println("free conn") //log.Println("free conn")
c.Close() c.Close()
} }
// SetWriteDeadline implements net.Conn // SetWriteDeadline implements net.Conn
func (c *Conn) SetWriteDeadline(t time.Time) error { func (c *Conn) SetWriteDeadline(t time.Time) error {
return c.c.SetWriteDeadline(t) return c.c.SetWriteDeadline(t)
} }
// SetReadDeadline implements net.Conn // SetReadDeadline implements net.Conn
func (c *Conn) SetReadDeadline(t time.Time) error { func (c *Conn) SetReadDeadline(t time.Time) error {
return c.c.SetReadDeadline(t) return c.c.SetReadDeadline(t)
} }
// RemoteAddr implements net.Conn // RemoteAddr implements net.Conn
func (c *Conn) RemoteAddr() net.Addr { func (c *Conn) RemoteAddr() net.Addr {
return c.c.RemoteAddr() return c.c.RemoteAddr()
} }
// LocalAddr implements net.Conn // LocalAddr implements net.Conn
func (c *Conn) LocalAddr() net.Addr { func (c *Conn) LocalAddr() net.Addr {
return c.c.LocalAddr() return c.c.LocalAddr()
} }
// SetDeadline implements net.Conn // SetDeadline implements net.Conn
func (c *Conn) SetDeadline(t time.Time) error { func (c *Conn) SetDeadline(t time.Time) error {
return c.c.SetDeadline(t) return c.c.SetDeadline(t)
} }
// ConnectionState get TLS connection state // ConnectionState get TLS connection state
func (c *Conn) ConnectionState() *ConnectionState { func (c *Conn) ConnectionState() *ConnectionState {
if c.state != nil { if c.state != nil {
return c.state return c.state
} }
version := version :=
uint16(C.gnutls_protocol_get_version(c.sess.session)) uint16(C.gnutls_protocol_get_version(c.sess.session))
versionname := C.GoString( versionname := C.GoString(
C.gnutls_protocol_get_name(C.gnutls_protocol_t(version))) C.gnutls_protocol_get_name(C.gnutls_protocol_t(version)))
state := &ConnectionState{ state := &ConnectionState{
NegotiatedProtocol: c.getAlpnSelectedProtocol(), NegotiatedProtocol: c.getAlpnSelectedProtocol(),
Version: version, Version: version,
HandshakeComplete: c.handshake, HandshakeComplete: c.handshake,
ServerName: c.getServerName(), ServerName: c.getServerName(),
VersionName: versionname, VersionName: versionname,
PeerCertificate: c.getPeerCertificate(), PeerCertificate: c.getPeerCertificate(),
} }
c.state = state c.state = state
return state return state
} }
func (c *Conn) getPeerCertificate() *Certificate { func (c *Conn) getPeerCertificate() *Certificate {
var size int var size int
st := C.get_peer_certificate(c.sess.session, (*C.int)(unsafe.Pointer(&size))) st := C.get_peer_certificate(c.sess.session, (*C.int)(unsafe.Pointer(&size)))
if st == nil { if st == nil {
return nil return nil
} }
cert := &Certificate{cert: st, certSize: C.int(size)} cert := &Certificate{cert: st, certSize: C.int(size)}
runtime.SetFinalizer(cert, (*Certificate).free) runtime.SetFinalizer(cert, (*Certificate).free)
return cert return cert
} }
func (c *Conn) getAlpnSelectedProtocol() string { func (c *Conn) getAlpnSelectedProtocol() string {
cbuf := C.malloc(100) cbuf := C.malloc(100)
defer C.free(cbuf) defer C.free(cbuf)
ret := C.alpn_get_selected_protocol(c.sess, (*C.char)(cbuf)) ret := C.alpn_get_selected_protocol(c.sess, (*C.char)(cbuf))
if int(ret) < 0 { if int(ret) < 0 {
return "" return ""
} }
alpnname := C.GoString((*C.char)(cbuf)) alpnname := C.GoString((*C.char)(cbuf))
return alpnname return alpnname
} }
func (c *Conn) getServerName() string { func (c *Conn) getServerName() string {
buflen := 100 buflen := 100
nametype := GNUTLS_NAME_DNS nametype := GNUTLS_NAME_DNS
cbuf := C.malloc(C.size_t(buflen)) cbuf := C.malloc(C.size_t(buflen))
defer C.free(cbuf) defer C.free(cbuf)
ret := C.gnutls_server_name_get(c.sess.session, cbuf, ret := C.gnutls_server_name_get(c.sess.session, cbuf,
(*C.size_t)(unsafe.Pointer(&buflen)), (*C.size_t)(unsafe.Pointer(&buflen)),
(*C.uint)(unsafe.Pointer(&nametype)), 0) (*C.uint)(unsafe.Pointer(&nametype)), 0)
if int(ret) < 0 { if int(ret) < 0 {
return "" return ""
} }
name := C.GoString((*C.char)(cbuf)) name := C.GoString((*C.char)(cbuf))
return name return name
} }
/* /*
var bufPool = &sync.Pool{ var bufPool = &sync.Pool{
New: func() interface{} { New: func() interface{} {
return make([]byte, 16*1024) return make([]byte, 16*1024)
}, },
} }
*/ */
// onDataReadCallback callback function for gnutls library want to read data from network // onDataReadCallback callback function for gnutls library want to read data from network
// //
//export onDataReadCallback //export onDataReadCallback
func onDataReadCallback(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int { func onDataReadCallback(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int {
//log.Println("read addr ", uintptr(d)) //log.Println("read addr ", uintptr(d))
conn := (*Conn)(unsafe.Pointer((uintptr(d)))) conn := (*Conn)(unsafe.Pointer((uintptr(d))))
buf := make([]byte, int(bufLen)) buf := make([]byte, int(bufLen))
/* /*
_length := int(bufLen) _length := int(bufLen)
buf := bufPool.Get().([]byte) buf := bufPool.Get().([]byte)
if len(buf) < _length { if len(buf) < _length {
buf = make([]byte, _length) buf = make([]byte, _length)
} }
defer bufPool.Put(buf) defer bufPool.Put(buf)
*/ */
n, err := conn.c.Read(buf[0:]) n, err := conn.c.Read(buf[0:])
if err != nil { if err != nil {
//log.Println(err) //log.Println(err)
// 0 indicates connection termination // 0 indicates connection termination
return 0 return 0
} }
//cbuf2 := C.CBytes(buf[:n]) //cbuf2 := C.CBytes(buf[:n])
// d := C.CString(string(buf[:n])) // d := C.CString(string(buf[:n]))
//defer C.free(cbuf2) //defer C.free(cbuf2)
C.memcpy(unsafe.Pointer(cbuf), unsafe.Pointer(&buf[0]), C.size_t(n)) C.memcpy(unsafe.Pointer(cbuf), unsafe.Pointer(&buf[0]), C.size_t(n))
return C.int(n) return C.int(n)
} }
// onDataWriteCallback callback function for gnutls library want to send data to network // onDataWriteCallback callback function for gnutls library want to send data to network
// //
//export onDataWriteCallback //export onDataWriteCallback
func onDataWriteCallback(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int { func onDataWriteCallback(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int {
//log.Println("write addr ", uintptr(d), int(_l)) //log.Println("write addr ", uintptr(d), int(_l))
conn := (*Conn)(unsafe.Pointer((uintptr(d)))) conn := (*Conn)(unsafe.Pointer((uintptr(d))))
gobuf := C.GoBytes(unsafe.Pointer(cbuf), bufLen) gobuf := C.GoBytes(unsafe.Pointer(cbuf), bufLen)
n, err := conn.c.Write(gobuf) n, err := conn.c.Write(gobuf)
if err != nil { if err != nil {
//log.Println(err) //log.Println(err)
return -1 return -1
} }
return C.int(n) return C.int(n)
} }
// onDataTimeoutRead callback function for timeout read // onDataTimeoutRead callback function for timeout read
// //
//export onDataTimeoutRead //export onDataTimeoutRead
func onDataTimeoutRead(d unsafe.Pointer, delay C.int) C.int { func onDataTimeoutRead(d unsafe.Pointer, delay C.int) C.int {
log.Println("timeout pull function") log.Println("timeout pull function")
return 0 return 0
} }
// onCertSelectCallback callback function for ceritificate select, // onCertSelectCallback callback function for ceritificate select,
// this function select certificate from Config.Certificates field, // this function select certificate from Config.Certificates field,
// //
// on server side, this function select the certificate depend on SNI what client send, // on server side, this function select the certificate depend on SNI what client send,
// if client not send SNI, select the Config.Certificates[0] // if client not send SNI, select the Config.Certificates[0]
// //
//export onCertSelectCallback //export onCertSelectCallback
func onCertSelectCallback(ptr unsafe.Pointer, hostname *C.char, func onCertSelectCallback(ptr unsafe.Pointer, hostname *C.char,
namelen C.int, pcertLength *C.int, cert **C.gnutls_pcert_st, privkey *C.gnutls_privkey_t) C.int { namelen C.int, pcertLength *C.int, cert **C.gnutls_pcert_st, privkey *C.gnutls_privkey_t) C.int {
servername := C.GoStringN(hostname, namelen) servername := C.GoStringN(hostname, namelen)
//log.Println("go cert select callback ", servername) //log.Println("go cert select callback ", servername)
conn := (*Conn)(unsafe.Pointer((uintptr(ptr)))) conn := (*Conn)(unsafe.Pointer((uintptr(ptr))))
//log.Println(conn) //log.Println(conn)
if int(namelen) == 0 && conn.cfg.Certificates != nil { if int(namelen) == 0 && conn.cfg.Certificates != nil {
_cert := conn.cfg.Certificates[0] _cert := conn.cfg.Certificates[0]
*pcertLength = _cert.certSize *pcertLength = _cert.certSize
*cert = _cert.cert *cert = _cert.cert
*privkey = _cert.privkey *privkey = _cert.privkey
//log.Println("set pcert length ", _cert.certSize) //log.Println("set pcert length ", _cert.certSize)
return 0 return 0
} }
for _, _cert := range conn.cfg.Certificates { for _, _cert := range conn.cfg.Certificates {
//log.Println(cert) //log.Println(cert)
if _cert.matchName(servername) { if _cert.matchName(servername) {
//log.Println("matched name ", _cert.names) //log.Println("matched name ", _cert.names)
*pcertLength = _cert.certSize *pcertLength = _cert.certSize
*cert = _cert.cert *cert = _cert.cert
*privkey = _cert.privkey *privkey = _cert.privkey
//log.Println("set pcert length ", _cert.certSize) //log.Println("set pcert length ", _cert.certSize)
return 0 return 0
} }
} }
if conn.cfg.Certificates != nil { if conn.cfg.Certificates != nil {
_cert := conn.cfg.Certificates[0] _cert := conn.cfg.Certificates[0]
*pcertLength = _cert.certSize *pcertLength = _cert.certSize
*cert = _cert.cert *cert = _cert.cert
*privkey = _cert.privkey *privkey = _cert.privkey
//log.Println("set pcert length ", _cert.certSize) //log.Println("set pcert length ", _cert.certSize)
return 0 return 0
} }
*pcertLength = 0 *pcertLength = 0
//log.Println("set pcert length 0") //log.Println("set pcert length 0")
return -1 return -1
} }
func init() { func init() {
C.init_xcred() C.init_xcred()
C.init_priority_cache() C.init_priority_cache()
} }

@ -1,431 +1,431 @@
package gnutls package gnutls
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
"runtime" "runtime"
"testing" "testing"
"time" "time"
) )
func TestTLSClient(t *testing.T) { func TestTLSClient(t *testing.T) {
cert, err := tls.LoadX509KeyPair("testdata/server.crt", "testdata/server.key") cert, err := tls.LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
if err != nil { if err != nil {
t.Fatal("load certificate failed") t.Fatal("load certificate failed")
} }
l, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ l, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
Certificates: []tls.Certificate{cert}, Certificates: []tls.Certificate{cert},
}) })
if err != nil { if err != nil {
t.Fatal("listen failed") t.Fatal("listen failed")
} }
defer l.Close() defer l.Close()
addr := l.Addr().String() addr := l.Addr().String()
t.Log("test server listen on ", addr) t.Log("test server listen on ", addr)
go func() { go func() {
for { for {
c, err := l.Accept() c, err := l.Accept()
if err != nil { if err != nil {
break break
} }
t.Logf("accept connection from %s", c.RemoteAddr()) t.Logf("accept connection from %s", c.RemoteAddr())
go func(c net.Conn) { go func(c net.Conn) {
defer c.Close() defer c.Close()
for { for {
buf := make([]byte, 4096) buf := make([]byte, 4096)
n, err := c.Read(buf) n, err := c.Read(buf)
if err != nil { if err != nil {
t.Log("connection closed") t.Log("connection closed")
break break
} }
if _, err = c.Write(buf[:n]); err != nil { if _, err = c.Write(buf[:n]); err != nil {
break break
} }
} }
}(c) }(c)
} }
}() }()
c, err := Dial("tcp", addr, &Config{InsecureSkipVerify: true}) c, err := Dial("tcp", addr, &Config{InsecureSkipVerify: true})
if err != nil { if err != nil {
t.Fatal("gnutls dial ", err) t.Fatal("gnutls dial ", err)
} }
defer c.Close() defer c.Close()
data := "hello, world" data := "hello, world"
if _, err = c.Write([]byte(data)); err != nil { if _, err = c.Write([]byte(data)); err != nil {
t.Fatal("gnutls write ", err) t.Fatal("gnutls write ", err)
} }
buf := make([]byte, 100) buf := make([]byte, 100)
n, err := c.Read(buf) n, err := c.Read(buf)
if err != nil { if err != nil {
t.Fatal("gnutls read ", err) t.Fatal("gnutls read ", err)
} }
if string(buf[:n]) != data { if string(buf[:n]) != data {
t.Errorf("need: %s, got: %s", data, string(buf[:n])) t.Errorf("need: %s, got: %s", data, string(buf[:n]))
} }
} }
func TestTLSServer(t *testing.T) { func TestTLSServer(t *testing.T) {
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key") cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
l, err := Listen("tcp", "127.0.0.1:0", &Config{ l, err := Listen("tcp", "127.0.0.1:0", &Config{
Certificates: []*Certificate{cert}, Certificates: []*Certificate{cert},
}) })
if err != nil { if err != nil {
t.Fatal("gnutls listen ", err) t.Fatal("gnutls listen ", err)
} }
addr := l.Addr().String() addr := l.Addr().String()
t.Log("test server listen on ", addr) t.Log("test server listen on ", addr)
defer l.Close() defer l.Close()
go func() { go func() {
for { for {
c, err := l.Accept() c, err := l.Accept()
if err != nil { if err != nil {
t.Log("gnutls accept ", err) t.Log("gnutls accept ", err)
break break
} }
t.Log("accept connection from ", c.RemoteAddr()) t.Log("accept connection from ", c.RemoteAddr())
go func(c net.Conn) { go func(c net.Conn) {
defer c.Close() defer c.Close()
tlsconn := c.(*Conn) tlsconn := c.(*Conn)
if err := tlsconn.Handshake(); err != nil { if err := tlsconn.Handshake(); err != nil {
t.Log(err) t.Log(err)
return return
} }
buf := make([]byte, 4096) buf := make([]byte, 4096)
for { for {
n, err := c.Read(buf[0:]) n, err := c.Read(buf[0:])
if err != nil { if err != nil {
t.Log("gnutls read ", err) t.Log("gnutls read ", err)
break break
} }
if _, err := c.Write(buf[:n]); err != nil { if _, err := c.Write(buf[:n]); err != nil {
t.Log("gnutls write ", err) t.Log("gnutls write ", err)
break break
} }
} }
}(c) }(c)
} }
}() }()
c, err := tls.Dial("tcp", addr, &tls.Config{InsecureSkipVerify: true}) c, err := tls.Dial("tcp", addr, &tls.Config{InsecureSkipVerify: true})
if err != nil { if err != nil {
t.Fatal("dial ", err) t.Fatal("dial ", err)
} }
defer c.Close() defer c.Close()
data := "hello, world" data := "hello, world"
if _, err := c.Write([]byte(data)); err != nil { if _, err := c.Write([]byte(data)); err != nil {
t.Fatal("write ", err) t.Fatal("write ", err)
} }
buf := make([]byte, 100) buf := make([]byte, 100)
n, err := c.Read(buf) n, err := c.Read(buf)
if err != nil { if err != nil {
t.Fatal("read ", err) t.Fatal("read ", err)
} }
if string(buf[:n]) != data { if string(buf[:n]) != data {
t.Errorf("need: %s, got: %s", data, string(buf[:n])) t.Errorf("need: %s, got: %s", data, string(buf[:n]))
} }
runtime.GC() runtime.GC()
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
} }
func TestTLSALPNServer(t *testing.T) { func TestTLSALPNServer(t *testing.T) {
serveralpn := []string{"a1", "a3", "a2"} serveralpn := []string{"a1", "a3", "a2"}
clientalpn := []string{"a0", "a2", "a5"} clientalpn := []string{"a0", "a2", "a5"}
expectedAlpn := "a2" expectedAlpn := "a2"
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key") cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
l, err := Listen("tcp", "127.0.0.1:0", &Config{ l, err := Listen("tcp", "127.0.0.1:0", &Config{
Certificates: []*Certificate{cert}, Certificates: []*Certificate{cert},
NextProtos: serveralpn, NextProtos: serveralpn,
}) })
if err != nil { if err != nil {
t.Fatal("gnutls listen ", err) t.Fatal("gnutls listen ", err)
} }
addr := l.Addr().String() addr := l.Addr().String()
t.Log("test server listen on ", addr) t.Log("test server listen on ", addr)
defer l.Close() defer l.Close()
go func() { go func() {
for { for {
c, err := l.Accept() c, err := l.Accept()
if err != nil { if err != nil {
t.Log("gnutls accept ", err) t.Log("gnutls accept ", err)
break break
} }
t.Log("accept connection from ", c.RemoteAddr()) t.Log("accept connection from ", c.RemoteAddr())
go func(c net.Conn) { go func(c net.Conn) {
defer c.Close() defer c.Close()
tlsConn := c.(*Conn) tlsConn := c.(*Conn)
if err := tlsConn.Handshake(); err != nil { if err := tlsConn.Handshake(); err != nil {
t.Log(err) t.Log(err)
return return
} }
connState := tlsConn.ConnectionState() connState := tlsConn.ConnectionState()
t.Logf("%+v", connState) t.Logf("%+v", connState)
buf := make([]byte, 4096) buf := make([]byte, 4096)
for { for {
n, err := c.Read(buf[0:]) n, err := c.Read(buf[0:])
if err != nil { if err != nil {
t.Log("gnutls read ", err) t.Log("gnutls read ", err)
break break
} }
if _, err := c.Write(buf[:n]); err != nil { if _, err := c.Write(buf[:n]); err != nil {
t.Log("gnutls write ", err) t.Log("gnutls write ", err)
break break
} }
} }
}(c) }(c)
} }
}() }()
c, err := tls.Dial("tcp", addr, &tls.Config{ c, err := tls.Dial("tcp", addr, &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
ServerName: "localhost", ServerName: "localhost",
NextProtos: clientalpn, NextProtos: clientalpn,
}) })
if err != nil { if err != nil {
t.Fatal("dial ", err) t.Fatal("dial ", err)
} }
defer c.Close() defer c.Close()
if err := c.Handshake(); err != nil { if err := c.Handshake(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
connState := c.ConnectionState() connState := c.ConnectionState()
t.Logf("%+v", connState) t.Logf("%+v", connState)
if connState.NegotiatedProtocol != expectedAlpn { if connState.NegotiatedProtocol != expectedAlpn {
t.Errorf("expected alpn %s, got %s", t.Errorf("expected alpn %s, got %s",
expectedAlpn, connState.NegotiatedProtocol) expectedAlpn, connState.NegotiatedProtocol)
} }
data := "hello, world" data := "hello, world"
if _, err := c.Write([]byte(data)); err != nil { if _, err := c.Write([]byte(data)); err != nil {
t.Fatal("write ", err) t.Fatal("write ", err)
} }
buf := make([]byte, 100) buf := make([]byte, 100)
n, err := c.Read(buf) n, err := c.Read(buf)
if err != nil { if err != nil {
t.Fatal("read ", err) t.Fatal("read ", err)
} }
if string(buf[:n]) != data { if string(buf[:n]) != data {
t.Errorf("need: %s, got: %s", data, string(buf[:n])) t.Errorf("need: %s, got: %s", data, string(buf[:n]))
} }
runtime.GC() runtime.GC()
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
} }
func TestTLSALPNClient(t *testing.T) { func TestTLSALPNClient(t *testing.T) {
serveralpn := []string{"a1", "a3", "a2"} serveralpn := []string{"a1", "a3", "a2"}
clientalpn := []string{"a0", "a2", "a5"} clientalpn := []string{"a0", "a2", "a5"}
expectedAlpn := "a2" expectedAlpn := "a2"
cert, err := tls.LoadX509KeyPair("testdata/server.crt", "testdata/server.key") cert, err := tls.LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
if err != nil { if err != nil {
t.Fatal("load key failed") t.Fatal("load key failed")
} }
l, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ l, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
Certificates: []tls.Certificate{cert}, Certificates: []tls.Certificate{cert},
NextProtos: serveralpn, NextProtos: serveralpn,
}) })
if err != nil { if err != nil {
t.Fatal("tls listen ", err) t.Fatal("tls listen ", err)
} }
addr := l.Addr().String() addr := l.Addr().String()
t.Log("test server listen on ", addr) t.Log("test server listen on ", addr)
defer l.Close() defer l.Close()
go func() { go func() {
for { for {
c, err := l.Accept() c, err := l.Accept()
if err != nil { if err != nil {
t.Log("gnutls accept ", err) t.Log("gnutls accept ", err)
break break
} }
t.Log("accept connection from ", c.RemoteAddr()) t.Log("accept connection from ", c.RemoteAddr())
go func(c net.Conn) { go func(c net.Conn) {
defer c.Close() defer c.Close()
tlsConn := c.(*tls.Conn) tlsConn := c.(*tls.Conn)
if err := tlsConn.Handshake(); err != nil { if err := tlsConn.Handshake(); err != nil {
t.Log(err) t.Log(err)
return return
} }
connState := tlsConn.ConnectionState() connState := tlsConn.ConnectionState()
t.Logf("%+v", connState) t.Logf("%+v", connState)
buf := make([]byte, 4096) buf := make([]byte, 4096)
for { for {
n, err := c.Read(buf[0:]) n, err := c.Read(buf[0:])
if err != nil { if err != nil {
t.Log("tls read ", err) t.Log("tls read ", err)
break break
} }
if _, err := c.Write(buf[:n]); err != nil { if _, err := c.Write(buf[:n]); err != nil {
t.Log("tls write ", err) t.Log("tls write ", err)
break break
} }
} }
}(c) }(c)
} }
}() }()
c, err := Dial("tcp", addr, &Config{InsecureSkipVerify: true, c, err := Dial("tcp", addr, &Config{InsecureSkipVerify: true,
ServerName: "localhost", ServerName: "localhost",
NextProtos: clientalpn, NextProtos: clientalpn,
}) })
if err != nil { if err != nil {
t.Fatal("dial ", err) t.Fatal("dial ", err)
} }
defer c.Close() defer c.Close()
if err := c.Handshake(); err != nil { if err := c.Handshake(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
connState := c.ConnectionState() connState := c.ConnectionState()
t.Logf("%+v", connState) t.Logf("%+v", connState)
if connState.NegotiatedProtocol != expectedAlpn { if connState.NegotiatedProtocol != expectedAlpn {
t.Errorf("expected alpn %s, got %s", t.Errorf("expected alpn %s, got %s",
expectedAlpn, connState.NegotiatedProtocol) expectedAlpn, connState.NegotiatedProtocol)
} }
data := "hello, world" data := "hello, world"
if _, err := c.Write([]byte(data)); err != nil { if _, err := c.Write([]byte(data)); err != nil {
t.Fatal("write ", err) t.Fatal("write ", err)
} }
buf := make([]byte, 100) buf := make([]byte, 100)
n, err := c.Read(buf) n, err := c.Read(buf)
if err != nil { if err != nil {
t.Fatal("read ", err) t.Fatal("read ", err)
} }
if string(buf[:n]) != data { if string(buf[:n]) != data {
t.Errorf("need: %s, got: %s", data, string(buf[:n])) t.Errorf("need: %s, got: %s", data, string(buf[:n]))
} }
runtime.GC() runtime.GC()
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
} }
func TestTLSServerSNI(t *testing.T) { func TestTLSServerSNI(t *testing.T) {
certificates := []*Certificate{} certificates := []*Certificate{}
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key") cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
if err != nil { if err != nil {
t.Fatal("load key failed") t.Fatal("load key failed")
} }
certificates = append(certificates, cert) certificates = append(certificates, cert)
cert, err = LoadX509KeyPair("testdata/server2.crt", "testdata/server2.key") cert, err = LoadX509KeyPair("testdata/server2.crt", "testdata/server2.key")
if err != nil { if err != nil {
t.Fatal("load key failed") t.Fatal("load key failed")
} }
certificates = append(certificates, cert) certificates = append(certificates, cert)
cert, err = LoadX509KeyPair("testdata/server3.crt", "testdata/server3.key") cert, err = LoadX509KeyPair("testdata/server3.crt", "testdata/server3.key")
if err != nil { if err != nil {
t.Fatal("load key failed") t.Fatal("load key failed")
} }
certificates = append(certificates, cert) certificates = append(certificates, cert)
l, err := Listen("tcp", "127.0.0.1:0", &Config{ l, err := Listen("tcp", "127.0.0.1:0", &Config{
Certificates: certificates, Certificates: certificates,
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer l.Close() defer l.Close()
addr := l.Addr().String() addr := l.Addr().String()
go func() { go func() {
for { for {
c, err := l.Accept() c, err := l.Accept()
if err != nil { if err != nil {
t.Log(err) t.Log(err)
break break
} }
go func(c net.Conn) { go func(c net.Conn) {
defer c.Close() defer c.Close()
tlsconn := c.(*Conn) tlsconn := c.(*Conn)
if err := tlsconn.Handshake(); err != nil { if err := tlsconn.Handshake(); err != nil {
t.Log(err) t.Log(err)
return return
} }
state := tlsconn.ConnectionState() state := tlsconn.ConnectionState()
fmt.Fprintf(c, state.ServerName) fmt.Fprintf(c, state.ServerName)
}(c) }(c)
} }
}() }()
for _, cfg := range []struct { for _, cfg := range []struct {
serverName string serverName string
commonName string commonName string
}{ }{
{"abc.com", "abc.com"}, {"abc.com", "abc.com"},
{"example.com", "example.com"}, {"example.com", "example.com"},
{"a.aaa.com", "*.aaa.com"}, {"a.aaa.com", "*.aaa.com"},
{"b.aaa.com", "*.aaa.com"}, {"b.aaa.com", "*.aaa.com"},
} { } {
conn, err := tls.Dial("tcp", addr, &tls.Config{ conn, err := tls.Dial("tcp", addr, &tls.Config{
ServerName: cfg.serverName, ServerName: cfg.serverName,
InsecureSkipVerify: true, InsecureSkipVerify: true,
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
state := conn.ConnectionState() state := conn.ConnectionState()
_commonName := state.PeerCertificates[0].Subject.CommonName _commonName := state.PeerCertificates[0].Subject.CommonName
if _commonName != cfg.commonName { if _commonName != cfg.commonName {
t.Errorf("expect: %s, got: %s", cfg.commonName, _commonName) t.Errorf("expect: %s, got: %s", cfg.commonName, _commonName)
} }
buf := make([]byte, 100) buf := make([]byte, 100)
n, err := conn.Read(buf) n, err := conn.Read(buf)
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
t.Error(err) t.Error(err)
} }
if !bytes.Equal(buf[:n], []byte(cfg.serverName)) { if !bytes.Equal(buf[:n], []byte(cfg.serverName)) {
t.Errorf("expect %s, got %s", cfg.serverName, string(buf[:n])) t.Errorf("expect %s, got %s", cfg.serverName, string(buf[:n]))
} }
conn.Close() conn.Close()
} }
runtime.GC() runtime.GC()
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
} }
func TestTLSGetPeerCert(t *testing.T) { func TestTLSGetPeerCert(t *testing.T) {
conn, err := Dial("tcp", "www.ratafee.nl:443", &Config{ conn, err := Dial("tcp", "www.ratafee.nl:443", &Config{
ServerName: "www.ratafee.nl", ServerName: "www.ratafee.nl",
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer conn.Close() defer conn.Close()
//tlsconn := conn.(*Conn) //tlsconn := conn.(*Conn)
if err := conn.Handshake(); err != nil { if err := conn.Handshake(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
state := conn.ConnectionState() state := conn.ConnectionState()
for i := 0; i < int(state.PeerCertificate.certSize); i++ { for i := 0; i < int(state.PeerCertificate.certSize); i++ {
t.Log(state.PeerCertificate.getCertString(i, 1)) t.Log(state.PeerCertificate.getCertString(i, 1))
} }
req, _ := http.NewRequest("GET", "https://www.ratafee.nl/httpbin/ip", nil) req, _ := http.NewRequest("GET", "https://www.ratafee.nl/httpbin/ip", nil)
req.Write(conn) req.Write(conn)
r := bufio.NewReader(conn) r := bufio.NewReader(conn)
resp, err := http.ReadResponse(r, req) resp, err := http.ReadResponse(r, req)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
var buf = new(bytes.Buffer) var buf = new(bytes.Buffer)
resp.Write(buf) resp.Write(buf)
t.Logf("%s", string(buf.Bytes())) t.Logf("%s", string(buf.Bytes()))
runtime.GC() runtime.GC()
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
} }

Loading…
Cancel
Save