diff --git a/_gnutls.h b/_gnutls.h index bc5048b..9bccd2f 100644 --- a/_gnutls.h +++ b/_gnutls.h @@ -55,4 +55,7 @@ int get_cert_dn(gnutls_pcert_st *st, int index, char *out); void free_cert_list(gnutls_pcert_st *st, int size); gnutls_pcert_st *get_peer_certificate(gnutls_session_t sess, int *pcert_length); int cert_check_hostname(gnutls_pcert_st *st, int len, char *hostname); + +void init_priority_cache(); +void init_xcred(); #endif \ No newline at end of file diff --git a/gnutls.c b/gnutls.c index 412b830..e0082ad 100644 --- a/gnutls.c +++ b/gnutls.c @@ -6,6 +6,8 @@ gnutls_datum_t out; int status; int type; +static gnutls_certificate_credentials_t xcred; +static gnutls_priority_t priority_cache; int _init_session(struct session *); int cert_select_callback(gnutls_session_t sess, const gnutls_datum_t *req_ca_dn, int nreqs, const gnutls_pk_algorithm_t *pk_algos, @@ -36,13 +38,33 @@ struct session *init_gnutls_server_session() return sess; } +void init_priority_cache() +{ + if (priority_cache == NULL) + { + //printf("init priority cache\n"); + gnutls_priority_init(&priority_cache, NULL, NULL); + } +} + +void init_xcred() +{ + if (xcred == NULL) + { + //printf("init xcred\n"); + gnutls_certificate_allocate_credentials(&xcred); + gnutls_certificate_set_x509_system_trust(xcred); + gnutls_certificate_set_retrieve_function2(xcred, cert_select_callback); + } +} + int _init_session(struct session *sess) { - gnutls_certificate_allocate_credentials(&sess->xcred); - gnutls_certificate_set_x509_system_trust(sess->xcred); - gnutls_certificate_set_retrieve_function2(sess->xcred, cert_select_callback); - gnutls_set_default_priority(sess->session); - gnutls_credentials_set(sess->session, GNUTLS_CRD_CERTIFICATE, sess->xcred); + //init_xcred(); + //init_priority_cache(); + //gnutls_set_default_priority(sess->session); + gnutls_priority_set(sess->session, priority_cache); + gnutls_credentials_set(sess->session, GNUTLS_CRD_CERTIFICATE, xcred); return 0; } @@ -50,7 +72,6 @@ void session_destroy(struct session *sess) { gnutls_bye(sess->session, GNUTLS_SHUT_WR); gnutls_deinit(sess->session); - gnutls_certificate_free_credentials(sess->xcred); free(sess); } diff --git a/tls.go b/tls.go index d172f8b..b243ab3 100644 --- a/tls.go +++ b/tls.go @@ -107,11 +107,15 @@ func Listen(network, addr string, cfg *Config) (net.Listener, error) { // Server create a server TLS Conn on c func Server(c net.Conn, cfg *Config) (*Conn, error) { + if cfg == nil { + return nil, errors.New("config is needed") + } + var sess = C.init_gnutls_server_session() + conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)} - n := C.size_t(uintptr(unsafe.Pointer(conn))) - //log.Println("conn addr ", int(n)) - C.set_data(sess, n) + + C.set_data(sess, C.size_t(uintptr(unsafe.Pointer(conn)))) C.set_callback(sess) if cfg.NextProtos != nil { @@ -126,11 +130,12 @@ func Server(c net.Conn, cfg *Config) (*Conn, error) { // Client create a client TLS Conn on c func Client(c net.Conn, cfg *Config) (*Conn, error) { var sess = C.init_gnutls_client_session() + conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)} - n := C.size_t(uintptr(unsafe.Pointer(conn))) - //log.Println("conn addr ", int(n)) - C.set_data(sess, n) + + C.set_data(sess, C.size_t(uintptr(unsafe.Pointer(conn)))) C.set_callback(sess) + if cfg != nil { if cfg.ServerName != "" { srvname := C.CString(cfg.ServerName) @@ -480,3 +485,8 @@ func onCertSelectCallback(ptr unsafe.Pointer, hostname *C.char, //log.Println("set pcert length 0") return -1 } + +func init() { + C.init_xcred() + C.init_priority_cache() +}