From 2e02d7a93c9737bda95f1227c5c7c9a5244640b0 Mon Sep 17 00:00:00 2001 From: fangdingjun Date: Thu, 28 Jun 2018 09:52:47 +0800 Subject: [PATCH] fix server name check issue --- tls.go | 58 ++++++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 42 insertions(+), 16 deletions(-) diff --git a/tls.go b/tls.go index 562fb62..49781f7 100644 --- a/tls.go +++ b/tls.go @@ -14,18 +14,25 @@ import ( "unsafe" ) +const ( + GNUTLS_NAME_DNS = 1 + GNUTLS_X509_FMT_PEM = 1 +) + // Conn tls connection for client type Conn struct { c net.Conn sess *C.struct_session handshake bool + cservname *C.char } // Config tls configure type Config struct { - ServerName string - CrtFile string - KeyFile string + ServerName string + CrtFile string + KeyFile string + InsecureSkipVerify bool } type listener struct { l net.Listener @@ -93,7 +100,8 @@ func NewServerConn(c net.Conn, cfg *Config) (*Conn, error) { keyfile := C.CString(cfg.KeyFile) defer C.free(unsafe.Pointer(crtfile)) defer C.free(unsafe.Pointer(keyfile)) - ret := C.set_keyfile(sess, crtfile, keyfile) + ret := C.gnutls_certificate_set_x509_key_file( + sess.xcred, crtfile, keyfile, GNUTLS_X509_FMT_PEM) if int(ret) < 0 { cerrstr := C.gnutls_strerror(ret) return nil, fmt.Errorf("set keyfile failed: %s", C.GoString(cerrstr)) @@ -112,8 +120,10 @@ func NewClientConn(c net.Conn, cfg *Config) (*Conn, error) { if cfg != nil { if cfg.ServerName != "" { srvname := C.CString(cfg.ServerName) - defer C.free(unsafe.Pointer(srvname)) - C.set_servername(sess, srvname, C.int(len(cfg.ServerName))) + //defer C.free(unsafe.Pointer(srvname)) + conn.cservname = srvname + C.gnutls_server_name_set(sess.session, GNUTLS_NAME_DNS, + unsafe.Pointer(srvname), C.size_t(len(cfg.ServerName))) } if cfg.CrtFile != "" && cfg.KeyFile != "" { @@ -121,11 +131,22 @@ func NewClientConn(c net.Conn, cfg *Config) (*Conn, error) { keyfile := C.CString(cfg.KeyFile) defer C.free(unsafe.Pointer(crtfile)) defer C.free(unsafe.Pointer(keyfile)) - ret := C.set_keyfile(sess, crtfile, keyfile) + ret := C.gnutls_certificate_set_x509_key_file( + sess.xcred, crtfile, keyfile, GNUTLS_X509_FMT_PEM) if int(ret) < 0 { - return nil, fmt.Errorf("set keyfile failed: %s", C.GoString(C.gnutls_strerror(ret))) + return nil, fmt.Errorf("set keyfile failed: %s", + C.GoString(C.gnutls_strerror(ret))) } } + if !cfg.InsecureSkipVerify { + if conn.cservname != nil { + C.gnutls_session_set_verify_cert(sess.session, conn.cservname, 0) + } else { + C.gnutls_session_set_verify_cert(sess.session, nil, 0) + } + } + } else { + C.gnutls_session_set_verify_cert(sess.session, nil, 0) } return conn, nil } @@ -144,7 +165,7 @@ func (c *Conn) Handshake() error { return nil } -// Read +// Read read data from tls connection func (c *Conn) Read(buf []byte) (n int, err error) { if !c.handshake { err = c.Handshake() @@ -158,9 +179,10 @@ func (c *Conn) Read(buf []byte) (n int, err error) { cbuf := C.malloc(C.size_t(bufLen)) defer C.free(cbuf) - ret := C.read_application_data(c.sess, (*C.char)(cbuf), C.int(bufLen)) + ret := C.gnutls_record_recv(c.sess.session, cbuf, C.size_t(bufLen)) if int(ret) < 0 { - return 0, fmt.Errorf("read error: %s", C.GoString(C.gnutls_strerror(ret))) + return 0, fmt.Errorf("read error: %s", + C.GoString(C.gnutls_strerror(C.int(ret)))) } if int(ret) == 0 { @@ -168,12 +190,12 @@ func (c *Conn) Read(buf []byte) (n int, err error) { } n = int(ret) - gobuf2 := C.GoBytes(cbuf, ret) + gobuf2 := C.GoBytes(cbuf, C.int(ret)) copy(buf, gobuf2) return n, nil } -// Write +// Write write data to tls connection func (c *Conn) Write(buf []byte) (n int, err error) { if !c.handshake { err = c.Handshake() @@ -185,11 +207,12 @@ func (c *Conn) Write(buf []byte) (n int, err error) { cbuf := C.CBytes(buf) defer C.free(cbuf) - ret := C.write_application_data(c.sess, (*C.char)(cbuf), C.int(len(buf))) + ret := C.gnutls_record_send(c.sess.session, cbuf, C.size_t(len(buf))) n = int(ret) if n < 0 { - return 0, fmt.Errorf("write error: %s", C.GoString(C.gnutls_strerror(ret))) + return 0, fmt.Errorf("write error: %s", + C.GoString(C.gnutls_strerror(C.int(ret)))) } if int(ret) == 0 { @@ -199,10 +222,13 @@ func (c *Conn) Write(buf []byte) (n int, err error) { return n, nil } -// Close close the conn +// Close close the conn and destroy the tls context func (c *Conn) Close() error { C.session_destroy(c.sess) c.c.Close() + if c.cservname != nil { + C.free(unsafe.Pointer(c.cservname)) + } return nil }