|
|
@ -9,7 +9,6 @@ import (
|
|
|
|
"fmt"
|
|
|
|
"fmt"
|
|
|
|
"log"
|
|
|
|
"log"
|
|
|
|
"net"
|
|
|
|
"net"
|
|
|
|
"os"
|
|
|
|
|
|
|
|
"time"
|
|
|
|
"time"
|
|
|
|
"unsafe"
|
|
|
|
"unsafe"
|
|
|
|
)
|
|
|
|
)
|
|
|
@ -24,17 +23,17 @@ const (
|
|
|
|
// Conn tls connection for client
|
|
|
|
// Conn tls connection for client
|
|
|
|
type Conn struct {
|
|
|
|
type Conn struct {
|
|
|
|
c net.Conn
|
|
|
|
c net.Conn
|
|
|
|
sess *C.struct_session
|
|
|
|
|
|
|
|
handshake bool
|
|
|
|
handshake bool
|
|
|
|
|
|
|
|
sess *C.struct_session
|
|
|
|
cservname *C.char
|
|
|
|
cservname *C.char
|
|
|
|
state *ConnectionState
|
|
|
|
state *ConnectionState
|
|
|
|
|
|
|
|
cfg *Config
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Config tls configure
|
|
|
|
// Config tls configure
|
|
|
|
type Config struct {
|
|
|
|
type Config struct {
|
|
|
|
ServerName string
|
|
|
|
ServerName string
|
|
|
|
CrtFile string
|
|
|
|
Certificates []*Certificate
|
|
|
|
KeyFile string
|
|
|
|
|
|
|
|
InsecureSkipVerify bool
|
|
|
|
InsecureSkipVerify bool
|
|
|
|
NextProtos []string
|
|
|
|
NextProtos []string
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -50,6 +49,7 @@ type ConnectionState struct {
|
|
|
|
Version uint16
|
|
|
|
Version uint16
|
|
|
|
// TLS version number, ex: TLS1.0
|
|
|
|
// TLS version number, ex: TLS1.0
|
|
|
|
VersionName string
|
|
|
|
VersionName string
|
|
|
|
|
|
|
|
PeerCertificate *Certificate
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
type listener struct {
|
|
|
|
type listener struct {
|
|
|
@ -90,15 +90,6 @@ func Listen(network, addr string, cfg *Config) (net.Listener, error) {
|
|
|
|
if cfg == nil {
|
|
|
|
if cfg == nil {
|
|
|
|
return nil, fmt.Errorf("config is need")
|
|
|
|
return nil, fmt.Errorf("config is need")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if cfg.CrtFile == "" || cfg.KeyFile == "" {
|
|
|
|
|
|
|
|
return nil, fmt.Errorf("keyfile is needed")
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if _, err := os.Stat(cfg.CrtFile); err != nil {
|
|
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if _, err := os.Stat(cfg.KeyFile); err != nil {
|
|
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
l, err := net.Listen(network, addr)
|
|
|
|
l, err := net.Listen(network, addr)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
return nil, err
|
|
|
@ -109,21 +100,12 @@ func Listen(network, addr string, cfg *Config) (net.Listener, error) {
|
|
|
|
// NewServerConn create a server Conn
|
|
|
|
// NewServerConn create a server Conn
|
|
|
|
func NewServerConn(c net.Conn, cfg *Config) (*Conn, error) {
|
|
|
|
func NewServerConn(c net.Conn, cfg *Config) (*Conn, error) {
|
|
|
|
var sess = C.init_server_session()
|
|
|
|
var sess = C.init_server_session()
|
|
|
|
conn := &Conn{c: c, sess: sess}
|
|
|
|
conn := &Conn{c: c, sess: sess, cfg: cfg}
|
|
|
|
n := C.size_t(uintptr(unsafe.Pointer(conn)))
|
|
|
|
n := C.size_t(uintptr(unsafe.Pointer(conn)))
|
|
|
|
//log.Println("conn addr ", int(n))
|
|
|
|
//log.Println("conn addr ", int(n))
|
|
|
|
C.set_data(sess, n)
|
|
|
|
C.set_data(sess, n)
|
|
|
|
C.set_callback(sess)
|
|
|
|
C.set_callback(sess)
|
|
|
|
crtfile := C.CString(cfg.CrtFile)
|
|
|
|
|
|
|
|
keyfile := C.CString(cfg.KeyFile)
|
|
|
|
|
|
|
|
defer C.free(unsafe.Pointer(crtfile))
|
|
|
|
|
|
|
|
defer C.free(unsafe.Pointer(keyfile))
|
|
|
|
|
|
|
|
ret := C.gnutls_certificate_set_x509_key_file(
|
|
|
|
|
|
|
|
sess.xcred, crtfile, keyfile, GNUTLS_X509_FMT_PEM)
|
|
|
|
|
|
|
|
if int(ret) < 0 {
|
|
|
|
|
|
|
|
cerrstr := C.gnutls_strerror(ret)
|
|
|
|
|
|
|
|
return nil, fmt.Errorf("set keyfile failed: %s", C.GoString(cerrstr))
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if cfg.NextProtos != nil {
|
|
|
|
if cfg.NextProtos != nil {
|
|
|
|
if err := setAlpnProtocols(sess, cfg); err != nil {
|
|
|
|
if err := setAlpnProtocols(sess, cfg); err != nil {
|
|
|
|
log.Println(err)
|
|
|
|
log.Println(err)
|
|
|
@ -135,7 +117,7 @@ func NewServerConn(c net.Conn, cfg *Config) (*Conn, error) {
|
|
|
|
// NewClientConn create a new gnutls connection
|
|
|
|
// NewClientConn create a new gnutls connection
|
|
|
|
func NewClientConn(c net.Conn, cfg *Config) (*Conn, error) {
|
|
|
|
func NewClientConn(c net.Conn, cfg *Config) (*Conn, error) {
|
|
|
|
var sess = C.init_client_session()
|
|
|
|
var sess = C.init_client_session()
|
|
|
|
conn := &Conn{c: c, sess: sess}
|
|
|
|
conn := &Conn{c: c, sess: sess, cfg: cfg}
|
|
|
|
n := C.size_t(uintptr(unsafe.Pointer(conn)))
|
|
|
|
n := C.size_t(uintptr(unsafe.Pointer(conn)))
|
|
|
|
//log.Println("conn addr ", int(n))
|
|
|
|
//log.Println("conn addr ", int(n))
|
|
|
|
C.set_data(sess, n)
|
|
|
|
C.set_data(sess, n)
|
|
|
@ -149,18 +131,6 @@ func NewClientConn(c net.Conn, cfg *Config) (*Conn, error) {
|
|
|
|
unsafe.Pointer(srvname), C.size_t(len(cfg.ServerName)))
|
|
|
|
unsafe.Pointer(srvname), C.size_t(len(cfg.ServerName)))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if cfg.CrtFile != "" && cfg.KeyFile != "" {
|
|
|
|
|
|
|
|
crtfile := C.CString(cfg.CrtFile)
|
|
|
|
|
|
|
|
keyfile := C.CString(cfg.KeyFile)
|
|
|
|
|
|
|
|
defer C.free(unsafe.Pointer(crtfile))
|
|
|
|
|
|
|
|
defer C.free(unsafe.Pointer(keyfile))
|
|
|
|
|
|
|
|
ret := C.gnutls_certificate_set_x509_key_file(
|
|
|
|
|
|
|
|
sess.xcred, crtfile, keyfile, GNUTLS_X509_FMT_PEM)
|
|
|
|
|
|
|
|
if int(ret) < 0 {
|
|
|
|
|
|
|
|
return nil, fmt.Errorf("set keyfile failed: %s",
|
|
|
|
|
|
|
|
C.GoString(C.gnutls_strerror(ret)))
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if !cfg.InsecureSkipVerify {
|
|
|
|
if !cfg.InsecureSkipVerify {
|
|
|
|
if conn.cservname != nil {
|
|
|
|
if conn.cservname != nil {
|
|
|
|
C.gnutls_session_set_verify_cert(sess.session, conn.cservname, 0)
|
|
|
|
C.gnutls_session_set_verify_cert(sess.session, conn.cservname, 0)
|
|
|
@ -270,11 +240,16 @@ func (c *Conn) Write(buf []byte) (n int, err error) {
|
|
|
|
|
|
|
|
|
|
|
|
// Close close the conn and destroy the tls context
|
|
|
|
// Close close the conn and destroy the tls context
|
|
|
|
func (c *Conn) Close() error {
|
|
|
|
func (c *Conn) Close() error {
|
|
|
|
|
|
|
|
C.gnutls_record_send(c.sess.session, nil, 0)
|
|
|
|
C.session_destroy(c.sess)
|
|
|
|
C.session_destroy(c.sess)
|
|
|
|
c.c.Close()
|
|
|
|
c.c.Close()
|
|
|
|
if c.cservname != nil {
|
|
|
|
if c.cservname != nil {
|
|
|
|
C.free(unsafe.Pointer(c.cservname))
|
|
|
|
C.free(unsafe.Pointer(c.cservname))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if c.state != nil && c.state.PeerCertificate != nil {
|
|
|
|
|
|
|
|
c.state.PeerCertificate.Free()
|
|
|
|
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -320,11 +295,21 @@ func (c *Conn) ConnectionState() *ConnectionState {
|
|
|
|
HandshakeComplete: c.handshake,
|
|
|
|
HandshakeComplete: c.handshake,
|
|
|
|
ServerName: c.getServerName(),
|
|
|
|
ServerName: c.getServerName(),
|
|
|
|
VersionName: versionname,
|
|
|
|
VersionName: versionname,
|
|
|
|
|
|
|
|
PeerCertificate: c.getPeerCertificate(),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
c.state = state
|
|
|
|
c.state = state
|
|
|
|
return state
|
|
|
|
return state
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func (c *Conn) getPeerCertificate() *Certificate {
|
|
|
|
|
|
|
|
var size int
|
|
|
|
|
|
|
|
st := C.get_peer_certificate(c.sess.session, (*C.int)(unsafe.Pointer(&size)))
|
|
|
|
|
|
|
|
if st == nil {
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return &Certificate{cert: st, certSize: C.int(size)}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (c *Conn) getAlpnSelectedProtocol() string {
|
|
|
|
func (c *Conn) getAlpnSelectedProtocol() string {
|
|
|
|
cbuf := C.malloc(100)
|
|
|
|
cbuf := C.malloc(100)
|
|
|
|
defer C.free(cbuf)
|
|
|
|
defer C.free(cbuf)
|
|
|
@ -353,9 +338,9 @@ func (c *Conn) getServerName() string {
|
|
|
|
return name
|
|
|
|
return name
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// DataRead c callback function for data read
|
|
|
|
// OnDataReadCallback c callback function for data read
|
|
|
|
//export DataRead
|
|
|
|
//export OnDataReadCallback
|
|
|
|
func DataRead(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int {
|
|
|
|
func OnDataReadCallback(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int {
|
|
|
|
//log.Println("read addr ", uintptr(d))
|
|
|
|
//log.Println("read addr ", uintptr(d))
|
|
|
|
conn := (*Conn)(unsafe.Pointer((uintptr(d))))
|
|
|
|
conn := (*Conn)(unsafe.Pointer((uintptr(d))))
|
|
|
|
buf := make([]byte, int(bufLen))
|
|
|
|
buf := make([]byte, int(bufLen))
|
|
|
@ -371,9 +356,9 @@ func DataRead(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int {
|
|
|
|
return C.int(n)
|
|
|
|
return C.int(n)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// DataWrite c callback function for data write
|
|
|
|
// OnDataWriteCallback c callback function for data write
|
|
|
|
//export DataWrite
|
|
|
|
//export OnDataWriteCallback
|
|
|
|
func DataWrite(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int {
|
|
|
|
func OnDataWriteCallback(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int {
|
|
|
|
//log.Println("write addr ", uintptr(d), int(_l))
|
|
|
|
//log.Println("write addr ", uintptr(d), int(_l))
|
|
|
|
conn := (*Conn)(unsafe.Pointer((uintptr(d))))
|
|
|
|
conn := (*Conn)(unsafe.Pointer((uintptr(d))))
|
|
|
|
gobuf := C.GoBytes(unsafe.Pointer(cbuf), bufLen)
|
|
|
|
gobuf := C.GoBytes(unsafe.Pointer(cbuf), bufLen)
|
|
|
@ -385,9 +370,50 @@ func DataWrite(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int {
|
|
|
|
return C.int(n)
|
|
|
|
return C.int(n)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// DataTimeoutPull c callback function for timeout read
|
|
|
|
// OnDataTimeoutRead c callback function for timeout read
|
|
|
|
//export DataTimeoutPull
|
|
|
|
//export OnDataTimeoutRead
|
|
|
|
func DataTimeoutPull(d unsafe.Pointer, delay C.int) C.int {
|
|
|
|
func OnDataTimeoutRead(d unsafe.Pointer, delay C.int) C.int {
|
|
|
|
log.Println("timeout pull function")
|
|
|
|
log.Println("timeout pull function")
|
|
|
|
return 0
|
|
|
|
return 0
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// OnCertSelectCallback callback function for ceritificate select
|
|
|
|
|
|
|
|
//export OnCertSelectCallback
|
|
|
|
|
|
|
|
func OnCertSelectCallback(ptr unsafe.Pointer, hostname *C.char,
|
|
|
|
|
|
|
|
namelen C.int, pcertLength *C.int, cert **C.gnutls_pcert_st, privkey *C.gnutls_privkey_t) C.int {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
servername := C.GoStringN(hostname, namelen)
|
|
|
|
|
|
|
|
//log.Println("go cert select callback ", servername)
|
|
|
|
|
|
|
|
conn := (*Conn)(unsafe.Pointer((uintptr(ptr))))
|
|
|
|
|
|
|
|
//log.Println(conn)
|
|
|
|
|
|
|
|
if int(namelen) == 0 && conn.cfg.Certificates != nil {
|
|
|
|
|
|
|
|
_cert := conn.cfg.Certificates[0]
|
|
|
|
|
|
|
|
*pcertLength = _cert.certSize
|
|
|
|
|
|
|
|
*cert = _cert.cert
|
|
|
|
|
|
|
|
*privkey = _cert.privkey
|
|
|
|
|
|
|
|
//log.Println("set pcert length ", _cert.certSize)
|
|
|
|
|
|
|
|
return 0
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, _cert := range conn.cfg.Certificates {
|
|
|
|
|
|
|
|
//log.Println(cert)
|
|
|
|
|
|
|
|
if _cert.matchName(servername) {
|
|
|
|
|
|
|
|
//log.Println("matched name ", _cert.names)
|
|
|
|
|
|
|
|
*pcertLength = _cert.certSize
|
|
|
|
|
|
|
|
*cert = _cert.cert
|
|
|
|
|
|
|
|
*privkey = _cert.privkey
|
|
|
|
|
|
|
|
//log.Println("set pcert length ", _cert.certSize)
|
|
|
|
|
|
|
|
return 0
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if conn.cfg.Certificates != nil {
|
|
|
|
|
|
|
|
_cert := conn.cfg.Certificates[0]
|
|
|
|
|
|
|
|
*pcertLength = _cert.certSize
|
|
|
|
|
|
|
|
*cert = _cert.cert
|
|
|
|
|
|
|
|
*privkey = _cert.privkey
|
|
|
|
|
|
|
|
//log.Println("set pcert length ", _cert.certSize)
|
|
|
|
|
|
|
|
return 0
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
*pcertLength = 0
|
|
|
|
|
|
|
|
//log.Println("set pcert length 0")
|
|
|
|
|
|
|
|
return -1
|
|
|
|
|
|
|
|
}
|
|
|
|