add lock to handshake and close

master
fangdingjun 6 years ago
parent a5ef54ff18
commit 50a41d72b7

@ -10,6 +10,7 @@ import (
"log"
"net"
"runtime"
"sync"
"time"
"unsafe"
)
@ -31,6 +32,7 @@ type Conn struct {
state *ConnectionState
cfg *Config
closed bool
lock *sync.Mutex
}
// Config gnutls TLS configure,
@ -104,7 +106,7 @@ func Listen(network, addr string, cfg *Config) (net.Listener, error) {
// NewServerConn create a server TLS Conn on c
func NewServerConn(c net.Conn, cfg *Config) (*Conn, error) {
var sess = C.init_gnutls_server_session()
conn := &Conn{c: c, sess: sess, cfg: cfg}
conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)}
n := C.size_t(uintptr(unsafe.Pointer(conn)))
//log.Println("conn addr ", int(n))
C.set_data(sess, n)
@ -122,7 +124,7 @@ func NewServerConn(c net.Conn, cfg *Config) (*Conn, error) {
// NewClientConn create a client TLS Conn on c
func NewClientConn(c net.Conn, cfg *Config) (*Conn, error) {
var sess = C.init_gnutls_client_session()
conn := &Conn{c: c, sess: sess, cfg: cfg}
conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)}
n := C.size_t(uintptr(unsafe.Pointer(conn)))
//log.Println("conn addr ", int(n))
C.set_data(sess, n)
@ -176,6 +178,8 @@ func setAlpnProtocols(sess *C.struct_session, cfg *Config) error {
// Handshake call handshake for TLS Conn,
// this function will call automatic on Read/Write, if not handshake yet
func (c *Conn) Handshake() error {
c.lock.Lock()
defer c.lock.Unlock()
if c.handshake {
return nil
}
@ -190,13 +194,10 @@ func (c *Conn) Handshake() error {
// Read read application data from TLS connection
func (c *Conn) Read(buf []byte) (n int, err error) {
if !c.handshake {
err = c.Handshake()
if err != nil {
return
}
c.handshake = true
}
bufLen := len(buf)
cbuf := C.malloc(C.size_t(bufLen))
@ -220,13 +221,10 @@ func (c *Conn) Read(buf []byte) (n int, err error) {
// Write write application data to TLS connection
func (c *Conn) Write(buf []byte) (n int, err error) {
if !c.handshake {
err = c.Handshake()
if err != nil {
return
}
c.handshake = true
}
cbuf := C.CBytes(buf)
defer C.free(cbuf)
@ -247,10 +245,12 @@ func (c *Conn) Write(buf []byte) (n int, err error) {
// Close close the TLS conn and destroy the tls context
func (c *Conn) Close() error {
c.lock.Lock()
defer c.lock.Unlock()
if c.closed {
return nil
}
//C.gnutls_record_send(c.sess.session, nil, 0)
C.gnutls_record_send(c.sess.session, nil, 0)
C.session_destroy(c.sess)
c.c.Close()
if c.cservname != nil {
@ -260,10 +260,12 @@ func (c *Conn) Close() error {
if c.state != nil && c.state.PeerCertificate != nil {
c.state.PeerCertificate.Free()
}
c.closed = true
return nil
}
func (c *Conn) free() {
//log.Println("free conn")
c.Close()
}

Loading…
Cancel
Save