diff --git a/tls.go b/tls.go index db4767a..7718547 100644 --- a/tls.go +++ b/tls.go @@ -10,6 +10,7 @@ import ( "log" "net" "runtime" + "sync" "time" "unsafe" ) @@ -31,6 +32,7 @@ type Conn struct { state *ConnectionState cfg *Config closed bool + lock *sync.Mutex } // Config gnutls TLS configure, @@ -104,7 +106,7 @@ func Listen(network, addr string, cfg *Config) (net.Listener, error) { // NewServerConn create a server TLS Conn on c func NewServerConn(c net.Conn, cfg *Config) (*Conn, error) { var sess = C.init_gnutls_server_session() - conn := &Conn{c: c, sess: sess, cfg: cfg} + conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)} n := C.size_t(uintptr(unsafe.Pointer(conn))) //log.Println("conn addr ", int(n)) C.set_data(sess, n) @@ -122,7 +124,7 @@ func NewServerConn(c net.Conn, cfg *Config) (*Conn, error) { // NewClientConn create a client TLS Conn on c func NewClientConn(c net.Conn, cfg *Config) (*Conn, error) { var sess = C.init_gnutls_client_session() - conn := &Conn{c: c, sess: sess, cfg: cfg} + conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)} n := C.size_t(uintptr(unsafe.Pointer(conn))) //log.Println("conn addr ", int(n)) C.set_data(sess, n) @@ -176,6 +178,8 @@ func setAlpnProtocols(sess *C.struct_session, cfg *Config) error { // Handshake call handshake for TLS Conn, // this function will call automatic on Read/Write, if not handshake yet func (c *Conn) Handshake() error { + c.lock.Lock() + defer c.lock.Unlock() if c.handshake { return nil } @@ -190,12 +194,9 @@ func (c *Conn) Handshake() error { // Read read application data from TLS connection func (c *Conn) Read(buf []byte) (n int, err error) { - if !c.handshake { - err = c.Handshake() - if err != nil { - return - } - c.handshake = true + err = c.Handshake() + if err != nil { + return } bufLen := len(buf) @@ -220,12 +221,9 @@ func (c *Conn) Read(buf []byte) (n int, err error) { // Write write application data to TLS connection func (c *Conn) Write(buf []byte) (n int, err error) { - if !c.handshake { - err = c.Handshake() - if err != nil { - return - } - c.handshake = true + err = c.Handshake() + if err != nil { + return } cbuf := C.CBytes(buf) defer C.free(cbuf) @@ -247,10 +245,12 @@ func (c *Conn) Write(buf []byte) (n int, err error) { // Close close the TLS conn and destroy the tls context func (c *Conn) Close() error { + c.lock.Lock() + defer c.lock.Unlock() if c.closed { return nil } - //C.gnutls_record_send(c.sess.session, nil, 0) + C.gnutls_record_send(c.sess.session, nil, 0) C.session_destroy(c.sess) c.c.Close() if c.cservname != nil { @@ -260,10 +260,12 @@ func (c *Conn) Close() error { if c.state != nil && c.state.PeerCertificate != nil { c.state.PeerCertificate.Free() } + c.closed = true return nil } func (c *Conn) free() { + //log.Println("free conn") c.Close() }