master
fangdingjun 7 years ago
parent 2c4c546551
commit e168de539b

@ -1,176 +1,176 @@
package gnutls
/*
#include "_gnutls.h"
*/
import "C"
import (
"fmt"
"log"
"runtime"
"strings"
"unsafe"
)
// Certificate x509 certificate
type Certificate struct {
cert *C.gnutls_pcert_st
privkey C.gnutls_privkey_t
certSize C.int
}
// Free free the certificate context
func (c *Certificate) Free() {
if c.cert != nil {
C.free_cert_list(c.cert, c.certSize)
}
if c.privkey != nil {
C.gnutls_privkey_deinit(c.privkey)
}
c.cert = nil
c.privkey = nil
c.certSize = 0
}
func (c *Certificate) free() {
//log.Println("free certificate")
c.Free()
}
func (c *Certificate) matchName(name string) bool {
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
ret := C.cert_check_hostname(c.cert, c.certSize, cname)
if int(ret) < 0 {
log.Println(C.GoString(C.gnutls_strerror(ret)))
return false
}
if int(ret) > 0 {
return true
}
return false
}
// CommonName get CN field in subject,
//
// example: subject C=xx,ST=bbb,CN=abc will return abc
func (c *Certificate) CommonName() string {
return c.commonName(0)
}
func (c *Certificate) commonName(index int) string {
dnMap := make(map[string]string)
dn := c.getDN(0)
if dn != "" {
dnFields := strings.Split(dn, ",")
for _, d := range dnFields {
dd := strings.Split(d, "=")
k := dd[0]
v := dd[1]
dnMap[k] = v
}
}
if cn, ok := dnMap["CN"]; ok {
return cn
}
return ""
}
// GetAltName get altname in certificate
func (c *Certificate) GetAltName() string {
return c.getAltName(0, 0)
}
func (c *Certificate) getAltName(index int, nameindex int) string {
out := C.malloc(1024)
defer C.free(out)
size := C.get_pcert_alt_name(
c.cert, C.int(index), C.int(nameindex), (*C.char)(out))
if int(size) < 0 {
log.Println(C.GoString(C.gnutls_strerror(size)))
return ""
}
name := C.GoBytes(out, size)
return string(name)
}
//GetCertString return certificate info string in one line
func (c *Certificate) GetCertString() string {
return c.getCertString(0, 1)
}
func (c *Certificate) getCertString(index int, flag int) string {
out := C.malloc(4096)
defer C.free(out)
size := C.get_cert_str(c.cert, C.int(index), C.int(flag), (*C.char)(out))
if int(size) < 0 {
log.Println(C.GoString(C.gnutls_strerror(size)))
return ""
}
s := C.GoBytes(out, size)
return string(s)
}
// GetDN get the certificate subject, like O=st,C=aa,CN=localhost
func (c *Certificate) GetDN() string {
return c.getDN(0)
}
func (c *Certificate) getDN(index int) string {
cbuf := C.malloc(200)
defer C.free(cbuf)
size := C.get_cert_dn(c.cert, C.int(index), (*C.char)(cbuf))
if int(size) < 0 {
log.Println(C.GoString(C.gnutls_strerror(size)))
return ""
}
s := C.GoBytes(cbuf, size)
return string(s)
}
// GetIssuerDN get the certificate issuer's subject, like O=st,C=ac,CN=localhost
func (c *Certificate) GetIssuerDN() string {
return c.getIssuerDN(0)
}
func (c *Certificate) getIssuerDN(index int) string {
cbuf := C.malloc(200)
defer C.free(cbuf)
size := C.get_cert_issuer_dn(c.cert, C.int(index), (*C.char)(cbuf))
if int(size) < 0 {
log.Println(C.GoString(C.gnutls_strerror(size)))
return ""
}
s := C.GoBytes(cbuf, size)
return string(s)
}
// LoadX509KeyPair load certificate pair,
// the return Certifciate must be freed by call Free(),
func LoadX509KeyPair(certfile, keyfile string) (*Certificate, error) {
_certfile := C.CString(certfile)
_keyfile := C.CString(keyfile)
defer C.free(unsafe.Pointer(_certfile))
defer C.free(unsafe.Pointer(_keyfile))
certificate := &Certificate{}
var ret C.int
var certSize C.int
cert := C.load_cert_list(_certfile, (*C.int)(unsafe.Pointer(&certSize)),
(*C.int)(unsafe.Pointer(&ret)))
if int(ret) < 0 {
return nil, fmt.Errorf("load cert failed: %s",
C.GoString(C.gnutls_strerror(ret)))
}
privkey := C.load_privkey(_keyfile, (*C.int)(unsafe.Pointer(&ret)))
if int(ret) < 0 {
return nil, fmt.Errorf("load privkey: %s",
C.GoString(C.gnutls_strerror(ret)))
}
certificate.cert = cert
certificate.privkey = privkey
certificate.certSize = certSize
runtime.SetFinalizer(certificate, (*Certificate).free)
return certificate, nil
}
package gnutls
/*
#include "_gnutls.h"
*/
import "C"
import (
"fmt"
"log"
"runtime"
"strings"
"unsafe"
)
// Certificate x509 certificate
type Certificate struct {
cert *C.gnutls_pcert_st
privkey C.gnutls_privkey_t
certSize C.int
}
// Free free the certificate context
func (c *Certificate) Free() {
if c.cert != nil {
C.free_cert_list(c.cert, c.certSize)
}
if c.privkey != nil {
C.gnutls_privkey_deinit(c.privkey)
}
c.cert = nil
c.privkey = nil
c.certSize = 0
}
func (c *Certificate) free() {
//log.Println("free certificate")
c.Free()
}
func (c *Certificate) matchName(name string) bool {
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
ret := C.cert_check_hostname(c.cert, c.certSize, cname)
if int(ret) < 0 {
log.Println(C.GoString(C.gnutls_strerror(ret)))
return false
}
if int(ret) > 0 {
return true
}
return false
}
// CommonName get CN field in subject,
//
// example: subject C=xx,ST=bbb,CN=abc will return abc
func (c *Certificate) CommonName() string {
return c.commonName(0)
}
func (c *Certificate) commonName(index int) string {
dnMap := make(map[string]string)
dn := c.getDN(0)
if dn != "" {
dnFields := strings.Split(dn, ",")
for _, d := range dnFields {
dd := strings.Split(d, "=")
k := dd[0]
v := dd[1]
dnMap[k] = v
}
}
if cn, ok := dnMap["CN"]; ok {
return cn
}
return ""
}
// GetAltName get altname in certificate
func (c *Certificate) GetAltName() string {
return c.getAltName(0, 0)
}
func (c *Certificate) getAltName(index int, nameindex int) string {
out := C.malloc(1024)
defer C.free(out)
size := C.get_pcert_alt_name(
c.cert, C.int(index), C.int(nameindex), (*C.char)(out))
if int(size) < 0 {
log.Println(C.GoString(C.gnutls_strerror(size)))
return ""
}
name := C.GoBytes(out, size)
return string(name)
}
//GetCertString return certificate info string in one line
func (c *Certificate) GetCertString() string {
return c.getCertString(0, 1)
}
func (c *Certificate) getCertString(index int, flag int) string {
out := C.malloc(4096)
defer C.free(out)
size := C.get_cert_str(c.cert, C.int(index), C.int(flag), (*C.char)(out))
if int(size) < 0 {
log.Println(C.GoString(C.gnutls_strerror(size)))
return ""
}
s := C.GoBytes(out, size)
return string(s)
}
// GetDN get the certificate subject, like O=st,C=aa,CN=localhost
func (c *Certificate) GetDN() string {
return c.getDN(0)
}
func (c *Certificate) getDN(index int) string {
cbuf := C.malloc(200)
defer C.free(cbuf)
size := C.get_cert_dn(c.cert, C.int(index), (*C.char)(cbuf))
if int(size) < 0 {
log.Println(C.GoString(C.gnutls_strerror(size)))
return ""
}
s := C.GoBytes(cbuf, size)
return string(s)
}
// GetIssuerDN get the certificate issuer's subject, like O=st,C=ac,CN=localhost
func (c *Certificate) GetIssuerDN() string {
return c.getIssuerDN(0)
}
func (c *Certificate) getIssuerDN(index int) string {
cbuf := C.malloc(200)
defer C.free(cbuf)
size := C.get_cert_issuer_dn(c.cert, C.int(index), (*C.char)(cbuf))
if int(size) < 0 {
log.Println(C.GoString(C.gnutls_strerror(size)))
return ""
}
s := C.GoBytes(cbuf, size)
return string(s)
}
// LoadX509KeyPair load certificate pair,
// the return Certifciate must be freed by call Free(),
func LoadX509KeyPair(certfile, keyfile string) (*Certificate, error) {
_certfile := C.CString(certfile)
_keyfile := C.CString(keyfile)
defer C.free(unsafe.Pointer(_certfile))
defer C.free(unsafe.Pointer(_keyfile))
certificate := &Certificate{}
var ret C.int
var certSize C.int
cert := C.load_cert_list(_certfile, (*C.int)(unsafe.Pointer(&certSize)),
(*C.int)(unsafe.Pointer(&ret)))
if int(ret) < 0 {
return nil, fmt.Errorf("load cert failed: %s",
C.GoString(C.gnutls_strerror(ret)))
}
privkey := C.load_privkey(_keyfile, (*C.int)(unsafe.Pointer(&ret)))
if int(ret) < 0 {
return nil, fmt.Errorf("load privkey: %s",
C.GoString(C.gnutls_strerror(ret)))
}
certificate.cert = cert
certificate.privkey = privkey
certificate.certSize = certSize
runtime.SetFinalizer(certificate, (*Certificate).free)
return certificate, nil
}

@ -1,35 +1,35 @@
package gnutls
import (
"runtime"
"testing"
"time"
)
func TestGetAltname(t *testing.T) {
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
if err != nil {
t.Fatal(err)
}
t.Logf("%+v", cert)
t.Log(cert.GetAltName())
t.Log(cert.GetCertString())
t.Log(cert.GetDN())
t.Log(cert.GetIssuerDN())
t.Log("CN ", cert.CommonName())
//t.Log("flag 0: ", cert.getCertString(0, 0))
//t.Log("flag 1: ", cert.getCertString(0, 1))
//t.Log("flag 2: ", cert.getCertString(0, 2))
//t.Log("flag 3: ", cert.getCertString(0, 3))
cert.Free()
}
func _loadCert(certfile, keyfile string) (*Certificate, error) {
return LoadX509KeyPair(certfile, keyfile)
}
func TestCertGC(t *testing.T) {
_loadCert("testdata/server.crt", "testdata/server.key")
runtime.GC()
time.Sleep(1 * time.Second)
}
package gnutls
import (
"runtime"
"testing"
"time"
)
func TestGetAltname(t *testing.T) {
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
if err != nil {
t.Fatal(err)
}
t.Logf("%+v", cert)
t.Log(cert.GetAltName())
t.Log(cert.GetCertString())
t.Log(cert.GetDN())
t.Log(cert.GetIssuerDN())
t.Log("CN ", cert.CommonName())
//t.Log("flag 0: ", cert.getCertString(0, 0))
//t.Log("flag 1: ", cert.getCertString(0, 1))
//t.Log("flag 2: ", cert.getCertString(0, 2))
//t.Log("flag 3: ", cert.getCertString(0, 3))
cert.Free()
}
func _loadCert(certfile, keyfile string) (*Certificate, error) {
return LoadX509KeyPair(certfile, keyfile)
}
func TestCertGC(t *testing.T) {
_loadCert("testdata/server.crt", "testdata/server.key")
runtime.GC()
time.Sleep(1 * time.Second)
}

@ -1,167 +1,167 @@
package gnutls
/*
#include "_gnutls.h"
#include <stdlib.h>
*/
import "C"
import (
"errors"
"fmt"
"log"
"runtime"
)
// CipherType cipher type
type CipherType int
const (
GNUTLS_CIPHER_AES_128_CBC CipherType = 4
GNUTLS_CIPHER_AES_256_CBC CipherType = 5
GNUTLS_CIPHER_ARCFOUR_40 CipherType = 6
GNUTLS_CIPHER_CAMELLIA_128_CBC CipherType = 7
GNUTLS_CIPHER_CAMELLIA_256_CBC CipherType = 8
GNUTLS_CIPHER_AES_192_CBC CipherType = 9
GNUTLS_CIPHER_AES_128_GCM CipherType = 10
GNUTLS_CIPHER_AES_256_GCM CipherType = 11
GNUTLS_CIPHER_CAMELLIA_192_CBC CipherType = 12
GNUTLS_CIPHER_SALSA20_256 CipherType = 13
GNUTLS_CIPHER_ESTREAM_SALSA20_256 CipherType = 14
GNUTLS_CIPHER_CAMELLIA_128_GCM CipherType = 15
GNUTLS_CIPHER_CAMELLIA_256_GCM CipherType = 16
GNUTLS_CIPHER_RC2_40_CBC CipherType = 17
GNUTLS_CIPHER_DES_CBC CipherType = 18
GNUTLS_CIPHER_AES_128_CCM CipherType = 19
GNUTLS_CIPHER_AES_256_CCM CipherType = 20
GNUTLS_CIPHER_AES_128_CCM_8 CipherType = 21
GNUTLS_CIPHER_AES_256_CCM_8 CipherType = 22
GNUTLS_CIPHER_CHACHA20_POLY1305 CipherType = 23
)
var (
// ErrBlockSize wrong block size
ErrBlockSize = errors.New("wrong block size")
// ErrKeyLength wrong key length
ErrKeyLength = errors.New("wrong key length")
// ErrIVLength wrong iv length
ErrIVLength = errors.New("wrong iv length")
)
// Cipher gnutls cipher struct
type Cipher struct {
cipher C.gnutls_cipher_hd_t
t CipherType
blockSize int
}
// NewCipher create a new cipher by give type, key, iv
//
// example:
// NewCipher(GNUTLS_CIPHER_AES_128_CBC, []byte("1234567890abcdef"), []byte("abcdef0123456789"))
//
// you can use GetCipherKeySize, GetCipherBlockSize, GetCipherIVSize to determine the given cipher 's key, block, iv size
func NewCipher(t CipherType, key []byte, iv []byte) (*Cipher, error) {
keysize := GetCipherKeySize(t)
ivSize := GetCipherIVSize(t)
blocksize := GetCipherBlockSize(t)
//log.Printf("block size: %d, iv size: %d", int(ivSize), int(blockSize))
if len(key) != int(keysize) {
return nil, ErrKeyLength
}
if len(iv) != int(ivSize) {
return nil, ErrIVLength
}
ckey := C.CBytes(key)
civ := C.CBytes(iv)
defer C.free(ckey)
defer C.free(civ)
c := C.new_cipher(C.int(t), (*C.char)(ckey), C.int(len(key)), (*C.char)(civ), C.int(len(iv)))
if c == nil {
log.Println("new cipher return nil")
return nil, nil
}
cipher := &Cipher{c, t, blocksize}
runtime.SetFinalizer(cipher, (*Cipher).free)
return cipher, nil
}
// Encrypt encrypt the buf and place the encrypted data in dst,
// the buf size must multiple of cipher's block size
func (c *Cipher) Encrypt(dst, buf []byte) error {
if len(buf)%c.blockSize != 0 {
return ErrBlockSize
}
cbuf := C.CBytes(buf)
defer C.free(cbuf)
bufLen := C.size_t(len(buf))
dstBuf := C.malloc(bufLen)
defer C.free(dstBuf)
ret := C.gnutls_cipher_encrypt2(c.cipher, cbuf, bufLen, dstBuf, bufLen)
if int(ret) < 0 {
return fmt.Errorf("encrypt error: %s", C.GoString(C.gnutls_strerror(ret)))
}
_buf := C.GoBytes(dstBuf, C.int(bufLen))
copy(dst, _buf)
return nil
}
// Decrypt decrypt the buf and place the decrypted data in dst,
// the buf size must multiple of cipher's block size
func (c *Cipher) Decrypt(dst, buf []byte) error {
if len(buf)%c.blockSize != 0 {
return ErrBlockSize
}
cbuf := C.CBytes(buf)
defer C.free(cbuf)
bufLen := C.size_t(len(buf))
dstBuf := C.malloc(C.size_t(len(buf)))
defer C.free(dstBuf)
ret := C.gnutls_cipher_decrypt2(c.cipher, cbuf, bufLen, dstBuf, bufLen)
if int(ret) < 0 {
return fmt.Errorf("decrypt error: %s", C.GoString(C.gnutls_strerror(ret)))
}
_buf := C.GoBytes(dstBuf, C.int(bufLen))
copy(dst, _buf)
return nil
}
// Close destroy the cipher context
func (c *Cipher) Close() error {
if c.cipher != nil {
C.gnutls_cipher_deinit(c.cipher)
c.cipher = nil
}
return nil
}
func (c *Cipher) free() {
//log.Println("free cipher")
c.Close()
}
// GetCipherKeySize get the cipher algorithm key length
func GetCipherKeySize(t CipherType) int {
return int(C.gnutls_cipher_get_key_size(C.gnutls_cipher_algorithm_t(t)))
}
// GetCipherIVSize get the cipher algorithm iv length
func GetCipherIVSize(t CipherType) int {
return int(C.gnutls_cipher_get_iv_size(C.gnutls_cipher_algorithm_t(t)))
}
// GetCipherBlockSize get the cipher algorithm block size
func GetCipherBlockSize(t CipherType) int {
return int(C.gnutls_cipher_get_block_size(C.gnutls_cipher_algorithm_t(t)))
}
package gnutls
/*
#include "_gnutls.h"
#include <stdlib.h>
*/
import "C"
import (
"errors"
"fmt"
"log"
"runtime"
)
// CipherType cipher type
type CipherType int
const (
GNUTLS_CIPHER_AES_128_CBC CipherType = 4
GNUTLS_CIPHER_AES_256_CBC CipherType = 5
GNUTLS_CIPHER_ARCFOUR_40 CipherType = 6
GNUTLS_CIPHER_CAMELLIA_128_CBC CipherType = 7
GNUTLS_CIPHER_CAMELLIA_256_CBC CipherType = 8
GNUTLS_CIPHER_AES_192_CBC CipherType = 9
GNUTLS_CIPHER_AES_128_GCM CipherType = 10
GNUTLS_CIPHER_AES_256_GCM CipherType = 11
GNUTLS_CIPHER_CAMELLIA_192_CBC CipherType = 12
GNUTLS_CIPHER_SALSA20_256 CipherType = 13
GNUTLS_CIPHER_ESTREAM_SALSA20_256 CipherType = 14
GNUTLS_CIPHER_CAMELLIA_128_GCM CipherType = 15
GNUTLS_CIPHER_CAMELLIA_256_GCM CipherType = 16
GNUTLS_CIPHER_RC2_40_CBC CipherType = 17
GNUTLS_CIPHER_DES_CBC CipherType = 18
GNUTLS_CIPHER_AES_128_CCM CipherType = 19
GNUTLS_CIPHER_AES_256_CCM CipherType = 20
GNUTLS_CIPHER_AES_128_CCM_8 CipherType = 21
GNUTLS_CIPHER_AES_256_CCM_8 CipherType = 22
GNUTLS_CIPHER_CHACHA20_POLY1305 CipherType = 23
)
var (
// ErrBlockSize wrong block size
ErrBlockSize = errors.New("wrong block size")
// ErrKeyLength wrong key length
ErrKeyLength = errors.New("wrong key length")
// ErrIVLength wrong iv length
ErrIVLength = errors.New("wrong iv length")
)
// Cipher gnutls cipher struct
type Cipher struct {
cipher C.gnutls_cipher_hd_t
t CipherType
blockSize int
}
// NewCipher create a new cipher by give type, key, iv
//
// example:
// NewCipher(GNUTLS_CIPHER_AES_128_CBC, []byte("1234567890abcdef"), []byte("abcdef0123456789"))
//
// you can use GetCipherKeySize, GetCipherBlockSize, GetCipherIVSize to determine the given cipher 's key, block, iv size
func NewCipher(t CipherType, key []byte, iv []byte) (*Cipher, error) {
keysize := GetCipherKeySize(t)
ivSize := GetCipherIVSize(t)
blocksize := GetCipherBlockSize(t)
//log.Printf("block size: %d, iv size: %d", int(ivSize), int(blockSize))
if len(key) != int(keysize) {
return nil, ErrKeyLength
}
if len(iv) != int(ivSize) {
return nil, ErrIVLength
}
ckey := C.CBytes(key)
civ := C.CBytes(iv)
defer C.free(ckey)
defer C.free(civ)
c := C.new_cipher(C.int(t), (*C.char)(ckey), C.int(len(key)), (*C.char)(civ), C.int(len(iv)))
if c == nil {
log.Println("new cipher return nil")
return nil, nil
}
cipher := &Cipher{c, t, blocksize}
runtime.SetFinalizer(cipher, (*Cipher).free)
return cipher, nil
}
// Encrypt encrypt the buf and place the encrypted data in dst,
// the buf size must multiple of cipher's block size
func (c *Cipher) Encrypt(dst, buf []byte) error {
if len(buf)%c.blockSize != 0 {
return ErrBlockSize
}
cbuf := C.CBytes(buf)
defer C.free(cbuf)
bufLen := C.size_t(len(buf))
dstBuf := C.malloc(bufLen)
defer C.free(dstBuf)
ret := C.gnutls_cipher_encrypt2(c.cipher, cbuf, bufLen, dstBuf, bufLen)
if int(ret) < 0 {
return fmt.Errorf("encrypt error: %s", C.GoString(C.gnutls_strerror(ret)))
}
_buf := C.GoBytes(dstBuf, C.int(bufLen))
copy(dst, _buf)
return nil
}
// Decrypt decrypt the buf and place the decrypted data in dst,
// the buf size must multiple of cipher's block size
func (c *Cipher) Decrypt(dst, buf []byte) error {
if len(buf)%c.blockSize != 0 {
return ErrBlockSize
}
cbuf := C.CBytes(buf)
defer C.free(cbuf)
bufLen := C.size_t(len(buf))
dstBuf := C.malloc(C.size_t(len(buf)))
defer C.free(dstBuf)
ret := C.gnutls_cipher_decrypt2(c.cipher, cbuf, bufLen, dstBuf, bufLen)
if int(ret) < 0 {
return fmt.Errorf("decrypt error: %s", C.GoString(C.gnutls_strerror(ret)))
}
_buf := C.GoBytes(dstBuf, C.int(bufLen))
copy(dst, _buf)
return nil
}
// Close destroy the cipher context
func (c *Cipher) Close() error {
if c.cipher != nil {
C.gnutls_cipher_deinit(c.cipher)
c.cipher = nil
}
return nil
}
func (c *Cipher) free() {
//log.Println("free cipher")
c.Close()
}
// GetCipherKeySize get the cipher algorithm key length
func GetCipherKeySize(t CipherType) int {
return int(C.gnutls_cipher_get_key_size(C.gnutls_cipher_algorithm_t(t)))
}
// GetCipherIVSize get the cipher algorithm iv length
func GetCipherIVSize(t CipherType) int {
return int(C.gnutls_cipher_get_iv_size(C.gnutls_cipher_algorithm_t(t)))
}
// GetCipherBlockSize get the cipher algorithm block size
func GetCipherBlockSize(t CipherType) int {
return int(C.gnutls_cipher_get_block_size(C.gnutls_cipher_algorithm_t(t)))
}

