go fmt
parent
2c4c546551
commit
e168de539b
@ -1,176 +1,176 @@
|
|||||||
package gnutls
|
package gnutls
|
||||||
|
|
||||||
/*
|
/*
|
||||||
#include "_gnutls.h"
|
#include "_gnutls.h"
|
||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Certificate x509 certificate
|
// Certificate x509 certificate
|
||||||
type Certificate struct {
|
type Certificate struct {
|
||||||
cert *C.gnutls_pcert_st
|
cert *C.gnutls_pcert_st
|
||||||
privkey C.gnutls_privkey_t
|
privkey C.gnutls_privkey_t
|
||||||
certSize C.int
|
certSize C.int
|
||||||
}
|
}
|
||||||
|
|
||||||
// Free free the certificate context
|
// Free free the certificate context
|
||||||
func (c *Certificate) Free() {
|
func (c *Certificate) Free() {
|
||||||
if c.cert != nil {
|
if c.cert != nil {
|
||||||
C.free_cert_list(c.cert, c.certSize)
|
C.free_cert_list(c.cert, c.certSize)
|
||||||
}
|
}
|
||||||
if c.privkey != nil {
|
if c.privkey != nil {
|
||||||
C.gnutls_privkey_deinit(c.privkey)
|
C.gnutls_privkey_deinit(c.privkey)
|
||||||
}
|
}
|
||||||
c.cert = nil
|
c.cert = nil
|
||||||
c.privkey = nil
|
c.privkey = nil
|
||||||
c.certSize = 0
|
c.certSize = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Certificate) free() {
|
func (c *Certificate) free() {
|
||||||
//log.Println("free certificate")
|
//log.Println("free certificate")
|
||||||
c.Free()
|
c.Free()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Certificate) matchName(name string) bool {
|
func (c *Certificate) matchName(name string) bool {
|
||||||
cname := C.CString(name)
|
cname := C.CString(name)
|
||||||
defer C.free(unsafe.Pointer(cname))
|
defer C.free(unsafe.Pointer(cname))
|
||||||
ret := C.cert_check_hostname(c.cert, c.certSize, cname)
|
ret := C.cert_check_hostname(c.cert, c.certSize, cname)
|
||||||
if int(ret) < 0 {
|
if int(ret) < 0 {
|
||||||
log.Println(C.GoString(C.gnutls_strerror(ret)))
|
log.Println(C.GoString(C.gnutls_strerror(ret)))
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if int(ret) > 0 {
|
if int(ret) > 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// CommonName get CN field in subject,
|
// CommonName get CN field in subject,
|
||||||
//
|
//
|
||||||
// example: subject C=xx,ST=bbb,CN=abc will return abc
|
// example: subject C=xx,ST=bbb,CN=abc will return abc
|
||||||
func (c *Certificate) CommonName() string {
|
func (c *Certificate) CommonName() string {
|
||||||
return c.commonName(0)
|
return c.commonName(0)
|
||||||
}
|
}
|
||||||
func (c *Certificate) commonName(index int) string {
|
func (c *Certificate) commonName(index int) string {
|
||||||
dnMap := make(map[string]string)
|
dnMap := make(map[string]string)
|
||||||
dn := c.getDN(0)
|
dn := c.getDN(0)
|
||||||
if dn != "" {
|
if dn != "" {
|
||||||
dnFields := strings.Split(dn, ",")
|
dnFields := strings.Split(dn, ",")
|
||||||
for _, d := range dnFields {
|
for _, d := range dnFields {
|
||||||
dd := strings.Split(d, "=")
|
dd := strings.Split(d, "=")
|
||||||
k := dd[0]
|
k := dd[0]
|
||||||
v := dd[1]
|
v := dd[1]
|
||||||
dnMap[k] = v
|
dnMap[k] = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if cn, ok := dnMap["CN"]; ok {
|
if cn, ok := dnMap["CN"]; ok {
|
||||||
return cn
|
return cn
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAltName get altname in certificate
|
// GetAltName get altname in certificate
|
||||||
func (c *Certificate) GetAltName() string {
|
func (c *Certificate) GetAltName() string {
|
||||||
return c.getAltName(0, 0)
|
return c.getAltName(0, 0)
|
||||||
}
|
}
|
||||||
func (c *Certificate) getAltName(index int, nameindex int) string {
|
func (c *Certificate) getAltName(index int, nameindex int) string {
|
||||||
out := C.malloc(1024)
|
out := C.malloc(1024)
|
||||||
defer C.free(out)
|
defer C.free(out)
|
||||||
size := C.get_pcert_alt_name(
|
size := C.get_pcert_alt_name(
|
||||||
c.cert, C.int(index), C.int(nameindex), (*C.char)(out))
|
c.cert, C.int(index), C.int(nameindex), (*C.char)(out))
|
||||||
if int(size) < 0 {
|
if int(size) < 0 {
|
||||||
log.Println(C.GoString(C.gnutls_strerror(size)))
|
log.Println(C.GoString(C.gnutls_strerror(size)))
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
name := C.GoBytes(out, size)
|
name := C.GoBytes(out, size)
|
||||||
return string(name)
|
return string(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
//GetCertString return certificate info string in one line
|
//GetCertString return certificate info string in one line
|
||||||
func (c *Certificate) GetCertString() string {
|
func (c *Certificate) GetCertString() string {
|
||||||
return c.getCertString(0, 1)
|
return c.getCertString(0, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Certificate) getCertString(index int, flag int) string {
|
func (c *Certificate) getCertString(index int, flag int) string {
|
||||||
out := C.malloc(4096)
|
out := C.malloc(4096)
|
||||||
defer C.free(out)
|
defer C.free(out)
|
||||||
size := C.get_cert_str(c.cert, C.int(index), C.int(flag), (*C.char)(out))
|
size := C.get_cert_str(c.cert, C.int(index), C.int(flag), (*C.char)(out))
|
||||||
if int(size) < 0 {
|
if int(size) < 0 {
|
||||||
log.Println(C.GoString(C.gnutls_strerror(size)))
|
log.Println(C.GoString(C.gnutls_strerror(size)))
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
s := C.GoBytes(out, size)
|
s := C.GoBytes(out, size)
|
||||||
return string(s)
|
return string(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDN get the certificate subject, like O=st,C=aa,CN=localhost
|
// GetDN get the certificate subject, like O=st,C=aa,CN=localhost
|
||||||
func (c *Certificate) GetDN() string {
|
func (c *Certificate) GetDN() string {
|
||||||
return c.getDN(0)
|
return c.getDN(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Certificate) getDN(index int) string {
|
func (c *Certificate) getDN(index int) string {
|
||||||
cbuf := C.malloc(200)
|
cbuf := C.malloc(200)
|
||||||
defer C.free(cbuf)
|
defer C.free(cbuf)
|
||||||
size := C.get_cert_dn(c.cert, C.int(index), (*C.char)(cbuf))
|
size := C.get_cert_dn(c.cert, C.int(index), (*C.char)(cbuf))
|
||||||
if int(size) < 0 {
|
if int(size) < 0 {
|
||||||
log.Println(C.GoString(C.gnutls_strerror(size)))
|
log.Println(C.GoString(C.gnutls_strerror(size)))
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
s := C.GoBytes(cbuf, size)
|
s := C.GoBytes(cbuf, size)
|
||||||
return string(s)
|
return string(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetIssuerDN get the certificate issuer's subject, like O=st,C=ac,CN=localhost
|
// GetIssuerDN get the certificate issuer's subject, like O=st,C=ac,CN=localhost
|
||||||
func (c *Certificate) GetIssuerDN() string {
|
func (c *Certificate) GetIssuerDN() string {
|
||||||
return c.getIssuerDN(0)
|
return c.getIssuerDN(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Certificate) getIssuerDN(index int) string {
|
func (c *Certificate) getIssuerDN(index int) string {
|
||||||
cbuf := C.malloc(200)
|
cbuf := C.malloc(200)
|
||||||
defer C.free(cbuf)
|
defer C.free(cbuf)
|
||||||
size := C.get_cert_issuer_dn(c.cert, C.int(index), (*C.char)(cbuf))
|
size := C.get_cert_issuer_dn(c.cert, C.int(index), (*C.char)(cbuf))
|
||||||
if int(size) < 0 {
|
if int(size) < 0 {
|
||||||
log.Println(C.GoString(C.gnutls_strerror(size)))
|
log.Println(C.GoString(C.gnutls_strerror(size)))
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
s := C.GoBytes(cbuf, size)
|
s := C.GoBytes(cbuf, size)
|
||||||
return string(s)
|
return string(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadX509KeyPair load certificate pair,
|
// LoadX509KeyPair load certificate pair,
|
||||||
// the return Certifciate must be freed by call Free(),
|
// the return Certifciate must be freed by call Free(),
|
||||||
func LoadX509KeyPair(certfile, keyfile string) (*Certificate, error) {
|
func LoadX509KeyPair(certfile, keyfile string) (*Certificate, error) {
|
||||||
_certfile := C.CString(certfile)
|
_certfile := C.CString(certfile)
|
||||||
_keyfile := C.CString(keyfile)
|
_keyfile := C.CString(keyfile)
|
||||||
|
|
||||||
defer C.free(unsafe.Pointer(_certfile))
|
defer C.free(unsafe.Pointer(_certfile))
|
||||||
defer C.free(unsafe.Pointer(_keyfile))
|
defer C.free(unsafe.Pointer(_keyfile))
|
||||||
|
|
||||||
certificate := &Certificate{}
|
certificate := &Certificate{}
|
||||||
var ret C.int
|
var ret C.int
|
||||||
var certSize C.int
|
var certSize C.int
|
||||||
cert := C.load_cert_list(_certfile, (*C.int)(unsafe.Pointer(&certSize)),
|
cert := C.load_cert_list(_certfile, (*C.int)(unsafe.Pointer(&certSize)),
|
||||||
(*C.int)(unsafe.Pointer(&ret)))
|
(*C.int)(unsafe.Pointer(&ret)))
|
||||||
if int(ret) < 0 {
|
if int(ret) < 0 {
|
||||||
return nil, fmt.Errorf("load cert failed: %s",
|
return nil, fmt.Errorf("load cert failed: %s",
|
||||||
C.GoString(C.gnutls_strerror(ret)))
|
C.GoString(C.gnutls_strerror(ret)))
|
||||||
}
|
}
|
||||||
privkey := C.load_privkey(_keyfile, (*C.int)(unsafe.Pointer(&ret)))
|
privkey := C.load_privkey(_keyfile, (*C.int)(unsafe.Pointer(&ret)))
|
||||||
if int(ret) < 0 {
|
if int(ret) < 0 {
|
||||||
return nil, fmt.Errorf("load privkey: %s",
|
return nil, fmt.Errorf("load privkey: %s",
|
||||||
C.GoString(C.gnutls_strerror(ret)))
|
C.GoString(C.gnutls_strerror(ret)))
|
||||||
}
|
}
|
||||||
certificate.cert = cert
|
certificate.cert = cert
|
||||||
certificate.privkey = privkey
|
certificate.privkey = privkey
|
||||||
certificate.certSize = certSize
|
certificate.certSize = certSize
|
||||||
runtime.SetFinalizer(certificate, (*Certificate).free)
|
runtime.SetFinalizer(certificate, (*Certificate).free)
|
||||||
return certificate, nil
|
return certificate, nil
|
||||||
}
|
}
|
||||||
|
@ -1,35 +1,35 @@
|
|||||||
package gnutls
|
package gnutls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"runtime"
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetAltname(t *testing.T) {
|
func TestGetAltname(t *testing.T) {
|
||||||
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
|
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
t.Logf("%+v", cert)
|
t.Logf("%+v", cert)
|
||||||
t.Log(cert.GetAltName())
|
t.Log(cert.GetAltName())
|
||||||
t.Log(cert.GetCertString())
|
t.Log(cert.GetCertString())
|
||||||
t.Log(cert.GetDN())
|
t.Log(cert.GetDN())
|
||||||
t.Log(cert.GetIssuerDN())
|
t.Log(cert.GetIssuerDN())
|
||||||
t.Log("CN ", cert.CommonName())
|
t.Log("CN ", cert.CommonName())
|
||||||
//t.Log("flag 0: ", cert.getCertString(0, 0))
|
//t.Log("flag 0: ", cert.getCertString(0, 0))
|
||||||
//t.Log("flag 1: ", cert.getCertString(0, 1))
|
//t.Log("flag 1: ", cert.getCertString(0, 1))
|
||||||
//t.Log("flag 2: ", cert.getCertString(0, 2))
|
//t.Log("flag 2: ", cert.getCertString(0, 2))
|
||||||
//t.Log("flag 3: ", cert.getCertString(0, 3))
|
//t.Log("flag 3: ", cert.getCertString(0, 3))
|
||||||
cert.Free()
|
cert.Free()
|
||||||
}
|
}
|
||||||
|
|
||||||
func _loadCert(certfile, keyfile string) (*Certificate, error) {
|
func _loadCert(certfile, keyfile string) (*Certificate, error) {
|
||||||
return LoadX509KeyPair(certfile, keyfile)
|
return LoadX509KeyPair(certfile, keyfile)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCertGC(t *testing.T) {
|
func TestCertGC(t *testing.T) {
|
||||||
_loadCert("testdata/server.crt", "testdata/server.key")
|
_loadCert("testdata/server.crt", "testdata/server.key")
|
||||||
runtime.GC()
|
runtime.GC()
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
}
|
}
|
||||||
|
@ -1,167 +1,167 @@
|
|||||||
package gnutls
|
package gnutls
|
||||||
|
|
||||||
/*
|
/*
|
||||||
#include "_gnutls.h"
|
#include "_gnutls.h"
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"runtime"
|
"runtime"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CipherType cipher type
|
// CipherType cipher type
|
||||||
type CipherType int
|
type CipherType int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
GNUTLS_CIPHER_AES_128_CBC CipherType = 4
|
GNUTLS_CIPHER_AES_128_CBC CipherType = 4
|
||||||
GNUTLS_CIPHER_AES_256_CBC CipherType = 5
|
GNUTLS_CIPHER_AES_256_CBC CipherType = 5
|
||||||
GNUTLS_CIPHER_ARCFOUR_40 CipherType = 6
|
GNUTLS_CIPHER_ARCFOUR_40 CipherType = 6
|
||||||
GNUTLS_CIPHER_CAMELLIA_128_CBC CipherType = 7
|
GNUTLS_CIPHER_CAMELLIA_128_CBC CipherType = 7
|
||||||
GNUTLS_CIPHER_CAMELLIA_256_CBC CipherType = 8
|
GNUTLS_CIPHER_CAMELLIA_256_CBC CipherType = 8
|
||||||
GNUTLS_CIPHER_AES_192_CBC CipherType = 9
|
GNUTLS_CIPHER_AES_192_CBC CipherType = 9
|
||||||
GNUTLS_CIPHER_AES_128_GCM CipherType = 10
|
GNUTLS_CIPHER_AES_128_GCM CipherType = 10
|
||||||
GNUTLS_CIPHER_AES_256_GCM CipherType = 11
|
GNUTLS_CIPHER_AES_256_GCM CipherType = 11
|
||||||
GNUTLS_CIPHER_CAMELLIA_192_CBC CipherType = 12
|
GNUTLS_CIPHER_CAMELLIA_192_CBC CipherType = 12
|
||||||
GNUTLS_CIPHER_SALSA20_256 CipherType = 13
|
GNUTLS_CIPHER_SALSA20_256 CipherType = 13
|
||||||
GNUTLS_CIPHER_ESTREAM_SALSA20_256 CipherType = 14
|
GNUTLS_CIPHER_ESTREAM_SALSA20_256 CipherType = 14
|
||||||
GNUTLS_CIPHER_CAMELLIA_128_GCM CipherType = 15
|
GNUTLS_CIPHER_CAMELLIA_128_GCM CipherType = 15
|
||||||
GNUTLS_CIPHER_CAMELLIA_256_GCM CipherType = 16
|
GNUTLS_CIPHER_CAMELLIA_256_GCM CipherType = 16
|
||||||
GNUTLS_CIPHER_RC2_40_CBC CipherType = 17
|
GNUTLS_CIPHER_RC2_40_CBC CipherType = 17
|
||||||
GNUTLS_CIPHER_DES_CBC CipherType = 18
|
GNUTLS_CIPHER_DES_CBC CipherType = 18
|
||||||
GNUTLS_CIPHER_AES_128_CCM CipherType = 19
|
GNUTLS_CIPHER_AES_128_CCM CipherType = 19
|
||||||
GNUTLS_CIPHER_AES_256_CCM CipherType = 20
|
GNUTLS_CIPHER_AES_256_CCM CipherType = 20
|
||||||
GNUTLS_CIPHER_AES_128_CCM_8 CipherType = 21
|
GNUTLS_CIPHER_AES_128_CCM_8 CipherType = 21
|
||||||
GNUTLS_CIPHER_AES_256_CCM_8 CipherType = 22
|
GNUTLS_CIPHER_AES_256_CCM_8 CipherType = 22
|
||||||
GNUTLS_CIPHER_CHACHA20_POLY1305 CipherType = 23
|
GNUTLS_CIPHER_CHACHA20_POLY1305 CipherType = 23
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// ErrBlockSize wrong block size
|
// ErrBlockSize wrong block size
|
||||||
ErrBlockSize = errors.New("wrong block size")
|
ErrBlockSize = errors.New("wrong block size")
|
||||||
// ErrKeyLength wrong key length
|
// ErrKeyLength wrong key length
|
||||||
ErrKeyLength = errors.New("wrong key length")
|
ErrKeyLength = errors.New("wrong key length")
|
||||||
// ErrIVLength wrong iv length
|
// ErrIVLength wrong iv length
|
||||||
ErrIVLength = errors.New("wrong iv length")
|
ErrIVLength = errors.New("wrong iv length")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Cipher gnutls cipher struct
|
// Cipher gnutls cipher struct
|
||||||
type Cipher struct {
|
type Cipher struct {
|
||||||
cipher C.gnutls_cipher_hd_t
|
cipher C.gnutls_cipher_hd_t
|
||||||
t CipherType
|
t CipherType
|
||||||
blockSize int
|
blockSize int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCipher create a new cipher by give type, key, iv
|
// NewCipher create a new cipher by give type, key, iv
|
||||||
//
|
//
|
||||||
// example:
|
// example:
|
||||||
// NewCipher(GNUTLS_CIPHER_AES_128_CBC, []byte("1234567890abcdef"), []byte("abcdef0123456789"))
|
// NewCipher(GNUTLS_CIPHER_AES_128_CBC, []byte("1234567890abcdef"), []byte("abcdef0123456789"))
|
||||||
//
|
//
|
||||||
// you can use GetCipherKeySize, GetCipherBlockSize, GetCipherIVSize to determine the given cipher 's key, block, iv size
|
// you can use GetCipherKeySize, GetCipherBlockSize, GetCipherIVSize to determine the given cipher 's key, block, iv size
|
||||||
func NewCipher(t CipherType, key []byte, iv []byte) (*Cipher, error) {
|
func NewCipher(t CipherType, key []byte, iv []byte) (*Cipher, error) {
|
||||||
keysize := GetCipherKeySize(t)
|
keysize := GetCipherKeySize(t)
|
||||||
ivSize := GetCipherIVSize(t)
|
ivSize := GetCipherIVSize(t)
|
||||||
blocksize := GetCipherBlockSize(t)
|
blocksize := GetCipherBlockSize(t)
|
||||||
//log.Printf("block size: %d, iv size: %d", int(ivSize), int(blockSize))
|
//log.Printf("block size: %d, iv size: %d", int(ivSize), int(blockSize))
|
||||||
if len(key) != int(keysize) {
|
if len(key) != int(keysize) {
|
||||||
return nil, ErrKeyLength
|
return nil, ErrKeyLength
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(iv) != int(ivSize) {
|
if len(iv) != int(ivSize) {
|
||||||
return nil, ErrIVLength
|
return nil, ErrIVLength
|
||||||
}
|
}
|
||||||
|
|
||||||
ckey := C.CBytes(key)
|
ckey := C.CBytes(key)
|
||||||
civ := C.CBytes(iv)
|
civ := C.CBytes(iv)
|
||||||
|
|
||||||
defer C.free(ckey)
|
defer C.free(ckey)
|
||||||
defer C.free(civ)
|
defer C.free(civ)
|
||||||
|
|
||||||
c := C.new_cipher(C.int(t), (*C.char)(ckey), C.int(len(key)), (*C.char)(civ), C.int(len(iv)))
|
c := C.new_cipher(C.int(t), (*C.char)(ckey), C.int(len(key)), (*C.char)(civ), C.int(len(iv)))
|
||||||
if c == nil {
|
if c == nil {
|
||||||
log.Println("new cipher return nil")
|
log.Println("new cipher return nil")
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
cipher := &Cipher{c, t, blocksize}
|
cipher := &Cipher{c, t, blocksize}
|
||||||
runtime.SetFinalizer(cipher, (*Cipher).free)
|
runtime.SetFinalizer(cipher, (*Cipher).free)
|
||||||
return cipher, nil
|
return cipher, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encrypt encrypt the buf and place the encrypted data in dst,
|
// Encrypt encrypt the buf and place the encrypted data in dst,
|
||||||
// the buf size must multiple of cipher's block size
|
// the buf size must multiple of cipher's block size
|
||||||
func (c *Cipher) Encrypt(dst, buf []byte) error {
|
func (c *Cipher) Encrypt(dst, buf []byte) error {
|
||||||
if len(buf)%c.blockSize != 0 {
|
if len(buf)%c.blockSize != 0 {
|
||||||
return ErrBlockSize
|
return ErrBlockSize
|
||||||
}
|
}
|
||||||
|
|
||||||
cbuf := C.CBytes(buf)
|
cbuf := C.CBytes(buf)
|
||||||
defer C.free(cbuf)
|
defer C.free(cbuf)
|
||||||
|
|
||||||
bufLen := C.size_t(len(buf))
|
bufLen := C.size_t(len(buf))
|
||||||
dstBuf := C.malloc(bufLen)
|
dstBuf := C.malloc(bufLen)
|
||||||
|
|
||||||
defer C.free(dstBuf)
|
defer C.free(dstBuf)
|
||||||
|
|
||||||
ret := C.gnutls_cipher_encrypt2(c.cipher, cbuf, bufLen, dstBuf, bufLen)
|
ret := C.gnutls_cipher_encrypt2(c.cipher, cbuf, bufLen, dstBuf, bufLen)
|
||||||
if int(ret) < 0 {
|
if int(ret) < 0 {
|
||||||
return fmt.Errorf("encrypt error: %s", C.GoString(C.gnutls_strerror(ret)))
|
return fmt.Errorf("encrypt error: %s", C.GoString(C.gnutls_strerror(ret)))
|
||||||
}
|
}
|
||||||
_buf := C.GoBytes(dstBuf, C.int(bufLen))
|
_buf := C.GoBytes(dstBuf, C.int(bufLen))
|
||||||
copy(dst, _buf)
|
copy(dst, _buf)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decrypt decrypt the buf and place the decrypted data in dst,
|
// Decrypt decrypt the buf and place the decrypted data in dst,
|
||||||
// the buf size must multiple of cipher's block size
|
// the buf size must multiple of cipher's block size
|
||||||
func (c *Cipher) Decrypt(dst, buf []byte) error {
|
func (c *Cipher) Decrypt(dst, buf []byte) error {
|
||||||
if len(buf)%c.blockSize != 0 {
|
if len(buf)%c.blockSize != 0 {
|
||||||
return ErrBlockSize
|
return ErrBlockSize
|
||||||
}
|
}
|
||||||
|
|
||||||
cbuf := C.CBytes(buf)
|
cbuf := C.CBytes(buf)
|
||||||
defer C.free(cbuf)
|
defer C.free(cbuf)
|
||||||
|
|
||||||
bufLen := C.size_t(len(buf))
|
bufLen := C.size_t(len(buf))
|
||||||
dstBuf := C.malloc(C.size_t(len(buf)))
|
dstBuf := C.malloc(C.size_t(len(buf)))
|
||||||
|
|
||||||
defer C.free(dstBuf)
|
defer C.free(dstBuf)
|
||||||
|
|
||||||
ret := C.gnutls_cipher_decrypt2(c.cipher, cbuf, bufLen, dstBuf, bufLen)
|
ret := C.gnutls_cipher_decrypt2(c.cipher, cbuf, bufLen, dstBuf, bufLen)
|
||||||
if int(ret) < 0 {
|
if int(ret) < 0 {
|
||||||
return fmt.Errorf("decrypt error: %s", C.GoString(C.gnutls_strerror(ret)))
|
return fmt.Errorf("decrypt error: %s", C.GoString(C.gnutls_strerror(ret)))
|
||||||
}
|
}
|
||||||
_buf := C.GoBytes(dstBuf, C.int(bufLen))
|
_buf := C.GoBytes(dstBuf, C.int(bufLen))
|
||||||
copy(dst, _buf)
|
copy(dst, _buf)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close destroy the cipher context
|
// Close destroy the cipher context
|
||||||
func (c *Cipher) Close() error {
|
func (c *Cipher) Close() error {
|
||||||
if c.cipher != nil {
|
if c.cipher != nil {
|
||||||
C.gnutls_cipher_deinit(c.cipher)
|
C.gnutls_cipher_deinit(c.cipher)
|
||||||
c.cipher = nil
|
c.cipher = nil
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Cipher) free() {
|
func (c *Cipher) free() {
|
||||||
//log.Println("free cipher")
|
//log.Println("free cipher")
|
||||||
c.Close()
|
c.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCipherKeySize get the cipher algorithm key length
|
// GetCipherKeySize get the cipher algorithm key length
|
||||||
func GetCipherKeySize(t CipherType) int {
|
func GetCipherKeySize(t CipherType) int {
|
||||||
return int(C.gnutls_cipher_get_key_size(C.gnutls_cipher_algorithm_t(t)))
|
return int(C.gnutls_cipher_get_key_size(C.gnutls_cipher_algorithm_t(t)))
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCipherIVSize get the cipher algorithm iv length
|
// GetCipherIVSize get the cipher algorithm iv length
|
||||||
func GetCipherIVSize(t CipherType) int {
|
func GetCipherIVSize(t CipherType) int {
|
||||||
return int(C.gnutls_cipher_get_iv_size(C.gnutls_cipher_algorithm_t(t)))
|
return int(C.gnutls_cipher_get_iv_size(C.gnutls_cipher_algorithm_t(t)))
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCipherBlockSize get the cipher algorithm block size
|
// GetCipherBlockSize get the cipher algorithm block size
|
||||||
func GetCipherBlockSize(t CipherType) int {
|
func GetCipherBlockSize(t CipherType) int {
|
||||||
return int(C.gnutls_cipher_get_block_size(C.gnutls_cipher_algorithm_t(t)))
|
return int(C.gnutls_cipher_get_block_size(C.gnutls_cipher_algorithm_t(t)))
|
||||||
}
|
}
|
||||||
|
@ -1,130 +1,130 @@
|
|||||||
package gnutls
|
package gnutls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/aes"
|
"crypto/aes"
|
||||||
"crypto/cipher"
|
"crypto/cipher"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"runtime"
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCipherSize(t *testing.T) {
|
func TestCipherSize(t *testing.T) {
|
||||||
data := []struct {
|
data := []struct {
|
||||||
t CipherType
|
t CipherType
|
||||||
bsize int
|
bsize int
|
||||||
isize int
|
isize int
|
||||||
}{
|
}{
|
||||||
{GNUTLS_CIPHER_AES_128_CBC, 16, 16},
|
{GNUTLS_CIPHER_AES_128_CBC, 16, 16},
|
||||||
{GNUTLS_CIPHER_AES_192_CBC, 24, 16},
|
{GNUTLS_CIPHER_AES_192_CBC, 24, 16},
|
||||||
{GNUTLS_CIPHER_AES_256_CBC, 32, 16},
|
{GNUTLS_CIPHER_AES_256_CBC, 32, 16},
|
||||||
}
|
}
|
||||||
for _, d := range data {
|
for _, d := range data {
|
||||||
blocksize := GetCipherKeySize(d.t)
|
blocksize := GetCipherKeySize(d.t)
|
||||||
if blocksize != d.bsize {
|
if blocksize != d.bsize {
|
||||||
t.Errorf("%d block size expect: %d, got: %d", d.t, d.bsize, blocksize)
|
t.Errorf("%d block size expect: %d, got: %d", d.t, d.bsize, blocksize)
|
||||||
}
|
}
|
||||||
ivsize := GetCipherIVSize(d.t)
|
ivsize := GetCipherIVSize(d.t)
|
||||||
if ivsize != d.isize {
|
if ivsize != d.isize {
|
||||||
t.Errorf("%d iv size expect: %d, got: %d", d.t, d.bsize, ivsize)
|
t.Errorf("%d iv size expect: %d, got: %d", d.t, d.bsize, ivsize)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func TestEncryptDecrypt(t *testing.T) {
|
func TestEncryptDecrypt(t *testing.T) {
|
||||||
cipherName := GNUTLS_CIPHER_AES_256_CBC
|
cipherName := GNUTLS_CIPHER_AES_256_CBC
|
||||||
keysize := GetCipherKeySize(cipherName)
|
keysize := GetCipherKeySize(cipherName)
|
||||||
ivsize := GetCipherIVSize(cipherName)
|
ivsize := GetCipherIVSize(cipherName)
|
||||||
blocksize := GetCipherBlockSize(cipherName)
|
blocksize := GetCipherBlockSize(cipherName)
|
||||||
|
|
||||||
key := make([]byte, keysize)
|
key := make([]byte, keysize)
|
||||||
iv := make([]byte, ivsize)
|
iv := make([]byte, ivsize)
|
||||||
rand.Reader.Read(key)
|
rand.Reader.Read(key)
|
||||||
rand.Reader.Read(iv)
|
rand.Reader.Read(iv)
|
||||||
|
|
||||||
c, err := NewCipher(cipherName, key, iv)
|
c, err := NewCipher(cipherName, key, iv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
//defer c.Close()
|
//defer c.Close()
|
||||||
|
|
||||||
c1, err := NewCipher(cipherName, key, iv)
|
c1, err := NewCipher(cipherName, key, iv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
//defer c1.Close()
|
//defer c1.Close()
|
||||||
|
|
||||||
data := make([]byte, blocksize*10)
|
data := make([]byte, blocksize*10)
|
||||||
if c == nil {
|
if c == nil {
|
||||||
t.Fatal("new ciphoer failed")
|
t.Fatal("new ciphoer failed")
|
||||||
}
|
}
|
||||||
cdata := make([]byte, len(data))
|
cdata := make([]byte, len(data))
|
||||||
err = c.Encrypt(cdata, data)
|
err = c.Encrypt(cdata, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("encrypt failed", err)
|
t.Fatal("encrypt failed", err)
|
||||||
}
|
}
|
||||||
data1 := make([]byte, len(data))
|
data1 := make([]byte, len(data))
|
||||||
err = c1.Decrypt(data1, cdata)
|
err = c1.Decrypt(data1, cdata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("decrypt failed", err)
|
t.Fatal("decrypt failed", err)
|
||||||
}
|
}
|
||||||
if !bytes.Equal(data, data1) {
|
if !bytes.Equal(data, data1) {
|
||||||
t.Fatal("encrypt/decrypt failed", string(data), string(data1))
|
t.Fatal("encrypt/decrypt failed", string(data), string(data1))
|
||||||
}
|
}
|
||||||
block, _ := aes.NewCipher(key)
|
block, _ := aes.NewCipher(key)
|
||||||
mode := cipher.NewCBCEncrypter(block, iv)
|
mode := cipher.NewCBCEncrypter(block, iv)
|
||||||
dst := make([]byte, len(data))
|
dst := make([]byte, len(data))
|
||||||
mode.CryptBlocks(dst, data)
|
mode.CryptBlocks(dst, data)
|
||||||
if !bytes.Equal(dst, cdata) {
|
if !bytes.Equal(dst, cdata) {
|
||||||
t.Fatal("cipher text not equal to cypto/aes")
|
t.Fatal("cipher text not equal to cypto/aes")
|
||||||
}
|
}
|
||||||
runtime.GC()
|
runtime.GC()
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkAESEncrypt(b *testing.B) {
|
func BenchmarkAESEncrypt(b *testing.B) {
|
||||||
cipherName := GNUTLS_CIPHER_AES_256_CBC
|
cipherName := GNUTLS_CIPHER_AES_256_CBC
|
||||||
keysize := GetCipherKeySize(cipherName)
|
keysize := GetCipherKeySize(cipherName)
|
||||||
ivsize := GetCipherIVSize(cipherName)
|
ivsize := GetCipherIVSize(cipherName)
|
||||||
blocksize := GetCipherBlockSize(cipherName)
|
blocksize := GetCipherBlockSize(cipherName)
|
||||||
datalen := blocksize * 500
|
datalen := blocksize * 500
|
||||||
|
|
||||||
key := make([]byte, keysize)
|
key := make([]byte, keysize)
|
||||||
iv := make([]byte, ivsize)
|
iv := make([]byte, ivsize)
|
||||||
rand.Reader.Read(key)
|
rand.Reader.Read(key)
|
||||||
rand.Reader.Read(iv)
|
rand.Reader.Read(iv)
|
||||||
buf := make([]byte, datalen)
|
buf := make([]byte, datalen)
|
||||||
dst := make([]byte, datalen)
|
dst := make([]byte, datalen)
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
c, err := NewCipher(cipherName, key, iv)
|
c, err := NewCipher(cipherName, key, iv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
c.Encrypt(dst, buf)
|
c.Encrypt(dst, buf)
|
||||||
c.Close()
|
c.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkAESEncrypt2(b *testing.B) {
|
func BenchmarkAESEncrypt2(b *testing.B) {
|
||||||
cipherName := GNUTLS_CIPHER_AES_256_CBC
|
cipherName := GNUTLS_CIPHER_AES_256_CBC
|
||||||
keysize := GetCipherKeySize(cipherName)
|
keysize := GetCipherKeySize(cipherName)
|
||||||
ivsize := GetCipherIVSize(cipherName)
|
ivsize := GetCipherIVSize(cipherName)
|
||||||
blocksize := GetCipherBlockSize(cipherName)
|
blocksize := GetCipherBlockSize(cipherName)
|
||||||
datalen := blocksize * 500
|
datalen := blocksize * 500
|
||||||
|
|
||||||
key := make([]byte, keysize)
|
key := make([]byte, keysize)
|
||||||
iv := make([]byte, ivsize)
|
iv := make([]byte, ivsize)
|
||||||
buf := make([]byte, datalen)
|
buf := make([]byte, datalen)
|
||||||
|
|
||||||
rand.Reader.Read(buf)
|
rand.Reader.Read(buf)
|
||||||
rand.Reader.Read(key)
|
rand.Reader.Read(key)
|
||||||
rand.Reader.Read(iv)
|
rand.Reader.Read(iv)
|
||||||
|
|
||||||
dst := make([]byte, datalen)
|
dst := make([]byte, datalen)
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
block, _ := aes.NewCipher(key)
|
block, _ := aes.NewCipher(key)
|
||||||
mode := cipher.NewCBCEncrypter(block, iv)
|
mode := cipher.NewCBCEncrypter(block, iv)
|
||||||
mode.CryptBlocks(dst, buf)
|
mode.CryptBlocks(dst, buf)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,108 +1,108 @@
|
|||||||
// Package gnutls is a gnutls binding for golang.
|
// Package gnutls is a gnutls binding for golang.
|
||||||
/*
|
/*
|
||||||
a limit set of api is supported.
|
a limit set of api is supported.
|
||||||
|
|
||||||
TLS api is very similar to crypto/tls on standard library.
|
TLS api is very similar to crypto/tls on standard library.
|
||||||
|
|
||||||
TLS client example:
|
TLS client example:
|
||||||
|
|
||||||
addr := "127.0.0.1:9443"
|
addr := "127.0.0.1:9443"
|
||||||
c, err := gnutls.Dial("tcp", addr, &gnutls.Config{ServerName: "localhost",InsecureSkipVerify: true})
|
c, err := gnutls.Dial("tcp", addr, &gnutls.Config{ServerName: "localhost",InsecureSkipVerify: true})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("gnutls dial ", err)
|
log.Fatal("gnutls dial ", err)
|
||||||
}
|
}
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
|
|
||||||
data := "hello, world"
|
data := "hello, world"
|
||||||
if _, err = c.Write([]byte(data)); err != nil {
|
if _, err = c.Write([]byte(data)); err != nil {
|
||||||
log.Fatal("gnutls write ", err)
|
log.Fatal("gnutls write ", err)
|
||||||
}
|
}
|
||||||
buf := make([]byte, 100)
|
buf := make([]byte, 100)
|
||||||
n, err := c.Read(buf)
|
n, err := c.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("gnutls read ", err)
|
log.Fatal("gnutls read ", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
TLS Server example:
|
TLS Server example:
|
||||||
cert, err := gnutls.LoadX509KeyPair("testdata/server/crt", "testdata/server.key")
|
cert, err := gnutls.LoadX509KeyPair("testdata/server/crt", "testdata/server.key")
|
||||||
if err != nil{
|
if err != nil{
|
||||||
// handle error
|
// handle error
|
||||||
}
|
}
|
||||||
l, err := gnults.Listen("tcp", "127.0.0.1:9443", &gnutls.Config{
|
l, err := gnults.Listen("tcp", "127.0.0.1:9443", &gnutls.Config{
|
||||||
Certificates: []*gnutls.Certificate{cert}})
|
Certificates: []*gnutls.Certificate{cert}})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// handle error
|
// handle error
|
||||||
}
|
}
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
for {
|
for {
|
||||||
c, err := l.Accept()
|
c, err := l.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("gnutls accept ", err)
|
log.Println("gnutls accept ", err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
log.Println("accept connection from ", c.RemoteAddr())
|
log.Println("accept connection from ", c.RemoteAddr())
|
||||||
go func(c net.Conn) {
|
go func(c net.Conn) {
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
|
|
||||||
buf := make([]byte, 4096)
|
buf := make([]byte, 4096)
|
||||||
for {
|
for {
|
||||||
n, err := c.Read(buf[0:])
|
n, err := c.Read(buf[0:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("gnutls read ", err)
|
log.Println("gnutls read ", err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if _, err := c.Write(buf[:n]); err != nil {
|
if _, err := c.Write(buf[:n]); err != nil {
|
||||||
log.Println("gnutls write ", err)
|
log.Println("gnutls write ", err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}(c)
|
}(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
AES encrypt/decrypt example:
|
AES encrypt/decrypt example:
|
||||||
|
|
||||||
key := []byte("0123456789abcdef")
|
key := []byte("0123456789abcdef")
|
||||||
iv := []byte("abcdefg123456789")
|
iv := []byte("abcdefg123456789")
|
||||||
c, err := gnutls.NewCipher(gnutls.GNUTLS_CIPHER_AES_128_CBC, key, iv)
|
c, err := gnutls.NewCipher(gnutls.GNUTLS_CIPHER_AES_128_CBC, key, iv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
|
|
||||||
c1, err := gnutls.NewCipher(gnutls.GNUTLS_CIPHER_AES_128_CBC, key, iv)
|
c1, err := gnutls.NewCipher(gnutls.GNUTLS_CIPHER_AES_128_CBC, key, iv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
defer c1.Close()
|
defer c1.Close()
|
||||||
|
|
||||||
data := []byte("1234012121212121")
|
data := []byte("1234012121212121")
|
||||||
if c == nil {
|
if c == nil {
|
||||||
log.Fatal("new ciphoer failed")
|
log.Fatal("new ciphoer failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
// encrypt
|
// encrypt
|
||||||
dst := make([]byte, len(data))
|
dst := make([]byte, len(data))
|
||||||
err := c.Encrypt(dst, data)
|
err := c.Encrypt(dst, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("encrypt failed", err)
|
log.Fatal("encrypt failed", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// decrypt
|
// decrypt
|
||||||
data1 := make([]byte, len(data))
|
data1 := make([]byte, len(data))
|
||||||
err := c1.Decrypt(data1, cdata)
|
err := c1.Decrypt(data1, cdata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("decrypt failed", err)
|
log.Fatal("decrypt failed", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
Hash example:
|
Hash example:
|
||||||
|
|
||||||
h := gnutls.NewHash(gnutls.GNUTLS_HASH_SHA512)
|
h := gnutls.NewHash(gnutls.GNUTLS_HASH_SHA512)
|
||||||
defer h.Close()
|
defer h.Close()
|
||||||
|
|
||||||
data := []byte("1234")
|
data := []byte("1234")
|
||||||
|
|
||||||
h1 := h.Sum(data)
|
h1 := h.Sum(data)
|
||||||
|
|
||||||
*/
|
*/
|
||||||
package gnutls
|
package gnutls
|
||||||
|
@ -1,90 +1,90 @@
|
|||||||
package gnutls
|
package gnutls
|
||||||
|
|
||||||
/*
|
/*
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include "_gnutls.h"
|
#include "_gnutls.h"
|
||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
)
|
)
|
||||||
|
|
||||||
// HashType hash type
|
// HashType hash type
|
||||||
type HashType int
|
type HashType int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
GNUTLS_HASH_MD5 HashType = 2
|
GNUTLS_HASH_MD5 HashType = 2
|
||||||
GNUTLS_HASH_SHA1 HashType = 3
|
GNUTLS_HASH_SHA1 HashType = 3
|
||||||
GNUTLS_HASH_MD2 HashType = 5
|
GNUTLS_HASH_MD2 HashType = 5
|
||||||
GNUTLS_HASH_SHA256 HashType = 6
|
GNUTLS_HASH_SHA256 HashType = 6
|
||||||
GNUTLS_HASH_SHA384 HashType = 7
|
GNUTLS_HASH_SHA384 HashType = 7
|
||||||
GNUTLS_HASH_SHA512 HashType = 8
|
GNUTLS_HASH_SHA512 HashType = 8
|
||||||
GNUTLS_HASH_SHA224 HashType = 9
|
GNUTLS_HASH_SHA224 HashType = 9
|
||||||
)
|
)
|
||||||
|
|
||||||
// Hash hash struct
|
// Hash hash struct
|
||||||
type Hash struct {
|
type Hash struct {
|
||||||
hash C.gnutls_hash_hd_t
|
hash C.gnutls_hash_hd_t
|
||||||
t HashType
|
t HashType
|
||||||
hashLen C.int
|
hashLen C.int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHash new hash struct
|
// NewHash new hash struct
|
||||||
func NewHash(t HashType) *Hash {
|
func NewHash(t HashType) *Hash {
|
||||||
h := C.new_hash(C.int(t))
|
h := C.new_hash(C.int(t))
|
||||||
hashOutLen := GetHashOutputLen(t)
|
hashOutLen := GetHashOutputLen(t)
|
||||||
hash := &Hash{h, t, C.int(hashOutLen)}
|
hash := &Hash{h, t, C.int(hashOutLen)}
|
||||||
runtime.SetFinalizer(hash, (*Hash).free)
|
runtime.SetFinalizer(hash, (*Hash).free)
|
||||||
return hash
|
return hash
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write write data to hash context
|
// Write write data to hash context
|
||||||
func (h *Hash) Write(buf []byte) error {
|
func (h *Hash) Write(buf []byte) error {
|
||||||
dataLen := len(buf)
|
dataLen := len(buf)
|
||||||
|
|
||||||
cbuf := C.CBytes(buf)
|
cbuf := C.CBytes(buf)
|
||||||
defer C.free(cbuf)
|
defer C.free(cbuf)
|
||||||
|
|
||||||
ret := C.gnutls_hash(h.hash, cbuf, C.size_t(dataLen))
|
ret := C.gnutls_hash(h.hash, cbuf, C.size_t(dataLen))
|
||||||
if int(ret) < 0 {
|
if int(ret) < 0 {
|
||||||
return fmt.Errorf("hash failed: %s", C.GoString(C.gnutls_strerror(ret)))
|
return fmt.Errorf("hash failed: %s", C.GoString(C.gnutls_strerror(ret)))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sum get hash result
|
// Sum get hash result
|
||||||
func (h *Hash) Sum(buf []byte) []byte {
|
func (h *Hash) Sum(buf []byte) []byte {
|
||||||
if buf != nil {
|
if buf != nil {
|
||||||
h.Write(buf)
|
h.Write(buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
dstBuf := C.malloc(C.size_t(h.hashLen))
|
dstBuf := C.malloc(C.size_t(h.hashLen))
|
||||||
defer C.free(dstBuf)
|
defer C.free(dstBuf)
|
||||||
|
|
||||||
C.gnutls_hash_output(h.hash, dstBuf)
|
C.gnutls_hash_output(h.hash, dstBuf)
|
||||||
|
|
||||||
gobuf := C.GoBytes(dstBuf, h.hashLen)
|
gobuf := C.GoBytes(dstBuf, h.hashLen)
|
||||||
|
|
||||||
return gobuf
|
return gobuf
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close destroy hash context
|
// Close destroy hash context
|
||||||
func (h *Hash) Close() error {
|
func (h *Hash) Close() error {
|
||||||
if h.hash != nil {
|
if h.hash != nil {
|
||||||
C.gnutls_hash_deinit(h.hash, nil)
|
C.gnutls_hash_deinit(h.hash, nil)
|
||||||
h.hash = nil
|
h.hash = nil
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (h *Hash) free() {
|
func (h *Hash) free() {
|
||||||
//log.Println("free hash")
|
//log.Println("free hash")
|
||||||
h.Close()
|
h.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetHashOutputLen get the hash algorithm output length
|
// GetHashOutputLen get the hash algorithm output length
|
||||||
//
|
//
|
||||||
// example GNUTLS_MD5 is 16
|
// example GNUTLS_MD5 is 16
|
||||||
func GetHashOutputLen(t HashType) int {
|
func GetHashOutputLen(t HashType) int {
|
||||||
return int(C.gnutls_hash_get_len(C.gnutls_digest_algorithm_t(t)))
|
return int(C.gnutls_hash_get_len(C.gnutls_digest_algorithm_t(t)))
|
||||||
}
|
}
|
||||||
|
@ -1,52 +1,52 @@
|
|||||||
package gnutls
|
package gnutls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha512"
|
"crypto/sha512"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"runtime"
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHashSHA(t *testing.T) {
|
func TestHashSHA(t *testing.T) {
|
||||||
h := NewHash(GNUTLS_HASH_SHA512)
|
h := NewHash(GNUTLS_HASH_SHA512)
|
||||||
//defer h.Close()
|
//defer h.Close()
|
||||||
|
|
||||||
data := []byte("1234")
|
data := []byte("1234")
|
||||||
|
|
||||||
h1 := h.Sum(data)
|
h1 := h.Sum(data)
|
||||||
|
|
||||||
h3 := sha512.New()
|
h3 := sha512.New()
|
||||||
h3.Write(data)
|
h3.Write(data)
|
||||||
h4 := h3.Sum(nil)
|
h4 := h3.Sum(nil)
|
||||||
if !bytes.Equal(h4[:], h1) {
|
if !bytes.Equal(h4[:], h1) {
|
||||||
t.Logf("\n%s\n%s", hex.EncodeToString(h4[:]), hex.EncodeToString(h1))
|
t.Logf("\n%s\n%s", hex.EncodeToString(h4[:]), hex.EncodeToString(h1))
|
||||||
t.Fatal("hash not equal")
|
t.Fatal("hash not equal")
|
||||||
}
|
}
|
||||||
runtime.GC()
|
runtime.GC()
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkHashSHA512(b *testing.B) {
|
func BenchmarkHashSHA512(b *testing.B) {
|
||||||
buf := make([]byte, 1000*1024)
|
buf := make([]byte, 1000*1024)
|
||||||
rand.Reader.Read(buf)
|
rand.Reader.Read(buf)
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
h := NewHash(GNUTLS_HASH_SHA512)
|
h := NewHash(GNUTLS_HASH_SHA512)
|
||||||
h.Write(buf)
|
h.Write(buf)
|
||||||
h.Sum(nil)
|
h.Sum(nil)
|
||||||
h.Close()
|
h.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkHashSHA512s(b *testing.B) {
|
func BenchmarkHashSHA512s(b *testing.B) {
|
||||||
buf := make([]byte, 1000*1024)
|
buf := make([]byte, 1000*1024)
|
||||||
rand.Reader.Read(buf)
|
rand.Reader.Read(buf)
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
h := sha512.New()
|
h := sha512.New()
|
||||||
h.Write(buf)
|
h.Write(buf)
|
||||||
h.Sum(nil)
|
h.Sum(nil)
|
||||||
//h.Close()
|
//h.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,492 +1,492 @@
|
|||||||
package gnutls
|
package gnutls
|
||||||
|
|
||||||
/*
|
/*
|
||||||
#include "_gnutls.h"
|
#include "_gnutls.h"
|
||||||
#cgo pkg-config: gnutls
|
#cgo pkg-config: gnutls
|
||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
GNUTLS_NAME_DNS = 1
|
GNUTLS_NAME_DNS = 1
|
||||||
GNUTLS_X509_FMT_PEM = 1
|
GNUTLS_X509_FMT_PEM = 1
|
||||||
GNUTLS_ALPN_MANDATORY = 1
|
GNUTLS_ALPN_MANDATORY = 1
|
||||||
GNUTLS_ALPN_SERVER_PRECEDENCE = 1 << 1
|
GNUTLS_ALPN_SERVER_PRECEDENCE = 1 << 1
|
||||||
)
|
)
|
||||||
|
|
||||||
// Conn gnutls TLS connection,
|
// Conn gnutls TLS connection,
|
||||||
// use Listen, Dial, Server or Client create this object
|
// use Listen, Dial, Server or Client create this object
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
c net.Conn
|
c net.Conn
|
||||||
handshake bool
|
handshake bool
|
||||||
sess *C.struct_session
|
sess *C.struct_session
|
||||||
cservname *C.char
|
cservname *C.char
|
||||||
state *ConnectionState
|
state *ConnectionState
|
||||||
cfg *Config
|
cfg *Config
|
||||||
closed bool
|
closed bool
|
||||||
lock *sync.Mutex
|
lock *sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config gnutls TLS configure,
|
// Config gnutls TLS configure,
|
||||||
type Config struct {
|
type Config struct {
|
||||||
ServerName string
|
ServerName string
|
||||||
Certificates []*Certificate
|
Certificates []*Certificate
|
||||||
InsecureSkipVerify bool
|
InsecureSkipVerify bool
|
||||||
NextProtos []string
|
NextProtos []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConnectionState gnutls TLS connection state
|
// ConnectionState gnutls TLS connection state
|
||||||
type ConnectionState struct {
|
type ConnectionState struct {
|
||||||
// SNI name client send
|
// SNI name client send
|
||||||
ServerName string
|
ServerName string
|
||||||
// selected ALPN protocl
|
// selected ALPN protocl
|
||||||
NegotiatedProtocol string
|
NegotiatedProtocol string
|
||||||
HandshakeComplete bool
|
HandshakeComplete bool
|
||||||
// TLS version number, ex: 0x303
|
// TLS version number, ex: 0x303
|
||||||
Version uint16
|
Version uint16
|
||||||
// TLS version number, ex: TLS1.0
|
// TLS version number, ex: TLS1.0
|
||||||
VersionName string
|
VersionName string
|
||||||
// peer's certificate
|
// peer's certificate
|
||||||
PeerCertificate *Certificate
|
PeerCertificate *Certificate
|
||||||
}
|
}
|
||||||
|
|
||||||
type listener struct {
|
type listener struct {
|
||||||
l net.Listener
|
l net.Listener
|
||||||
c *Config
|
c *Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accept
|
// Accept
|
||||||
func (l *listener) Accept() (net.Conn, error) {
|
func (l *listener) Accept() (net.Conn, error) {
|
||||||
c, err := l.l.Accept()
|
c, err := l.l.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return Server(c, l.c)
|
return Server(c, l.c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close
|
// Close
|
||||||
func (l *listener) Close() error {
|
func (l *listener) Close() error {
|
||||||
return l.l.Close()
|
return l.l.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Addr
|
// Addr
|
||||||
func (l *listener) Addr() net.Addr {
|
func (l *listener) Addr() net.Addr {
|
||||||
return l.l.Addr()
|
return l.l.Addr()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dial dial to (network, addr) and create a gnutls Conn
|
// Dial dial to (network, addr) and create a gnutls Conn
|
||||||
func Dial(network, addr string, cfg *Config) (*Conn, error) {
|
func Dial(network, addr string, cfg *Config) (*Conn, error) {
|
||||||
c, err := net.Dial(network, addr)
|
c, err := net.Dial(network, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return Client(c, cfg)
|
return Client(c, cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Listen create a gnutls listener on (network, addr),
|
// Listen create a gnutls listener on (network, addr),
|
||||||
func Listen(network, addr string, cfg *Config) (net.Listener, error) {
|
func Listen(network, addr string, cfg *Config) (net.Listener, error) {
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
return nil, errors.New("config is needed")
|
return nil, errors.New("config is needed")
|
||||||
}
|
}
|
||||||
l, err := net.Listen(network, addr)
|
l, err := net.Listen(network, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &listener{l, cfg}, nil
|
return &listener{l, cfg}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Server create a server TLS Conn on c
|
// Server create a server TLS Conn on c
|
||||||
func Server(c net.Conn, cfg *Config) (*Conn, error) {
|
func Server(c net.Conn, cfg *Config) (*Conn, error) {
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
return nil, errors.New("config is needed")
|
return nil, errors.New("config is needed")
|
||||||
}
|
}
|
||||||
|
|
||||||
var sess = C.init_gnutls_server_session()
|
var sess = C.init_gnutls_server_session()
|
||||||
|
|
||||||
conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)}
|
conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)}
|
||||||
|
|
||||||
C.set_data(sess, C.size_t(uintptr(unsafe.Pointer(conn))))
|
C.set_data(sess, C.size_t(uintptr(unsafe.Pointer(conn))))
|
||||||
C.set_callback(sess)
|
C.set_callback(sess)
|
||||||
|
|
||||||
if cfg.NextProtos != nil {
|
if cfg.NextProtos != nil {
|
||||||
if err := setAlpnProtocols(sess, cfg); err != nil {
|
if err := setAlpnProtocols(sess, cfg); err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
runtime.SetFinalizer(conn, (*Conn).free)
|
runtime.SetFinalizer(conn, (*Conn).free)
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client create a client TLS Conn on c
|
// Client create a client TLS Conn on c
|
||||||
func Client(c net.Conn, cfg *Config) (*Conn, error) {
|
func Client(c net.Conn, cfg *Config) (*Conn, error) {
|
||||||
var sess = C.init_gnutls_client_session()
|
var sess = C.init_gnutls_client_session()
|
||||||
|
|
||||||
conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)}
|
conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)}
|
||||||
|
|
||||||
C.set_data(sess, C.size_t(uintptr(unsafe.Pointer(conn))))
|
C.set_data(sess, C.size_t(uintptr(unsafe.Pointer(conn))))
|
||||||
C.set_callback(sess)
|
C.set_callback(sess)
|
||||||
|
|
||||||
if cfg != nil {
|
if cfg != nil {
|
||||||
if cfg.ServerName != "" {
|
if cfg.ServerName != "" {
|
||||||
srvname := C.CString(cfg.ServerName)
|
srvname := C.CString(cfg.ServerName)
|
||||||
//defer C.free(unsafe.Pointer(srvname))
|
//defer C.free(unsafe.Pointer(srvname))
|
||||||
conn.cservname = srvname
|
conn.cservname = srvname
|
||||||
C.gnutls_server_name_set(sess.session, GNUTLS_NAME_DNS,
|
C.gnutls_server_name_set(sess.session, GNUTLS_NAME_DNS,
|
||||||
unsafe.Pointer(srvname), C.size_t(len(cfg.ServerName)))
|
unsafe.Pointer(srvname), C.size_t(len(cfg.ServerName)))
|
||||||
}
|
}
|
||||||
|
|
||||||
if !cfg.InsecureSkipVerify {
|
if !cfg.InsecureSkipVerify {
|
||||||
if conn.cservname != nil {
|
if conn.cservname != nil {
|
||||||
C.gnutls_session_set_verify_cert(sess.session, conn.cservname, 0)
|
C.gnutls_session_set_verify_cert(sess.session, conn.cservname, 0)
|
||||||
} else {
|
} else {
|
||||||
C.gnutls_session_set_verify_cert(sess.session, nil, 0)
|
C.gnutls_session_set_verify_cert(sess.session, nil, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.NextProtos != nil {
|
if cfg.NextProtos != nil {
|
||||||
if err := setAlpnProtocols(sess, cfg); err != nil {
|
if err := setAlpnProtocols(sess, cfg); err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
C.gnutls_session_set_verify_cert(sess.session, nil, 0)
|
C.gnutls_session_set_verify_cert(sess.session, nil, 0)
|
||||||
}
|
}
|
||||||
runtime.SetFinalizer(conn, (*Conn).free)
|
runtime.SetFinalizer(conn, (*Conn).free)
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setAlpnProtocols(sess *C.struct_session, cfg *Config) error {
|
func setAlpnProtocols(sess *C.struct_session, cfg *Config) error {
|
||||||
arg := make([](*C.char), 0)
|
arg := make([](*C.char), 0)
|
||||||
for _, s := range cfg.NextProtos {
|
for _, s := range cfg.NextProtos {
|
||||||
cbuf := C.CString(s)
|
cbuf := C.CString(s)
|
||||||
defer C.free(unsafe.Pointer(cbuf))
|
defer C.free(unsafe.Pointer(cbuf))
|
||||||
arg = append(arg, (*C.char)(cbuf))
|
arg = append(arg, (*C.char)(cbuf))
|
||||||
}
|
}
|
||||||
ret := C.alpn_set_protocols(sess,
|
ret := C.alpn_set_protocols(sess,
|
||||||
(**C.char)(unsafe.Pointer(&arg[0])), C.int(len(cfg.NextProtos)))
|
(**C.char)(unsafe.Pointer(&arg[0])), C.int(len(cfg.NextProtos)))
|
||||||
if int(ret) < 0 {
|
if int(ret) < 0 {
|
||||||
return fmt.Errorf("set alpn failed: %s", C.GoString(C.gnutls_strerror(ret)))
|
return fmt.Errorf("set alpn failed: %s", C.GoString(C.gnutls_strerror(ret)))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handshake call handshake for TLS Conn,
|
// Handshake call handshake for TLS Conn,
|
||||||
// this function will call automatic on Read/Write, if not handshake yet
|
// this function will call automatic on Read/Write, if not handshake yet
|
||||||
func (c *Conn) Handshake() error {
|
func (c *Conn) Handshake() error {
|
||||||
c.lock.Lock()
|
c.lock.Lock()
|
||||||
defer c.lock.Unlock()
|
defer c.lock.Unlock()
|
||||||
if c.handshake {
|
if c.handshake {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
ret := C.handshake(c.sess)
|
ret := C.handshake(c.sess)
|
||||||
if int(ret) < 0 {
|
if int(ret) < 0 {
|
||||||
return fmt.Errorf("handshake error: %s", C.GoString(C.gnutls_strerror(ret)))
|
return fmt.Errorf("handshake error: %s", C.GoString(C.gnutls_strerror(ret)))
|
||||||
}
|
}
|
||||||
c.handshake = true
|
c.handshake = true
|
||||||
//log.Println("handshake done")
|
//log.Println("handshake done")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read read application data from TLS connection
|
// Read read application data from TLS connection
|
||||||
func (c *Conn) Read(buf []byte) (n int, err error) {
|
func (c *Conn) Read(buf []byte) (n int, err error) {
|
||||||
err = c.Handshake()
|
err = c.Handshake()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(buf) == 0 {
|
if len(buf) == 0 {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
if c.closed {
|
if c.closed {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
//bufLen := len(buf)
|
//bufLen := len(buf)
|
||||||
//cbuf := C.malloc(C.size_t(bufLen))
|
//cbuf := C.malloc(C.size_t(bufLen))
|
||||||
//defer C.free(cbuf)
|
//defer C.free(cbuf)
|
||||||
|
|
||||||
ret := C.gnutls_record_recv(c.sess.session,
|
ret := C.gnutls_record_recv(c.sess.session,
|
||||||
unsafe.Pointer(&buf[0]), C.size_t(len(buf)))
|
unsafe.Pointer(&buf[0]), C.size_t(len(buf)))
|
||||||
if int(ret) < 0 {
|
if int(ret) < 0 {
|
||||||
return 0, fmt.Errorf("read error: %s",
|
return 0, fmt.Errorf("read error: %s",
|
||||||
C.GoString(C.gnutls_strerror(C.int(ret))))
|
C.GoString(C.gnutls_strerror(C.int(ret))))
|
||||||
}
|
}
|
||||||
|
|
||||||
if int(ret) == 0 {
|
if int(ret) == 0 {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
n = int(ret)
|
n = int(ret)
|
||||||
//gobuf2 := C.GoBytes(cbuf, C.int(ret))
|
//gobuf2 := C.GoBytes(cbuf, C.int(ret))
|
||||||
//copy(buf, gobuf2)
|
//copy(buf, gobuf2)
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write write application data to TLS connection
|
// Write write application data to TLS connection
|
||||||
func (c *Conn) Write(buf []byte) (n int, err error) {
|
func (c *Conn) Write(buf []byte) (n int, err error) {
|
||||||
err = c.Handshake()
|
err = c.Handshake()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// user may call Write(nil) to do handshake
|
// user may call Write(nil) to do handshake
|
||||||
if len(buf) == 0 {
|
if len(buf) == 0 {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
if c.closed {
|
if c.closed {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
//cbuf := C.CBytes(buf)
|
//cbuf := C.CBytes(buf)
|
||||||
//defer C.free(cbuf)
|
//defer C.free(cbuf)
|
||||||
|
|
||||||
ret := C.gnutls_record_send(c.sess.session,
|
ret := C.gnutls_record_send(c.sess.session,
|
||||||
unsafe.Pointer(&buf[0]), C.size_t(len(buf)))
|
unsafe.Pointer(&buf[0]), C.size_t(len(buf)))
|
||||||
n = int(ret)
|
n = int(ret)
|
||||||
|
|
||||||
if n < 0 {
|
if n < 0 {
|
||||||
return 0, fmt.Errorf("write error: %s",
|
return 0, fmt.Errorf("write error: %s",
|
||||||
C.GoString(C.gnutls_strerror(C.int(ret))))
|
C.GoString(C.gnutls_strerror(C.int(ret))))
|
||||||
}
|
}
|
||||||
|
|
||||||
if int(ret) == 0 {
|
if int(ret) == 0 {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close close the TLS conn and destroy the tls context
|
// Close close the TLS conn and destroy the tls context
|
||||||
func (c *Conn) Close() error {
|
func (c *Conn) Close() error {
|
||||||
c.lock.Lock()
|
c.lock.Lock()
|
||||||
defer c.lock.Unlock()
|
defer c.lock.Unlock()
|
||||||
if c.closed {
|
if c.closed {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
C.gnutls_record_send(c.sess.session, nil, 0)
|
C.gnutls_record_send(c.sess.session, nil, 0)
|
||||||
C.session_destroy(c.sess)
|
C.session_destroy(c.sess)
|
||||||
c.c.Close()
|
c.c.Close()
|
||||||
if c.cservname != nil {
|
if c.cservname != nil {
|
||||||
C.free(unsafe.Pointer(c.cservname))
|
C.free(unsafe.Pointer(c.cservname))
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.state != nil && c.state.PeerCertificate != nil {
|
if c.state != nil && c.state.PeerCertificate != nil {
|
||||||
c.state.PeerCertificate.Free()
|
c.state.PeerCertificate.Free()
|
||||||
}
|
}
|
||||||
c.closed = true
|
c.closed = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) free() {
|
func (c *Conn) free() {
|
||||||
//log.Println("free conn")
|
//log.Println("free conn")
|
||||||
c.Close()
|
c.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetWriteDeadline implements net.Conn
|
// SetWriteDeadline implements net.Conn
|
||||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||||
return c.c.SetWriteDeadline(t)
|
return c.c.SetWriteDeadline(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetReadDeadline implements net.Conn
|
// SetReadDeadline implements net.Conn
|
||||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||||
return c.c.SetReadDeadline(t)
|
return c.c.SetReadDeadline(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoteAddr implements net.Conn
|
// RemoteAddr implements net.Conn
|
||||||
func (c *Conn) RemoteAddr() net.Addr {
|
func (c *Conn) RemoteAddr() net.Addr {
|
||||||
return c.c.RemoteAddr()
|
return c.c.RemoteAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
// LocalAddr implements net.Conn
|
// LocalAddr implements net.Conn
|
||||||
func (c *Conn) LocalAddr() net.Addr {
|
func (c *Conn) LocalAddr() net.Addr {
|
||||||
return c.c.LocalAddr()
|
return c.c.LocalAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetDeadline implements net.Conn
|
// SetDeadline implements net.Conn
|
||||||
func (c *Conn) SetDeadline(t time.Time) error {
|
func (c *Conn) SetDeadline(t time.Time) error {
|
||||||
return c.c.SetDeadline(t)
|
return c.c.SetDeadline(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConnectionState get TLS connection state
|
// ConnectionState get TLS connection state
|
||||||
func (c *Conn) ConnectionState() *ConnectionState {
|
func (c *Conn) ConnectionState() *ConnectionState {
|
||||||
if c.state != nil {
|
if c.state != nil {
|
||||||
return c.state
|
return c.state
|
||||||
}
|
}
|
||||||
version :=
|
version :=
|
||||||
uint16(C.gnutls_protocol_get_version(c.sess.session))
|
uint16(C.gnutls_protocol_get_version(c.sess.session))
|
||||||
|
|
||||||
versionname := C.GoString(
|
versionname := C.GoString(
|
||||||
C.gnutls_protocol_get_name(C.gnutls_protocol_t(version)))
|
C.gnutls_protocol_get_name(C.gnutls_protocol_t(version)))
|
||||||
|
|
||||||
state := &ConnectionState{
|
state := &ConnectionState{
|
||||||
NegotiatedProtocol: c.getAlpnSelectedProtocol(),
|
NegotiatedProtocol: c.getAlpnSelectedProtocol(),
|
||||||
Version: version,
|
Version: version,
|
||||||
HandshakeComplete: c.handshake,
|
HandshakeComplete: c.handshake,
|
||||||
ServerName: c.getServerName(),
|
ServerName: c.getServerName(),
|
||||||
VersionName: versionname,
|
VersionName: versionname,
|
||||||
PeerCertificate: c.getPeerCertificate(),
|
PeerCertificate: c.getPeerCertificate(),
|
||||||
}
|
}
|
||||||
c.state = state
|
c.state = state
|
||||||
return state
|
return state
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) getPeerCertificate() *Certificate {
|
func (c *Conn) getPeerCertificate() *Certificate {
|
||||||
var size int
|
var size int
|
||||||
st := C.get_peer_certificate(c.sess.session, (*C.int)(unsafe.Pointer(&size)))
|
st := C.get_peer_certificate(c.sess.session, (*C.int)(unsafe.Pointer(&size)))
|
||||||
if st == nil {
|
if st == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
cert := &Certificate{cert: st, certSize: C.int(size)}
|
cert := &Certificate{cert: st, certSize: C.int(size)}
|
||||||
runtime.SetFinalizer(cert, (*Certificate).free)
|
runtime.SetFinalizer(cert, (*Certificate).free)
|
||||||
return cert
|
return cert
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) getAlpnSelectedProtocol() string {
|
func (c *Conn) getAlpnSelectedProtocol() string {
|
||||||
cbuf := C.malloc(100)
|
cbuf := C.malloc(100)
|
||||||
defer C.free(cbuf)
|
defer C.free(cbuf)
|
||||||
|
|
||||||
ret := C.alpn_get_selected_protocol(c.sess, (*C.char)(cbuf))
|
ret := C.alpn_get_selected_protocol(c.sess, (*C.char)(cbuf))
|
||||||
if int(ret) < 0 {
|
if int(ret) < 0 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
alpnname := C.GoString((*C.char)(cbuf))
|
alpnname := C.GoString((*C.char)(cbuf))
|
||||||
return alpnname
|
return alpnname
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) getServerName() string {
|
func (c *Conn) getServerName() string {
|
||||||
buflen := 100
|
buflen := 100
|
||||||
nametype := GNUTLS_NAME_DNS
|
nametype := GNUTLS_NAME_DNS
|
||||||
cbuf := C.malloc(C.size_t(buflen))
|
cbuf := C.malloc(C.size_t(buflen))
|
||||||
defer C.free(cbuf)
|
defer C.free(cbuf)
|
||||||
|
|
||||||
ret := C.gnutls_server_name_get(c.sess.session, cbuf,
|
ret := C.gnutls_server_name_get(c.sess.session, cbuf,
|
||||||
(*C.size_t)(unsafe.Pointer(&buflen)),
|
(*C.size_t)(unsafe.Pointer(&buflen)),
|
||||||
(*C.uint)(unsafe.Pointer(&nametype)), 0)
|
(*C.uint)(unsafe.Pointer(&nametype)), 0)
|
||||||
if int(ret) < 0 {
|
if int(ret) < 0 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
name := C.GoString((*C.char)(cbuf))
|
name := C.GoString((*C.char)(cbuf))
|
||||||
return name
|
return name
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
var bufPool = &sync.Pool{
|
var bufPool = &sync.Pool{
|
||||||
New: func() interface{} {
|
New: func() interface{} {
|
||||||
return make([]byte, 16*1024)
|
return make([]byte, 16*1024)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// onDataReadCallback callback function for gnutls library want to read data from network
|
// onDataReadCallback callback function for gnutls library want to read data from network
|
||||||
//
|
//
|
||||||
//export onDataReadCallback
|
//export onDataReadCallback
|
||||||
func onDataReadCallback(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int {
|
func onDataReadCallback(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int {
|
||||||
//log.Println("read addr ", uintptr(d))
|
//log.Println("read addr ", uintptr(d))
|
||||||
conn := (*Conn)(unsafe.Pointer((uintptr(d))))
|
conn := (*Conn)(unsafe.Pointer((uintptr(d))))
|
||||||
|
|
||||||
buf := make([]byte, int(bufLen))
|
buf := make([]byte, int(bufLen))
|
||||||
/*
|
/*
|
||||||
_length := int(bufLen)
|
_length := int(bufLen)
|
||||||
buf := bufPool.Get().([]byte)
|
buf := bufPool.Get().([]byte)
|
||||||
if len(buf) < _length {
|
if len(buf) < _length {
|
||||||
buf = make([]byte, _length)
|
buf = make([]byte, _length)
|
||||||
}
|
}
|
||||||
defer bufPool.Put(buf)
|
defer bufPool.Put(buf)
|
||||||
*/
|
*/
|
||||||
|
|
||||||
n, err := conn.c.Read(buf[0:])
|
n, err := conn.c.Read(buf[0:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//log.Println(err)
|
//log.Println(err)
|
||||||
// 0 indicates connection termination
|
// 0 indicates connection termination
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
//cbuf2 := C.CBytes(buf[:n])
|
//cbuf2 := C.CBytes(buf[:n])
|
||||||
// d := C.CString(string(buf[:n]))
|
// d := C.CString(string(buf[:n]))
|
||||||
//defer C.free(cbuf2)
|
//defer C.free(cbuf2)
|
||||||
C.memcpy(unsafe.Pointer(cbuf), unsafe.Pointer(&buf[0]), C.size_t(n))
|
C.memcpy(unsafe.Pointer(cbuf), unsafe.Pointer(&buf[0]), C.size_t(n))
|
||||||
return C.int(n)
|
return C.int(n)
|
||||||
}
|
}
|
||||||
|
|
||||||
// onDataWriteCallback callback function for gnutls library want to send data to network
|
// onDataWriteCallback callback function for gnutls library want to send data to network
|
||||||
//
|
//
|
||||||
//export onDataWriteCallback
|
//export onDataWriteCallback
|
||||||
func onDataWriteCallback(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int {
|
func onDataWriteCallback(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int {
|
||||||
//log.Println("write addr ", uintptr(d), int(_l))
|
//log.Println("write addr ", uintptr(d), int(_l))
|
||||||
conn := (*Conn)(unsafe.Pointer((uintptr(d))))
|
conn := (*Conn)(unsafe.Pointer((uintptr(d))))
|
||||||
gobuf := C.GoBytes(unsafe.Pointer(cbuf), bufLen)
|
gobuf := C.GoBytes(unsafe.Pointer(cbuf), bufLen)
|
||||||
n, err := conn.c.Write(gobuf)
|
n, err := conn.c.Write(gobuf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//log.Println(err)
|
//log.Println(err)
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
return C.int(n)
|
return C.int(n)
|
||||||
}
|
}
|
||||||
|
|
||||||
// onDataTimeoutRead callback function for timeout read
|
// onDataTimeoutRead callback function for timeout read
|
||||||
//
|
//
|
||||||
//export onDataTimeoutRead
|
//export onDataTimeoutRead
|
||||||
func onDataTimeoutRead(d unsafe.Pointer, delay C.int) C.int {
|
func onDataTimeoutRead(d unsafe.Pointer, delay C.int) C.int {
|
||||||
log.Println("timeout pull function")
|
log.Println("timeout pull function")
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// onCertSelectCallback callback function for ceritificate select,
|
// onCertSelectCallback callback function for ceritificate select,
|
||||||
// this function select certificate from Config.Certificates field,
|
// this function select certificate from Config.Certificates field,
|
||||||
//
|
//
|
||||||
// on server side, this function select the certificate depend on SNI what client send,
|
// on server side, this function select the certificate depend on SNI what client send,
|
||||||
// if client not send SNI, select the Config.Certificates[0]
|
// if client not send SNI, select the Config.Certificates[0]
|
||||||
//
|
//
|
||||||
//export onCertSelectCallback
|
//export onCertSelectCallback
|
||||||
func onCertSelectCallback(ptr unsafe.Pointer, hostname *C.char,
|
func onCertSelectCallback(ptr unsafe.Pointer, hostname *C.char,
|
||||||
namelen C.int, pcertLength *C.int, cert **C.gnutls_pcert_st, privkey *C.gnutls_privkey_t) C.int {
|
namelen C.int, pcertLength *C.int, cert **C.gnutls_pcert_st, privkey *C.gnutls_privkey_t) C.int {
|
||||||
|
|
||||||
servername := C.GoStringN(hostname, namelen)
|
servername := C.GoStringN(hostname, namelen)
|
||||||
//log.Println("go cert select callback ", servername)
|
//log.Println("go cert select callback ", servername)
|
||||||
conn := (*Conn)(unsafe.Pointer((uintptr(ptr))))
|
conn := (*Conn)(unsafe.Pointer((uintptr(ptr))))
|
||||||
//log.Println(conn)
|
//log.Println(conn)
|
||||||
if int(namelen) == 0 && conn.cfg.Certificates != nil {
|
if int(namelen) == 0 && conn.cfg.Certificates != nil {
|
||||||
_cert := conn.cfg.Certificates[0]
|
_cert := conn.cfg.Certificates[0]
|
||||||
*pcertLength = _cert.certSize
|
*pcertLength = _cert.certSize
|
||||||
*cert = _cert.cert
|
*cert = _cert.cert
|
||||||
*privkey = _cert.privkey
|
*privkey = _cert.privkey
|
||||||
//log.Println("set pcert length ", _cert.certSize)
|
//log.Println("set pcert length ", _cert.certSize)
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
for _, _cert := range conn.cfg.Certificates {
|
for _, _cert := range conn.cfg.Certificates {
|
||||||
//log.Println(cert)
|
//log.Println(cert)
|
||||||
if _cert.matchName(servername) {
|
if _cert.matchName(servername) {
|
||||||
//log.Println("matched name ", _cert.names)
|
//log.Println("matched name ", _cert.names)
|
||||||
*pcertLength = _cert.certSize
|
*pcertLength = _cert.certSize
|
||||||
*cert = _cert.cert
|
*cert = _cert.cert
|
||||||
*privkey = _cert.privkey
|
*privkey = _cert.privkey
|
||||||
//log.Println("set pcert length ", _cert.certSize)
|
//log.Println("set pcert length ", _cert.certSize)
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if conn.cfg.Certificates != nil {
|
if conn.cfg.Certificates != nil {
|
||||||
_cert := conn.cfg.Certificates[0]
|
_cert := conn.cfg.Certificates[0]
|
||||||
*pcertLength = _cert.certSize
|
*pcertLength = _cert.certSize
|
||||||
*cert = _cert.cert
|
*cert = _cert.cert
|
||||||
*privkey = _cert.privkey
|
*privkey = _cert.privkey
|
||||||
//log.Println("set pcert length ", _cert.certSize)
|
//log.Println("set pcert length ", _cert.certSize)
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
*pcertLength = 0
|
*pcertLength = 0
|
||||||
//log.Println("set pcert length 0")
|
//log.Println("set pcert length 0")
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
C.init_xcred()
|
C.init_xcred()
|
||||||
C.init_priority_cache()
|
C.init_priority_cache()
|
||||||
}
|
}
|
||||||
|
@ -1,431 +1,431 @@
|
|||||||
package gnutls
|
package gnutls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"runtime"
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestTLSClient(t *testing.T) {
|
func TestTLSClient(t *testing.T) {
|
||||||
cert, err := tls.LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
|
cert, err := tls.LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("load certificate failed")
|
t.Fatal("load certificate failed")
|
||||||
}
|
}
|
||||||
l, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
|
l, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
|
||||||
Certificates: []tls.Certificate{cert},
|
Certificates: []tls.Certificate{cert},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("listen failed")
|
t.Fatal("listen failed")
|
||||||
}
|
}
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
addr := l.Addr().String()
|
addr := l.Addr().String()
|
||||||
t.Log("test server listen on ", addr)
|
t.Log("test server listen on ", addr)
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
c, err := l.Accept()
|
c, err := l.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
t.Logf("accept connection from %s", c.RemoteAddr())
|
t.Logf("accept connection from %s", c.RemoteAddr())
|
||||||
go func(c net.Conn) {
|
go func(c net.Conn) {
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
for {
|
for {
|
||||||
buf := make([]byte, 4096)
|
buf := make([]byte, 4096)
|
||||||
n, err := c.Read(buf)
|
n, err := c.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Log("connection closed")
|
t.Log("connection closed")
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if _, err = c.Write(buf[:n]); err != nil {
|
if _, err = c.Write(buf[:n]); err != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}(c)
|
}(c)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c, err := Dial("tcp", addr, &Config{InsecureSkipVerify: true})
|
c, err := Dial("tcp", addr, &Config{InsecureSkipVerify: true})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("gnutls dial ", err)
|
t.Fatal("gnutls dial ", err)
|
||||||
}
|
}
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
|
|
||||||
data := "hello, world"
|
data := "hello, world"
|
||||||
if _, err = c.Write([]byte(data)); err != nil {
|
if _, err = c.Write([]byte(data)); err != nil {
|
||||||
t.Fatal("gnutls write ", err)
|
t.Fatal("gnutls write ", err)
|
||||||
}
|
}
|
||||||
buf := make([]byte, 100)
|
buf := make([]byte, 100)
|
||||||
n, err := c.Read(buf)
|
n, err := c.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("gnutls read ", err)
|
t.Fatal("gnutls read ", err)
|
||||||
}
|
}
|
||||||
if string(buf[:n]) != data {
|
if string(buf[:n]) != data {
|
||||||
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
|
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTLSServer(t *testing.T) {
|
func TestTLSServer(t *testing.T) {
|
||||||
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
|
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
l, err := Listen("tcp", "127.0.0.1:0", &Config{
|
l, err := Listen("tcp", "127.0.0.1:0", &Config{
|
||||||
Certificates: []*Certificate{cert},
|
Certificates: []*Certificate{cert},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("gnutls listen ", err)
|
t.Fatal("gnutls listen ", err)
|
||||||
}
|
}
|
||||||
addr := l.Addr().String()
|
addr := l.Addr().String()
|
||||||
t.Log("test server listen on ", addr)
|
t.Log("test server listen on ", addr)
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
c, err := l.Accept()
|
c, err := l.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Log("gnutls accept ", err)
|
t.Log("gnutls accept ", err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
t.Log("accept connection from ", c.RemoteAddr())
|
t.Log("accept connection from ", c.RemoteAddr())
|
||||||
go func(c net.Conn) {
|
go func(c net.Conn) {
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
tlsconn := c.(*Conn)
|
tlsconn := c.(*Conn)
|
||||||
if err := tlsconn.Handshake(); err != nil {
|
if err := tlsconn.Handshake(); err != nil {
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
buf := make([]byte, 4096)
|
buf := make([]byte, 4096)
|
||||||
for {
|
for {
|
||||||
n, err := c.Read(buf[0:])
|
n, err := c.Read(buf[0:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Log("gnutls read ", err)
|
t.Log("gnutls read ", err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if _, err := c.Write(buf[:n]); err != nil {
|
if _, err := c.Write(buf[:n]); err != nil {
|
||||||
t.Log("gnutls write ", err)
|
t.Log("gnutls write ", err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}(c)
|
}(c)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c, err := tls.Dial("tcp", addr, &tls.Config{InsecureSkipVerify: true})
|
c, err := tls.Dial("tcp", addr, &tls.Config{InsecureSkipVerify: true})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("dial ", err)
|
t.Fatal("dial ", err)
|
||||||
}
|
}
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
|
|
||||||
data := "hello, world"
|
data := "hello, world"
|
||||||
if _, err := c.Write([]byte(data)); err != nil {
|
if _, err := c.Write([]byte(data)); err != nil {
|
||||||
t.Fatal("write ", err)
|
t.Fatal("write ", err)
|
||||||
}
|
}
|
||||||
buf := make([]byte, 100)
|
buf := make([]byte, 100)
|
||||||
n, err := c.Read(buf)
|
n, err := c.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("read ", err)
|
t.Fatal("read ", err)
|
||||||
}
|
}
|
||||||
if string(buf[:n]) != data {
|
if string(buf[:n]) != data {
|
||||||
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
|
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
|
||||||
}
|
}
|
||||||
runtime.GC()
|
runtime.GC()
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTLSALPNServer(t *testing.T) {
|
func TestTLSALPNServer(t *testing.T) {
|
||||||
serveralpn := []string{"a1", "a3", "a2"}
|
serveralpn := []string{"a1", "a3", "a2"}
|
||||||
clientalpn := []string{"a0", "a2", "a5"}
|
clientalpn := []string{"a0", "a2", "a5"}
|
||||||
expectedAlpn := "a2"
|
expectedAlpn := "a2"
|
||||||
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
|
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
l, err := Listen("tcp", "127.0.0.1:0", &Config{
|
l, err := Listen("tcp", "127.0.0.1:0", &Config{
|
||||||
Certificates: []*Certificate{cert},
|
Certificates: []*Certificate{cert},
|
||||||
NextProtos: serveralpn,
|
NextProtos: serveralpn,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("gnutls listen ", err)
|
t.Fatal("gnutls listen ", err)
|
||||||
}
|
}
|
||||||
addr := l.Addr().String()
|
addr := l.Addr().String()
|
||||||
t.Log("test server listen on ", addr)
|
t.Log("test server listen on ", addr)
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
c, err := l.Accept()
|
c, err := l.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Log("gnutls accept ", err)
|
t.Log("gnutls accept ", err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
t.Log("accept connection from ", c.RemoteAddr())
|
t.Log("accept connection from ", c.RemoteAddr())
|
||||||
go func(c net.Conn) {
|
go func(c net.Conn) {
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
tlsConn := c.(*Conn)
|
tlsConn := c.(*Conn)
|
||||||
if err := tlsConn.Handshake(); err != nil {
|
if err := tlsConn.Handshake(); err != nil {
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
connState := tlsConn.ConnectionState()
|
connState := tlsConn.ConnectionState()
|
||||||
t.Logf("%+v", connState)
|
t.Logf("%+v", connState)
|
||||||
buf := make([]byte, 4096)
|
buf := make([]byte, 4096)
|
||||||
for {
|
for {
|
||||||
n, err := c.Read(buf[0:])
|
n, err := c.Read(buf[0:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Log("gnutls read ", err)
|
t.Log("gnutls read ", err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if _, err := c.Write(buf[:n]); err != nil {
|
if _, err := c.Write(buf[:n]); err != nil {
|
||||||
t.Log("gnutls write ", err)
|
t.Log("gnutls write ", err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}(c)
|
}(c)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c, err := tls.Dial("tcp", addr, &tls.Config{
|
c, err := tls.Dial("tcp", addr, &tls.Config{
|
||||||
InsecureSkipVerify: true,
|
InsecureSkipVerify: true,
|
||||||
ServerName: "localhost",
|
ServerName: "localhost",
|
||||||
NextProtos: clientalpn,
|
NextProtos: clientalpn,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("dial ", err)
|
t.Fatal("dial ", err)
|
||||||
}
|
}
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
|
|
||||||
if err := c.Handshake(); err != nil {
|
if err := c.Handshake(); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
connState := c.ConnectionState()
|
connState := c.ConnectionState()
|
||||||
t.Logf("%+v", connState)
|
t.Logf("%+v", connState)
|
||||||
|
|
||||||
if connState.NegotiatedProtocol != expectedAlpn {
|
if connState.NegotiatedProtocol != expectedAlpn {
|
||||||
t.Errorf("expected alpn %s, got %s",
|
t.Errorf("expected alpn %s, got %s",
|
||||||
expectedAlpn, connState.NegotiatedProtocol)
|
expectedAlpn, connState.NegotiatedProtocol)
|
||||||
}
|
}
|
||||||
|
|
||||||
data := "hello, world"
|
data := "hello, world"
|
||||||
if _, err := c.Write([]byte(data)); err != nil {
|
if _, err := c.Write([]byte(data)); err != nil {
|
||||||
t.Fatal("write ", err)
|
t.Fatal("write ", err)
|
||||||
}
|
}
|
||||||
buf := make([]byte, 100)
|
buf := make([]byte, 100)
|
||||||
n, err := c.Read(buf)
|
n, err := c.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("read ", err)
|
t.Fatal("read ", err)
|
||||||
}
|
}
|
||||||
if string(buf[:n]) != data {
|
if string(buf[:n]) != data {
|
||||||
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
|
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
|
||||||
}
|
}
|
||||||
runtime.GC()
|
runtime.GC()
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTLSALPNClient(t *testing.T) {
|
func TestTLSALPNClient(t *testing.T) {
|
||||||
serveralpn := []string{"a1", "a3", "a2"}
|
serveralpn := []string{"a1", "a3", "a2"}
|
||||||
clientalpn := []string{"a0", "a2", "a5"}
|
clientalpn := []string{"a0", "a2", "a5"}
|
||||||
expectedAlpn := "a2"
|
expectedAlpn := "a2"
|
||||||
|
|
||||||
cert, err := tls.LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
|
cert, err := tls.LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("load key failed")
|
t.Fatal("load key failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
l, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
|
l, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
|
||||||
Certificates: []tls.Certificate{cert},
|
Certificates: []tls.Certificate{cert},
|
||||||
NextProtos: serveralpn,
|
NextProtos: serveralpn,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("tls listen ", err)
|
t.Fatal("tls listen ", err)
|
||||||
}
|
}
|
||||||
addr := l.Addr().String()
|
addr := l.Addr().String()
|
||||||
t.Log("test server listen on ", addr)
|
t.Log("test server listen on ", addr)
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
c, err := l.Accept()
|
c, err := l.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Log("gnutls accept ", err)
|
t.Log("gnutls accept ", err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
t.Log("accept connection from ", c.RemoteAddr())
|
t.Log("accept connection from ", c.RemoteAddr())
|
||||||
go func(c net.Conn) {
|
go func(c net.Conn) {
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
tlsConn := c.(*tls.Conn)
|
tlsConn := c.(*tls.Conn)
|
||||||
if err := tlsConn.Handshake(); err != nil {
|
if err := tlsConn.Handshake(); err != nil {
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
connState := tlsConn.ConnectionState()
|
connState := tlsConn.ConnectionState()
|
||||||
t.Logf("%+v", connState)
|
t.Logf("%+v", connState)
|
||||||
buf := make([]byte, 4096)
|
buf := make([]byte, 4096)
|
||||||
for {
|
for {
|
||||||
n, err := c.Read(buf[0:])
|
n, err := c.Read(buf[0:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Log("tls read ", err)
|
t.Log("tls read ", err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if _, err := c.Write(buf[:n]); err != nil {
|
if _, err := c.Write(buf[:n]); err != nil {
|
||||||
t.Log("tls write ", err)
|
t.Log("tls write ", err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}(c)
|
}(c)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c, err := Dial("tcp", addr, &Config{InsecureSkipVerify: true,
|
c, err := Dial("tcp", addr, &Config{InsecureSkipVerify: true,
|
||||||
ServerName: "localhost",
|
ServerName: "localhost",
|
||||||
NextProtos: clientalpn,
|
NextProtos: clientalpn,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("dial ", err)
|
t.Fatal("dial ", err)
|
||||||
}
|
}
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
|
|
||||||
if err := c.Handshake(); err != nil {
|
if err := c.Handshake(); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
connState := c.ConnectionState()
|
connState := c.ConnectionState()
|
||||||
t.Logf("%+v", connState)
|
t.Logf("%+v", connState)
|
||||||
|
|
||||||
if connState.NegotiatedProtocol != expectedAlpn {
|
if connState.NegotiatedProtocol != expectedAlpn {
|
||||||
t.Errorf("expected alpn %s, got %s",
|
t.Errorf("expected alpn %s, got %s",
|
||||||
expectedAlpn, connState.NegotiatedProtocol)
|
expectedAlpn, connState.NegotiatedProtocol)
|
||||||
}
|
}
|
||||||
|
|
||||||
data := "hello, world"
|
data := "hello, world"
|
||||||
if _, err := c.Write([]byte(data)); err != nil {
|
if _, err := c.Write([]byte(data)); err != nil {
|
||||||
t.Fatal("write ", err)
|
t.Fatal("write ", err)
|
||||||
}
|
}
|
||||||
buf := make([]byte, 100)
|
buf := make([]byte, 100)
|
||||||
n, err := c.Read(buf)
|
n, err := c.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("read ", err)
|
t.Fatal("read ", err)
|
||||||
}
|
}
|
||||||
if string(buf[:n]) != data {
|
if string(buf[:n]) != data {
|
||||||
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
|
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
|
||||||
}
|
}
|
||||||
runtime.GC()
|
runtime.GC()
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTLSServerSNI(t *testing.T) {
|
func TestTLSServerSNI(t *testing.T) {
|
||||||
certificates := []*Certificate{}
|
certificates := []*Certificate{}
|
||||||
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
|
cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("load key failed")
|
t.Fatal("load key failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
certificates = append(certificates, cert)
|
certificates = append(certificates, cert)
|
||||||
cert, err = LoadX509KeyPair("testdata/server2.crt", "testdata/server2.key")
|
cert, err = LoadX509KeyPair("testdata/server2.crt", "testdata/server2.key")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("load key failed")
|
t.Fatal("load key failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
certificates = append(certificates, cert)
|
certificates = append(certificates, cert)
|
||||||
cert, err = LoadX509KeyPair("testdata/server3.crt", "testdata/server3.key")
|
cert, err = LoadX509KeyPair("testdata/server3.crt", "testdata/server3.key")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("load key failed")
|
t.Fatal("load key failed")
|
||||||
}
|
}
|
||||||
certificates = append(certificates, cert)
|
certificates = append(certificates, cert)
|
||||||
|
|
||||||
l, err := Listen("tcp", "127.0.0.1:0", &Config{
|
l, err := Listen("tcp", "127.0.0.1:0", &Config{
|
||||||
Certificates: certificates,
|
Certificates: certificates,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
addr := l.Addr().String()
|
addr := l.Addr().String()
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
c, err := l.Accept()
|
c, err := l.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
go func(c net.Conn) {
|
go func(c net.Conn) {
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
tlsconn := c.(*Conn)
|
tlsconn := c.(*Conn)
|
||||||
if err := tlsconn.Handshake(); err != nil {
|
if err := tlsconn.Handshake(); err != nil {
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
state := tlsconn.ConnectionState()
|
state := tlsconn.ConnectionState()
|
||||||
fmt.Fprintf(c, state.ServerName)
|
fmt.Fprintf(c, state.ServerName)
|
||||||
}(c)
|
}(c)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for _, cfg := range []struct {
|
for _, cfg := range []struct {
|
||||||
serverName string
|
serverName string
|
||||||
commonName string
|
commonName string
|
||||||
}{
|
}{
|
||||||
{"abc.com", "abc.com"},
|
{"abc.com", "abc.com"},
|
||||||
{"example.com", "example.com"},
|
{"example.com", "example.com"},
|
||||||
{"a.aaa.com", "*.aaa.com"},
|
{"a.aaa.com", "*.aaa.com"},
|
||||||
{"b.aaa.com", "*.aaa.com"},
|
{"b.aaa.com", "*.aaa.com"},
|
||||||
} {
|
} {
|
||||||
conn, err := tls.Dial("tcp", addr, &tls.Config{
|
conn, err := tls.Dial("tcp", addr, &tls.Config{
|
||||||
ServerName: cfg.serverName,
|
ServerName: cfg.serverName,
|
||||||
InsecureSkipVerify: true,
|
InsecureSkipVerify: true,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
state := conn.ConnectionState()
|
state := conn.ConnectionState()
|
||||||
_commonName := state.PeerCertificates[0].Subject.CommonName
|
_commonName := state.PeerCertificates[0].Subject.CommonName
|
||||||
if _commonName != cfg.commonName {
|
if _commonName != cfg.commonName {
|
||||||
t.Errorf("expect: %s, got: %s", cfg.commonName, _commonName)
|
t.Errorf("expect: %s, got: %s", cfg.commonName, _commonName)
|
||||||
}
|
}
|
||||||
buf := make([]byte, 100)
|
buf := make([]byte, 100)
|
||||||
n, err := conn.Read(buf)
|
n, err := conn.Read(buf)
|
||||||
if err != nil && err != io.EOF {
|
if err != nil && err != io.EOF {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if !bytes.Equal(buf[:n], []byte(cfg.serverName)) {
|
if !bytes.Equal(buf[:n], []byte(cfg.serverName)) {
|
||||||
t.Errorf("expect %s, got %s", cfg.serverName, string(buf[:n]))
|
t.Errorf("expect %s, got %s", cfg.serverName, string(buf[:n]))
|
||||||
}
|
}
|
||||||
conn.Close()
|
conn.Close()
|
||||||
}
|
}
|
||||||
runtime.GC()
|
runtime.GC()
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTLSGetPeerCert(t *testing.T) {
|
func TestTLSGetPeerCert(t *testing.T) {
|
||||||
conn, err := Dial("tcp", "www.ratafee.nl:443", &Config{
|
conn, err := Dial("tcp", "www.ratafee.nl:443", &Config{
|
||||||
ServerName: "www.ratafee.nl",
|
ServerName: "www.ratafee.nl",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
//tlsconn := conn.(*Conn)
|
//tlsconn := conn.(*Conn)
|
||||||
if err := conn.Handshake(); err != nil {
|
if err := conn.Handshake(); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
state := conn.ConnectionState()
|
state := conn.ConnectionState()
|
||||||
for i := 0; i < int(state.PeerCertificate.certSize); i++ {
|
for i := 0; i < int(state.PeerCertificate.certSize); i++ {
|
||||||
t.Log(state.PeerCertificate.getCertString(i, 1))
|
t.Log(state.PeerCertificate.getCertString(i, 1))
|
||||||
}
|
}
|
||||||
|
|
||||||
req, _ := http.NewRequest("GET", "https://www.ratafee.nl/httpbin/ip", nil)
|
req, _ := http.NewRequest("GET", "https://www.ratafee.nl/httpbin/ip", nil)
|
||||||
req.Write(conn)
|
req.Write(conn)
|
||||||
r := bufio.NewReader(conn)
|
r := bufio.NewReader(conn)
|
||||||
resp, err := http.ReadResponse(r, req)
|
resp, err := http.ReadResponse(r, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
var buf = new(bytes.Buffer)
|
var buf = new(bytes.Buffer)
|
||||||
resp.Write(buf)
|
resp.Write(buf)
|
||||||
t.Logf("%s", string(buf.Bytes()))
|
t.Logf("%s", string(buf.Bytes()))
|
||||||
runtime.GC()
|
runtime.GC()
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue