diff --git a/_gnutls.h b/_gnutls.h index ebbda63..6e8eab6 100644 --- a/_gnutls.h +++ b/_gnutls.h @@ -35,4 +35,6 @@ gnutls_cipher_hd_t new_cipher(int cipher_type, char *key, int keylen, char *iv, gnutls_hash_hd_t new_hash(int t); +int alpn_set_protocols(struct session *sess, char **, int); +int alpn_get_selected_protocol(struct session *sess, char *buf); #endif \ No newline at end of file diff --git a/gnutls.c b/gnutls.c index 7015e6a..ee30939 100644 --- a/gnutls.c +++ b/gnutls.c @@ -145,3 +145,40 @@ gnutls_hash_hd_t new_hash(int t) gnutls_hash_init(&hash, t); return hash; } + +int alpn_set_protocols(struct session *sess, char **names, int namelen) +{ + gnutls_datum_t *t; + int ret; + int i; + + t = (gnutls_datum_t *)malloc(namelen * sizeof(gnutls_datum_t)); + for (i = 0; i < namelen; i++) + { + t[i].data = names[i]; + t[i].size = strlen(names[i]); + } + + ret = gnutls_alpn_set_protocols(sess->session, t, + namelen, + GNUTLS_ALPN_SERVER_PRECEDENCE); + free(t); + return ret; +} + +int alpn_get_selected_protocol(struct session *sess, char *buf) +{ + gnutls_datum_t p; + int ret; + memset(&p, 0, sizeof(gnutls_datum_t)); + ret = gnutls_alpn_get_selected_protocol(sess->session, &p); + if (ret < 0) + { + return ret; + } + strcpy(buf, p.data); + + // note: p.data is constant value, only valid during the session life + + return 0; +} \ No newline at end of file diff --git a/tls.go b/tls.go index 49781f7..53ffa97 100644 --- a/tls.go +++ b/tls.go @@ -15,8 +15,10 @@ import ( ) const ( - GNUTLS_NAME_DNS = 1 - GNUTLS_X509_FMT_PEM = 1 + GNUTLS_NAME_DNS = 1 + GNUTLS_X509_FMT_PEM = 1 + GNUTLS_ALPN_MANDATORY = 1 + GNUTLS_ALPN_SERVER_PRECEDENCE = 1 << 1 ) // Conn tls connection for client @@ -25,6 +27,7 @@ type Conn struct { sess *C.struct_session handshake bool cservname *C.char + state *ConnectionState } // Config tls configure @@ -33,7 +36,22 @@ type Config struct { CrtFile string KeyFile string InsecureSkipVerify bool + NextProtos []string } + +// ConnectionState connection state +type ConnectionState struct { + // SNI name client send + ServerName string + // selected ALPN protocl + NegotiatedProtocol string + HandshakeComplete bool + // TLS version number, ex: 0x303 + Version uint16 + // TLS version number, ex: TLS1.0 + VersionName string +} + type listener struct { l net.Listener c *Config @@ -106,6 +124,11 @@ func NewServerConn(c net.Conn, cfg *Config) (*Conn, error) { cerrstr := C.gnutls_strerror(ret) return nil, fmt.Errorf("set keyfile failed: %s", C.GoString(cerrstr)) } + if cfg.NextProtos != nil { + if err := setAlpnProtocols(sess, cfg); err != nil { + log.Println(err) + } + } return conn, nil } @@ -145,12 +168,35 @@ func NewClientConn(c net.Conn, cfg *Config) (*Conn, error) { C.gnutls_session_set_verify_cert(sess.session, nil, 0) } } + + if cfg.NextProtos != nil { + if err := setAlpnProtocols(sess, cfg); err != nil { + log.Println(err) + } + } + } else { C.gnutls_session_set_verify_cert(sess.session, nil, 0) } return conn, nil } +func setAlpnProtocols(sess *C.struct_session, cfg *Config) error { + arg := make([](*C.char), 0) + for _, s := range cfg.NextProtos { + cbuf := C.CString(s) + defer C.free(unsafe.Pointer(cbuf)) + arg = append(arg, (*C.char)(cbuf)) + } + ret := C.alpn_set_protocols(sess, + (**C.char)(unsafe.Pointer(&arg[0])), C.int(len(cfg.NextProtos))) + if int(ret) < 0 { + return fmt.Errorf("set alpn failed: %s", C.GoString(C.gnutls_strerror(ret))) + } + return nil + +} + // Handshake handshake tls func (c *Conn) Handshake() error { if c.handshake { @@ -257,6 +303,56 @@ func (c *Conn) SetDeadline(t time.Time) error { return c.c.SetDeadline(t) } +// ConnectionState report connection state +func (c *Conn) ConnectionState() *ConnectionState { + if c.state != nil { + return c.state + } + version := + uint16(C.gnutls_protocol_get_version(c.sess.session)) + + versionname := C.GoString( + C.gnutls_protocol_get_name(C.gnutls_protocol_t(version))) + + state := &ConnectionState{ + NegotiatedProtocol: c.getAlpnSelectedProtocol(), + Version: version, + HandshakeComplete: c.handshake, + ServerName: c.getServerName(), + VersionName: versionname, + } + c.state = state + return state +} + +func (c *Conn) getAlpnSelectedProtocol() string { + cbuf := C.malloc(100) + defer C.free(cbuf) + + ret := C.alpn_get_selected_protocol(c.sess, (*C.char)(cbuf)) + if int(ret) < 0 { + return "" + } + alpnname := C.GoString((*C.char)(cbuf)) + return alpnname +} + +func (c *Conn) getServerName() string { + buflen := 100 + nametype := GNUTLS_NAME_DNS + cbuf := C.malloc(C.size_t(buflen)) + defer C.free(cbuf) + + ret := C.gnutls_server_name_get(c.sess.session, cbuf, + (*C.size_t)(unsafe.Pointer(&buflen)), + (*C.uint)(unsafe.Pointer(&nametype)), 0) + if int(ret) < 0 { + return "" + } + name := C.GoString((*C.char)(cbuf)) + return name +} + // DataRead c callback function for data read //export DataRead func DataRead(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int { diff --git a/tls_test.go b/tls_test.go index 83e9043..3c86c6f 100644 --- a/tls_test.go +++ b/tls_test.go @@ -120,3 +120,175 @@ func TestTLSServer(t *testing.T) { t.Errorf("need: %s, got: %s", data, string(buf[:n])) } } + +func TestTLSALPNServer(t *testing.T) { + serveralpn := []string{"a1", "a3", "a2"} + clientalpn := []string{"a0", "a2", "a5"} + expectedAlpn := "a2" + + l, err := Listen("tcp", "127.0.0.1:0", &Config{ + CrtFile: "testdata/server.crt", + KeyFile: "testdata/server.key", + NextProtos: serveralpn, + }) + if err != nil { + t.Fatal("gnutls listen ", err) + } + addr := l.Addr().String() + log.Println("test server listen on ", addr) + defer l.Close() + go func() { + for { + c, err := l.Accept() + if err != nil { + log.Println("gnutls accept ", err) + break + } + log.Println("accept connection from ", c.RemoteAddr()) + go func(c net.Conn) { + defer c.Close() + tlsConn := c.(*Conn) + if err := tlsConn.Handshake(); err != nil { + log.Println(err) + return + } + connState := tlsConn.ConnectionState() + log.Printf("%+v", connState) + buf := make([]byte, 4096) + for { + n, err := c.Read(buf[0:]) + if err != nil { + log.Println("gnutls read ", err) + break + } + if _, err := c.Write(buf[:n]); err != nil { + log.Println("gnutls write ", err) + break + } + } + }(c) + } + }() + + c, err := tls.Dial("tcp", addr, &tls.Config{ + InsecureSkipVerify: true, + ServerName: "localhost", + NextProtos: clientalpn, + }) + if err != nil { + t.Fatal("dial ", err) + } + defer c.Close() + + if err := c.Handshake(); err != nil { + t.Fatal(err) + } + connState := c.ConnectionState() + log.Printf("%+v", connState) + + if connState.NegotiatedProtocol != expectedAlpn { + t.Errorf("expected alpn %s, got %s", + expectedAlpn, connState.NegotiatedProtocol) + } + + data := "hello, world" + if _, err := c.Write([]byte(data)); err != nil { + t.Fatal("write ", err) + } + buf := make([]byte, 100) + n, err := c.Read(buf) + if err != nil { + t.Fatal("read ", err) + } + if string(buf[:n]) != data { + t.Errorf("need: %s, got: %s", data, string(buf[:n])) + } +} + +func TestTLSALPNClient(t *testing.T) { + serveralpn := []string{"a1", "a3", "a2"} + clientalpn := []string{"a0", "a2", "a5"} + expectedAlpn := "a2" + + cert, err := tls.LoadX509KeyPair("testdata/server.crt", "testdata/server.key") + if err != nil { + t.Fatal("load key failed") + } + + l, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{cert}, + NextProtos: serveralpn, + }) + + if err != nil { + t.Fatal("tls listen ", err) + } + addr := l.Addr().String() + log.Println("test server listen on ", addr) + defer l.Close() + go func() { + for { + c, err := l.Accept() + if err != nil { + log.Println("gnutls accept ", err) + break + } + log.Println("accept connection from ", c.RemoteAddr()) + go func(c net.Conn) { + defer c.Close() + tlsConn := c.(*tls.Conn) + if err := tlsConn.Handshake(); err != nil { + log.Println(err) + return + } + connState := tlsConn.ConnectionState() + log.Printf("%+v", connState) + buf := make([]byte, 4096) + for { + n, err := c.Read(buf[0:]) + if err != nil { + log.Println("tls read ", err) + break + } + if _, err := c.Write(buf[:n]); err != nil { + log.Println("tls write ", err) + break + } + } + }(c) + } + }() + + c, err := Dial("tcp", addr, &Config{InsecureSkipVerify: true, + ServerName: "localhost", + NextProtos: clientalpn, + }) + if err != nil { + t.Fatal("dial ", err) + } + defer c.Close() + + if err := c.Handshake(); err != nil { + t.Fatal(err) + } + connState := c.ConnectionState() + log.Printf("%+v", connState) + + if connState.NegotiatedProtocol != expectedAlpn { + t.Errorf("expected alpn %s, got %s", + expectedAlpn, connState.NegotiatedProtocol) + } + + data := "hello, world" + if _, err := c.Write([]byte(data)); err != nil { + t.Fatal("write ", err) + } + buf := make([]byte, 100) + n, err := c.Read(buf) + if err != nil { + t.Fatal("read ", err) + } + if string(buf[:n]) != data { + t.Errorf("need: %s, got: %s", data, string(buf[:n])) + } +}