diff --git a/_gnutls.h b/_gnutls.h index 6e8eab6..462aa63 100644 --- a/_gnutls.h +++ b/_gnutls.h @@ -2,6 +2,7 @@ #define _GNUTLS_H #include #include +#include #include #include #include @@ -10,13 +11,12 @@ struct session { gnutls_session_t session; gnutls_certificate_credentials_t xcred; - int handshake; void *data; }; -extern int DataRead(void *, char *, int); -extern int DataWrite(void *, char *, int); -extern int DataTimeoutPull(void *, int); +extern int OnOnDataReadCallbackCallback(void *, char *, int); +extern int OnDataWriteCallback(void *, char *, int); +extern int OnDataTimeoutRead(void *, int); struct session *init_client_session(); struct session *init_server_session(); @@ -31,10 +31,27 @@ int set_callback(struct session *sess); void session_destroy(struct session *); +int OnCertSelectCallback(void *ptr, char *hostname, int namelen, + int *pcert_length, gnutls_pcert_st **cert, gnutls_privkey_t *privke); + gnutls_cipher_hd_t new_cipher(int cipher_type, char *key, int keylen, char *iv, int ivlen); gnutls_hash_hd_t new_hash(int t); int alpn_set_protocols(struct session *sess, char **, int); int alpn_get_selected_protocol(struct session *sess, char *buf); + +gnutls_privkey_t load_privkey(char *keyfile, int *); +gnutls_pcert_st *load_cert_list(char *certfile, int *, int *); + +int get_pcert_alt_name(gnutls_pcert_st *st, int index, int nameindex, char *out); + +int get_cert_str(gnutls_pcert_st *st, int index, int flag, char *out); + +int get_cert_issuer_dn(gnutls_pcert_st *st, int index, char *out); + +int get_cert_dn(gnutls_pcert_st *st, int index, char *out); + +void free_cert_list(gnutls_pcert_st *st, int size); +gnutls_pcert_st *get_peer_certificate(gnutls_session_t sess, int *pcert_length); #endif \ No newline at end of file diff --git a/gnutls.c b/gnutls.c index ee30939..65b2669 100644 --- a/gnutls.c +++ b/gnutls.c @@ -7,6 +7,10 @@ int status; int type; int _init_session(struct session *); +int cert_select_callback(gnutls_session_t sess, const gnutls_datum_t *req_ca_dn, + int nreqs, const gnutls_pk_algorithm_t *pk_algos, + int pk_algos_length, gnutls_pcert_st **pcert, + unsigned int *pcert_length, gnutls_privkey_t *pkey); struct session *init_client_session() { @@ -36,33 +40,69 @@ int _init_session(struct session *sess) { gnutls_certificate_allocate_credentials(&sess->xcred); gnutls_certificate_set_x509_system_trust(sess->xcred); + gnutls_certificate_set_retrieve_function2(sess->xcred, cert_select_callback); gnutls_set_default_priority(sess->session); gnutls_credentials_set(sess->session, GNUTLS_CRD_CERTIFICATE, sess->xcred); - return 0; } void session_destroy(struct session *sess) { - gnutls_bye(sess->session, GNUTLS_SHUT_WR); + gnutls_bye(sess->session, GNUTLS_SHUT_RDWR); gnutls_deinit(sess->session); gnutls_certificate_free_credentials(sess->xcred); free(sess); } +int cert_select_callback(gnutls_session_t sess, const gnutls_datum_t *req_ca_dn, + int nreqs, const gnutls_pk_algorithm_t *pk_algos, + int pk_algos_length, gnutls_pcert_st **pcert, + unsigned int *pcert_length, gnutls_privkey_t *pkey) +{ + char hostname[100]; + int namelen = 100; + int type = GNUTLS_NAME_DNS; + int ret; + void *ptr; + + //printf("cert_select callback\n"); + if (sess == NULL) + { + //printf("session is NULL\n"); + return -1; + } + ptr = gnutls_session_get_ptr(sess); + if (ptr == NULL) + { + //printf("ptr is NULL\n"); + return -1; + } + ret = gnutls_server_name_get(sess, hostname, (size_t *)(&namelen), &type, 0); + if (ret < 0) + { + //printf("get server name error: %s\n", gnutls_strerror(ret)); + namelen = 0; + //return -1; + } + //printf("call go callback\n"); + ret = OnCertSelectCallback(ptr, hostname, namelen, pcert_length, pcert, pkey); + //printf("after callback pcert_length %d, pcert 0x%x, pkey 0x%x\n", *pcert_length, pcert, pkey); + return ret; +} + ssize_t pull_function(gnutls_transport_ptr_t ptr, void *data, size_t len) { - return DataRead(ptr, data, len); + return OnOnDataReadCallbackCallback(ptr, data, len); } int pull_timeout_function(gnutls_transport_ptr_t ptr, unsigned int ms) { - return DataTimeoutPull(ptr, ms); + return OnDataTimeoutRead(ptr, ms); } ssize_t push_function(gnutls_transport_ptr_t ptr, const void *data, size_t len) { - return DataWrite(ptr, (char *)data, len); + return OnDataWriteCallback(ptr, (char *)data, len); } void set_data(struct session *sess, size_t data) @@ -72,15 +112,12 @@ void set_data(struct session *sess, size_t data) int handshake(struct session *sess) { - if (sess->handshake > 0) - { - return 0; - } int ret; do { ret = gnutls_handshake(sess->session); + //printf("handshake ret %d\n", ret); } while (ret < 0 && gnutls_error_is_fatal(ret) == 0); if (ret < 0) @@ -108,9 +145,11 @@ int set_callback(struct session *sess) { if (sess->data == NULL) { + printf("set callback failed\n"); return -1; } gnutls_transport_set_ptr(sess->session, sess->data); + gnutls_session_set_ptr(sess->session, sess->data); gnutls_transport_set_pull_function(sess->session, pull_function); gnutls_transport_set_push_function(sess->session, push_function); gnutls_transport_set_pull_timeout_function(sess->session, pull_timeout_function); @@ -181,4 +220,222 @@ int alpn_get_selected_protocol(struct session *sess, char *buf) // note: p.data is constant value, only valid during the session life return 0; +} + +void free_cert_list(gnutls_pcert_st *st, int size) +{ + int i; + gnutls_pcert_st *st1; + for (i = 0; i < size; i++) + { + st1 = st + i; + gnutls_pcert_deinit(st1); + } + free(st); +} + +gnutls_pcert_st *load_cert_list(char *certfile, int *cert_size, int *retcode) +{ + gnutls_datum_t data; + int maxsize = 10; + int ret; + gnutls_pcert_st *st = malloc(10 * sizeof(gnutls_pcert_st)); + ret = gnutls_load_file(certfile, &data); + if (ret < 0) + { + //printf("load file failed: %s", gnutls_strerror(ret)); + *retcode = ret; + return NULL; + } + ret = gnutls_pcert_list_import_x509_raw( + st, &maxsize, &data, GNUTLS_X509_FMT_PEM, 0); + if (ret < 0) + { + gnutls_free(data.data); + //printf("import certificate failed: %s", gnutls_strerror(ret)); + *retcode = ret; + return NULL; + } + gnutls_free(data.data); + *cert_size = maxsize; + *retcode = 0; + return st; +} + +gnutls_privkey_t load_privkey(char *keyfile, int *retcode) +{ + gnutls_privkey_t privkey; + gnutls_datum_t data; + int ret; + ret = gnutls_load_file(keyfile, &data); + if (ret < 0) + { + //printf("load file failed: %s", gnutls_strerror(ret)); + *retcode = ret; + return NULL; + } + gnutls_privkey_init(&privkey); + ret = gnutls_privkey_import_x509_raw( + privkey, &data, GNUTLS_X509_FMT_PEM, NULL, 0); + if (ret < 0) + { + //printf("import privkey failed: %s", gnutls_strerror(ret)); + *retcode = ret; + gnutls_free(data.data); + return NULL; + } + gnutls_free(data.data); + *retcode = 0; + return privkey; +} + +int get_pcert_alt_name( + gnutls_pcert_st *st, int index, int nameindex, char *out) +{ + gnutls_x509_crt_t crt; + int ret; + char data[1024]; + size_t size = 1024; + gnutls_pcert_st *st1 = st + index; + ret = gnutls_x509_crt_init(&crt); + if (ret < 0) + { + return ret; + } + ret = gnutls_pcert_export_x509(st1, &crt); + if (ret < 0) + { + gnutls_x509_crt_deinit(crt); + return ret; + } + ret = gnutls_x509_crt_get_subject_alt_name( + crt, nameindex, (void *)data, &size, NULL); + if (ret < 0) + { + gnutls_x509_crt_deinit(crt); + return ret; + } + gnutls_x509_crt_deinit(crt); + memcpy(out, data, size); + return size; +} + +int get_cert_str(gnutls_pcert_st *st, int index, int flag, char *out) +{ + gnutls_x509_crt_t crt; + int ret; + gnutls_datum_t data; + gnutls_pcert_st *st1 = st + index; + ret = gnutls_x509_crt_init(&crt); + if (ret < 0) + { + return ret; + } + ret = gnutls_pcert_export_x509(st1, &crt); + if (ret < 0) + { + gnutls_x509_crt_deinit(crt); + return ret; + } + ret = gnutls_x509_crt_print(crt, flag, &data); + if (ret < 0) + { + return ret; + } + memcpy(out, data.data, data.size); + gnutls_free(data.data); + gnutls_x509_crt_deinit(crt); + return data.size; +} + +int get_cert_dn(gnutls_pcert_st *st, int index, char *out) +{ + + gnutls_x509_crt_t crt; + int ret; + char data[200]; + size_t size = 200; + gnutls_pcert_st *st1 = st + index; + ret = gnutls_x509_crt_init(&crt); + if (ret < 0) + { + return ret; + } + ret = gnutls_pcert_export_x509(st1, &crt); + if (ret < 0) + { + gnutls_x509_crt_deinit(crt); + return ret; + } + ret = gnutls_x509_crt_get_dn(crt, data, &size); + if (ret < 0) + { + gnutls_x509_crt_deinit(crt); + return ret; + } + memcpy(out, data, size); + return size; +} + +int get_cert_issuer_dn(gnutls_pcert_st *st, int index, char *out) +{ + + gnutls_x509_crt_t crt; + int ret; + char data[200]; + size_t size = 200; + gnutls_pcert_st *st1 = st + index; + ret = gnutls_x509_crt_init(&crt); + if (ret < 0) + { + return ret; + } + ret = gnutls_pcert_export_x509(st1, &crt); + if (ret < 0) + { + gnutls_x509_crt_deinit(crt); + return ret; + } + ret = gnutls_x509_crt_get_issuer_dn(crt, data, &size); + if (ret < 0) + { + gnutls_x509_crt_deinit(crt); + return ret; + } + memcpy(out, data, size); + return size; +} + +gnutls_pcert_st *get_peer_certificate(gnutls_session_t sess, int *pcert_length) +{ + const gnutls_datum_t *raw_certs; + const gnutls_datum_t *d; + gnutls_pcert_st *st, *st1; + int ret; + int i; + *pcert_length = 0; + raw_certs = gnutls_certificate_get_peers(sess, pcert_length); + if (pcert_length == NULL) + { + printf("pcert length is NULL\n"); + return NULL; + } + if (*pcert_length == 0) + { + printf("pcert length is 0\n"); + return NULL; + } + //printf("pcert length %d\n", *pcert_length); + st = malloc((*pcert_length) * sizeof(gnutls_pcert_st)); + for (i = 0; i < *pcert_length; i++) + { + st1 = st + i; + d = raw_certs + i; + ret = gnutls_pcert_import_x509_raw(st1, d, GNUTLS_X509_FMT_DER, 0); + if (ret < 0) + { + printf("import cert failed: %s\n", gnutls_strerror(ret)); + } + } + return st; } \ No newline at end of file diff --git a/testdata/server.crt b/testdata/server.crt index ba731b0..15020aa 100644 --- a/testdata/server.crt +++ b/testdata/server.crt @@ -1,21 +1,23 @@ -----BEGIN CERTIFICATE----- -MIIDYDCCAkigAwIBAgIJAJ92ThBK0H0ZMA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV -BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX -aWRnaXRzIFB0eSBMdGQwHhcNMTgwNjI3MDMwMjUwWhcNMjgwNjI0MDMwMjUwWjBF -MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50 -ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB -CgKCAQEAufMjhIyoWzqqRKbglkfCWqSbU1N6SPaSjBmVQCfFGzUZx3DHWvemJ8Ip -2Nb9Gx9fI/GiRHs1K6noWZjJzsFZNdsoAADg+vXhS4MEwsItcRnKFgbmBtjKPbQy -jHnndAy8Wwe73NXTy7oBSbd5CZogvblLjfndSUIhCXVv7+PFHjfG78LEL7Dp1i1A -96SV2YuVYHr6FIS6C5FA0FGtGpXUC265jUN+sI88ONXMc/7zQ+4VWggB+Kq3B7uR -4DjTtxAC0X58AdWJm3yH6nTJLcsfOVQs1u7Mg97aUpRo9osw7ZMcEEchOw85L0Rl -K4vlZDnB2xD+8S8RRQi+Y/4GKPAihwIDAQABo1MwUTAdBgNVHQ4EFgQUGY2IxTxn -3tKlRsUnko97fF4X8WYwHwYDVR0jBBgwFoAUGY2IxTxn3tKlRsUnko97fF4X8WYw -DwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAHpvWRFV+Cen2VQdX -OO1YYufqcGPmTxo0XmksOUQsd+vwR0HSjCK1oC0PlloviFvCXn9p4ZvwVK1NdK19 -RUkiXPAIw5QmXeNQgCJRv7jhObIuqfKVGfZuH8PdPMcPzU+PxgaZ/gnaW62AAFJA -frZXwMhr/Ar+CvH4NSZfOxBF4LOWM0eVYfxUOFq+qvYptdSxXyK1kuBxFXOVNEIw -CWp75uzvZ1gNaUrdfCNoYMV2qeyuL8gNfRhExrDfjn4ouOZq0Lh8FHCnGdoZo15D -ZHkLUqdKnOXtHKdnLcxStF6orKF7I3f3fp5kyRM19uZzLV07zdawVJSDbaO2cJz1 -Qff5gg== +MIIDzDCCArSgAwIBAgIJAInocGP04UR9MA0GCSqGSIb3DQEBCwUAMHsxCzAJBgNV +BAYTAkNOMQswCQYDVQQIDAJYWjELMAkGA1UEBwwCc3oxEzARBgNVBAoMCnN6Y2h1 +emhvbmcxDzANBgNVBAsMBnJkIGRwdDEQMA4GA1UEAwwHYWJjLmNvbTEaMBgGCSqG +SIb3DQEJARYLYWFhQGFiYy5jb20wHhcNMTgwNzA1MDczMjU4WhcNMTgwODA0MDcz +MjU4WjB7MQswCQYDVQQGEwJDTjELMAkGA1UECAwCWFoxCzAJBgNVBAcMAnN6MRMw +EQYDVQQKDApzemNodXpob25nMQ8wDQYDVQQLDAZyZCBkcHQxEDAOBgNVBAMMB2Fi +Yy5jb20xGjAYBgkqhkiG9w0BCQEWC2FhYUBhYmMuY29tMIIBIjANBgkqhkiG9w0B +AQEFAAOCAQ8AMIIBCgKCAQEAufMjhIyoWzqqRKbglkfCWqSbU1N6SPaSjBmVQCfF +GzUZx3DHWvemJ8Ip2Nb9Gx9fI/GiRHs1K6noWZjJzsFZNdsoAADg+vXhS4MEwsIt +cRnKFgbmBtjKPbQyjHnndAy8Wwe73NXTy7oBSbd5CZogvblLjfndSUIhCXVv7+PF +HjfG78LEL7Dp1i1A96SV2YuVYHr6FIS6C5FA0FGtGpXUC265jUN+sI88ONXMc/7z +Q+4VWggB+Kq3B7uR4DjTtxAC0X58AdWJm3yH6nTJLcsfOVQs1u7Mg97aUpRo9osw +7ZMcEEchOw85L0RlK4vlZDnB2xD+8S8RRQi+Y/4GKPAihwIDAQABo1MwUTAdBgNV +HQ4EFgQUGY2IxTxn3tKlRsUnko97fF4X8WYwHwYDVR0jBBgwFoAUGY2IxTxn3tKl +RsUnko97fF4X8WYwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEA +PQnnEF/nzpIfc4x/kC9w5ynWeyQBOnqYnUH1lgauHEbpzImKvaf0RcsWouP0y4hr +goVYX92RinSvH331r0A9HgHeiobMZG7cuV9jS4EUcz+HjM2FVQIBCitx4YbY6Ylg +PS4wBlKcSJ5q6mFZyLizbxA88WATbFnLnC5yQfs5WydbCWk+wlR7DKjbWvoQ3B8P +VneicIhDT0IseG7Fz+FiXlVD8ihgDbH5lctYmEAU7RIgtcVgIxMkeE2vcugyUb+E +KGs02MrRtsUr4/roDDXnBjPv8W0OXp/CN4Ng0Oc0YgODD/H4Bk5sPt6SrgrU1Ddo +dA9dGTT5aTst1KyRtIyAkg== -----END CERTIFICATE----- diff --git a/testdata/server2.crt b/testdata/server2.crt new file mode 100644 index 0000000..472224d --- /dev/null +++ b/testdata/server2.crt @@ -0,0 +1,21 @@ +-----BEGIN CERTIFICATE----- +MIIDjDCCAnSgAwIBAgIJAPls60eMgWdFMA0GCSqGSIb3DQEBCwUAMFsxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQxFDASBgNVBAMMC2V4YW1wbGUuY29tMB4XDTE4MDcwNjAx +MjY1OFoXDTE5MDcwNjAxMjY1OFowWzELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNv +bWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEUMBIG +A1UEAwwLZXhhbXBsZS5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIB +AQCiZBJGAAZJLK7MK1jrMnHLOJvVXosfJrpGtadwQObORljD3tgfsAOZopvM23aj +sTwQJttemtceWKYNMi5TX9+M9/hN3Wug3Mm3DiTPthmtEEoJ12h5MuS/6d1yyT2D +zwZxouh0ocg9kIlB701DpLfkvLxVhMK+Zlh/JeWlnWG4k1rLIwivSW4pabWb7fgM +p6Oppw1XaFurQh9tnpZO3cSGLY2PsM525xwdG5xPLEUHRLbPamR150m3K0R7wMB1 +MeCrM3jwW3AjQC/4kHrVQQagCwoF4jMZ9FH8nyci8aAar3N9lNB6QbOatZvODT8l +SgfzHBstA2lL3Vhork3wgMcnAgMBAAGjUzBRMB0GA1UdDgQWBBS77XPkKRzbf6xi +xS/hO1vxGPgzAzAfBgNVHSMEGDAWgBS77XPkKRzbf6xixS/hO1vxGPgzAzAPBgNV +HRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQCg4j544c1wi0fxq6wq4xGE +kuPLIzK17VRqAJPHPYG0ruMVVwAWLG+5hE4IowZQIGrDM2akJUfbqpNkHldof+mC +jXEtWozU+/1cvmaez4+hAyq/KN4mfcvRcXVxhD6X4mawBDFfEHt9Fd0IYpgJTkcD +m4M9D0lUvoRfaKHt1vEkXiTLRM6Xbn96Jp+wLiXcIiX9ZwXe9NoVtMJikivoYNlP +L0/v0XAiN2CgjCPER5ipqs8mKshYQMyb/Le2J05cgHZ02Wg3xvlqxO/2k9vLBtw+ +Qaik9QX4Xb7YVwUBGQWDqxK1m/+FxtMpralaigUxVyOYS6CFF1RFcbcptd0jTtJS +-----END CERTIFICATE----- diff --git a/testdata/server2.key b/testdata/server2.key new file mode 100644 index 0000000..2ac21e3 --- /dev/null +++ b/testdata/server2.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEogIBAAKCAQEAomQSRgAGSSyuzCtY6zJxyzib1V6LHya6RrWncEDmzkZYw97Y +H7ADmaKbzNt2o7E8ECbbXprXHlimDTIuU1/fjPf4Td1roNzJtw4kz7YZrRBKCddo +eTLkv+ndcsk9g88GcaLodKHIPZCJQe9NQ6S35Ly8VYTCvmZYfyXlpZ1huJNayyMI +r0luKWm1m+34DKejqacNV2hbq0IfbZ6WTt3Ehi2Nj7DOduccHRucTyxFB0S2z2pk +dedJtytEe8DAdTHgqzN48FtwI0Av+JB61UEGoAsKBeIzGfRR/J8nIvGgGq9zfZTQ +ekGzmrWbzg0/JUoH8xwbLQNpS91YaK5N8IDHJwIDAQABAoIBABp1X5zUKDIH+7r8 +XRKFN5E6+fj73IMI1lTrCAr8KB73y0KurlwLW6rOmb/5Cg2FtRmUmy2A4QfqvbNs +t6uR9WSMioJ1TzH4h00yGsFVFD3kZ4vO8xC8QBUcz54CN+mf85bUSjemnG+beyGp +EdexoNy9+5mbdfd7yXN+AzrGt8NQzIhAMJU2hYPSTMnYoxqENA7Xuocz54LXtfq/ +SwLQcCddRH+6nGWuqtsMLlyw0+c3nOyZXbl0ngmPK2jmZsE8NBXXTiLjeUVP24rl +FwH+A6bVtSLjuuuiP69oPkKe3KNVD/KpdHlaH/6Qb6pulU5Vb1F0rugk7w23IKaa +bnkap2ECgYEAz7pdNTaNR+14uovVDVPciTZyDpHsOl+Kn8vGGZtw2eEkYLyOXjOr ++FUmciwFdam4eCDIu4fhpCjuwfCHWrVYzQsRNNNgSbrssu4jj7/Dxd/w51d224eq +EszJC95GgqhHi13/I6zgcxeZBQc5bcW0bFlfaD8PVzIiTaS2LQJUDHkCgYEAyCCg +Z+NlJv48PwDQuxSuLSAFQMZisy/UP4+wvj6igBtKxc5ork06Em4sFJZYOzOm5rDI +QEkFa0sUKVtgJgio2zA8QUeRtMwhBopi5jcTucDid6H6z2lQQbyFRbr60vo0bTyP +ahX13J4lBcIGy0MJdgj6l6yy1M5ZslqZZ0hjSJ8CgYAtxCqa+bzg1wIdX4d+Gzbg +iD1S1nWMWtZo5HVt2OBhMIhaQ9C+EnZWDTSePPKq/Mymstpm7sYY6+fGlN7Nblz1 +N/X/hH6XX/acaXkuR5qzcuZZodyO+3HOGI5G7h7s1HSG0RvQWVtOICnXgML3W3Kn +2Hz7s8EGfgYuwxZcDkJESQKBgDsR5kbDx8eKox21j+aoZADNwr8rz1Y0d+GK+BAv +TKejZp6cHinUgZ+PBVPOTJys0kalR3YyF3dj4b+TSP3w8GZCBob+KPPEjLrxfd+V +wizB0hadqPovi5DbpDrAxeggEflsNqiJcth7lVHtwzObxd8hJ1Y9k0tc3PzX4Q+r +PiLhAoGAGG1/MbkqY5w0sYjk/eFwL0Elcn2nRthlY0dsywTzXRrMrUBXEJArGOL8 +mjI2EWGZBOWQ7wEqpD3LsmprkSsQJmrsdKnrmWzO52cSpiB3/NKWbVEoF6kNlI+c +PuKo/iSVEeRKWRCbjtdfV3jleSw8XUXH9W9388iVOsC/t+lZy9o= +-----END RSA PRIVATE KEY----- diff --git a/testdata/server3.crt b/testdata/server3.crt new file mode 100644 index 0000000..45d0cb6 --- /dev/null +++ b/testdata/server3.crt @@ -0,0 +1,21 @@ +-----BEGIN CERTIFICATE----- +MIIDiDCCAnCgAwIBAgIJAN+27ugRe7kwMA0GCSqGSIb3DQEBCwUAMFkxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQxEjAQBgNVBAMMCSouYWFhLmNvbTAeFw0xODA3MDYwMTQ1 +MDNaFw0xODA4MDUwMTQ1MDNaMFkxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21l +LVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxEjAQBgNV +BAMMCSouYWFhLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANCS +OnOOSOwuLnukitv1itAsNokLuV+YnEwZCRD3X8yLhhEVPftpo0DobOO3Lg38JmWg +zKQem41nE8ruMEAUZSlPHeHKc+yHbqWGBpXlcxXihyVWouZaKT7YFQp9E8CRSscS +9yeXRkKfH7Yh6660OysqZfCAhjKxUCvvp1W3sWqLx8ssdGcuLD+CaOGJMv66n3nS +qfFZXE8H7vdB+TFUlvNcRgAoSxV36zS2i3HcUv+SUGXoupAJG3l23x3UhZky7Yo7 +qJkT681O8zuhAv/tZPqv5PHLOl9t4mYc4jX+TZ2ZNLckZ+VjNPPVn920/fZ1cvio +GuBjekzPewXE6O8NaH0CAwEAAaNTMFEwHQYDVR0OBBYEFNC8YhToPT55/xEziMnR +zbVOM/o+MB8GA1UdIwQYMBaAFNC8YhToPT55/xEziMnRzbVOM/o+MA8GA1UdEwEB +/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAFccqqBIdCwMjbxcaFY6vwHUSdid +YrVcQSHy6w7hUG9Una7HAqyGrK0Uwe+njR8eE0Xs7Q+QviQDijgC4FHfkrEPFZYG +9yVAAcwNl2NxuKAwHtao/P7JwBCsuGamDOxq/C9nN5M9USiBnBytYGv1a3CYEYl9 +tmjegWe2lN9eyqtbqTH218tMa/avUzOkkRkPKuik1BIFDCkzbfqnlI6uUpzdKqAR +jkj9MHtFjJ5SxtW8ak9Lo78zawbfdLuec9NfJZaTZqIy969+0apYCQ66wUBMwa8u +u3rVWdkWDEoFx5KnEUchKD1/pgI+mU45kqay1kxW3A9Qvt1NNK9hVEGcCGw= +-----END CERTIFICATE----- diff --git a/testdata/server3.key b/testdata/server3.key new file mode 100644 index 0000000..0ceaef5 --- /dev/null +++ b/testdata/server3.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEA0JI6c45I7C4ue6SK2/WK0Cw2iQu5X5icTBkJEPdfzIuGERU9 ++2mjQOhs47cuDfwmZaDMpB6bjWcTyu4wQBRlKU8d4cpz7IdupYYGleVzFeKHJVai +5lopPtgVCn0TwJFKxxL3J5dGQp8ftiHrrrQ7Kypl8ICGMrFQK++nVbexaovHyyx0 +Zy4sP4Jo4Yky/rqfedKp8VlcTwfu90H5MVSW81xGAChLFXfrNLaLcdxS/5JQZei6 +kAkbeXbfHdSFmTLtijuomRPrzU7zO6EC/+1k+q/k8cs6X23iZhziNf5NnZk0tyRn +5WM089Wf3bT99nVy+Kga4GN6TM97BcTo7w1ofQIDAQABAoIBABfeZCoNQnMk5dTo +g6ugcf5Y0PTmDpTOFjTvOIZkiOYToYga8jjvYetvreZxdCfNj2dZ+5Fcn1iTT0SN +9Z+fteQAVd8dGB8dcKAosmA9HeqHPggb3hCWiNnUSLQmlDgZaIFXvkkdmsDNDQf+ +4cXgglTySTA4xSLP/+jHSFMa7obOu2EWkJMYWg1/jWT0F1xIZo6ca3OSrtm65F1b +fMTL6U9+hQzbFIEzqAulL8RejCSz4vRBy6zUBWdQ6OBip+SiKjO0ArDavOFfw9Qq +0X9Wkd1Mm/fytFrZbtIClpTN1196OMmDuabVBowkLKg6pi6gAp/YzPvaTMmTZPGc +42PJ19ECgYEA8eBcRIJ+/7d/ym7ro4jGWwvB39eUB7qvP5wknKakG9PRZcSwJTjU +Rcw8dREdsUutHkrW+6G/EXO7SsEr1XUvuEkX1QA9ecE+9p1tBGuTEwufZZ/vcScz +3zHPUnOKjtRUAgi1S+X/t/qjlp0zkJ6RMz1NWeSBol0+ULol03X/yTMCgYEA3MAD +Dqi26sutDOLLj7iXMPB9ZkuIWpKmByNqBUmKLgCOIlXzzPgTsooFpQCL0uFcIMLX +26ZYtWT0O/sNtpT97EgqBn6TE0oGzEmZPtAm8UXQsWiRgGTy0lgpxQdAIS6GCtKZ +DbX6lV+Apu8nx2lt9+uSBYBYDzxt2oaKkdTl548CgYEAnldpL7RaHV9sOgKJIiKM +79dvkPEYyEKPLU4zrZVtw4XUBBJR2dwtIpBEV8LftEw9RaJqwIovgeZIivSQlInF +tperEVa55/X5GQsP7h/aRVKLg8TCxEmMrKV3+psG7t/TKw22Wbx0vmVHKHc65YbY +uTl9ZMaxkrAF9mUWFCugSn8CgYApkYrB1uli+2mh1I9KiBMIZzDl83FAxP64t2V+ +i2OW2Anr002umkRzSWRYtuqdkkxb6vTk6sUnm1QWe2cQq6vJM6meQXWGm1j+XRmb +Z2z94Ay1a6CCkf/btjhfXscnuHALV670kwEV4b8DMGPIPEU1+0kq+gkbDWEOVml/ +npyQZwKBgQDFrKiiawiykpqTPaXabOf6hv3WkAUMxlqeDHDKoLLU2oM/UmU/mDA1 +RXn8lN6nMRTFTPFsMdPhKiTV5l0fq8LQQmaWeyLMjQBAUMlcLsdoDjltjLZa4dt+ +tVKFQ6jAbzkuuhUM1NlENTDNW5flRksJy52077yXsWJUEK6RUYYpqQ== +-----END RSA PRIVATE KEY----- diff --git a/tls.go b/tls.go index 53ffa97..ce35209 100644 --- a/tls.go +++ b/tls.go @@ -9,7 +9,6 @@ import ( "fmt" "log" "net" - "os" "time" "unsafe" ) @@ -24,17 +23,17 @@ const ( // Conn tls connection for client type Conn struct { c net.Conn - sess *C.struct_session handshake bool + sess *C.struct_session cservname *C.char state *ConnectionState + cfg *Config } // Config tls configure type Config struct { ServerName string - CrtFile string - KeyFile string + Certificates []*Certificate InsecureSkipVerify bool NextProtos []string } @@ -49,7 +48,8 @@ type ConnectionState struct { // TLS version number, ex: 0x303 Version uint16 // TLS version number, ex: TLS1.0 - VersionName string + VersionName string + PeerCertificate *Certificate } type listener struct { @@ -90,15 +90,6 @@ func Listen(network, addr string, cfg *Config) (net.Listener, error) { if cfg == nil { return nil, fmt.Errorf("config is need") } - if cfg.CrtFile == "" || cfg.KeyFile == "" { - return nil, fmt.Errorf("keyfile is needed") - } - if _, err := os.Stat(cfg.CrtFile); err != nil { - return nil, err - } - if _, err := os.Stat(cfg.KeyFile); err != nil { - return nil, err - } l, err := net.Listen(network, addr) if err != nil { return nil, err @@ -109,21 +100,12 @@ func Listen(network, addr string, cfg *Config) (net.Listener, error) { // NewServerConn create a server Conn func NewServerConn(c net.Conn, cfg *Config) (*Conn, error) { var sess = C.init_server_session() - conn := &Conn{c: c, sess: sess} + conn := &Conn{c: c, sess: sess, cfg: cfg} n := C.size_t(uintptr(unsafe.Pointer(conn))) //log.Println("conn addr ", int(n)) C.set_data(sess, n) C.set_callback(sess) - crtfile := C.CString(cfg.CrtFile) - keyfile := C.CString(cfg.KeyFile) - defer C.free(unsafe.Pointer(crtfile)) - defer C.free(unsafe.Pointer(keyfile)) - ret := C.gnutls_certificate_set_x509_key_file( - sess.xcred, crtfile, keyfile, GNUTLS_X509_FMT_PEM) - if int(ret) < 0 { - cerrstr := C.gnutls_strerror(ret) - return nil, fmt.Errorf("set keyfile failed: %s", C.GoString(cerrstr)) - } + if cfg.NextProtos != nil { if err := setAlpnProtocols(sess, cfg); err != nil { log.Println(err) @@ -135,7 +117,7 @@ func NewServerConn(c net.Conn, cfg *Config) (*Conn, error) { // NewClientConn create a new gnutls connection func NewClientConn(c net.Conn, cfg *Config) (*Conn, error) { var sess = C.init_client_session() - conn := &Conn{c: c, sess: sess} + conn := &Conn{c: c, sess: sess, cfg: cfg} n := C.size_t(uintptr(unsafe.Pointer(conn))) //log.Println("conn addr ", int(n)) C.set_data(sess, n) @@ -149,18 +131,6 @@ func NewClientConn(c net.Conn, cfg *Config) (*Conn, error) { unsafe.Pointer(srvname), C.size_t(len(cfg.ServerName))) } - if cfg.CrtFile != "" && cfg.KeyFile != "" { - crtfile := C.CString(cfg.CrtFile) - keyfile := C.CString(cfg.KeyFile) - defer C.free(unsafe.Pointer(crtfile)) - defer C.free(unsafe.Pointer(keyfile)) - ret := C.gnutls_certificate_set_x509_key_file( - sess.xcred, crtfile, keyfile, GNUTLS_X509_FMT_PEM) - if int(ret) < 0 { - return nil, fmt.Errorf("set keyfile failed: %s", - C.GoString(C.gnutls_strerror(ret))) - } - } if !cfg.InsecureSkipVerify { if conn.cservname != nil { C.gnutls_session_set_verify_cert(sess.session, conn.cservname, 0) @@ -270,11 +240,16 @@ func (c *Conn) Write(buf []byte) (n int, err error) { // Close close the conn and destroy the tls context func (c *Conn) Close() error { + C.gnutls_record_send(c.sess.session, nil, 0) C.session_destroy(c.sess) c.c.Close() if c.cservname != nil { C.free(unsafe.Pointer(c.cservname)) } + + if c.state != nil && c.state.PeerCertificate != nil { + c.state.PeerCertificate.Free() + } return nil } @@ -320,11 +295,21 @@ func (c *Conn) ConnectionState() *ConnectionState { HandshakeComplete: c.handshake, ServerName: c.getServerName(), VersionName: versionname, + PeerCertificate: c.getPeerCertificate(), } c.state = state return state } +func (c *Conn) getPeerCertificate() *Certificate { + var size int + st := C.get_peer_certificate(c.sess.session, (*C.int)(unsafe.Pointer(&size))) + if st == nil { + return nil + } + return &Certificate{cert: st, certSize: C.int(size)} +} + func (c *Conn) getAlpnSelectedProtocol() string { cbuf := C.malloc(100) defer C.free(cbuf) @@ -353,9 +338,9 @@ func (c *Conn) getServerName() string { return name } -// DataRead c callback function for data read -//export DataRead -func DataRead(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int { +// OnDataReadCallback c callback function for data read +//export OnDataReadCallback +func OnDataReadCallback(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int { //log.Println("read addr ", uintptr(d)) conn := (*Conn)(unsafe.Pointer((uintptr(d)))) buf := make([]byte, int(bufLen)) @@ -371,9 +356,9 @@ func DataRead(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int { return C.int(n) } -// DataWrite c callback function for data write -//export DataWrite -func DataWrite(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int { +// OnDataWriteCallback c callback function for data write +//export OnDataWriteCallback +func OnDataWriteCallback(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int { //log.Println("write addr ", uintptr(d), int(_l)) conn := (*Conn)(unsafe.Pointer((uintptr(d)))) gobuf := C.GoBytes(unsafe.Pointer(cbuf), bufLen) @@ -385,9 +370,50 @@ func DataWrite(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int { return C.int(n) } -// DataTimeoutPull c callback function for timeout read -//export DataTimeoutPull -func DataTimeoutPull(d unsafe.Pointer, delay C.int) C.int { +// OnDataTimeoutRead c callback function for timeout read +//export OnDataTimeoutRead +func OnDataTimeoutRead(d unsafe.Pointer, delay C.int) C.int { log.Println("timeout pull function") return 0 } + +// OnCertSelectCallback callback function for ceritificate select +//export OnCertSelectCallback +func OnCertSelectCallback(ptr unsafe.Pointer, hostname *C.char, + namelen C.int, pcertLength *C.int, cert **C.gnutls_pcert_st, privkey *C.gnutls_privkey_t) C.int { + + servername := C.GoStringN(hostname, namelen) + //log.Println("go cert select callback ", servername) + conn := (*Conn)(unsafe.Pointer((uintptr(ptr)))) + //log.Println(conn) + if int(namelen) == 0 && conn.cfg.Certificates != nil { + _cert := conn.cfg.Certificates[0] + *pcertLength = _cert.certSize + *cert = _cert.cert + *privkey = _cert.privkey + //log.Println("set pcert length ", _cert.certSize) + return 0 + } + for _, _cert := range conn.cfg.Certificates { + //log.Println(cert) + if _cert.matchName(servername) { + //log.Println("matched name ", _cert.names) + *pcertLength = _cert.certSize + *cert = _cert.cert + *privkey = _cert.privkey + //log.Println("set pcert length ", _cert.certSize) + return 0 + } + } + if conn.cfg.Certificates != nil { + _cert := conn.cfg.Certificates[0] + *pcertLength = _cert.certSize + *cert = _cert.cert + *privkey = _cert.privkey + //log.Println("set pcert length ", _cert.certSize) + return 0 + } + *pcertLength = 0 + //log.Println("set pcert length 0") + return -1 +} diff --git a/tls_test.go b/tls_test.go index 3c86c6f..f57cf12 100644 --- a/tls_test.go +++ b/tls_test.go @@ -1,9 +1,15 @@ package gnutls import ( + "bufio" + "bytes" "crypto/tls" + "fmt" + "io" "log" "net" + "net/http" + "os" "testing" ) @@ -66,8 +72,13 @@ func TestTLSClient(t *testing.T) { } func TestTLSServer(t *testing.T) { + cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key") + if err != nil { + t.Fatal(err) + } l, err := Listen("tcp", "127.0.0.1:0", &Config{ - CrtFile: "testdata/server.crt", KeyFile: "testdata/server.key"}) + Certificates: []*Certificate{cert}, + }) if err != nil { t.Fatal("gnutls listen ", err) } @@ -84,7 +95,11 @@ func TestTLSServer(t *testing.T) { log.Println("accept connection from ", c.RemoteAddr()) go func(c net.Conn) { defer c.Close() - + tlsconn := c.(*Conn) + if err := tlsconn.Handshake(); err != nil { + log.Println(err) + return + } buf := make([]byte, 4096) for { n, err := c.Read(buf[0:]) @@ -125,11 +140,13 @@ func TestTLSALPNServer(t *testing.T) { serveralpn := []string{"a1", "a3", "a2"} clientalpn := []string{"a0", "a2", "a5"} expectedAlpn := "a2" - + cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key") + if err != nil { + t.Fatal(err) + } l, err := Listen("tcp", "127.0.0.1:0", &Config{ - CrtFile: "testdata/server.crt", - KeyFile: "testdata/server.key", - NextProtos: serveralpn, + Certificates: []*Certificate{cert}, + NextProtos: serveralpn, }) if err != nil { t.Fatal("gnutls listen ", err) @@ -292,3 +309,100 @@ func TestTLSALPNClient(t *testing.T) { t.Errorf("need: %s, got: %s", data, string(buf[:n])) } } + +func TestTLSServerSNI(t *testing.T) { + certificates := []*Certificate{} + cert, err := LoadX509KeyPair("testdata/server.crt", "testdata/server.key") + if err != nil { + t.Fatal("load key failed") + } + + certificates = append(certificates, cert) + cert, err = LoadX509KeyPair("testdata/server2.crt", "testdata/server2.key") + if err != nil { + t.Fatal("load key failed") + } + + certificates = append(certificates, cert) + cert, err = LoadX509KeyPair("testdata/server3.crt", "testdata/server3.key") + if err != nil { + t.Fatal("load key failed") + } + certificates = append(certificates, cert) + + l, err := Listen("tcp", "127.0.0.1:0", &Config{ + Certificates: certificates, + }) + if err != nil { + t.Fatal(err) + } + defer l.Close() + addr := l.Addr().String() + go func() { + for { + c, err := l.Accept() + if err != nil { + log.Println(err) + break + } + go func(c net.Conn) { + defer c.Close() + tlsconn := c.(*Conn) + if err := tlsconn.Handshake(); err != nil { + log.Println(err) + return + } + state := tlsconn.ConnectionState() + fmt.Fprintf(c, state.ServerName) + }(c) + } + }() + + for _, servername := range []string{"abc.com", "example.com", "a.aaa.com", "b.aaa.com"} { + conn, err := tls.Dial("tcp", addr, &tls.Config{ + ServerName: servername, + InsecureSkipVerify: true, + }) + if err != nil { + t.Fatal(err) + } + //state := conn.ConnectionState() + //log.Printf("%+v", state.PeerCertificates[0]) + buf := make([]byte, 100) + n, err := conn.Read(buf) + if err != nil && err != io.EOF { + t.Error(err) + } + if !bytes.Equal(buf[:n], []byte(servername)) { + t.Errorf("expect %s, got %s", servername, string(buf[:n])) + } + conn.Close() + } +} + +func TestTLSGetPeerCert(t *testing.T) { + conn, err := Dial("tcp", "www.ratafee.nl:443", &Config{ + ServerName: "www.ratafee.nl", + }) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + //tlsconn := conn.(*Conn) + if err := conn.Handshake(); err != nil { + t.Fatal(err) + } + state := conn.ConnectionState() + for i := 0; i < int(state.PeerCertificate.certSize); i++ { + log.Println(state.PeerCertificate.getCertString(i, 1)) + } + + req, _ := http.NewRequest("GET", "https://www.ratafee.nl/httpbin/ip", nil) + req.Write(conn) + r := bufio.NewReader(conn) + resp, err := http.ReadResponse(r, req) + if err != nil { + t.Error(err) + } + resp.Write(os.Stdout) +}