change cipher function signature

master
fangdingjun 6 years ago
parent e95a82779c
commit 2737d50d0c

@ -6,52 +6,70 @@ package gnutls
*/ */
import "C" import "C"
import ( import (
"errors"
"fmt" "fmt"
"log" "log"
) )
// CipherType cipher type
type CipherType int
const ( const (
GNUTLS_CIPHER_AES_128_CBC = 4 GNUTLS_CIPHER_AES_128_CBC CipherType = 4
GNUTLS_CIPHER_AES_256_CBC = 5 GNUTLS_CIPHER_AES_256_CBC CipherType = 5
GNUTLS_CIPHER_ARCFOUR_40 = 6 GNUTLS_CIPHER_ARCFOUR_40 CipherType = 6
GNUTLS_CIPHER_CAMELLIA_128_CBC = 7 GNUTLS_CIPHER_CAMELLIA_128_CBC CipherType = 7
GNUTLS_CIPHER_CAMELLIA_256_CBC = 8 GNUTLS_CIPHER_CAMELLIA_256_CBC CipherType = 8
GNUTLS_CIPHER_AES_192_CBC = 9 GNUTLS_CIPHER_AES_192_CBC CipherType = 9
GNUTLS_CIPHER_AES_128_GCM = 10 GNUTLS_CIPHER_AES_128_GCM CipherType = 10
GNUTLS_CIPHER_AES_256_GCM = 11 GNUTLS_CIPHER_AES_256_GCM CipherType = 11
GNUTLS_CIPHER_CAMELLIA_192_CBC = 12 GNUTLS_CIPHER_CAMELLIA_192_CBC CipherType = 12
GNUTLS_CIPHER_SALSA20_256 = 13 GNUTLS_CIPHER_SALSA20_256 CipherType = 13
GNUTLS_CIPHER_ESTREAM_SALSA20_256 = 14 GNUTLS_CIPHER_ESTREAM_SALSA20_256 CipherType = 14
GNUTLS_CIPHER_CAMELLIA_128_GCM = 15 GNUTLS_CIPHER_CAMELLIA_128_GCM CipherType = 15
GNUTLS_CIPHER_CAMELLIA_256_GCM = 16 GNUTLS_CIPHER_CAMELLIA_256_GCM CipherType = 16
GNUTLS_CIPHER_RC2_40_CBC = 17 GNUTLS_CIPHER_RC2_40_CBC CipherType = 17
GNUTLS_CIPHER_DES_CBC = 18 GNUTLS_CIPHER_DES_CBC CipherType = 18
GNUTLS_CIPHER_AES_128_CCM = 19 GNUTLS_CIPHER_AES_128_CCM CipherType = 19
GNUTLS_CIPHER_AES_256_CCM = 20 GNUTLS_CIPHER_AES_256_CCM CipherType = 20
GNUTLS_CIPHER_AES_128_CCM_8 = 21 GNUTLS_CIPHER_AES_128_CCM_8 CipherType = 21
GNUTLS_CIPHER_AES_256_CCM_8 = 22 GNUTLS_CIPHER_AES_256_CCM_8 CipherType = 22
GNUTLS_CIPHER_CHACHA20_POLY1305 = 23 GNUTLS_CIPHER_CHACHA20_POLY1305 CipherType = 23
)
var (
// ErrBlockSize wrong block size
ErrBlockSize = errors.New("wrong block size")
// ErrKeyLength wrong key length
ErrKeyLength = errors.New("wrong key length")
// ErrIVLength wrong iv length
ErrIVLength = errors.New("wrong iv length")
) )
// Cipher cipher // Cipher gnutls cipher struct
type Cipher struct { type Cipher struct {
cipher C.gnutls_cipher_hd_t cipher C.gnutls_cipher_hd_t
t int t CipherType
blockSize int blockSize int
} }
// NewCipher create cipher // NewCipher create a new cipher by give type, key, iv
func NewCipher(t int, key []byte, iv []byte) (*Cipher, error) { //
// example:
// NewCipher(GNUTLS_CIPHER_AES_128_CBC, []byte("1234567890abcdef"), []byte("abcdef0123456789"))
//
// you can use GetCipherKeySize, GetCipherBlockSize, GetCipherIVSize to determine the given cipher 's key, block, iv size
func NewCipher(t CipherType, key []byte, iv []byte) (*Cipher, error) {
keysize := GetCipherKeySize(t) keysize := GetCipherKeySize(t)
ivSize := GetCipherIVSize(t) ivSize := GetCipherIVSize(t)
blocksize := GetCipherBlockSize(t) blocksize := GetCipherBlockSize(t)
//log.Printf("block size: %d, iv size: %d", int(ivSize), int(blockSize)) //log.Printf("block size: %d, iv size: %d", int(ivSize), int(blockSize))
if len(key) != int(keysize) { if len(key) != int(keysize) {
return nil, fmt.Errorf("wrong key size") return nil, ErrKeyLength
} }
if len(iv) != int(ivSize) { if len(iv) != int(ivSize) {
return nil, fmt.Errorf("wrong iv size") return nil, ErrIVLength
} }
ckey := C.CBytes(key) ckey := C.CBytes(key)
@ -68,10 +86,11 @@ func NewCipher(t int, key []byte, iv []byte) (*Cipher, error) {
return &Cipher{c, t, blocksize}, nil return &Cipher{c, t, blocksize}, nil
} }
// Encrypt encrypt // Encrypt encrypt the buf and place the encrypted data in dst,
func (c *Cipher) Encrypt(buf []byte) ([]byte, error) { // the buf size must multiple of cipher's block size
func (c *Cipher) Encrypt(dst, buf []byte) error {
if len(buf)%c.blockSize != 0 { if len(buf)%c.blockSize != 0 {
return nil, fmt.Errorf("wrong block size") return ErrBlockSize
} }
cbuf := C.CBytes(buf) cbuf := C.CBytes(buf)
@ -84,15 +103,18 @@ func (c *Cipher) Encrypt(buf []byte) ([]byte, error) {
ret := C.gnutls_cipher_encrypt2(c.cipher, cbuf, bufLen, dstBuf, bufLen) ret := C.gnutls_cipher_encrypt2(c.cipher, cbuf, bufLen, dstBuf, bufLen)
if int(ret) < 0 { if int(ret) < 0 {
return nil, fmt.Errorf("encrypt error: %s", C.GoString(C.gnutls_strerror(ret))) return fmt.Errorf("encrypt error: %s", C.GoString(C.gnutls_strerror(ret)))
} }
return C.GoBytes(dstBuf, C.int(bufLen)), nil _buf := C.GoBytes(dstBuf, C.int(bufLen))
copy(dst, _buf)
return nil
} }
// Decrypt decrypt // Decrypt decrypt the buf and place the decrypted data in dst,
func (c *Cipher) Decrypt(buf []byte) ([]byte, error) { // the buf size must multiple of cipher's block size
func (c *Cipher) Decrypt(dst, buf []byte) error {
if len(buf)%c.blockSize != 0 { if len(buf)%c.blockSize != 0 {
return nil, fmt.Errorf("wrong block size") return ErrBlockSize
} }
cbuf := C.CBytes(buf) cbuf := C.CBytes(buf)
@ -105,28 +127,30 @@ func (c *Cipher) Decrypt(buf []byte) ([]byte, error) {
ret := C.gnutls_cipher_decrypt2(c.cipher, cbuf, bufLen, dstBuf, bufLen) ret := C.gnutls_cipher_decrypt2(c.cipher, cbuf, bufLen, dstBuf, bufLen)
if int(ret) < 0 { if int(ret) < 0 {
return nil, fmt.Errorf("decrypt error: %s", C.GoString(C.gnutls_strerror(ret))) return fmt.Errorf("decrypt error: %s", C.GoString(C.gnutls_strerror(ret)))
} }
return C.GoBytes(dstBuf, C.int(bufLen)), nil _buf := C.GoBytes(dstBuf, C.int(bufLen))
copy(dst, _buf)
return nil
} }
// Close destroy the cipher // Close destroy the cipher context
func (c *Cipher) Close() error { func (c *Cipher) Close() error {
C.gnutls_cipher_deinit(c.cipher) C.gnutls_cipher_deinit(c.cipher)
return nil return nil
} }
// GetCipherKeySize get the cipher algorithm key length // GetCipherKeySize get the cipher algorithm key length
func GetCipherKeySize(t int) int { func GetCipherKeySize(t CipherType) int {
return int(C.gnutls_cipher_get_key_size(C.gnutls_cipher_algorithm_t(t))) return int(C.gnutls_cipher_get_key_size(C.gnutls_cipher_algorithm_t(t)))
} }
// GetCipherIVSize get the cipher algorithm iv length // GetCipherIVSize get the cipher algorithm iv length
func GetCipherIVSize(t int) int { func GetCipherIVSize(t CipherType) int {
return int(C.gnutls_cipher_get_iv_size(C.gnutls_cipher_algorithm_t(t))) return int(C.gnutls_cipher_get_iv_size(C.gnutls_cipher_algorithm_t(t)))
} }
// GetCipherBlockSize get the cipher algorithm block size // GetCipherBlockSize get the cipher algorithm block size
func GetCipherBlockSize(t int) int { func GetCipherBlockSize(t CipherType) int {
return int(C.gnutls_cipher_get_block_size(C.gnutls_cipher_algorithm_t(t))) return int(C.gnutls_cipher_get_block_size(C.gnutls_cipher_algorithm_t(t)))
} }

@ -10,7 +10,7 @@ import (
func TestCipherSize(t *testing.T) { func TestCipherSize(t *testing.T) {
data := []struct { data := []struct {
t int t CipherType
bsize int bsize int
isize int isize int
}{ }{
@ -56,11 +56,13 @@ func TestEncryptDecrypt(t *testing.T) {
if c == nil { if c == nil {
t.Fatal("new ciphoer failed") t.Fatal("new ciphoer failed")
} }
cdata, err := c.Encrypt(data) cdata := make([]byte, len(data))
err = c.Encrypt(cdata, data)
if err != nil { if err != nil {
t.Fatal("encrypt failed", err) t.Fatal("encrypt failed", err)
} }
data1, err := c1.Decrypt(cdata) data1 := make([]byte, len(data))
err = c1.Decrypt(data1, cdata)
if err != nil { if err != nil {
t.Fatal("decrypt failed", err) t.Fatal("decrypt failed", err)
} }
@ -88,13 +90,13 @@ func BenchmarkAESEncrypt(b *testing.B) {
rand.Reader.Read(key) rand.Reader.Read(key)
rand.Reader.Read(iv) rand.Reader.Read(iv)
buf := make([]byte, datalen) buf := make([]byte, datalen)
dst := make([]byte, datalen)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
c, err := NewCipher(cipherName, key, iv) c, err := NewCipher(cipherName, key, iv)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
c.Encrypt(buf) c.Encrypt(dst, buf)
c.Close() c.Close()
} }
} }

Loading…
Cancel
Save