@ -1,130 +1,130 @@
package gnutls
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"runtime"
"testing"
"time"
)
func TestCipherSize(t *testing.T) {
data := []struct {
t CipherType
bsize int
isize int
}{
{GNUTLS_CIPHER_AES_128_CBC, 16, 16},
{GNUTLS_CIPHER_AES_192_CBC, 24, 16},
{GNUTLS_CIPHER_AES_256_CBC, 32, 16},
}
for _, d := range data {
blocksize := GetCipherKeySize(d.t)
if blocksize != d.bsize {
t.Errorf("%d block size expect: %d, got: %d", d.t, d.bsize, blocksize)
}
ivsize := GetCipherIVSize(d.t)
if ivsize != d.isize {
t.Errorf("%d iv size expect: %d, got: %d", d.t, d.bsize, ivsize)
}
}
}
func TestEncryptDecrypt(t *testing.T) {
cipherName := GNUTLS_CIPHER_AES_256_CBC
keysize := GetCipherKeySize(cipherName)
ivsize := GetCipherIVSize(cipherName)
blocksize := GetCipherBlockSize(cipherName)
key := make([]byte, keysize)
iv := make([]byte, ivsize)
rand.Reader.Read(key)
rand.Reader.Read(iv)
c, err := NewCipher(cipherName, key, iv)
if err != nil {
t.Fatal(err)
}
//defer c.Close()
c1, err := NewCipher(cipherName, key, iv)
if err != nil {
t.Fatal(err)
}
//defer c1.Close()
data := make([]byte, blocksize*10)
if c == nil {
t.Fatal("new ciphoer failed")
}
cdata := make([]byte, len(data))
err = c.Encrypt(cdata, data)
if err != nil {
t.Fatal("encrypt failed", err)
}
data1 := make([]byte, len(data))
err = c1.Decrypt(data1, cdata)
if err != nil {
t.Fatal("decrypt failed", err)
}
if !bytes.Equal(data, data1) {
t.Fatal("encrypt/decrypt failed", string(data), string(data1))
}
block, _ := aes.NewCipher(key)
mode := cipher.NewCBCEncrypter(block, iv)
dst := make([]byte, len(data))
mode.CryptBlocks(dst, data)
if !bytes.Equal(dst, cdata) {
t.Fatal("cipher text not equal to cypto/aes")
}
runtime.GC()
time.Sleep(1 * time.Second)
}
func BenchmarkAESEncrypt(b *testing.B) {
cipherName := GNUTLS_CIPHER_AES_256_CBC
keysize := GetCipherKeySize(cipherName)
ivsize := GetCipherIVSize(cipherName)
blocksize := GetCipherBlockSize(cipherName)
datalen := blocksize * 500
key := make([]byte, keysize)
iv := make([]byte, ivsize)
rand.Reader.Read(key)
rand.Reader.Read(iv)
buf := make([]byte, datalen)
dst := make([]byte, datalen)
for i := 0; i < b.N; i++ {
c, err := NewCipher(cipherName, key, iv)
if err != nil {
b.Fatal(err)
}
c.Encrypt(dst, buf)
c.Close()
}
}
func BenchmarkAESEncrypt2(b *testing.B) {
cipherName := GNUTLS_CIPHER_AES_256_CBC
keysize := GetCipherKeySize(cipherName)
ivsize := GetCipherIVSize(cipherName)
blocksize := GetCipherBlockSize(cipherName)
datalen := blocksize * 500
key := make([]byte, keysize)
iv := make([]byte, ivsize)
buf := make([]byte, datalen)
rand.Reader.Read(buf)
rand.Reader.Read(key)
rand.Reader.Read(iv)
dst := make([]byte, datalen)
for i := 0; i < b.N; i++ {
block, _ := aes.NewCipher(key)
mode := cipher.NewCBCEncrypter(block, iv)
mode.CryptBlocks(dst, buf)
}
}
package gnutls
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"runtime"
"testing"
"time"
)
func TestCipherSize(t *testing.T) {
data := []struct {
t CipherType
bsize int
isize int
}{
{GNUTLS_CIPHER_AES_128_CBC, 16, 16},
{GNUTLS_CIPHER_AES_192_CBC, 24, 16},
{GNUTLS_CIPHER_AES_256_CBC, 32, 16},
}
for _, d := range data {
blocksize := GetCipherKeySize(d.t)
if blocksize != d.bsize {
t.Errorf("%d block size expect: %d, got: %d", d.t, d.bsize, blocksize)
}
ivsize := GetCipherIVSize(d.t)
if ivsize != d.isize {
t.Errorf("%d iv size expect: %d, got: %d", d.t, d.bsize, ivsize)
}
}
}
func TestEncryptDecrypt(t *testing.T) {
cipherName := GNUTLS_CIPHER_AES_256_CBC
keysize := GetCipherKeySize(cipherName)
ivsize := GetCipherIVSize(cipherName)
blocksize := GetCipherBlockSize(cipherName)
key := make([]byte, keysize)
iv := make([]byte, ivsize)
rand.Reader.Read(key)
rand.Reader.Read(iv)
c, err := NewCipher(cipherName, key, iv)
if err != nil {
t.Fatal(err)
}
//defer c.Close()
c1, err := NewCipher(cipherName, key, iv)
if err != nil {
t.Fatal(err)
}
//defer c1.Close()
data := make([]byte, blocksize*10)
if c == nil {
t.Fatal("new ciphoer failed")
}
cdata := make([]byte, len(data))
err = c.Encrypt(cdata, data)
if err != nil {
t.Fatal("encrypt failed", err)
}
data1 := make([]byte, len(data))
err = c1.Decrypt(data1, cdata)
if err != nil {
t.Fatal("decrypt failed", err)
}
if !bytes.Equal(data, data1) {
t.Fatal("encrypt/decrypt failed", string(data), string(data1))
}
block, _ := aes.NewCipher(key)
mode := cipher.NewCBCEncrypter(block, iv)
dst := make([]byte, len(data))
mode.CryptBlocks(dst, data)
if !bytes.Equal(dst, cdata) {
t.Fatal("cipher text not equal to cypto/aes")
}
runtime.GC()
time.Sleep(1 * time.Second)
}
func BenchmarkAESEncrypt(b *testing.B) {
cipherName := GNUTLS_CIPHER_AES_256_CBC
keysize := GetCipherKeySize(cipherName)
ivsize := GetCipherIVSize(cipherName)
blocksize := GetCipherBlockSize(cipherName)
datalen := blocksize * 500
key := make([]byte, keysize)
iv := make([]byte, ivsize)
rand.Reader.Read(key)
rand.Reader.Read(iv)
buf := make([]byte, datalen)
dst := make([]byte, datalen)
for i := 0; i < b.N; i++ {
c, err := NewCipher(cipherName, key, iv)
if err != nil {
b.Fatal(err)
}
c.Encrypt(dst, buf)
c.Close()
}
}
func BenchmarkAESEncrypt2(b *testing.B) {
cipherName := GNUTLS_CIPHER_AES_256_CBC
keysize := GetCipherKeySize(cipherName)
ivsize := GetCipherIVSize(cipherName)
blocksize := GetCipherBlockSize(cipherName)
datalen := blocksize * 500
key := make([]byte, keysize)
iv := make([]byte, ivsize)
buf := make([]byte, datalen)
rand.Reader.Read(buf)
rand.Reader.Read(key)
rand.Reader.Read(iv)
dst := make([]byte, datalen)
for i := 0; i < b.N; i++ {
block, _ := aes.NewCipher(key)
mode := cipher.NewCBCEncrypter(block, iv)
mode.CryptBlocks(dst, buf)
}
}

216
doc.go

