|
|
@ -14,11 +14,17 @@ import (
|
|
|
|
"unsafe"
|
|
|
|
"unsafe"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const (
|
|
|
|
|
|
|
|
GNUTLS_NAME_DNS = 1
|
|
|
|
|
|
|
|
GNUTLS_X509_FMT_PEM = 1
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
// Conn tls connection for client
|
|
|
|
// Conn tls connection for client
|
|
|
|
type Conn struct {
|
|
|
|
type Conn struct {
|
|
|
|
c net.Conn
|
|
|
|
c net.Conn
|
|
|
|
sess *C.struct_session
|
|
|
|
sess *C.struct_session
|
|
|
|
handshake bool
|
|
|
|
handshake bool
|
|
|
|
|
|
|
|
cservname *C.char
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Config tls configure
|
|
|
|
// Config tls configure
|
|
|
@ -26,6 +32,7 @@ type Config struct {
|
|
|
|
ServerName string
|
|
|
|
ServerName string
|
|
|
|
CrtFile string
|
|
|
|
CrtFile string
|
|
|
|
KeyFile string
|
|
|
|
KeyFile string
|
|
|
|
|
|
|
|
InsecureSkipVerify bool
|
|
|
|
}
|
|
|
|
}
|
|
|
|
type listener struct {
|
|
|
|
type listener struct {
|
|
|
|
l net.Listener
|
|
|
|
l net.Listener
|
|
|
@ -93,7 +100,8 @@ func NewServerConn(c net.Conn, cfg *Config) (*Conn, error) {
|
|
|
|
keyfile := C.CString(cfg.KeyFile)
|
|
|
|
keyfile := C.CString(cfg.KeyFile)
|
|
|
|
defer C.free(unsafe.Pointer(crtfile))
|
|
|
|
defer C.free(unsafe.Pointer(crtfile))
|
|
|
|
defer C.free(unsafe.Pointer(keyfile))
|
|
|
|
defer C.free(unsafe.Pointer(keyfile))
|
|
|
|
ret := C.set_keyfile(sess, crtfile, keyfile)
|
|
|
|
ret := C.gnutls_certificate_set_x509_key_file(
|
|
|
|
|
|
|
|
sess.xcred, crtfile, keyfile, GNUTLS_X509_FMT_PEM)
|
|
|
|
if int(ret) < 0 {
|
|
|
|
if int(ret) < 0 {
|
|
|
|
cerrstr := C.gnutls_strerror(ret)
|
|
|
|
cerrstr := C.gnutls_strerror(ret)
|
|
|
|
return nil, fmt.Errorf("set keyfile failed: %s", C.GoString(cerrstr))
|
|
|
|
return nil, fmt.Errorf("set keyfile failed: %s", C.GoString(cerrstr))
|
|
|
@ -112,8 +120,10 @@ func NewClientConn(c net.Conn, cfg *Config) (*Conn, error) {
|
|
|
|
if cfg != nil {
|
|
|
|
if cfg != nil {
|
|
|
|
if cfg.ServerName != "" {
|
|
|
|
if cfg.ServerName != "" {
|
|
|
|
srvname := C.CString(cfg.ServerName)
|
|
|
|
srvname := C.CString(cfg.ServerName)
|
|
|
|
defer C.free(unsafe.Pointer(srvname))
|
|
|
|
//defer C.free(unsafe.Pointer(srvname))
|
|
|
|
C.set_servername(sess, srvname, C.int(len(cfg.ServerName)))
|
|
|
|
conn.cservname = srvname
|
|
|
|
|
|
|
|
C.gnutls_server_name_set(sess.session, GNUTLS_NAME_DNS,
|
|
|
|
|
|
|
|
unsafe.Pointer(srvname), C.size_t(len(cfg.ServerName)))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if cfg.CrtFile != "" && cfg.KeyFile != "" {
|
|
|
|
if cfg.CrtFile != "" && cfg.KeyFile != "" {
|
|
|
@ -121,12 +131,23 @@ func NewClientConn(c net.Conn, cfg *Config) (*Conn, error) {
|
|
|
|
keyfile := C.CString(cfg.KeyFile)
|
|
|
|
keyfile := C.CString(cfg.KeyFile)
|
|
|
|
defer C.free(unsafe.Pointer(crtfile))
|
|
|
|
defer C.free(unsafe.Pointer(crtfile))
|
|
|
|
defer C.free(unsafe.Pointer(keyfile))
|
|
|
|
defer C.free(unsafe.Pointer(keyfile))
|
|
|
|
ret := C.set_keyfile(sess, crtfile, keyfile)
|
|
|
|
ret := C.gnutls_certificate_set_x509_key_file(
|
|
|
|
|
|
|
|
sess.xcred, crtfile, keyfile, GNUTLS_X509_FMT_PEM)
|
|
|
|
if int(ret) < 0 {
|
|
|
|
if int(ret) < 0 {
|
|
|
|
return nil, fmt.Errorf("set keyfile failed: %s", C.GoString(C.gnutls_strerror(ret)))
|
|
|
|
return nil, fmt.Errorf("set keyfile failed: %s",
|
|
|
|
|
|
|
|
C.GoString(C.gnutls_strerror(ret)))
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if !cfg.InsecureSkipVerify {
|
|
|
|
|
|
|
|
if conn.cservname != nil {
|
|
|
|
|
|
|
|
C.gnutls_session_set_verify_cert(sess.session, conn.cservname, 0)
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
C.gnutls_session_set_verify_cert(sess.session, nil, 0)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
C.gnutls_session_set_verify_cert(sess.session, nil, 0)
|
|
|
|
|
|
|
|
}
|
|
|
|
return conn, nil
|
|
|
|
return conn, nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -144,7 +165,7 @@ func (c *Conn) Handshake() error {
|
|
|
|
return nil
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Read
|
|
|
|
// Read read data from tls connection
|
|
|
|
func (c *Conn) Read(buf []byte) (n int, err error) {
|
|
|
|
func (c *Conn) Read(buf []byte) (n int, err error) {
|
|
|
|
if !c.handshake {
|
|
|
|
if !c.handshake {
|
|
|
|
err = c.Handshake()
|
|
|
|
err = c.Handshake()
|
|
|
@ -158,9 +179,10 @@ func (c *Conn) Read(buf []byte) (n int, err error) {
|
|
|
|
cbuf := C.malloc(C.size_t(bufLen))
|
|
|
|
cbuf := C.malloc(C.size_t(bufLen))
|
|
|
|
defer C.free(cbuf)
|
|
|
|
defer C.free(cbuf)
|
|
|
|
|
|
|
|
|
|
|
|
ret := C.read_application_data(c.sess, (*C.char)(cbuf), C.int(bufLen))
|
|
|
|
ret := C.gnutls_record_recv(c.sess.session, cbuf, C.size_t(bufLen))
|
|
|
|
if int(ret) < 0 {
|
|
|
|
if int(ret) < 0 {
|
|
|
|
return 0, fmt.Errorf("read error: %s", C.GoString(C.gnutls_strerror(ret)))
|
|
|
|
return 0, fmt.Errorf("read error: %s",
|
|
|
|
|
|
|
|
C.GoString(C.gnutls_strerror(C.int(ret))))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if int(ret) == 0 {
|
|
|
|
if int(ret) == 0 {
|
|
|
@ -168,12 +190,12 @@ func (c *Conn) Read(buf []byte) (n int, err error) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
n = int(ret)
|
|
|
|
n = int(ret)
|
|
|
|
gobuf2 := C.GoBytes(cbuf, ret)
|
|
|
|
gobuf2 := C.GoBytes(cbuf, C.int(ret))
|
|
|
|
copy(buf, gobuf2)
|
|
|
|
copy(buf, gobuf2)
|
|
|
|
return n, nil
|
|
|
|
return n, nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Write
|
|
|
|
// Write write data to tls connection
|
|
|
|
func (c *Conn) Write(buf []byte) (n int, err error) {
|
|
|
|
func (c *Conn) Write(buf []byte) (n int, err error) {
|
|
|
|
if !c.handshake {
|
|
|
|
if !c.handshake {
|
|
|
|
err = c.Handshake()
|
|
|
|
err = c.Handshake()
|
|
|
@ -185,11 +207,12 @@ func (c *Conn) Write(buf []byte) (n int, err error) {
|
|
|
|
cbuf := C.CBytes(buf)
|
|
|
|
cbuf := C.CBytes(buf)
|
|
|
|
defer C.free(cbuf)
|
|
|
|
defer C.free(cbuf)
|
|
|
|
|
|
|
|
|
|
|
|
ret := C.write_application_data(c.sess, (*C.char)(cbuf), C.int(len(buf)))
|
|
|
|
ret := C.gnutls_record_send(c.sess.session, cbuf, C.size_t(len(buf)))
|
|
|
|
n = int(ret)
|
|
|
|
n = int(ret)
|
|
|
|
|
|
|
|
|
|
|
|
if n < 0 {
|
|
|
|
if n < 0 {
|
|
|
|
return 0, fmt.Errorf("write error: %s", C.GoString(C.gnutls_strerror(ret)))
|
|
|
|
return 0, fmt.Errorf("write error: %s",
|
|
|
|
|
|
|
|
C.GoString(C.gnutls_strerror(C.int(ret))))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if int(ret) == 0 {
|
|
|
|
if int(ret) == 0 {
|
|
|
@ -199,10 +222,13 @@ func (c *Conn) Write(buf []byte) (n int, err error) {
|
|
|
|
return n, nil
|
|
|
|
return n, nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Close close the conn
|
|
|
|
// Close close the conn and destroy the tls context
|
|
|
|
func (c *Conn) Close() error {
|
|
|
|
func (c *Conn) Close() error {
|
|
|
|
C.session_destroy(c.sess)
|
|
|
|
C.session_destroy(c.sess)
|
|
|
|
c.c.Close()
|
|
|
|
c.c.Close()
|
|
|
|
|
|
|
|
if c.cservname != nil {
|
|
|
|
|
|
|
|
C.free(unsafe.Pointer(c.cservname))
|
|
|
|
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|