master
dingjun 4 years ago
parent 89cacc7f0f
commit 78a61c187e

@ -21,8 +21,8 @@ func makeServers(cfg conf) {
if err != nil {
log.Fatalf("parse %s, error %s", c.Listen, err)
}
switch u.Scheme {
switch u.Scheme {
case "ws":
exists := false
for i := 0; i < len(wsservers); i++ {
@ -35,7 +35,6 @@ func makeServers(cfg conf) {
if !exists {
wsservers = append(wsservers, wsServer{u.Host, []forwardRule{{u.Path, c.Remote}}})
}
case "tcp":
tcpservers = append(tcpservers, tcpServer{u.Host, c.Remote})
default:
@ -46,6 +45,7 @@ func makeServers(cfg conf) {
for _, srv := range wsservers {
go srv.run()
}
for _, srv := range tcpservers {
go srv.run()
}
@ -64,6 +64,7 @@ func main() {
if err != nil {
log.Fatal(err)
}
var cfg conf
if err := yaml.Unmarshal(data, &cfg); err != nil {
log.Fatal(err)

@ -35,10 +35,12 @@ var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
var dialer = &websocket.Dialer{}
func forwardWS2WS(conn, conn1 *websocket.Conn) {
ch := make(chan struct{}, 2)
go func() {
for {
t, data, err := conn.ReadMessage()
@ -54,6 +56,7 @@ func forwardWS2WS(conn, conn1 *websocket.Conn) {
}
ch <- struct{}{}
}()
go func() {
for {
t, data, err := conn1.ReadMessage()
@ -68,14 +71,14 @@ func forwardWS2WS(conn, conn1 *websocket.Conn) {
}
}
ch <- struct{}{}
}()
<-ch
}
func forwardWS2TCP(conn1 *websocket.Conn, conn2 net.Conn) {
ch := make(chan struct{}, 2)
go func() {
for {
_, data, err := conn1.ReadMessage()
@ -83,6 +86,7 @@ func forwardWS2TCP(conn1 *websocket.Conn, conn2 net.Conn) {
log.Errorln(err)
break
}
_, err = conn2.Write(data)
if err != nil {
log.Errorln(err)
@ -91,14 +95,17 @@ func forwardWS2TCP(conn1 *websocket.Conn, conn2 net.Conn) {
}
ch <- struct{}{}
}()
go func() {
buf := make([]byte, 1024)
for {
n, err := conn2.Read(buf)
if err != nil {
log.Errorln(err)
break
}
err = conn1.WriteMessage(websocket.BinaryMessage, buf[:n])
if err != nil {
log.Errorln(err)
@ -107,11 +114,13 @@ func forwardWS2TCP(conn1 *websocket.Conn, conn2 net.Conn) {
}
ch <- struct{}{}
}()
<-ch
}
func forwardTCP2TCP(c1, c2 net.Conn) {
ch := make(chan struct{}, 2)
go func() {
_, err := io.Copy(c1, c2)
if err != nil {
@ -119,6 +128,7 @@ func forwardTCP2TCP(c1, c2 net.Conn) {
}
ch <- struct{}{}
}()
go func() {
_, err := io.Copy(c2, c1)
if err != nil {
@ -126,6 +136,7 @@ func forwardTCP2TCP(c1, c2 net.Conn) {
}
ch <- struct{}{}
}()
<-ch
}
@ -142,6 +153,7 @@ func (wss *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, "not found", http.StatusNotFound)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Errorln(err)
@ -159,11 +171,12 @@ func (wss *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
resp.Body.Close()
defer conn1.Close()
if resp.StatusCode != http.StatusSwitchingProtocols {
log.Errorf("dial remote ws %d", resp.StatusCode)
return
}
defer conn1.Close()
forwardWS2WS(conn, conn1)
return
}
@ -175,6 +188,7 @@ func (wss *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
defer conn1.Close()
forwardWS2TCP(conn, conn1)
return
}
@ -188,6 +202,7 @@ func (srv *tcpServer) run() {
return
}
defer l.Close()
for {
conn, err := l.Accept()
if err != nil {
@ -210,11 +225,12 @@ func (srv *tcpServer) serve(c net.Conn) {
return
}
resp.Body.Close()
defer conn1.Close()
if resp.StatusCode != http.StatusSwitchingProtocols {
log.Errorf("dial remote ws %d", resp.StatusCode)
return
}
defer conn1.Close()
forwardWS2TCP(conn1, c)
return
}
@ -226,8 +242,10 @@ func (srv *tcpServer) serve(c net.Conn) {
return
}
defer conn1.Close()
forwardTCP2TCP(c, conn1)
return
}
log.Errorf("unsupported scheme %s", u.Scheme)
}

Loading…
Cancel
Save