You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
86 lines
1.6 KiB
Go
86 lines
1.6 KiB
Go
package nghttp2
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
)
|
|
|
|
// Transport the nghttp2 RoundTripper implement
|
|
type Transport struct {
|
|
TLSConfig *tls.Config
|
|
DialTLS func(network, addr string, cfg *tls.Config) (*tls.Conn, error)
|
|
cacheConn map[string]*Conn
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// RoundTrip send req and get res
|
|
func (tr *Transport) RoundTrip(req *http.Request) (res *http.Response, err error) {
|
|
h2conn, err := tr.getConn(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return h2conn.RoundTrip(req)
|
|
}
|
|
|
|
func (tr *Transport) getConn(req *http.Request) (*Conn, error) {
|
|
tr.mu.Lock()
|
|
defer tr.mu.Unlock()
|
|
|
|
if tr.cacheConn == nil {
|
|
tr.cacheConn = map[string]*Conn{}
|
|
}
|
|
k := req.URL.Host
|
|
if c, ok := tr.cacheConn[k]; ok {
|
|
if c.CanTakeNewRequest() {
|
|
return c, nil
|
|
}
|
|
delete(tr.cacheConn, k)
|
|
c.Close()
|
|
}
|
|
c, err := tr.createConn(k)
|
|
if err == nil {
|
|
tr.cacheConn[k] = c
|
|
}
|
|
return c, err
|
|
}
|
|
|
|
func (tr *Transport) createConn(host string) (*Conn, error) {
|
|
dial := tls.Dial
|
|
if tr.DialTLS != nil {
|
|
dial = tr.DialTLS
|
|
}
|
|
cfg := tr.TLSConfig
|
|
if cfg == nil {
|
|
h, _, err := net.SplitHostPort(host)
|
|
if err != nil {
|
|
h = host
|
|
}
|
|
cfg = &tls.Config{
|
|
ServerName: h,
|
|
NextProtos: []string{"h2"},
|
|
}
|
|
}
|
|
if !strings.Contains(host, ":") {
|
|
host = fmt.Sprintf("%s:443", host)
|
|
}
|
|
conn, err := dial("tcp", host, cfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err = conn.Handshake(); err != nil {
|
|
return nil, err
|
|
}
|
|
state := conn.ConnectionState()
|
|
if state.NegotiatedProtocol != "h2" {
|
|
conn.Close()
|
|
return nil, errors.New("http2 is not supported")
|
|
}
|
|
|
|
return Client(conn)
|
|
}
|