add alpn support

master
fangdingjun 6 years ago
parent a1a67d272c
commit 555759c8d9

@ -35,4 +35,6 @@ gnutls_cipher_hd_t new_cipher(int cipher_type, char *key, int keylen, char *iv,
gnutls_hash_hd_t new_hash(int t); gnutls_hash_hd_t new_hash(int t);
int alpn_set_protocols(struct session *sess, char **, int);
int alpn_get_selected_protocol(struct session *sess, char *buf);
#endif #endif

@ -145,3 +145,40 @@ gnutls_hash_hd_t new_hash(int t)
gnutls_hash_init(&hash, t); gnutls_hash_init(&hash, t);
return hash; return hash;
} }
int alpn_set_protocols(struct session *sess, char **names, int namelen)
{
gnutls_datum_t *t;
int ret;
int i;
t = (gnutls_datum_t *)malloc(namelen * sizeof(gnutls_datum_t));
for (i = 0; i < namelen; i++)
{
t[i].data = names[i];
t[i].size = strlen(names[i]);
}
ret = gnutls_alpn_set_protocols(sess->session, t,
namelen,
GNUTLS_ALPN_SERVER_PRECEDENCE);
free(t);
return ret;
}
int alpn_get_selected_protocol(struct session *sess, char *buf)
{
gnutls_datum_t p;
int ret;
memset(&p, 0, sizeof(gnutls_datum_t));
ret = gnutls_alpn_get_selected_protocol(sess->session, &p);
if (ret < 0)
{
return ret;
}
strcpy(buf, p.data);
// note: p.data is constant value, only valid during the session life
return 0;
}

