master
dingjun 4 years ago
parent 89cacc7f0f
commit 78a61c187e

@ -21,8 +21,8 @@ func makeServers(cfg conf) {
if err != nil { if err != nil {
log.Fatalf("parse %s, error %s", c.Listen, err) log.Fatalf("parse %s, error %s", c.Listen, err)
} }
switch u.Scheme {
switch u.Scheme {
case "ws": case "ws":
exists := false exists := false
for i := 0; i < len(wsservers); i++ { for i := 0; i < len(wsservers); i++ {
@ -35,7 +35,6 @@ func makeServers(cfg conf) {
if !exists { if !exists {
wsservers = append(wsservers, wsServer{u.Host, []forwardRule{{u.Path, c.Remote}}}) wsservers = append(wsservers, wsServer{u.Host, []forwardRule{{u.Path, c.Remote}}})
} }
case "tcp": case "tcp":
tcpservers = append(tcpservers, tcpServer{u.Host, c.Remote}) tcpservers = append(tcpservers, tcpServer{u.Host, c.Remote})
default: default:
@ -46,6 +45,7 @@ func makeServers(cfg conf) {
for _, srv := range wsservers { for _, srv := range wsservers {
go srv.run() go srv.run()
} }
for _, srv := range tcpservers { for _, srv := range tcpservers {
go srv.run() go srv.run()
} }
@ -64,6 +64,7 @@ func main() {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
var cfg conf var cfg conf
if err := yaml.Unmarshal(data, &cfg); err != nil { if err := yaml.Unmarshal(data, &cfg); err != nil {
log.Fatal(err) log.Fatal(err)

@ -35,10 +35,12 @@ var upgrader = websocket.Upgrader{
ReadBufferSize: 1024, ReadBufferSize: 1024,
WriteBufferSize: 1024, WriteBufferSize: 1024,
} }
var dialer = &websocket.Dialer{} var dialer = &websocket.Dialer{}
func forwardWS2WS(conn, conn1 *websocket.Conn) { func forwardWS2WS(conn, conn1 *websocket.Conn) {
ch := make(chan struct{}, 2) ch := make(chan struct{}, 2)
go func() { go func() {
for { for {
t, data, err := conn.ReadMessage() t, data, err := conn.ReadMessage()
@ -54,6 +56,7 @@ func forwardWS2WS(conn, conn1 *websocket.Conn) {
} }
ch <- struct{}{} ch <- struct{}{}
}() }()
go func() { go func() {
for { for {
t, data, err := conn1.ReadMessage() t, data, err := conn1.ReadMessage()
@ -68,14 +71,14 @@ func forwardWS2WS(conn, conn1 *websocket.Conn) {
} }
} }
ch <- struct{}{} ch <- struct{}{}
}() }()
<-ch <-ch
} }
func forwardWS2TCP(conn1 *websocket.Conn, conn2 net.Conn) { func forwardWS2TCP(conn1 *websocket.Conn, conn2 net.Conn) {
ch := make(chan struct{}, 2) ch := make(chan struct{}, 2)
go func() { go func() {
for { for {
_, data, err := conn1.ReadMessage() _, data, err := conn1.ReadMessage()
@ -83,6 +86,7 @@ func forwardWS2TCP(conn1 *websocket.Conn, conn2 net.Conn) {
log.Errorln(err) log.Errorln(err)
break break
} }
_, err = conn2.Write(data) _, err = conn2.Write(data)
if err != nil { if err != nil {
log.Errorln(err) log.Errorln(err)
@ -91,14 +95,17 @@ func forwardWS2TCP(conn1 *websocket.Conn, conn2 net.Conn) {
} }
ch <- struct{}{} ch <- struct{}{}
}() }()
go func() { go func() {
buf := make([]byte, 1024) buf := make([]byte, 1024)
for { for {
n, err := conn2.Read(buf) n, err := conn2.Read(buf)
if err != nil { if err != nil {
log.Errorln(err) log.Errorln(err)
break break
} }
err = conn1.WriteMessage(websocket.BinaryMessage, buf[:n]) err = conn1.WriteMessage(websocket.BinaryMessage, buf[:n])
if err != nil { if err != nil {
log.Errorln(err) log.Errorln(err)
@ -107,11 +114,13 @@ func forwardWS2TCP(conn1 *websocket.Conn, conn2 net.Conn) {
} }
ch <- struct{}{} ch <- struct{}{}
}() }()
<-ch <-ch
} }
func forwardTCP2TCP(c1, c2 net.Conn) { func forwardTCP2TCP(c1, c2 net.Conn) {
ch := make(chan struct{}, 2) ch := make(chan struct{}, 2)
go func() { go func() {
_, err := io.Copy(c1, c2) _, err := io.Copy(c1, c2)
if err != nil { if err != nil {
@ -119,6 +128,7 @@ func forwardTCP2TCP(c1, c2 net.Conn) {
} }
ch <- struct{}{} ch <- struct{}{}
}() }()
go func() { go func() {
_, err := io.Copy(c2, c1) _, err := io.Copy(c2, c1)
if err != nil { if err != nil {
@ -126,6 +136,7 @@ func forwardTCP2TCP(c1, c2 net.Conn) {
} }
ch <- struct{}{} ch <- struct{}{}
}() }()
<-ch <-ch
} }
@ -142,6 +153,7 @@ func (wss *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, "not found", http.StatusNotFound) http.Error(w, "not found", http.StatusNotFound)
return return
} }
conn, err := upgrader.Upgrade(w, r, nil) conn, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
log.Errorln(err) log.Errorln(err)
@ -159,11 +171,12 @@ func (wss *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
resp.Body.Close() resp.Body.Close()
defer conn1.Close()
if resp.StatusCode != http.StatusSwitchingProtocols { if resp.StatusCode != http.StatusSwitchingProtocols {
log.Errorf("dial remote ws %d", resp.StatusCode) log.Errorf("dial remote ws %d", resp.StatusCode)
return return
} }
defer conn1.Close()
forwardWS2WS(conn, conn1) forwardWS2WS(conn, conn1)
return return
} }
@ -175,6 +188,7 @@ func (wss *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
defer conn1.Close() defer conn1.Close()
forwardWS2TCP(conn, conn1) forwardWS2TCP(conn, conn1)
return return
} }
@ -188,6 +202,7 @@ func (srv *tcpServer) run() {
return return
} }
defer l.Close() defer l.Close()
for { for {
conn, err := l.Accept() conn, err := l.Accept()
if err != nil { if err != nil {
@ -210,11 +225,12 @@ func (srv *tcpServer) serve(c net.Conn) {
return return
} }
resp.Body.Close() resp.Body.Close()
defer conn1.Close()
if resp.StatusCode != http.StatusSwitchingProtocols { if resp.StatusCode != http.StatusSwitchingProtocols {
log.Errorf("dial remote ws %d", resp.StatusCode) log.Errorf("dial remote ws %d", resp.StatusCode)
return return
} }
defer conn1.Close()
forwardWS2TCP(conn1, c) forwardWS2TCP(conn1, c)
return return
} }
@ -226,8 +242,10 @@ func (srv *tcpServer) serve(c net.Conn) {
return return
} }
defer conn1.Close() defer conn1.Close()
forwardTCP2TCP(c, conn1) forwardTCP2TCP(c, conn1)
return return
} }
log.Errorf("unsupported scheme %s", u.Scheme) log.Errorf("unsupported scheme %s", u.Scheme)
} }

Loading…
Cancel
Save