add lock to handshake and close

master
fangdingjun 6 years ago
parent a5ef54ff18
commit 50a41d72b7

@ -10,6 +10,7 @@ import (
"log" "log"
"net" "net"
"runtime" "runtime"
"sync"
"time" "time"
"unsafe" "unsafe"
) )
@ -31,6 +32,7 @@ type Conn struct {
state *ConnectionState state *ConnectionState
cfg *Config cfg *Config
closed bool closed bool
lock *sync.Mutex
} }
// Config gnutls TLS configure, // Config gnutls TLS configure,
@ -104,7 +106,7 @@ func Listen(network, addr string, cfg *Config) (net.Listener, error) {
// NewServerConn create a server TLS Conn on c // NewServerConn create a server TLS Conn on c
func NewServerConn(c net.Conn, cfg *Config) (*Conn, error) { func NewServerConn(c net.Conn, cfg *Config) (*Conn, error) {
var sess = C.init_gnutls_server_session() var sess = C.init_gnutls_server_session()
conn := &Conn{c: c, sess: sess, cfg: cfg} conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)}
n := C.size_t(uintptr(unsafe.Pointer(conn))) n := C.size_t(uintptr(unsafe.Pointer(conn)))
//log.Println("conn addr ", int(n)) //log.Println("conn addr ", int(n))
C.set_data(sess, n) C.set_data(sess, n)
@ -122,7 +124,7 @@ func NewServerConn(c net.Conn, cfg *Config) (*Conn, error) {
// NewClientConn create a client TLS Conn on c // NewClientConn create a client TLS Conn on c
func NewClientConn(c net.Conn, cfg *Config) (*Conn, error) { func NewClientConn(c net.Conn, cfg *Config) (*Conn, error) {
var sess = C.init_gnutls_client_session() var sess = C.init_gnutls_client_session()
conn := &Conn{c: c, sess: sess, cfg: cfg} conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)}
n := C.size_t(uintptr(unsafe.Pointer(conn))) n := C.size_t(uintptr(unsafe.Pointer(conn)))
//log.Println("conn addr ", int(n)) //log.Println("conn addr ", int(n))
C.set_data(sess, n) C.set_data(sess, n)
@ -176,6 +178,8 @@ func setAlpnProtocols(sess *C.struct_session, cfg *Config) error {
// Handshake call handshake for TLS Conn, // Handshake call handshake for TLS Conn,
// this function will call automatic on Read/Write, if not handshake yet // this function will call automatic on Read/Write, if not handshake yet
func (c *Conn) Handshake() error { func (c *Conn) Handshake() error {
c.lock.Lock()
defer c.lock.Unlock()
if c.handshake { if c.handshake {
return nil return nil
} }
@ -190,13 +194,10 @@ func (c *Conn) Handshake() error {
// Read read application data from TLS connection // Read read application data from TLS connection
func (c *Conn) Read(buf []byte) (n int, err error) { func (c *Conn) Read(buf []byte) (n int, err error) {
if !c.handshake {
err = c.Handshake() err = c.Handshake()
if err != nil { if err != nil {
return return
} }
c.handshake = true
}
bufLen := len(buf) bufLen := len(buf)
cbuf := C.malloc(C.size_t(bufLen)) cbuf := C.malloc(C.size_t(bufLen))
@ -220,13 +221,10 @@ func (c *Conn) Read(buf []byte) (n int, err error) {
// Write write application data to TLS connection // Write write application data to TLS connection
func (c *Conn) Write(buf []byte) (n int, err error) { func (c *Conn) Write(buf []byte) (n int, err error) {
if !c.handshake {
err = c.Handshake() err = c.Handshake()
if err != nil { if err != nil {
return return
} }
c.handshake = true
}
cbuf := C.CBytes(buf) cbuf := C.CBytes(buf)
defer C.free(cbuf) defer C.free(cbuf)
@ -247,10 +245,12 @@ func (c *Conn) Write(buf []byte) (n int, err error) {
// Close close the TLS conn and destroy the tls context // Close close the TLS conn and destroy the tls context
func (c *Conn) Close() error { func (c *Conn) Close() error {
c.lock.Lock()
defer c.lock.Unlock()
if c.closed { if c.closed {
return nil return nil
} }
//C.gnutls_record_send(c.sess.session, nil, 0) C.gnutls_record_send(c.sess.session, nil, 0)
C.session_destroy(c.sess) C.session_destroy(c.sess)
c.c.Close() c.c.Close()
if c.cservname != nil { if c.cservname != nil {
@ -260,10 +260,12 @@ func (c *Conn) Close() error {
if c.state != nil && c.state.PeerCertificate != nil { if c.state != nil && c.state.PeerCertificate != nil {
c.state.PeerCertificate.Free() c.state.PeerCertificate.Free()
} }
c.closed = true
return nil return nil
} }
func (c *Conn) free() { func (c *Conn) free() {
//log.Println("free conn")
c.Close() c.Close()
} }

Loading…
Cancel
Save