go fmt
parent
2c4c546551
commit
e168de539b
@ -1,176 +1,176 @@
|
||||
package gnutls
|
||||
|
||||
/*
|
||||
#include "_gnutls.h"
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"runtime"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// Certificate x509 certificate
|
||||
type Certificate struct {
|
||||
cert *C.gnutls_pcert_st
|
||||
privkey C.gnutls_privkey_t
|
||||
certSize C.int
|
||||
}
|
||||
|
||||
// Free free the certificate context
|
||||
func (c *Certificate) Free() {
|
||||
if c.cert != nil {
|
||||
C.free_cert_list(c.cert, c.certSize)
|
||||
}
|
||||
if c.privkey != nil {
|
||||
C.gnutls_privkey_deinit(c.privkey)
|
||||
}
|
||||
c.cert = nil
|
||||
c.privkey = nil
|
||||
c.certSize = 0
|
||||
}
|
||||
|
||||
func (c *Certificate) free() {
|
||||
//log.Println("free certificate")
|
||||
c.Free()
|
||||
}
|
||||
|
||||
func (c *Certificate) matchName(name string) bool {
|
||||
cname := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
ret := C.cert_check_hostname(c.cert, c.certSize, cname)
|
||||
if int(ret) < 0 {
|
||||
log.Println(C.GoString(C.gnutls_strerror(ret)))
|
||||
return false
|
||||
}
|
||||
if int(ret) > 0 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// CommonName get CN field in subject,
|
||||
//
|
||||
// example: subject C=xx,ST=bbb,CN=abc will return abc
|
||||
func (c *Certificate) CommonName() string {
|
||||
return c.commonName(0)
|
||||
}
|
||||
func (c *Certificate) commonName(index int) string {
|
||||
dnMap := make(map[string]string)
|
||||
dn := c.getDN(0)
|
||||
if dn != "" {
|
||||
dnFields := strings.Split(dn, ",")
|
||||
for _, d := range dnFields {
|
||||
dd := strings.Split(d, "=")
|
||||
k := dd[0]
|
||||
v := dd[1]
|
||||
dnMap[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
if cn, ok := dnMap["CN"]; ok {
|
||||
return cn
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetAltName get altname in certificate
|
||||
func (c *Certificate) GetAltName() string {
|
||||
return c.getAltName(0, 0)
|
||||
}
|
||||
func (c *Certificate) getAltName(index int, nameindex int) string {
|
||||
out := C.malloc(1024)
|
||||
defer C.free(out)
|
||||
size := C.get_pcert_alt_name(
|
||||
c.cert, C.int(index), C.int(nameindex), (*C.char)(out))
|
||||
if int(size) < 0 {
|
||||
log.Println(C.GoString(C.gnutls_strerror(size)))
|
||||
return ""
|
||||
}
|
||||
|
||||
name := C.GoBytes(out, size)
|
||||
return string(name)
|
||||
}
|
||||
|
||||
//GetCertString return certificate info string in one line
|
||||
func (c *Certificate) GetCertString() string {
|
||||
return c.getCertString(0, 1)
|
||||
}
|
||||
|
||||
func (c *Certificate) getCertString(index int, flag int) string {
|
||||
out := C.malloc(4096)
|
||||
defer C.free(out)
|
||||
size := C.get_cert_str(c.cert, C.int(index), C.int(flag), (*C.char)(out))
|
||||
if int(size) < 0 {
|
||||
log.Println(C.GoString(C.gnutls_strerror(size)))
|
||||
return ""
|
||||
}
|
||||
s := C.GoBytes(out, size)
|
||||
return string(s)
|
||||
}
|
||||
|
||||
// GetDN get the certificate subject, like O=st,C=aa,CN=localhost
|
||||
func (c *Certificate) GetDN() string {
|
||||
return c.getDN(0)
|
||||
}
|
||||
|
||||
func (c *Certificate) getDN(index int) string {
|
||||
cbuf := C.malloc(200)
|
||||
defer C.free(cbuf)
|
||||
size := C.get_cert_dn(c.cert, C.int(index), (*C.char)(cbuf))
|
||||
if int(size) < 0 {
|
||||
log.Println(C.GoString(C.gnutls_strerror(size)))
|
||||
return ""
|
||||
}
|
||||
s := C.GoBytes(cbuf, size)
|
||||
return string(s)
|
||||
}
|
||||
|
||||
// GetIssuerDN get the certificate issuer's subject, like O=st,C=ac,CN=localhost
|
||||
func (c *Certificate) GetIssuerDN() string {
|
||||
return c.getIssuerDN(0)
|
||||
}
|
||||
|
||||
func (c *Certificate) getIssuerDN(index int) string {
|
||||
cbuf := C.malloc(200)
|
||||
defer C.free(cbuf)
|
||||
size := C.get_cert_issuer_dn(c.cert, C.int(index), (*C.char)(cbuf))
|
||||
if int(size) < 0 {
|
||||
log.Println(C.GoString(C.gnutls_strerror(size)))
|
||||
return ""
|
||||
}
|
||||
s := C.GoBytes(cbuf, size)
|
||||
return string(s)
|
||||
}
|
||||
|
||||
// LoadX509KeyPair load certificate pair,
|
||||
// the return Certifciate must be freed by call Free(),
|
||||
func LoadX509KeyPair(certfile, keyfile string) (*Certificate, error) {
|
||||
_certfile := C.CString(certfile)
|
||||
_keyfile := C.CString(keyfile)
|
||||
|
||||
defer C.free(unsafe.Pointer(_certfile))
|
||||
defer C.free(unsafe.Pointer(_keyfile))
|
||||
|
||||
certificate := &Certificate{}
|
||||
var ret C.int
|
||||
var certSize C.int
|
||||
cert := C.load_cert_list(_certfile, (*C.int)(unsafe.Pointer(&certSize)),
|
||||
(*C.int)(unsafe.Pointer(&ret)))
|
||||
if int(ret) < 0 {
|
||||
return nil, fmt.Errorf("load cert failed: %s",
|
||||
C.GoString(C.gnutls_strerror(ret)))
|
||||
}
|
||||
privkey := C.load_privkey(_keyfile, (*C.int)(unsafe.Pointer(&ret)))
|
||||
if int(ret) < 0 {
|
||||
return nil, fmt.Errorf("load privkey: %s",
|
||||
C.GoString(C.gnutls_strerror(ret)))
|
||||
}
|
||||
certificate.cert = cert
|
||||
certificate.privkey = privkey
|
||||
certificate.certSize = certSize
|
||||
runtime.SetFinalizer(certificate, (*Certificate).free)
|
||||
return certificate, nil
|
||||
}
|
||||
package gnutls
|
||||
|
||||
/*
|
||||
#include "_gnutls.h"
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"runtime"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// Certificate x509 certificate
|
||||
type Certificate struct {
|
||||
cert *C.gnutls_pcert_st
|
||||
privkey C.gnutls_privkey_t
|
||||
certSize C.int
|
||||
}
|
||||
|
||||
// Free free the certificate context
|
||||
func (c *Certificate) Free() {
|
||||
if c.cert != nil {
|
||||
C.free_cert_list(c.cert, c.certSize)
|
||||
}
|
||||
if c.privkey != nil {
|
||||
C.gnutls_privkey_deinit(c.privkey)
|
||||
}
|
||||
c.cert = nil
|
||||
c.privkey = nil
|
||||
c.certSize = 0
|
||||
}
|
||||
|
||||
func (c *Certificate) free() {
|
||||
//log.Println("free certificate")
|
||||
c.Free()
|
||||
}
|
||||
|
||||
func (c *Certificate) matchName(name string) bool {
|
||||
cname := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
ret := C.cert_check_hostname(c.cert, c.certSize, cname)
|
||||
if int(ret) < 0 {
|
||||
log.Println(C.GoString(C.gnutls_strerror(ret)))
|
||||
return false
|
||||
}
|
||||
if int(ret) > 0 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// CommonName get CN field in subject,
|
||||
//
|
||||
// example: subject C=xx,ST=bbb,CN=abc will return abc
|
||||
func (c *Certificate) CommonName() string {
|
||||
return c.commonName(0)
|
||||
}
|
||||
func (c *Certificate) commonName(index int) string {
|
||||
dnMap := make(map[string]string)
|
||||
dn := c.getDN(0)
|
||||
if dn != "" {
|
||||
dnFields := strings.Split(dn, ",")
|
||||
for _, d := range dnFields {
|
||||
dd := strings.Split(d, "=")
|
||||
k := dd[0]
|
||||
v := dd[1]
|
||||
dnMap[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
if cn, ok := dnMap["CN"]; ok {
|
||||
return cn
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetAltName get altname in certificate
|
||||
func (c *Certificate) GetAltName() string {
|
||||
return c.getAltName(0, 0)
|
||||
}
|
||||
func (c *Certificate) getAltName(index int, nameindex int) string {
|
||||
out := C.malloc(1024)
|
||||
defer C.free(out)
|
||||
size := C.get_pcert_alt_name(
|
||||
c.cert, C.int(index), C.int(nameindex), (*C.char)(out))
|
||||
if int(size) < 0 {
|
||||
log.Println(C.GoString(C.gnutls_strerror(size)))
|
||||
return ""
|
||||
}
|
||||
|
||||
name := C.GoBytes(out, size)
|
||||
return string(name)
|
||||
}
|
||||
|
||||
//GetCertString return certificate info string in one line
|
||||
func (c *Certificate) GetCertString() string {
|
||||
return c.getCertString(0, 1)
|
||||
}
|
||||
|
||||
func (c *Certificate) getCertString(index int, flag int) string {
|
||||
out := C.malloc(4096)
|
||||
defer C.free(out)
|
||||
size := C.get_cert_str(c.cert, C.int(index), C.int(flag), (*C.char)(out))
|
||||
if int(size) < 0 {
|
||||
log.Println(C.GoString(C.gnutls_strerror(size)))
|
||||
return ""
|
||||
}
|
||||
s := C.GoBytes(out, size)
|
||||
return string(s)
|
||||
}
|
||||
|
||||
// GetDN get the certificate subject, like O=st,C=aa,CN=localhost
|
||||
func (c *Certificate) GetDN() string {
|
||||
return c.getDN(0)
|
||||
}
|
||||
|
||||
func (c *Certificate) getDN(index int) string {
|
||||
cbuf := C.malloc(200)
|
||||
defer C.free(cbuf)
|
||||
size := C.get_cert_dn(c.cert, C.int(index), (*C.char)(cbuf))
|
||||
if int(size) < 0 {
|
||||
log.Println(C.GoString(C.gnutls_strerror(size)))
|
||||
return ""
|
||||
}
|
||||
s := C.GoBytes(cbuf, size)
|
||||
return string(s)
|
||||
}
|
||||
|
||||
// GetIssuerDN get the certificate issuer's subject, like O=st,C=ac,CN=localhost
|
||||
func (c *Certificate) GetIssuerDN() string {
|
||||
return c.getIssuerDN(0)
|
||||
}
|
||||
|
||||
func (c *Certificate) getIssuerDN(index int) string {
|
||||
cbuf := C.malloc(200)
|
||||
defer C.free(cbuf)
|
||||
size := C.get_cert_issuer_dn(c.cert, C.int(index), (*C.char)(cbuf))
|
||||
if int(size) < 0 {
|
||||
log.Println(C.GoString(C.gnutls_strerror(size)))
|
||||
return ""
|
||||
}
|
||||
s := C.GoBytes(cbuf, size)
|
||||
return string(s)
|
||||
}
|
||||
|
||||
// LoadX509KeyPair load certificate pair,
|
||||
// the return Certifciate must be freed by call Free(),
|
||||
func LoadX509KeyPair(certfile, keyfile string) (*Certificate, error) {
|
||||
_certfile := C.CString(certfile)
|
||||
_keyfile := C.CString(keyfile)
|
||||
|
||||
defer C.free(unsafe.Pointer(_certfile))
|
||||
defer C.free(unsafe.Pointer(_keyfile))
|
||||
|
||||
certificate := &Certificate{}
|
||||
var ret C.int
|
||||
var certSize C.int
|
||||
cert := C.load_cert_list(_certfile, (*C.int)(unsafe.Pointer(&certSize)),
|
||||
(*C.int)(unsafe.Pointer(&ret)))
|
||||
if int(ret) < 0 {
|
||||
return nil, fmt.Errorf("load cert failed: %s",
|
||||
C.GoString(C.gnutls_strerror(ret)))
|
||||
}
|
||||
privkey := C.load_privkey(_keyfile, (*C.int)(unsafe.Pointer(&ret)))
|
||||
if int(ret) < 0 {
|
||||
return nil, fmt.Errorf("load privkey: %s",
|
||||
C.GoString(C.gnutls_strerror(ret)))
|
||||
}
|
||||
certificate.cert = cert
|
||||
certificate.privkey = privkey
|
||||
certificate.certSize = certSize
|
||||
runtime.SetFinalizer(certificate, (*Certificate).free)
|
||||
return certificate, nil
|
||||
}
|
||||
|
@ -1,35 +1,35 @@
|
||||
package gnutls
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGetAltname(t *testing.T) {
|
||||
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("%+v", cert)
|
||||
t.Log(cert.GetAltName())
|
||||
t.Log(cert.GetCertString())
|
||||
t.Log(cert.GetDN())
|
||||
t.Log(cert.GetIssuerDN())
|
||||
t.Log("CN ", cert.CommonName())
|
||||
//t.Log("flag 0: ", cert.getCertString(0, 0))
|
||||
//t.Log("flag 1: ", cert.getCertString(0, 1))
|
||||
//t.Log("flag 2: ", cert.getCertString(0, 2))
|
||||
//t.Log("flag 3: ", cert.getCertString(0, 3))
|
||||
cert.Free()
|
||||
}
|
||||
|
||||
func _loadCert(certfile, keyfile string) (*Certificate, error) {
|
||||
return LoadX509KeyPair(certfile, keyfile)
|
||||
}
|
||||
|
||||
func TestCertGC(t *testing.T) {
|
||||
_loadCert("testdata/server.crt", "testdata/server.key")
|
||||
runtime.GC()
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
package gnutls
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGetAltname(t *testing.T) {
|
||||
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("%+v", cert)
|
||||
t.Log(cert.GetAltName())
|
||||
t.Log(cert.GetCertString())
|
||||
t.Log(cert.GetDN())
|
||||
t.Log(cert.GetIssuerDN())
|
||||
t.Log("CN ", cert.CommonName())
|
||||
//t.Log("flag 0: ", cert.getCertString(0, 0))
|
||||
//t.Log("flag 1: ", cert.getCertString(0, 1))
|
||||
//t.Log("flag 2: ", cert.getCertString(0, 2))
|
||||
//t.Log("flag 3: ", cert.getCertString(0, 3))
|
||||
cert.Free()
|
||||
}
|
||||
|
||||
func _loadCert(certfile, keyfile string) (*Certificate, error) {
|
||||
return LoadX509KeyPair(certfile, keyfile)
|
||||
}
|
||||
|
||||
func TestCertGC(t *testing.T) {
|
||||
_loadCert("testdata/server.crt", "testdata/server.key")
|
||||
runtime.GC()
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
@ -1,167 +1,167 @@
|
||||
package gnutls
|
||||
|
||||
/*
|
||||
#include "_gnutls.h"
|
||||
#include <stdlib.h>
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// CipherType cipher type
|
||||
type CipherType int
|
||||
|
||||
const (
|
||||
GNUTLS_CIPHER_AES_128_CBC CipherType = 4
|
||||
GNUTLS_CIPHER_AES_256_CBC CipherType = 5
|
||||
GNUTLS_CIPHER_ARCFOUR_40 CipherType = 6
|
||||
GNUTLS_CIPHER_CAMELLIA_128_CBC CipherType = 7
|
||||
GNUTLS_CIPHER_CAMELLIA_256_CBC CipherType = 8
|
||||
GNUTLS_CIPHER_AES_192_CBC CipherType = 9
|
||||
GNUTLS_CIPHER_AES_128_GCM CipherType = 10
|
||||
GNUTLS_CIPHER_AES_256_GCM CipherType = 11
|
||||
GNUTLS_CIPHER_CAMELLIA_192_CBC CipherType = 12
|
||||
GNUTLS_CIPHER_SALSA20_256 CipherType = 13
|
||||
GNUTLS_CIPHER_ESTREAM_SALSA20_256 CipherType = 14
|
||||
GNUTLS_CIPHER_CAMELLIA_128_GCM CipherType = 15
|
||||
GNUTLS_CIPHER_CAMELLIA_256_GCM CipherType = 16
|
||||
GNUTLS_CIPHER_RC2_40_CBC CipherType = 17
|
||||
GNUTLS_CIPHER_DES_CBC CipherType = 18
|
||||
GNUTLS_CIPHER_AES_128_CCM CipherType = 19
|
||||
GNUTLS_CIPHER_AES_256_CCM CipherType = 20
|
||||
GNUTLS_CIPHER_AES_128_CCM_8 CipherType = 21
|
||||
GNUTLS_CIPHER_AES_256_CCM_8 CipherType = 22
|
||||
GNUTLS_CIPHER_CHACHA20_POLY1305 CipherType = 23
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrBlockSize wrong block size
|
||||
ErrBlockSize = errors.New("wrong block size")
|
||||
// ErrKeyLength wrong key length
|
||||
ErrKeyLength = errors.New("wrong key length")
|
||||
// ErrIVLength wrong iv length
|
||||
ErrIVLength = errors.New("wrong iv length")
|
||||
)
|
||||
|
||||
// Cipher gnutls cipher struct
|
||||
type Cipher struct {
|
||||
cipher C.gnutls_cipher_hd_t
|
||||
t CipherType
|
||||
blockSize int
|
||||
}
|
||||
|
||||
// NewCipher create a new cipher by give type, key, iv
|
||||
//
|
||||
// example:
|
||||
// NewCipher(GNUTLS_CIPHER_AES_128_CBC, []byte("1234567890abcdef"), []byte("abcdef0123456789"))
|
||||
//
|
||||
// you can use GetCipherKeySize, GetCipherBlockSize, GetCipherIVSize to determine the given cipher 's key, block, iv size
|
||||
func NewCipher(t CipherType, key []byte, iv []byte) (*Cipher, error) {
|
||||
keysize := GetCipherKeySize(t)
|
||||
ivSize := GetCipherIVSize(t)
|
||||
blocksize := GetCipherBlockSize(t)
|
||||
//log.Printf("block size: %d, iv size: %d", int(ivSize), int(blockSize))
|
||||
if len(key) != int(keysize) {
|
||||
return nil, ErrKeyLength
|
||||
}
|
||||
|
||||
if len(iv) != int(ivSize) {
|
||||
return nil, ErrIVLength
|
||||
}
|
||||
|
||||
ckey := C.CBytes(key)
|
||||
civ := C.CBytes(iv)
|
||||
|
||||
defer C.free(ckey)
|
||||
defer C.free(civ)
|
||||
|
||||
c := C.new_cipher(C.int(t), (*C.char)(ckey), C.int(len(key)), (*C.char)(civ), C.int(len(iv)))
|
||||
if c == nil {
|
||||
log.Println("new cipher return nil")
|
||||
return nil, nil
|
||||
}
|
||||
cipher := &Cipher{c, t, blocksize}
|
||||
runtime.SetFinalizer(cipher, (*Cipher).free)
|
||||
return cipher, nil
|
||||
}
|
||||
|
||||
// Encrypt encrypt the buf and place the encrypted data in dst,
|
||||
// the buf size must multiple of cipher's block size
|
||||
func (c *Cipher) Encrypt(dst, buf []byte) error {
|
||||
if len(buf)%c.blockSize != 0 {
|
||||
return ErrBlockSize
|
||||
}
|
||||
|
||||
cbuf := C.CBytes(buf)
|
||||
defer C.free(cbuf)
|
||||
|
||||
bufLen := C.size_t(len(buf))
|
||||
dstBuf := C.malloc(bufLen)
|
||||
|
||||
defer C.free(dstBuf)
|
||||
|
||||
ret := C.gnutls_cipher_encrypt2(c.cipher, cbuf, bufLen, dstBuf, bufLen)
|
||||
if int(ret) < 0 {
|
||||
return fmt.Errorf("encrypt error: %s", C.GoString(C.gnutls_strerror(ret)))
|
||||
}
|
||||
_buf := C.GoBytes(dstBuf, C.int(bufLen))
|
||||
copy(dst, _buf)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decrypt decrypt the buf and place the decrypted data in dst,
|
||||
// the buf size must multiple of cipher's block size
|
||||
func (c *Cipher) Decrypt(dst, buf []byte) error {
|
||||
if len(buf)%c.blockSize != 0 {
|
||||
return ErrBlockSize
|
||||
}
|
||||
|
||||
cbuf := C.CBytes(buf)
|
||||
defer C.free(cbuf)
|
||||
|
||||
bufLen := C.size_t(len(buf))
|
||||
dstBuf := C.malloc(C.size_t(len(buf)))
|
||||
|
||||
defer C.free(dstBuf)
|
||||
|
||||
ret := C.gnutls_cipher_decrypt2(c.cipher, cbuf, bufLen, dstBuf, bufLen)
|
||||
if int(ret) < 0 {
|
||||
return fmt.Errorf("decrypt error: %s", C.GoString(C.gnutls_strerror(ret)))
|
||||
}
|
||||
_buf := C.GoBytes(dstBuf, C.int(bufLen))
|
||||
copy(dst, _buf)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close destroy the cipher context
|
||||
func (c *Cipher) Close() error {
|
||||
if c.cipher != nil {
|
||||
C.gnutls_cipher_deinit(c.cipher)
|
||||
c.cipher = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Cipher) free() {
|
||||
//log.Println("free cipher")
|
||||
c.Close()
|
||||
}
|
||||
|
||||
// GetCipherKeySize get the cipher algorithm key length
|
||||
func GetCipherKeySize(t CipherType) int {
|
||||
return int(C.gnutls_cipher_get_key_size(C.gnutls_cipher_algorithm_t(t)))
|
||||
}
|
||||
|
||||
// GetCipherIVSize get the cipher algorithm iv length
|
||||
func GetCipherIVSize(t CipherType) int {
|
||||
return int(C.gnutls_cipher_get_iv_size(C.gnutls_cipher_algorithm_t(t)))
|
||||
}
|
||||
|
||||
// GetCipherBlockSize get the cipher algorithm block size
|
||||
func GetCipherBlockSize(t CipherType) int {
|
||||
return int(C.gnutls_cipher_get_block_size(C.gnutls_cipher_algorithm_t(t)))
|
||||
}
|
||||
package gnutls
|
||||
|
||||
/*
|
||||
#include "_gnutls.h"
|
||||
#include <stdlib.h>
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// CipherType cipher type
|
||||
type CipherType int
|
||||
|
||||
const (
|
||||
GNUTLS_CIPHER_AES_128_CBC CipherType = 4
|
||||
GNUTLS_CIPHER_AES_256_CBC CipherType = 5
|
||||
GNUTLS_CIPHER_ARCFOUR_40 CipherType = 6
|
||||
GNUTLS_CIPHER_CAMELLIA_128_CBC CipherType = 7
|
||||
GNUTLS_CIPHER_CAMELLIA_256_CBC CipherType = 8
|
||||
GNUTLS_CIPHER_AES_192_CBC CipherType = 9
|
||||
GNUTLS_CIPHER_AES_128_GCM CipherType = 10
|
||||
GNUTLS_CIPHER_AES_256_GCM CipherType = 11
|
||||
GNUTLS_CIPHER_CAMELLIA_192_CBC CipherType = 12
|
||||
GNUTLS_CIPHER_SALSA20_256 CipherType = 13
|
||||
GNUTLS_CIPHER_ESTREAM_SALSA20_256 CipherType = 14
|
||||
GNUTLS_CIPHER_CAMELLIA_128_GCM CipherType = 15
|
||||
GNUTLS_CIPHER_CAMELLIA_256_GCM CipherType = 16
|
||||
GNUTLS_CIPHER_RC2_40_CBC CipherType = 17
|
||||
GNUTLS_CIPHER_DES_CBC CipherType = 18
|
||||
GNUTLS_CIPHER_AES_128_CCM CipherType = 19
|
||||
GNUTLS_CIPHER_AES_256_CCM CipherType = 20
|
||||
GNUTLS_CIPHER_AES_128_CCM_8 CipherType = 21
|
||||
GNUTLS_CIPHER_AES_256_CCM_8 CipherType = 22
|
||||
GNUTLS_CIPHER_CHACHA20_POLY1305 CipherType = 23
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrBlockSize wrong block size
|
||||
ErrBlockSize = errors.New("wrong block size")
|
||||
// ErrKeyLength wrong key length
|
||||
ErrKeyLength = errors.New("wrong key length")
|
||||
// ErrIVLength wrong iv length
|
||||
ErrIVLength = errors.New("wrong iv length")
|
||||
)
|
||||
|
||||
// Cipher gnutls cipher struct
|
||||
type Cipher struct {
|
||||
cipher C.gnutls_cipher_hd_t
|
||||
t CipherType
|
||||
blockSize int
|
||||
}
|
||||
|
||||
// NewCipher create a new cipher by give type, key, iv
|
||||
//
|
||||
// example:
|
||||
// NewCipher(GNUTLS_CIPHER_AES_128_CBC, []byte("1234567890abcdef"), []byte("abcdef0123456789"))
|
||||
//
|
||||
// you can use GetCipherKeySize, GetCipherBlockSize, GetCipherIVSize to determine the given cipher 's key, block, iv size
|
||||
func NewCipher(t CipherType, key []byte, iv []byte) (*Cipher, error) {
|
||||
keysize := GetCipherKeySize(t)
|
||||
ivSize := GetCipherIVSize(t)
|
||||
blocksize := GetCipherBlockSize(t)
|
||||
//log.Printf("block size: %d, iv size: %d", int(ivSize), int(blockSize))
|
||||
if len(key) != int(keysize) {
|
||||
return nil, ErrKeyLength
|
||||
}
|
||||
|
||||
if len(iv) != int(ivSize) {
|
||||
return nil, ErrIVLength
|
||||
}
|
||||
|
||||
ckey := C.CBytes(key)
|
||||
civ := C.CBytes(iv)
|
||||
|
||||
defer C.free(ckey)
|
||||
defer C.free(civ)
|
||||
|
||||
c := C.new_cipher(C.int(t), (*C.char)(ckey), C.int(len(key)), (*C.char)(civ), C.int(len(iv)))
|
||||
if c == nil {
|
||||
log.Println("new cipher return nil")
|
||||
return nil, nil
|
||||
}
|
||||
cipher := &Cipher{c, t, blocksize}
|
||||
runtime.SetFinalizer(cipher, (*Cipher).free)
|
||||
return cipher, nil
|
||||
}
|
||||
|
||||
// Encrypt encrypt the buf and place the encrypted data in dst,
|
||||
// the buf size must multiple of cipher's block size
|
||||
func (c *Cipher) Encrypt(dst, buf []byte) error {
|
||||
if len(buf)%c.blockSize != 0 {
|
||||
return ErrBlockSize
|
||||
}
|
||||
|
||||
cbuf := C.CBytes(buf)
|
||||
defer C.free(cbuf)
|
||||
|
||||
bufLen := C.size_t(len(buf))
|
||||
dstBuf := C.malloc(bufLen)
|
||||
|
||||
defer C.free(dstBuf)
|
||||
|
||||
ret := C.gnutls_cipher_encrypt2(c.cipher, cbuf, bufLen, dstBuf, bufLen)
|
||||
if int(ret) < 0 {
|
||||
return fmt.Errorf("encrypt error: %s", C.GoString(C.gnutls_strerror(ret)))
|
||||
}
|
||||
_buf := C.GoBytes(dstBuf, C.int(bufLen))
|
||||
copy(dst, _buf)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decrypt decrypt the buf and place the decrypted data in dst,
|
||||
// the buf size must multiple of cipher's block size
|
||||
func (c *Cipher) Decrypt(dst, buf []byte) error {
|
||||
if len(buf)%c.blockSize != 0 {
|
||||
return ErrBlockSize
|
||||
}
|
||||
|
||||
cbuf := C.CBytes(buf)
|
||||
defer C.free(cbuf)
|
||||
|
||||
bufLen := C.size_t(len(buf))
|
||||
dstBuf := C.malloc(C.size_t(len(buf)))
|
||||
|
||||
defer C.free(dstBuf)
|
||||
|
||||
ret := C.gnutls_cipher_decrypt2(c.cipher, cbuf, bufLen, dstBuf, bufLen)
|
||||
if int(ret) < 0 {
|
||||
return fmt.Errorf("decrypt error: %s", C.GoString(C.gnutls_strerror(ret)))
|
||||
}
|
||||
_buf := C.GoBytes(dstBuf, C.int(bufLen))
|
||||
copy(dst, _buf)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close destroy the cipher context
|
||||
func (c *Cipher) Close() error {
|
||||
if c.cipher != nil {
|
||||
C.gnutls_cipher_deinit(c.cipher)
|
||||
c.cipher = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Cipher) free() {
|
||||
//log.Println("free cipher")
|
||||
c.Close()
|
||||
}
|
||||
|
||||
// GetCipherKeySize get the cipher algorithm key length
|
||||
func GetCipherKeySize(t CipherType) int {
|
||||
return int(C.gnutls_cipher_get_key_size(C.gnutls_cipher_algorithm_t(t)))
|
||||
}
|
||||
|
||||
// GetCipherIVSize get the cipher algorithm iv length
|
||||
func GetCipherIVSize(t CipherType) int {
|
||||
return int(C.gnutls_cipher_get_iv_size(C.gnutls_cipher_algorithm_t(t)))
|
||||
}
|
||||
|
||||
// GetCipherBlockSize get the cipher algorithm block size
|
||||
func GetCipherBlockSize(t CipherType) int {
|
||||
return int(C.gnutls_cipher_get_block_size(C.gnutls_cipher_algorithm_t(t)))
|
||||
}
|
||||
|
@ -1,130 +1,130 @@
|
||||
package gnutls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCipherSize(t *testing.T) {
|
||||
data := []struct {
|
||||
t CipherType
|
||||
bsize int
|
||||
isize int
|
||||
}{
|
||||
{GNUTLS_CIPHER_AES_128_CBC, 16, 16},
|
||||
{GNUTLS_CIPHER_AES_192_CBC, 24, 16},
|
||||
{GNUTLS_CIPHER_AES_256_CBC, 32, 16},
|
||||
}
|
||||
for _, d := range data {
|
||||
blocksize := GetCipherKeySize(d.t)
|
||||
if blocksize != d.bsize {
|
||||
t.Errorf("%d block size expect: %d, got: %d", d.t, d.bsize, blocksize)
|
||||
}
|
||||
ivsize := GetCipherIVSize(d.t)
|
||||
if ivsize != d.isize {
|
||||
t.Errorf("%d iv size expect: %d, got: %d", d.t, d.bsize, ivsize)
|
||||
}
|
||||
}
|
||||
}
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
cipherName := GNUTLS_CIPHER_AES_256_CBC
|
||||
keysize := GetCipherKeySize(cipherName)
|
||||
ivsize := GetCipherIVSize(cipherName)
|
||||
blocksize := GetCipherBlockSize(cipherName)
|
||||
|
||||
key := make([]byte, keysize)
|
||||
iv := make([]byte, ivsize)
|
||||
rand.Reader.Read(key)
|
||||
rand.Reader.Read(iv)
|
||||
|
||||
c, err := NewCipher(cipherName, key, iv)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
//defer c.Close()
|
||||
|
||||
c1, err := NewCipher(cipherName, key, iv)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
//defer c1.Close()
|
||||
|
||||
data := make([]byte, blocksize*10)
|
||||
if c == nil {
|
||||
t.Fatal("new ciphoer failed")
|
||||
}
|
||||
cdata := make([]byte, len(data))
|
||||
err = c.Encrypt(cdata, data)
|
||||
if err != nil {
|
||||
t.Fatal("encrypt failed", err)
|
||||
}
|
||||
data1 := make([]byte, len(data))
|
||||
err = c1.Decrypt(data1, cdata)
|
||||
if err != nil {
|
||||
t.Fatal("decrypt failed", err)
|
||||
}
|
||||
if !bytes.Equal(data, data1) {
|
||||
t.Fatal("encrypt/decrypt failed", string(data), string(data1))
|
||||
}
|
||||
block, _ := aes.NewCipher(key)
|
||||
mode := cipher.NewCBCEncrypter(block, iv)
|
||||
dst := make([]byte, len(data))
|
||||
mode.CryptBlocks(dst, data)
|
||||
if !bytes.Equal(dst, cdata) {
|
||||
t.Fatal("cipher text not equal to cypto/aes")
|
||||
}
|
||||
runtime.GC()
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
func BenchmarkAESEncrypt(b *testing.B) {
|
||||
cipherName := GNUTLS_CIPHER_AES_256_CBC
|
||||
keysize := GetCipherKeySize(cipherName)
|
||||
ivsize := GetCipherIVSize(cipherName)
|
||||
blocksize := GetCipherBlockSize(cipherName)
|
||||
datalen := blocksize * 500
|
||||
|
||||
key := make([]byte, keysize)
|
||||
iv := make([]byte, ivsize)
|
||||
rand.Reader.Read(key)
|
||||
rand.Reader.Read(iv)
|
||||
buf := make([]byte, datalen)
|
||||
dst := make([]byte, datalen)
|
||||
for i := 0; i < b.N; i++ {
|
||||
c, err := NewCipher(cipherName, key, iv)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
c.Encrypt(dst, buf)
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAESEncrypt2(b *testing.B) {
|
||||
cipherName := GNUTLS_CIPHER_AES_256_CBC
|
||||
keysize := GetCipherKeySize(cipherName)
|
||||
ivsize := GetCipherIVSize(cipherName)
|
||||
blocksize := GetCipherBlockSize(cipherName)
|
||||
datalen := blocksize * 500
|
||||
|
||||
key := make([]byte, keysize)
|
||||
iv := make([]byte, ivsize)
|
||||
buf := make([]byte, datalen)
|
||||
|
||||
rand.Reader.Read(buf)
|
||||
rand.Reader.Read(key)
|
||||
rand.Reader.Read(iv)
|
||||
|
||||
dst := make([]byte, datalen)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
block, _ := aes.NewCipher(key)
|
||||
mode := cipher.NewCBCEncrypter(block, iv)
|
||||
mode.CryptBlocks(dst, buf)
|
||||
}
|
||||
}
|
||||
package gnutls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCipherSize(t *testing.T) {
|
||||
data := []struct {
|
||||
t CipherType
|
||||
bsize int
|
||||
isize int
|
||||
}{
|
||||
{GNUTLS_CIPHER_AES_128_CBC, 16, 16},
|
||||
{GNUTLS_CIPHER_AES_192_CBC, 24, 16},
|
||||
{GNUTLS_CIPHER_AES_256_CBC, 32, 16},
|
||||
}
|
||||
for _, d := range data {
|
||||
blocksize := GetCipherKeySize(d.t)
|
||||
if blocksize != d.bsize {
|
||||
t.Errorf("%d block size expect: %d, got: %d", d.t, d.bsize, blocksize)
|
||||
}
|
||||
ivsize := GetCipherIVSize(d.t)
|
||||
if ivsize != d.isize {
|
||||
t.Errorf("%d iv size expect: %d, got: %d", d.t, d.bsize, ivsize)
|
||||
}
|
||||
}
|
||||
}
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
cipherName := GNUTLS_CIPHER_AES_256_CBC
|
||||
keysize := GetCipherKeySize(cipherName)
|
||||
ivsize := GetCipherIVSize(cipherName)
|
||||
blocksize := GetCipherBlockSize(cipherName)
|
||||
|
||||
key := make([]byte, keysize)
|
||||
iv := make([]byte, ivsize)
|
||||
rand.Reader.Read(key)
|
||||
rand.Reader.Read(iv)
|
||||
|
||||
c, err := NewCipher(cipherName, key, iv)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
//defer c.Close()
|
||||
|
||||
c1, err := NewCipher(cipherName, key, iv)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
//defer c1.Close()
|
||||
|
||||
data := make([]byte, blocksize*10)
|
||||
if c == nil {
|
||||
t.Fatal("new ciphoer failed")
|
||||
}
|
||||
cdata := make([]byte, len(data))
|
||||
err = c.Encrypt(cdata, data)
|
||||
if err != nil {
|
||||
t.Fatal("encrypt failed", err)
|
||||
}
|
||||
data1 := make([]byte, len(data))
|
||||
err = c1.Decrypt(data1, cdata)
|
||||
if err != nil {
|
||||
t.Fatal("decrypt failed", err)
|
||||
}
|
||||
if !bytes.Equal(data, data1) {
|
||||
t.Fatal("encrypt/decrypt failed", string(data), string(data1))
|
||||
}
|
||||
block, _ := aes.NewCipher(key)
|
||||
mode := cipher.NewCBCEncrypter(block, iv)
|
||||
dst := make([]byte, len(data))
|
||||
mode.CryptBlocks(dst, data)
|
||||
if !bytes.Equal(dst, cdata) {
|
||||
t.Fatal("cipher text not equal to cypto/aes")
|
||||
}
|
||||
runtime.GC()
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
func BenchmarkAESEncrypt(b *testing.B) {
|
||||
cipherName := GNUTLS_CIPHER_AES_256_CBC
|
||||
keysize := GetCipherKeySize(cipherName)
|
||||
ivsize := GetCipherIVSize(cipherName)
|
||||
blocksize := GetCipherBlockSize(cipherName)
|
||||
datalen := blocksize * 500
|
||||
|
||||
key := make([]byte, keysize)
|
||||
iv := make([]byte, ivsize)
|
||||
rand.Reader.Read(key)
|
||||
rand.Reader.Read(iv)
|
||||
buf := make([]byte, datalen)
|
||||
dst := make([]byte, datalen)
|
||||
for i := 0; i < b.N; i++ {
|
||||
c, err := NewCipher(cipherName, key, iv)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
c.Encrypt(dst, buf)
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAESEncrypt2(b *testing.B) {
|
||||
cipherName := GNUTLS_CIPHER_AES_256_CBC
|
||||
keysize := GetCipherKeySize(cipherName)
|
||||
ivsize := GetCipherIVSize(cipherName)
|
||||
blocksize := GetCipherBlockSize(cipherName)
|
||||
datalen := blocksize * 500
|
||||
|
||||
key := make([]byte, keysize)
|
||||
iv := make([]byte, ivsize)
|
||||
buf := make([]byte, datalen)
|
||||
|
||||
rand.Reader.Read(buf)
|
||||
rand.Reader.Read(key)
|
||||
rand.Reader.Read(iv)
|
||||
|
||||
dst := make([]byte, datalen)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
block, _ := aes.NewCipher(key)
|
||||
mode := cipher.NewCBCEncrypter(block, iv)
|
||||
mode.CryptBlocks(dst, buf)
|
||||
}
|
||||
}
|
||||
|
@ -1,108 +1,108 @@
|
||||
// Package gnutls is a gnutls binding for golang.
|
||||
/*
|
||||
a limit set of api is supported.
|
||||
|
||||
TLS api is very similar to crypto/tls on standard library.
|
||||
|
||||
TLS client example:
|
||||
|
||||
addr := "127.0.0.1:9443"
|
||||
c, err := gnutls.Dial("tcp", addr, &gnutls.Config{ServerName: "localhost",InsecureSkipVerify: true})
|
||||
if err != nil {
|
||||
log.Fatal("gnutls dial ", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
data := "hello, world"
|
||||
if _, err = c.Write([]byte(data)); err != nil {
|
||||
log.Fatal("gnutls write ", err)
|
||||
}
|
||||
buf := make([]byte, 100)
|
||||
n, err := c.Read(buf)
|
||||
if err != nil {
|
||||
log.Fatal("gnutls read ", err)
|
||||
}
|
||||
|
||||
TLS Server example:
|
||||
cert, err := gnutls.LoadX509KeyPair("testdata/server/crt", "testdata/server.key")
|
||||
if err != nil{
|
||||
// handle error
|
||||
}
|
||||
l, err := gnults.Listen("tcp", "127.0.0.1:9443", &gnutls.Config{
|
||||
Certificates: []*gnutls.Certificate{cert}})
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
defer l.Close()
|
||||
for {
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
log.Println("gnutls accept ", err)
|
||||
break
|
||||
}
|
||||
log.Println("accept connection from ", c.RemoteAddr())
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, err := c.Read(buf[0:])
|
||||
if err != nil {
|
||||
log.Println("gnutls read ", err)
|
||||
break
|
||||
}
|
||||
if _, err := c.Write(buf[:n]); err != nil {
|
||||
log.Println("gnutls write ", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}(c)
|
||||
}
|
||||
|
||||
AES encrypt/decrypt example:
|
||||
|
||||
key := []byte("0123456789abcdef")
|
||||
iv := []byte("abcdefg123456789")
|
||||
c, err := gnutls.NewCipher(gnutls.GNUTLS_CIPHER_AES_128_CBC, key, iv)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
c1, err := gnutls.NewCipher(gnutls.GNUTLS_CIPHER_AES_128_CBC, key, iv)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer c1.Close()
|
||||
|
||||
data := []byte("1234012121212121")
|
||||
if c == nil {
|
||||
log.Fatal("new ciphoer failed")
|
||||
}
|
||||
|
||||
// encrypt
|
||||
dst := make([]byte, len(data))
|
||||
err := c.Encrypt(dst, data)
|
||||
if err != nil {
|
||||
log.Fatal("encrypt failed", err)
|
||||
}
|
||||
|
||||
// decrypt
|
||||
data1 := make([]byte, len(data))
|
||||
err := c1.Decrypt(data1, cdata)
|
||||
if err != nil {
|
||||
log.Fatal("decrypt failed", err)
|
||||
}
|
||||
|
||||
|
||||
Hash example:
|
||||
|
||||
h := gnutls.NewHash(gnutls.GNUTLS_HASH_SHA512)
|
||||
defer h.Close()
|
||||
|
||||
data := []byte("1234")
|
||||
|
||||
h1 := h.Sum(data)
|
||||
|
||||
*/
|
||||
package gnutls
|
||||
// Package gnutls is a gnutls binding for golang.
|
||||
/*
|
||||
a limit set of api is supported.
|
||||
|
||||
TLS api is very similar to crypto/tls on standard library.
|
||||
|
||||
TLS client example:
|
||||
|
||||
addr := "127.0.0.1:9443"
|
||||
c, err := gnutls.Dial("tcp", addr, &gnutls.Config{ServerName: "localhost",InsecureSkipVerify: true})
|
||||
if err != nil {
|
||||
log.Fatal("gnutls dial ", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
data := "hello, world"
|
||||
if _, err = c.Write([]byte(data)); err != nil {
|
||||
log.Fatal("gnutls write ", err)
|
||||
}
|
||||
buf := make([]byte, 100)
|
||||
n, err := c.Read(buf)
|
||||
if err != nil {
|
||||
log.Fatal("gnutls read ", err)
|
||||
}
|
||||
|
||||
TLS Server example:
|
||||
cert, err := gnutls.LoadX509KeyPair("testdata/server/crt", "testdata/server.key")
|
||||
if err != nil{
|
||||
// handle error
|
||||
}
|
||||
l, err := gnults.Listen("tcp", "127.0.0.1:9443", &gnutls.Config{
|
||||
Certificates: []*gnutls.Certificate{cert}})
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
defer l.Close()
|
||||
for {
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
log.Println("gnutls accept ", err)
|
||||
break
|
||||
}
|
||||
log.Println("accept connection from ", c.RemoteAddr())
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, err := c.Read(buf[0:])
|
||||
if err != nil {
|
||||
log.Println("gnutls read ", err)
|
||||
break
|
||||
}
|
||||
if _, err := c.Write(buf[:n]); err != nil {
|
||||
log.Println("gnutls write ", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}(c)
|
||||
}
|
||||
|
||||
AES encrypt/decrypt example:
|
||||
|
||||
key := []byte("0123456789abcdef")
|
||||
iv := []byte("abcdefg123456789")
|
||||
c, err := gnutls.NewCipher(gnutls.GNUTLS_CIPHER_AES_128_CBC, key, iv)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
c1, err := gnutls.NewCipher(gnutls.GNUTLS_CIPHER_AES_128_CBC, key, iv)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer c1.Close()
|
||||
|
||||
data := []byte("1234012121212121")
|
||||
if c == nil {
|
||||
log.Fatal("new ciphoer failed")
|
||||
}
|
||||
|
||||
// encrypt
|
||||
dst := make([]byte, len(data))
|
||||
err := c.Encrypt(dst, data)
|
||||
if err != nil {
|
||||
log.Fatal("encrypt failed", err)
|
||||
}
|
||||
|
||||
// decrypt
|
||||
data1 := make([]byte, len(data))
|
||||
err := c1.Decrypt(data1, cdata)
|
||||
if err != nil {
|
||||
log.Fatal("decrypt failed", err)
|
||||
}
|
||||
|
||||
|
||||
Hash example:
|
||||
|
||||
h := gnutls.NewHash(gnutls.GNUTLS_HASH_SHA512)
|
||||
defer h.Close()
|
||||
|
||||
data := []byte("1234")
|
||||
|
||||
h1 := h.Sum(data)
|
||||
|
||||
*/
|
||||
package gnutls
|
||||
|
@ -1,90 +1,90 @@
|
||||
package gnutls
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include "_gnutls.h"
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// HashType hash type
|
||||
type HashType int
|
||||
|
||||
const (
|
||||
GNUTLS_HASH_MD5 HashType = 2
|
||||
GNUTLS_HASH_SHA1 HashType = 3
|
||||
GNUTLS_HASH_MD2 HashType = 5
|
||||
GNUTLS_HASH_SHA256 HashType = 6
|
||||
GNUTLS_HASH_SHA384 HashType = 7
|
||||
GNUTLS_HASH_SHA512 HashType = 8
|
||||
GNUTLS_HASH_SHA224 HashType = 9
|
||||
)
|
||||
|
||||
// Hash hash struct
|
||||
type Hash struct {
|
||||
hash C.gnutls_hash_hd_t
|
||||
t HashType
|
||||
hashLen C.int
|
||||
}
|
||||
|
||||
// NewHash new hash struct
|
||||
func NewHash(t HashType) *Hash {
|
||||
h := C.new_hash(C.int(t))
|
||||
hashOutLen := GetHashOutputLen(t)
|
||||
hash := &Hash{h, t, C.int(hashOutLen)}
|
||||
runtime.SetFinalizer(hash, (*Hash).free)
|
||||
return hash
|
||||
}
|
||||
|
||||
// Write write data to hash context
|
||||
func (h *Hash) Write(buf []byte) error {
|
||||
dataLen := len(buf)
|
||||
|
||||
cbuf := C.CBytes(buf)
|
||||
defer C.free(cbuf)
|
||||
|
||||
ret := C.gnutls_hash(h.hash, cbuf, C.size_t(dataLen))
|
||||
if int(ret) < 0 {
|
||||
return fmt.Errorf("hash failed: %s", C.GoString(C.gnutls_strerror(ret)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sum get hash result
|
||||
func (h *Hash) Sum(buf []byte) []byte {
|
||||
if buf != nil {
|
||||
h.Write(buf)
|
||||
}
|
||||
|
||||
dstBuf := C.malloc(C.size_t(h.hashLen))
|
||||
defer C.free(dstBuf)
|
||||
|
||||
C.gnutls_hash_output(h.hash, dstBuf)
|
||||
|
||||
gobuf := C.GoBytes(dstBuf, h.hashLen)
|
||||
|
||||
return gobuf
|
||||
}
|
||||
|
||||
// Close destroy hash context
|
||||
func (h *Hash) Close() error {
|
||||
if h.hash != nil {
|
||||
C.gnutls_hash_deinit(h.hash, nil)
|
||||
h.hash = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (h *Hash) free() {
|
||||
//log.Println("free hash")
|
||||
h.Close()
|
||||
}
|
||||
|
||||
// GetHashOutputLen get the hash algorithm output length
|
||||
//
|
||||
// example GNUTLS_MD5 is 16
|
||||
func GetHashOutputLen(t HashType) int {
|
||||
return int(C.gnutls_hash_get_len(C.gnutls_digest_algorithm_t(t)))
|
||||
}
|
||||
package gnutls
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include "_gnutls.h"
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// HashType hash type
|
||||
type HashType int
|
||||
|
||||
const (
|
||||
GNUTLS_HASH_MD5 HashType = 2
|
||||
GNUTLS_HASH_SHA1 HashType = 3
|
||||
GNUTLS_HASH_MD2 HashType = 5
|
||||
GNUTLS_HASH_SHA256 HashType = 6
|
||||
GNUTLS_HASH_SHA384 HashType = 7
|
||||
GNUTLS_HASH_SHA512 HashType = 8
|
||||
GNUTLS_HASH_SHA224 HashType = 9
|
||||
)
|
||||
|
||||
// Hash hash struct
|
||||
type Hash struct {
|
||||
hash C.gnutls_hash_hd_t
|
||||
t HashType
|
||||
hashLen C.int
|
||||
}
|
||||
|
||||
// NewHash new hash struct
|
||||
func NewHash(t HashType) *Hash {
|
||||
h := C.new_hash(C.int(t))
|
||||
hashOutLen := GetHashOutputLen(t)
|
||||
hash := &Hash{h, t, C.int(hashOutLen)}
|
||||
runtime.SetFinalizer(hash, (*Hash).free)
|
||||
return hash
|
||||
}
|
||||
|
||||
// Write write data to hash context
|
||||
func (h *Hash) Write(buf []byte) error {
|
||||
dataLen := len(buf)
|
||||
|
||||
cbuf := C.CBytes(buf)
|
||||
defer C.free(cbuf)
|
||||
|
||||
ret := C.gnutls_hash(h.hash, cbuf, C.size_t(dataLen))
|
||||
if int(ret) < 0 {
|
||||
return fmt.Errorf("hash failed: %s", C.GoString(C.gnutls_strerror(ret)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sum get hash result
|
||||
func (h *Hash) Sum(buf []byte) []byte {
|
||||
if buf != nil {
|
||||
h.Write(buf)
|
||||
}
|
||||
|
||||
dstBuf := C.malloc(C.size_t(h.hashLen))
|
||||
defer C.free(dstBuf)
|
||||
|
||||
C.gnutls_hash_output(h.hash, dstBuf)
|
||||
|
||||
gobuf := C.GoBytes(dstBuf, h.hashLen)
|
||||
|
||||
return gobuf
|
||||
}
|
||||
|
||||
// Close destroy hash context
|
||||
func (h *Hash) Close() error {
|
||||
if h.hash != nil {
|
||||
C.gnutls_hash_deinit(h.hash, nil)
|
||||
h.hash = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (h *Hash) free() {
|
||||
//log.Println("free hash")
|
||||
h.Close()
|
||||
}
|
||||
|
||||
// GetHashOutputLen get the hash algorithm output length
|
||||
//
|
||||
// example GNUTLS_MD5 is 16
|
||||
func GetHashOutputLen(t HashType) int {
|
||||
return int(C.gnutls_hash_get_len(C.gnutls_digest_algorithm_t(t)))
|
||||
}
|
||||
|
@ -1,52 +1,52 @@
|
||||
package gnutls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha512"
|
||||
"encoding/hex"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestHashSHA(t *testing.T) {
|
||||
h := NewHash(GNUTLS_HASH_SHA512)
|
||||
//defer h.Close()
|
||||
|
||||
data := []byte("1234")
|
||||
|
||||
h1 := h.Sum(data)
|
||||
|
||||
h3 := sha512.New()
|
||||
h3.Write(data)
|
||||
h4 := h3.Sum(nil)
|
||||
if !bytes.Equal(h4[:], h1) {
|
||||
t.Logf("\n%s\n%s", hex.EncodeToString(h4[:]), hex.EncodeToString(h1))
|
||||
t.Fatal("hash not equal")
|
||||
}
|
||||
runtime.GC()
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
func BenchmarkHashSHA512(b *testing.B) {
|
||||
buf := make([]byte, 1000*1024)
|
||||
rand.Reader.Read(buf)
|
||||
for i := 0; i < b.N; i++ {
|
||||
h := NewHash(GNUTLS_HASH_SHA512)
|
||||
h.Write(buf)
|
||||
h.Sum(nil)
|
||||
h.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHashSHA512s(b *testing.B) {
|
||||
buf := make([]byte, 1000*1024)
|
||||
rand.Reader.Read(buf)
|
||||
for i := 0; i < b.N; i++ {
|
||||
h := sha512.New()
|
||||
h.Write(buf)
|
||||
h.Sum(nil)
|
||||
//h.Close()
|
||||
}
|
||||
}
|
||||
package gnutls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha512"
|
||||
"encoding/hex"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestHashSHA(t *testing.T) {
|
||||
h := NewHash(GNUTLS_HASH_SHA512)
|
||||
//defer h.Close()
|
||||
|
||||
data := []byte("1234")
|
||||
|
||||
h1 := h.Sum(data)
|
||||
|
||||
h3 := sha512.New()
|
||||
h3.Write(data)
|
||||
h4 := h3.Sum(nil)
|
||||
if !bytes.Equal(h4[:], h1) {
|
||||
t.Logf("\n%s\n%s", hex.EncodeToString(h4[:]), hex.EncodeToString(h1))
|
||||
t.Fatal("hash not equal")
|
||||
}
|
||||
runtime.GC()
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
func BenchmarkHashSHA512(b *testing.B) {
|
||||
buf := make([]byte, 1000*1024)
|
||||
rand.Reader.Read(buf)
|
||||
for i := 0; i < b.N; i++ {
|
||||
h := NewHash(GNUTLS_HASH_SHA512)
|
||||
h.Write(buf)
|
||||
h.Sum(nil)
|
||||
h.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHashSHA512s(b *testing.B) {
|
||||
buf := make([]byte, 1000*1024)
|
||||
rand.Reader.Read(buf)
|
||||
for i := 0; i < b.N; i++ {
|
||||
h := sha512.New()
|
||||
h.Write(buf)
|
||||
h.Sum(nil)
|
||||
//h.Close()
|
||||
}
|
||||
}
|
||||
|
@ -1,492 +1,492 @@
|
||||
package gnutls
|
||||
|
||||
/*
|
||||
#include "_gnutls.h"
|
||||
#cgo pkg-config: gnutls
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
GNUTLS_NAME_DNS = 1
|
||||
GNUTLS_X509_FMT_PEM = 1
|
||||
GNUTLS_ALPN_MANDATORY = 1
|
||||
GNUTLS_ALPN_SERVER_PRECEDENCE = 1 << 1
|
||||
)
|
||||
|
||||
// Conn gnutls TLS connection,
|
||||
// use Listen, Dial, Server or Client create this object
|
||||
type Conn struct {
|
||||
c net.Conn
|
||||
handshake bool
|
||||
sess *C.struct_session
|
||||
cservname *C.char
|
||||
state *ConnectionState
|
||||
cfg *Config
|
||||
closed bool
|
||||
lock *sync.Mutex
|
||||
}
|
||||
|
||||
// Config gnutls TLS configure,
|
||||
type Config struct {
|
||||
ServerName string
|
||||
Certificates []*Certificate
|
||||
InsecureSkipVerify bool
|
||||
NextProtos []string
|
||||
}
|
||||
|
||||
// ConnectionState gnutls TLS connection state
|
||||
type ConnectionState struct {
|
||||
// SNI name client send
|
||||
ServerName string
|
||||
// selected ALPN protocl
|
||||
NegotiatedProtocol string
|
||||
HandshakeComplete bool
|
||||
// TLS version number, ex: 0x303
|
||||
Version uint16
|
||||
// TLS version number, ex: TLS1.0
|
||||
VersionName string
|
||||
// peer's certificate
|
||||
PeerCertificate *Certificate
|
||||
}
|
||||
|
||||
type listener struct {
|
||||
l net.Listener
|
||||
c *Config
|
||||
}
|
||||
|
||||
// Accept
|
||||
func (l *listener) Accept() (net.Conn, error) {
|
||||
c, err := l.l.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Server(c, l.c)
|
||||
}
|
||||
|
||||
// Close
|
||||
func (l *listener) Close() error {
|
||||
return l.l.Close()
|
||||
}
|
||||
|
||||
// Addr
|
||||
func (l *listener) Addr() net.Addr {
|
||||
return l.l.Addr()
|
||||
}
|
||||
|
||||
// Dial dial to (network, addr) and create a gnutls Conn
|
||||
func Dial(network, addr string, cfg *Config) (*Conn, error) {
|
||||
c, err := net.Dial(network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Client(c, cfg)
|
||||
}
|
||||
|
||||
// Listen create a gnutls listener on (network, addr),
|
||||
func Listen(network, addr string, cfg *Config) (net.Listener, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is needed")
|
||||
}
|
||||
l, err := net.Listen(network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &listener{l, cfg}, nil
|
||||
}
|
||||
|
||||
// Server create a server TLS Conn on c
|
||||
func Server(c net.Conn, cfg *Config) (*Conn, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is needed")
|
||||
}
|
||||
|
||||
var sess = C.init_gnutls_server_session()
|
||||
|
||||
conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)}
|
||||
|
||||
C.set_data(sess, C.size_t(uintptr(unsafe.Pointer(conn))))
|
||||
C.set_callback(sess)
|
||||
|
||||
if cfg.NextProtos != nil {
|
||||
if err := setAlpnProtocols(sess, cfg); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
runtime.SetFinalizer(conn, (*Conn).free)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Client create a client TLS Conn on c
|
||||
func Client(c net.Conn, cfg *Config) (*Conn, error) {
|
||||
var sess = C.init_gnutls_client_session()
|
||||
|
||||
conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)}
|
||||
|
||||
C.set_data(sess, C.size_t(uintptr(unsafe.Pointer(conn))))
|
||||
C.set_callback(sess)
|
||||
|
||||
if cfg != nil {
|
||||
if cfg.ServerName != "" {
|
||||
srvname := C.CString(cfg.ServerName)
|
||||
//defer C.free(unsafe.Pointer(srvname))
|
||||
conn.cservname = srvname
|
||||
C.gnutls_server_name_set(sess.session, GNUTLS_NAME_DNS,
|
||||
unsafe.Pointer(srvname), C.size_t(len(cfg.ServerName)))
|
||||
}
|
||||
|
||||
if !cfg.InsecureSkipVerify {
|
||||
if conn.cservname != nil {
|
||||
C.gnutls_session_set_verify_cert(sess.session, conn.cservname, 0)
|
||||
} else {
|
||||
C.gnutls_session_set_verify_cert(sess.session, nil, 0)
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.NextProtos != nil {
|
||||
if err := setAlpnProtocols(sess, cfg); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
C.gnutls_session_set_verify_cert(sess.session, nil, 0)
|
||||
}
|
||||
runtime.SetFinalizer(conn, (*Conn).free)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func setAlpnProtocols(sess *C.struct_session, cfg *Config) error {
|
||||
arg := make([](*C.char), 0)
|
||||
for _, s := range cfg.NextProtos {
|
||||
cbuf := C.CString(s)
|
||||
defer C.free(unsafe.Pointer(cbuf))
|
||||
arg = append(arg, (*C.char)(cbuf))
|
||||
}
|
||||
ret := C.alpn_set_protocols(sess,
|
||||
(**C.char)(unsafe.Pointer(&arg[0])), C.int(len(cfg.NextProtos)))
|
||||
if int(ret) < 0 {
|
||||
return fmt.Errorf("set alpn failed: %s", C.GoString(C.gnutls_strerror(ret)))
|
||||
}
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// Handshake call handshake for TLS Conn,
|
||||
// this function will call automatic on Read/Write, if not handshake yet
|
||||
func (c *Conn) Handshake() error {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
if c.handshake {
|
||||
return nil
|
||||
}
|
||||
ret := C.handshake(c.sess)
|
||||
if int(ret) < 0 {
|
||||
return fmt.Errorf("handshake error: %s", C.GoString(C.gnutls_strerror(ret)))
|
||||
}
|
||||
c.handshake = true
|
||||
//log.Println("handshake done")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read read application data from TLS connection
|
||||
func (c *Conn) Read(buf []byte) (n int, err error) {
|
||||
err = c.Handshake()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if len(buf) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if c.closed {
|
||||
return 0, io.EOF
|
||||
}
|
||||
//bufLen := len(buf)
|
||||
//cbuf := C.malloc(C.size_t(bufLen))
|
||||
//defer C.free(cbuf)
|
||||
|
||||
ret := C.gnutls_record_recv(c.sess.session,
|
||||
unsafe.Pointer(&buf[0]), C.size_t(len(buf)))
|
||||
if int(ret) < 0 {
|
||||
return 0, fmt.Errorf("read error: %s",
|
||||
C.GoString(C.gnutls_strerror(C.int(ret))))
|
||||
}
|
||||
|
||||
if int(ret) == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
n = int(ret)
|
||||
//gobuf2 := C.GoBytes(cbuf, C.int(ret))
|
||||
//copy(buf, gobuf2)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Write write application data to TLS connection
|
||||
func (c *Conn) Write(buf []byte) (n int, err error) {
|
||||
err = c.Handshake()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// user may call Write(nil) to do handshake
|
||||
if len(buf) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if c.closed {
|
||||
return 0, io.EOF
|
||||
}
|
||||
//cbuf := C.CBytes(buf)
|
||||
//defer C.free(cbuf)
|
||||
|
||||
ret := C.gnutls_record_send(c.sess.session,
|
||||
unsafe.Pointer(&buf[0]), C.size_t(len(buf)))
|
||||
n = int(ret)
|
||||
|
||||
if n < 0 {
|
||||
return 0, fmt.Errorf("write error: %s",
|
||||
C.GoString(C.gnutls_strerror(C.int(ret))))
|
||||
}
|
||||
|
||||
if int(ret) == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Close close the TLS conn and destroy the tls context
|
||||
func (c *Conn) Close() error {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
if c.closed {
|
||||
return nil
|
||||
}
|
||||
C.gnutls_record_send(c.sess.session, nil, 0)
|
||||
C.session_destroy(c.sess)
|
||||
c.c.Close()
|
||||
if c.cservname != nil {
|
||||
C.free(unsafe.Pointer(c.cservname))
|
||||
}
|
||||
|
||||
if c.state != nil && c.state.PeerCertificate != nil {
|
||||
c.state.PeerCertificate.Free()
|
||||
}
|
||||
c.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) free() {
|
||||
//log.Println("free conn")
|
||||
c.Close()
|
||||
}
|
||||
|
||||
// SetWriteDeadline implements net.Conn
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
return c.c.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
// SetReadDeadline implements net.Conn
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
return c.c.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
// RemoteAddr implements net.Conn
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.c.RemoteAddr()
|
||||
}
|
||||
|
||||
// LocalAddr implements net.Conn
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return c.c.LocalAddr()
|
||||
}
|
||||
|
||||
// SetDeadline implements net.Conn
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
return c.c.SetDeadline(t)
|
||||
}
|
||||
|
||||
// ConnectionState get TLS connection state
|
||||
func (c *Conn) ConnectionState() *ConnectionState {
|
||||
if c.state != nil {
|
||||
return c.state
|
||||
}
|
||||
version :=
|
||||
uint16(C.gnutls_protocol_get_version(c.sess.session))
|
||||
|
||||
versionname := C.GoString(
|
||||
C.gnutls_protocol_get_name(C.gnutls_protocol_t(version)))
|
||||
|
||||
state := &ConnectionState{
|
||||
NegotiatedProtocol: c.getAlpnSelectedProtocol(),
|
||||
Version: version,
|
||||
HandshakeComplete: c.handshake,
|
||||
ServerName: c.getServerName(),
|
||||
VersionName: versionname,
|
||||
PeerCertificate: c.getPeerCertificate(),
|
||||
}
|
||||
c.state = state
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *Conn) getPeerCertificate() *Certificate {
|
||||
var size int
|
||||
st := C.get_peer_certificate(c.sess.session, (*C.int)(unsafe.Pointer(&size)))
|
||||
if st == nil {
|
||||
return nil
|
||||
}
|
||||
cert := &Certificate{cert: st, certSize: C.int(size)}
|
||||
runtime.SetFinalizer(cert, (*Certificate).free)
|
||||
return cert
|
||||
}
|
||||
|
||||
func (c *Conn) getAlpnSelectedProtocol() string {
|
||||
cbuf := C.malloc(100)
|
||||
defer C.free(cbuf)
|
||||
|
||||
ret := C.alpn_get_selected_protocol(c.sess, (*C.char)(cbuf))
|
||||
if int(ret) < 0 {
|
||||
return ""
|
||||
}
|
||||
alpnname := C.GoString((*C.char)(cbuf))
|
||||
return alpnname
|
||||
}
|
||||
|
||||
func (c *Conn) getServerName() string {
|
||||
buflen := 100
|
||||
nametype := GNUTLS_NAME_DNS
|
||||
cbuf := C.malloc(C.size_t(buflen))
|
||||
defer C.free(cbuf)
|
||||
|
||||
ret := C.gnutls_server_name_get(c.sess.session, cbuf,
|
||||
(*C.size_t)(unsafe.Pointer(&buflen)),
|
||||
(*C.uint)(unsafe.Pointer(&nametype)), 0)
|
||||
if int(ret) < 0 {
|
||||
return ""
|
||||
}
|
||||
name := C.GoString((*C.char)(cbuf))
|
||||
return name
|
||||
}
|
||||
|
||||
/*
|
||||
var bufPool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, 16*1024)
|
||||
},
|
||||
}
|
||||
*/
|
||||
|
||||
// onDataReadCallback callback function for gnutls library want to read data from network
|
||||
//
|
||||
//export onDataReadCallback
|
||||
func onDataReadCallback(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int {
|
||||
//log.Println("read addr ", uintptr(d))
|
||||
conn := (*Conn)(unsafe.Pointer((uintptr(d))))
|
||||
|
||||
buf := make([]byte, int(bufLen))
|
||||
/*
|
||||
_length := int(bufLen)
|
||||
buf := bufPool.Get().([]byte)
|
||||
if len(buf) < _length {
|
||||
buf = make([]byte, _length)
|
||||
}
|
||||
defer bufPool.Put(buf)
|
||||
*/
|
||||
|
||||
n, err := conn.c.Read(buf[0:])
|
||||
if err != nil {
|
||||
//log.Println(err)
|
||||
// 0 indicates connection termination
|
||||
return 0
|
||||
}
|
||||
//cbuf2 := C.CBytes(buf[:n])
|
||||
// d := C.CString(string(buf[:n]))
|
||||
//defer C.free(cbuf2)
|
||||
C.memcpy(unsafe.Pointer(cbuf), unsafe.Pointer(&buf[0]), C.size_t(n))
|
||||
return C.int(n)
|
||||
}
|
||||
|
||||
// onDataWriteCallback callback function for gnutls library want to send data to network
|
||||
//
|
||||
//export onDataWriteCallback
|
||||
func onDataWriteCallback(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int {
|
||||
//log.Println("write addr ", uintptr(d), int(_l))
|
||||
conn := (*Conn)(unsafe.Pointer((uintptr(d))))
|
||||
gobuf := C.GoBytes(unsafe.Pointer(cbuf), bufLen)
|
||||
n, err := conn.c.Write(gobuf)
|
||||
if err != nil {
|
||||
//log.Println(err)
|
||||
return -1
|
||||
}
|
||||
return C.int(n)
|
||||
}
|
||||
|
||||
// onDataTimeoutRead callback function for timeout read
|
||||
//
|
||||
//export onDataTimeoutRead
|
||||
func onDataTimeoutRead(d unsafe.Pointer, delay C.int) C.int {
|
||||
log.Println("timeout pull function")
|
||||
return 0
|
||||
}
|
||||
|
||||
// onCertSelectCallback callback function for ceritificate select,
|
||||
// this function select certificate from Config.Certificates field,
|
||||
//
|
||||
// on server side, this function select the certificate depend on SNI what client send,
|
||||
// if client not send SNI, select the Config.Certificates[0]
|
||||
//
|
||||
//export onCertSelectCallback
|
||||
func onCertSelectCallback(ptr unsafe.Pointer, hostname *C.char,
|
||||
namelen C.int, pcertLength *C.int, cert **C.gnutls_pcert_st, privkey *C.gnutls_privkey_t) C.int {
|
||||
|
||||
servername := C.GoStringN(hostname, namelen)
|
||||
//log.Println("go cert select callback ", servername)
|
||||
conn := (*Conn)(unsafe.Pointer((uintptr(ptr))))
|
||||
//log.Println(conn)
|
||||
if int(namelen) == 0 && conn.cfg.Certificates != nil {
|
||||
_cert := conn.cfg.Certificates[0]
|
||||
*pcertLength = _cert.certSize
|
||||
*cert = _cert.cert
|
||||
*privkey = _cert.privkey
|
||||
//log.Println("set pcert length ", _cert.certSize)
|
||||
return 0
|
||||
}
|
||||
for _, _cert := range conn.cfg.Certificates {
|
||||
//log.Println(cert)
|
||||
if _cert.matchName(servername) {
|
||||
//log.Println("matched name ", _cert.names)
|
||||
*pcertLength = _cert.certSize
|
||||
*cert = _cert.cert
|
||||
*privkey = _cert.privkey
|
||||
//log.Println("set pcert length ", _cert.certSize)
|
||||
return 0
|
||||
}
|
||||
}
|
||||
if conn.cfg.Certificates != nil {
|
||||
_cert := conn.cfg.Certificates[0]
|
||||
*pcertLength = _cert.certSize
|
||||
*cert = _cert.cert
|
||||
*privkey = _cert.privkey
|
||||
//log.Println("set pcert length ", _cert.certSize)
|
||||
return 0
|
||||
}
|
||||
*pcertLength = 0
|
||||
//log.Println("set pcert length 0")
|
||||
return -1
|
||||
}
|
||||
|
||||
func init() {
|
||||
C.init_xcred()
|
||||
C.init_priority_cache()
|
||||
}
|
||||
package gnutls
|
||||
|
||||
/*
|
||||
#include "_gnutls.h"
|
||||
#cgo pkg-config: gnutls
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
GNUTLS_NAME_DNS = 1
|
||||
GNUTLS_X509_FMT_PEM = 1
|
||||
GNUTLS_ALPN_MANDATORY = 1
|
||||
GNUTLS_ALPN_SERVER_PRECEDENCE = 1 << 1
|
||||
)
|
||||
|
||||
// Conn gnutls TLS connection,
|
||||
// use Listen, Dial, Server or Client create this object
|
||||
type Conn struct {
|
||||
c net.Conn
|
||||
handshake bool
|
||||
sess *C.struct_session
|
||||
cservname *C.char
|
||||
state *ConnectionState
|
||||
cfg *Config
|
||||
closed bool
|
||||
lock *sync.Mutex
|
||||
}
|
||||
|
||||
// Config gnutls TLS configure,
|
||||
type Config struct {
|
||||
ServerName string
|
||||
Certificates []*Certificate
|
||||
InsecureSkipVerify bool
|
||||
NextProtos []string
|
||||
}
|
||||
|
||||
// ConnectionState gnutls TLS connection state
|
||||
type ConnectionState struct {
|
||||
// SNI name client send
|
||||
ServerName string
|
||||
// selected ALPN protocl
|
||||
NegotiatedProtocol string
|
||||
HandshakeComplete bool
|
||||
// TLS version number, ex: 0x303
|
||||
Version uint16
|
||||
// TLS version number, ex: TLS1.0
|
||||
VersionName string
|
||||
// peer's certificate
|
||||
PeerCertificate *Certificate
|
||||
}
|
||||
|
||||
type listener struct {
|
||||
l net.Listener
|
||||
c *Config
|
||||
}
|
||||
|
||||
// Accept
|
||||
func (l *listener) Accept() (net.Conn, error) {
|
||||
c, err := l.l.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Server(c, l.c)
|
||||
}
|
||||
|
||||
// Close
|
||||
func (l *listener) Close() error {
|
||||
return l.l.Close()
|
||||
}
|
||||
|
||||
// Addr
|
||||
func (l *listener) Addr() net.Addr {
|
||||
return l.l.Addr()
|
||||
}
|
||||
|
||||
// Dial dial to (network, addr) and create a gnutls Conn
|
||||
func Dial(network, addr string, cfg *Config) (*Conn, error) {
|
||||
c, err := net.Dial(network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Client(c, cfg)
|
||||
}
|
||||
|
||||
// Listen create a gnutls listener on (network, addr),
|
||||
func Listen(network, addr string, cfg *Config) (net.Listener, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is needed")
|
||||
}
|
||||
l, err := net.Listen(network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &listener{l, cfg}, nil
|
||||
}
|
||||
|
||||
// Server create a server TLS Conn on c
|
||||
func Server(c net.Conn, cfg *Config) (*Conn, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is needed")
|
||||
}
|
||||
|
||||
var sess = C.init_gnutls_server_session()
|
||||
|
||||
conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)}
|
||||
|
||||
C.set_data(sess, C.size_t(uintptr(unsafe.Pointer(conn))))
|
||||
C.set_callback(sess)
|
||||
|
||||
if cfg.NextProtos != nil {
|
||||
if err := setAlpnProtocols(sess, cfg); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
runtime.SetFinalizer(conn, (*Conn).free)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Client create a client TLS Conn on c
|
||||
func Client(c net.Conn, cfg *Config) (*Conn, error) {
|
||||
var sess = C.init_gnutls_client_session()
|
||||
|
||||
conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)}
|
||||
|
||||
C.set_data(sess, C.size_t(uintptr(unsafe.Pointer(conn))))
|
||||
C.set_callback(sess)
|
||||
|
||||
if cfg != nil {
|
||||
if cfg.ServerName != "" {
|
||||
srvname := C.CString(cfg.ServerName)
|
||||
//defer C.free(unsafe.Pointer(srvname))
|
||||
conn.cservname = srvname
|
||||
C.gnutls_server_name_set(sess.session, GNUTLS_NAME_DNS,
|
||||
unsafe.Pointer(srvname), C.size_t(len(cfg.ServerName)))
|
||||
}
|
||||
|
||||
if !cfg.InsecureSkipVerify {
|
||||
if conn.cservname != nil {
|
||||
C.gnutls_session_set_verify_cert(sess.session, conn.cservname, 0)
|
||||
} else {
|
||||
C.gnutls_session_set_verify_cert(sess.session, nil, 0)
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.NextProtos != nil {
|
||||
if err := setAlpnProtocols(sess, cfg); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
C.gnutls_session_set_verify_cert(sess.session, nil, 0)
|
||||
}
|
||||
runtime.SetFinalizer(conn, (*Conn).free)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func setAlpnProtocols(sess *C.struct_session, cfg *Config) error {
|
||||
arg := make([](*C.char), 0)
|
||||
for _, s := range cfg.NextProtos {
|
||||
cbuf := C.CString(s)
|
||||
defer C.free(unsafe.Pointer(cbuf))
|
||||
arg = append(arg, (*C.char)(cbuf))
|
||||
}
|
||||
ret := C.alpn_set_protocols(sess,
|
||||
(**C.char)(unsafe.Pointer(&arg[0])), C.int(len(cfg.NextProtos)))
|
||||
if int(ret) < 0 {
|
||||
return fmt.Errorf("set alpn failed: %s", C.GoString(C.gnutls_strerror(ret)))
|
||||
}
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// Handshake call handshake for TLS Conn,
|
||||
// this function will call automatic on Read/Write, if not handshake yet
|
||||
func (c *Conn) Handshake() error {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
if c.handshake {
|
||||
return nil
|
||||
}
|
||||
ret := C.handshake(c.sess)
|
||||
if int(ret) < 0 {
|
||||
return fmt.Errorf("handshake error: %s", C.GoString(C.gnutls_strerror(ret)))
|
||||
}
|
||||
c.handshake = true
|
||||
//log.Println("handshake done")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read read application data from TLS connection
|
||||
func (c *Conn) Read(buf []byte) (n int, err error) {
|
||||
err = c.Handshake()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if len(buf) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if c.closed {
|
||||
return 0, io.EOF
|
||||
}
|
||||
//bufLen := len(buf)
|
||||
//cbuf := C.malloc(C.size_t(bufLen))
|
||||
//defer C.free(cbuf)
|
||||
|
||||
ret := C.gnutls_record_recv(c.sess.session,
|
||||
unsafe.Pointer(&buf[0]), C.size_t(len(buf)))
|
||||
if int(ret) < 0 {
|
||||
return 0, fmt.Errorf("read error: %s",
|
||||
C.GoString(C.gnutls_strerror(C.int(ret))))
|
||||
}
|
||||
|
||||
if int(ret) == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
n = int(ret)
|
||||
//gobuf2 := C.GoBytes(cbuf, C.int(ret))
|
||||
//copy(buf, gobuf2)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Write write application data to TLS connection
|
||||
func (c *Conn) Write(buf []byte) (n int, err error) {
|
||||
err = c.Handshake()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// user may call Write(nil) to do handshake
|
||||
if len(buf) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if c.closed {
|
||||
return 0, io.EOF
|
||||
}
|
||||
//cbuf := C.CBytes(buf)
|
||||
//defer C.free(cbuf)
|
||||
|
||||
ret := C.gnutls_record_send(c.sess.session,
|
||||
unsafe.Pointer(&buf[0]), C.size_t(len(buf)))
|
||||
n = int(ret)
|
||||
|
||||
if n < 0 {
|
||||
return 0, fmt.Errorf("write error: %s",
|
||||
C.GoString(C.gnutls_strerror(C.int(ret))))
|
||||
}
|
||||
|
||||
if int(ret) == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Close close the TLS conn and destroy the tls context
|
||||
func (c *Conn) Close() error {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
if c.closed {
|
||||
return nil
|
||||
}
|
||||
C.gnutls_record_send(c.sess.session, nil, 0)
|
||||
C.session_destroy(c.sess)
|
||||
c.c.Close()
|
||||
if c.cservname != nil {
|
||||
C.free(unsafe.Pointer(c.cservname))
|
||||
}
|
||||
|
||||
if c.state != nil && c.state.PeerCertificate != nil {
|
||||
c.state.PeerCertificate.Free()
|
||||
}
|
||||
c.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) free() {
|
||||
//log.Println("free conn")
|
||||
c.Close()
|
||||
}
|
||||
|
||||
// SetWriteDeadline implements net.Conn
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
return c.c.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
// SetReadDeadline implements net.Conn
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
return c.c.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
// RemoteAddr implements net.Conn
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.c.RemoteAddr()
|
||||
}
|
||||
|
||||
// LocalAddr implements net.Conn
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return c.c.LocalAddr()
|
||||
}
|
||||
|
||||
// SetDeadline implements net.Conn
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
return c.c.SetDeadline(t)
|
||||
}
|
||||
|
||||
// ConnectionState get TLS connection state
|
||||
func (c *Conn) ConnectionState() *ConnectionState {
|
||||
if c.state != nil {
|
||||
return c.state
|
||||
}
|
||||
version :=
|
||||
uint16(C.gnutls_protocol_get_version(c.sess.session))
|
||||
|
||||
versionname := C.GoString(
|
||||
C.gnutls_protocol_get_name(C.gnutls_protocol_t(version)))
|
||||
|
||||
state := &ConnectionState{
|
||||
NegotiatedProtocol: c.getAlpnSelectedProtocol(),
|
||||
Version: version,
|
||||
HandshakeComplete: c.handshake,
|
||||
ServerName: c.getServerName(),
|
||||
VersionName: versionname,
|
||||
PeerCertificate: c.getPeerCertificate(),
|
||||
}
|
||||
c.state = state
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *Conn) getPeerCertificate() *Certificate {
|
||||
var size int
|
||||
st := C.get_peer_certificate(c.sess.session, (*C.int)(unsafe.Pointer(&size)))
|
||||
if st == nil {
|
||||
return nil
|
||||
}
|
||||
cert := &Certificate{cert: st, certSize: C.int(size)}
|
||||
runtime.SetFinalizer(cert, (*Certificate).free)
|
||||
return cert
|
||||
}
|
||||
|
||||
func (c *Conn) getAlpnSelectedProtocol() string {
|
||||
cbuf := C.malloc(100)
|
||||
defer C.free(cbuf)
|
||||
|
||||
ret := C.alpn_get_selected_protocol(c.sess, (*C.char)(cbuf))
|
||||
if int(ret) < 0 {
|
||||
return ""
|
||||
}
|
||||
alpnname := C.GoString((*C.char)(cbuf))
|
||||
return alpnname
|
||||
}
|
||||
|
||||
func (c *Conn) getServerName() string {
|
||||
buflen := 100
|
||||
nametype := GNUTLS_NAME_DNS
|
||||
cbuf := C.malloc(C.size_t(buflen))
|
||||
defer C.free(cbuf)
|
||||
|
||||
ret := C.gnutls_server_name_get(c.sess.session, cbuf,
|
||||
(*C.size_t)(unsafe.Pointer(&buflen)),
|
||||
(*C.uint)(unsafe.Pointer(&nametype)), 0)
|
||||
if int(ret) < 0 {
|
||||
return ""
|
||||
}
|
||||
name := C.GoString((*C.char)(cbuf))
|
||||
return name
|
||||
}
|
||||
|
||||
/*
|
||||
var bufPool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, 16*1024)
|
||||
},
|
||||
}
|
||||
*/
|
||||
|
||||
// onDataReadCallback callback function for gnutls library want to read data from network
|
||||
//
|
||||
//export onDataReadCallback
|
||||
func onDataReadCallback(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int {
|
||||
//log.Println("read addr ", uintptr(d))
|
||||
conn := (*Conn)(unsafe.Pointer((uintptr(d))))
|
||||
|
||||
buf := make([]byte, int(bufLen))
|
||||
/*
|
||||
_length := int(bufLen)
|
||||
buf := bufPool.Get().([]byte)
|
||||
if len(buf) < _length {
|
||||
buf = make([]byte, _length)
|
||||
}
|
||||
defer bufPool.Put(buf)
|
||||
*/
|
||||
|
||||
n, err := conn.c.Read(buf[0:])
|
||||
if err != nil {
|
||||
//log.Println(err)
|
||||
// 0 indicates connection termination
|
||||
return 0
|
||||
}
|
||||
//cbuf2 := C.CBytes(buf[:n])
|
||||
// d := C.CString(string(buf[:n]))
|
||||
//defer C.free(cbuf2)
|
||||
C.memcpy(unsafe.Pointer(cbuf), unsafe.Pointer(&buf[0]), C.size_t(n))
|
||||
return C.int(n)
|
||||
}
|
||||
|
||||
// onDataWriteCallback callback function for gnutls library want to send data to network
|
||||
//
|
||||
//export onDataWriteCallback
|
||||
func onDataWriteCallback(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int {
|
||||
//log.Println("write addr ", uintptr(d), int(_l))
|
||||
conn := (*Conn)(unsafe.Pointer((uintptr(d))))
|
||||
gobuf := C.GoBytes(unsafe.Pointer(cbuf), bufLen)
|
||||
n, err := conn.c.Write(gobuf)
|
||||
if err != nil {
|
||||
//log.Println(err)
|
||||
return -1
|
||||
}
|
||||
return C.int(n)
|
||||
}
|
||||
|
||||
// onDataTimeoutRead callback function for timeout read
|
||||
//
|
||||
//export onDataTimeoutRead
|
||||
func onDataTimeoutRead(d unsafe.Pointer, delay C.int) C.int {
|
||||
log.Println("timeout pull function")
|
||||
return 0
|
||||
}
|
||||
|
||||
// onCertSelectCallback callback function for ceritificate select,
|
||||
// this function select certificate from Config.Certificates field,
|
||||
//
|
||||
// on server side, this function select the certificate depend on SNI what client send,
|
||||
// if client not send SNI, select the Config.Certificates[0]
|
||||
//
|
||||
//export onCertSelectCallback
|
||||
func onCertSelectCallback(ptr unsafe.Pointer, hostname *C.char,
|
||||
namelen C.int, pcertLength *C.int, cert **C.gnutls_pcert_st, privkey *C.gnutls_privkey_t) C.int {
|
||||
|
||||
servername := C.GoStringN(hostname, namelen)
|
||||
//log.Println("go cert select callback ", servername)
|
||||
conn := (*Conn)(unsafe.Pointer((uintptr(ptr))))
|
||||
//log.Println(conn)
|
||||
if int(namelen) == 0 && conn.cfg.Certificates != nil {
|
||||
_cert := conn.cfg.Certificates[0]
|
||||
*pcertLength = _cert.certSize
|
||||
*cert = _cert.cert
|
||||
*privkey = _cert.privkey
|
||||
//log.Println("set pcert length ", _cert.certSize)
|
||||
return 0
|
||||
}
|
||||
for _, _cert := range conn.cfg.Certificates {
|
||||
//log.Println(cert)
|
||||
if _cert.matchName(servername) {
|
||||
//log.Println("matched name ", _cert.names)
|
||||
*pcertLength = _cert.certSize
|
||||
*cert = _cert.cert
|
||||
*privkey = _cert.privkey
|
||||
//log.Println("set pcert length ", _cert.certSize)
|
||||
return 0
|
||||
}
|
||||
}
|
||||
if conn.cfg.Certificates != nil {
|
||||
_cert := conn.cfg.Certificates[0]
|
||||
*pcertLength = _cert.certSize
|
||||
*cert = _cert.cert
|
||||
*privkey = _cert.privkey
|
||||
//log.Println("set pcert length ", _cert.certSize)
|
||||
return 0
|
||||
}
|
||||
*pcertLength = 0
|
||||
//log.Println("set pcert length 0")
|
||||
return -1
|
||||
}
|
||||
|
||||
func init() {
|
||||
C.init_xcred()
|
||||
C.init_priority_cache()
|
||||
}
|
||||
|
@ -1,431 +1,431 @@
|
||||
package gnutls
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTLSClient(t *testing.T) {
|
||||
cert, err := tls.LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
|
||||
if err != nil {
|
||||
t.Fatal("load certificate failed")
|
||||
}
|
||||
l, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("listen failed")
|
||||
}
|
||||
defer l.Close()
|
||||
addr := l.Addr().String()
|
||||
t.Log("test server listen on ", addr)
|
||||
go func() {
|
||||
for {
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
t.Logf("accept connection from %s", c.RemoteAddr())
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
for {
|
||||
buf := make([]byte, 4096)
|
||||
n, err := c.Read(buf)
|
||||
if err != nil {
|
||||
t.Log("connection closed")
|
||||
break
|
||||
}
|
||||
if _, err = c.Write(buf[:n]); err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}(c)
|
||||
}
|
||||
}()
|
||||
|
||||
c, err := Dial("tcp", addr, &Config{InsecureSkipVerify: true})
|
||||
if err != nil {
|
||||
t.Fatal("gnutls dial ", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
data := "hello, world"
|
||||
if _, err = c.Write([]byte(data)); err != nil {
|
||||
t.Fatal("gnutls write ", err)
|
||||
}
|
||||
buf := make([]byte, 100)
|
||||
n, err := c.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatal("gnutls read ", err)
|
||||
}
|
||||
if string(buf[:n]) != data {
|
||||
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSServer(t *testing.T) {
|
||||
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
l, err := Listen("tcp", "127.0.0.1:0", &Config{
|
||||
Certificates: []*Certificate{cert},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("gnutls listen ", err)
|
||||
}
|
||||
addr := l.Addr().String()
|
||||
t.Log("test server listen on ", addr)
|
||||
defer l.Close()
|
||||
go func() {
|
||||
for {
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
t.Log("gnutls accept ", err)
|
||||
break
|
||||
}
|
||||
t.Log("accept connection from ", c.RemoteAddr())
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
tlsconn := c.(*Conn)
|
||||
if err := tlsconn.Handshake(); err != nil {
|
||||
t.Log(err)
|
||||
return
|
||||
}
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, err := c.Read(buf[0:])
|
||||
if err != nil {
|
||||
t.Log("gnutls read ", err)
|
||||
break
|
||||
}
|
||||
if _, err := c.Write(buf[:n]); err != nil {
|
||||
t.Log("gnutls write ", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}(c)
|
||||
}
|
||||
}()
|
||||
|
||||
c, err := tls.Dial("tcp", addr, &tls.Config{InsecureSkipVerify: true})
|
||||
if err != nil {
|
||||
t.Fatal("dial ", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
data := "hello, world"
|
||||
if _, err := c.Write([]byte(data)); err != nil {
|
||||
t.Fatal("write ", err)
|
||||
}
|
||||
buf := make([]byte, 100)
|
||||
n, err := c.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatal("read ", err)
|
||||
}
|
||||
if string(buf[:n]) != data {
|
||||
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
|
||||
}
|
||||
runtime.GC()
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
func TestTLSALPNServer(t *testing.T) {
|
||||
serveralpn := []string{"a1", "a3", "a2"}
|
||||
clientalpn := []string{"a0", "a2", "a5"}
|
||||
expectedAlpn := "a2"
|
||||
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
l, err := Listen("tcp", "127.0.0.1:0", &Config{
|
||||
Certificates: []*Certificate{cert},
|
||||
NextProtos: serveralpn,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("gnutls listen ", err)
|
||||
}
|
||||
addr := l.Addr().String()
|
||||
t.Log("test server listen on ", addr)
|
||||
defer l.Close()
|
||||
go func() {
|
||||
for {
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
t.Log("gnutls accept ", err)
|
||||
break
|
||||
}
|
||||
t.Log("accept connection from ", c.RemoteAddr())
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
tlsConn := c.(*Conn)
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
t.Log(err)
|
||||
return
|
||||
}
|
||||
connState := tlsConn.ConnectionState()
|
||||
t.Logf("%+v", connState)
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, err := c.Read(buf[0:])
|
||||
if err != nil {
|
||||
t.Log("gnutls read ", err)
|
||||
break
|
||||
}
|
||||
if _, err := c.Write(buf[:n]); err != nil {
|
||||
t.Log("gnutls write ", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}(c)
|
||||
}
|
||||
}()
|
||||
|
||||
c, err := tls.Dial("tcp", addr, &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: "localhost",
|
||||
NextProtos: clientalpn,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("dial ", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
if err := c.Handshake(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
connState := c.ConnectionState()
|
||||
t.Logf("%+v", connState)
|
||||
|
||||
if connState.NegotiatedProtocol != expectedAlpn {
|
||||
t.Errorf("expected alpn %s, got %s",
|
||||
expectedAlpn, connState.NegotiatedProtocol)
|
||||
}
|
||||
|
||||
data := "hello, world"
|
||||
if _, err := c.Write([]byte(data)); err != nil {
|
||||
t.Fatal("write ", err)
|
||||
}
|
||||
buf := make([]byte, 100)
|
||||
n, err := c.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatal("read ", err)
|
||||
}
|
||||
if string(buf[:n]) != data {
|
||||
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
|
||||
}
|
||||
runtime.GC()
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
func TestTLSALPNClient(t *testing.T) {
|
||||
serveralpn := []string{"a1", "a3", "a2"}
|
||||
clientalpn := []string{"a0", "a2", "a5"}
|
||||
expectedAlpn := "a2"
|
||||
|
||||
cert, err := tls.LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
|
||||
if err != nil {
|
||||
t.Fatal("load key failed")
|
||||
}
|
||||
|
||||
l, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
NextProtos: serveralpn,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("tls listen ", err)
|
||||
}
|
||||
addr := l.Addr().String()
|
||||
t.Log("test server listen on ", addr)
|
||||
defer l.Close()
|
||||
go func() {
|
||||
for {
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
t.Log("gnutls accept ", err)
|
||||
break
|
||||
}
|
||||
t.Log("accept connection from ", c.RemoteAddr())
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
tlsConn := c.(*tls.Conn)
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
t.Log(err)
|
||||
return
|
||||
}
|
||||
connState := tlsConn.ConnectionState()
|
||||
t.Logf("%+v", connState)
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, err := c.Read(buf[0:])
|
||||
if err != nil {
|
||||
t.Log("tls read ", err)
|
||||
break
|
||||
}
|
||||
if _, err := c.Write(buf[:n]); err != nil {
|
||||
t.Log("tls write ", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}(c)
|
||||
}
|
||||
}()
|
||||
|
||||
c, err := Dial("tcp", addr, &Config{InsecureSkipVerify: true,
|
||||
ServerName: "localhost",
|
||||
NextProtos: clientalpn,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("dial ", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
if err := c.Handshake(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
connState := c.ConnectionState()
|
||||
t.Logf("%+v", connState)
|
||||
|
||||
if connState.NegotiatedProtocol != expectedAlpn {
|
||||
t.Errorf("expected alpn %s, got %s",
|
||||
expectedAlpn, connState.NegotiatedProtocol)
|
||||
}
|
||||
|
||||
data := "hello, world"
|
||||
if _, err := c.Write([]byte(data)); err != nil {
|
||||
t.Fatal("write ", err)
|
||||
}
|
||||
buf := make([]byte, 100)
|
||||
n, err := c.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatal("read ", err)
|
||||
}
|
||||
if string(buf[:n]) != data {
|
||||
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
|
||||
}
|
||||
runtime.GC()
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
func TestTLSServerSNI(t *testing.T) {
|
||||
certificates := []*Certificate{}
|
||||
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
|
||||
if err != nil {
|
||||
t.Fatal("load key failed")
|
||||
}
|
||||
|
||||
certificates = append(certificates, cert)
|
||||
cert, err = LoadX509KeyPair("testdata/server2.crt", "testdata/server2.key")
|
||||
if err != nil {
|
||||
t.Fatal("load key failed")
|
||||
}
|
||||
|
||||
certificates = append(certificates, cert)
|
||||
cert, err = LoadX509KeyPair("testdata/server3.crt", "testdata/server3.key")
|
||||
if err != nil {
|
||||
t.Fatal("load key failed")
|
||||
}
|
||||
certificates = append(certificates, cert)
|
||||
|
||||
l, err := Listen("tcp", "127.0.0.1:0", &Config{
|
||||
Certificates: certificates,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer l.Close()
|
||||
addr := l.Addr().String()
|
||||
go func() {
|
||||
for {
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
break
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
tlsconn := c.(*Conn)
|
||||
if err := tlsconn.Handshake(); err != nil {
|
||||
t.Log(err)
|
||||
return
|
||||
}
|
||||
state := tlsconn.ConnectionState()
|
||||
fmt.Fprintf(c, state.ServerName)
|
||||
}(c)
|
||||
}
|
||||
}()
|
||||
|
||||
for _, cfg := range []struct {
|
||||
serverName string
|
||||
commonName string
|
||||
}{
|
||||
{"abc.com", "abc.com"},
|
||||
{"example.com", "example.com"},
|
||||
{"a.aaa.com", "*.aaa.com"},
|
||||
{"b.aaa.com", "*.aaa.com"},
|
||||
} {
|
||||
conn, err := tls.Dial("tcp", addr, &tls.Config{
|
||||
ServerName: cfg.serverName,
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
state := conn.ConnectionState()
|
||||
_commonName := state.PeerCertificates[0].Subject.CommonName
|
||||
if _commonName != cfg.commonName {
|
||||
t.Errorf("expect: %s, got: %s", cfg.commonName, _commonName)
|
||||
}
|
||||
buf := make([]byte, 100)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil && err != io.EOF {
|
||||
t.Error(err)
|
||||
}
|
||||
if !bytes.Equal(buf[:n], []byte(cfg.serverName)) {
|
||||
t.Errorf("expect %s, got %s", cfg.serverName, string(buf[:n]))
|
||||
}
|
||||
conn.Close()
|
||||
}
|
||||
runtime.GC()
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
func TestTLSGetPeerCert(t *testing.T) {
|
||||
conn, err := Dial("tcp", "www.ratafee.nl:443", &Config{
|
||||
ServerName: "www.ratafee.nl",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
//tlsconn := conn.(*Conn)
|
||||
if err := conn.Handshake(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
state := conn.ConnectionState()
|
||||
for i := 0; i < int(state.PeerCertificate.certSize); i++ {
|
||||
t.Log(state.PeerCertificate.getCertString(i, 1))
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("GET", "https://www.ratafee.nl/httpbin/ip", nil)
|
||||
req.Write(conn)
|
||||
r := bufio.NewReader(conn)
|
||||
resp, err := http.ReadResponse(r, req)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
var buf = new(bytes.Buffer)
|
||||
resp.Write(buf)
|
||||
t.Logf("%s", string(buf.Bytes()))
|
||||
runtime.GC()
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
package gnutls
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTLSClient(t *testing.T) {
|
||||
cert, err := tls.LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
|
||||
if err != nil {
|
||||
t.Fatal("load certificate failed")
|
||||
}
|
||||
l, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("listen failed")
|
||||
}
|
||||
defer l.Close()
|
||||
addr := l.Addr().String()
|
||||
t.Log("test server listen on ", addr)
|
||||
go func() {
|
||||
for {
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
t.Logf("accept connection from %s", c.RemoteAddr())
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
for {
|
||||
buf := make([]byte, 4096)
|
||||
n, err := c.Read(buf)
|
||||
if err != nil {
|
||||
t.Log("connection closed")
|
||||
break
|
||||
}
|
||||
if _, err = c.Write(buf[:n]); err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}(c)
|
||||
}
|
||||
}()
|
||||
|
||||
c, err := Dial("tcp", addr, &Config{InsecureSkipVerify: true})
|
||||
if err != nil {
|
||||
t.Fatal("gnutls dial ", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
data := "hello, world"
|
||||
if _, err = c.Write([]byte(data)); err != nil {
|
||||
t.Fatal("gnutls write ", err)
|
||||
}
|
||||
buf := make([]byte, 100)
|
||||
n, err := c.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatal("gnutls read ", err)
|
||||
}
|
||||
if string(buf[:n]) != data {
|
||||
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSServer(t *testing.T) {
|
||||
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
l, err := Listen("tcp", "127.0.0.1:0", &Config{
|
||||
Certificates: []*Certificate{cert},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("gnutls listen ", err)
|
||||
}
|
||||
addr := l.Addr().String()
|
||||
t.Log("test server listen on ", addr)
|
||||
defer l.Close()
|
||||
go func() {
|
||||
for {
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
t.Log("gnutls accept ", err)
|
||||
break
|
||||
}
|
||||
t.Log("accept connection from ", c.RemoteAddr())
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
tlsconn := c.(*Conn)
|
||||
if err := tlsconn.Handshake(); err != nil {
|
||||
t.Log(err)
|
||||
return
|
||||
}
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, err := c.Read(buf[0:])
|
||||
if err != nil {
|
||||
t.Log("gnutls read ", err)
|
||||
break
|
||||
}
|
||||
if _, err := c.Write(buf[:n]); err != nil {
|
||||
t.Log("gnutls write ", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}(c)
|
||||
}
|
||||
}()
|
||||
|
||||
c, err := tls.Dial("tcp", addr, &tls.Config{InsecureSkipVerify: true})
|
||||
if err != nil {
|
||||
t.Fatal("dial ", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
data := "hello, world"
|
||||
if _, err := c.Write([]byte(data)); err != nil {
|
||||
t.Fatal("write ", err)
|
||||
}
|
||||
buf := make([]byte, 100)
|
||||
n, err := c.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatal("read ", err)
|
||||
}
|
||||
if string(buf[:n]) != data {
|
||||
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
|
||||
}
|
||||
runtime.GC()
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
func TestTLSALPNServer(t *testing.T) {
|
||||
serveralpn := []string{"a1", "a3", "a2"}
|
||||
clientalpn := []string{"a0", "a2", "a5"}
|
||||
expectedAlpn := "a2"
|
||||
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
l, err := Listen("tcp", "127.0.0.1:0", &Config{
|
||||
Certificates: []*Certificate{cert},
|
||||
NextProtos: serveralpn,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("gnutls listen ", err)
|
||||
}
|
||||
addr := l.Addr().String()
|
||||
t.Log("test server listen on ", addr)
|
||||
defer l.Close()
|
||||
go func() {
|
||||
for {
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
t.Log("gnutls accept ", err)
|
||||
break
|
||||
}
|
||||
t.Log("accept connection from ", c.RemoteAddr())
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
tlsConn := c.(*Conn)
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
t.Log(err)
|
||||
return
|
||||
}
|
||||
connState := tlsConn.ConnectionState()
|
||||
t.Logf("%+v", connState)
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, err := c.Read(buf[0:])
|
||||
if err != nil {
|
||||
t.Log("gnutls read ", err)
|
||||
break
|
||||
}
|
||||
if _, err := c.Write(buf[:n]); err != nil {
|
||||
t.Log("gnutls write ", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}(c)
|
||||
}
|
||||
}()
|
||||
|
||||
c, err := tls.Dial("tcp", addr, &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: "localhost",
|
||||
NextProtos: clientalpn,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("dial ", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
if err := c.Handshake(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
connState := c.ConnectionState()
|
||||
t.Logf("%+v", connState)
|
||||
|
||||
if connState.NegotiatedProtocol != expectedAlpn {
|
||||
t.Errorf("expected alpn %s, got %s",
|
||||
expectedAlpn, connState.NegotiatedProtocol)
|
||||
}
|
||||
|
||||
data := "hello, world"
|
||||
if _, err := c.Write([]byte(data)); err != nil {
|
||||
t.Fatal("write ", err)
|
||||
}
|
||||
buf := make([]byte, 100)
|
||||
n, err := c.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatal("read ", err)
|
||||
}
|
||||
if string(buf[:n]) != data {
|
||||
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
|
||||
}
|
||||
runtime.GC()
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
func TestTLSALPNClient(t *testing.T) {
|
||||
serveralpn := []string{"a1", "a3", "a2"}
|
||||
clientalpn := []string{"a0", "a2", "a5"}
|
||||
expectedAlpn := "a2"
|
||||
|
||||
cert, err := tls.LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
|
||||
if err != nil {
|
||||
t.Fatal("load key failed")
|
||||
}
|
||||
|
||||
l, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
NextProtos: serveralpn,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("tls listen ", err)
|
||||
}
|
||||
addr := l.Addr().String()
|
||||
t.Log("test server listen on ", addr)
|
||||
defer l.Close()
|
||||
go func() {
|
||||
for {
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
t.Log("gnutls accept ", err)
|
||||
break
|
||||
}
|
||||
t.Log("accept connection from ", c.RemoteAddr())
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
tlsConn := c.(*tls.Conn)
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
t.Log(err)
|
||||
return
|
||||
}
|
||||
connState := tlsConn.ConnectionState()
|
||||
t.Logf("%+v", connState)
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, err := c.Read(buf[0:])
|
||||
if err != nil {
|
||||
t.Log("tls read ", err)
|
||||
break
|
||||
}
|
||||
if _, err := c.Write(buf[:n]); err != nil {
|
||||
t.Log("tls write ", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}(c)
|
||||
}
|
||||
}()
|
||||
|
||||
c, err := Dial("tcp", addr, &Config{InsecureSkipVerify: true,
|
||||
ServerName: "localhost",
|
||||
NextProtos: clientalpn,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("dial ", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
if err := c.Handshake(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
connState := c.ConnectionState()
|
||||
t.Logf("%+v", connState)
|
||||
|
||||
if connState.NegotiatedProtocol != expectedAlpn {
|
||||
t.Errorf("expected alpn %s, got %s",
|
||||
expectedAlpn, connState.NegotiatedProtocol)
|
||||
}
|
||||
|
||||
data := "hello, world"
|
||||
if _, err := c.Write([]byte(data)); err != nil {
|
||||
t.Fatal("write ", err)
|
||||
}
|
||||
buf := make([]byte, 100)
|
||||
n, err := c.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatal("read ", err)
|
||||
}
|
||||
if string(buf[:n]) != data {
|
||||
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
|
||||
}
|
||||
runtime.GC()
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
func TestTLSServerSNI(t *testing.T) {
|
||||
certificates := []*Certificate{}
|
||||
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
|
||||
if err != nil {
|
||||
t.Fatal("load key failed")
|
||||
}
|
||||
|
||||
certificates = append(certificates, cert)
|
||||
cert, err = LoadX509KeyPair("testdata/server2.crt", "testdata/server2.key")
|
||||
if err != nil {
|
||||
t.Fatal("load key failed")
|
||||
}
|
||||
|
||||
certificates = append(certificates, cert)
|
||||
cert, err = LoadX509KeyPair("testdata/server3.crt", "testdata/server3.key")
|
||||
if err != nil {
|
||||
t.Fatal("load key failed")
|
||||
}
|
||||
certificates = append(certificates, cert)
|
||||
|
||||
l, err := Listen("tcp", "127.0.0.1:0", &Config{
|
||||
Certificates: certificates,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer l.Close()
|
||||
addr := l.Addr().String()
|
||||
go func() {
|
||||
for {
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
break
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
tlsconn := c.(*Conn)
|
||||
if err := tlsconn.Handshake(); err != nil {
|
||||
t.Log(err)
|
||||
return
|
||||
}
|
||||
state := tlsconn.ConnectionState()
|
||||
fmt.Fprintf(c, state.ServerName)
|
||||
}(c)
|
||||
}
|
||||
}()
|
||||
|
||||
for _, cfg := range []struct {
|
||||
serverName string
|
||||
commonName string
|
||||
}{
|
||||
{"abc.com", "abc.com"},
|
||||
{"example.com", "example.com"},
|
||||
{"a.aaa.com", "*.aaa.com"},
|
||||
{"b.aaa.com", "*.aaa.com"},
|
||||
} {
|
||||
conn, err := tls.Dial("tcp", addr, &tls.Config{
|
||||
ServerName: cfg.serverName,
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
state := conn.ConnectionState()
|
||||
_commonName := state.PeerCertificates[0].Subject.CommonName
|
||||
if _commonName != cfg.commonName {
|
||||
t.Errorf("expect: %s, got: %s", cfg.commonName, _commonName)
|
||||
}
|
||||
buf := make([]byte, 100)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil && err != io.EOF {
|
||||
t.Error(err)
|
||||
}
|
||||
if !bytes.Equal(buf[:n], []byte(cfg.serverName)) {
|
||||
t.Errorf("expect %s, got %s", cfg.serverName, string(buf[:n]))
|
||||
}
|
||||
conn.Close()
|
||||
}
|
||||
runtime.GC()
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
func TestTLSGetPeerCert(t *testing.T) {
|
||||
conn, err := Dial("tcp", "www.ratafee.nl:443", &Config{
|
||||
ServerName: "www.ratafee.nl",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
//tlsconn := conn.(*Conn)
|
||||
if err := conn.Handshake(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
state := conn.ConnectionState()
|
||||
for i := 0; i < int(state.PeerCertificate.certSize); i++ {
|
||||
t.Log(state.PeerCertificate.getCertString(i, 1))
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("GET", "https://www.ratafee.nl/httpbin/ip", nil)
|
||||
req.Write(conn)
|
||||
r := bufio.NewReader(conn)
|
||||
resp, err := http.ReadResponse(r, req)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
var buf = new(bytes.Buffer)
|
||||
resp.Write(buf)
|
||||
t.Logf("%s", string(buf.Bytes()))
|
||||
runtime.GC()
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
Loading…
Reference in New Issue