You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
272 lines
6.1 KiB
Go
272 lines
6.1 KiB
Go
7 years ago
|
package gnutls
|
||
|
|
||
|
/*
|
||
|
#include "_gnutls.h"
|
||
|
#cgo pkg-config: gnutls
|
||
|
*/
|
||
|
import "C"
|
||
|
import (
|
||
|
"fmt"
|
||
|
"log"
|
||
|
"net"
|
||
|
"os"
|
||
|
"time"
|
||
|
"unsafe"
|
||
|
)
|
||
|
|
||
|
// Conn tls connection for client
|
||
|
type Conn struct {
|
||
|
c net.Conn
|
||
|
sess *C.struct_session
|
||
|
handshake bool
|
||
|
}
|
||
|
|
||
|
// Config tls configure
|
||
|
type Config struct {
|
||
|
ServerName string
|
||
|
CrtFile string
|
||
|
KeyFile string
|
||
|
}
|
||
|
type listener struct {
|
||
|
l net.Listener
|
||
|
c *Config
|
||
|
}
|
||
|
|
||
|
// Accept
|
||
|
func (l *listener) Accept() (net.Conn, error) {
|
||
|
c, err := l.l.Accept()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return NewServerConn(c, l.c)
|
||
|
}
|
||
|
|
||
|
// Close
|
||
|
func (l *listener) Close() error {
|
||
|
return l.l.Close()
|
||
|
}
|
||
|
|
||
|
// Addr
|
||
|
func (l *listener) Addr() net.Addr {
|
||
|
return l.l.Addr()
|
||
|
}
|
||
|
|
||
|
// Dial create a new connection
|
||
|
func Dial(network, addr string, cfg *Config) (*Conn, error) {
|
||
|
c, err := net.Dial(network, addr)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return NewClientConn(c, cfg)
|
||
|
}
|
||
|
|
||
|
// Listen create a listener
|
||
|
func Listen(network, addr string, cfg *Config) (net.Listener, error) {
|
||
|
if cfg == nil {
|
||
|
return nil, fmt.Errorf("config is need")
|
||
|
}
|
||
|
if cfg.CrtFile == "" || cfg.KeyFile == "" {
|
||
|
return nil, fmt.Errorf("keyfile is needed")
|
||
|
}
|
||
|
if _, err := os.Stat(cfg.CrtFile); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if _, err := os.Stat(cfg.KeyFile); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
l, err := net.Listen(network, addr)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return &listener{l, cfg}, nil
|
||
|
}
|
||
|
|
||
|
// NewServerConn create a server Conn
|
||
|
func NewServerConn(c net.Conn, cfg *Config) (*Conn, error) {
|
||
|
var sess = C.init_server_session()
|
||
|
conn := &Conn{c: c, sess: sess}
|
||
|
n := C.size_t(uintptr(unsafe.Pointer(conn)))
|
||
|
//log.Println("conn addr ", int(n))
|
||
|
C.set_data(sess, n)
|
||
|
C.set_callback(sess)
|
||
|
crtfile := C.CString(cfg.CrtFile)
|
||
|
keyfile := C.CString(cfg.KeyFile)
|
||
|
defer C.free(unsafe.Pointer(crtfile))
|
||
|
defer C.free(unsafe.Pointer(keyfile))
|
||
|
ret := C.set_keyfile(sess, crtfile, keyfile)
|
||
|
if int(ret) < 0 {
|
||
|
cerrstr := C.gnutls_strerror(ret)
|
||
|
return nil, fmt.Errorf("set keyfile failed: %s", C.GoString(cerrstr))
|
||
|
}
|
||
|
return conn, nil
|
||
|
}
|
||
|
|
||
|
// NewClientConn create a new gnutls connection
|
||
|
func NewClientConn(c net.Conn, cfg *Config) (*Conn, error) {
|
||
|
var sess = C.init_client_session()
|
||
|
conn := &Conn{c: c, sess: sess}
|
||
|
n := C.size_t(uintptr(unsafe.Pointer(conn)))
|
||
|
//log.Println("conn addr ", int(n))
|
||
|
C.set_data(sess, n)
|
||
|
C.set_callback(sess)
|
||
|
if cfg != nil {
|
||
|
if cfg.ServerName != "" {
|
||
|
srvname := C.CString(cfg.ServerName)
|
||
|
defer C.free(unsafe.Pointer(srvname))
|
||
|
C.set_servername(sess, srvname, C.int(len(cfg.ServerName)))
|
||
|
}
|
||
|
|
||
|
if cfg.CrtFile != "" && cfg.KeyFile != "" {
|
||
|
crtfile := C.CString(cfg.CrtFile)
|
||
|
keyfile := C.CString(cfg.KeyFile)
|
||
|
defer C.free(unsafe.Pointer(crtfile))
|
||
|
defer C.free(unsafe.Pointer(keyfile))
|
||
|
ret := C.set_keyfile(sess, crtfile, keyfile)
|
||
|
if int(ret) < 0 {
|
||
|
return nil, fmt.Errorf("set keyfile failed: %s", C.GoString(C.gnutls_strerror(ret)))
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return conn, nil
|
||
|
}
|
||
|
|
||
|
// Handshake handshake tls
|
||
|
func (c *Conn) Handshake() error {
|
||
|
if c.handshake {
|
||
|
return nil
|
||
|
}
|
||
|
ret := C.handshake(c.sess)
|
||
|
if int(ret) < 0 {
|
||
|
return fmt.Errorf("handshake error")
|
||
|
}
|
||
|
c.handshake = true
|
||
|
//log.Println("handshake done")
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Read
|
||
|
func (c *Conn) Read(buf []byte) (n int, err error) {
|
||
|
if !c.handshake {
|
||
|
err = c.Handshake()
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
c.handshake = true
|
||
|
}
|
||
|
|
||
|
bufLen := len(buf)
|
||
|
cbuf := C.malloc(C.size_t(bufLen))
|
||
|
defer C.free(cbuf)
|
||
|
|
||
|
ret := C.read_application_data(c.sess, (*C.char)(cbuf), C.int(bufLen))
|
||
|
if int(ret) < 0 {
|
||
|
return 0, fmt.Errorf("read error: %s", C.GoString(C.gnutls_strerror(ret)))
|
||
|
}
|
||
|
|
||
|
if int(ret) == 0 {
|
||
|
return 0, fmt.Errorf("connection closed")
|
||
|
}
|
||
|
|
||
|
n = int(ret)
|
||
|
gobuf2 := C.GoBytes(cbuf, ret)
|
||
|
copy(buf, gobuf2)
|
||
|
return n, nil
|
||
|
}
|
||
|
|
||
|
// Write
|
||
|
func (c *Conn) Write(buf []byte) (n int, err error) {
|
||
|
if !c.handshake {
|
||
|
err = c.Handshake()
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
c.handshake = true
|
||
|
}
|
||
|
cbuf := C.CBytes(buf)
|
||
|
defer C.free(cbuf)
|
||
|
|
||
|
ret := C.write_application_data(c.sess, (*C.char)(cbuf), C.int(len(buf)))
|
||
|
n = int(ret)
|
||
|
|
||
|
if n < 0 {
|
||
|
return 0, fmt.Errorf("write error: %s", C.GoString(C.gnutls_strerror(ret)))
|
||
|
}
|
||
|
|
||
|
if int(ret) == 0 {
|
||
|
return 0, fmt.Errorf("connection closed")
|
||
|
}
|
||
|
|
||
|
return n, nil
|
||
|
}
|
||
|
|
||
|
// Close close the conn
|
||
|
func (c *Conn) Close() error {
|
||
|
C.session_destroy(c.sess)
|
||
|
c.c.Close()
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// SetWriteDeadline implements net.Conn
|
||
|
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||
|
return c.c.SetWriteDeadline(t)
|
||
|
}
|
||
|
|
||
|
// SetReadDeadline implements net.Conn
|
||
|
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||
|
return c.c.SetReadDeadline(t)
|
||
|
}
|
||
|
|
||
|
// RemoteAddr implements net.Conn
|
||
|
func (c *Conn) RemoteAddr() net.Addr {
|
||
|
return c.c.RemoteAddr()
|
||
|
}
|
||
|
|
||
|
// LocalAddr implements net.Conn
|
||
|
func (c *Conn) LocalAddr() net.Addr {
|
||
|
return c.c.LocalAddr()
|
||
|
}
|
||
|
|
||
|
// SetDeadline implements net.Conn
|
||
|
func (c *Conn) SetDeadline(t time.Time) error {
|
||
|
return c.c.SetDeadline(t)
|
||
|
}
|
||
|
|
||
|
// DataRead c callback function for data read
|
||
|
//export DataRead
|
||
|
func DataRead(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int {
|
||
|
//log.Println("read addr ", uintptr(d))
|
||
|
conn := (*Conn)(unsafe.Pointer((uintptr(d))))
|
||
|
buf := make([]byte, int(bufLen))
|
||
|
n, err := conn.c.Read(buf)
|
||
|
if err != nil {
|
||
|
log.Println(err)
|
||
|
return -1
|
||
|
}
|
||
|
cbuf2 := C.CBytes(buf[:n])
|
||
|
// d := C.CString(string(buf[:n]))
|
||
|
defer C.free(cbuf2)
|
||
|
C.memcpy(unsafe.Pointer(cbuf), unsafe.Pointer(cbuf2), C.size_t(n))
|
||
|
return C.int(n)
|
||
|
}
|
||
|
|
||
|
// DataWrite c callback function for data write
|
||
|
//export DataWrite
|
||
|
func DataWrite(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int {
|
||
|
//log.Println("write addr ", uintptr(d), int(_l))
|
||
|
conn := (*Conn)(unsafe.Pointer((uintptr(d))))
|
||
|
gobuf := C.GoBytes(unsafe.Pointer(cbuf), bufLen)
|
||
|
n, err := conn.c.Write(gobuf)
|
||
|
if err != nil {
|
||
|
log.Println(err)
|
||
|
return -1
|
||
|
}
|
||
|
return C.int(n)
|
||
|
}
|
||
|
|
||
|
// DataTimeoutPull c callback function for timeout read
|
||
|
//export DataTimeoutPull
|
||
|
func DataTimeoutPull(d unsafe.Pointer, delay C.int) C.int {
|
||
|
log.Println("timeout pull function")
|
||
|
return 0
|
||
|
}
|