diff --git a/cipher.go b/cipher.go index 9a3e946..aacbb88 100644 --- a/cipher.go +++ b/cipher.go @@ -6,52 +6,70 @@ package gnutls */ import "C" import ( + "errors" "fmt" "log" ) +// CipherType cipher type +type CipherType int + const ( - GNUTLS_CIPHER_AES_128_CBC = 4 - GNUTLS_CIPHER_AES_256_CBC = 5 - GNUTLS_CIPHER_ARCFOUR_40 = 6 - GNUTLS_CIPHER_CAMELLIA_128_CBC = 7 - GNUTLS_CIPHER_CAMELLIA_256_CBC = 8 - GNUTLS_CIPHER_AES_192_CBC = 9 - GNUTLS_CIPHER_AES_128_GCM = 10 - GNUTLS_CIPHER_AES_256_GCM = 11 - GNUTLS_CIPHER_CAMELLIA_192_CBC = 12 - GNUTLS_CIPHER_SALSA20_256 = 13 - GNUTLS_CIPHER_ESTREAM_SALSA20_256 = 14 - GNUTLS_CIPHER_CAMELLIA_128_GCM = 15 - GNUTLS_CIPHER_CAMELLIA_256_GCM = 16 - GNUTLS_CIPHER_RC2_40_CBC = 17 - GNUTLS_CIPHER_DES_CBC = 18 - GNUTLS_CIPHER_AES_128_CCM = 19 - GNUTLS_CIPHER_AES_256_CCM = 20 - GNUTLS_CIPHER_AES_128_CCM_8 = 21 - GNUTLS_CIPHER_AES_256_CCM_8 = 22 - GNUTLS_CIPHER_CHACHA20_POLY1305 = 23 + GNUTLS_CIPHER_AES_128_CBC CipherType = 4 + GNUTLS_CIPHER_AES_256_CBC CipherType = 5 + GNUTLS_CIPHER_ARCFOUR_40 CipherType = 6 + GNUTLS_CIPHER_CAMELLIA_128_CBC CipherType = 7 + GNUTLS_CIPHER_CAMELLIA_256_CBC CipherType = 8 + GNUTLS_CIPHER_AES_192_CBC CipherType = 9 + GNUTLS_CIPHER_AES_128_GCM CipherType = 10 + GNUTLS_CIPHER_AES_256_GCM CipherType = 11 + GNUTLS_CIPHER_CAMELLIA_192_CBC CipherType = 12 + GNUTLS_CIPHER_SALSA20_256 CipherType = 13 + GNUTLS_CIPHER_ESTREAM_SALSA20_256 CipherType = 14 + GNUTLS_CIPHER_CAMELLIA_128_GCM CipherType = 15 + GNUTLS_CIPHER_CAMELLIA_256_GCM CipherType = 16 + GNUTLS_CIPHER_RC2_40_CBC CipherType = 17 + GNUTLS_CIPHER_DES_CBC CipherType = 18 + GNUTLS_CIPHER_AES_128_CCM CipherType = 19 + GNUTLS_CIPHER_AES_256_CCM CipherType = 20 + GNUTLS_CIPHER_AES_128_CCM_8 CipherType = 21 + GNUTLS_CIPHER_AES_256_CCM_8 CipherType = 22 + GNUTLS_CIPHER_CHACHA20_POLY1305 CipherType = 23 +) + +var ( + // ErrBlockSize wrong block size + ErrBlockSize = errors.New("wrong block size") + // ErrKeyLength wrong key length + ErrKeyLength = errors.New("wrong key length") + // ErrIVLength wrong iv length + ErrIVLength = errors.New("wrong iv length") ) -// Cipher cipher +// Cipher gnutls cipher struct type Cipher struct { cipher C.gnutls_cipher_hd_t - t int + t CipherType blockSize int } -// NewCipher create cipher -func NewCipher(t int, key []byte, iv []byte) (*Cipher, error) { +// NewCipher create a new cipher by give type, key, iv +// +// example: +// NewCipher(GNUTLS_CIPHER_AES_128_CBC, []byte("1234567890abcdef"), []byte("abcdef0123456789")) +// +// you can use GetCipherKeySize, GetCipherBlockSize, GetCipherIVSize to determine the given cipher 's key, block, iv size +func NewCipher(t CipherType, key []byte, iv []byte) (*Cipher, error) { keysize := GetCipherKeySize(t) ivSize := GetCipherIVSize(t) blocksize := GetCipherBlockSize(t) //log.Printf("block size: %d, iv size: %d", int(ivSize), int(blockSize)) if len(key) != int(keysize) { - return nil, fmt.Errorf("wrong key size") + return nil, ErrKeyLength } if len(iv) != int(ivSize) { - return nil, fmt.Errorf("wrong iv size") + return nil, ErrIVLength } ckey := C.CBytes(key) @@ -68,10 +86,11 @@ func NewCipher(t int, key []byte, iv []byte) (*Cipher, error) { return &Cipher{c, t, blocksize}, nil } -// Encrypt encrypt -func (c *Cipher) Encrypt(buf []byte) ([]byte, error) { +// Encrypt encrypt the buf and place the encrypted data in dst, +// the buf size must multiple of cipher's block size +func (c *Cipher) Encrypt(dst, buf []byte) error { if len(buf)%c.blockSize != 0 { - return nil, fmt.Errorf("wrong block size") + return ErrBlockSize } cbuf := C.CBytes(buf) @@ -84,15 +103,18 @@ func (c *Cipher) Encrypt(buf []byte) ([]byte, error) { ret := C.gnutls_cipher_encrypt2(c.cipher, cbuf, bufLen, dstBuf, bufLen) if int(ret) < 0 { - return nil, fmt.Errorf("encrypt error: %s", C.GoString(C.gnutls_strerror(ret))) + return fmt.Errorf("encrypt error: %s", C.GoString(C.gnutls_strerror(ret))) } - return C.GoBytes(dstBuf, C.int(bufLen)), nil + _buf := C.GoBytes(dstBuf, C.int(bufLen)) + copy(dst, _buf) + return nil } -// Decrypt decrypt -func (c *Cipher) Decrypt(buf []byte) ([]byte, error) { +// Decrypt decrypt the buf and place the decrypted data in dst, +// the buf size must multiple of cipher's block size +func (c *Cipher) Decrypt(dst, buf []byte) error { if len(buf)%c.blockSize != 0 { - return nil, fmt.Errorf("wrong block size") + return ErrBlockSize } cbuf := C.CBytes(buf) @@ -105,28 +127,30 @@ func (c *Cipher) Decrypt(buf []byte) ([]byte, error) { ret := C.gnutls_cipher_decrypt2(c.cipher, cbuf, bufLen, dstBuf, bufLen) if int(ret) < 0 { - return nil, fmt.Errorf("decrypt error: %s", C.GoString(C.gnutls_strerror(ret))) + return fmt.Errorf("decrypt error: %s", C.GoString(C.gnutls_strerror(ret))) } - return C.GoBytes(dstBuf, C.int(bufLen)), nil + _buf := C.GoBytes(dstBuf, C.int(bufLen)) + copy(dst, _buf) + return nil } -// Close destroy the cipher +// Close destroy the cipher context func (c *Cipher) Close() error { C.gnutls_cipher_deinit(c.cipher) return nil } // GetCipherKeySize get the cipher algorithm key length -func GetCipherKeySize(t int) int { +func GetCipherKeySize(t CipherType) int { return int(C.gnutls_cipher_get_key_size(C.gnutls_cipher_algorithm_t(t))) } // GetCipherIVSize get the cipher algorithm iv length -func GetCipherIVSize(t int) int { +func GetCipherIVSize(t CipherType) int { return int(C.gnutls_cipher_get_iv_size(C.gnutls_cipher_algorithm_t(t))) } // GetCipherBlockSize get the cipher algorithm block size -func GetCipherBlockSize(t int) int { +func GetCipherBlockSize(t CipherType) int { return int(C.gnutls_cipher_get_block_size(C.gnutls_cipher_algorithm_t(t))) } diff --git a/cipher_test.go b/cipher_test.go index bd7458c..a82eeb6 100644 --- a/cipher_test.go +++ b/cipher_test.go @@ -10,7 +10,7 @@ import ( func TestCipherSize(t *testing.T) { data := []struct { - t int + t CipherType bsize int isize int }{ @@ -56,11 +56,13 @@ func TestEncryptDecrypt(t *testing.T) { if c == nil { t.Fatal("new ciphoer failed") } - cdata, err := c.Encrypt(data) + cdata := make([]byte, len(data)) + err = c.Encrypt(cdata, data) if err != nil { t.Fatal("encrypt failed", err) } - data1, err := c1.Decrypt(cdata) + data1 := make([]byte, len(data)) + err = c1.Decrypt(data1, cdata) if err != nil { t.Fatal("decrypt failed", err) } @@ -88,13 +90,13 @@ func BenchmarkAESEncrypt(b *testing.B) { rand.Reader.Read(key) rand.Reader.Read(iv) buf := make([]byte, datalen) - + dst := make([]byte, datalen) for i := 0; i < b.N; i++ { c, err := NewCipher(cipherName, key, iv) if err != nil { b.Fatal(err) } - c.Encrypt(buf) + c.Encrypt(dst, buf) c.Close() } }