@ -17,6 +17,8 @@ import (
const ( const (
GNUTLS_NAME_DNS = 1 GNUTLS_NAME_DNS = 1
GNUTLS_X509_FMT_PEM = 1 GNUTLS_X509_FMT_PEM = 1
GNUTLS_ALPN_MANDATORY = 1
GNUTLS_ALPN_SERVER_PRECEDENCE = 1 << 1
) )
// Conn tls connection for client // Conn tls connection for client
@ -25,6 +27,7 @@ type Conn struct {
sess *C.struct_session sess *C.struct_session
handshake bool handshake bool
cservname *C.char cservname *C.char
state *ConnectionState
} }
// Config tls configure // Config tls configure
@ -33,7 +36,22 @@ type Config struct {
CrtFile string CrtFile string
KeyFile string KeyFile string
InsecureSkipVerify bool InsecureSkipVerify bool
NextProtos []string
} }
// ConnectionState connection state
type ConnectionState struct {
// SNI name client send
ServerName string
// selected ALPN protocl
NegotiatedProtocol string
HandshakeComplete bool
// TLS version number, ex: 0x303
Version uint16
// TLS version number, ex: TLS1.0
VersionName string
}
type listener struct { type listener struct {
l net.Listener l net.Listener
c *Config c *Config
@ -106,6 +124,11 @@ func NewServerConn(c net.Conn, cfg *Config) (*Conn, error) {
cerrstr := C.gnutls_strerror(ret) cerrstr := C.gnutls_strerror(ret)
return nil, fmt.Errorf("set keyfile failed: %s", C.GoString(cerrstr)) return nil, fmt.Errorf("set keyfile failed: %s", C.GoString(cerrstr))
} }
if cfg.NextProtos != nil {
if err := setAlpnProtocols(sess, cfg); err != nil {
log.Println(err)
}
}
return conn, nil return conn, nil
} }
@ -145,12 +168,35 @@ func NewClientConn(c net.Conn, cfg *Config) (*Conn, error) {
C.gnutls_session_set_verify_cert(sess.session, nil, 0) C.gnutls_session_set_verify_cert(sess.session, nil, 0)
} }
} }
if cfg.NextProtos != nil {
if err := setAlpnProtocols(sess, cfg); err != nil {
log.Println(err)
}
}
} else { } else {
C.gnutls_session_set_verify_cert(sess.session, nil, 0) C.gnutls_session_set_verify_cert(sess.session, nil, 0)
} }
return conn, nil return conn, nil
} }
func setAlpnProtocols(sess *C.struct_session, cfg *Config) error {
arg := make([](*C.char), 0)
for _, s := range cfg.NextProtos {
cbuf := C.CString(s)
defer C.free(unsafe.Pointer(cbuf))
arg = append(arg, (*C.char)(cbuf))
}
ret := C.alpn_set_protocols(sess,
(**C.char)(unsafe.Pointer(&arg[0])), C.int(len(cfg.NextProtos)))
if int(ret) < 0 {
return fmt.Errorf("set alpn failed: %s", C.GoString(C.gnutls_strerror(ret)))
}
return nil
}
// Handshake handshake tls // Handshake handshake tls
func (c *Conn) Handshake() error { func (c *Conn) Handshake() error {
if c.handshake { if c.handshake {
@ -257,6 +303,56 @@ func (c *Conn) SetDeadline(t time.Time) error {
return c.c.SetDeadline(t) return c.c.SetDeadline(t)
} }
// ConnectionState report connection state
func (c *Conn) ConnectionState() *ConnectionState {
if c.state != nil {
return c.state
}
version :=
uint16(C.gnutls_protocol_get_version(c.sess.session))
versionname := C.GoString(
C.gnutls_protocol_get_name(C.gnutls_protocol_t(version)))
state := &ConnectionState{
NegotiatedProtocol: c.getAlpnSelectedProtocol(),
Version: version,
HandshakeComplete: c.handshake,
ServerName: c.getServerName(),
VersionName: versionname,
}
c.state = state
return state
}
func (c *Conn) getAlpnSelectedProtocol() string {
cbuf := C.malloc(100)
defer C.free(cbuf)
ret := C.alpn_get_selected_protocol(c.sess, (*C.char)(cbuf))
if int(ret) < 0 {
return ""
}
alpnname := C.GoString((*C.char)(cbuf))
return alpnname
}
func (c *Conn) getServerName() string {
buflen := 100
nametype := GNUTLS_NAME_DNS
cbuf := C.malloc(C.size_t(buflen))
defer C.free(cbuf)
ret := C.gnutls_server_name_get(c.sess.session, cbuf,
(*C.size_t)(unsafe.Pointer(&buflen)),
(*C.uint)(unsafe.Pointer(&nametype)), 0)
if int(ret) < 0 {
return ""
}
name := C.GoString((*C.char)(cbuf))
return name
}
// DataRead c callback function for data read // DataRead c callback function for data read
//export DataRead //export DataRead
func DataRead(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int { func DataRead(d unsafe.Pointer, cbuf *C.char, bufLen C.int) C.int {

@ -120,3 +120,175 @@ func TestTLSServer(t *testing.T) {
t.Errorf("need: %s, got: %s", data, string(buf[:n])) t.Errorf("need: %s, got: %s", data, string(buf[:n]))
} }
} }
func TestTLSALPNServer(t *testing.T) {
serveralpn := []string{"a1", "a3", "a2"}
clientalpn := []string{"a0", "a2", "a5"}
expectedAlpn := "a2"
l, err := Listen("tcp", "127.0.0.1:0", &Config{
CrtFile: "testdata/server.crt",
KeyFile: "testdata/server.key",
NextProtos: serveralpn,
})
if err != nil {
t.Fatal("gnutls listen ", err)
}
addr := l.Addr().String()
log.Println("test server listen on ", addr)
defer l.Close()
go func() {
for {
c, err := l.Accept()
if err != nil {
log.Println("gnutls accept ", err)
break
}
log.Println("accept connection from ", c.RemoteAddr())
go func(c net.Conn) {
defer c.Close()
tlsConn := c.(*Conn)
if err := tlsConn.Handshake(); err != nil {
log.Println(err)
return
}
connState := tlsConn.ConnectionState()
log.Printf("%+v", connState)
buf := make([]byte, 4096)
for {
n, err := c.Read(buf[0:])
if err != nil {
log.Println("gnutls read ", err)
break
}
if _, err := c.Write(buf[:n]); err != nil {
log.Println("gnutls write ", err)
break
}
}
}(c)
}
}()
c, err := tls.Dial("tcp", addr, &tls.Config{
InsecureSkipVerify: true,
ServerName: "localhost",
NextProtos: clientalpn,
})
if err != nil {
t.Fatal("dial ", err)
}
defer c.Close()
if err := c.Handshake(); err != nil {
t.Fatal(err)
}
connState := c.ConnectionState()
log.Printf("%+v", connState)
if connState.NegotiatedProtocol != expectedAlpn {
t.Errorf("expected alpn %s, got %s",
expectedAlpn, connState.NegotiatedProtocol)
}
data := "hello, world"
if _, err := c.Write([]byte(data)); err != nil {
t.Fatal("write ", err)
}
buf := make([]byte, 100)
n, err := c.Read(buf)
if err != nil {
t.Fatal("read ", err)
}
if string(buf[:n]) != data {
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
}
}
func TestTLSALPNClient(t *testing.T) {
serveralpn := []string{"a1", "a3", "a2"}
clientalpn := []string{"a0", "a2", "a5"}
expectedAlpn := "a2"
cert, err := tls.LoadX509KeyPair("testdata/server.crt", "testdata/server.key")
if err != nil {
t.Fatal("load key failed")
}
l, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
Certificates: []tls.Certificate{cert},
NextProtos: serveralpn,
})
if err != nil {
t.Fatal("tls listen ", err)
}
addr := l.Addr().String()
log.Println("test server listen on ", addr)
defer l.Close()
go func() {
for {
c, err := l.Accept()
if err != nil {
log.Println("gnutls accept ", err)
break
}
log.Println("accept connection from ", c.RemoteAddr())
go func(c net.Conn) {
defer c.Close()
tlsConn := c.(*tls.Conn)
if err := tlsConn.Handshake(); err != nil {
log.Println(err)
return
}
connState := tlsConn.ConnectionState()
log.Printf("%+v", connState)
buf := make([]byte, 4096)
for {
n, err := c.Read(buf[0:])
if err != nil {
log.Println("tls read ", err)
break
}
if _, err := c.Write(buf[:n]); err != nil {
log.Println("tls write ", err)
break
}
}
}(c)
}
}()
c, err := Dial("tcp", addr, &Config{InsecureSkipVerify: true,
ServerName: "localhost",
NextProtos: clientalpn,
})
if err != nil {
t.Fatal("dial ", err)
}
defer c.Close()
if err := c.Handshake(); err != nil {
t.Fatal(err)
}
connState := c.ConnectionState()
log.Printf("%+v", connState)
if connState.NegotiatedProtocol != expectedAlpn {
t.Errorf("expected alpn %s, got %s",
expectedAlpn, connState.NegotiatedProtocol)
}
data := "hello, world"
if _, err := c.Write([]byte(data)); err != nil {
t.Fatal("write ", err)
}
buf := make([]byte, 100)
n, err := c.Read(buf)
if err != nil {
t.Fatal("read ", err)
}
if string(buf[:n]) != data {
t.Errorf("need: %s, got: %s", data, string(buf[:n]))
}
}

Loading…
Cancel
Save