diff --git a/certificate.go b/certificate.go index e5b17f8..7f63483 100644 --- a/certificate.go +++ b/certificate.go @@ -1,176 +1,176 @@ -package gnutls - -/* -#include "_gnutls.h" -*/ -import "C" -import ( - "fmt" - "log" - "runtime" - "strings" - "unsafe" -) - -// Certificate x509 certificate -type Certificate struct { - cert *C.gnutls_pcert_st - privkey C.gnutls_privkey_t - certSize C.int -} - -// Free free the certificate context -func (c *Certificate) Free() { - if c.cert != nil { - C.free_cert_list(c.cert, c.certSize) - } - if c.privkey != nil { - C.gnutls_privkey_deinit(c.privkey) - } - c.cert = nil - c.privkey = nil - c.certSize = 0 -} - -func (c *Certificate) free() { - //log.Println("free certificate") - c.Free() -} - -func (c *Certificate) matchName(name string) bool { - cname := C.CString(name) - defer C.free(unsafe.Pointer(cname)) - ret := C.cert_check_hostname(c.cert, c.certSize, cname) - if int(ret) < 0 { - log.Println(C.GoString(C.gnutls_strerror(ret))) - return false - } - if int(ret) > 0 { - return true - } - return false -} - -// CommonName get CN field in subject, -// -// example: subject C=xx,ST=bbb,CN=abc will return abc -func (c *Certificate) CommonName() string { - return c.commonName(0) -} -func (c *Certificate) commonName(index int) string { - dnMap := make(map[string]string) - dn := c.getDN(0) - if dn != "" { - dnFields := strings.Split(dn, ",") - for _, d := range dnFields { - dd := strings.Split(d, "=") - k := dd[0] - v := dd[1] - dnMap[k] = v - } - } - - if cn, ok := dnMap["CN"]; ok { - return cn - } - return "" -} - -// GetAltName get altname in certificate -func (c *Certificate) GetAltName() string { - return c.getAltName(0, 0) -} -func (c *Certificate) getAltName(index int, nameindex int) string { - out := C.malloc(1024) - defer C.free(out) - size := C.get_pcert_alt_name( - c.cert, C.int(index), C.int(nameindex), (*C.char)(out)) - if int(size) < 0 { - log.Println(C.GoString(C.gnutls_strerror(size))) - return "" - } - - name := C.GoBytes(out, size) - return string(name) -} - -//GetCertString return certificate info string in one line -func (c *Certificate) GetCertString() string { - return c.getCertString(0, 1) -} - -func (c *Certificate) getCertString(index int, flag int) string { - out := C.malloc(4096) - defer C.free(out) - size := C.get_cert_str(c.cert, C.int(index), C.int(flag), (*C.char)(out)) - if int(size) < 0 { - log.Println(C.GoString(C.gnutls_strerror(size))) - return "" - } - s := C.GoBytes(out, size) - return string(s) -} - -// GetDN get the certificate subject, like O=st,C=aa,CN=localhost -func (c *Certificate) GetDN() string { - return c.getDN(0) -} - -func (c *Certificate) getDN(index int) string { - cbuf := C.malloc(200) - defer C.free(cbuf) - size := C.get_cert_dn(c.cert, C.int(index), (*C.char)(cbuf)) - if int(size) < 0 { - log.Println(C.GoString(C.gnutls_strerror(size))) - return "" - } - s := C.GoBytes(cbuf, size) - return string(s) -} - -// GetIssuerDN get the certificate issuer's subject, like O=st,C=ac,CN=localhost -func (c *Certificate) GetIssuerDN() string { - return c.getIssuerDN(0) -} - -func (c *Certificate) getIssuerDN(index int) string { - cbuf := C.malloc(200) - defer C.free(cbuf) - size := C.get_cert_issuer_dn(c.cert, C.int(index), (*C.char)(cbuf)) - if int(size) < 0 { - log.Println(C.GoString(C.gnutls_strerror(size))) - return "" - } - s := C.GoBytes(cbuf, size) - return string(s) -} - -// LoadX509KeyPair load certificate pair, -// the return Certifciate must be freed by call Free(), -func LoadX509KeyPair(certfile, keyfile string) (*Certificate, error) { - _certfile := C.CString(certfile) - _keyfile := C.CString(keyfile) - - defer C.free(unsafe.Pointer(_certfile)) - defer C.free(unsafe.Pointer(_keyfile)) - - certificate := &Certificate{} - var ret C.int - var certSize C.int - cert := C.load_cert_list(_certfile, (*C.int)(unsafe.Pointer(&certSize)), - (*C.int)(unsafe.Pointer(&ret))) - if int(ret) < 0 { - return nil, fmt.Errorf("load cert failed: %s", - C.GoString(C.gnutls_strerror(ret))) - } - privkey := C.load_privkey(_keyfile, (*C.int)(unsafe.Pointer(&ret))) - if int(ret) < 0 { - return nil, fmt.Errorf("load privkey: %s", - C.GoString(C.gnutls_strerror(ret))) - } - certificate.cert = cert - certificate.privkey = privkey - certificate.certSize = certSize - runtime.SetFinalizer(certificate, (*Certificate).free) - return certificate, nil -} +package gnutls + +/* +#include "_gnutls.h" +*/ +import "C" +import ( + "fmt" + "log" + "runtime" + "strings" + "unsafe" +) + +// Certificate x509 certificate +type Certificate struct { + cert *C.gnutls_pcert_st + privkey C.gnutls_privkey_t + certSize C.int +} + +// Free free the certificate context +func (c *Certificate) Free() { + if c.cert != nil { + C.free_cert_list(c.cert, c.certSize) + } + if c.privkey != nil { + C.gnutls_privkey_deinit(c.privkey) + } + c.cert = nil + c.privkey = nil + c.certSize = 0 +} + +func (c *Certificate) free() { + //log.Println("free certificate") + c.Free() +} + +func (c *Certificate) matchName(name string) bool { + cname := C.CString(name) + defer C.free(unsafe.Pointer(cname)) + ret := C.cert_check_hostname(c.cert, c.certSize, cname) + if int(ret) < 0 { + log.Println(C.GoString(C.gnutls_strerror(ret))) + return false + } + if int(ret) > 0 { + return true + } + return false +} + +// CommonName get CN field in subject, +// +// example: subject C=xx,ST=bbb,CN=abc will return abc +func (c *Certificate) CommonName() string { + return c.commonName(0) +} +func (c *Certificate) commonName(index int) string { + dnMap := make(map[string]string) + dn := c.getDN(0) + if dn != "" { + dnFields := strings.Split(dn, ",") + for _, d := range dnFields { + dd := strings.Split(d, "=") + k := dd[0] + v := dd[1] + dnMap[k] = v + } + } + + if cn, ok := dnMap["CN"]; ok { + return cn + } + return "" +} + +// GetAltName get altname in certificate +func (c *Certificate) GetAltName() string { + return c.getAltName(0, 0) +} +func (c *Certificate) getAltName(index int, nameindex int) string { + out := C.malloc(1024) + defer C.free(out) + size := C.get_pcert_alt_name( + c.cert, C.int(index), C.int(nameindex), (*C.char)(out)) + if int(size) < 0 { + log.Println(C.GoString(C.gnutls_strerror(size))) + return "" + } + + name := C.GoBytes(out, size) + return string(name) +} + +//GetCertString return certificate info string in one line +func (c *Certificate) GetCertString() string { + return c.getCertString(0, 1) +} + +func (c *Certificate) getCertString(index int, flag int) string { + out := C.malloc(4096) + defer C.free(out) + size := C.get_cert_str(c.cert, C.int(index), C.int(flag), (*C.char)(out)) + if int(size) < 0 { + log.Println(C.GoString(C.gnutls_strerror(size))) + return "" + } + s := C.GoBytes(out, size) + return string(s) +} + +// GetDN get the certificate subject, like O=st,C=aa,CN=localhost +func (c *Certificate) GetDN() string { + return c.getDN(0) +} + +func (c *Certificate) getDN(index int) string { + cbuf := C.malloc(200) + defer C.free(cbuf) + size := C.get_cert_dn(c.cert, C.int(index), (*C.char)(cbuf)) + if int(size) < 0 { + log.Println(C.GoString(C.gnutls_strerror(size))) + return "" + } + s := C.GoBytes(cbuf, size) + return string(s) +} + +// GetIssuerDN get the certificate issuer's subject, like O=st,C=ac,CN=localhost +func (c *Certificate) GetIssuerDN() string { + return c.getIssuerDN(0) +} + +func (c *Certificate) getIssuerDN(index int) string { + cbuf := C.malloc(200) + defer C.free(cbuf) + size := C.get_cert_issuer_dn(c.cert, C.int(index), (*C.char)(cbuf)) + if int(size) < 0 { + log.Println(C.GoString(C.gnutls_strerror(size))) + return "" + } + s := C.GoBytes(cbuf, size) + return string(s) +} + +// LoadX509KeyPair load certificate pair, +// the return Certifciate must be freed by call Free(), +func LoadX509KeyPair(certfile, keyfile string) (*Certificate, error) { + _certfile := C.CString(certfile) + _keyfile := C.CString(keyfile) + + defer C.free(unsafe.Pointer(_certfile)) + defer C.free(unsafe.Pointer(_keyfile)) + + certificate := &Certificate{} + var ret C.int + var certSize C.int + cert := C.load_cert_list(_certfile, (*C.int)(unsafe.Pointer(&certSize)), + (*C.int)(unsafe.Pointer(&ret))) + if int(ret) < 0 { + return nil, fmt.Errorf("load cert failed: %s", + C.GoString(C.gnutls_strerror(ret))) + } + privkey := C.load_privkey(_keyfile, (*C.int)(unsafe.Pointer(&ret))) + if int(ret) < 0 { + return nil, fmt.Errorf("load privkey: %s", + C.GoString(C.gnutls_strerror(ret))) + } + certificate.cert = cert + certificate.privkey = privkey + certificate.certSize = certSize + runtime.SetFinalizer(certificate, (*Certificate).free) + return certificate, nil +} diff --git a/certificate_test.go b/certificate_test.go index 14990b5..36a2fe7 100644 --- a/certificate_test.go +++ b/certificate_test.go @@ -1,35 +1,35 @@ -package gnutls - -import ( - "runtime" - "testing" - "time" -) - -func TestGetAltname(t *testing.T) { - cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key") - if err != nil { - t.Fatal(err) - } - t.Logf("%+v", cert) - t.Log(cert.GetAltName()) - t.Log(cert.GetCertString()) - t.Log(cert.GetDN()) - t.Log(cert.GetIssuerDN()) - t.Log("CN ", cert.CommonName()) - //t.Log("flag 0: ", cert.getCertString(0, 0)) - //t.Log("flag 1: ", cert.getCertString(0, 1)) - //t.Log("flag 2: ", cert.getCertString(0, 2)) - //t.Log("flag 3: ", cert.getCertString(0, 3)) - cert.Free() -} - -func _loadCert(certfile, keyfile string) (*Certificate, error) { - return LoadX509KeyPair(certfile, keyfile) -} - -func TestCertGC(t *testing.T) { - _loadCert("testdata/server.crt", "testdata/server.key") - runtime.GC() - time.Sleep(1 * time.Second) -} +package gnutls + +import ( + "runtime" + "testing" + "time" +) + +func TestGetAltname(t *testing.T) { + cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key") + if err != nil { + t.Fatal(err) + } + t.Logf("%+v", cert) + t.Log(cert.GetAltName()) + t.Log(cert.GetCertString()) + t.Log(cert.GetDN()) + t.Log(cert.GetIssuerDN()) + t.Log("CN ", cert.CommonName()) + //t.Log("flag 0: ", cert.getCertString(0, 0)) + //t.Log("flag 1: ", cert.getCertString(0, 1)) + //t.Log("flag 2: ", cert.getCertString(0, 2)) + //t.Log("flag 3: ", cert.getCertString(0, 3)) + cert.Free() +} + +func _loadCert(certfile, keyfile string) (*Certificate, error) { + return LoadX509KeyPair(certfile, keyfile) +} + +func TestCertGC(t *testing.T) { + _loadCert("testdata/server.crt", "testdata/server.key") + runtime.GC() + time.Sleep(1 * time.Second) +} diff --git a/cipher.go b/cipher.go index 08e7437..76a68de 100644 --- a/cipher.go +++ b/cipher.go @@ -1,167 +1,167 @@ -package gnutls - -/* -#include "_gnutls.h" -#include -*/ -import "C" -import ( - "errors" - "fmt" - "log" - "runtime" -) - -// CipherType cipher type -type CipherType int - -const ( - GNUTLS_CIPHER_AES_128_CBC CipherType = 4 - GNUTLS_CIPHER_AES_256_CBC CipherType = 5 - GNUTLS_CIPHER_ARCFOUR_40 CipherType = 6 - GNUTLS_CIPHER_CAMELLIA_128_CBC CipherType = 7 - GNUTLS_CIPHER_CAMELLIA_256_CBC CipherType = 8 - GNUTLS_CIPHER_AES_192_CBC CipherType = 9 - GNUTLS_CIPHER_AES_128_GCM CipherType = 10 - GNUTLS_CIPHER_AES_256_GCM CipherType = 11 - GNUTLS_CIPHER_CAMELLIA_192_CBC CipherType = 12 - GNUTLS_CIPHER_SALSA20_256 CipherType = 13 - GNUTLS_CIPHER_ESTREAM_SALSA20_256 CipherType = 14 - GNUTLS_CIPHER_CAMELLIA_128_GCM CipherType = 15 - GNUTLS_CIPHER_CAMELLIA_256_GCM CipherType = 16 - GNUTLS_CIPHER_RC2_40_CBC CipherType = 17 - GNUTLS_CIPHER_DES_CBC CipherType = 18 - GNUTLS_CIPHER_AES_128_CCM CipherType = 19 - GNUTLS_CIPHER_AES_256_CCM CipherType = 20 - GNUTLS_CIPHER_AES_128_CCM_8 CipherType = 21 - GNUTLS_CIPHER_AES_256_CCM_8 CipherType = 22 - GNUTLS_CIPHER_CHACHA20_POLY1305 CipherType = 23 -) - -var ( - // ErrBlockSize wrong block size - ErrBlockSize = errors.New("wrong block size") - // ErrKeyLength wrong key length - ErrKeyLength = errors.New("wrong key length") - // ErrIVLength wrong iv length - ErrIVLength = errors.New("wrong iv length") -) - -// Cipher gnutls cipher struct -type Cipher struct { - cipher C.gnutls_cipher_hd_t - t CipherType - blockSize int -} - -// NewCipher create a new cipher by give type, key, iv -// -// example: -// NewCipher(GNUTLS_CIPHER_AES_128_CBC, []byte("1234567890abcdef"), []byte("abcdef0123456789")) -// -// you can use GetCipherKeySize, GetCipherBlockSize, GetCipherIVSize to determine the given cipher 's key, block, iv size -func NewCipher(t CipherType, key []byte, iv []byte) (*Cipher, error) { - keysize := GetCipherKeySize(t) - ivSize := GetCipherIVSize(t) - blocksize := GetCipherBlockSize(t) - //log.Printf("block size: %d, iv size: %d", int(ivSize), int(blockSize)) - if len(key) != int(keysize) { - return nil, ErrKeyLength - } - - if len(iv) != int(ivSize) { - return nil, ErrIVLength - } - - ckey := C.CBytes(key) - civ := C.CBytes(iv) - - defer C.free(ckey) - defer C.free(civ) - - c := C.new_cipher(C.int(t), (*C.char)(ckey), C.int(len(key)), (*C.char)(civ), C.int(len(iv))) - if c == nil { - log.Println("new cipher return nil") - return nil, nil - } - cipher := &Cipher{c, t, blocksize} - runtime.SetFinalizer(cipher, (*Cipher).free) - return cipher, nil -} - -// Encrypt encrypt the buf and place the encrypted data in dst, -// the buf size must multiple of cipher's block size -func (c *Cipher) Encrypt(dst, buf []byte) error { - if len(buf)%c.blockSize != 0 { - return ErrBlockSize - } - - cbuf := C.CBytes(buf) - defer C.free(cbuf) - - bufLen := C.size_t(len(buf)) - dstBuf := C.malloc(bufLen) - - defer C.free(dstBuf) - - ret := C.gnutls_cipher_encrypt2(c.cipher, cbuf, bufLen, dstBuf, bufLen) - if int(ret) < 0 { - return fmt.Errorf("encrypt error: %s", C.GoString(C.gnutls_strerror(ret))) - } - _buf := C.GoBytes(dstBuf, C.int(bufLen)) - copy(dst, _buf) - return nil -} - -// Decrypt decrypt the buf and place the decrypted data in dst, -// the buf size must multiple of cipher's block size -func (c *Cipher) Decrypt(dst, buf []byte) error { - if len(buf)%c.blockSize != 0 { - return ErrBlockSize - } - - cbuf := C.CBytes(buf) - defer C.free(cbuf) - - bufLen := C.size_t(len(buf)) - dstBuf := C.malloc(C.size_t(len(buf))) - - defer C.free(dstBuf) - - ret := C.gnutls_cipher_decrypt2(c.cipher, cbuf, bufLen, dstBuf, bufLen) - if int(ret) < 0 { - return fmt.Errorf("decrypt error: %s", C.GoString(C.gnutls_strerror(ret))) - } - _buf := C.GoBytes(dstBuf, C.int(bufLen)) - copy(dst, _buf) - return nil -} - -// Close destroy the cipher context -func (c *Cipher) Close() error { - if c.cipher != nil { - C.gnutls_cipher_deinit(c.cipher) - c.cipher = nil - } - return nil -} - -func (c *Cipher) free() { - //log.Println("free cipher") - c.Close() -} - -// GetCipherKeySize get the cipher algorithm key length -func GetCipherKeySize(t CipherType) int { - return int(C.gnutls_cipher_get_key_size(C.gnutls_cipher_algorithm_t(t))) -} - -// GetCipherIVSize get the cipher algorithm iv length -func GetCipherIVSize(t CipherType) int { - return int(C.gnutls_cipher_get_iv_size(C.gnutls_cipher_algorithm_t(t))) -} - -// GetCipherBlockSize get the cipher algorithm block size -func GetCipherBlockSize(t CipherType) int { - return int(C.gnutls_cipher_get_block_size(C.gnutls_cipher_algorithm_t(t))) -} +package gnutls + +/* +#include "_gnutls.h" +#include +*/ +import "C" +import ( + "errors" + "fmt" + "log" + "runtime" +) + +// CipherType cipher type +type CipherType int + +const ( + GNUTLS_CIPHER_AES_128_CBC CipherType = 4 + GNUTLS_CIPHER_AES_256_CBC CipherType = 5 + GNUTLS_CIPHER_ARCFOUR_40 CipherType = 6 + GNUTLS_CIPHER_CAMELLIA_128_CBC CipherType = 7 + GNUTLS_CIPHER_CAMELLIA_256_CBC CipherType = 8 + GNUTLS_CIPHER_AES_192_CBC CipherType = 9 + GNUTLS_CIPHER_AES_128_GCM CipherType = 10 + GNUTLS_CIPHER_AES_256_GCM CipherType = 11 + GNUTLS_CIPHER_CAMELLIA_192_CBC CipherType = 12 + GNUTLS_CIPHER_SALSA20_256 CipherType = 13 + GNUTLS_CIPHER_ESTREAM_SALSA20_256 CipherType = 14 + GNUTLS_CIPHER_CAMELLIA_128_GCM CipherType = 15 + GNUTLS_CIPHER_CAMELLIA_256_GCM CipherType = 16 + GNUTLS_CIPHER_RC2_40_CBC CipherType = 17 + GNUTLS_CIPHER_DES_CBC CipherType = 18 + GNUTLS_CIPHER_AES_128_CCM CipherType = 19 + GNUTLS_CIPHER_AES_256_CCM CipherType = 20 + GNUTLS_CIPHER_AES_128_CCM_8 CipherType = 21 + GNUTLS_CIPHER_AES_256_CCM_8 CipherType = 22 + GNUTLS_CIPHER_CHACHA20_POLY1305 CipherType = 23 +) + +var ( + // ErrBlockSize wrong block size + ErrBlockSize = errors.New("wrong block size") + // ErrKeyLength wrong key length + ErrKeyLength = errors.New("wrong key length") + // ErrIVLength wrong iv length + ErrIVLength = errors.New("wrong iv length") +) + +// Cipher gnutls cipher struct +type Cipher struct { + cipher C.gnutls_cipher_hd_t + t CipherType + blockSize int +} + +// NewCipher create a new cipher by give type, key, iv +// +// example: +// NewCipher(GNUTLS_CIPHER_AES_128_CBC, []byte("1234567890abcdef"), []byte("abcdef0123456789")) +// +// you can use GetCipherKeySize, GetCipherBlockSize, GetCipherIVSize to determine the given cipher 's key, block, iv size +func NewCipher(t CipherType, key []byte, iv []byte) (*Cipher, error) { + keysize := GetCipherKeySize(t) + ivSize := GetCipherIVSize(t) + blocksize := GetCipherBlockSize(t) + //log.Printf("block size: %d, iv size: %d", int(ivSize), int(blockSize)) + if len(key) != int(keysize) { + return nil, ErrKeyLength + } + + if len(iv) != int(ivSize) { + return nil, ErrIVLength + } + + ckey := C.CBytes(key) + civ := C.CBytes(iv) + + defer C.free(ckey) + defer C.free(civ) + + c := C.new_cipher(C.int(t), (*C.char)(ckey), C.int(len(key)), (*C.char)(civ), C.int(len(iv))) + if c == nil { + log.Println("new cipher return nil") + return nil, nil + } + cipher := &Cipher{c, t, blocksize} + runtime.SetFinalizer(cipher, (*Cipher).free) + return cipher, nil +} + +// Encrypt encrypt the buf and place the encrypted data in dst, +// the buf size must multiple of cipher's block size +func (c *Cipher) Encrypt(dst, buf []byte) error { + if len(buf)%c.blockSize != 0 { + return ErrBlockSize + } + + cbuf := C.CBytes(buf) + defer C.free(cbuf) + + bufLen := C.size_t(len(buf)) + dstBuf := C.malloc(bufLen) + + defer C.free(dstBuf) + + ret := C.gnutls_cipher_encrypt2(c.cipher, cbuf, bufLen, dstBuf, bufLen) + if int(ret) < 0 { + return fmt.Errorf("encrypt error: %s", C.GoString(C.gnutls_strerror(ret))) + } + _buf := C.GoBytes(dstBuf, C.int(bufLen)) + copy(dst, _buf) + return nil +} + +// Decrypt decrypt the buf and place the decrypted data in dst, +// the buf size must multiple of cipher's block size +func (c *Cipher) Decrypt(dst, buf []byte) error { + if len(buf)%c.blockSize != 0 { + return ErrBlockSize + } + + cbuf := C.CBytes(buf) + defer C.free(cbuf) + + bufLen := C.size_t(len(buf)) + dstBuf := C.malloc(C.size_t(len(buf))) + + defer C.free(dstBuf) + + ret := C.gnutls_cipher_decrypt2(c.cipher, cbuf, bufLen, dstBuf, bufLen) + if int(ret) < 0 { + return fmt.Errorf("decrypt error: %s", C.GoString(C.gnutls_strerror(ret))) + } + _buf := C.GoBytes(dstBuf, C.int(bufLen)) + copy(dst, _buf) + return nil +} + +// Close destroy the cipher context +func (c *Cipher) Close() error { + if c.cipher != nil { + C.gnutls_cipher_deinit(c.cipher) + c.cipher = nil + } + return nil +} + +func (c *Cipher) free() { + //log.Println("free cipher") + c.Close() +} + +// GetCipherKeySize get the cipher algorithm key length +func GetCipherKeySize(t CipherType) int { + return int(C.gnutls_cipher_get_key_size(C.gnutls_cipher_algorithm_t(t))) +} + +// GetCipherIVSize get the cipher algorithm iv length +func GetCipherIVSize(t CipherType) int { + return int(C.gnutls_cipher_get_iv_size(C.gnutls_cipher_algorithm_t(t))) +} + +// GetCipherBlockSize get the cipher algorithm block size +func GetCipherBlockSize(t CipherType) int { + return int(C.gnutls_cipher_get_block_size(C.gnutls_cipher_algorithm_t(t))) +} diff --git a/cipher_test.go b/cipher_test.go index a6b3566..efe57e7 100644 --- a/cipher_test.go +++ b/cipher_test.go @@ -1,130 +1,130 @@ -package gnutls - -import ( - "bytes" - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "runtime" - "testing" - "time" -) - -func TestCipherSize(t *testing.T) { - data := []struct { - t CipherType - bsize int - isize int - }{ - {GNUTLS_CIPHER_AES_128_CBC, 16, 16}, - {GNUTLS_CIPHER_AES_192_CBC, 24, 16}, - {GNUTLS_CIPHER_AES_256_CBC, 32, 16}, - } - for _, d := range data { - blocksize := GetCipherKeySize(d.t) - if blocksize != d.bsize { - t.Errorf("%d block size expect: %d, got: %d", d.t, d.bsize, blocksize) - } - ivsize := GetCipherIVSize(d.t) - if ivsize != d.isize { - t.Errorf("%d iv size expect: %d, got: %d", d.t, d.bsize, ivsize) - } - } -} -func TestEncryptDecrypt(t *testing.T) { - cipherName := GNUTLS_CIPHER_AES_256_CBC - keysize := GetCipherKeySize(cipherName) - ivsize := GetCipherIVSize(cipherName) - blocksize := GetCipherBlockSize(cipherName) - - key := make([]byte, keysize) - iv := make([]byte, ivsize) - rand.Reader.Read(key) - rand.Reader.Read(iv) - - c, err := NewCipher(cipherName, key, iv) - if err != nil { - t.Fatal(err) - } - //defer c.Close() - - c1, err := NewCipher(cipherName, key, iv) - if err != nil { - t.Fatal(err) - } - //defer c1.Close() - - data := make([]byte, blocksize*10) - if c == nil { - t.Fatal("new ciphoer failed") - } - cdata := make([]byte, len(data)) - err = c.Encrypt(cdata, data) - if err != nil { - t.Fatal("encrypt failed", err) - } - data1 := make([]byte, len(data)) - err = c1.Decrypt(data1, cdata) - if err != nil { - t.Fatal("decrypt failed", err) - } - if !bytes.Equal(data, data1) { - t.Fatal("encrypt/decrypt failed", string(data), string(data1)) - } - block, _ := aes.NewCipher(key) - mode := cipher.NewCBCEncrypter(block, iv) - dst := make([]byte, len(data)) - mode.CryptBlocks(dst, data) - if !bytes.Equal(dst, cdata) { - t.Fatal("cipher text not equal to cypto/aes") - } - runtime.GC() - time.Sleep(1 * time.Second) -} - -func BenchmarkAESEncrypt(b *testing.B) { - cipherName := GNUTLS_CIPHER_AES_256_CBC - keysize := GetCipherKeySize(cipherName) - ivsize := GetCipherIVSize(cipherName) - blocksize := GetCipherBlockSize(cipherName) - datalen := blocksize * 500 - - key := make([]byte, keysize) - iv := make([]byte, ivsize) - rand.Reader.Read(key) - rand.Reader.Read(iv) - buf := make([]byte, datalen) - dst := make([]byte, datalen) - for i := 0; i < b.N; i++ { - c, err := NewCipher(cipherName, key, iv) - if err != nil { - b.Fatal(err) - } - c.Encrypt(dst, buf) - c.Close() - } -} - -func BenchmarkAESEncrypt2(b *testing.B) { - cipherName := GNUTLS_CIPHER_AES_256_CBC - keysize := GetCipherKeySize(cipherName) - ivsize := GetCipherIVSize(cipherName) - blocksize := GetCipherBlockSize(cipherName) - datalen := blocksize * 500 - - key := make([]byte, keysize) - iv := make([]byte, ivsize) - buf := make([]byte, datalen) - - rand.Reader.Read(buf) - rand.Reader.Read(key) - rand.Reader.Read(iv) - - dst := make([]byte, datalen) - - for i := 0; i < b.N; i++ { - block, _ := aes.NewCipher(key) - mode := cipher.NewCBCEncrypter(block, iv) - mode.CryptBlocks(dst, buf) - } -} +package gnutls + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "runtime" + "testing" + "time" +) + +func TestCipherSize(t *testing.T) { + data := []struct { + t CipherType + bsize int + isize int + }{ + {GNUTLS_CIPHER_AES_128_CBC, 16, 16}, + {GNUTLS_CIPHER_AES_192_CBC, 24, 16}, + {GNUTLS_CIPHER_AES_256_CBC, 32, 16}, + } + for _, d := range data { + blocksize := GetCipherKeySize(d.t) + if blocksize != d.bsize { + t.Errorf("%d block size expect: %d, got: %d", d.t, d.bsize, blocksize) + } + ivsize := GetCipherIVSize(d.t) + if ivsize != d.isize { + t.Errorf("%d iv size expect: %d, got: %d", d.t, d.bsize, ivsize) + } + } +} +func TestEncryptDecrypt(t *testing.T) { + cipherName := GNUTLS_CIPHER_AES_256_CBC + keysize := GetCipherKeySize(cipherName) + ivsize := GetCipherIVSize(cipherName) + blocksize := GetCipherBlockSize(cipherName) + + key := make([]byte, keysize) + iv := make([]byte, ivsize) + rand.Reader.Read(key) + rand.Reader.Read(iv) + + c, err := NewCipher(cipherName, key, iv) + if err != nil { + t.Fatal(err) + } + //defer c.Close() + + c1, err := NewCipher(cipherName, key, iv) + if err != nil { + t.Fatal(err) + } + //defer c1.Close() + + data := make([]byte, blocksize*10) + if c == nil { + t.Fatal("new ciphoer failed") + } + cdata := make([]byte, len(data)) + err = c.Encrypt(cdata, data) + if err != nil { + t.Fatal("encrypt failed", err) + } + data1 := make([]byte, len(data)) + err = c1.Decrypt(data1, cdata) + if err != nil { + t.Fatal("decrypt failed", err) + } + if !bytes.Equal(data, data1) { + t.Fatal("encrypt/decrypt failed", string(data), string(data1)) + } + block, _ := aes.NewCipher(key) + mode := cipher.NewCBCEncrypter(block, iv) + dst := make([]byte, len(data)) + mode.CryptBlocks(dst, data) + if !bytes.Equal(dst, cdata) { + t.Fatal("cipher text not equal to cypto/aes") + } + runtime.GC() + time.Sleep(1 * time.Second) +} + +func BenchmarkAESEncrypt(b *testing.B) { + cipherName := GNUTLS_CIPHER_AES_256_CBC + keysize := GetCipherKeySize(cipherName) + ivsize := GetCipherIVSize(cipherName) + blocksize := GetCipherBlockSize(cipherName) + datalen := blocksize * 500 + + key := make([]byte, keysize) + iv := make([]byte, ivsize) + rand.Reader.Read(key) + rand.Reader.Read(iv) + buf := make([]byte, datalen) + dst := make([]byte, datalen) + for i := 0; i < b.N; i++ { + c, err := NewCipher(cipherName, key, iv) + if err != nil { + b.Fatal(err) + } + c.Encrypt(dst, buf) + c.Close() + } +} + +func BenchmarkAESEncrypt2(b *testing.B) { + cipherName := GNUTLS_CIPHER_AES_256_CBC + keysize := GetCipherKeySize(cipherName) + ivsize := GetCipherIVSize(cipherName) + blocksize := GetCipherBlockSize(cipherName) + datalen := blocksize * 500 + + key := make([]byte, keysize) + iv := make([]byte, ivsize) + buf := make([]byte, datalen) + + rand.Reader.Read(buf) + rand.Reader.Read(key) + rand.Reader.Read(iv) + + dst := make([]byte, datalen) + + for i := 0; i < b.N; i++ { + block, _ := aes.NewCipher(key) + mode := cipher.NewCBCEncrypter(block, iv) + mode.CryptBlocks(dst, buf) + } +} diff --git a/doc.go b/doc.go index df2f9a1..a3e08d8 100644 --- a/doc.go +++ b/doc.go @@ -1,108 +1,108 @@ -// Package gnutls is a gnutls binding for golang. -/* -a limit set of api is supported. - -TLS api is very similar to crypto/tls on standard library. - -TLS client example: - - addr := "127.0.0.1:9443" - c, err := gnutls.Dial("tcp", addr, &gnutls.Config{ServerName: "localhost",InsecureSkipVerify: true}) - if err != nil { - log.Fatal("gnutls dial ", err) - } - defer c.Close() - - data := "hello, world" - if _, err = c.Write([]byte(data)); err != nil { - log.Fatal("gnutls write ", err) - } - buf := make([]byte, 100) - n, err := c.Read(buf) - if err != nil { - log.Fatal("gnutls read ", err) - } - -TLS Server example: - cert, err := gnutls.LoadX509KeyPair("testdata/server/crt", "testdata/server.key") - if err != nil{ - // handle error - } - l, err := gnults.Listen("tcp", "127.0.0.1:9443", &gnutls.Config{ - Certificates: []*gnutls.Certificate{cert}}) - if err != nil { - // handle error - } - defer l.Close() - for { - c, err := l.Accept() - if err != nil { - log.Println("gnutls accept ", err) - break - } - log.Println("accept connection from ", c.RemoteAddr()) - go func(c net.Conn) { - defer c.Close() - - buf := make([]byte, 4096) - for { - n, err := c.Read(buf[0:]) - if err != nil { - log.Println("gnutls read ", err) - break - } - if _, err := c.Write(buf[:n]); err != nil { - log.Println("gnutls write ", err) - break - } - } - }(c) - } - -AES encrypt/decrypt example: - - key := []byte("0123456789abcdef") - iv := []byte("abcdefg123456789") - c, err := gnutls.NewCipher(gnutls.GNUTLS_CIPHER_AES_128_CBC, key, iv) - if err != nil { - log.Fatal(err) - } - defer c.Close() - - c1, err := gnutls.NewCipher(gnutls.GNUTLS_CIPHER_AES_128_CBC, key, iv) - if err != nil { - log.Fatal(err) - } - defer c1.Close() - - data := []byte("1234012121212121") - if c == nil { - log.Fatal("new ciphoer failed") - } - - // encrypt - dst := make([]byte, len(data)) - err := c.Encrypt(dst, data) - if err != nil { - log.Fatal("encrypt failed", err) - } - - // decrypt - data1 := make([]byte, len(data)) - err := c1.Decrypt(data1, cdata) - if err != nil { - log.Fatal("decrypt failed", err) - } - - -Hash example: - - h := gnutls.NewHash(gnutls.GNUTLS_HASH_SHA512) - defer h.Close() - - data := []byte("1234") - - h1 := h.Sum(data) - -*/ -package gnutls +// Package gnutls is a gnutls binding for golang. +/* +a limit set of api is supported. + +TLS api is very similar to crypto/tls on standard library. + +TLS client example: + + addr := "127.0.0.1:9443" + c, err := gnutls.Dial("tcp", addr, &gnutls.Config{ServerName: "localhost",InsecureSkipVerify: true}) + if err != nil { + log.Fatal("gnutls dial ", err) + } + defer c.Close() + + data := "hello, world" + if _, err = c.Write([]byte(data)); err != nil { + log.Fatal("gnutls write ", err) + } + buf := make([]byte, 100) + n, err := c.Read(buf) + if err != nil { + log.Fatal("gnutls read ", err) + } + +TLS Server example: + cert, err := gnutls.LoadX509KeyPair("testdata/server/crt", "testdata/server.key") + if err != nil{ + // handle error + } + l, err := gnults.Listen("tcp", "127.0.0.1:9443", &gnutls.Config{ + Certificates: []*gnutls.Certificate{cert}}) + if err != nil { + // handle error + } + defer l.Close() + for { + c, err := l.Accept() + if err != nil { + log.Println("gnutls accept ", err) + break + } + log.Println("accept connection from ", c.RemoteAddr()) + go func(c net.Conn) { + defer c.Close() + + buf := make([]byte, 4096) + for { + n, err := c.Read(buf[0:]) + if err != nil { + log.Println("gnutls read ", err) + break + } + if _, err := c.Write(buf[:n]); err != nil { + log.Println("gnutls write ", err) + break + } + } + }(c) + } + +AES encrypt/decrypt example: + + key := []byte("0123456789abcdef") + iv := []byte("abcdefg123456789") + c, err := gnutls.NewCipher(gnutls.GNUTLS_CIPHER_AES_128_CBC, key, iv) + if err != nil { + log.Fatal(err) + } + defer c.Close() + + c1, err := gnutls.NewCipher(gnutls.GNUTLS_CIPHER_AES_128_CBC, key, iv) + if err != nil { + log.Fatal(err) + } + defer c1.Close() + + data := []byte("1234012121212121") + if c == nil { + log.Fatal("new ciphoer failed") + } + + // encrypt + dst := make([]byte, len(data)) + err := c.Encrypt(dst, data) + if err != nil { + log.Fatal("encrypt failed", err) + } + + // decrypt + data1 := make([]byte, len(data)) + err := c1.Decrypt(data1, cdata) + if err != nil { + log.Fatal("decrypt failed", err) + } + + +Hash example: + + h := gnutls.NewHash(gnutls.GNUTLS_HASH_SHA512) + defer h.Close() + + data := []byte("1234") + + h1 := h.Sum(data) + +*/ +package gnutls diff --git a/hash.go b/hash.go index 3270a13..3e212c4 100644 --- a/hash.go +++ b/hash.go @@ -1,90 +1,90 @@ -package gnutls - -/* -#include -#include "_gnutls.h" -*/ -import "C" -import ( - "fmt" - "runtime" -) - -// HashType hash type -type HashType int - -const ( - GNUTLS_HASH_MD5 HashType = 2 - GNUTLS_HASH_SHA1 HashType = 3 - GNUTLS_HASH_MD2 HashType = 5 - GNUTLS_HASH_SHA256 HashType = 6 - GNUTLS_HASH_SHA384 HashType = 7 - GNUTLS_HASH_SHA512 HashType = 8 - GNUTLS_HASH_SHA224 HashType = 9 -) - -// Hash hash struct -type Hash struct { - hash C.gnutls_hash_hd_t - t HashType - hashLen C.int -} - -// NewHash new hash struct -func NewHash(t HashType) *Hash { - h := C.new_hash(C.int(t)) - hashOutLen := GetHashOutputLen(t) - hash := &Hash{h, t, C.int(hashOutLen)} - runtime.SetFinalizer(hash, (*Hash).free) - return hash -} - -// Write write data to hash context -func (h *Hash) Write(buf []byte) error { - dataLen := len(buf) - - cbuf := C.CBytes(buf) - defer C.free(cbuf) - - ret := C.gnutls_hash(h.hash, cbuf, C.size_t(dataLen)) - if int(ret) < 0 { - return fmt.Errorf("hash failed: %s", C.GoString(C.gnutls_strerror(ret))) - } - return nil -} - -// Sum get hash result -func (h *Hash) Sum(buf []byte) []byte { - if buf != nil { - h.Write(buf) - } - - dstBuf := C.malloc(C.size_t(h.hashLen)) - defer C.free(dstBuf) - - C.gnutls_hash_output(h.hash, dstBuf) - - gobuf := C.GoBytes(dstBuf, h.hashLen) - - return gobuf -} - -// Close destroy hash context -func (h *Hash) Close() error { - if h.hash != nil { - C.gnutls_hash_deinit(h.hash, nil) - h.hash = nil - } - return nil -} -func (h *Hash) free() { - //log.Println("free hash") - h.Close() -} - -// GetHashOutputLen get the hash algorithm output length -// -// example GNUTLS_MD5 is 16 -func GetHashOutputLen(t HashType) int { - return int(C.gnutls_hash_get_len(C.gnutls_digest_algorithm_t(t))) -} +package gnutls + +/* +#include +#include "_gnutls.h" +*/ +import "C" +import ( + "fmt" + "runtime" +) + +// HashType hash type +type HashType int + +const ( + GNUTLS_HASH_MD5 HashType = 2 + GNUTLS_HASH_SHA1 HashType = 3 + GNUTLS_HASH_MD2 HashType = 5 + GNUTLS_HASH_SHA256 HashType = 6 + GNUTLS_HASH_SHA384 HashType = 7 + GNUTLS_HASH_SHA512 HashType = 8 + GNUTLS_HASH_SHA224 HashType = 9 +) + +// Hash hash struct +type Hash struct { + hash C.gnutls_hash_hd_t + t HashType + hashLen C.int +} + +// NewHash new hash struct +func NewHash(t HashType) *Hash { + h := C.new_hash(C.int(t)) + hashOutLen := GetHashOutputLen(t) + hash := &Hash{h, t, C.int(hashOutLen)} + runtime.SetFinalizer(hash, (*Hash).free) + return hash +} + +// Write write data to hash context +func (h *Hash) Write(buf []byte) error { + dataLen := len(buf) + + cbuf := C.CBytes(buf) + defer C.free(cbuf) + + ret := C.gnutls_hash(h.hash, cbuf, C.size_t(dataLen)) + if int(ret) < 0 { + return fmt.Errorf("hash failed: %s", C.GoString(C.gnutls_strerror(ret))) + } + return nil +} + +// Sum get hash result +func (h *Hash) Sum(buf []byte) []byte { + if buf != nil { + h.Write(buf) + } + + dstBuf := C.malloc(C.size_t(h.hashLen)) + defer C.free(dstBuf) + + C.gnutls_hash_output(h.hash, dstBuf) + + gobuf := C.GoBytes(dstBuf, h.hashLen) + + return gobuf +} + +// Close destroy hash context +func (h *Hash) Close() error { + if h.hash != nil { + C.gnutls_hash_deinit(h.hash, nil) + h.hash = nil + } + return nil +} +func (h *Hash) free() { + //log.Println("free hash") + h.Close() +} + +// GetHashOutputLen get the hash algorithm output length +// +// example GNUTLS_MD5 is 16 +func GetHashOutputLen(t HashType) int { + return int(C.gnutls_hash_get_len(C.gnutls_digest_algorithm_t(t))) +} diff --git a/hash_test.go b/hash_test.go index 288ba03..1c5a3fd 100644 --- a/hash_test.go +++ b/hash_test.go @@ -1,52 +1,52 @@ -package gnutls - -import ( - "bytes" - "crypto/rand" - "crypto/sha512" - "encoding/hex" - "runtime" - "testing" - "time" -) - -func TestHashSHA(t *testing.T) { - h := NewHash(GNUTLS_HASH_SHA512) - //defer h.Close() - - data := []byte("1234") - - h1 := h.Sum(data) - - h3 := sha512.New() - h3.Write(data) - h4 := h3.Sum(nil) - if !bytes.Equal(h4[:], h1) { - t.Logf("\n%s\n%s", hex.EncodeToString(h4[:]), hex.EncodeToString(h1)) - t.Fatal("hash not equal") - } - runtime.GC() - time.Sleep(1 * time.Second) -} - -func BenchmarkHashSHA512(b *testing.B) { - buf := make([]byte, 1000*1024) - rand.Reader.Read(buf) - for i := 0; i < b.N; i++ { - h := NewHash(GNUTLS_HASH_SHA512) - h.Write(buf) - h.Sum(nil) - h.Close() - } -} - -func BenchmarkHashSHA512s(b *testing.B) { - buf := make([]byte, 1000*1024) - rand.Reader.Read(buf) - for i := 0; i < b.N; i++ { - h := sha512.New() - h.Write(buf) - h.Sum(nil) - //h.Close() - } -} +package gnutls + +import ( + "bytes" + "crypto/rand" + "crypto/sha512" + "encoding/hex" + "runtime" + "testing" + "time" +) + +func TestHashSHA(t *testing.T) { + h := NewHash(GNUTLS_HASH_SHA512) + //defer h.Close() + + data := []byte("1234") + + h1 := h.Sum(data) + + h3 := sha512.New() + h3.Write(data) + h4 := h3.Sum(nil) + if !bytes.Equal(h4[:], h1) { + t.Logf("\n%s\n%s", hex.EncodeToString(h4[:]), hex.EncodeToString(h1)) + t.Fatal("hash not equal") + } + runtime.GC() + time.Sleep(1 * time.Second) +} + +func BenchmarkHashSHA512(b *testing.B) { + buf := make([]byte, 1000*1024) + rand.Reader.Read(buf) + for i := 0; i < b.N; i++ { + h := NewHash(GNUTLS_HASH_SHA512) + h.Write(buf) + h.Sum(nil) + h.Close() + } +} + +func BenchmarkHashSHA512s(b *testing.B) { + buf := make([]byte, 1000*1024) + rand.Reader.Read(buf) + for i := 0; i < b.N; i++ { + h := sha512.New() + h.Write(buf) + h.Sum(nil) + //h.Close() + } +} diff --git a/tls.go b/tls.go index b243ab3..db58bb5 100644 --- a/tls.go +++ b/tls.go @@ -1,492 +1,492 @@ -package gnutls - -/* -#include "_gnutls.h" -#cgo pkg-config: gnutls -*/ -import "C" -import ( - "errors" - "fmt" - "io" - "log" - "net" - "runtime" - "sync" - "time" - "unsafe" -) - -const ( - GNUTLS_NAME_DNS = 1 - GNUTLS_X509_FMT_PEM = 1 - GNUTLS_ALPN_MANDATORY = 1 - GNUTLS_ALPN_SERVER_PRECEDENCE = 1 << 1 -) - -// Conn gnutls TLS connection, -// use Listen, Dial, Server or Client create this object -type Conn struct { - c net.Conn - handshake bool - sess *C.struct_session - cservname *C.char - state *ConnectionState - cfg *Config - closed bool - lock *sync.Mutex -} - -// Config gnutls TLS configure, -type Config struct { - ServerName string - Certificates []*Certificate - InsecureSkipVerify bool - NextProtos []string -} - -// ConnectionState gnutls TLS connection state -type ConnectionState struct { - // SNI name client send - ServerName string - // selected ALPN protocl - NegotiatedProtocol string - HandshakeComplete bool - // TLS version number, ex: 0x303 - Version uint16 - // TLS version number, ex: TLS1.0 - VersionName string - // peer's certificate - PeerCertificate *Certificate -} - -type listener struct { - l net.Listener - c *Config -} - -// Accept -func (l *listener) Accept() (net.Conn, error) { - c, err := l.l.Accept() - if err != nil { - return nil, err - } - return Server(c, l.c) -} - -// Close -func (l *listener) Close() error { - return l.l.Close() -} - -// Addr -func (l *listener) Addr() net.Addr { - return l.l.Addr() -} - -// Dial dial to (network, addr) and create a gnutls Conn -func Dial(network, addr string, cfg *Config) (*Conn, error) { - c, err := net.Dial(network, addr) - if err != nil { - return nil, err - } - return Client(c, cfg) -} - -// Listen create a gnutls listener on (network, addr), -func Listen(network, addr string, cfg *Config) (net.Listener, error) { - if cfg == nil { - return nil, errors.New("config is needed") - } - l, err := net.Listen(network, addr) - if err != nil { - return nil, err - } - return &listener{l, cfg}, nil -} - -// Server create a server TLS Conn on c -func Server(c net.Conn, cfg *Config) (*Conn, error) { - if cfg == nil { - return nil, errors.New("config is needed") - } - - var sess = C.init_gnutls_server_session() - - conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)} - - C.set_data(sess, C.size_t(uintptr(unsafe.Pointer(conn)))) - C.set_callback(sess) - - if cfg.NextProtos != nil { - if err := setAlpnProtocols(sess, cfg); err != nil { - log.Println(err) - } - } - runtime.SetFinalizer(conn, (*Conn).free) - return conn, nil -} - -// Client create a client TLS Conn on c -func Client(c net.Conn, cfg *Config) (*Conn, error) { - var sess = C.init_gnutls_client_session() - - conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)} - - C.set_data(sess, C.size_t(uintptr(unsafe.Pointer(conn)))) - C.set_callback(sess) - - if cfg != nil { - if cfg.ServerName != "" { - srvname := C.CString(cfg.ServerName) - //defer C.free(unsafe.Pointer(srvname)) - conn.cservname = srvname - C.gnutls_server_name_set(sess.session, GNUTLS_NAME_DNS, - unsafe.Pointer(srvname), C.size_t(len(cfg.ServerName))) - } - - if !cfg.InsecureSkipVerify { - if conn.cservname != nil { - C.gnutls_session_set_verify_cert(sess.session, conn.cservname, 0) - } else { - C.gnutls_session_set_verify_cert(sess.session, nil, 0) - } - } - - if cfg.NextProtos != nil { - if err := setAlpnProtocols(sess, cfg); err != nil { - log.Println(err) - } - } - - } else { - C.gnutls_session_set_verify_cert(sess.session, nil, 0) - } - runtime.SetFinalizer(conn, (*Conn).free) - return conn, nil -} - -func setAlpnProtocols(sess *C.struct_session, cfg *Config) error { - arg := make([](*C.char), 0) - for _, s := range cfg.NextProtos { - cbuf := C.CString(s) - defer C.free(unsafe.Pointer(cbuf)) - arg = append(arg, (*C.char)(cbuf)) - } - ret := C.alpn_set_protocols(sess, - (**C.char)(unsafe.Pointer(&arg[0])), C.int(len(cfg.NextProtos))) - if int(ret) < 0 { - return fmt.Errorf("set alpn failed: %s", C.GoString(C.gnutls_strerror(ret))) - } - return nil - -} - -// Handshake call handshake for TLS Conn, -// this function will call automatic on Read/Write, if not handshake yet -func (c *Conn) Handshake() error { - c.lock.Lock() - defer c.lock.Unlock() - if c.handshake { - return nil - } - ret := C.handshake(c.sess) - if int(ret) < 0 { - return fmt.Errorf("handshake error: %s", C.GoString(C.gnutls_strerror(ret))) - } - c.handshake = true - //log.Println("handshake done") - return nil -} - -// Read read application data from TLS connection -func (c *Conn) Read(buf []byte) (n int, err error) { - err = c.Handshake() - if err != nil { - return - } - - if len(buf) == 0 { - return 0, nil - } - if c.closed { - return 0, io.EOF - } - //bufLen := len(buf) - //cbuf := C.malloc(C.size_t(bufLen)) - //defer C.free(cbuf) - - ret := C.gnutls_record_recv(c.sess.session, - unsafe.Pointer(&buf[0]), C.size_t(len(buf))) - if int(ret) < 0 { - return 0, fmt.Errorf("read error: %s", - C.GoString(C.gnutls_strerror(C.int(ret)))) - } - - if int(ret) == 0 { - return 0, io.EOF - } - - n = int(ret) - //gobuf2 := C.GoBytes(cbuf, C.int(ret)) - //copy(buf, gobuf2) - return n, nil -} - -// Write write application data to TLS connection -func (c *Conn) Write(buf []byte) (n int, err error) { - err = c.Handshake() - if err != nil { - return - } - - // user may call Write(nil) to do handshake - if len(buf) == 0 { - return 0, nil - } - if c.closed { - return 0, io.EOF - } - //cbuf := C.CBytes(buf) - //defer C.free(cbuf) - - ret := C.gnutls_record_send(c.sess.session, - unsafe.Pointer(&buf[0]), C.size_t(len(buf))) - n = int(ret) - - if n < 0 { - return 0, fmt.Errorf("write error: %s", - C.GoString(C.gnutls_strerror(C.int(ret)))) - } - - if int(ret) == 0 { - return 0, io.EOF - } - - return n, nil -} - -// Close close the TLS conn and destroy the tls context -func (c *Conn) Close() error { - c.lock.Lock() - defer c.lock.Unlock() - if c.closed { - return nil - } - C.gnutls_record_send(c.sess.session, nil, 0) - C.session_destroy(c.sess) - c.c.Close() - if c.cservname != nil { - C.free(unsafe.Pointer(c.cservname)) - } - - if c.state != nil && c.state.PeerCertificate != nil { - c.state.PeerCertificate.Free() - } - c.closed = true - return nil -} - -func (c *Conn) free() { - //log.Println("free conn") - c.Close() -} - -// SetWriteDeadline implements net.Conn -func (c *Conn) SetWriteDeadline(t time.Time) error { - return c.c.SetWriteDeadline(t) -} - -// SetReadDeadline implements net.Conn -func (c *Conn) SetReadDeadline(t time.Time) error { - return c.c.SetReadDeadline(t) -} - -// RemoteAddr implements net.Conn -func (c *Conn) RemoteAddr() net.Addr { - return c.c.RemoteAddr() -} - -// LocalAddr implements net.Conn -func (c *Conn) LocalAddr() net.Addr { - return c.c.LocalAddr() -} - -// SetDeadline implements net.Conn -func (c *Conn) SetDeadline(t time.Time) error { - return c.c.SetDeadline(t) -} - -// ConnectionState get TLS connection state -func (c *Conn) ConnectionState() *ConnectionState { - if c.state != nil { - return c.state - } - version := - uint16(C.gnutls_protocol_get_version(c.sess.session)) - - versionname := C.GoString( - C.gnutls_protocol_get_name(C.gnutls_protocol_t(version))) - - state := &ConnectionState{ - NegotiatedProtocol: c.getAlpnSelectedProtocol(), - Version: version, - HandshakeComplete: c.handshake, - ServerName: c.getServerName(), - VersionName: versionname, - PeerCertificate: c.getPeerCertificate(), - } - c.state = state - return state -} - -func (c *Conn) getPeerCertificate() *Certificate { - var size int - st := C.get_peer_certificate(c.sess.session, (*C.int)(unsafe.Pointer(&size))) - if st == nil { - return nil - } - cert := &Certificate{cert: st, certSize: C.int(size)} - runtime.SetFinalizer(cert, (*Certificate).free) - return cert -} - -func (c *Conn) getAlpnSelectedProtocol() string { - cbuf := C.malloc(100) - defer C.free(cbuf) - - ret := C.alpn_get_selected_protocol(c.sess, (*C.char)(cbuf)) - if int(ret) < 0 { - return "" - } - alpnname := C.GoString((*C.char)(cbuf)) - return alpnname -} - -func (c *Conn) getServerName() string { - buflen := 100 - nametype := GNUTLS_NAME_DNS - cbuf := C.malloc(C.size_t(buflen)) - defer C.free(cbuf) - - ret := C.gnutls_server_name_get(c.sess.session, cbuf, - (*C.size_t)(unsafe.Pointer(&buflen)), - (*C.uint)(unsafe.Pointer(&nametype)), 0) - if int(ret) < 0 { - return "" - } - name := C.GoString((*C.char)(cbuf)) - return name -} - -/* -var bufPool = &sync.Pool{ - New: func() interface{} { - return make([]byte, 16*1024) - }, -} -*/ - -// onDataReadCallback callback function for gnutls library want to read data from network -// -//export onDataReadCallback -func onDataReadCallback(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int { - //log.Println("read addr ", uintptr(d)) - conn := (*Conn)(unsafe.Pointer((uintptr(d)))) - - buf := make([]byte, int(bufLen)) - /* - _length := int(bufLen) - buf := bufPool.Get().([]byte) - if len(buf) < _length { - buf = make([]byte, _length) - } - defer bufPool.Put(buf) - */ - - n, err := conn.c.Read(buf[0:]) - if err != nil { - //log.Println(err) - // 0 indicates connection termination - return 0 - } - //cbuf2 := C.CBytes(buf[:n]) - // d := C.CString(string(buf[:n])) - //defer C.free(cbuf2) - C.memcpy(unsafe.Pointer(cbuf), unsafe.Pointer(&buf[0]), C.size_t(n)) - return C.int(n) -} - -// onDataWriteCallback callback function for gnutls library want to send data to network -// -//export onDataWriteCallback -func onDataWriteCallback(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int { - //log.Println("write addr ", uintptr(d), int(_l)) - conn := (*Conn)(unsafe.Pointer((uintptr(d)))) - gobuf := C.GoBytes(unsafe.Pointer(cbuf), bufLen) - n, err := conn.c.Write(gobuf) - if err != nil { - //log.Println(err) - return -1 - } - return C.int(n) -} - -// onDataTimeoutRead callback function for timeout read -// -//export onDataTimeoutRead -func onDataTimeoutRead(d unsafe.Pointer, delay C.int) C.int { - log.Println("timeout pull function") - return 0 -} - -// onCertSelectCallback callback function for ceritificate select, -// this function select certificate from Config.Certificates field, -// -// on server side, this function select the certificate depend on SNI what client send, -// if client not send SNI, select the Config.Certificates[0] -// -//export onCertSelectCallback -func onCertSelectCallback(ptr unsafe.Pointer, hostname *C.char, - namelen C.int, pcertLength *C.int, cert **C.gnutls_pcert_st, privkey *C.gnutls_privkey_t) C.int { - - servername := C.GoStringN(hostname, namelen) - //log.Println("go cert select callback ", servername) - conn := (*Conn)(unsafe.Pointer((uintptr(ptr)))) - //log.Println(conn) - if int(namelen) == 0 && conn.cfg.Certificates != nil { - _cert := conn.cfg.Certificates[0] - *pcertLength = _cert.certSize - *cert = _cert.cert - *privkey = _cert.privkey - //log.Println("set pcert length ", _cert.certSize) - return 0 - } - for _, _cert := range conn.cfg.Certificates { - //log.Println(cert) - if _cert.matchName(servername) { - //log.Println("matched name ", _cert.names) - *pcertLength = _cert.certSize - *cert = _cert.cert - *privkey = _cert.privkey - //log.Println("set pcert length ", _cert.certSize) - return 0 - } - } - if conn.cfg.Certificates != nil { - _cert := conn.cfg.Certificates[0] - *pcertLength = _cert.certSize - *cert = _cert.cert - *privkey = _cert.privkey - //log.Println("set pcert length ", _cert.certSize) - return 0 - } - *pcertLength = 0 - //log.Println("set pcert length 0") - return -1 -} - -func init() { - C.init_xcred() - C.init_priority_cache() -} +package gnutls + +/* +#include "_gnutls.h" +#cgo pkg-config: gnutls +*/ +import "C" +import ( + "errors" + "fmt" + "io" + "log" + "net" + "runtime" + "sync" + "time" + "unsafe" +) + +const ( + GNUTLS_NAME_DNS = 1 + GNUTLS_X509_FMT_PEM = 1 + GNUTLS_ALPN_MANDATORY = 1 + GNUTLS_ALPN_SERVER_PRECEDENCE = 1 << 1 +) + +// Conn gnutls TLS connection, +// use Listen, Dial, Server or Client create this object +type Conn struct { + c net.Conn + handshake bool + sess *C.struct_session + cservname *C.char + state *ConnectionState + cfg *Config + closed bool + lock *sync.Mutex +} + +// Config gnutls TLS configure, +type Config struct { + ServerName string + Certificates []*Certificate + InsecureSkipVerify bool + NextProtos []string +} + +// ConnectionState gnutls TLS connection state +type ConnectionState struct { + // SNI name client send + ServerName string + // selected ALPN protocl + NegotiatedProtocol string + HandshakeComplete bool + // TLS version number, ex: 0x303 + Version uint16 + // TLS version number, ex: TLS1.0 + VersionName string + // peer's certificate + PeerCertificate *Certificate +} + +type listener struct { + l net.Listener + c *Config +} + +// Accept +func (l *listener) Accept() (net.Conn, error) { + c, err := l.l.Accept() + if err != nil { + return nil, err + } + return Server(c, l.c) +} + +// Close +func (l *listener) Close() error { + return l.l.Close() +} + +// Addr +func (l *listener) Addr() net.Addr { + return l.l.Addr() +} + +// Dial dial to (network, addr) and create a gnutls Conn +func Dial(network, addr string, cfg *Config) (*Conn, error) { + c, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + return Client(c, cfg) +} + +// Listen create a gnutls listener on (network, addr), +func Listen(network, addr string, cfg *Config) (net.Listener, error) { + if cfg == nil { + return nil, errors.New("config is needed") + } + l, err := net.Listen(network, addr) + if err != nil { + return nil, err + } + return &listener{l, cfg}, nil +} + +// Server create a server TLS Conn on c +func Server(c net.Conn, cfg *Config) (*Conn, error) { + if cfg == nil { + return nil, errors.New("config is needed") + } + + var sess = C.init_gnutls_server_session() + + conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)} + + C.set_data(sess, C.size_t(uintptr(unsafe.Pointer(conn)))) + C.set_callback(sess) + + if cfg.NextProtos != nil { + if err := setAlpnProtocols(sess, cfg); err != nil { + log.Println(err) + } + } + runtime.SetFinalizer(conn, (*Conn).free) + return conn, nil +} + +// Client create a client TLS Conn on c +func Client(c net.Conn, cfg *Config) (*Conn, error) { + var sess = C.init_gnutls_client_session() + + conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)} + + C.set_data(sess, C.size_t(uintptr(unsafe.Pointer(conn)))) + C.set_callback(sess) + + if cfg != nil { + if cfg.ServerName != "" { + srvname := C.CString(cfg.ServerName) + //defer C.free(unsafe.Pointer(srvname)) + conn.cservname = srvname + C.gnutls_server_name_set(sess.session, GNUTLS_NAME_DNS, + unsafe.Pointer(srvname), C.size_t(len(cfg.ServerName))) + } + + if !cfg.InsecureSkipVerify { + if conn.cservname != nil { + C.gnutls_session_set_verify_cert(sess.session, conn.cservname, 0) + } else { + C.gnutls_session_set_verify_cert(sess.session, nil, 0) + } + } + + if cfg.NextProtos != nil { + if err := setAlpnProtocols(sess, cfg); err != nil { + log.Println(err) + } + } + + } else { + C.gnutls_session_set_verify_cert(sess.session, nil, 0) + } + runtime.SetFinalizer(conn, (*Conn).free) + return conn, nil +} + +func setAlpnProtocols(sess *C.struct_session, cfg *Config) error { + arg := make([](*C.char), 0) + for _, s := range cfg.NextProtos { + cbuf := C.CString(s) + defer C.free(unsafe.Pointer(cbuf)) + arg = append(arg, (*C.char)(cbuf)) + } + ret := C.alpn_set_protocols(sess, + (**C.char)(unsafe.Pointer(&arg[0])), C.int(len(cfg.NextProtos))) + if int(ret) < 0 { + return fmt.Errorf("set alpn failed: %s", C.GoString(C.gnutls_strerror(ret))) + } + return nil + +} + +// Handshake call handshake for TLS Conn, +// this function will call automatic on Read/Write, if not handshake yet +func (c *Conn) Handshake() error { + c.lock.Lock() + defer c.lock.Unlock() + if c.handshake { + return nil + } + ret := C.handshake(c.sess) + if int(ret) < 0 { + return fmt.Errorf("handshake error: %s", C.GoString(C.gnutls_strerror(ret))) + } + c.handshake = true + //log.Println("handshake done") + return nil +} + +// Read read application data from TLS connection +func (c *Conn) Read(buf []byte) (n int, err error) { + err = c.Handshake() + if err != nil { + return + } + + if len(buf) == 0 { + return 0, nil + } + if c.closed { + return 0, io.EOF + } + //bufLen := len(buf) + //cbuf := C.malloc(C.size_t(bufLen)) + //defer C.free(cbuf) + + ret := C.gnutls_record_recv(c.sess.session, + unsafe.Pointer(&buf[0]), C.size_t(len(buf))) + if int(ret) < 0 { + return 0, fmt.Errorf("read error: %s", + C.GoString(C.gnutls_strerror(C.int(ret)))) + } + + if int(ret) == 0 { + return 0, io.EOF + } + + n = int(ret) + //gobuf2 := C.GoBytes(cbuf, C.int(ret)) + //copy(buf, gobuf2) + return n, nil +} + +// Write write application data to TLS connection +func (c *Conn) Write(buf []byte) (n int, err error) { + err = c.Handshake() + if err != nil { + return + } + + // user may call Write(nil) to do handshake + if len(buf) == 0 { + return 0, nil + } + if c.closed { + return 0, io.EOF + } + //cbuf := C.CBytes(buf) + //defer C.free(cbuf) + + ret := C.gnutls_record_send(c.sess.session, + unsafe.Pointer(&buf[0]), C.size_t(len(buf))) + n = int(ret) + + if n < 0 { + return 0, fmt.Errorf("write error: %s", + C.GoString(C.gnutls_strerror(C.int(ret)))) + } + + if int(ret) == 0 { + return 0, io.EOF + } + + return n, nil +} + +// Close close the TLS conn and destroy the tls context +func (c *Conn) Close() error { + c.lock.Lock() + defer c.lock.Unlock() + if c.closed { + return nil + } + C.gnutls_record_send(c.sess.session, nil, 0) + C.session_destroy(c.sess) + c.c.Close() + if c.cservname != nil { + C.free(unsafe.Pointer(c.cservname)) + } + + if c.state != nil && c.state.PeerCertificate != nil { + c.state.PeerCertificate.Free() + } + c.closed = true + return nil +} + +func (c *Conn) free() { + //log.Println("free conn") + c.Close() +} + +// SetWriteDeadline implements net.Conn +func (c *Conn) SetWriteDeadline(t time.Time) error { + return c.c.SetWriteDeadline(t) +} + +// SetReadDeadline implements net.Conn +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.c.SetReadDeadline(t) +} + +// RemoteAddr implements net.Conn +func (c *Conn) RemoteAddr() net.Addr { + return c.c.RemoteAddr() +} + +// LocalAddr implements net.Conn +func (c *Conn) LocalAddr() net.Addr { + return c.c.LocalAddr() +} + +// SetDeadline implements net.Conn +func (c *Conn) SetDeadline(t time.Time) error { + return c.c.SetDeadline(t) +} + +// ConnectionState get TLS connection state +func (c *Conn) ConnectionState() *ConnectionState { + if c.state != nil { + return c.state + } + version := + uint16(C.gnutls_protocol_get_version(c.sess.session)) + + versionname := C.GoString( + C.gnutls_protocol_get_name(C.gnutls_protocol_t(version))) + + state := &ConnectionState{ + NegotiatedProtocol: c.getAlpnSelectedProtocol(), + Version: version, + HandshakeComplete: c.handshake, + ServerName: c.getServerName(), + VersionName: versionname, + PeerCertificate: c.getPeerCertificate(), + } + c.state = state + return state +} + +func (c *Conn) getPeerCertificate() *Certificate { + var size int + st := C.get_peer_certificate(c.sess.session, (*C.int)(unsafe.Pointer(&size))) + if st == nil { + return nil + } + cert := &Certificate{cert: st, certSize: C.int(size)} + runtime.SetFinalizer(cert, (*Certificate).free) + return cert +} + +func (c *Conn) getAlpnSelectedProtocol() string { + cbuf := C.malloc(100) + defer C.free(cbuf) + + ret := C.alpn_get_selected_protocol(c.sess, (*C.char)(cbuf)) + if int(ret) < 0 { + return "" + } + alpnname := C.GoString((*C.char)(cbuf)) + return alpnname +} + +func (c *Conn) getServerName() string { + buflen := 100 + nametype := GNUTLS_NAME_DNS + cbuf := C.malloc(C.size_t(buflen)) + defer C.free(cbuf) + + ret := C.gnutls_server_name_get(c.sess.session, cbuf, + (*C.size_t)(unsafe.Pointer(&buflen)), + (*C.uint)(unsafe.Pointer(&nametype)), 0) + if int(ret) < 0 { + return "" + } + name := C.GoString((*C.char)(cbuf)) + return name +} + +/* +var bufPool = &sync.Pool{ + New: func() interface{} { + return make([]byte, 16*1024) + }, +} +*/ + +// onDataReadCallback callback function for gnutls library want to read data from network +// +//export onDataReadCallback +func onDataReadCallback(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int { + //log.Println("read addr ", uintptr(d)) + conn := (*Conn)(unsafe.Pointer((uintptr(d)))) + + buf := make([]byte, int(bufLen)) + /* + _length := int(bufLen) + buf := bufPool.Get().([]byte) + if len(buf) < _length { + buf = make([]byte, _length) + } + defer bufPool.Put(buf) + */ + + n, err := conn.c.Read(buf[0:]) + if err != nil { + //log.Println(err) + // 0 indicates connection termination + return 0 + } + //cbuf2 := C.CBytes(buf[:n]) + // d := C.CString(string(buf[:n])) + //defer C.free(cbuf2) + C.memcpy(unsafe.Pointer(cbuf), unsafe.Pointer(&buf[0]), C.size_t(n)) + return C.int(n) +} + +// onDataWriteCallback callback function for gnutls library want to send data to network +// +//export onDataWriteCallback +func onDataWriteCallback(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int { + //log.Println("write addr ", uintptr(d), int(_l)) + conn := (*Conn)(unsafe.Pointer((uintptr(d)))) + gobuf := C.GoBytes(unsafe.Pointer(cbuf), bufLen) + n, err := conn.c.Write(gobuf) + if err != nil { + //log.Println(err) + return -1 + } + return C.int(n) +} + +// onDataTimeoutRead callback function for timeout read +// +//export onDataTimeoutRead +func onDataTimeoutRead(d unsafe.Pointer, delay C.int) C.int { + log.Println("timeout pull function") + return 0 +} + +// onCertSelectCallback callback function for ceritificate select, +// this function select certificate from Config.Certificates field, +// +// on server side, this function select the certificate depend on SNI what client send, +// if client not send SNI, select the Config.Certificates[0] +// +//export onCertSelectCallback +func onCertSelectCallback(ptr unsafe.Pointer, hostname *C.char, + namelen C.int, pcertLength *C.int, cert **C.gnutls_pcert_st, privkey *C.gnutls_privkey_t) C.int { + + servername := C.GoStringN(hostname, namelen) + //log.Println("go cert select callback ", servername) + conn := (*Conn)(unsafe.Pointer((uintptr(ptr)))) + //log.Println(conn) + if int(namelen) == 0 && conn.cfg.Certificates != nil { + _cert := conn.cfg.Certificates[0] + *pcertLength = _cert.certSize + *cert = _cert.cert + *privkey = _cert.privkey + //log.Println("set pcert length ", _cert.certSize) + return 0 + } + for _, _cert := range conn.cfg.Certificates { + //log.Println(cert) + if _cert.matchName(servername) { + //log.Println("matched name ", _cert.names) + *pcertLength = _cert.certSize + *cert = _cert.cert + *privkey = _cert.privkey + //log.Println("set pcert length ", _cert.certSize) + return 0 + } + } + if conn.cfg.Certificates != nil { + _cert := conn.cfg.Certificates[0] + *pcertLength = _cert.certSize + *cert = _cert.cert + *privkey = _cert.privkey + //log.Println("set pcert length ", _cert.certSize) + return 0 + } + *pcertLength = 0 + //log.Println("set pcert length 0") + return -1 +} + +func init() { + C.init_xcred() + C.init_priority_cache() +} diff --git a/tls_test.go b/tls_test.go index 1f9ce0f..e28227b 100644 --- a/tls_test.go +++ b/tls_test.go @@ -1,431 +1,431 @@ -package gnutls - -import ( - "bufio" - "bytes" - "crypto/tls" - "fmt" - "io" - "net" - "net/http" - "runtime" - "testing" - "time" -) - -func TestTLSClient(t *testing.T) { - cert, err := tls.LoadX509KeyPair("testdata/server.crt", "testdata/server.key") - if err != nil { - t.Fatal("load certificate failed") - } - l, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ - Certificates: []tls.Certificate{cert}, - }) - if err != nil { - t.Fatal("listen failed") - } - defer l.Close() - addr := l.Addr().String() - t.Log("test server listen on ", addr) - go func() { - for { - c, err := l.Accept() - if err != nil { - break - } - t.Logf("accept connection from %s", c.RemoteAddr()) - go func(c net.Conn) { - defer c.Close() - for { - buf := make([]byte, 4096) - n, err := c.Read(buf) - if err != nil { - t.Log("connection closed") - break - } - if _, err = c.Write(buf[:n]); err != nil { - break - } - } - }(c) - } - }() - - c, err := Dial("tcp", addr, &Config{InsecureSkipVerify: true}) - if err != nil { - t.Fatal("gnutls dial ", err) - } - defer c.Close() - - data := "hello, world" - if _, err = c.Write([]byte(data)); err != nil { - t.Fatal("gnutls write ", err) - } - buf := make([]byte, 100) - n, err := c.Read(buf) - if err != nil { - t.Fatal("gnutls read ", err) - } - if string(buf[:n]) != data { - t.Errorf("need: %s, got: %s", data, string(buf[:n])) - } -} - -func TestTLSServer(t *testing.T) { - cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key") - if err != nil { - t.Fatal(err) - } - l, err := Listen("tcp", "127.0.0.1:0", &Config{ - Certificates: []*Certificate{cert}, - }) - if err != nil { - t.Fatal("gnutls listen ", err) - } - addr := l.Addr().String() - t.Log("test server listen on ", addr) - defer l.Close() - go func() { - for { - c, err := l.Accept() - if err != nil { - t.Log("gnutls accept ", err) - break - } - t.Log("accept connection from ", c.RemoteAddr()) - go func(c net.Conn) { - defer c.Close() - tlsconn := c.(*Conn) - if err := tlsconn.Handshake(); err != nil { - t.Log(err) - return - } - buf := make([]byte, 4096) - for { - n, err := c.Read(buf[0:]) - if err != nil { - t.Log("gnutls read ", err) - break - } - if _, err := c.Write(buf[:n]); err != nil { - t.Log("gnutls write ", err) - break - } - } - }(c) - } - }() - - c, err := tls.Dial("tcp", addr, &tls.Config{InsecureSkipVerify: true}) - if err != nil { - t.Fatal("dial ", err) - } - defer c.Close() - - data := "hello, world" - if _, err := c.Write([]byte(data)); err != nil { - t.Fatal("write ", err) - } - buf := make([]byte, 100) - n, err := c.Read(buf) - if err != nil { - t.Fatal("read ", err) - } - if string(buf[:n]) != data { - t.Errorf("need: %s, got: %s", data, string(buf[:n])) - } - runtime.GC() - time.Sleep(1 * time.Second) -} - -func TestTLSALPNServer(t *testing.T) { - serveralpn := []string{"a1", "a3", "a2"} - clientalpn := []string{"a0", "a2", "a5"} - expectedAlpn := "a2" - cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key") - if err != nil { - t.Fatal(err) - } - l, err := Listen("tcp", "127.0.0.1:0", &Config{ - Certificates: []*Certificate{cert}, - NextProtos: serveralpn, - }) - if err != nil { - t.Fatal("gnutls listen ", err) - } - addr := l.Addr().String() - t.Log("test server listen on ", addr) - defer l.Close() - go func() { - for { - c, err := l.Accept() - if err != nil { - t.Log("gnutls accept ", err) - break - } - t.Log("accept connection from ", c.RemoteAddr()) - go func(c net.Conn) { - defer c.Close() - tlsConn := c.(*Conn) - if err := tlsConn.Handshake(); err != nil { - t.Log(err) - return - } - connState := tlsConn.ConnectionState() - t.Logf("%+v", connState) - buf := make([]byte, 4096) - for { - n, err := c.Read(buf[0:]) - if err != nil { - t.Log("gnutls read ", err) - break - } - if _, err := c.Write(buf[:n]); err != nil { - t.Log("gnutls write ", err) - break - } - } - }(c) - } - }() - - c, err := tls.Dial("tcp", addr, &tls.Config{ - InsecureSkipVerify: true, - ServerName: "localhost", - NextProtos: clientalpn, - }) - if err != nil { - t.Fatal("dial ", err) - } - defer c.Close() - - if err := c.Handshake(); err != nil { - t.Fatal(err) - } - connState := c.ConnectionState() - t.Logf("%+v", connState) - - if connState.NegotiatedProtocol != expectedAlpn { - t.Errorf("expected alpn %s, got %s", - expectedAlpn, connState.NegotiatedProtocol) - } - - data := "hello, world" - if _, err := c.Write([]byte(data)); err != nil { - t.Fatal("write ", err) - } - buf := make([]byte, 100) - n, err := c.Read(buf) - if err != nil { - t.Fatal("read ", err) - } - if string(buf[:n]) != data { - t.Errorf("need: %s, got: %s", data, string(buf[:n])) - } - runtime.GC() - time.Sleep(1 * time.Second) -} - -func TestTLSALPNClient(t *testing.T) { - serveralpn := []string{"a1", "a3", "a2"} - clientalpn := []string{"a0", "a2", "a5"} - expectedAlpn := "a2" - - cert, err := tls.LoadX509KeyPair("testdata/server.crt", "testdata/server.key") - if err != nil { - t.Fatal("load key failed") - } - - l, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ - Certificates: []tls.Certificate{cert}, - NextProtos: serveralpn, - }) - - if err != nil { - t.Fatal("tls listen ", err) - } - addr := l.Addr().String() - t.Log("test server listen on ", addr) - defer l.Close() - go func() { - for { - c, err := l.Accept() - if err != nil { - t.Log("gnutls accept ", err) - break - } - t.Log("accept connection from ", c.RemoteAddr()) - go func(c net.Conn) { - defer c.Close() - tlsConn := c.(*tls.Conn) - if err := tlsConn.Handshake(); err != nil { - t.Log(err) - return - } - connState := tlsConn.ConnectionState() - t.Logf("%+v", connState) - buf := make([]byte, 4096) - for { - n, err := c.Read(buf[0:]) - if err != nil { - t.Log("tls read ", err) - break - } - if _, err := c.Write(buf[:n]); err != nil { - t.Log("tls write ", err) - break - } - } - }(c) - } - }() - - c, err := Dial("tcp", addr, &Config{InsecureSkipVerify: true, - ServerName: "localhost", - NextProtos: clientalpn, - }) - if err != nil { - t.Fatal("dial ", err) - } - defer c.Close() - - if err := c.Handshake(); err != nil { - t.Fatal(err) - } - connState := c.ConnectionState() - t.Logf("%+v", connState) - - if connState.NegotiatedProtocol != expectedAlpn { - t.Errorf("expected alpn %s, got %s", - expectedAlpn, connState.NegotiatedProtocol) - } - - data := "hello, world" - if _, err := c.Write([]byte(data)); err != nil { - t.Fatal("write ", err) - } - buf := make([]byte, 100) - n, err := c.Read(buf) - if err != nil { - t.Fatal("read ", err) - } - if string(buf[:n]) != data { - t.Errorf("need: %s, got: %s", data, string(buf[:n])) - } - runtime.GC() - time.Sleep(1 * time.Second) -} - -func TestTLSServerSNI(t *testing.T) { - certificates := []*Certificate{} - cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key") - if err != nil { - t.Fatal("load key failed") - } - - certificates = append(certificates, cert) - cert, err = LoadX509KeyPair("testdata/server2.crt", "testdata/server2.key") - if err != nil { - t.Fatal("load key failed") - } - - certificates = append(certificates, cert) - cert, err = LoadX509KeyPair("testdata/server3.crt", "testdata/server3.key") - if err != nil { - t.Fatal("load key failed") - } - certificates = append(certificates, cert) - - l, err := Listen("tcp", "127.0.0.1:0", &Config{ - Certificates: certificates, - }) - if err != nil { - t.Fatal(err) - } - defer l.Close() - addr := l.Addr().String() - go func() { - for { - c, err := l.Accept() - if err != nil { - t.Log(err) - break - } - go func(c net.Conn) { - defer c.Close() - tlsconn := c.(*Conn) - if err := tlsconn.Handshake(); err != nil { - t.Log(err) - return - } - state := tlsconn.ConnectionState() - fmt.Fprintf(c, state.ServerName) - }(c) - } - }() - - for _, cfg := range []struct { - serverName string - commonName string - }{ - {"abc.com", "abc.com"}, - {"example.com", "example.com"}, - {"a.aaa.com", "*.aaa.com"}, - {"b.aaa.com", "*.aaa.com"}, - } { - conn, err := tls.Dial("tcp", addr, &tls.Config{ - ServerName: cfg.serverName, - InsecureSkipVerify: true, - }) - if err != nil { - t.Fatal(err) - } - state := conn.ConnectionState() - _commonName := state.PeerCertificates[0].Subject.CommonName - if _commonName != cfg.commonName { - t.Errorf("expect: %s, got: %s", cfg.commonName, _commonName) - } - buf := make([]byte, 100) - n, err := conn.Read(buf) - if err != nil && err != io.EOF { - t.Error(err) - } - if !bytes.Equal(buf[:n], []byte(cfg.serverName)) { - t.Errorf("expect %s, got %s", cfg.serverName, string(buf[:n])) - } - conn.Close() - } - runtime.GC() - time.Sleep(1 * time.Second) -} - -func TestTLSGetPeerCert(t *testing.T) { - conn, err := Dial("tcp", "www.ratafee.nl:443", &Config{ - ServerName: "www.ratafee.nl", - }) - if err != nil { - t.Fatal(err) - } - defer conn.Close() - //tlsconn := conn.(*Conn) - if err := conn.Handshake(); err != nil { - t.Fatal(err) - } - state := conn.ConnectionState() - for i := 0; i < int(state.PeerCertificate.certSize); i++ { - t.Log(state.PeerCertificate.getCertString(i, 1)) - } - - req, _ := http.NewRequest("GET", "https://www.ratafee.nl/httpbin/ip", nil) - req.Write(conn) - r := bufio.NewReader(conn) - resp, err := http.ReadResponse(r, req) - if err != nil { - t.Error(err) - } - var buf = new(bytes.Buffer) - resp.Write(buf) - t.Logf("%s", string(buf.Bytes())) - runtime.GC() - time.Sleep(1 * time.Second) -} +package gnutls + +import ( + "bufio" + "bytes" + "crypto/tls" + "fmt" + "io" + "net" + "net/http" + "runtime" + "testing" + "time" +) + +func TestTLSClient(t *testing.T) { + cert, err := tls.LoadX509KeyPair("testdata/server.crt", "testdata/server.key") + if err != nil { + t.Fatal("load certificate failed") + } + l, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{cert}, + }) + if err != nil { + t.Fatal("listen failed") + } + defer l.Close() + addr := l.Addr().String() + t.Log("test server listen on ", addr) + go func() { + for { + c, err := l.Accept() + if err != nil { + break + } + t.Logf("accept connection from %s", c.RemoteAddr()) + go func(c net.Conn) { + defer c.Close() + for { + buf := make([]byte, 4096) + n, err := c.Read(buf) + if err != nil { + t.Log("connection closed") + break + } + if _, err = c.Write(buf[:n]); err != nil { + break + } + } + }(c) + } + }() + + c, err := Dial("tcp", addr, &Config{InsecureSkipVerify: true}) + if err != nil { + t.Fatal("gnutls dial ", err) + } + defer c.Close() + + data := "hello, world" + if _, err = c.Write([]byte(data)); err != nil { + t.Fatal("gnutls write ", err) + } + buf := make([]byte, 100) + n, err := c.Read(buf) + if err != nil { + t.Fatal("gnutls read ", err) + } + if string(buf[:n]) != data { + t.Errorf("need: %s, got: %s", data, string(buf[:n])) + } +} + +func TestTLSServer(t *testing.T) { + cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key") + if err != nil { + t.Fatal(err) + } + l, err := Listen("tcp", "127.0.0.1:0", &Config{ + Certificates: []*Certificate{cert}, + }) + if err != nil { + t.Fatal("gnutls listen ", err) + } + addr := l.Addr().String() + t.Log("test server listen on ", addr) + defer l.Close() + go func() { + for { + c, err := l.Accept() + if err != nil { + t.Log("gnutls accept ", err) + break + } + t.Log("accept connection from ", c.RemoteAddr()) + go func(c net.Conn) { + defer c.Close() + tlsconn := c.(*Conn) + if err := tlsconn.Handshake(); err != nil { + t.Log(err) + return + } + buf := make([]byte, 4096) + for { + n, err := c.Read(buf[0:]) + if err != nil { + t.Log("gnutls read ", err) + break + } + if _, err := c.Write(buf[:n]); err != nil { + t.Log("gnutls write ", err) + break + } + } + }(c) + } + }() + + c, err := tls.Dial("tcp", addr, &tls.Config{InsecureSkipVerify: true}) + if err != nil { + t.Fatal("dial ", err) + } + defer c.Close() + + data := "hello, world" + if _, err := c.Write([]byte(data)); err != nil { + t.Fatal("write ", err) + } + buf := make([]byte, 100) + n, err := c.Read(buf) + if err != nil { + t.Fatal("read ", err) + } + if string(buf[:n]) != data { + t.Errorf("need: %s, got: %s", data, string(buf[:n])) + } + runtime.GC() + time.Sleep(1 * time.Second) +} + +func TestTLSALPNServer(t *testing.T) { + serveralpn := []string{"a1", "a3", "a2"} + clientalpn := []string{"a0", "a2", "a5"} + expectedAlpn := "a2" + cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key") + if err != nil { + t.Fatal(err) + } + l, err := Listen("tcp", "127.0.0.1:0", &Config{ + Certificates: []*Certificate{cert}, + NextProtos: serveralpn, + }) + if err != nil { + t.Fatal("gnutls listen ", err) + } + addr := l.Addr().String() + t.Log("test server listen on ", addr) + defer l.Close() + go func() { + for { + c, err := l.Accept() + if err != nil { + t.Log("gnutls accept ", err) + break + } + t.Log("accept connection from ", c.RemoteAddr()) + go func(c net.Conn) { + defer c.Close() + tlsConn := c.(*Conn) + if err := tlsConn.Handshake(); err != nil { + t.Log(err) + return + } + connState := tlsConn.ConnectionState() + t.Logf("%+v", connState) + buf := make([]byte, 4096) + for { + n, err := c.Read(buf[0:]) + if err != nil { + t.Log("gnutls read ", err) + break + } + if _, err := c.Write(buf[:n]); err != nil { + t.Log("gnutls write ", err) + break + } + } + }(c) + } + }() + + c, err := tls.Dial("tcp", addr, &tls.Config{ + InsecureSkipVerify: true, + ServerName: "localhost", + NextProtos: clientalpn, + }) + if err != nil { + t.Fatal("dial ", err) + } + defer c.Close() + + if err := c.Handshake(); err != nil { + t.Fatal(err) + } + connState := c.ConnectionState() + t.Logf("%+v", connState) + + if connState.NegotiatedProtocol != expectedAlpn { + t.Errorf("expected alpn %s, got %s", + expectedAlpn, connState.NegotiatedProtocol) + } + + data := "hello, world" + if _, err := c.Write([]byte(data)); err != nil { + t.Fatal("write ", err) + } + buf := make([]byte, 100) + n, err := c.Read(buf) + if err != nil { + t.Fatal("read ", err) + } + if string(buf[:n]) != data { + t.Errorf("need: %s, got: %s", data, string(buf[:n])) + } + runtime.GC() + time.Sleep(1 * time.Second) +} + +func TestTLSALPNClient(t *testing.T) { + serveralpn := []string{"a1", "a3", "a2"} + clientalpn := []string{"a0", "a2", "a5"} + expectedAlpn := "a2" + + cert, err := tls.LoadX509KeyPair("testdata/server.crt", "testdata/server.key") + if err != nil { + t.Fatal("load key failed") + } + + l, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{cert}, + NextProtos: serveralpn, + }) + + if err != nil { + t.Fatal("tls listen ", err) + } + addr := l.Addr().String() + t.Log("test server listen on ", addr) + defer l.Close() + go func() { + for { + c, err := l.Accept() + if err != nil { + t.Log("gnutls accept ", err) + break + } + t.Log("accept connection from ", c.RemoteAddr()) + go func(c net.Conn) { + defer c.Close() + tlsConn := c.(*tls.Conn) + if err := tlsConn.Handshake(); err != nil { + t.Log(err) + return + } + connState := tlsConn.ConnectionState() + t.Logf("%+v", connState) + buf := make([]byte, 4096) + for { + n, err := c.Read(buf[0:]) + if err != nil { + t.Log("tls read ", err) + break + } + if _, err := c.Write(buf[:n]); err != nil { + t.Log("tls write ", err) + break + } + } + }(c) + } + }() + + c, err := Dial("tcp", addr, &Config{InsecureSkipVerify: true, + ServerName: "localhost", + NextProtos: clientalpn, + }) + if err != nil { + t.Fatal("dial ", err) + } + defer c.Close() + + if err := c.Handshake(); err != nil { + t.Fatal(err) + } + connState := c.ConnectionState() + t.Logf("%+v", connState) + + if connState.NegotiatedProtocol != expectedAlpn { + t.Errorf("expected alpn %s, got %s", + expectedAlpn, connState.NegotiatedProtocol) + } + + data := "hello, world" + if _, err := c.Write([]byte(data)); err != nil { + t.Fatal("write ", err) + } + buf := make([]byte, 100) + n, err := c.Read(buf) + if err != nil { + t.Fatal("read ", err) + } + if string(buf[:n]) != data { + t.Errorf("need: %s, got: %s", data, string(buf[:n])) + } + runtime.GC() + time.Sleep(1 * time.Second) +} + +func TestTLSServerSNI(t *testing.T) { + certificates := []*Certificate{} + cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key") + if err != nil { + t.Fatal("load key failed") + } + + certificates = append(certificates, cert) + cert, err = LoadX509KeyPair("testdata/server2.crt", "testdata/server2.key") + if err != nil { + t.Fatal("load key failed") + } + + certificates = append(certificates, cert) + cert, err = LoadX509KeyPair("testdata/server3.crt", "testdata/server3.key") + if err != nil { + t.Fatal("load key failed") + } + certificates = append(certificates, cert) + + l, err := Listen("tcp", "127.0.0.1:0", &Config{ + Certificates: certificates, + }) + if err != nil { + t.Fatal(err) + } + defer l.Close() + addr := l.Addr().String() + go func() { + for { + c, err := l.Accept() + if err != nil { + t.Log(err) + break + } + go func(c net.Conn) { + defer c.Close() + tlsconn := c.(*Conn) + if err := tlsconn.Handshake(); err != nil { + t.Log(err) + return + } + state := tlsconn.ConnectionState() + fmt.Fprintf(c, state.ServerName) + }(c) + } + }() + + for _, cfg := range []struct { + serverName string + commonName string + }{ + {"abc.com", "abc.com"}, + {"example.com", "example.com"}, + {"a.aaa.com", "*.aaa.com"}, + {"b.aaa.com", "*.aaa.com"}, + } { + conn, err := tls.Dial("tcp", addr, &tls.Config{ + ServerName: cfg.serverName, + InsecureSkipVerify: true, + }) + if err != nil { + t.Fatal(err) + } + state := conn.ConnectionState() + _commonName := state.PeerCertificates[0].Subject.CommonName + if _commonName != cfg.commonName { + t.Errorf("expect: %s, got: %s", cfg.commonName, _commonName) + } + buf := make([]byte, 100) + n, err := conn.Read(buf) + if err != nil && err != io.EOF { + t.Error(err) + } + if !bytes.Equal(buf[:n], []byte(cfg.serverName)) { + t.Errorf("expect %s, got %s", cfg.serverName, string(buf[:n])) + } + conn.Close() + } + runtime.GC() + time.Sleep(1 * time.Second) +} + +func TestTLSGetPeerCert(t *testing.T) { + conn, err := Dial("tcp", "www.ratafee.nl:443", &Config{ + ServerName: "www.ratafee.nl", + }) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + //tlsconn := conn.(*Conn) + if err := conn.Handshake(); err != nil { + t.Fatal(err) + } + state := conn.ConnectionState() + for i := 0; i < int(state.PeerCertificate.certSize); i++ { + t.Log(state.PeerCertificate.getCertString(i, 1)) + } + + req, _ := http.NewRequest("GET", "https://www.ratafee.nl/httpbin/ip", nil) + req.Write(conn) + r := bufio.NewReader(conn) + resp, err := http.ReadResponse(r, req) + if err != nil { + t.Error(err) + } + var buf = new(bytes.Buffer) + resp.Write(buf) + t.Logf("%s", string(buf.Bytes())) + runtime.GC() + time.Sleep(1 * time.Second) +}