use gnutls api to check hostname

master
fangdingjun 6 years ago
parent 3c2bd315dc
commit 9723226e55

@ -54,4 +54,5 @@ int get_cert_dn(gnutls_pcert_st *st, int index, char *out);
void free_cert_list(gnutls_pcert_st *st, int size);
gnutls_pcert_st *get_peer_certificate(gnutls_session_t sess, int *pcert_length);
int cert_check_hostname(gnutls_pcert_st *st, int len, char *hostname);
#endif

@ -17,7 +17,6 @@ type Certificate struct {
cert *C.gnutls_pcert_st
privkey C.gnutls_privkey_t
certSize C.int
names []string
}
// Free free the certificate context
@ -39,31 +38,17 @@ func (c *Certificate) free() {
}
func (c *Certificate) matchName(name string) bool {
for _, n := range c.names {
if n == name {
return true
}
if strings.HasPrefix(n, "*") {
n1 := strings.Replace(n, "*.", "", 1)
if strings.HasSuffix(name, n1) {
return true
}
}
}
return false
}
func (c *Certificate) buildNames() {
if c.names != nil {
return
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
ret := C.cert_check_hostname(c.cert, c.certSize, cname)
if int(ret) < 0 {
log.Println(C.GoString(C.gnutls_strerror(ret)))
return false
}
c.names = []string{}
for i := 0; i < int(c.certSize); i++ {
cn := c.commonName(i)
if cn != "" {
c.names = append(c.names, cn)
}
if int(ret) > 0 {
return true
}
return false
}
// CommonName get CN field in subject,
@ -186,7 +171,6 @@ func LoadX509KeyPair(certfile, keyfile string) (*Certificate, error) {
certificate.cert = cert
certificate.privkey = privkey
certificate.certSize = certSize
certificate.buildNames()
runtime.SetFinalizer(certificate, (*Certificate).free)
return certificate, nil
}

@ -439,3 +439,29 @@ gnutls_pcert_st *get_peer_certificate(gnutls_session_t sess, int *pcert_length)
}
return st;
}
int cert_check_hostname(gnutls_pcert_st *st, int len, char *hostname)
{
int i;
int ret;
int allow = 0;
gnutls_x509_crt_t crt;
for (i = 0; i < len; i++)
{
gnutls_x509_crt_init(&crt);
ret = gnutls_pcert_export_x509((st + i), &crt);
if (ret < 0)
{
return ret;
}
ret = gnutls_x509_crt_check_hostname(crt, hostname);
if (ret != 0)
{
allow = 1;
gnutls_x509_crt_deinit(crt);
break;
}
gnutls_x509_crt_deinit(crt);
}
return allow;
}

@ -366,23 +366,34 @@ func TestTLSServerSNI(t *testing.T) {
}
}()
for _, servername := range []string{"abc.com", "example.com", "a.aaa.com", "b.aaa.com"} {
for _, cfg := range []struct {
serverName string
commonName string
}{
{"abc.com", "abc.com"},
{"example.com", "example.com"},
{"a.aaa.com", "*.aaa.com"},
{"b.aaa.com", "*.aaa.com"},
} {
conn, err := tls.Dial("tcp", addr, &tls.Config{
ServerName: servername,
ServerName: cfg.serverName,
InsecureSkipVerify: true,
})
if err != nil {
t.Fatal(err)
}
//state := conn.ConnectionState()
//log.Printf("%+v", state.PeerCertificates[0])
state := conn.ConnectionState()
_commonName := state.PeerCertificates[0].Subject.CommonName
if _commonName != cfg.commonName {
t.Errorf("expect: %s, got: %s", cfg.commonName, _commonName)
}
buf := make([]byte, 100)
n, err := conn.Read(buf)
if err != nil && err != io.EOF {
t.Error(err)
}
if !bytes.Equal(buf[:n], []byte(servername)) {
t.Errorf("expect %s, got %s", servername, string(buf[:n]))
if !bytes.Equal(buf[:n], []byte(cfg.serverName)) {
t.Errorf("expect %s, got %s", cfg.serverName, string(buf[:n]))
}
conn.Close()
}

Loading…
Cancel
Save