diff --git a/cipher.go b/cipher.go index 2e2c3e9..9a3e946 100644 --- a/cipher.go +++ b/cipher.go @@ -35,17 +35,23 @@ const ( // Cipher cipher type Cipher struct { - cipher C.gnutls_cipher_hd_t - t int + cipher C.gnutls_cipher_hd_t + t int + blockSize int } // NewCipher create cipher func NewCipher(t int, key []byte, iv []byte) (*Cipher, error) { - ivSize := C.cipher_get_block_size(C.int(t)) - blockSize := C.cipher_get_iv_size(C.int(t)) - if len(key) != int(blockSize) || len(iv) != int(ivSize) { + keysize := GetCipherKeySize(t) + ivSize := GetCipherIVSize(t) + blocksize := GetCipherBlockSize(t) + //log.Printf("block size: %d, iv size: %d", int(ivSize), int(blockSize)) + if len(key) != int(keysize) { + return nil, fmt.Errorf("wrong key size") + } - return nil, fmt.Errorf("wrong block/iv size") + if len(iv) != int(ivSize) { + return nil, fmt.Errorf("wrong iv size") } ckey := C.CBytes(key) @@ -59,13 +65,12 @@ func NewCipher(t int, key []byte, iv []byte) (*Cipher, error) { log.Println("new cipher return nil") return nil, nil } - return &Cipher{c, t}, nil + return &Cipher{c, t, blocksize}, nil } // Encrypt encrypt func (c *Cipher) Encrypt(buf []byte) ([]byte, error) { - blockSize := C.cipher_get_iv_size(C.int(c.t)) - if len(buf)%int(blockSize) != 0 { + if len(buf)%c.blockSize != 0 { return nil, fmt.Errorf("wrong block size") } @@ -86,8 +91,7 @@ func (c *Cipher) Encrypt(buf []byte) ([]byte, error) { // Decrypt decrypt func (c *Cipher) Decrypt(buf []byte) ([]byte, error) { - blockSize := C.cipher_get_iv_size(C.int(c.t)) - if len(buf)%int(blockSize) != 0 { + if len(buf)%c.blockSize != 0 { return nil, fmt.Errorf("wrong block size") } @@ -111,3 +115,18 @@ func (c *Cipher) Close() error { C.gnutls_cipher_deinit(c.cipher) return nil } + +// GetCipherKeySize get the cipher algorithm key length +func GetCipherKeySize(t int) int { + return int(C.gnutls_cipher_get_key_size(C.gnutls_cipher_algorithm_t(t))) +} + +// GetCipherIVSize get the cipher algorithm iv length +func GetCipherIVSize(t int) int { + return int(C.gnutls_cipher_get_iv_size(C.gnutls_cipher_algorithm_t(t))) +} + +// GetCipherBlockSize get the cipher algorithm block size +func GetCipherBlockSize(t int) int { + return int(C.gnutls_cipher_get_block_size(C.gnutls_cipher_algorithm_t(t))) +}