diff --git a/certificate.go b/certificate.go index 5550330..1f62d01 100644 --- a/certificate.go +++ b/certificate.go @@ -7,6 +7,7 @@ import "C" import ( "fmt" "log" + "runtime" "strings" "unsafe" ) @@ -32,6 +33,11 @@ func (c *Certificate) Free() { c.certSize = 0 } +func (c *Certificate) free() { + log.Println("free certificate") + c.Free() +} + func (c *Certificate) matchName(name string) bool { for _, n := range c.names { if n == name { @@ -181,5 +187,6 @@ func LoadX509KeyPair(certfile, keyfile string) (*Certificate, error) { certificate.privkey = privkey certificate.certSize = certSize certificate.buildNames() + runtime.SetFinalizer(certificate, (*Certificate).free) return certificate, nil } diff --git a/certificate_test.go b/certificate_test.go index f991e6d..f58fb7e 100644 --- a/certificate_test.go +++ b/certificate_test.go @@ -2,7 +2,9 @@ package gnutls import ( "log" + "runtime" "testing" + "time" ) func TestGetAltname(t *testing.T) { @@ -22,3 +24,13 @@ func TestGetAltname(t *testing.T) { //log.Println("flag 3: ", cert.getCertString(0, 3)) cert.Free() } + +func _loadCert(certfile, keyfile string) (*Certificate, error) { + return LoadX509KeyPair(certfile, keyfile) +} + +func TestCertGC(t *testing.T) { + _loadCert("testdata/server.crt", "testdata/server.key") + runtime.GC() + time.Sleep(1 * time.Second) +} diff --git a/cipher.go b/cipher.go index aacbb88..14feefe 100644 --- a/cipher.go +++ b/cipher.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "log" + "runtime" ) // CipherType cipher type @@ -83,7 +84,9 @@ func NewCipher(t CipherType, key []byte, iv []byte) (*Cipher, error) { log.Println("new cipher return nil") return nil, nil } - return &Cipher{c, t, blocksize}, nil + cipher := &Cipher{c, t, blocksize} + runtime.SetFinalizer(cipher, (*Cipher).free) + return cipher, nil } // Encrypt encrypt the buf and place the encrypted data in dst, @@ -136,10 +139,18 @@ func (c *Cipher) Decrypt(dst, buf []byte) error { // Close destroy the cipher context func (c *Cipher) Close() error { - C.gnutls_cipher_deinit(c.cipher) + if c.cipher != nil { + C.gnutls_cipher_deinit(c.cipher) + c.cipher = nil + } return nil } +func (c *Cipher) free() { + log.Println("free cipher") + c.Close() +} + // GetCipherKeySize get the cipher algorithm key length func GetCipherKeySize(t CipherType) int { return int(C.gnutls_cipher_get_key_size(C.gnutls_cipher_algorithm_t(t))) diff --git a/cipher_test.go b/cipher_test.go index a82eeb6..a6b3566 100644 --- a/cipher_test.go +++ b/cipher_test.go @@ -5,7 +5,9 @@ import ( "crypto/aes" "crypto/cipher" "crypto/rand" + "runtime" "testing" + "time" ) func TestCipherSize(t *testing.T) { @@ -44,13 +46,13 @@ func TestEncryptDecrypt(t *testing.T) { if err != nil { t.Fatal(err) } - defer c.Close() + //defer c.Close() c1, err := NewCipher(cipherName, key, iv) if err != nil { t.Fatal(err) } - defer c1.Close() + //defer c1.Close() data := make([]byte, blocksize*10) if c == nil { @@ -76,6 +78,8 @@ func TestEncryptDecrypt(t *testing.T) { if !bytes.Equal(dst, cdata) { t.Fatal("cipher text not equal to cypto/aes") } + runtime.GC() + time.Sleep(1 * time.Second) } func BenchmarkAESEncrypt(b *testing.B) { diff --git a/hash.go b/hash.go index 65776ea..5d83121 100644 --- a/hash.go +++ b/hash.go @@ -7,6 +7,8 @@ package gnutls import "C" import ( "fmt" + "log" + "runtime" ) // HashType hash type @@ -33,7 +35,9 @@ type Hash struct { func NewHash(t HashType) *Hash { h := C.new_hash(C.int(t)) hashOutLen := GetHashOutputLen(t) - return &Hash{h, t, C.int(hashOutLen)} + hash := &Hash{h, t, C.int(hashOutLen)} + runtime.SetFinalizer(hash, (*Hash).free) + return hash } // Write write data to hash context @@ -68,9 +72,16 @@ func (h *Hash) Sum(buf []byte) []byte { // Close destroy hash context func (h *Hash) Close() error { - C.gnutls_hash_deinit(h.hash, nil) + if h.hash != nil { + C.gnutls_hash_deinit(h.hash, nil) + h.hash = nil + } return nil } +func (h *Hash) free() { + log.Println("free hash") + h.Close() +} // GetHashOutputLen get the hash algorithm output length // diff --git a/hash_test.go b/hash_test.go index 50ebf0b..3d9ee1a 100644 --- a/hash_test.go +++ b/hash_test.go @@ -6,12 +6,14 @@ import ( "crypto/sha512" "encoding/hex" "log" + "runtime" "testing" + "time" ) func TestHashSHA(t *testing.T) { h := NewHash(GNUTLS_HASH_SHA512) - defer h.Close() + //defer h.Close() data := []byte("1234") @@ -24,6 +26,8 @@ func TestHashSHA(t *testing.T) { log.Printf("\n%s\n%s", hex.EncodeToString(h4[:]), hex.EncodeToString(h1)) t.Fatal("hash not equal") } + runtime.GC() + time.Sleep(1 * time.Second) } func BenchmarkHashSHA512(b *testing.B) { diff --git a/tls.go b/tls.go index f47f0de..db4767a 100644 --- a/tls.go +++ b/tls.go @@ -9,6 +9,7 @@ import ( "fmt" "log" "net" + "runtime" "time" "unsafe" ) @@ -29,6 +30,7 @@ type Conn struct { cservname *C.char state *ConnectionState cfg *Config + closed bool } // Config gnutls TLS configure, @@ -113,6 +115,7 @@ func NewServerConn(c net.Conn, cfg *Config) (*Conn, error) { log.Println(err) } } + runtime.SetFinalizer(conn, (*Conn).free) return conn, nil } @@ -150,6 +153,7 @@ func NewClientConn(c net.Conn, cfg *Config) (*Conn, error) { } else { C.gnutls_session_set_verify_cert(sess.session, nil, 0) } + runtime.SetFinalizer(conn, (*Conn).free) return conn, nil } @@ -243,7 +247,10 @@ func (c *Conn) Write(buf []byte) (n int, err error) { // Close close the TLS conn and destroy the tls context func (c *Conn) Close() error { - C.gnutls_record_send(c.sess.session, nil, 0) + if c.closed { + return nil + } + //C.gnutls_record_send(c.sess.session, nil, 0) C.session_destroy(c.sess) c.c.Close() if c.cservname != nil { @@ -256,6 +263,10 @@ func (c *Conn) Close() error { return nil } +func (c *Conn) free() { + c.Close() +} + // SetWriteDeadline implements net.Conn func (c *Conn) SetWriteDeadline(t time.Time) error { return c.c.SetWriteDeadline(t) @@ -310,7 +321,9 @@ func (c *Conn) getPeerCertificate() *Certificate { if st == nil { return nil } - return &Certificate{cert: st, certSize: C.int(size)} + cert := &Certificate{cert: st, certSize: C.int(size)} + runtime.SetFinalizer(cert, (*Certificate).free) + return cert } func (c *Conn) getAlpnSelectedProtocol() string { diff --git a/tls_test.go b/tls_test.go index f57cf12..3ffc3a4 100644 --- a/tls_test.go +++ b/tls_test.go @@ -10,7 +10,9 @@ import ( "net" "net/http" "os" + "runtime" "testing" + "time" ) func TestTLSClient(t *testing.T) { @@ -134,6 +136,8 @@ func TestTLSServer(t *testing.T) { if string(buf[:n]) != data { t.Errorf("need: %s, got: %s", data, string(buf[:n])) } + runtime.GC() + time.Sleep(1 * time.Second) } func TestTLSALPNServer(t *testing.T) { @@ -220,6 +224,8 @@ func TestTLSALPNServer(t *testing.T) { if string(buf[:n]) != data { t.Errorf("need: %s, got: %s", data, string(buf[:n])) } + runtime.GC() + time.Sleep(1 * time.Second) } func TestTLSALPNClient(t *testing.T) { @@ -308,6 +314,8 @@ func TestTLSALPNClient(t *testing.T) { if string(buf[:n]) != data { t.Errorf("need: %s, got: %s", data, string(buf[:n])) } + runtime.GC() + time.Sleep(1 * time.Second) } func TestTLSServerSNI(t *testing.T) { @@ -378,6 +386,8 @@ func TestTLSServerSNI(t *testing.T) { } conn.Close() } + runtime.GC() + time.Sleep(1 * time.Second) } func TestTLSGetPeerCert(t *testing.T) { @@ -405,4 +415,6 @@ func TestTLSGetPeerCert(t *testing.T) { t.Error(err) } resp.Write(os.Stdout) + runtime.GC() + time.Sleep(1 * time.Second) }