You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

503 lines
11 KiB
C

#include "_gnutls.h"
#define MAX_BUF 1024
char buffer[MAX_BUF + 1], *desc;
gnutls_datum_t out;
int status;
int type;
static gnutls_certificate_credentials_t xcred;
static gnutls_priority_t priority_cache;
int _init_session(struct session *);
int cert_select_callback(gnutls_session_t sess, const gnutls_datum_t *req_ca_dn,
int nreqs, const gnutls_pk_algorithm_t *pk_algos,
int pk_algos_length, gnutls_pcert_st **pcert,
unsigned int *pcert_length, gnutls_privkey_t *pkey);
struct session *init_gnutls_client_session()
{
struct session *sess = malloc(sizeof(struct session));
memset(sess, sizeof(struct session), 0);
gnutls_init(&sess->session, GNUTLS_CLIENT);
_init_session(sess);
return sess;
}
struct session *init_gnutls_server_session()
{
struct session *sess = malloc(sizeof(struct session));
memset(sess, sizeof(struct session), 0);
gnutls_init(&sess->session, GNUTLS_SERVER);
_init_session(sess);
gnutls_certificate_server_set_request(sess->session, GNUTLS_CERT_IGNORE);
return sess;
}
void init_priority_cache()
{
if (priority_cache == NULL)
{
//printf("init priority cache\n");
gnutls_priority_init(&priority_cache, NULL, NULL);
}
}
void init_xcred()
{
if (xcred == NULL)
{
//printf("init xcred\n");
gnutls_certificate_allocate_credentials(&xcred);
gnutls_certificate_set_x509_system_trust(xcred);
gnutls_certificate_set_retrieve_function2(xcred, cert_select_callback);
}
}
int _init_session(struct session *sess)
{
//init_xcred();
//init_priority_cache();
//gnutls_set_default_priority(sess->session);
gnutls_priority_set(sess->session, priority_cache);
gnutls_credentials_set(sess->session, GNUTLS_CRD_CERTIFICATE, xcred);
return 0;
}
void session_destroy(struct session *sess)
{
gnutls_bye(sess->session, GNUTLS_SHUT_WR);
gnutls_deinit(sess->session);
free(sess);
}
int cert_select_callback(gnutls_session_t sess, const gnutls_datum_t *req_ca_dn,
int nreqs, const gnutls_pk_algorithm_t *pk_algos,
int pk_algos_length, gnutls_pcert_st **pcert,
unsigned int *pcert_length, gnutls_privkey_t *pkey)
{
char hostname[100];
int namelen = 100;
int type = GNUTLS_NAME_DNS;
int ret;
void *ptr;
//printf("cert_select callback\n");
if (sess == NULL)
{
//printf("session is NULL\n");
return -1;
}
ptr = gnutls_session_get_ptr(sess);
if (ptr == NULL)
{
//printf("ptr is NULL\n");
return -1;
}
ret = gnutls_server_name_get(sess, hostname, (size_t *)(&namelen), &type, 0);
if (ret < 0)
{
//printf("get server name error: %s\n", gnutls_strerror(ret));
namelen = 0;
//return -1;
}
//printf("call go callback\n");
ret = onCertSelectCallback(ptr, hostname, namelen, pcert_length, pcert, pkey);
//printf("after callback pcert_length %d, pcert 0x%x, pkey 0x%x\n", *pcert_length, pcert, pkey);
return ret;
}
ssize_t pull_function(gnutls_transport_ptr_t ptr, void *data, size_t len)
{
return onDataReadCallback(ptr, data, len);
}
int pull_timeout_function(gnutls_transport_ptr_t ptr, unsigned int ms)
{
return onDataTimeoutRead(ptr, ms);
}
ssize_t push_function(gnutls_transport_ptr_t ptr, const void *data, size_t len)
{
return onDataWriteCallback(ptr, (char *)data, len);
}
void set_data(struct session *sess, size_t data)
{
sess->data = (void *)((int *)data);
}
int handshake(struct session *sess)
{
int ret;
do
{
ret = gnutls_handshake(sess->session);
//printf("handshake ret %d\n", ret);
} while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
if (ret < 0)
{
if (ret == GNUTLS_E_CERTIFICATE_VERIFICATION_ERROR)
{
// check certificate verification status
type = gnutls_certificate_type_get(sess->session);
status = gnutls_session_get_verify_cert_status(sess->session);
gnutls_certificate_verification_status_print(status,
type, &out, 0);
printf("cert verify output: %s\n", out.data);
gnutls_free(out.data);
}
//fprintf(stderr, "*** Handshake failed: %s\n", gnutls_strerror(ret));
} /*else{
desc = gnutls_session_get_desc(sess->session);
printf("- Session info: %s\n", desc);
gnutls_free(desc);
}*/
return ret;
}
int set_callback(struct session *sess)
{
if (sess->data == NULL)
{
printf("set callback failed\n");
return -1;
}
gnutls_transport_set_ptr(sess->session, sess->data);
gnutls_session_set_ptr(sess->session, sess->data);
gnutls_transport_set_pull_function(sess->session, pull_function);
gnutls_transport_set_push_function(sess->session, push_function);
gnutls_transport_set_pull_timeout_function(sess->session, pull_timeout_function);
return 0;
}
gnutls_cipher_hd_t new_cipher(int cipher_type, char *key, int keylen, char *iv, int ivlen)
{
gnutls_cipher_hd_t handle;
gnutls_datum_t _key;
gnutls_datum_t _iv;
_key.data = key;
_key.size = keylen;
_iv.data = iv;
_iv.size = ivlen;
int ret = gnutls_cipher_init(&handle, cipher_type, &_key, &_iv);
if (ret < 0)
{
printf("new cipher: %s\n", gnutls_strerror(ret));
return NULL;
}
//printf("new cipher done\n");
//cipher->handle = handle;
return handle;
}
gnutls_hash_hd_t new_hash(int t)
{
gnutls_hash_hd_t hash;
gnutls_hash_init(&hash, t);
return hash;
}
int alpn_set_protocols(struct session *sess, char **names, int namelen)
{
gnutls_datum_t *t;
int ret;
int i;
t = (gnutls_datum_t *)malloc(namelen * sizeof(gnutls_datum_t));
for (i = 0; i < namelen; i++)
{
t[i].data = names[i];
t[i].size = strlen(names[i]);
}
ret = gnutls_alpn_set_protocols(sess->session, t,
namelen,
GNUTLS_ALPN_SERVER_PRECEDENCE);
free(t);
return ret;
}
int alpn_get_selected_protocol(struct session *sess, char *buf)
{
gnutls_datum_t p;
int ret;
memset(&p, 0, sizeof(gnutls_datum_t));
ret = gnutls_alpn_get_selected_protocol(sess->session, &p);
if (ret < 0)
{
return ret;
}
strcpy(buf, p.data);
// note: p.data is constant value, only valid during the session life
return 0;
}
void free_cert_list(gnutls_pcert_st *st, int size)
{
int i;
gnutls_pcert_st *st1;
for (i = 0; i < size; i++)
{
st1 = st + i;
gnutls_pcert_deinit(st1);
}
free(st);
}
gnutls_pcert_st *load_cert_list(char *certfile, int *cert_size, int *retcode)
{
gnutls_datum_t data;
int maxsize = 10;
int ret;
gnutls_pcert_st *st = malloc(10 * sizeof(gnutls_pcert_st));
ret = gnutls_load_file(certfile, &data);
if (ret < 0)
{
//printf("load file failed: %s", gnutls_strerror(ret));
*retcode = ret;
free(st);
return NULL;
}
ret = gnutls_pcert_list_import_x509_raw(
st, &maxsize, &data, GNUTLS_X509_FMT_PEM, 0);
if (ret < 0)
{
gnutls_free(data.data);
//printf("import certificate failed: %s", gnutls_strerror(ret));
*retcode = ret;
free(st);
return NULL;
}
gnutls_free(data.data);
*cert_size = maxsize;
*retcode = 0;
return st;
}
gnutls_privkey_t load_privkey(char *keyfile, int *retcode)
{
gnutls_privkey_t privkey;
gnutls_datum_t data;
int ret;
ret = gnutls_load_file(keyfile, &data);
if (ret < 0)
{
//printf("load file failed: %s", gnutls_strerror(ret));
*retcode = ret;
return NULL;
}
gnutls_privkey_init(&privkey);
ret = gnutls_privkey_import_x509_raw(
privkey, &data, GNUTLS_X509_FMT_PEM, NULL, 0);
if (ret < 0)
{
//printf("import privkey failed: %s", gnutls_strerror(ret));
*retcode = ret;
gnutls_free(data.data);
gnutls_privkey_deinit(privkey);
return NULL;
}
gnutls_free(data.data);
*retcode = 0;
return privkey;
}
int get_pcert_alt_name(
gnutls_pcert_st *st, int index, int nameindex, char *out)
{
gnutls_x509_crt_t crt;
int ret;
char data[1024];
size_t size = 1024;
gnutls_pcert_st *st1 = st + index;
ret = gnutls_x509_crt_init(&crt);
if (ret < 0)
{
return ret;
}
ret = gnutls_pcert_export_x509(st1, &crt);
if (ret < 0)
{
goto err;
}
ret = gnutls_x509_crt_get_subject_alt_name(
crt, nameindex, (void *)data, &size, NULL);
if (ret < 0)
{
goto err;
}
//gnutls_x509_crt_deinit(crt);
memcpy(out, data, size);
//return size;
ret = size;
err:
gnutls_x509_crt_deinit(crt);
return ret;
}
int get_cert_str(gnutls_pcert_st *st, int index, int flag, char *out)
{
gnutls_x509_crt_t crt;
int ret;
gnutls_datum_t data;
gnutls_pcert_st *st1 = st + index;
ret = gnutls_x509_crt_init(&crt);
if (ret < 0)
{
return ret;
}
ret = gnutls_pcert_export_x509(st1, &crt);
if (ret < 0)
{
goto err;
}
ret = gnutls_x509_crt_print(crt, flag, &data);
if (ret < 0)
{
goto err;
}
memcpy(out, data.data, data.size);
ret = data.size;
gnutls_free(data.data);
//gnutls_x509_crt_deinit(crt);
//return data.size;
err:
gnutls_x509_crt_deinit(crt);
return ret;
}
int get_cert_dn(gnutls_pcert_st *st, int index, char *out)
{
gnutls_x509_crt_t crt;
int ret;
char data[200];
size_t size = 200;
gnutls_pcert_st *st1 = st + index;
ret = gnutls_x509_crt_init(&crt);
if (ret < 0)
{
return ret;
}
ret = gnutls_pcert_export_x509(st1, &crt);
if (ret < 0)
{
goto err;
}
ret = gnutls_x509_crt_get_dn(crt, data, &size);
if (ret < 0)
{
goto err;
}
//gnutls_x509_crt_deinit(crt);
memcpy(out, data, size);
//return size;
ret = size;
err:
gnutls_x509_crt_deinit(crt);
return ret;
}
int get_cert_issuer_dn(gnutls_pcert_st *st, int index, char *out)
{
gnutls_x509_crt_t crt;
int ret;
char data[200];
size_t size = 200;
gnutls_pcert_st *st1 = st + index;
ret = gnutls_x509_crt_init(&crt);
if (ret < 0)
{
return ret;
}
ret = gnutls_pcert_export_x509(st1, &crt);
if (ret < 0)
{
goto err;
}
ret = gnutls_x509_crt_get_issuer_dn(crt, data, &size);
if (ret < 0)
{
goto err;
}
//gnutls_x509_crt_deinit(crt);
memcpy(out, data, size);
//return size;
ret = size;
err:
gnutls_x509_crt_deinit(crt);
return ret;
}
gnutls_pcert_st *get_peer_certificate(gnutls_session_t sess, int *pcert_length)
{
const gnutls_datum_t *raw_certs;
const gnutls_datum_t *d;
gnutls_pcert_st *st, *st1;
int ret;
int i;
*pcert_length = 0;
raw_certs = gnutls_certificate_get_peers(sess, pcert_length);
if (pcert_length == NULL)
{
//printf("pcert length is NULL\n");
return NULL;
}
if (*pcert_length == 0)
{
//printf("pcert length is 0\n");
return NULL;
}
//printf("pcert length %d\n", *pcert_length);
st = malloc((*pcert_length) * sizeof(gnutls_pcert_st));
for (i = 0; i < *pcert_length; i++)
{
st1 = st + i;
d = raw_certs + i;
ret = gnutls_pcert_import_x509_raw(st1, d, GNUTLS_X509_FMT_DER, 0);
if (ret < 0)
{
printf("import cert failed: %s\n", gnutls_strerror(ret));
}
}
return st;
}
int cert_check_hostname(gnutls_pcert_st *st, int len, char *hostname)
{
int i;
int ret;
int allow = 0;
gnutls_x509_crt_t crt;
for (i = 0; i < len; i++)
{
gnutls_x509_crt_init(&crt);
ret = gnutls_pcert_export_x509((st + i), &crt);
if (ret < 0)
{
gnutls_x509_crt_deinit(crt);
return ret;
}
ret = gnutls_x509_crt_check_hostname(crt, hostname);
if (ret != 0)
{
allow = 1;
gnutls_x509_crt_deinit(crt);
break;
}
gnutls_x509_crt_deinit(crt);
}
return allow;
}