You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

394 lines
9.2 KiB
Go

package gnutls
/*
#include "_gnutls.h"
#cgo pkg-config: gnutls
*/
import "C"
import (
"fmt"
"log"
"net"
"os"
"time"
"unsafe"
)
const (
GNUTLS_NAME_DNS = 1
GNUTLS_X509_FMT_PEM = 1
GNUTLS_ALPN_MANDATORY = 1
GNUTLS_ALPN_SERVER_PRECEDENCE = 1 << 1
)
// Conn tls connection for client
type Conn struct {
c net.Conn
sess *C.struct_session
handshake bool
cservname *C.char
state *ConnectionState
}
// Config tls configure
type Config struct {
ServerName string
CrtFile string
KeyFile string
InsecureSkipVerify bool
NextProtos []string
}
// ConnectionState connection state
type ConnectionState struct {
// SNI name client send
ServerName string
// selected ALPN protocl
NegotiatedProtocol string
HandshakeComplete bool
// TLS version number, ex: 0x303
Version uint16
// TLS version number, ex: TLS1.0
VersionName string
}
type listener struct {
l net.Listener
c *Config
}
// Accept
func (l *listener) Accept() (net.Conn, error) {
c, err := l.l.Accept()
if err != nil {
return nil, err
}
return NewServerConn(c, l.c)
}
// Close
func (l *listener) Close() error {
return l.l.Close()
}
// Addr
func (l *listener) Addr() net.Addr {
return l.l.Addr()
}
// Dial create a new connection
func Dial(network, addr string, cfg *Config) (*Conn, error) {
c, err := net.Dial(network, addr)
if err != nil {
return nil, err
}
return NewClientConn(c, cfg)
}
// Listen create a listener
func Listen(network, addr string, cfg *Config) (net.Listener, error) {
if cfg == nil {
return nil, fmt.Errorf("config is need")
}
if cfg.CrtFile == "" || cfg.KeyFile == "" {
return nil, fmt.Errorf("keyfile is needed")
}
if _, err := os.Stat(cfg.CrtFile); err != nil {
return nil, err
}
if _, err := os.Stat(cfg.KeyFile); err != nil {
return nil, err
}
l, err := net.Listen(network, addr)
if err != nil {
return nil, err
}
return &listener{l, cfg}, nil
}
// NewServerConn create a server Conn
func NewServerConn(c net.Conn, cfg *Config) (*Conn, error) {
var sess = C.init_server_session()
conn := &Conn{c: c, sess: sess}
n := C.size_t(uintptr(unsafe.Pointer(conn)))
//log.Println("conn addr ", int(n))
C.set_data(sess, n)
C.set_callback(sess)
crtfile := C.CString(cfg.CrtFile)
keyfile := C.CString(cfg.KeyFile)
defer C.free(unsafe.Pointer(crtfile))
defer C.free(unsafe.Pointer(keyfile))
ret := C.gnutls_certificate_set_x509_key_file(
sess.xcred, crtfile, keyfile, GNUTLS_X509_FMT_PEM)
if int(ret) < 0 {
cerrstr := C.gnutls_strerror(ret)
return nil, fmt.Errorf("set keyfile failed: %s", C.GoString(cerrstr))
}
if cfg.NextProtos != nil {
if err := setAlpnProtocols(sess, cfg); err != nil {
log.Println(err)
}
}
return conn, nil
}
// NewClientConn create a new gnutls connection
func NewClientConn(c net.Conn, cfg *Config) (*Conn, error) {
var sess = C.init_client_session()
conn := &Conn{c: c, sess: sess}
n := C.size_t(uintptr(unsafe.Pointer(conn)))
//log.Println("conn addr ", int(n))
C.set_data(sess, n)
C.set_callback(sess)
if cfg != nil {
if cfg.ServerName != "" {
srvname := C.CString(cfg.ServerName)
//defer C.free(unsafe.Pointer(srvname))
conn.cservname = srvname
C.gnutls_server_name_set(sess.session, GNUTLS_NAME_DNS,
unsafe.Pointer(srvname), C.size_t(len(cfg.ServerName)))
}
if cfg.CrtFile != "" && cfg.KeyFile != "" {
crtfile := C.CString(cfg.CrtFile)
keyfile := C.CString(cfg.KeyFile)
defer C.free(unsafe.Pointer(crtfile))
defer C.free(unsafe.Pointer(keyfile))
ret := C.gnutls_certificate_set_x509_key_file(
sess.xcred, crtfile, keyfile, GNUTLS_X509_FMT_PEM)
if int(ret) < 0 {
return nil, fmt.Errorf("set keyfile failed: %s",
C.GoString(C.gnutls_strerror(ret)))
}
}
if !cfg.InsecureSkipVerify {
if conn.cservname != nil {
C.gnutls_session_set_verify_cert(sess.session, conn.cservname, 0)
} else {
C.gnutls_session_set_verify_cert(sess.session, nil, 0)
}
}
if cfg.NextProtos != nil {
if err := setAlpnProtocols(sess, cfg); err != nil {
log.Println(err)
}
}
} else {
C.gnutls_session_set_verify_cert(sess.session, nil, 0)
}
return conn, nil
}
func setAlpnProtocols(sess *C.struct_session, cfg *Config) error {
arg := make([](*C.char), 0)
for _, s := range cfg.NextProtos {
cbuf := C.CString(s)
defer C.free(unsafe.Pointer(cbuf))
arg = append(arg, (*C.char)(cbuf))
}
ret := C.alpn_set_protocols(sess,
(**C.char)(unsafe.Pointer(&arg[0])), C.int(len(cfg.NextProtos)))
if int(ret) < 0 {
return fmt.Errorf("set alpn failed: %s", C.GoString(C.gnutls_strerror(ret)))
}
return nil
}
// Handshake handshake tls
func (c *Conn) Handshake() error {
if c.handshake {
return nil
}
ret := C.handshake(c.sess)
if int(ret) < 0 {
return fmt.Errorf("handshake error")
}
c.handshake = true
//log.Println("handshake done")
return nil
}
// Read read data from tls connection
func (c *Conn) Read(buf []byte) (n int, err error) {
if !c.handshake {
err = c.Handshake()
if err != nil {
return
}
c.handshake = true
}
bufLen := len(buf)
cbuf := C.malloc(C.size_t(bufLen))
defer C.free(cbuf)
ret := C.gnutls_record_recv(c.sess.session, cbuf, C.size_t(bufLen))
if int(ret) < 0 {
return 0, fmt.Errorf("read error: %s",
C.GoString(C.gnutls_strerror(C.int(ret))))
}
if int(ret) == 0 {
return 0, fmt.Errorf("connection closed")
}
n = int(ret)
gobuf2 := C.GoBytes(cbuf, C.int(ret))
copy(buf, gobuf2)
return n, nil
}
// Write write data to tls connection
func (c *Conn) Write(buf []byte) (n int, err error) {
if !c.handshake {
err = c.Handshake()
if err != nil {
return
}
c.handshake = true
}
cbuf := C.CBytes(buf)
defer C.free(cbuf)
ret := C.gnutls_record_send(c.sess.session, cbuf, C.size_t(len(buf)))
n = int(ret)
if n < 0 {
return 0, fmt.Errorf("write error: %s",
C.GoString(C.gnutls_strerror(C.int(ret))))
}
if int(ret) == 0 {
return 0, fmt.Errorf("connection closed")
}
return n, nil
}
// Close close the conn and destroy the tls context
func (c *Conn) Close() error {
C.session_destroy(c.sess)
c.c.Close()
if c.cservname != nil {
C.free(unsafe.Pointer(c.cservname))
}
return nil
}
// SetWriteDeadline implements net.Conn
func (c *Conn) SetWriteDeadline(t time.Time) error {
return c.c.SetWriteDeadline(t)
}
// SetReadDeadline implements net.Conn
func (c *Conn) SetReadDeadline(t time.Time) error {
return c.c.SetReadDeadline(t)
}
// RemoteAddr implements net.Conn
func (c *Conn) RemoteAddr() net.Addr {
return c.c.RemoteAddr()
}
// LocalAddr implements net.Conn
func (c *Conn) LocalAddr() net.Addr {
return c.c.LocalAddr()
}
// SetDeadline implements net.Conn
func (c *Conn) SetDeadline(t time.Time) error {
return c.c.SetDeadline(t)
}
// ConnectionState report connection state
func (c *Conn) ConnectionState() *ConnectionState {
if c.state != nil {
return c.state
}
version :=
uint16(C.gnutls_protocol_get_version(c.sess.session))
versionname := C.GoString(
C.gnutls_protocol_get_name(C.gnutls_protocol_t(version)))
state := &ConnectionState{
NegotiatedProtocol: c.getAlpnSelectedProtocol(),
Version: version,
HandshakeComplete: c.handshake,
ServerName: c.getServerName(),
VersionName: versionname,
}
c.state = state
return state
}
func (c *Conn) getAlpnSelectedProtocol() string {
cbuf := C.malloc(100)
defer C.free(cbuf)
ret := C.alpn_get_selected_protocol(c.sess, (*C.char)(cbuf))
if int(ret) < 0 {
return ""
}
alpnname := C.GoString((*C.char)(cbuf))
return alpnname
}
func (c *Conn) getServerName() string {
buflen := 100
nametype := GNUTLS_NAME_DNS
cbuf := C.malloc(C.size_t(buflen))
defer C.free(cbuf)
ret := C.gnutls_server_name_get(c.sess.session, cbuf,
(*C.size_t)(unsafe.Pointer(&buflen)),
(*C.uint)(unsafe.Pointer(&nametype)), 0)
if int(ret) < 0 {
return ""
}
name := C.GoString((*C.char)(cbuf))
return name
}
// DataRead c callback function for data read
//export DataRead
func DataRead(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int {
//log.Println("read addr ", uintptr(d))
conn := (*Conn)(unsafe.Pointer((uintptr(d))))
buf := make([]byte, int(bufLen))
n, err := conn.c.Read(buf)
if err != nil {
log.Println(err)
return -1
}
cbuf2 := C.CBytes(buf[:n])
// d := C.CString(string(buf[:n]))
defer C.free(cbuf2)
C.memcpy(unsafe.Pointer(cbuf), unsafe.Pointer(cbuf2), C.size_t(n))
return C.int(n)
}
// DataWrite c callback function for data write
//export DataWrite
func DataWrite(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int {
//log.Println("write addr ", uintptr(d), int(_l))
conn := (*Conn)(unsafe.Pointer((uintptr(d))))
gobuf := C.GoBytes(unsafe.Pointer(cbuf), bufLen)
n, err := conn.c.Write(gobuf)
if err != nil {
log.Println(err)
return -1
}
return C.int(n)
}
// DataTimeoutPull c callback function for timeout read
//export DataTimeoutPull
func DataTimeoutPull(d unsafe.Pointer, delay C.int) C.int {
log.Println("timeout pull function")
return 0
}