@ -1,108 +1,108 @@
// Package gnutls is a gnutls binding for golang.
/*
a limit set of api is supported.
TLS api is very similar to crypto/tls on standard library.
TLS client example:
addr := "127.0.0.1:9443"
c, err := gnutls.Dial("tcp", addr, &gnutls.Config{ServerName: "localhost",InsecureSkipVerify: true})
if err != nil {
log.Fatal("gnutls dial ", err)
}
defer c.Close()
data := "hello, world"
if _, err = c.Write([]byte(data)); err != nil {
log.Fatal("gnutls write ", err)
}
buf := make([]byte, 100)
n, err := c.Read(buf)
if err != nil {
log.Fatal("gnutls read ", err)
}
TLS Server example:
cert, err := gnutls.LoadX509KeyPair("testdata/server/crt", "testdata/server.key")
if err != nil{
// handle error
}
l, err := gnults.Listen("tcp", "127.0.0.1:9443", &gnutls.Config{
Certificates: []*gnutls.Certificate{cert}})
if err != nil {
// handle error
}
defer l.Close()
for {
c, err := l.Accept()
if err != nil {
log.Println("gnutls accept ", err)
break
}
log.Println("accept connection from ", c.RemoteAddr())
go func(c net.Conn) {
defer c.Close()
buf := make([]byte, 4096)
for {
n, err := c.Read(buf[0:])
if err != nil {
log.Println("gnutls read ", err)
break
}
if _, err := c.Write(buf[:n]); err != nil {
log.Println("gnutls write ", err)
break
}
}
}(c)
}
AES encrypt/decrypt example:
key := []byte("0123456789abcdef")
iv := []byte("abcdefg123456789")
c, err := gnutls.NewCipher(gnutls.GNUTLS_CIPHER_AES_128_CBC, key, iv)
if err != nil {
log.Fatal(err)
}
defer c.Close()
c1, err := gnutls.NewCipher(gnutls.GNUTLS_CIPHER_AES_128_CBC, key, iv)
if err != nil {
log.Fatal(err)
}
defer c1.Close()
data := []byte("1234012121212121")
if c == nil {
log.Fatal("new ciphoer failed")
}
// encrypt
dst := make([]byte, len(data))
err := c.Encrypt(dst, data)
if err != nil {
log.Fatal("encrypt failed", err)
}
// decrypt
data1 := make([]byte, len(data))
err := c1.Decrypt(data1, cdata)
if err != nil {
log.Fatal("decrypt failed", err)
}
Hash example:
h := gnutls.NewHash(gnutls.GNUTLS_HASH_SHA512)
defer h.Close()
data := []byte("1234")
h1 := h.Sum(data)
*/
package gnutls
// Package gnutls is a gnutls binding for golang.
/*
a limit set of api is supported.
TLS api is very similar to crypto/tls on standard library.
TLS client example:
addr := "127.0.0.1:9443"
c, err := gnutls.Dial("tcp", addr, &gnutls.Config{ServerName: "localhost",InsecureSkipVerify: true})
if err != nil {
log.Fatal("gnutls dial ", err)
}
defer c.Close()
data := "hello, world"
if _, err = c.Write([]byte(data)); err != nil {
log.Fatal("gnutls write ", err)
}
buf := make([]byte, 100)
n, err := c.Read(buf)
if err != nil {
log.Fatal("gnutls read ", err)
}
TLS Server example:
cert, err := gnutls.LoadX509KeyPair("testdata/server/crt", "testdata/server.key")
if err != nil{
// handle error
}
l, err := gnults.Listen("tcp", "127.0.0.1:9443", &gnutls.Config{
Certificates: []*gnutls.Certificate{cert}})
if err != nil {
// handle error
}
defer l.Close()
for {
c, err := l.Accept()
if err != nil {
log.Println("gnutls accept ", err)
break
}
log.Println("accept connection from ", c.RemoteAddr())
go func(c net.Conn) {
defer c.Close()
buf := make([]byte, 4096)
for {
n, err := c.Read(buf[0:])
if err != nil {
log.Println("gnutls read ", err)
break
}
if _, err := c.Write(buf[:n]); err != nil {
log.Println("gnutls write ", err)
break
}
}
}(c)
}
AES encrypt/decrypt example:
key := []byte("0123456789abcdef")
iv := []byte("abcdefg123456789")
c, err := gnutls.NewCipher(gnutls.GNUTLS_CIPHER_AES_128_CBC, key, iv)
if err != nil {
log.Fatal(err)
}
defer c.Close()
c1, err := gnutls.NewCipher(gnutls.GNUTLS_CIPHER_AES_128_CBC, key, iv)
if err != nil {
log.Fatal(err)
}
defer c1.Close()
data := []byte("1234012121212121")
if c == nil {
log.Fatal("new ciphoer failed")
}
// encrypt
dst := make([]byte, len(data))
err := c.Encrypt(dst, data)
if err != nil {
log.Fatal("encrypt failed", err)
}
// decrypt
data1 := make([]byte, len(data))
err := c1.Decrypt(data1, cdata)
if err != nil {
log.Fatal("decrypt failed", err)
}
Hash example:
h := gnutls.NewHash(gnutls.GNUTLS_HASH_SHA512)
defer h.Close()
data := []byte("1234")
h1 := h.Sum(data)
*/
package gnutls

@ -1,90 +1,90 @@
package gnutls
/*
#include <stdlib.h>
#include "_gnutls.h"
*/
import "C"
import (
"fmt"
"runtime"
)
// HashType hash type
type HashType int
const (
GNUTLS_HASH_MD5 HashType = 2
GNUTLS_HASH_SHA1 HashType = 3
GNUTLS_HASH_MD2 HashType = 5
GNUTLS_HASH_SHA256 HashType = 6
GNUTLS_HASH_SHA384 HashType = 7
GNUTLS_HASH_SHA512 HashType = 8
GNUTLS_HASH_SHA224 HashType = 9
)
// Hash hash struct
type Hash struct {
hash C.gnutls_hash_hd_t
t HashType
hashLen C.int
}
// NewHash new hash struct
func NewHash(t HashType) *Hash {
h := C.new_hash(C.int(t))
hashOutLen := GetHashOutputLen(t)
hash := &Hash{h, t, C.int(hashOutLen)}
runtime.SetFinalizer(hash, (*Hash).free)
return hash
}
// Write write data to hash context
func (h *Hash) Write(buf []byte) error {
dataLen := len(buf)
cbuf := C.CBytes(buf)
defer C.free(cbuf)
ret := C.gnutls_hash(h.hash, cbuf, C.size_t(dataLen))
if int(ret) < 0 {
return fmt.Errorf("hash failed: %s", C.GoString(C.gnutls_strerror(ret)))
}
return nil
}
// Sum get hash result
func (h *Hash) Sum(buf []byte) []byte {
if buf != nil {
h.Write(buf)
}
dstBuf := C.malloc(C.size_t(h.hashLen))
defer C.free(dstBuf)
C.gnutls_hash_output(h.hash, dstBuf)
gobuf := C.GoBytes(dstBuf, h.hashLen)
return gobuf
}
// Close destroy hash context
func (h *Hash) Close() error {
if h.hash != nil {
C.gnutls_hash_deinit(h.hash, nil)
h.hash = nil
}
return nil
}
func (h *Hash) free() {
//log.Println("free hash")
h.Close()
}
// GetHashOutputLen get the hash algorithm output length
//
// example GNUTLS_MD5 is 16
func GetHashOutputLen(t HashType) int {
return int(C.gnutls_hash_get_len(C.gnutls_digest_algorithm_t(t)))
}
package gnutls
/*
#include <stdlib.h>
#include "_gnutls.h"
*/
import "C"
import (
"fmt"
"runtime"
)
// HashType hash type
type HashType int
const (
GNUTLS_HASH_MD5 HashType = 2
GNUTLS_HASH_SHA1 HashType = 3
GNUTLS_HASH_MD2 HashType = 5
GNUTLS_HASH_SHA256 HashType = 6
GNUTLS_HASH_SHA384 HashType = 7
GNUTLS_HASH_SHA512 HashType = 8
GNUTLS_HASH_SHA224 HashType = 9
)
// Hash hash struct
type Hash struct {
hash C.gnutls_hash_hd_t
t HashType
hashLen C.int
}
// NewHash new hash struct
func NewHash(t HashType) *Hash {
h := C.new_hash(C.int(t))
hashOutLen := GetHashOutputLen(t)
hash := &Hash{h, t, C.int(hashOutLen)}
runtime.SetFinalizer(hash, (*Hash).free)
return hash
}
// Write write data to hash context
func (h *Hash) Write(buf []byte) error {
dataLen := len(buf)
cbuf := C.CBytes(buf)
defer C.free(cbuf)
ret := C.gnutls_hash(h.hash, cbuf, C.size_t(dataLen))
if int(ret) < 0 {
return fmt.Errorf("hash failed: %s", C.GoString(C.gnutls_strerror(ret)))
}
return nil
}
// Sum get hash result
func (h *Hash) Sum(buf []byte) []byte {
if buf != nil {
h.Write(buf)
}
dstBuf := C.malloc(C.size_t(h.hashLen))
defer C.free(dstBuf)
C.gnutls_hash_output(h.hash, dstBuf)
gobuf := C.GoBytes(dstBuf, h.hashLen)
return gobuf
}
// Close destroy hash context
func (h *Hash) Close() error {
if h.hash != nil {
C.gnutls_hash_deinit(h.hash, nil)
h.hash = nil
}
return nil
}
func (h *Hash) free() {
//log.Println("free hash")
h.Close()
}
// GetHashOutputLen get the hash algorithm output length
//
// example GNUTLS_MD5 is 16
func GetHashOutputLen(t HashType) int {
return int(C.gnutls_hash_get_len(C.gnutls_digest_algorithm_t(t)))
}

@ -1,52 +1,52 @@
package gnutls
import (
"bytes"
"crypto/rand"
"crypto/sha512"
"encoding/hex"
"runtime"
"testing"
"time"
)
func TestHashSHA(t *testing.T) {
h := NewHash(GNUTLS_HASH_SHA512)
//defer h.Close()
data := []byte("1234")
h1 := h.Sum(data)
h3 := sha512.New()
h3.Write(data)
h4 := h3.Sum(nil)
if !bytes.Equal(h4[:], h1) {
t.Logf("\n%s\n%s", hex.EncodeToString(h4[:]), hex.EncodeToString(h1))
t.Fatal("hash not equal")
}
runtime.GC()
time.Sleep(1 * time.Second)
}
func BenchmarkHashSHA512(b *testing.B) {
buf := make([]byte, 1000*1024)
rand.Reader.Read(buf)
for i := 0; i < b.N; i++ {
h := NewHash(GNUTLS_HASH_SHA512)
h.Write(buf)
h.Sum(nil)
h.Close()
}
}
func BenchmarkHashSHA512s(b *testing.B) {
buf := make([]byte, 1000*1024)
rand.Reader.Read(buf)
for i := 0; i < b.N; i++ {
h := sha512.New()
h.Write(buf)
h.Sum(nil)
//h.Close()
}
}
package gnutls
import (
"bytes"
"crypto/rand"
"crypto/sha512"
"encoding/hex"
"runtime"
"testing"
"time"
)
func TestHashSHA(t *testing.T) {
h := NewHash(GNUTLS_HASH_SHA512)
//defer h.Close()
data := []byte("1234")
h1 := h.Sum(data)
h3 := sha512.New()
h3.Write(data)
h4 := h3.Sum(nil)
if !bytes.Equal(h4[:], h1) {
t.Logf("\n%s\n%s", hex.EncodeToString(h4[:]), hex.EncodeToString(h1))
t.Fatal("hash not equal")
}
runtime.GC()
time.Sleep(1 * time.Second)
}
func BenchmarkHashSHA512(b *testing.B) {
buf := make([]byte, 1000*1024)
rand.Reader.Read(buf)
for i := 0; i < b.N; i++ {
h := NewHash(GNUTLS_HASH_SHA512)
h.Write(buf)
h.Sum(nil)
h.Close()
}
}
func BenchmarkHashSHA512s(b *testing.B) {
buf := make([]byte, 1000*1024)
rand.Reader.Read(buf)
for i := 0; i < b.N; i++ {
h := sha512.New()
h.Write(buf)
h.Sum(nil)
//h.Close()
}
}

984
tls.go

