add finalizer to make sure release memory

master
fangdingjun 6 years ago
parent 4e5708fbe4
commit a5ef54ff18

@ -7,6 +7,7 @@ import "C"
import ( import (
"fmt" "fmt"
"log" "log"
"runtime"
"strings" "strings"
"unsafe" "unsafe"
) )
@ -32,6 +33,11 @@ func (c *Certificate) Free() {
c.certSize = 0 c.certSize = 0
} }
func (c *Certificate) free() {
log.Println("free certificate")
c.Free()
}
func (c *Certificate) matchName(name string) bool { func (c *Certificate) matchName(name string) bool {
for _, n := range c.names { for _, n := range c.names {
if n == name { if n == name {
@ -181,5 +187,6 @@ func LoadX509KeyPair(certfile, keyfile string) (*Certificate, error) {
certificate.privkey = privkey certificate.privkey = privkey
certificate.certSize = certSize certificate.certSize = certSize
certificate.buildNames() certificate.buildNames()
runtime.SetFinalizer(certificate, (*Certificate).free)
return certificate, nil return certificate, nil
} }

@ -2,7 +2,9 @@ package gnutls
import ( import (
"log" "log"
"runtime"
"testing" "testing"
"time"
) )
func TestGetAltname(t *testing.T) { func TestGetAltname(t *testing.T) {
@ -22,3 +24,13 @@ func TestGetAltname(t *testing.T) {
//log.Println("flag 3: ", cert.getCertString(0, 3)) //log.Println("flag 3: ", cert.getCertString(0, 3))
cert.Free() cert.Free()
} }
func _loadCert(certfile, keyfile string) (*Certificate, error) {
return LoadX509KeyPair(certfile, keyfile)
}
func TestCertGC(t *testing.T) {
_loadCert("testdata/server.crt", "testdata/server.key")
runtime.GC()
time.Sleep(1 * time.Second)
}

@ -9,6 +9,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"log" "log"
"runtime"
) )
// CipherType cipher type // CipherType cipher type
@ -83,7 +84,9 @@ func NewCipher(t CipherType, key []byte, iv []byte) (*Cipher, error) {
log.Println("new cipher return nil") log.Println("new cipher return nil")
return nil, nil return nil, nil
} }
return &Cipher{c, t, blocksize}, nil cipher := &Cipher{c, t, blocksize}
runtime.SetFinalizer(cipher, (*Cipher).free)
return cipher, nil
} }
// Encrypt encrypt the buf and place the encrypted data in dst, // Encrypt encrypt the buf and place the encrypted data in dst,
@ -136,10 +139,18 @@ func (c *Cipher) Decrypt(dst, buf []byte) error {
// Close destroy the cipher context // Close destroy the cipher context
func (c *Cipher) Close() error { func (c *Cipher) Close() error {
if c.cipher != nil {
C.gnutls_cipher_deinit(c.cipher) C.gnutls_cipher_deinit(c.cipher)
c.cipher = nil
}
return nil return nil
} }
func (c *Cipher) free() {
log.Println("free cipher")
c.Close()
}
// GetCipherKeySize get the cipher algorithm key length // GetCipherKeySize get the cipher algorithm key length
func GetCipherKeySize(t CipherType) int { func GetCipherKeySize(t CipherType) int {
return int(C.gnutls_cipher_get_key_size(C.gnutls_cipher_algorithm_t(t))) return int(C.gnutls_cipher_get_key_size(C.gnutls_cipher_algorithm_t(t)))

@ -5,7 +5,9 @@ import (
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/rand" "crypto/rand"
"runtime"
"testing" "testing"
"time"
) )
func TestCipherSize(t *testing.T) { func TestCipherSize(t *testing.T) {
@ -44,13 +46,13 @@ func TestEncryptDecrypt(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer c.Close() //defer c.Close()
c1, err := NewCipher(cipherName, key, iv) c1, err := NewCipher(cipherName, key, iv)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer c1.Close() //defer c1.Close()
data := make([]byte, blocksize*10) data := make([]byte, blocksize*10)
if c == nil { if c == nil {
@ -76,6 +78,8 @@ func TestEncryptDecrypt(t *testing.T) {
if !bytes.Equal(dst, cdata) { if !bytes.Equal(dst, cdata) {
t.Fatal("cipher text not equal to cypto/aes") t.Fatal("cipher text not equal to cypto/aes")
} }
runtime.GC()
time.Sleep(1 * time.Second)
} }
func BenchmarkAESEncrypt(b *testing.B) { func BenchmarkAESEncrypt(b *testing.B) {

@ -7,6 +7,8 @@ package gnutls
import "C" import "C"
import ( import (
"fmt" "fmt"
"log"
"runtime"
) )
// HashType hash type // HashType hash type
@ -33,7 +35,9 @@ type Hash struct {
func NewHash(t HashType) *Hash { func NewHash(t HashType) *Hash {
h := C.new_hash(C.int(t)) h := C.new_hash(C.int(t))
hashOutLen := GetHashOutputLen(t) hashOutLen := GetHashOutputLen(t)
return &Hash{h, t, C.int(hashOutLen)} hash := &Hash{h, t, C.int(hashOutLen)}
runtime.SetFinalizer(hash, (*Hash).free)
return hash
} }
// Write write data to hash context // Write write data to hash context
@ -68,9 +72,16 @@ func (h *Hash) Sum(buf []byte) []byte {
// Close destroy hash context // Close destroy hash context
func (h *Hash) Close() error { func (h *Hash) Close() error {
if h.hash != nil {
C.gnutls_hash_deinit(h.hash, nil) C.gnutls_hash_deinit(h.hash, nil)
h.hash = nil
}
return nil return nil
} }
func (h *Hash) free() {
log.Println("free hash")
h.Close()
}
// GetHashOutputLen get the hash algorithm output length // GetHashOutputLen get the hash algorithm output length
// //

@ -6,12 +6,14 @@ import (
"crypto/sha512" "crypto/sha512"
"encoding/hex" "encoding/hex"
"log" "log"
"runtime"
"testing" "testing"
"time"
) )
func TestHashSHA(t *testing.T) { func TestHashSHA(t *testing.T) {
h := NewHash(GNUTLS_HASH_SHA512) h := NewHash(GNUTLS_HASH_SHA512)
defer h.Close() //defer h.Close()
data := []byte("1234") data := []byte("1234")
@ -24,6 +26,8 @@ func TestHashSHA(t *testing.T) {
log.Printf("\n%s\n%s", hex.EncodeToString(h4[:]), hex.EncodeToString(h1)) log.Printf("\n%s\n%s", hex.EncodeToString(h4[:]), hex.EncodeToString(h1))
t.Fatal("hash not equal") t.Fatal("hash not equal")
} }
runtime.GC()
time.Sleep(1 * time.Second)
} }
func BenchmarkHashSHA512(b *testing.B) { func BenchmarkHashSHA512(b *testing.B) {

@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"log" "log"
"net" "net"
"runtime"
"time" "time"
"unsafe" "unsafe"
) )
@ -29,6 +30,7 @@ type Conn struct {
cservname *C.char cservname *C.char
state *ConnectionState state *ConnectionState
cfg *Config cfg *Config
closed bool
} }
// Config gnutls TLS configure, // Config gnutls TLS configure,
@ -113,6 +115,7 @@ func NewServerConn(c net.Conn, cfg *Config) (*Conn, error) {
log.Println(err) log.Println(err)
} }
} }
runtime.SetFinalizer(conn, (*Conn).free)
return conn, nil return conn, nil
} }
@ -150,6 +153,7 @@ func NewClientConn(c net.Conn, cfg *Config) (*Conn, error) {
} else { } else {
C.gnutls_session_set_verify_cert(sess.session, nil, 0) C.gnutls_session_set_verify_cert(sess.session, nil, 0)
} }
runtime.SetFinalizer(conn, (*Conn).free)
return conn, nil return conn, nil
} }
@ -243,7 +247,10 @@ func (c *Conn) Write(buf []byte) (n int, err error) {
// Close close the TLS conn and destroy the tls context // Close close the TLS conn and destroy the tls context
func (c *Conn) Close() error { func (c *Conn) Close() error {
C.gnutls_record_send(c.sess.session, nil, 0) if c.closed {
return nil
}
//C.gnutls_record_send(c.sess.session, nil, 0)
C.session_destroy(c.sess) C.session_destroy(c.sess)
c.c.Close() c.c.Close()
if c.cservname != nil { if c.cservname != nil {
@ -256,6 +263,10 @@ func (c *Conn) Close() error {
return nil return nil
} }
func (c *Conn) free() {
c.Close()
}
// SetWriteDeadline implements net.Conn // SetWriteDeadline implements net.Conn
func (c *Conn) SetWriteDeadline(t time.Time) error { func (c *Conn) SetWriteDeadline(t time.Time) error {
return c.c.SetWriteDeadline(t) return c.c.SetWriteDeadline(t)
@ -310,7 +321,9 @@ func (c *Conn) getPeerCertificate() *Certificate {
if st == nil { if st == nil {
return nil return nil
} }
return &Certificate{cert: st, certSize: C.int(size)} cert := &Certificate{cert: st, certSize: C.int(size)}
runtime.SetFinalizer(cert, (*Certificate).free)
return cert
} }
func (c *Conn) getAlpnSelectedProtocol() string { func (c *Conn) getAlpnSelectedProtocol() string {

@ -10,7 +10,9 @@ import (
"net" "net"
"net/http" "net/http"
"os" "os"
"runtime"
"testing" "testing"
"time"
) )
func TestTLSClient(t *testing.T) { func TestTLSClient(t *testing.T) {
@ -134,6 +136,8 @@ func TestTLSServer(t *testing.T) {
if string(buf[:n]) != data { if string(buf[:n]) != data {
t.Errorf("need: %s, got: %s", data, string(buf[:n])) t.Errorf("need: %s, got: %s", data, string(buf[:n]))
} }
runtime.GC()
time.Sleep(1 * time.Second)
} }
func TestTLSALPNServer(t *testing.T) { func TestTLSALPNServer(t *testing.T) {
@ -220,6 +224,8 @@ func TestTLSALPNServer(t *testing.T) {
if string(buf[:n]) != data { if string(buf[:n]) != data {
t.Errorf("need: %s, got: %s", data, string(buf[:n])) t.Errorf("need: %s, got: %s", data, string(buf[:n]))
} }
runtime.GC()
time.Sleep(1 * time.Second)
} }
func TestTLSALPNClient(t *testing.T) { func TestTLSALPNClient(t *testing.T) {
@ -308,6 +314,8 @@ func TestTLSALPNClient(t *testing.T) {
if string(buf[:n]) != data { if string(buf[:n]) != data {
t.Errorf("need: %s, got: %s", data, string(buf[:n])) t.Errorf("need: %s, got: %s", data, string(buf[:n]))
} }
runtime.GC()
time.Sleep(1 * time.Second)
} }
func TestTLSServerSNI(t *testing.T) { func TestTLSServerSNI(t *testing.T) {
@ -378,6 +386,8 @@ func TestTLSServerSNI(t *testing.T) {
} }
conn.Close() conn.Close()
} }
runtime.GC()
time.Sleep(1 * time.Second)
} }
func TestTLSGetPeerCert(t *testing.T) { func TestTLSGetPeerCert(t *testing.T) {
@ -405,4 +415,6 @@ func TestTLSGetPeerCert(t *testing.T) {
t.Error(err) t.Error(err)
} }
resp.Write(os.Stdout) resp.Write(os.Stdout)
runtime.GC()
time.Sleep(1 * time.Second)
} }

Loading…
Cancel
Save