diff --git a/server_test.go b/server_test.go index b5d8f77..55cc01d 100644 --- a/server_test.go +++ b/server_test.go @@ -1,10 +1,10 @@ package main import ( - "bytes" "io/ioutil" "net" "net/http" + "net/url" "testing" "time" @@ -13,127 +13,125 @@ import ( "gopkg.in/yaml.v2" ) -func TestServer(t *testing.T) { - cfgfile := "config.example.yaml" - data, err := ioutil.ReadFile(cfgfile) - if err != nil { - log.Fatal(err) - } - var cfg conf - if err := yaml.Unmarshal(data, &cfg); err != nil { - log.Fatal(err) - } - makeServers(cfg) - l1, err := net.Listen("tcp", "127.0.0.1:2903") +func echoServer(addr string) { + l1, err := net.Listen("tcp", addr) if err != nil { - t.Fatal(err) + log.Errorln(err) + return } - go func() { - defer l1.Close() - for { - - c1, err := l1.Accept() - if err != nil { - log.Errorln(err) - return - } - go func(c net.Conn) { - defer c.Close() - data := make([]byte, 1024) - for { - n, err := c.Read(data) - if err != nil { - log.Errorln(err) - break - } - c.Write(data[:n]) - log.Infof("2903 receive: %s", string(data[:n])) - } - }(c1) + defer l1.Close() + + for { + c1, err := l1.Accept() + if err != nil { + log.Errorln(err) + return } - }() - l2, err := net.Listen("tcp", "127.0.0.1:2904") - if err != nil { - t.Fatal(err) - } - go func() { - defer l2.Close() - for { - c1, err := l2.Accept() - if err != nil { - log.Errorln(err) - } - go func(c net.Conn) { - defer c.Close() - data := make([]byte, 1024) - for { - n, err := c.Read(data) - if err != nil { - log.Errorln(err) - break - } - c.Write(data[:n]) - log.Infof("2904 receive: %s", string(data[:n])) + go func(c net.Conn) { + defer c.Close() + data := make([]byte, 1024) + for { + n, err := c.Read(data) + if err != nil { + log.Errorln(err) + break } - }(c1) - } - }() + c.Write(data[:n]) + log.Infof("%s receive: %s", addr, data[:n]) + } + }(c1) + } +} - time.Sleep(time.Second) - c1, resp, err := dialer.Dial("ws://127.0.0.1:2901/p1", nil) - if err != nil { - t.Fatal(err) +func sendAndRecv(addr string, msg string) string { + u, _ := url.Parse(addr) + if u.Scheme == "ws" || u.Scheme == "wss" { + return _sendAndRecvWS(addr, msg) } - if resp.StatusCode != http.StatusSwitchingProtocols { - t.Fatalf("dial ws code %d", resp.StatusCode) + if u.Scheme == "tcp" { + return _sendAndRecvTCP(addr, msg) } - err = c1.WriteMessage(websocket.BinaryMessage, []byte("p1")) + return "" +} + +func _sendAndRecvTCP(addr string, msg string) string { + u, _ := url.Parse(addr) + c, err := net.Dial("tcp", u.Host) if err != nil { - t.Fatal(err) + log.Errorln(err) + return "" } - _, d, err := c1.ReadMessage() + + _, err = c.Write([]byte(msg)) if err != nil { - t.Fatal(err) - } - if !bytes.Equal([]byte("p1"), d) { - t.Errorf("failed msg not equal, expect p1, got %s", d) + log.Errorln(err) + return "" } - c2, resp, err := dialer.Dial("ws://127.0.0.1:2901/p2", nil) + + data := make([]byte, 100) + n, err := c.Read(data) if err != nil { - t.Fatal(err) + log.Errorln(err) + return "" } + return string(data[:n]) +} + +func _sendAndRecvWS(addr string, msg string) string { + c1, resp, err := dialer.Dial(addr, nil) if err != nil { - t.Fatal(err) + log.Errorln(err) + return "" } if resp.StatusCode != http.StatusSwitchingProtocols { - t.Fatalf("dial ws code %d", resp.StatusCode) + log.Errorf("dial ws code %d", resp.StatusCode) } - err = c2.WriteMessage(websocket.BinaryMessage, []byte("p2")) + err = c1.WriteMessage(websocket.BinaryMessage, []byte(msg)) if err != nil { - t.Fatal(err) + log.Errorln(err) + return "" } - _, d, err = c2.ReadMessage() + _, d, err := c1.ReadMessage() if err != nil { - t.Fatal(err) - } - if !bytes.Equal([]byte("p2"), d) { - t.Errorf("failed msg not equal, expect p2, got %s", d) + log.Errorln(err) + return "" } + return string(d) +} - c3, err := net.Dial("tcp", "127.0.0.1:2905") - if err != nil { - t.Fatal(err) - } - _, err = c3.Write([]byte("c3")) +func TestServer(t *testing.T) { + cfgfile := "config.example.yaml" + + log.Default.Level = log.DEBUG + + data, err := ioutil.ReadFile(cfgfile) if err != nil { - t.Fatal(err) + log.Fatal(err) } - d2 := make([]byte, 20) - n, err := c3.Read(d2) - if err != nil { - t.Fatal(err) + var cfg conf + if err := yaml.Unmarshal(data, &cfg); err != nil { + log.Fatal(err) } - if !bytes.Equal([]byte("c3"), d2[:n]) { - t.Errorf("failed msg not equal, expect c3, got %s", d2[:n]) + + makeServers(cfg) + + go echoServer("127.0.0.1:2903") + go echoServer("127.0.0.1:2904") + + time.Sleep(time.Second) + + testdata := []struct { + addr string + msg string + }{ + {"ws://127.0.0.1:2901/p1", "p1"}, + {"ws://127.0.0.1:2901/p2", "p2"}, + {"tcp://127.0.0.1:2905", "c3"}, + } + for _, tt := range testdata { + _m := sendAndRecv(tt.addr, tt.msg) + if _m != tt.msg { + t.Errorf("expected %s, got %s", tt.msg, _m) + } } }