@ -1,492 +1,492 @@
package gnutls
/*
#include "_gnutls.h"
#cgo pkg-config: gnutls
*/
import "C"
import (
"errors"
"fmt"
"io"
"log"
"net"
"runtime"
"sync"
"time"
"unsafe"
)
const (
GNUTLS_NAME_DNS = 1
GNUTLS_X509_FMT_PEM = 1
GNUTLS_ALPN_MANDATORY = 1
GNUTLS_ALPN_SERVER_PRECEDENCE = 1 << 1
)
// Conn gnutls TLS connection,
// use Listen, Dial, Server or Client create this object
type Conn struct {
c net.Conn
handshake bool
sess *C.struct_session
cservname *C.char
state *ConnectionState
cfg *Config
closed bool
lock *sync.Mutex
}
// Config gnutls TLS configure,
type Config struct {
ServerName string
Certificates []*Certificate
InsecureSkipVerify bool
NextProtos []string
}
// ConnectionState gnutls TLS connection state
type ConnectionState struct {
// SNI name client send
ServerName string
// selected ALPN protocl
NegotiatedProtocol string
HandshakeComplete bool
// TLS version number, ex: 0x303
Version uint16
// TLS version number, ex: TLS1.0
VersionName string
// peer's certificate
PeerCertificate *Certificate
}
type listener struct {
l net.Listener
c *Config
}
// Accept
func (l *listener) Accept() (net.Conn, error) {
c, err := l.l.Accept()
if err != nil {
return nil, err
}
return Server(c, l.c)
}
// Close
func (l *listener) Close() error {
return l.l.Close()
}
// Addr
func (l *listener) Addr() net.Addr {
return l.l.Addr()
}
// Dial dial to (network, addr) and create a gnutls Conn
func Dial(network, addr string, cfg *Config) (*Conn, error) {
c, err := net.Dial(network, addr)
if err != nil {
return nil, err
}
return Client(c, cfg)
}
// Listen create a gnutls listener on (network, addr),
func Listen(network, addr string, cfg *Config) (net.Listener, error) {
if cfg == nil {
return nil, errors.New("config is needed")
}
l, err := net.Listen(network, addr)
if err != nil {
return nil, err
}
return &listener{l, cfg}, nil
}
// Server create a server TLS Conn on c
func Server(c net.Conn, cfg *Config) (*Conn, error) {
if cfg == nil {
return nil, errors.New("config is needed")
}
var sess = C.init_gnutls_server_session()
conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)}
C.set_data(sess, C.size_t(uintptr(unsafe.Pointer(conn))))
C.set_callback(sess)
if cfg.NextProtos != nil {
if err := setAlpnProtocols(sess, cfg); err != nil {
log.Println(err)
}
}
runtime.SetFinalizer(conn, (*Conn).free)
return conn, nil
}
// Client create a client TLS Conn on c
func Client(c net.Conn, cfg *Config) (*Conn, error) {
var sess = C.init_gnutls_client_session()
conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)}
C.set_data(sess, C.size_t(uintptr(unsafe.Pointer(conn))))
C.set_callback(sess)
if cfg != nil {
if cfg.ServerName != "" {
srvname := C.CString(cfg.ServerName)
//defer C.free(unsafe.Pointer(srvname))
conn.cservname = srvname
C.gnutls_server_name_set(sess.session, GNUTLS_NAME_DNS,
unsafe.Pointer(srvname), C.size_t(len(cfg.ServerName)))
}
if !cfg.InsecureSkipVerify {
if conn.cservname != nil {
C.gnutls_session_set_verify_cert(sess.session, conn.cservname, 0)
} else {
C.gnutls_session_set_verify_cert(sess.session, nil, 0)
}
}
if cfg.NextProtos != nil {
if err := setAlpnProtocols(sess, cfg); err != nil {
log.Println(err)
}
}
} else {
C.gnutls_session_set_verify_cert(sess.session, nil, 0)
}
runtime.SetFinalizer(conn, (*Conn).free)
return conn, nil
}
func setAlpnProtocols(sess *C.struct_session, cfg *Config) error {
arg := make([](*C.char), 0)
for _, s := range cfg.NextProtos {
cbuf := C.CString(s)
defer C.free(unsafe.Pointer(cbuf))
arg = append(arg, (*C.char)(cbuf))
}
ret := C.alpn_set_protocols(sess,
(**C.char)(unsafe.Pointer(&arg[0])), C.int(len(cfg.NextProtos)))
if int(ret) < 0 {
return fmt.Errorf("set alpn failed: %s", C.GoString(C.gnutls_strerror(ret)))
}
return nil
}
// Handshake call handshake for TLS Conn,
// this function will call automatic on Read/Write, if not handshake yet
func (c *Conn) Handshake() error {
c.lock.Lock()
defer c.lock.Unlock()
if c.handshake {
return nil
}
ret := C.handshake(c.sess)
if int(ret) < 0 {
return fmt.Errorf("handshake error: %s", C.GoString(C.gnutls_strerror(ret)))
}
c.handshake = true
//log.Println("handshake done")
return nil
}
// Read read application data from TLS connection
func (c *Conn) Read(buf []byte) (n int, err error) {
err = c.Handshake()
if err != nil {
return
}
if len(buf) == 0 {
return 0, nil
}
if c.closed {
return 0, io.EOF
}
//bufLen := len(buf)
//cbuf := C.malloc(C.size_t(bufLen))
//defer C.free(cbuf)
ret := C.gnutls_record_recv(c.sess.session,
unsafe.Pointer(&buf[0]), C.size_t(len(buf)))
if int(ret) < 0 {
return 0, fmt.Errorf("read error: %s",
C.GoString(C.gnutls_strerror(C.int(ret))))
}
if int(ret) == 0 {
return 0, io.EOF
}
n = int(ret)
//gobuf2 := C.GoBytes(cbuf, C.int(ret))
//copy(buf, gobuf2)
return n, nil
}
// Write write application data to TLS connection
func (c *Conn) Write(buf []byte) (n int, err error) {
err = c.Handshake()
if err != nil {
return
}
// user may call Write(nil) to do handshake
if len(buf) == 0 {
return 0, nil
}
if c.closed {
return 0, io.EOF
}
//cbuf := C.CBytes(buf)
//defer C.free(cbuf)
ret := C.gnutls_record_send(c.sess.session,
unsafe.Pointer(&buf[0]), C.size_t(len(buf)))
n = int(ret)
if n < 0 {
return 0, fmt.Errorf("write error: %s",
C.GoString(C.gnutls_strerror(C.int(ret))))
}
if int(ret) == 0 {
return 0, io.EOF
}
return n, nil
}
// Close close the TLS conn and destroy the tls context
func (c *Conn) Close() error {
c.lock.Lock()
defer c.lock.Unlock()
if c.closed {
return nil
}
C.gnutls_record_send(c.sess.session, nil, 0)
C.session_destroy(c.sess)
c.c.Close()
if c.cservname != nil {
C.free(unsafe.Pointer(c.cservname))
}
if c.state != nil && c.state.PeerCertificate != nil {
c.state.PeerCertificate.Free()
}
c.closed = true
return nil
}
func (c *Conn) free() {
//log.Println("free conn")
c.Close()
}
// SetWriteDeadline implements net.Conn
func (c *Conn) SetWriteDeadline(t time.Time) error {
return c.c.SetWriteDeadline(t)
}
// SetReadDeadline implements net.Conn
func (c *Conn) SetReadDeadline(t time.Time) error {
return c.c.SetReadDeadline(t)
}
// RemoteAddr implements net.Conn
func (c *Conn) RemoteAddr() net.Addr {
return c.c.RemoteAddr()
}
// LocalAddr implements net.Conn
func (c *Conn) LocalAddr() net.Addr {
return c.c.LocalAddr()
}
// SetDeadline implements net.Conn
func (c *Conn) SetDeadline(t time.Time) error {
return c.c.SetDeadline(t)
}
// ConnectionState get TLS connection state
func (c *Conn) ConnectionState() *ConnectionState {
if c.state != nil {
return c.state
}
version :=
uint16(C.gnutls_protocol_get_version(c.sess.session))
versionname := C.GoString(
C.gnutls_protocol_get_name(C.gnutls_protocol_t(version)))
state := &ConnectionState{
NegotiatedProtocol: c.getAlpnSelectedProtocol(),
Version: version,
HandshakeComplete: c.handshake,
ServerName: c.getServerName(),
VersionName: versionname,
PeerCertificate: c.getPeerCertificate(),
}
c.state = state
return state
}
func (c *Conn) getPeerCertificate() *Certificate {
var size int
st := C.get_peer_certificate(c.sess.session, (*C.int)(unsafe.Pointer(&size)))
if st == nil {
return nil
}
cert := &Certificate{cert: st, certSize: C.int(size)}
runtime.SetFinalizer(cert, (*Certificate).free)
return cert
}
func (c *Conn) getAlpnSelectedProtocol() string {
cbuf := C.malloc(100)
defer C.free(cbuf)
ret := C.alpn_get_selected_protocol(c.sess, (*C.char)(cbuf))
if int(ret) < 0 {
return ""
}
alpnname := C.GoString((*C.char)(cbuf))
return alpnname
}
func (c *Conn) getServerName() string {
buflen := 100
nametype := GNUTLS_NAME_DNS
cbuf := C.malloc(C.size_t(buflen))
defer C.free(cbuf)
ret := C.gnutls_server_name_get(c.sess.session, cbuf,
(*C.size_t)(unsafe.Pointer(&buflen)),
(*C.uint)(unsafe.Pointer(&nametype)), 0)
if int(ret) < 0 {
return ""
}
name := C.GoString((*C.char)(cbuf))
return name
}
/*
var bufPool = &sync.Pool{
New: func() interface{} {
return make([]byte, 16*1024)
},
}
*/
// onDataReadCallback callback function for gnutls library want to read data from network
//
//export onDataReadCallback
func onDataReadCallback(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int {
//log.Println("read addr ", uintptr(d))
conn := (*Conn)(unsafe.Pointer((uintptr(d))))
buf := make([]byte, int(bufLen))
/*
_length := int(bufLen)
buf := bufPool.Get().([]byte)
if len(buf) < _length {
buf = make([]byte, _length)
}
defer bufPool.Put(buf)
*/
n, err := conn.c.Read(buf[0:])
if err != nil {
//log.Println(err)
// 0 indicates connection termination
return 0
}
//cbuf2 := C.CBytes(buf[:n])
// d := C.CString(string(buf[:n]))
//defer C.free(cbuf2)
C.memcpy(unsafe.Pointer(cbuf), unsafe.Pointer(&buf[0]), C.size_t(n))
return C.int(n)
}
// onDataWriteCallback callback function for gnutls library want to send data to network
//
//export onDataWriteCallback
func onDataWriteCallback(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int {
//log.Println("write addr ", uintptr(d), int(_l))
conn := (*Conn)(unsafe.Pointer((uintptr(d))))
gobuf := C.GoBytes(unsafe.Pointer(cbuf), bufLen)
n, err := conn.c.Write(gobuf)
if err != nil {
//log.Println(err)
return -1
}
return C.int(n)
}
// onDataTimeoutRead callback function for timeout read
//
//export onDataTimeoutRead
func onDataTimeoutRead(d unsafe.Pointer, delay C.int) C.int {
log.Println("timeout pull function")
return 0
}
// onCertSelectCallback callback function for ceritificate select,
// this function select certificate from Config.Certificates field,
//
// on server side, this function select the certificate depend on SNI what client send,
// if client not send SNI, select the Config.Certificates[0]
//
//export onCertSelectCallback
func onCertSelectCallback(ptr unsafe.Pointer, hostname *C.char,
namelen C.int, pcertLength *C.int, cert **C.gnutls_pcert_st, privkey *C.gnutls_privkey_t) C.int {
servername := C.GoStringN(hostname, namelen)
//log.Println("go cert select callback ", servername)
conn := (*Conn)(unsafe.Pointer((uintptr(ptr))))
//log.Println(conn)
if int(namelen) == 0 && conn.cfg.Certificates != nil {
_cert := conn.cfg.Certificates[0]
*pcertLength = _cert.certSize
*cert = _cert.cert
*privkey = _cert.privkey
//log.Println("set pcert length ", _cert.certSize)
return 0
}
for _, _cert := range conn.cfg.Certificates {
//log.Println(cert)
if _cert.matchName(servername) {
//log.Println("matched name ", _cert.names)
*pcertLength = _cert.certSize
*cert = _cert.cert
*privkey = _cert.privkey
//log.Println("set pcert length ", _cert.certSize)
return 0
}
}
if conn.cfg.Certificates != nil {
_cert := conn.cfg.Certificates[0]
*pcertLength = _cert.certSize
*cert = _cert.cert
*privkey = _cert.privkey
//log.Println("set pcert length ", _cert.certSize)
return 0
}
*pcertLength = 0
//log.Println("set pcert length 0")
return -1
}
func init() {
C.init_xcred()
C.init_priority_cache()
}
package gnutls
/*
#include "_gnutls.h"
#cgo pkg-config: gnutls
*/
import "C"
import (
"errors"
"fmt"
"io"
"log"
"net"
"runtime"
"sync"
"time"
"unsafe"
)
const (
GNUTLS_NAME_DNS = 1
GNUTLS_X509_FMT_PEM = 1
GNUTLS_ALPN_MANDATORY = 1
GNUTLS_ALPN_SERVER_PRECEDENCE = 1 << 1
)
// Conn gnutls TLS connection,
// use Listen, Dial, Server or Client create this object
type Conn struct {
c net.Conn
handshake bool
sess *C.struct_session
cservname *C.char
state *ConnectionState
cfg *Config
closed bool
lock *sync.Mutex
}
// Config gnutls TLS configure,
type Config struct {
ServerName string
Certificates []*Certificate
InsecureSkipVerify bool
NextProtos []string
}
// ConnectionState gnutls TLS connection state
type ConnectionState struct {
// SNI name client send
ServerName string
// selected ALPN protocl
NegotiatedProtocol string
HandshakeComplete bool
// TLS version number, ex: 0x303
Version uint16
// TLS version number, ex: TLS1.0
VersionName string
// peer's certificate
PeerCertificate *Certificate
}
type listener struct {
l net.Listener
c *Config
}
// Accept
func (l *listener) Accept() (net.Conn, error) {
c, err := l.l.Accept()
if err != nil {
return nil, err
}
return Server(c, l.c)
}
// Close
func (l *listener) Close() error {
return l.l.Close()
}
// Addr
func (l *listener) Addr() net.Addr {
return l.l.Addr()
}
// Dial dial to (network, addr) and create a gnutls Conn
func Dial(network, addr string, cfg *Config) (*Conn, error) {
c, err := net.Dial(network, addr)
if err != nil {
return nil, err
}
return Client(c, cfg)
}
// Listen create a gnutls listener on (network, addr),
func Listen(network, addr string, cfg *Config) (net.Listener, error) {
if cfg == nil {
return nil, errors.New("config is needed")
}
l, err := net.Listen(network, addr)
if err != nil {
return nil, err
}
return &listener{l, cfg}, nil
}
// Server create a server TLS Conn on c
func Server(c net.Conn, cfg *Config) (*Conn, error) {
if cfg == nil {
return nil, errors.New("config is needed")
}
var sess = C.init_gnutls_server_session()
conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)}
C.set_data(sess, C.size_t(uintptr(unsafe.Pointer(conn))))
C.set_callback(sess)
if cfg.NextProtos != nil {
if err := setAlpnProtocols(sess, cfg); err != nil {
log.Println(err)
}
}
runtime.SetFinalizer(conn, (*Conn).free)
return conn, nil
}
// Client create a client TLS Conn on c
func Client(c net.Conn, cfg *Config) (*Conn, error) {
var sess = C.init_gnutls_client_session()
conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)}
C.set_data(sess, C.size_t(uintptr(unsafe.Pointer(conn))))
C.set_callback(sess)
if cfg != nil {
if cfg.ServerName != "" {
srvname := C.CString(cfg.ServerName)
//defer C.free(unsafe.Pointer(srvname))
conn.cservname = srvname
C.gnutls_server_name_set(sess.session, GNUTLS_NAME_DNS,
unsafe.Pointer(srvname), C.size_t(len(cfg.ServerName)))
}
if !cfg.InsecureSkipVerify {
if conn.cservname != nil {
C.gnutls_session_set_verify_cert(sess.session, conn.cservname, 0)
} else {
C.gnutls_session_set_verify_cert(sess.session, nil, 0)
}
}
if cfg.NextProtos != nil {
if err := setAlpnProtocols(sess, cfg); err != nil {
log.Println(err)
}
}
} else {
C.gnutls_session_set_verify_cert(sess.session, nil, 0)
}
runtime.SetFinalizer(conn, (*Conn).free)
return conn, nil
}
func setAlpnProtocols(sess *C.struct_session, cfg *Config) error {
arg := make([](*C.char), 0)
for _, s := range cfg.NextProtos {
cbuf := C.CString(s)
defer C.free(unsafe.Pointer(cbuf))
arg = append(arg, (*C.char)(cbuf))
}
ret := C.alpn_set_protocols(sess,
(**C.char)(unsafe.Pointer(&arg[0])), C.int(len(cfg.NextProtos)))
if int(ret) < 0 {
return fmt.Errorf("set alpn failed: %s", C.GoString(C.gnutls_strerror(ret)))
}
return nil
}
// Handshake call handshake for TLS Conn,
// this function will call automatic on Read/Write, if not handshake yet
func (c *Conn) Handshake() error {
c.lock.Lock()
defer c.lock.Unlock()
if c.handshake {
return nil
}
ret := C.handshake(c.sess)
if int(ret) < 0 {
return fmt.Errorf("handshake error: %s", C.GoString(C.gnutls_strerror(ret)))
}
c.handshake = true
//log.Println("handshake done")
return nil
}
// Read read application data from TLS connection
func (c *Conn) Read(buf []byte) (n int, err error) {
err = c.Handshake()
if err != nil {
return
}
if len(buf) == 0 {
return 0, nil
}
if c.closed {
return 0, io.EOF
}
//bufLen := len(buf)
//cbuf := C.malloc(C.size_t(bufLen))
//defer C.free(cbuf)
ret := C.gnutls_record_recv(c.sess.session,
unsafe.Pointer(&buf[0]), C.size_t(len(buf)))
if int(ret) < 0 {
return 0, fmt.Errorf("read error: %s",
C.GoString(C.gnutls_strerror(C.int(ret))))
}
if int(ret) == 0 {
return 0, io.EOF
}
n = int(ret)
//gobuf2 := C.GoBytes(cbuf, C.int(ret))
//copy(buf, gobuf2)
return n, nil
}
// Write write application data to TLS connection
func (c *Conn) Write(buf []byte) (n int, err error) {
err = c.Handshake()
if err != nil {
return
}
// user may call Write(nil) to do handshake
if len(buf) == 0 {
return 0, nil
}
if c.closed {
return 0, io.EOF
}
//cbuf := C.CBytes(buf)
//defer C.free(cbuf)
ret := C.gnutls_record_send(c.sess.session,
unsafe.Pointer(&buf[0]), C.size_t(len(buf)))
n = int(ret)
if n < 0 {
return 0, fmt.Errorf("write error: %s",
C.GoString(C.gnutls_strerror(C.int(ret))))
}
if int(ret) == 0 {
return 0, io.EOF
}
return n, nil
}
// Close close the TLS conn and destroy the tls context
func (c *Conn) Close() error {
c.lock.Lock()
defer c.lock.Unlock()
if c.closed {
return nil
}
C.gnutls_record_send(c.sess.session, nil, 0)
C.session_destroy(c.sess)
c.c.Close()
if c.cservname != nil {
C.free(unsafe.Pointer(c.cservname))
}
if c.state != nil && c.state.PeerCertificate != nil {
c.state.PeerCertificate.Free()
}
c.closed = true
return nil
}
func (c *Conn) free() {
//log.Println("free conn")
c.Close()
}
// SetWriteDeadline implements net.Conn
func (c *Conn) SetWriteDeadline(t time.Time) error {
return c.c.SetWriteDeadline(t)
}
// SetReadDeadline implements net.Conn
func (c *Conn) SetReadDeadline(t time.Time) error {
return c.c.SetReadDeadline(t)
}
// RemoteAddr implements net.Conn
func (c *Conn) RemoteAddr() net.Addr {
return c.c.RemoteAddr()
}
// LocalAddr implements net.Conn
func (c *Conn) LocalAddr() net.Addr {
return c.c.LocalAddr()
}
// SetDeadline implements net.Conn
func (c *Conn) SetDeadline(t time.Time) error {
return c.c.SetDeadline(t)
}
// ConnectionState get TLS connection state
func (c *Conn) ConnectionState() *ConnectionState {
if c.state != nil {
return c.state
}
version :=
uint16(C.gnutls_protocol_get_version(c.sess.session))
versionname := C.GoString(
C.gnutls_protocol_get_name(C.gnutls_protocol_t(version)))
state := &ConnectionState{
NegotiatedProtocol: c.getAlpnSelectedProtocol(),
Version: version,
HandshakeComplete: c.handshake,
ServerName: c.getServerName(),
VersionName: versionname,
PeerCertificate: c.getPeerCertificate(),
}
c.state = state
return state
}
func (c *Conn) getPeerCertificate() *Certificate {
var size int
st := C.get_peer_certificate(c.sess.session, (*C.int)(unsafe.Pointer(&size)))
if st == nil {
return nil
}
cert := &Certificate{cert: st, certSize: C.int(size)}
runtime.SetFinalizer(cert, (*Certificate).free)
return cert
}
func (c *Conn) getAlpnSelectedProtocol() string {
cbuf := C.malloc(100)
defer C.free(cbuf)
ret := C.alpn_get_selected_protocol(c.sess, (*C.char)(cbuf))
if int(ret) < 0 {
return ""
}
alpnname := C.GoString((*C.char)(cbuf))
return alpnname
}
func (c *Conn) getServerName() string {
buflen := 100
nametype := GNUTLS_NAME_DNS
cbuf := C.malloc(C.size_t(buflen))
defer C.free(cbuf)
ret := C.gnutls_server_name_get(c.sess.session, cbuf,
(*C.size_t)(unsafe.Pointer(&buflen)),
(*C.uint)(unsafe.Pointer(&nametype)), 0)
if int(ret) < 0 {
return ""
}
name := C.GoString((*C.char)(cbuf))
return name
}
/*
var bufPool = &sync.Pool{
New: func() interface{} {
return make([]byte, 16*1024)
},
}
*/
// onDataReadCallback callback function for gnutls library want to read data from network
//
//export onDataReadCallback
func onDataReadCallback(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int {
//log.Println("read addr ", uintptr(d))
conn := (*Conn)(unsafe.Pointer((uintptr(d))))
buf := make([]byte, int(bufLen))
/*
_length := int(bufLen)
buf := bufPool.Get().([]byte)
if len(buf) < _length {
buf = make([]byte, _length)
}
defer bufPool.Put(buf)
*/
n, err := conn.c.Read(buf[0:])
if err != nil {
//log.Println(err)
// 0 indicates connection termination
return 0
}
//cbuf2 := C.CBytes(buf[:n])
// d := C.CString(string(buf[:n]))
//defer C.free(cbuf2)
C.memcpy(unsafe.Pointer(cbuf), unsafe.Pointer(&buf[0]), C.size_t(n))
return C.int(n)
}
// onDataWriteCallback callback function for gnutls library want to send data to network
//
//export onDataWriteCallback
func onDataWriteCallback(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int {
//log.Println("write addr ", uintptr(d), int(_l))
conn := (*Conn)(unsafe.Pointer((uintptr(d))))
gobuf := C.GoBytes(unsafe.Pointer(cbuf), bufLen)
n, err := conn.c.Write(gobuf)
if err != nil {
//log.Println(err)
return -1
}
return C.int(n)
}
// onDataTimeoutRead callback function for timeout read
//
//export onDataTimeoutRead
func onDataTimeoutRead(d unsafe.Pointer, delay C.int) C.int {
log.Println("timeout pull function")
return 0
}
// onCertSelectCallback callback function for ceritificate select,
// this function select certificate from Config.Certificates field,
//
// on server side, this function select the certificate depend on SNI what client send,
// if client not send SNI, select the Config.Certificates[0]
//
//export onCertSelectCallback
func onCertSelectCallback(ptr unsafe.Pointer, hostname *C.char,
namelen C.int, pcertLength *C.int, cert **C.gnutls_pcert_st, privkey *C.gnutls_privkey_t) C.int {
servername := C.GoStringN(hostname, namelen)
//log.Println("go cert select callback ", servername)
conn := (*Conn)(unsafe.Pointer((uintptr(ptr))))
//log.Println(conn)
if int(namelen) == 0 && conn.cfg.Certificates != nil {
_cert := conn.cfg.Certificates[0]
*pcertLength = _cert.certSize
*cert = _cert.cert
*privkey = _cert.privkey
//log.Println("set pcert length ", _cert.certSize)
return 0
}
for _, _cert := range conn.cfg.Certificates {
//log.Println(cert)
if _cert.matchName(servername) {
//log.Println("matched name ", _cert.names)
*pcertLength = _cert.certSize
*cert = _cert.cert
*privkey = _cert.privkey
//log.Println("set pcert length ", _cert.certSize)
return 0
}
}
if conn.cfg.Certificates != nil {
_cert := conn.cfg.Certificates[0]
*pcertLength = _cert.certSize
*cert = _cert.cert
*privkey = _cert.privkey
//log.Println("set pcert length ", _cert.certSize)
return 0
}
*pcertLength = 0
//log.Println("set pcert length 0")
return -1
}
func init() {
C.init_xcred()
C.init_priority_cache()
}

@ -1,431 +1,431 @@
package gnutls
import (
"bufio"
"bytes"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"runtime"
"testing"
"time"
)
func TestTLSClient(t *testing.T) {
cert, err := tls.LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
if err != nil {
t.Fatal("load certificate failed")
}
l, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
Certificates: []tls.Certificate{cert},
})
if err != nil {
t.Fatal("listen failed")
}
defer l.Close()
addr := l.Addr().String()
t.Log("test server listen on ", addr)
go func() {
for {
c, err := l.Accept()
if err != nil {
break
}
t.Logf("accept connection from %s", c.RemoteAddr())
go func(c net.Conn) {
defer c.Close()
for {
buf := make([]byte, 4096)
n, err := c.Read(buf)
if err != nil {
t.Log("connection closed")
break
}
if _, err = c.Write(buf[:n]); err != nil {
break
}
}
}(c)
}
}()
c, err := Dial("tcp", addr, &Config{InsecureSkipVerify: true})
if err != nil {
t.Fatal("gnutls dial ", err)
}
defer c.Close()
data := "hello, world"
if _, err = c.Write([]byte(data)); err != nil {
t.Fatal("gnutls write ", err)
}
buf := make([]byte, 100)
n, err := c.Read(buf)
if err != nil {
t.Fatal("gnutls read ", err)
}
if string(buf[:n]) != data {
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
}
}
func TestTLSServer(t *testing.T) {
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
if err != nil {
t.Fatal(err)
}
l, err := Listen("tcp", "127.0.0.1:0", &Config{
Certificates: []*Certificate{cert},
})
if err != nil {
t.Fatal("gnutls listen ", err)
}
addr := l.Addr().String()
t.Log("test server listen on ", addr)
defer l.Close()
go func() {
for {
c, err := l.Accept()
if err != nil {
t.Log("gnutls accept ", err)
break
}
t.Log("accept connection from ", c.RemoteAddr())
go func(c net.Conn) {
defer c.Close()
tlsconn := c.(*Conn)
if err := tlsconn.Handshake(); err != nil {
t.Log(err)
return
}
buf := make([]byte, 4096)
for {
n, err := c.Read(buf[0:])
if err != nil {
t.Log("gnutls read ", err)
break
}
if _, err := c.Write(buf[:n]); err != nil {
t.Log("gnutls write ", err)
break
}
}
}(c)
}
}()
c, err := tls.Dial("tcp", addr, &tls.Config{InsecureSkipVerify: true})
if err != nil {
t.Fatal("dial ", err)
}
defer c.Close()
data := "hello, world"
if _, err := c.Write([]byte(data)); err != nil {
t.Fatal("write ", err)
}
buf := make([]byte, 100)
n, err := c.Read(buf)
if err != nil {
t.Fatal("read ", err)
}
if string(buf[:n]) != data {
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
}
runtime.GC()
time.Sleep(1 * time.Second)
}
func TestTLSALPNServer(t *testing.T) {
serveralpn := []string{"a1", "a3", "a2"}
clientalpn := []string{"a0", "a2", "a5"}
expectedAlpn := "a2"
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
if err != nil {
t.Fatal(err)
}
l, err := Listen("tcp", "127.0.0.1:0", &Config{
Certificates: []*Certificate{cert},
NextProtos: serveralpn,
})
if err != nil {
t.Fatal("gnutls listen ", err)
}
addr := l.Addr().String()
t.Log("test server listen on ", addr)
defer l.Close()
go func() {
for {
c, err := l.Accept()
if err != nil {
t.Log("gnutls accept ", err)
break
}
t.Log("accept connection from ", c.RemoteAddr())
go func(c net.Conn) {
defer c.Close()
tlsConn := c.(*Conn)
if err := tlsConn.Handshake(); err != nil {
t.Log(err)
return
}
connState := tlsConn.ConnectionState()
t.Logf("%+v", connState)
buf := make([]byte, 4096)
for {
n, err := c.Read(buf[0:])
if err != nil {
t.Log("gnutls read ", err)
break
}
if _, err := c.Write(buf[:n]); err != nil {
t.Log("gnutls write ", err)
break
}
}
}(c)
}
}()
c, err := tls.Dial("tcp", addr, &tls.Config{
InsecureSkipVerify: true,
ServerName: "localhost",
NextProtos: clientalpn,
})
if err != nil {
t.Fatal("dial ", err)
}
defer c.Close()
if err := c.Handshake(); err != nil {
t.Fatal(err)
}
connState := c.ConnectionState()
t.Logf("%+v", connState)
if connState.NegotiatedProtocol != expectedAlpn {
t.Errorf("expected alpn %s, got %s",
expectedAlpn, connState.NegotiatedProtocol)
}
data := "hello, world"
if _, err := c.Write([]byte(data)); err != nil {
t.Fatal("write ", err)
}
buf := make([]byte, 100)
n, err := c.Read(buf)
if err != nil {
t.Fatal("read ", err)
}
if string(buf[:n]) != data {
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
}
runtime.GC()
time.Sleep(1 * time.Second)
}
func TestTLSALPNClient(t *testing.T) {
serveralpn := []string{"a1", "a3", "a2"}
clientalpn := []string{"a0", "a2", "a5"}
expectedAlpn := "a2"
cert, err := tls.LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
if err != nil {
t.Fatal("load key failed")
}
l, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
Certificates: []tls.Certificate{cert},
NextProtos: serveralpn,
})
if err != nil {
t.Fatal("tls listen ", err)
}
addr := l.Addr().String()
t.Log("test server listen on ", addr)
defer l.Close()
go func() {
for {
c, err := l.Accept()
if err != nil {
t.Log("gnutls accept ", err)
break
}
t.Log("accept connection from ", c.RemoteAddr())
go func(c net.Conn) {
defer c.Close()
tlsConn := c.(*tls.Conn)
if err := tlsConn.Handshake(); err != nil {
t.Log(err)
return
}
connState := tlsConn.ConnectionState()
t.Logf("%+v", connState)
buf := make([]byte, 4096)
for {
n, err := c.Read(buf[0:])
if err != nil {
t.Log("tls read ", err)
break
}
if _, err := c.Write(buf[:n]); err != nil {
t.Log("tls write ", err)
break
}
}
}(c)
}
}()
c, err := Dial("tcp", addr, &Config{InsecureSkipVerify: true,
ServerName: "localhost",
NextProtos: clientalpn,
})
if err != nil {
t.Fatal("dial ", err)
}
defer c.Close()
if err := c.Handshake(); err != nil {
t.Fatal(err)
}
connState := c.ConnectionState()
t.Logf("%+v", connState)
if connState.NegotiatedProtocol != expectedAlpn {
t.Errorf("expected alpn %s, got %s",
expectedAlpn, connState.NegotiatedProtocol)
}
data := "hello, world"
if _, err := c.Write([]byte(data)); err != nil {
t.Fatal("write ", err)
}
buf := make([]byte, 100)
n, err := c.Read(buf)
if err != nil {
t.Fatal("read ", err)
}
if string(buf[:n]) != data {
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
}
runtime.GC()
time.Sleep(1 * time.Second)
}
func TestTLSServerSNI(t *testing.T) {
certificates := []*Certificate{}
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
if err != nil {
t.Fatal("load key failed")
}
certificates = append(certificates, cert)
cert, err = LoadX509KeyPair("testdata/server2.crt", "testdata/server2.key")
if err != nil {
t.Fatal("load key failed")
}
certificates = append(certificates, cert)
cert, err = LoadX509KeyPair("testdata/server3.crt", "testdata/server3.key")
if err != nil {
t.Fatal("load key failed")
}
certificates = append(certificates, cert)
l, err := Listen("tcp", "127.0.0.1:0", &Config{
Certificates: certificates,
})
if err != nil {
t.Fatal(err)
}
defer l.Close()
addr := l.Addr().String()
go func() {
for {
c, err := l.Accept()
if err != nil {
t.Log(err)
break
}
go func(c net.Conn) {
defer c.Close()
tlsconn := c.(*Conn)
if err := tlsconn.Handshake(); err != nil {
t.Log(err)
return
}
state := tlsconn.ConnectionState()
fmt.Fprintf(c, state.ServerName)
}(c)
}
}()
for _, cfg := range []struct {
serverName string
commonName string
}{
{"abc.com", "abc.com"},
{"example.com", "example.com"},
{"a.aaa.com", "*.aaa.com"},
{"b.aaa.com", "*.aaa.com"},
} {
conn, err := tls.Dial("tcp", addr, &tls.Config{
ServerName: cfg.serverName,
InsecureSkipVerify: true,
})
if err != nil {
t.Fatal(err)
}
state := conn.ConnectionState()
_commonName := state.PeerCertificates[0].Subject.CommonName
if _commonName != cfg.commonName {
t.Errorf("expect: %s, got: %s", cfg.commonName, _commonName)
}
buf := make([]byte, 100)
n, err := conn.Read(buf)
if err != nil && err != io.EOF {
t.Error(err)
}
if !bytes.Equal(buf[:n], []byte(cfg.serverName)) {
t.Errorf("expect %s, got %s", cfg.serverName, string(buf[:n]))
}
conn.Close()
}
runtime.GC()
time.Sleep(1 * time.Second)
}
func TestTLSGetPeerCert(t *testing.T) {
conn, err := Dial("tcp", "www.ratafee.nl:443", &Config{
ServerName: "www.ratafee.nl",
})
if err != nil {
t.Fatal(err)
}
defer conn.Close()
//tlsconn := conn.(*Conn)
if err := conn.Handshake(); err != nil {
t.Fatal(err)
}
state := conn.ConnectionState()
for i := 0; i < int(state.PeerCertificate.certSize); i++ {
t.Log(state.PeerCertificate.getCertString(i, 1))
}
req, _ := http.NewRequest("GET", "https://www.ratafee.nl/httpbin/ip", nil)
req.Write(conn)
r := bufio.NewReader(conn)
resp, err := http.ReadResponse(r, req)
if err != nil {
t.Error(err)
}
var buf = new(bytes.Buffer)
resp.Write(buf)
t.Logf("%s", string(buf.Bytes()))
runtime.GC()
time.Sleep(1 * time.Second)
}
package gnutls
import (
"bufio"
"bytes"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"runtime"
"testing"
"time"
)
func TestTLSClient(t *testing.T) {
cert, err := tls.LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
if err != nil {
t.Fatal("load certificate failed")
}
l, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
Certificates: []tls.Certificate{cert},
})
if err != nil {
t.Fatal("listen failed")
}
defer l.Close()
addr := l.Addr().String()
t.Log("test server listen on ", addr)
go func() {
for {
c, err := l.Accept()
if err != nil {
break
}
t.Logf("accept connection from %s", c.RemoteAddr())
go func(c net.Conn) {
defer c.Close()
for {
buf := make([]byte, 4096)
n, err := c.Read(buf)
if err != nil {
t.Log("connection closed")
break
}
if _, err = c.Write(buf[:n]); err != nil {
break
}
}
}(c)
}
}()
c, err := Dial("tcp", addr, &Config{InsecureSkipVerify: true})
if err != nil {
t.Fatal("gnutls dial ", err)
}
defer c.Close()
data := "hello, world"
if _, err = c.Write([]byte(data)); err != nil {
t.Fatal("gnutls write ", err)
}
buf := make([]byte, 100)
n, err := c.Read(buf)
if err != nil {
t.Fatal("gnutls read ", err)
}
if string(buf[:n]) != data {
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
}
}
func TestTLSServer(t *testing.T) {
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
if err != nil {
t.Fatal(err)
}
l, err := Listen("tcp", "127.0.0.1:0", &Config{
Certificates: []*Certificate{cert},
})
if err != nil {
t.Fatal("gnutls listen ", err)
}
addr := l.Addr().String()
t.Log("test server listen on ", addr)
defer l.Close()
go func() {
for {
c, err := l.Accept()
if err != nil {
t.Log("gnutls accept ", err)
break
}
t.Log("accept connection from ", c.RemoteAddr())
go func(c net.Conn) {
defer c.Close()
tlsconn := c.(*Conn)
if err := tlsconn.Handshake(); err != nil {
t.Log(err)
return
}
buf := make([]byte, 4096)
for {
n, err := c.Read(buf[0:])
if err != nil {
t.Log("gnutls read ", err)
break
}
if _, err := c.Write(buf[:n]); err != nil {
t.Log("gnutls write ", err)
break
}
}
}(c)
}
}()
c, err := tls.Dial("tcp", addr, &tls.Config{InsecureSkipVerify: true})
if err != nil {
t.Fatal("dial ", err)
}
defer c.Close()
data := "hello, world"
if _, err := c.Write([]byte(data)); err != nil {
t.Fatal("write ", err)
}
buf := make([]byte, 100)
n, err := c.Read(buf)
if err != nil {
t.Fatal("read ", err)
}
if string(buf[:n]) != data {
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
}
runtime.GC()
time.Sleep(1 * time.Second)
}
func TestTLSALPNServer(t *testing.T) {
serveralpn := []string{"a1", "a3", "a2"}
clientalpn := []string{"a0", "a2", "a5"}
expectedAlpn := "a2"
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
if err != nil {
t.Fatal(err)
}
l, err := Listen("tcp", "127.0.0.1:0", &Config{
Certificates: []*Certificate{cert},
NextProtos: serveralpn,
})
if err != nil {
t.Fatal("gnutls listen ", err)
}
addr := l.Addr().String()
t.Log("test server listen on ", addr)
defer l.Close()
go func() {
for {
c, err := l.Accept()
if err != nil {
t.Log("gnutls accept ", err)
break
}
t.Log("accept connection from ", c.RemoteAddr())
go func(c net.Conn) {
defer c.Close()
tlsConn := c.(*Conn)
if err := tlsConn.Handshake(); err != nil {
t.Log(err)
return
}
connState := tlsConn.ConnectionState()
t.Logf("%+v", connState)
buf := make([]byte, 4096)
for {
n, err := c.Read(buf[0:])
if err != nil {
t.Log("gnutls read ", err)
break
}
if _, err := c.Write(buf[:n]); err != nil {
t.Log("gnutls write ", err)
break
}
}
}(c)
}
}()
c, err := tls.Dial("tcp", addr, &tls.Config{
InsecureSkipVerify: true,
ServerName: "localhost",
NextProtos: clientalpn,
})
if err != nil {
t.Fatal("dial ", err)
}
defer c.Close()
if err := c.Handshake(); err != nil {
t.Fatal(err)
}
connState := c.ConnectionState()
t.Logf("%+v", connState)
if connState.NegotiatedProtocol != expectedAlpn {
t.Errorf("expected alpn %s, got %s",
expectedAlpn, connState.NegotiatedProtocol)
}
data := "hello, world"
if _, err := c.Write([]byte(data)); err != nil {
t.Fatal("write ", err)
}
buf := make([]byte, 100)
n, err := c.Read(buf)
if err != nil {
t.Fatal("read ", err)
}
if string(buf[:n]) != data {
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
}
runtime.GC()
time.Sleep(1 * time.Second)
}
func TestTLSALPNClient(t *testing.T) {
serveralpn := []string{"a1", "a3", "a2"}
clientalpn := []string{"a0", "a2", "a5"}
expectedAlpn := "a2"
cert, err := tls.LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
if err != nil {
t.Fatal("load key failed")
}
l, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
Certificates: []tls.Certificate{cert},
NextProtos: serveralpn,
})
if err != nil {
t.Fatal("tls listen ", err)
}
addr := l.Addr().String()
t.Log("test server listen on ", addr)
defer l.Close()
go func() {
for {
c, err := l.Accept()
if err != nil {
t.Log("gnutls accept ", err)
break
}
t.Log("accept connection from ", c.RemoteAddr())
go func(c net.Conn) {
defer c.Close()
tlsConn := c.(*tls.Conn)
if err := tlsConn.Handshake(); err != nil {
t.Log(err)
return
}
connState := tlsConn.ConnectionState()
t.Logf("%+v", connState)
buf := make([]byte, 4096)
for {
n, err := c.Read(buf[0:])
if err != nil {
t.Log("tls read ", err)
break
}
if _, err := c.Write(buf[:n]); err != nil {
t.Log("tls write ", err)
break
}
}
}(c)
}
}()
c, err := Dial("tcp", addr, &Config{InsecureSkipVerify: true,
ServerName: "localhost",
NextProtos: clientalpn,
})
if err != nil {
t.Fatal("dial ", err)
}
defer c.Close()
if err := c.Handshake(); err != nil {
t.Fatal(err)
}
connState := c.ConnectionState()
t.Logf("%+v", connState)
if connState.NegotiatedProtocol != expectedAlpn {
t.Errorf("expected alpn %s, got %s",
expectedAlpn, connState.NegotiatedProtocol)
}
data := "hello, world"
if _, err := c.Write([]byte(data)); err != nil {
t.Fatal("write ", err)
}
buf := make([]byte, 100)
n, err := c.Read(buf)
if err != nil {
t.Fatal("read ", err)
}
if string(buf[:n]) != data {
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
}
runtime.GC()
time.Sleep(1 * time.Second)
}
func TestTLSServerSNI(t *testing.T) {
certificates := []*Certificate{}
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
if err != nil {
t.Fatal("load key failed")
}
certificates = append(certificates, cert)
cert, err = LoadX509KeyPair("testdata/server2.crt", "testdata/server2.key")
if err != nil {
t.Fatal("load key failed")
}
certificates = append(certificates, cert)
cert, err = LoadX509KeyPair("testdata/server3.crt", "testdata/server3.key")
if err != nil {
t.Fatal("load key failed")
}
certificates = append(certificates, cert)
l, err := Listen("tcp", "127.0.0.1:0", &Config{
Certificates: certificates,
})
if err != nil {
t.Fatal(err)
}
defer l.Close()
addr := l.Addr().String()
go func() {
for {
c, err := l.Accept()
if err != nil {
t.Log(err)
break
}
go func(c net.Conn) {
defer c.Close()
tlsconn := c.(*Conn)
if err := tlsconn.Handshake(); err != nil {
t.Log(err)
return
}
state := tlsconn.ConnectionState()
fmt.Fprintf(c, state.ServerName)
}(c)
}
}()
for _, cfg := range []struct {
serverName string
commonName string
}{
{"abc.com", "abc.com"},
{"example.com", "example.com"},
{"a.aaa.com", "*.aaa.com"},
{"b.aaa.com", "*.aaa.com"},
} {
conn, err := tls.Dial("tcp", addr, &tls.Config{
ServerName: cfg.serverName,
InsecureSkipVerify: true,
})
if err != nil {
t.Fatal(err)
}
state := conn.ConnectionState()
_commonName := state.PeerCertificates[0].Subject.CommonName
if _commonName != cfg.commonName {
t.Errorf("expect: %s, got: %s", cfg.commonName, _commonName)
}
buf := make([]byte, 100)
n, err := conn.Read(buf)
if err != nil && err != io.EOF {
t.Error(err)
}
if !bytes.Equal(buf[:n], []byte(cfg.serverName)) {
t.Errorf("expect %s, got %s", cfg.serverName, string(buf[:n]))
}
conn.Close()
}
runtime.GC()
time.Sleep(1 * time.Second)
}
func TestTLSGetPeerCert(t *testing.T) {
conn, err := Dial("tcp", "www.ratafee.nl:443", &Config{
ServerName: "www.ratafee.nl",
})
if err != nil {
t.Fatal(err)
}
defer conn.Close()
//tlsconn := conn.(*Conn)
if err := conn.Handshake(); err != nil {
t.Fatal(err)
}
state := conn.ConnectionState()
for i := 0; i < int(state.PeerCertificate.certSize); i++ {
t.Log(state.PeerCertificate.getCertString(i, 1))
}
req, _ := http.NewRequest("GET", "https://www.ratafee.nl/httpbin/ip", nil)
req.Write(conn)
r := bufio.NewReader(conn)
resp, err := http.ReadResponse(r, req)
if err != nil {
t.Error(err)
}
var buf = new(bytes.Buffer)
resp.Write(buf)
t.Logf("%s", string(buf.Bytes()))
runtime.GC()
time.Sleep(1 * time.Second)
}

Loading…
Cancel
Save