use context to graceful shutdown

master
dingjun 6 years ago
parent 96e0332d16
commit 51e00f2c99

@ -1,6 +1,7 @@
package main
import (
"context"
"flag"
"fmt"
"io"
@ -10,6 +11,7 @@ import (
"os/signal"
"strings"
"syscall"
"time"
glog "github.com/fangdingjun/go-log"
proxyproto "github.com/pires/go-proxyproto"
@ -59,7 +61,7 @@ func getSNIServerName(buf []byte) string {
return msg.serverName
}
func forward(c net.Conn, data []byte, dst string) {
func forward(ctx context.Context, c net.Conn, data []byte, dst string) {
addr := dst
proxyProto := 0
@ -90,8 +92,8 @@ func forward(c net.Conn, data []byte, dst string) {
hdr.Version = 2
}
}
c1, err := net.Dial("tcp", addr)
dialer := &net.Dialer{Timeout: 10 * time.Second}
c1, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
glog.Error(err)
return
@ -101,7 +103,10 @@ func forward(c net.Conn, data []byte, dst string) {
if proxyProto != 0 {
glog.Debugf("send proxy proto v%d to %s", proxyProto, addr)
hdr.WriteTo(c1)
if _, err = hdr.WriteTo(c1); err != nil {
glog.Errorln(err)
return
}
}
if _, err = c1.Write(data); err != nil {
@ -121,7 +126,10 @@ func forward(c net.Conn, data []byte, dst string) {
ch <- struct{}{}
}()
<-ch
select {
case <-ch:
case <-ctx.Done():
}
}
func getDST(c net.Conn, name string) string {
@ -134,7 +142,7 @@ func getDefaultDST() string {
return cfg.Default
}
func serve(c net.Conn) {
func serve(ctx context.Context, c net.Conn) {
defer c.Close()
buf := make([]byte, 1024)
@ -146,7 +154,7 @@ func serve(c net.Conn) {
servername := getSNIServerName(buf[:n])
if servername == "" {
glog.Debugf("no sni, send to default")
forward(c, buf[:n], getDefaultDST())
forward(ctx, c, buf[:n], getDefaultDST())
return
}
dst := getDST(c, servername)
@ -154,7 +162,7 @@ func serve(c net.Conn) {
dst = getDefaultDST()
glog.Debugf("use default dst %s for sni %s", dst, servername)
}
forward(c, buf[:n], dst)
forward(ctx, c, buf[:n], dst)
}
var cfg conf
@ -188,28 +196,34 @@ func main() {
glog.Default.Level = lv
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for _, d := range cfg.Listen {
glog.Infof("listen on :%d", d)
l, err := net.Listen("tcp", fmt.Sprintf(":%d", d))
lc := &net.ListenConfig{}
l, err := lc.Listen(ctx, "tcp", fmt.Sprintf(":%d", d))
if err != nil {
glog.Fatal(err)
}
go func(l net.Listener) {
go func(ctx context.Context, l net.Listener) {
defer l.Close()
for {
c1, err := l.Accept()
if err != nil {
glog.Fatal(err)
glog.Error(err)
break
}
go serve(c1)
go serve(ctx, c1)
}
}(l)
}(ctx, l)
}
ch := make(chan os.Signal, 2)
signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM)
select {
case s := <-ch:
cancel()
glog.Printf("received signal %s, exit.", s)
}
}

@ -1,6 +1,7 @@
package main
import (
"context"
"crypto/tls"
"io/ioutil"
"net"
@ -36,7 +37,7 @@ func TestProxyProto(t *testing.T) {
log.Errorln(err)
return
}
go serve(conn)
go serve(context.Background(), conn)
}
}()
cert, err := tls.LoadX509KeyPair("server.crt", "server.key")

Loading…
Cancel
Save