diff --git a/main.go b/main.go index e8a996d..3f6e23e 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "flag" "fmt" "io" @@ -10,6 +11,7 @@ import ( "os/signal" "strings" "syscall" + "time" glog "github.com/fangdingjun/go-log" proxyproto "github.com/pires/go-proxyproto" @@ -59,7 +61,7 @@ func getSNIServerName(buf []byte) string { return msg.serverName } -func forward(c net.Conn, data []byte, dst string) { +func forward(ctx context.Context, c net.Conn, data []byte, dst string) { addr := dst proxyProto := 0 @@ -90,8 +92,8 @@ func forward(c net.Conn, data []byte, dst string) { hdr.Version = 2 } } - - c1, err := net.Dial("tcp", addr) + dialer := &net.Dialer{Timeout: 10 * time.Second} + c1, err := dialer.DialContext(ctx, "tcp", addr) if err != nil { glog.Error(err) return @@ -101,7 +103,10 @@ func forward(c net.Conn, data []byte, dst string) { if proxyProto != 0 { glog.Debugf("send proxy proto v%d to %s", proxyProto, addr) - hdr.WriteTo(c1) + if _, err = hdr.WriteTo(c1); err != nil { + glog.Errorln(err) + return + } } if _, err = c1.Write(data); err != nil { @@ -121,7 +126,10 @@ func forward(c net.Conn, data []byte, dst string) { ch <- struct{}{} }() - <-ch + select { + case <-ch: + case <-ctx.Done(): + } } func getDST(c net.Conn, name string) string { @@ -134,7 +142,7 @@ func getDefaultDST() string { return cfg.Default } -func serve(c net.Conn) { +func serve(ctx context.Context, c net.Conn) { defer c.Close() buf := make([]byte, 1024) @@ -146,7 +154,7 @@ func serve(c net.Conn) { servername := getSNIServerName(buf[:n]) if servername == "" { glog.Debugf("no sni, send to default") - forward(c, buf[:n], getDefaultDST()) + forward(ctx, c, buf[:n], getDefaultDST()) return } dst := getDST(c, servername) @@ -154,7 +162,7 @@ func serve(c net.Conn) { dst = getDefaultDST() glog.Debugf("use default dst %s for sni %s", dst, servername) } - forward(c, buf[:n], dst) + forward(ctx, c, buf[:n], dst) } var cfg conf @@ -188,28 +196,34 @@ func main() { glog.Default.Level = lv } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for _, d := range cfg.Listen { glog.Infof("listen on :%d", d) - l, err := net.Listen("tcp", fmt.Sprintf(":%d", d)) + lc := &net.ListenConfig{} + l, err := lc.Listen(ctx, "tcp", fmt.Sprintf(":%d", d)) if err != nil { glog.Fatal(err) } - go func(l net.Listener) { + go func(ctx context.Context, l net.Listener) { defer l.Close() for { c1, err := l.Accept() if err != nil { - glog.Fatal(err) + glog.Error(err) + break } - go serve(c1) + go serve(ctx, c1) } - }(l) + }(ctx, l) } ch := make(chan os.Signal, 2) signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM) select { case s := <-ch: + cancel() glog.Printf("received signal %s, exit.", s) } } diff --git a/proto_test.go b/proto_test.go index bb83a92..78b7b26 100644 --- a/proto_test.go +++ b/proto_test.go @@ -1,6 +1,7 @@ package main import ( + "context" "crypto/tls" "io/ioutil" "net" @@ -36,7 +37,7 @@ func TestProxyProto(t *testing.T) { log.Errorln(err) return } - go serve(conn) + go serve(context.Background(), conn) } }() cert, err := tls.LoadX509KeyPair("server.crt", "server.key")