use gnutls api to check hostname

master
fangdingjun 6 years ago
parent 3c2bd315dc
commit 9723226e55

@ -54,4 +54,5 @@ int get_cert_dn(gnutls_pcert_st *st, int index, char *out);
void free_cert_list(gnutls_pcert_st *st, int size); void free_cert_list(gnutls_pcert_st *st, int size);
gnutls_pcert_st *get_peer_certificate(gnutls_session_t sess, int *pcert_length); gnutls_pcert_st *get_peer_certificate(gnutls_session_t sess, int *pcert_length);
int cert_check_hostname(gnutls_pcert_st *st, int len, char *hostname);
#endif #endif

@ -17,7 +17,6 @@ type Certificate struct {
cert *C.gnutls_pcert_st cert *C.gnutls_pcert_st
privkey C.gnutls_privkey_t privkey C.gnutls_privkey_t
certSize C.int certSize C.int
names []string
} }
// Free free the certificate context // Free free the certificate context
@ -39,31 +38,17 @@ func (c *Certificate) free() {
} }
func (c *Certificate) matchName(name string) bool { func (c *Certificate) matchName(name string) bool {
for _, n := range c.names { cname := C.CString(name)
if n == name { defer C.free(unsafe.Pointer(cname))
return true ret := C.cert_check_hostname(c.cert, c.certSize, cname)
} if int(ret) < 0 {
if strings.HasPrefix(n, "*") { log.Println(C.GoString(C.gnutls_strerror(ret)))
n1 := strings.Replace(n, "*.", "", 1) return false
if strings.HasSuffix(name, n1) {
return true
}
}
}
return false
}
func (c *Certificate) buildNames() {
if c.names != nil {
return
} }
c.names = []string{} if int(ret) > 0 {
for i := 0; i < int(c.certSize); i++ { return true
cn := c.commonName(i)
if cn != "" {
c.names = append(c.names, cn)
}
} }
return false
} }
// CommonName get CN field in subject, // CommonName get CN field in subject,
@ -186,7 +171,6 @@ func LoadX509KeyPair(certfile, keyfile string) (*Certificate, error) {
certificate.cert = cert certificate.cert = cert
certificate.privkey = privkey certificate.privkey = privkey
certificate.certSize = certSize certificate.certSize = certSize
certificate.buildNames()
runtime.SetFinalizer(certificate, (*Certificate).free) runtime.SetFinalizer(certificate, (*Certificate).free)
return certificate, nil return certificate, nil
} }

@ -439,3 +439,29 @@ gnutls_pcert_st *get_peer_certificate(gnutls_session_t sess, int *pcert_length)
} }
return st; return st;
} }
int cert_check_hostname(gnutls_pcert_st *st, int len, char *hostname)
{
int i;
int ret;
int allow = 0;
gnutls_x509_crt_t crt;
for (i = 0; i < len; i++)
{
gnutls_x509_crt_init(&crt);
ret = gnutls_pcert_export_x509((st + i), &crt);
if (ret < 0)
{
return ret;
}
ret = gnutls_x509_crt_check_hostname(crt, hostname);
if (ret != 0)
{
allow = 1;
gnutls_x509_crt_deinit(crt);
break;
}
gnutls_x509_crt_deinit(crt);
}
return allow;
}

@ -366,23 +366,34 @@ func TestTLSServerSNI(t *testing.T) {
} }
}() }()
for _, servername := range []string{"abc.com", "example.com", "a.aaa.com", "b.aaa.com"} { for _, cfg := range []struct {
serverName string
commonName string
}{
{"abc.com", "abc.com"},
{"example.com", "example.com"},
{"a.aaa.com", "*.aaa.com"},
{"b.aaa.com", "*.aaa.com"},
} {
conn, err := tls.Dial("tcp", addr, &tls.Config{ conn, err := tls.Dial("tcp", addr, &tls.Config{
ServerName: servername, ServerName: cfg.serverName,
InsecureSkipVerify: true, InsecureSkipVerify: true,
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
//state := conn.ConnectionState() state := conn.ConnectionState()
//log.Printf("%+v", state.PeerCertificates[0]) _commonName := state.PeerCertificates[0].Subject.CommonName
if _commonName != cfg.commonName {
t.Errorf("expect: %s, got: %s", cfg.commonName, _commonName)
}
buf := make([]byte, 100) buf := make([]byte, 100)
n, err := conn.Read(buf) n, err := conn.Read(buf)
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
t.Error(err) t.Error(err)
} }
if !bytes.Equal(buf[:n], []byte(servername)) { if !bytes.Equal(buf[:n], []byte(cfg.serverName)) {
t.Errorf("expect %s, got %s", servername, string(buf[:n])) t.Errorf("expect %s, got %s", cfg.serverName, string(buf[:n]))
} }
conn.Close() conn.Close()
} }

Loading…
Cancel
Save