diff --git a/_gnutls.h b/_gnutls.h index 60991ad..bc5048b 100644 --- a/_gnutls.h +++ b/_gnutls.h @@ -54,4 +54,5 @@ int get_cert_dn(gnutls_pcert_st *st, int index, char *out); void free_cert_list(gnutls_pcert_st *st, int size); gnutls_pcert_st *get_peer_certificate(gnutls_session_t sess, int *pcert_length); +int cert_check_hostname(gnutls_pcert_st *st, int len, char *hostname); #endif \ No newline at end of file diff --git a/certificate.go b/certificate.go index 1f62d01..36791ad 100644 --- a/certificate.go +++ b/certificate.go @@ -17,7 +17,6 @@ type Certificate struct { cert *C.gnutls_pcert_st privkey C.gnutls_privkey_t certSize C.int - names []string } // Free free the certificate context @@ -39,31 +38,17 @@ func (c *Certificate) free() { } func (c *Certificate) matchName(name string) bool { - for _, n := range c.names { - if n == name { - return true - } - if strings.HasPrefix(n, "*") { - n1 := strings.Replace(n, "*.", "", 1) - if strings.HasSuffix(name, n1) { - return true - } - } - } - return false -} - -func (c *Certificate) buildNames() { - if c.names != nil { - return + cname := C.CString(name) + defer C.free(unsafe.Pointer(cname)) + ret := C.cert_check_hostname(c.cert, c.certSize, cname) + if int(ret) < 0 { + log.Println(C.GoString(C.gnutls_strerror(ret))) + return false } - c.names = []string{} - for i := 0; i < int(c.certSize); i++ { - cn := c.commonName(i) - if cn != "" { - c.names = append(c.names, cn) - } + if int(ret) > 0 { + return true } + return false } // CommonName get CN field in subject, @@ -186,7 +171,6 @@ func LoadX509KeyPair(certfile, keyfile string) (*Certificate, error) { certificate.cert = cert certificate.privkey = privkey certificate.certSize = certSize - certificate.buildNames() runtime.SetFinalizer(certificate, (*Certificate).free) return certificate, nil } diff --git a/gnutls.c b/gnutls.c index 8ed4047..eebb724 100644 --- a/gnutls.c +++ b/gnutls.c @@ -438,4 +438,30 @@ gnutls_pcert_st *get_peer_certificate(gnutls_session_t sess, int *pcert_length) } } return st; +} + +int cert_check_hostname(gnutls_pcert_st *st, int len, char *hostname) +{ + int i; + int ret; + int allow = 0; + gnutls_x509_crt_t crt; + for (i = 0; i < len; i++) + { + gnutls_x509_crt_init(&crt); + ret = gnutls_pcert_export_x509((st + i), &crt); + if (ret < 0) + { + return ret; + } + ret = gnutls_x509_crt_check_hostname(crt, hostname); + if (ret != 0) + { + allow = 1; + gnutls_x509_crt_deinit(crt); + break; + } + gnutls_x509_crt_deinit(crt); + } + return allow; } \ No newline at end of file diff --git a/tls_test.go b/tls_test.go index 3ffc3a4..10638cc 100644 --- a/tls_test.go +++ b/tls_test.go @@ -366,23 +366,34 @@ func TestTLSServerSNI(t *testing.T) { } }() - for _, servername := range []string{"abc.com", "example.com", "a.aaa.com", "b.aaa.com"} { + for _, cfg := range []struct { + serverName string + commonName string + }{ + {"abc.com", "abc.com"}, + {"example.com", "example.com"}, + {"a.aaa.com", "*.aaa.com"}, + {"b.aaa.com", "*.aaa.com"}, + } { conn, err := tls.Dial("tcp", addr, &tls.Config{ - ServerName: servername, + ServerName: cfg.serverName, InsecureSkipVerify: true, }) if err != nil { t.Fatal(err) } - //state := conn.ConnectionState() - //log.Printf("%+v", state.PeerCertificates[0]) + state := conn.ConnectionState() + _commonName := state.PeerCertificates[0].Subject.CommonName + if _commonName != cfg.commonName { + t.Errorf("expect: %s, got: %s", cfg.commonName, _commonName) + } buf := make([]byte, 100) n, err := conn.Read(buf) if err != nil && err != io.EOF { t.Error(err) } - if !bytes.Equal(buf[:n], []byte(servername)) { - t.Errorf("expect %s, got %s", servername, string(buf[:n])) + if !bytes.Equal(buf[:n], []byte(cfg.serverName)) { + t.Errorf("expect %s, got %s", cfg.serverName, string(buf[:n])) } conn.Close() }