use global certificate cred and priority cache

master
fangdingjun 6 years ago
parent 15fd652c74
commit 2c4c546551

@ -55,4 +55,7 @@ int get_cert_dn(gnutls_pcert_st *st, int index, char *out);
void free_cert_list(gnutls_pcert_st *st, int size); void free_cert_list(gnutls_pcert_st *st, int size);
gnutls_pcert_st *get_peer_certificate(gnutls_session_t sess, int *pcert_length); gnutls_pcert_st *get_peer_certificate(gnutls_session_t sess, int *pcert_length);
int cert_check_hostname(gnutls_pcert_st *st, int len, char *hostname); int cert_check_hostname(gnutls_pcert_st *st, int len, char *hostname);
void init_priority_cache();
void init_xcred();
#endif #endif

@ -6,6 +6,8 @@ gnutls_datum_t out;
int status; int status;
int type; int type;
static gnutls_certificate_credentials_t xcred;
static gnutls_priority_t priority_cache;
int _init_session(struct session *); int _init_session(struct session *);
int cert_select_callback(gnutls_session_t sess, const gnutls_datum_t *req_ca_dn, int cert_select_callback(gnutls_session_t sess, const gnutls_datum_t *req_ca_dn,
int nreqs, const gnutls_pk_algorithm_t *pk_algos, int nreqs, const gnutls_pk_algorithm_t *pk_algos,
@ -36,13 +38,33 @@ struct session *init_gnutls_server_session()
return sess; return sess;
} }
void init_priority_cache()
{
if (priority_cache == NULL)
{
//printf("init priority cache\n");
gnutls_priority_init(&priority_cache, NULL, NULL);
}
}
void init_xcred()
{
if (xcred == NULL)
{
//printf("init xcred\n");
gnutls_certificate_allocate_credentials(&xcred);
gnutls_certificate_set_x509_system_trust(xcred);
gnutls_certificate_set_retrieve_function2(xcred, cert_select_callback);
}
}
int _init_session(struct session *sess) int _init_session(struct session *sess)
{ {
gnutls_certificate_allocate_credentials(&sess->xcred); //init_xcred();
gnutls_certificate_set_x509_system_trust(sess->xcred); //init_priority_cache();
gnutls_certificate_set_retrieve_function2(sess->xcred, cert_select_callback); //gnutls_set_default_priority(sess->session);
gnutls_set_default_priority(sess->session); gnutls_priority_set(sess->session, priority_cache);
gnutls_credentials_set(sess->session, GNUTLS_CRD_CERTIFICATE, sess->xcred); gnutls_credentials_set(sess->session, GNUTLS_CRD_CERTIFICATE, xcred);
return 0; return 0;
} }
@ -50,7 +72,6 @@ void session_destroy(struct session *sess)
{ {
gnutls_bye(sess->session, GNUTLS_SHUT_WR); gnutls_bye(sess->session, GNUTLS_SHUT_WR);
gnutls_deinit(sess->session); gnutls_deinit(sess->session);
gnutls_certificate_free_credentials(sess->xcred);
free(sess); free(sess);
} }

@ -107,11 +107,15 @@ func Listen(network, addr string, cfg *Config) (net.Listener, error) {
// Server create a server TLS Conn on c // Server create a server TLS Conn on c
func Server(c net.Conn, cfg *Config) (*Conn, error) { func Server(c net.Conn, cfg *Config) (*Conn, error) {
if cfg == nil {
return nil, errors.New("config is needed")
}
var sess = C.init_gnutls_server_session() var sess = C.init_gnutls_server_session()
conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)} conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)}
n := C.size_t(uintptr(unsafe.Pointer(conn)))
//log.Println("conn addr ", int(n)) C.set_data(sess, C.size_t(uintptr(unsafe.Pointer(conn))))
C.set_data(sess, n)
C.set_callback(sess) C.set_callback(sess)
if cfg.NextProtos != nil { if cfg.NextProtos != nil {
@ -126,11 +130,12 @@ func Server(c net.Conn, cfg *Config) (*Conn, error) {
// Client create a client TLS Conn on c // Client create a client TLS Conn on c
func Client(c net.Conn, cfg *Config) (*Conn, error) { func Client(c net.Conn, cfg *Config) (*Conn, error) {
var sess = C.init_gnutls_client_session() var sess = C.init_gnutls_client_session()
conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)} conn := &Conn{c: c, sess: sess, cfg: cfg, lock: new(sync.Mutex)}
n := C.size_t(uintptr(unsafe.Pointer(conn)))
//log.Println("conn addr ", int(n)) C.set_data(sess, C.size_t(uintptr(unsafe.Pointer(conn))))
C.set_data(sess, n)
C.set_callback(sess) C.set_callback(sess)
if cfg != nil { if cfg != nil {
if cfg.ServerName != "" { if cfg.ServerName != "" {
srvname := C.CString(cfg.ServerName) srvname := C.CString(cfg.ServerName)
@ -480,3 +485,8 @@ func onCertSelectCallback(ptr unsafe.Pointer, hostname *C.char,
//log.Println("set pcert length 0") //log.Println("set pcert length 0")
return -1 return -1
} }
func init() {
C.init_xcred()
C.init_priority_cache()
}

Loading…
Cancel
Save