From 9a7e25c4081fb01f864e67b177765dcdc4dfbe53 Mon Sep 17 00:00:00 2001 From: dingjun Date: Mon, 6 Jul 2020 19:19:38 +0800 Subject: [PATCH] split file --- main.go | 40 ------------ server.go | 164 ++++++++++---------------------------------------- tcp_server.go | 69 +++++++++++++++++++++ ws_server.go | 88 +++++++++++++++++++++++++++ 4 files changed, 189 insertions(+), 172 deletions(-) create mode 100644 tcp_server.go create mode 100644 ws_server.go diff --git a/main.go b/main.go index c542375..d627947 100644 --- a/main.go +++ b/main.go @@ -3,7 +3,6 @@ package main import ( "flag" "io/ioutil" - "net/url" "os" "os/signal" "syscall" @@ -12,45 +11,6 @@ import ( "gopkg.in/yaml.v2" ) -func makeServers(cfg conf) { - var wsservers = []wsServer{} - var tcpservers = []tcpServer{} - - for _, c := range cfg.ProxyConfig { - u, err := url.Parse(c.Listen) - if err != nil { - log.Fatalf("parse %s, error %s", c.Listen, err) - } - - switch u.Scheme { - case "ws": - exists := false - for i := 0; i < len(wsservers); i++ { - if wsservers[i].addr == u.Host { - exists = true - wsservers[i].rule = append(wsservers[i].rule, forwardRule{u.Path, c.Remote}) - break - } - } - if !exists { - wsservers = append(wsservers, wsServer{u.Host, []forwardRule{{u.Path, c.Remote}}}) - } - case "tcp": - tcpservers = append(tcpservers, tcpServer{u.Host, c.Remote}) - default: - log.Fatalf("unsupported scheme %s", u.Scheme) - } - } - - for _, srv := range wsservers { - go srv.run() - } - - for _, srv := range tcpservers { - go srv.run() - } -} - func main() { var cfgfile string var logfile string diff --git a/server.go b/server.go index c331bef..ee37215 100644 --- a/server.go +++ b/server.go @@ -3,52 +3,23 @@ package main import ( "io" "net" - "net/http" "net/url" log "github.com/fangdingjun/go-log/v5" "github.com/gorilla/websocket" ) -type forwardRule struct { - local string - remote string -} - -type wsServer struct { - addr string - rule []forwardRule -} - -type tcpServer struct { - addr string - remote string -} - -func (wss *wsServer) run() { - if err := http.ListenAndServe(wss.addr, wss); err != nil { - log.Errorln(err) - } -} - -var upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, -} - -var dialer = &websocket.Dialer{} - -func forwardWS2WS(conn, conn1 *websocket.Conn) { +func forwardWS2WS(conn1, conn2 *websocket.Conn) { ch := make(chan struct{}, 2) go func() { for { - t, data, err := conn.ReadMessage() + t, data, err := conn1.ReadMessage() if err != nil { log.Errorln(err) break } - err = conn1.WriteMessage(t, data) + err = conn2.WriteMessage(t, data) if err != nil { log.Errorln(err) break @@ -59,12 +30,12 @@ func forwardWS2WS(conn, conn1 *websocket.Conn) { go func() { for { - t, data, err := conn1.ReadMessage() + t, data, err := conn2.ReadMessage() if err != nil { log.Errorln(err) break } - err = conn.WriteMessage(t, data) + err = conn1.WriteMessage(t, data) if err != nil { log.Errorln(err) break @@ -140,112 +111,41 @@ func forwardTCP2TCP(c1, c2 net.Conn) { <-ch } -func (wss *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - p := r.URL.Path - remote := "" - for _, ru := range wss.rule { - if ru.local == p { - remote = ru.remote - } - } - - if remote == "" { - http.Error(w, "not found", http.StatusNotFound) - return - } - - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - log.Errorln(err) - http.Error(w, "bad request", http.StatusBadRequest) - return - } - defer conn.Close() - - u, _ := url.Parse(remote) - - if u.Scheme == "ws" || u.Scheme == "wss" { - conn1, resp, err := dialer.Dial(remote, nil) - if err != nil { - log.Errorln(err) - return - } - resp.Body.Close() - if resp.StatusCode != http.StatusSwitchingProtocols { - log.Errorf("dial remote ws %d", resp.StatusCode) - return - } - defer conn1.Close() - - forwardWS2WS(conn, conn1) - return - } +func makeServers(cfg conf) { + var wsservers = []wsServer{} + var tcpservers = []tcpServer{} - if u.Scheme == "tcp" { - conn1, err := net.Dial("tcp", u.Host) + for _, c := range cfg.ProxyConfig { + u, err := url.Parse(c.Listen) if err != nil { - log.Errorln(err) - return + log.Fatalf("parse %s, error %s", c.Listen, err) } - defer conn1.Close() - forwardWS2TCP(conn, conn1) - return - } - log.Errorf("unsupported scheme %s", u.Scheme) -} - -func (srv *tcpServer) run() { - l, err := net.Listen("tcp", srv.addr) - if err != nil { - log.Errorln(err) - return - } - defer l.Close() - - for { - conn, err := l.Accept() - if err != nil { - log.Error(err) - return + switch u.Scheme { + case "ws": + exists := false + for i := 0; i < len(wsservers); i++ { + if wsservers[i].addr == u.Host { + exists = true + wsservers[i].rule = append(wsservers[i].rule, forwardRule{u.Path, c.Remote}) + break + } + } + if !exists { + wsservers = append(wsservers, wsServer{u.Host, []forwardRule{{u.Path, c.Remote}}}) + } + case "tcp": + tcpservers = append(tcpservers, tcpServer{u.Host, c.Remote}) + default: + log.Fatalf("unsupported scheme %s", u.Scheme) } - go srv.serve(conn) } -} - -func (srv *tcpServer) serve(c net.Conn) { - defer c.Close() - u, _ := url.Parse(srv.remote) - - if u.Scheme == "ws" || u.Scheme == "wss" { - conn1, resp, err := dialer.Dial(srv.remote, nil) - if err != nil { - log.Errorln(err) - return - } - resp.Body.Close() - if resp.StatusCode != http.StatusSwitchingProtocols { - log.Errorf("dial remote ws %d", resp.StatusCode) - return - } - defer conn1.Close() - - forwardWS2TCP(conn1, c) - return + for _, srv := range wsservers { + go srv.run() } - if u.Scheme == "tcp" { - conn1, err := net.Dial("tcp", u.Host) - if err != nil { - log.Errorln(err) - return - } - defer conn1.Close() - - forwardTCP2TCP(c, conn1) - return + for _, srv := range tcpservers { + go srv.run() } - - log.Errorf("unsupported scheme %s", u.Scheme) } diff --git a/tcp_server.go b/tcp_server.go new file mode 100644 index 0000000..67ca5b1 --- /dev/null +++ b/tcp_server.go @@ -0,0 +1,69 @@ +package main + +import ( + "net" + "net/http" + "net/url" + + log "github.com/fangdingjun/go-log/v5" +) + +type tcpServer struct { + addr string + remote string +} + +func (srv *tcpServer) run() { + l, err := net.Listen("tcp", srv.addr) + if err != nil { + log.Errorln(err) + return + } + defer l.Close() + + for { + conn, err := l.Accept() + if err != nil { + log.Error(err) + return + } + go srv.serve(conn) + } +} + +func (srv *tcpServer) serve(c net.Conn) { + defer c.Close() + + u, _ := url.Parse(srv.remote) + + if u.Scheme == "ws" || u.Scheme == "wss" { + conn1, resp, err := dialer.Dial(srv.remote, nil) + if err != nil { + log.Errorln(err) + return + } + resp.Body.Close() + if resp.StatusCode != http.StatusSwitchingProtocols { + log.Errorf("dial remote ws %d", resp.StatusCode) + return + } + defer conn1.Close() + + forwardWS2TCP(conn1, c) + return + } + + if u.Scheme == "tcp" { + conn1, err := net.Dial("tcp", u.Host) + if err != nil { + log.Errorln(err) + return + } + defer conn1.Close() + + forwardTCP2TCP(c, conn1) + return + } + + log.Errorf("unsupported scheme %s", u.Scheme) +} diff --git a/ws_server.go b/ws_server.go new file mode 100644 index 0000000..6cf2ee1 --- /dev/null +++ b/ws_server.go @@ -0,0 +1,88 @@ +package main + +import ( + "net" + "net/http" + "net/url" + + log "github.com/fangdingjun/go-log/v5" + "github.com/gorilla/websocket" +) + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, +} + +var dialer = &websocket.Dialer{} + +type forwardRule struct { + local string + remote string +} + +type wsServer struct { + addr string + rule []forwardRule +} + +func (wss *wsServer) run() { + if err := http.ListenAndServe(wss.addr, wss); err != nil { + log.Errorln(err) + } +} + +func (wss *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + p := r.URL.Path + remote := "" + for _, ru := range wss.rule { + if ru.local == p { + remote = ru.remote + } + } + + if remote == "" { + http.Error(w, "not found", http.StatusNotFound) + return + } + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Errorln(err) + http.Error(w, "bad request", http.StatusBadRequest) + return + } + defer conn.Close() + + u, _ := url.Parse(remote) + + if u.Scheme == "ws" || u.Scheme == "wss" { + conn1, resp, err := dialer.Dial(remote, nil) + if err != nil { + log.Errorln(err) + return + } + resp.Body.Close() + if resp.StatusCode != http.StatusSwitchingProtocols { + log.Errorf("dial remote ws %d", resp.StatusCode) + return + } + defer conn1.Close() + + forwardWS2WS(conn, conn1) + return + } + + if u.Scheme == "tcp" { + conn1, err := net.Dial("tcp", u.Host) + if err != nil { + log.Errorln(err) + return + } + defer conn1.Close() + + forwardWS2TCP(conn, conn1) + return + } + log.Errorf("unsupported scheme %s", u.Scheme) +}