diff --git a/main.go b/main.go index e7058cb..c542375 100644 --- a/main.go +++ b/main.go @@ -21,8 +21,8 @@ func makeServers(cfg conf) { if err != nil { log.Fatalf("parse %s, error %s", c.Listen, err) } - switch u.Scheme { + switch u.Scheme { case "ws": exists := false for i := 0; i < len(wsservers); i++ { @@ -35,7 +35,6 @@ func makeServers(cfg conf) { if !exists { wsservers = append(wsservers, wsServer{u.Host, []forwardRule{{u.Path, c.Remote}}}) } - case "tcp": tcpservers = append(tcpservers, tcpServer{u.Host, c.Remote}) default: @@ -46,6 +45,7 @@ func makeServers(cfg conf) { for _, srv := range wsservers { go srv.run() } + for _, srv := range tcpservers { go srv.run() } @@ -64,6 +64,7 @@ func main() { if err != nil { log.Fatal(err) } + var cfg conf if err := yaml.Unmarshal(data, &cfg); err != nil { log.Fatal(err) diff --git a/server.go b/server.go index 4f45531..c331bef 100644 --- a/server.go +++ b/server.go @@ -35,10 +35,12 @@ var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, } + var dialer = &websocket.Dialer{} func forwardWS2WS(conn, conn1 *websocket.Conn) { ch := make(chan struct{}, 2) + go func() { for { t, data, err := conn.ReadMessage() @@ -54,6 +56,7 @@ func forwardWS2WS(conn, conn1 *websocket.Conn) { } ch <- struct{}{} }() + go func() { for { t, data, err := conn1.ReadMessage() @@ -68,14 +71,14 @@ func forwardWS2WS(conn, conn1 *websocket.Conn) { } } ch <- struct{}{} - }() + <-ch } func forwardWS2TCP(conn1 *websocket.Conn, conn2 net.Conn) { - ch := make(chan struct{}, 2) + go func() { for { _, data, err := conn1.ReadMessage() @@ -83,6 +86,7 @@ func forwardWS2TCP(conn1 *websocket.Conn, conn2 net.Conn) { log.Errorln(err) break } + _, err = conn2.Write(data) if err != nil { log.Errorln(err) @@ -91,14 +95,17 @@ func forwardWS2TCP(conn1 *websocket.Conn, conn2 net.Conn) { } ch <- struct{}{} }() + go func() { buf := make([]byte, 1024) + for { n, err := conn2.Read(buf) if err != nil { log.Errorln(err) break } + err = conn1.WriteMessage(websocket.BinaryMessage, buf[:n]) if err != nil { log.Errorln(err) @@ -107,11 +114,13 @@ func forwardWS2TCP(conn1 *websocket.Conn, conn2 net.Conn) { } ch <- struct{}{} }() + <-ch } func forwardTCP2TCP(c1, c2 net.Conn) { ch := make(chan struct{}, 2) + go func() { _, err := io.Copy(c1, c2) if err != nil { @@ -119,6 +128,7 @@ func forwardTCP2TCP(c1, c2 net.Conn) { } ch <- struct{}{} }() + go func() { _, err := io.Copy(c2, c1) if err != nil { @@ -126,6 +136,7 @@ func forwardTCP2TCP(c1, c2 net.Conn) { } ch <- struct{}{} }() + <-ch } @@ -142,6 +153,7 @@ func (wss *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, "not found", http.StatusNotFound) return } + conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Errorln(err) @@ -159,11 +171,12 @@ func (wss *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } resp.Body.Close() - defer conn1.Close() if resp.StatusCode != http.StatusSwitchingProtocols { log.Errorf("dial remote ws %d", resp.StatusCode) return } + defer conn1.Close() + forwardWS2WS(conn, conn1) return } @@ -175,6 +188,7 @@ func (wss *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } defer conn1.Close() + forwardWS2TCP(conn, conn1) return } @@ -188,6 +202,7 @@ func (srv *tcpServer) run() { return } defer l.Close() + for { conn, err := l.Accept() if err != nil { @@ -210,11 +225,12 @@ func (srv *tcpServer) serve(c net.Conn) { return } resp.Body.Close() - defer conn1.Close() if resp.StatusCode != http.StatusSwitchingProtocols { log.Errorf("dial remote ws %d", resp.StatusCode) return } + defer conn1.Close() + forwardWS2TCP(conn1, c) return } @@ -226,8 +242,10 @@ func (srv *tcpServer) serve(c net.Conn) { return } defer conn1.Close() + forwardTCP2TCP(c, conn1) return } + log.Errorf("unsupported scheme %s", u.Scheme) }