add finalizer to make sure release memory

master
fangdingjun 6 years ago
parent 4e5708fbe4
commit a5ef54ff18

@ -7,6 +7,7 @@ import "C"
import (
"fmt"
"log"
"runtime"
"strings"
"unsafe"
)
@ -32,6 +33,11 @@ func (c *Certificate) Free() {
c.certSize = 0
}
func (c *Certificate) free() {
log.Println("free certificate")
c.Free()
}
func (c *Certificate) matchName(name string) bool {
for _, n := range c.names {
if n == name {
@ -181,5 +187,6 @@ func LoadX509KeyPair(certfile, keyfile string) (*Certificate, error) {
certificate.privkey = privkey
certificate.certSize = certSize
certificate.buildNames()
runtime.SetFinalizer(certificate, (*Certificate).free)
return certificate, nil
}

@ -2,7 +2,9 @@ package gnutls
import (
"log"
"runtime"
"testing"
"time"
)
func TestGetAltname(t *testing.T) {
@ -22,3 +24,13 @@ func TestGetAltname(t *testing.T) {
//log.Println("flag 3: ", cert.getCertString(0, 3))
cert.Free()
}
func _loadCert(certfile, keyfile string) (*Certificate, error) {
return LoadX509KeyPair(certfile, keyfile)
}
func TestCertGC(t *testing.T) {
_loadCert("testdata/server.crt", "testdata/server.key")
runtime.GC()
time.Sleep(1 * time.Second)
}

@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"log"
"runtime"
)
// CipherType cipher type
@ -83,7 +84,9 @@ func NewCipher(t CipherType, key []byte, iv []byte) (*Cipher, error) {
log.Println("new cipher return nil")
return nil, nil
}
return &Cipher{c, t, blocksize}, nil
cipher := &Cipher{c, t, blocksize}
runtime.SetFinalizer(cipher, (*Cipher).free)
return cipher, nil
}
// Encrypt encrypt the buf and place the encrypted data in dst,
@ -136,10 +139,18 @@ func (c *Cipher) Decrypt(dst, buf []byte) error {
// Close destroy the cipher context
func (c *Cipher) Close() error {
if c.cipher != nil {
C.gnutls_cipher_deinit(c.cipher)
c.cipher = nil
}
return nil
}
func (c *Cipher) free() {
log.Println("free cipher")
c.Close()
}
// GetCipherKeySize get the cipher algorithm key length
func GetCipherKeySize(t CipherType) int {
return int(C.gnutls_cipher_get_key_size(C.gnutls_cipher_algorithm_t(t)))

@ -5,7 +5,9 @@ import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"runtime"
"testing"
"time"
)
func TestCipherSize(t *testing.T) {
@ -44,13 +46,13 @@ func TestEncryptDecrypt(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer c.Close()
//defer c.Close()
c1, err := NewCipher(cipherName, key, iv)
if err != nil {
t.Fatal(err)
}
defer c1.Close()
//defer c1.Close()
data := make([]byte, blocksize*10)
if c == nil {
@ -76,6 +78,8 @@ func TestEncryptDecrypt(t *testing.T) {
if !bytes.Equal(dst, cdata) {
t.Fatal("cipher text not equal to cypto/aes")
}
runtime.GC()
time.Sleep(1 * time.Second)
}
func BenchmarkAESEncrypt(b *testing.B) {

@ -7,6 +7,8 @@ package gnutls
import "C"
import (
"fmt"
"log"
"runtime"
)
// HashType hash type
@ -33,7 +35,9 @@ type Hash struct {
func NewHash(t HashType) *Hash {
h := C.new_hash(C.int(t))
hashOutLen := GetHashOutputLen(t)
return &Hash{h, t, C.int(hashOutLen)}
hash := &Hash{h, t, C.int(hashOutLen)}
runtime.SetFinalizer(hash, (*Hash).free)
return hash
}
// Write write data to hash context
@ -68,9 +72,16 @@ func (h *Hash) Sum(buf []byte) []byte {
// Close destroy hash context
func (h *Hash) Close() error {
if h.hash != nil {
C.gnutls_hash_deinit(h.hash, nil)
h.hash = nil
}
return nil
}
func (h *Hash) free() {
log.Println("free hash")
h.Close()
}
// GetHashOutputLen get the hash algorithm output length
//

@ -6,12 +6,14 @@ import (
"crypto/sha512"
"encoding/hex"
"log"
"runtime"
"testing"
"time"
)
func TestHashSHA(t *testing.T) {
h := NewHash(GNUTLS_HASH_SHA512)
defer h.Close()
//defer h.Close()
data := []byte("1234")
@ -24,6 +26,8 @@ func TestHashSHA(t *testing.T) {
log.Printf("\n%s\n%s", hex.EncodeToString(h4[:]), hex.EncodeToString(h1))
t.Fatal("hash not equal")
}
runtime.GC()
time.Sleep(1 * time.Second)
}
func BenchmarkHashSHA512(b *testing.B) {

@ -9,6 +9,7 @@ import (
"fmt"
"log"
"net"
"runtime"
"time"
"unsafe"
)
@ -29,6 +30,7 @@ type Conn struct {
cservname *C.char
state *ConnectionState
cfg *Config
closed bool
}
// Config gnutls TLS configure,
@ -113,6 +115,7 @@ func NewServerConn(c net.Conn, cfg *Config) (*Conn, error) {
log.Println(err)
}
}
runtime.SetFinalizer(conn, (*Conn).free)
return conn, nil
}
@ -150,6 +153,7 @@ func NewClientConn(c net.Conn, cfg *Config) (*Conn, error) {
} else {
C.gnutls_session_set_verify_cert(sess.session, nil, 0)
}
runtime.SetFinalizer(conn, (*Conn).free)
return conn, nil
}
@ -243,7 +247,10 @@ func (c *Conn) Write(buf []byte) (n int, err error) {
// Close close the TLS conn and destroy the tls context
func (c *Conn) Close() error {
C.gnutls_record_send(c.sess.session, nil, 0)
if c.closed {
return nil
}
//C.gnutls_record_send(c.sess.session, nil, 0)
C.session_destroy(c.sess)
c.c.Close()
if c.cservname != nil {
@ -256,6 +263,10 @@ func (c *Conn) Close() error {
return nil
}
func (c *Conn) free() {
c.Close()
}
// SetWriteDeadline implements net.Conn
func (c *Conn) SetWriteDeadline(t time.Time) error {
return c.c.SetWriteDeadline(t)
@ -310,7 +321,9 @@ func (c *Conn) getPeerCertificate() *Certificate {
if st == nil {
return nil
}
return &Certificate{cert: st, certSize: C.int(size)}
cert := &Certificate{cert: st, certSize: C.int(size)}
runtime.SetFinalizer(cert, (*Certificate).free)
return cert
}
func (c *Conn) getAlpnSelectedProtocol() string {

@ -10,7 +10,9 @@ import (
"net"
"net/http"
"os"
"runtime"
"testing"
"time"
)
func TestTLSClient(t *testing.T) {
@ -134,6 +136,8 @@ func TestTLSServer(t *testing.T) {
if string(buf[:n]) != data {
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
}
runtime.GC()
time.Sleep(1 * time.Second)
}
func TestTLSALPNServer(t *testing.T) {
@ -220,6 +224,8 @@ func TestTLSALPNServer(t *testing.T) {
if string(buf[:n]) != data {
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
}
runtime.GC()
time.Sleep(1 * time.Second)
}
func TestTLSALPNClient(t *testing.T) {
@ -308,6 +314,8 @@ func TestTLSALPNClient(t *testing.T) {
if string(buf[:n]) != data {
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
}
runtime.GC()
time.Sleep(1 * time.Second)
}
func TestTLSServerSNI(t *testing.T) {
@ -378,6 +386,8 @@ func TestTLSServerSNI(t *testing.T) {
}
conn.Close()
}
runtime.GC()
time.Sleep(1 * time.Second)
}
func TestTLSGetPeerCert(t *testing.T) {
@ -405,4 +415,6 @@ func TestTLSGetPeerCert(t *testing.T) {
t.Error(err)
}
resp.Write(os.Stdout)
runtime.GC()
time.Sleep(1 * time.Second)
}

Loading…
Cancel
Save